JIT.cpp 17.9 KB
Newer Older
J
fix fmt  
jackalcooper 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
J
jackalcooper 已提交
16 17 18
#include "OneFlow/JIT.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "OneFlow/OneFlowDialect.h"
J
jackalcooper 已提交
19 20
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
J
jackalcooper 已提交
21
#include "oneflow/core/framework/op_interpreter/jit_op_interpreter.h"
J
jackalcooper 已提交
22
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
J
refine  
jackalcooper 已提交
23
#include "oneflow/core/operator/operator.h"
J
jackalcooper 已提交
24 25
#include "oneflow/core/framework/user_op_registry_manager.h"
#include "oneflow/core/framework/user_op_def.h"
J
jackalcooper 已提交
26 27
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
J
jackalcooper 已提交
28
#include "oneflow/core/kernel/user_kernel.h"
J
jackalcooper 已提交
29 30
#include "oneflow/core/rpc/include/global_process_ctx.h"
#include "oneflow/core/device/device_context_adapter.h"
J
jackalcooper 已提交
31

J
jackalcooper 已提交
32
namespace {
J
jackalcooper 已提交
33 34

using namespace mlir;
J
reefine  
jackalcooper 已提交
35
using namespace ::oneflow::user_op;
J
jackalcooper 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
using namespace ::oneflow;

std::unique_ptr<BlobDesc> GetBlobDescFromMlirTensorType(TensorType tensor_type) {
  auto dtype = DataType::kInvalidDataType;
  if (tensor_type.getElementType().isF32()) {
    dtype = DataType::kFloat;
  } else {
    tensor_type.dump();
    LOG(FATAL) << "fail to get BlobDesc from TensorType";
  }
  auto shape_from_mlir = new Shape({tensor_type.getShape().begin(), tensor_type.getShape().end()});
  return std::make_unique<BlobDesc>(*shape_from_mlir, dtype);
}

ParallelContext GetSingleDeviceParallelContext() {
  ParallelContext parallel_ctx;
  parallel_ctx.set_parallel_id(0);
  parallel_ctx.set_parallel_num(1);
  return parallel_ctx;
}

void InsertLbnSegmentIntoMapping(const ::mlir::ArrayAttr& lbn_segment_keys,
                                 const ::mlir::ArrayAttr& lbn_segment_sizes, ValueRange values,
J
refine  
jackalcooper 已提交
59
                                 std::unordered_map<std::string, mlir::Value>& value_mapping_) {
J
jackalcooper 已提交
60 61 62 63 64 65
  auto operand_it = values.begin();
  for (const auto& bn_size_pair : llvm::zip(lbn_segment_keys, lbn_segment_sizes)) {
    const auto& bn = std::get<0>(bn_size_pair).dyn_cast<StringAttr>().getValue().str();
    const auto& length = std::get<1>(bn_size_pair).dyn_cast<IntegerAttr>().getInt();
    for (size_t i = 0; i < length; i++) {
      const auto indexed_bn = bn + "_" + std::to_string(i);
J
refine  
jackalcooper 已提交
66
      CHECK(value_mapping_.emplace(indexed_bn, *operand_it).second) << "indexed_bn: " << indexed_bn;
J
jackalcooper 已提交
67 68 69 70
      operand_it += 1;
    }
  }
}
J
jackalcooper 已提交
71

J
jackalcooper 已提交
72
class ReturnAllLeaveResultPass : public ReturnAllLeaveResultPassBase<ReturnAllLeaveResultPass> {
J
jackalcooper 已提交
73 74 75 76 77 78 79 80 81 82 83 84
  void runOnFunction() override {
    auto CollectNotUsedResults = [&](Operation* op) {
      for (auto result : op->getOpResults()) {
        if (result.use_empty()) {
          llvm::errs() << "use_empty: ";
          result.dump();
        }
      }
      return WalkResult::advance();
    };
    getFunction()->walk(CollectNotUsedResults);
  }
J
jackalcooper 已提交
85
};
J
jackalcooper 已提交
86

J
reefine  
jackalcooper 已提交
87
struct JITKernelLaunchContext {
J
jackalcooper 已提交
88
  const OpKernel* kernel;
J
refine  
jackalcooper 已提交
89
  KernelComputeContext* compute_ctx;
J
jackalcooper 已提交
90 91
  JITKernelLaunchContext(const OpKernel* kernel, KernelComputeContext* compute_ctx)
      : kernel(kernel), compute_ctx(compute_ctx) {}
J
reefine  
jackalcooper 已提交
92 93
};

J
jackalcooper 已提交
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
StreamContext* GetStreamCxtFromStreamId(const StreamId& stream_id) {
  StreamContext* stream_ctx =
      NewObj<int, StreamContext, const StreamId&>(stream_id.device_id().device_type(), stream_id);
  return stream_ctx;
}

StreamContext* GetComputeStreamCxt() {
  static int64_t GPU0 = 0;
  static DeviceId device_id(GlobalProcessCtx::Rank(), DeviceType::kGPU, GPU0);
  static StreamContext* stream_ctx = GetStreamCxtFromStreamId(StreamId(device_id, 0));
  return stream_ctx;
}

DeviceCtx* GetComputeDeviceCxt() {
  static auto device_ctx = CHECK_NOTNULL(NewDeviceCtxAdapter(GetComputeStreamCxt()));
  return device_ctx;
}

JITKernelLaunchContext* GetKernelLaunchContext(const KernelConf& kernel_conf) {
  static std::vector<std::shared_ptr<const OpKernel>> managed_kernels;
  static std::vector<std::shared_ptr<KernelComputeContext>> managed_compute_contexts;
  static std::vector<std::shared_ptr<JITKernelLaunchContext>> managed_jit_kernel_launch_contexts;
  managed_kernels.emplace_back(one::ir::GetKernel(kernel_conf));
  managed_compute_contexts.emplace_back(
      one::ir::GetKernelComputeContext(GetComputeDeviceCxt(), GetComputeStreamCxt(), kernel_conf));
  auto jit_kernel_launch_ctx = std::make_shared<JITKernelLaunchContext>(
      managed_kernels.back().get(), managed_compute_contexts.back().get());
  managed_jit_kernel_launch_contexts.emplace_back(jit_kernel_launch_ctx);
  return jit_kernel_launch_ctx.get();
J
reefine  
jackalcooper 已提交
123 124
}

J
refine  
jackalcooper 已提交
125 126
extern "C" void _mlir_ciface_LaunchOneFlowKernel(JITKernelLaunchContext* ctx) {
  ctx->kernel->Compute(ctx->compute_ctx);
J
jackalcooper 已提交
127 128
}

J
jackalcooper 已提交
129 130 131 132
class CreateComputeCtxPass : public CreateComputeCtxPassBase<CreateComputeCtxPass> {
  void runOnFunction() override {
    ModuleOp top_module = getFunction()->getParentOfType<ModuleOp>();
    mlir::MLIRContext& context = getContext();
J
jackalcooper 已提交
133 134 135
    auto jit_interpreter =
        dynamic_cast<::oneflow::one::JitInterpreter*>(::oneflow::one::GetJitInterpreter().get());
    auto importer = jit_interpreter->GetImporter();
J
jackalcooper 已提交
136 137
    Builder builder(&context);
    // external func to launch kernel
J
jackalcooper 已提交
138 139
    auto func_type = builder.getFunctionType(
        LLVM::LLVMPointerType::get(IntegerType::get(&context, 8)), llvm::None);
J
jackalcooper 已提交
140 141 142 143 144 145 146 147 148 149 150 151
    auto function = mlir::FuncOp::create(getFunction()->getLoc(), "LaunchOneFlowKernel", func_type);
    top_module.push_back(function);
    auto CollectLowering = [&](Operation* op) {
      if (llvm::dyn_cast<mlir::oneflow::UserOp>(op) || op->hasAttr("op_type_name")) {
        mlir::oneflow::UserOpAdaptor user_op_adaptor(op->getOperands(), op->getAttrDictionary());
        llvm::errs() << "lowering op to launch kernel: ";
        user_op_adaptor.op_name().dump();
        ::oneflow::OperatorConf op_conf;
        const std::string op_name = user_op_adaptor.op_name().getValue().str();
        auto user_conf = op_conf.mutable_user_conf();
        if (succeeded(ConvertUserOpInputs(op, user_op_adaptor, user_conf))
            && succeeded(ConvertUserOpOutputs(op, user_op_adaptor, user_conf))
J
refine  
jackalcooper 已提交
152
            && succeeded(importer.ConvertUserOpAttributes(op, user_op_adaptor, op_conf))
J
jackalcooper 已提交
153
            && succeeded(ConvertCtrlInputs(op, op_conf))) {
J
jackalcooper 已提交
154
          // pass
J
jackalcooper 已提交
155 156 157
        } else {
          return WalkResult::interrupt();
        }
J
jackalcooper 已提交
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
        auto oneflow_op = CHECK_JUST(ConstructOp(op_conf));
        std::unordered_map<std::string, mlir::Value> value_mapping_;  // "a0" => %result
        InsertLbnSegmentIntoMapping(user_op_adaptor.input_lbn_segment_keys(),
                                    user_op_adaptor.input_lbn_segment_sizes(), op->getOperands(),
                                    value_mapping_);
        InsertLbnSegmentIntoMapping(user_op_adaptor.output_lbn_segment_keys(),
                                    user_op_adaptor.output_lbn_segment_sizes(), op->getResults(),
                                    value_mapping_);
        HashMap<std::string, std::unique_ptr<BlobDesc>> lbi2logical_blob_desc_;
        static ParallelContext parallel_ctx = GetSingleDeviceParallelContext();
        auto GetBlobDesc4BnInOp = [&](const std::string& bn) -> BlobDesc* {
          if (lbi2logical_blob_desc_.find(bn) == lbi2logical_blob_desc_.end()) {
            auto value_it = value_mapping_.find(bn);
            if (value_it == value_mapping_.end()) {
              auto blob_desc = std::make_unique<BlobDesc>(DataType::kInvalidDataType);
              CHECK(lbi2logical_blob_desc_.emplace(bn, std::move(blob_desc)).second);
              if (bn != "tmp_buffer_0") {
                op->dump();
                LOG(FATAL) << "value not found in MLIR op for indexed bn: " << bn;
              }
            } else {
              auto found =
                  GetBlobDescFromMlirTensorType(value_it->second.getType().cast<TensorType>());
              CHECK(lbi2logical_blob_desc_.emplace(bn, std::move(found)).second);
            }
          }
          return lbi2logical_blob_desc_.at(bn).get();
        };
        KernelConf kernel_conf;
        oneflow_op->GenKernelConf(GetBlobDesc4BnInOp, &parallel_ctx, &kernel_conf);
        one::ir::GetKernel(kernel_conf);
J
jackalcooper 已提交
189 190 191 192 193 194 195
      }
      return WalkResult::advance();
    };
    getFunction()->walk(CollectLowering);
  }
};

J
jackalcooper 已提交
196 197 198 199
}  // namespace

namespace mlir {
namespace oneflow {
J
jackalcooper 已提交
200

J
jackalcooper 已提交
201 202 203 204
std::unique_ptr<Pass> createReturnAllLeaveResultPass() {
  return std::make_unique<ReturnAllLeaveResultPass>();
}

J
jackalcooper 已提交
205 206 207 208
std::unique_ptr<Pass> createCreateComputeCtxPass() {
  return std::make_unique<CreateComputeCtxPass>();
}

J
jackalcooper 已提交
209 210 211
}  // namespace oneflow
}  // namespace mlir

J
jackalcooper 已提交
212 213 214 215 216 217 218 219
namespace oneflow {

namespace one {

namespace ir {

using namespace mlir;

J
refine  
jackalcooper 已提交
220
OwningOpRef<ModuleOp> CreateJitModule(MLIRContext* context) {
J
jackalcooper 已提交
221 222
  context->loadDialect<mlir::oneflow::OneFlowDialect>();
  context->loadDialect<StandardOpsDialect>();
J
jackalcooper 已提交
223
  context->loadDialect<LLVM::LLVMDialect>();
J
refine  
jackalcooper 已提交
224
  OwningOpRef<ModuleOp> module(
J
jackalcooper 已提交
225 226 227 228
      ModuleOp::create(FileLineColLoc::get(context, "", /*line=*/0, /*column=*/0)));
  return module;
}

J
jackalcooper 已提交
229 230
LogicalResult JitImporter::AppendDataInOperand(const std::string& key, const int32_t index,
                                               const std::string& lbn,
J
refine  
jackalcooper 已提交
231
                                               std::vector<::mlir::Value>& operand_vec) {
J
refine  
jackalcooper 已提交
232
  operand_vec.push_back(GetResultByBnAndIndex(key, index).getValue());
J
refine  
jackalcooper 已提交
233 234
  return success();
}
J
jackalcooper 已提交
235 236
LogicalResult JitImporter::AddDeviceName(const ::oneflow::OperatorConf& op,
                                         std::vector<NamedAttribute>& attr_vec) {
J
jackalcooper 已提交
237 238 239 240 241 242 243 244 245
  const ::oneflow::ParallelConf& pc = parallel_desc_->parallel_conf();
  std::vector<llvm::StringRef> device_vec = {pc.device_name().begin(), pc.device_name().end()};
  attr_vec.push_back(
      GetBuilder().getNamedAttr("device_name", GetBuilder().getStrArrayAttr(device_vec)));
  if (pc.has_hierarchy()) {
    attr_vec.push_back(GetBuilder().getNamedAttr(
        "hierarchy",
        GetBuilder().getI64ArrayAttr({pc.hierarchy().dim().begin(), pc.hierarchy().dim().end()})));
  }
J
jackalcooper 已提交
246 247
  return success();
}
J
refine  
jackalcooper 已提交
248
Type JitImporter::GetTensorTypeOfLbn(const std::string& lbn) {
J
jackalcooper 已提交
249 250
  LogicalBlobId lbi = GenLogicalBlobId(lbn);
  return result_type_mapping_.at(lbi.blob_name());
J
refine  
jackalcooper 已提交
251
}
J
jackalcooper 已提交
252
std::shared_ptr<MirroredTensor> JitImporter::MakeIntermediateTensor(
J
jackalcooper 已提交
253 254
    const std::string& lbn, Value result,
    const std::shared_ptr<const ParallelDesc>& parallel_desc) {
J
jackalcooper 已提交
255
  auto tensor_type = result.getType().cast<TensorType>();
J
jackalcooper 已提交
256
  auto dtype = DataType::kInvalidDataType;
J
jackalcooper 已提交
257
  if (tensor_type.getElementType().isF32()) {
J
jackalcooper 已提交
258 259 260 261 262 263
    dtype = DataType::kFloat;
  } else {
    result.dump();
    LOG(FATAL) << "fail to creat tensor";
  }
  const auto& device = CHECK_JUST(Device::MakeDeviceByParallelDesc(*parallel_desc));
J
jackalcooper 已提交
264
  auto shape_from_mlir = new Shape({tensor_type.getShape().begin(), tensor_type.getShape().end()});
J
refine  
jackalcooper 已提交
265 266
  auto shape = std::make_shared<Shape>();
  shape.reset(shape_from_mlir);
J
jackalcooper 已提交
267 268 269
  auto tensor = MirroredTensor::MakeTensor(shape, dtype, device, /* is_lazy */ true,
                                           /* requires_grad= */ false, /* is_leaf= */ true)
                    .GetPtrOrThrow();
J
jackalcooper 已提交
270
  // TODO: refactor intermediate_tensors_. Same type of op has identical name. For instance, matmul3
J
jackalcooper 已提交
271 272 273 274
  CHECK(intermediate_tensors_.emplace(lbn, tensor).second)
      << "Intermediate tensor already created, lbn: " << lbn;
  CHECK(result_mapping_.emplace(tensor.get(), result).second)
      << "Intermediate tensor already mapped to mlir value, lbn: " << lbn;
J
jackalcooper 已提交
275 276 277 278
  return tensor;
}
LogicalResult JitImporter::InsertOpResults(const ::oneflow::OperatorConf& op_conf,
                                           Operation* created_op) {
J
jackalcooper 已提交
279 280 281 282 283 284 285
  auto output_lbns = created_op->getAttrOfType<ArrayAttr>("output_lbns");
  CHECK_EQ(output_lbns.size(), outputs_->size());
  for (auto data_out : llvm::enumerate(GetDataOutputResults(created_op))) {
    auto lbn = output_lbns[data_out.index()].dyn_cast<StringAttr>().getValue().str();
    auto tensor = MakeIntermediateTensor(lbn, data_out.value(), parallel_desc_);
    (*outputs_)[data_out.index()] = tensor;
  }
J
jackalcooper 已提交
286 287
  return success();
}
J
jackalcooper 已提交
288 289
::oneflow::AttrType JitImporter::QueryAttrType(const std::string& op_type_name,
                                               const std::string& attr_name) {
J
jackalcooper 已提交
290 291 292 293 294 295
  const user_op::OpRegistryResult* val =
      user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_type_name);
  CHECK(val) << " Cannot find op_type_name: " << op_type_name;
  user_op::UserOpDefWrapper op_def(val->op_def);
  CHECK(op_def.IsAttrName(attr_name)) << attr_name << " not a attr name for op: " << op_type_name;
  return op_def.GetAttrType(attr_name);
J
jackalcooper 已提交
296 297
}

J
jackalcooper 已提交
298 299
mlir::FuncOp JitImporter::GetOrInsertFunc(const std::string& func_name, const TensorTuple& inputs,
                                          TensorTuple* outputs) {
J
jackalcooper 已提交
300
  // convert data types from oneflow
J
jackalcooper 已提交
301
  outputs_ = outputs;
J
jackalcooper 已提交
302 303 304 305 306 307
  auto result_types = llvm::SmallVector<Type, 8>();
  SymbolTable symbol_table(GetModule());
  FuncOp found_func = symbol_table.lookup<FuncOp>(func_name);
  if (found_func) {
    return found_func;
  } else {
J
jackalcooper 已提交
308 309 310 311 312 313 314 315 316 317
    auto arg_tensors = GetJitForwardArgs();
    auto arg_types = llvm::SmallVector<Type, 8>();
    for (const auto& arg_tensor : arg_tensors) {
      auto mlir_dtype = GetTypeFromOneFlowDataType(arg_tensor->dtype()->data_type());
      auto mlir_tensor_type =
          RankedTensorType::get(ArrayRef<int64_t>(arg_tensor->shape()->dim_vec().begin(),
                                                  arg_tensor->shape()->dim_vec().end()),
                                mlir_dtype.getValue());
      arg_types.push_back(mlir_tensor_type);
    }
J
jackalcooper 已提交
318 319
    auto func_type = GetBuilder().getFunctionType(arg_types, llvm::NoneType());
    FuncOp function = mlir::FuncOp::create(GetRootLocation(), func_name, func_type);
J
jackalcooper 已提交
320 321
    auto entryBlock = function.addEntryBlock();
    CHECK_EQ(arg_tensors.size(), function.body().getArguments().size());
J
refein  
jackalcooper 已提交
322 323 324 325
    for (auto argument_pair : llvm::zip(arg_tensors, function.body().getArguments())) {
      CHECK(result_mapping_.emplace(std::get<0>(argument_pair).get(), std::get<1>(argument_pair))
                .second);
    }
J
jackalcooper 已提交
326 327 328 329
    GetBuilder().setInsertionPointToStart(entryBlock);
    GetModule().push_back(function);
    return function;
  }
J
jackalcooper 已提交
330 331
}

J
jackalcooper 已提交
332 333 334 335 336 337 338 339 340 341
llvm::Optional<TensorType> JitImporter::GetMlirTensorTypeFromBlobDesc(const BlobDesc& blob_desc) {
  if (auto t = GetTypeFromOneFlowDataType(blob_desc.data_type())) {
    return RankedTensorType::get(
        ArrayRef<int64_t>(blob_desc.shape().dim_vec().begin(), blob_desc.shape().dim_vec().end()),
        t.getValue());
  } else {
    return llvm::None;
  }
}

J
rename  
jackalcooper 已提交
342
void JitImporter::CreateOperandMapping(const ::oneflow::OperatorConf& op_conf,
J
jackalcooper 已提交
343
                                       const std::shared_ptr<const ParallelDesc> parallel_desc,
J
refine  
jackalcooper 已提交
344
                                       const std::shared_ptr<const ArgTuple>& input_arg_tuple,
J
refine  
jackalcooper 已提交
345
                                       const TensorTuple& inputs) {
J
jackalcooper 已提交
346
  operand_mapping_.clear();
J
refine  
jackalcooper 已提交
347 348
  input_arg_tuple_ = input_arg_tuple;
  inputs_ = inputs;
J
fix  
jackalcooper 已提交
349
  result_type_mapping_.clear();
J
jackalcooper 已提交
350 351
  HashMap<std::string, std::unique_ptr<BlobDesc>> lbi2logical_blob_desc_;
  auto op = CHECK_JUST(ConstructOp(op_conf));
J
refine  
jackalcooper 已提交
352 353
  for (auto pair : llvm::zip(input_arg_tuple->indexed_bns(),
                             input_arg_tuple->indexed_arg_name_and_index(), inputs)) {
J
refine  
jackalcooper 已提交
354
    const auto& indexed_bn = std::get<0>(pair);
J
refine  
jackalcooper 已提交
355 356 357 358 359 360
    const auto& indexed_arg_name_and_index = std::get<1>(pair);
    const auto& tensor = std::get<2>(pair);
    if (auto result = GetResultByBnAndIndex(indexed_arg_name_and_index.first,
                                            indexed_arg_name_and_index.second)) {
      assert(operand_mapping_.emplace(indexed_bn, result.getValue()).second);
    } else {
J
refine  
jackalcooper 已提交
361
      LOG(FATAL) << "result not found, indexed_bn: " << indexed_bn << ", tensor: " << tensor.get()
J
refine  
jackalcooper 已提交
362 363
                 << ", shape: " << tensor->shape()->DebugStr()
                 << ", dtype: " << tensor->dtype()->name();
J
jackalcooper 已提交
364
    }
J
jackalcooper 已提交
365
  }
J
jackalcooper 已提交
366 367 368 369 370 371 372 373 374 375 376 377 378 379 380
  // TODO: refine here
  auto GetLogicalBlobDesc4BnInOp = [&](const std::string& bn) -> BlobDesc* {
    if (lbi2logical_blob_desc_.find(bn) == lbi2logical_blob_desc_.end()) {
      auto operand_it = operand_mapping_.find(bn);
      if (operand_it == operand_mapping_.end()) {
        auto blob_desc = std::make_unique<BlobDesc>(DataType::kInvalidDataType);
        CHECK(lbi2logical_blob_desc_.emplace(bn, std::move(blob_desc)).second);
      } else {
        auto found = GetBlobDescFromMlirTensorType(operand_it->second.getType().cast<TensorType>());
        CHECK(lbi2logical_blob_desc_.emplace(bn, std::move(found)).second);
      }
    }
    return lbi2logical_blob_desc_.at(bn).get();
  };
  CHECK_JUST(op->InferLogicalOutBlobDescs(GetLogicalBlobDesc4BnInOp, *parallel_desc));
J
jackalcooper 已提交
381
  static ParallelContext parallel_ctx = GetSingleDeviceParallelContext();
J
jackalcooper 已提交
382 383 384 385 386
  for (auto& kv : lbi2logical_blob_desc_) {
    CHECK(
        result_type_mapping_.emplace(kv.first, GetMlirTensorTypeFromBlobDesc(*kv.second).getValue())
            .second);
  }
J
jackalcooper 已提交
387
}
J
jackalcooper 已提交
388

J
refine  
jackalcooper 已提交
389 390
llvm::Optional<mlir::Value> JitImporter::GetResultByBnAndIndex(const std::string& bn,
                                                               const int32_t index) {
J
refine  
jackalcooper 已提交
391 392 393 394
  auto idx = input_arg_tuple_->TensorTupleIndex4ArgNameAndIndex(bn, index);
  auto tensor = inputs_[idx];
  auto result_it = result_mapping_.find(tensor.get());
  if (result_it == result_mapping_.end()) {
J
refine  
jackalcooper 已提交
395
    return llvm::None;
J
refine  
jackalcooper 已提交
396 397 398 399 400
  } else {
    return result_it->second;
  }
}

J
jackalcooper 已提交
401 402 403
LogicalResult JitImporter::LowerToOneFlowKernel() {
  GetBuilder().create<ReturnOp>(GetModule()->getLoc());
  mlir::PassManager pm(GetModule()->getContext());
J
jackalcooper 已提交
404
  pm.addNestedPass<mlir::FuncOp>(::mlir::createCanonicalizerPass());
J
jackalcooper 已提交
405
  pm.addNestedPass<mlir::FuncOp>(::mlir::oneflow::createReturnAllLeaveResultPass());
J
jackalcooper 已提交
406 407
  pm.addNestedPass<mlir::FuncOp>(::mlir::oneflow::createCreateComputeCtxPass());
  return pm.run(GetModule());
J
jackalcooper 已提交
408 409
}

J
jackalcooper 已提交
410 411 412 413 414
}  // namespace ir

}  // namespace one

}  // namespace oneflow