diff --git a/oneflow/core/kernel/user_kernel.cpp b/oneflow/core/kernel/user_kernel.cpp index 51eb1948f8d6c18f9d7bb082edf30830b380cb0d..0275cf3655e5f61d9bf8e840bfd35be9f6df82f5 100644 --- a/oneflow/core/kernel/user_kernel.cpp +++ b/oneflow/core/kernel/user_kernel.cpp @@ -787,6 +787,13 @@ const user_op::OpKernel* GetKernel(const KernelConf& kernel_conf) { return kernel_reg_val->create_fn(&create_ctx); } +user_op::KernelComputeContext* GetKernelComputeContext(DeviceCtx* device_ctx, + StreamContext* stream_ctx, + const KernelConf& kernel_conf) { + auto ctx = new UserKernelComputeContext(device_ctx, stream_ctx, kernel_conf); + return static_cast(ctx); +} + } // namespace ir } // namespace one diff --git a/oneflow/core/kernel/user_kernel.h b/oneflow/core/kernel/user_kernel.h index bb59e356ba9d791574fca9b7e5c55e7d95aac6f1..c9db486b26aa6fce58812081243290c3989e9be3 100644 --- a/oneflow/core/kernel/user_kernel.h +++ b/oneflow/core/kernel/user_kernel.h @@ -68,6 +68,9 @@ namespace one { namespace ir { const user_op::OpKernel* GetKernel(const KernelConf& kernel_conf); +user_op::KernelComputeContext* GetKernelComputeContext(DeviceCtx* device_ctx, + StreamContext* stream_ctx, + const KernelConf& kernel_conf); } // namespace ir diff --git a/oneflow/ir/oneflow-extension/extension.cpp b/oneflow/ir/oneflow-extension/extension.cpp index 94a3ad9b16ad3e9e07747e1729f6444081d62b4f..d0a18075631d1ada2fa1b34d07d8881539c145f5 100644 --- a/oneflow/ir/oneflow-extension/extension.cpp +++ b/oneflow/ir/oneflow-extension/extension.cpp @@ -170,6 +170,7 @@ class MlirJitCpuKernel final : public user_op::OpKernel { private: void Compute(user_op::KernelComputeContext* ctx) const override { + ctx->getcallback()(ctx); WithMlirContext( ctx, {}, [&ctx](mlir::MLIRContext* mlir_ctx) { diff --git a/oneflow/ir/oneflow-jit/JIT.cpp b/oneflow/ir/oneflow-jit/JIT.cpp index 6737d6ce789022ae018db199d64044cc49a4e65a..9ba5f976699237411fb7a52cf482281d244bbeec 100644 --- a/oneflow/ir/oneflow-jit/JIT.cpp +++ b/oneflow/ir/oneflow-jit/JIT.cpp @@ -26,6 +26,8 @@ limitations under the License. #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "oneflow/core/kernel/user_kernel.h" +#include "oneflow/core/rpc/include/global_process_ctx.h" +#include "oneflow/core/device/device_context_adapter.h" namespace { @@ -85,11 +87,39 @@ class ReturnAllLeaveResultPass : public ReturnAllLeaveResultPassBase> created_kernels; - static std::vector> created; +StreamContext* GetStreamCxtFromStreamId(const StreamId& stream_id) { + StreamContext* stream_ctx = + NewObj(stream_id.device_id().device_type(), stream_id); + return stream_ctx; +} + +StreamContext* GetComputeStreamCxt() { + static int64_t GPU0 = 0; + static DeviceId device_id(GlobalProcessCtx::Rank(), DeviceType::kGPU, GPU0); + static StreamContext* stream_ctx = GetStreamCxtFromStreamId(StreamId(device_id, 0)); + return stream_ctx; +} + +DeviceCtx* GetComputeDeviceCxt() { + static auto device_ctx = CHECK_NOTNULL(NewDeviceCtxAdapter(GetComputeStreamCxt())); + return device_ctx; +} + +JITKernelLaunchContext* GetKernelLaunchContext(const KernelConf& kernel_conf) { + static std::vector> managed_kernels; + static std::vector> managed_compute_contexts; + static std::vector> managed_jit_kernel_launch_contexts; + managed_kernels.emplace_back(one::ir::GetKernel(kernel_conf)); + managed_compute_contexts.emplace_back( + one::ir::GetKernelComputeContext(GetComputeDeviceCxt(), GetComputeStreamCxt(), kernel_conf)); + auto jit_kernel_launch_ctx = std::make_shared( + managed_kernels.back().get(), managed_compute_contexts.back().get()); + managed_jit_kernel_launch_contexts.emplace_back(jit_kernel_launch_ctx); + return jit_kernel_launch_ctx.get(); } extern "C" void _mlir_ciface_LaunchOneFlowKernel(JITKernelLaunchContext* ctx) {