提交 d02a218a 编写于 作者: J jackalcooper

add functions for kernel

上级 f4e8a306
......@@ -787,6 +787,13 @@ const user_op::OpKernel* GetKernel(const KernelConf& kernel_conf) {
return kernel_reg_val->create_fn(&create_ctx);
}
user_op::KernelComputeContext* GetKernelComputeContext(DeviceCtx* device_ctx,
StreamContext* stream_ctx,
const KernelConf& kernel_conf) {
auto ctx = new UserKernelComputeContext(device_ctx, stream_ctx, kernel_conf);
return static_cast<user_op::KernelComputeContext*>(ctx);
}
} // namespace ir
} // namespace one
......
......@@ -68,6 +68,9 @@ namespace one {
namespace ir {
const user_op::OpKernel* GetKernel(const KernelConf& kernel_conf);
user_op::KernelComputeContext* GetKernelComputeContext(DeviceCtx* device_ctx,
StreamContext* stream_ctx,
const KernelConf& kernel_conf);
} // namespace ir
......
......@@ -170,6 +170,7 @@ class MlirJitCpuKernel final : public user_op::OpKernel {
private:
void Compute(user_op::KernelComputeContext* ctx) const override {
ctx->getcallback()(ctx);
WithMlirContext(
ctx, {},
[&ctx](mlir::MLIRContext* mlir_ctx) {
......
......@@ -26,6 +26,8 @@ limitations under the License.
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "oneflow/core/kernel/user_kernel.h"
#include "oneflow/core/rpc/include/global_process_ctx.h"
#include "oneflow/core/device/device_context_adapter.h"
namespace {
......@@ -85,11 +87,39 @@ class ReturnAllLeaveResultPass : public ReturnAllLeaveResultPassBase<ReturnAllLe
struct JITKernelLaunchContext {
const OpKernel* kernel;
KernelComputeContext* compute_ctx;
JITKernelLaunchContext(const OpKernel* kernel, KernelComputeContext* compute_ctx)
: kernel(kernel), compute_ctx(compute_ctx) {}
};
KernelComputeContext* GetKernelComputeContext(const ::oneflow::UserOpConf& user_op_conf) {
static std::vector<std::shared_ptr<const OpKernel>> created_kernels;
static std::vector<std::shared_ptr<KernelComputeContext>> created;
StreamContext* GetStreamCxtFromStreamId(const StreamId& stream_id) {
StreamContext* stream_ctx =
NewObj<int, StreamContext, const StreamId&>(stream_id.device_id().device_type(), stream_id);
return stream_ctx;
}
StreamContext* GetComputeStreamCxt() {
static int64_t GPU0 = 0;
static DeviceId device_id(GlobalProcessCtx::Rank(), DeviceType::kGPU, GPU0);
static StreamContext* stream_ctx = GetStreamCxtFromStreamId(StreamId(device_id, 0));
return stream_ctx;
}
DeviceCtx* GetComputeDeviceCxt() {
static auto device_ctx = CHECK_NOTNULL(NewDeviceCtxAdapter(GetComputeStreamCxt()));
return device_ctx;
}
JITKernelLaunchContext* GetKernelLaunchContext(const KernelConf& kernel_conf) {
static std::vector<std::shared_ptr<const OpKernel>> managed_kernels;
static std::vector<std::shared_ptr<KernelComputeContext>> managed_compute_contexts;
static std::vector<std::shared_ptr<JITKernelLaunchContext>> managed_jit_kernel_launch_contexts;
managed_kernels.emplace_back(one::ir::GetKernel(kernel_conf));
managed_compute_contexts.emplace_back(
one::ir::GetKernelComputeContext(GetComputeDeviceCxt(), GetComputeStreamCxt(), kernel_conf));
auto jit_kernel_launch_ctx = std::make_shared<JITKernelLaunchContext>(
managed_kernels.back().get(), managed_compute_contexts.back().get());
managed_jit_kernel_launch_contexts.emplace_back(jit_kernel_launch_ctx);
return jit_kernel_launch_ctx.get();
}
extern "C" void _mlir_ciface_LaunchOneFlowKernel(JITKernelLaunchContext* ctx) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册