Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
d02a218a
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
d02a218a
编写于
11月 02, 2021
作者:
J
jackalcooper
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add functions for kernel
上级
f4e8a306
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
44 addition
and
3 deletion
+44
-3
oneflow/core/kernel/user_kernel.cpp
oneflow/core/kernel/user_kernel.cpp
+7
-0
oneflow/core/kernel/user_kernel.h
oneflow/core/kernel/user_kernel.h
+3
-0
oneflow/ir/oneflow-extension/extension.cpp
oneflow/ir/oneflow-extension/extension.cpp
+1
-0
oneflow/ir/oneflow-jit/JIT.cpp
oneflow/ir/oneflow-jit/JIT.cpp
+33
-3
未找到文件。
oneflow/core/kernel/user_kernel.cpp
浏览文件 @
d02a218a
...
...
@@ -787,6 +787,13 @@ const user_op::OpKernel* GetKernel(const KernelConf& kernel_conf) {
return
kernel_reg_val
->
create_fn
(
&
create_ctx
);
}
user_op
::
KernelComputeContext
*
GetKernelComputeContext
(
DeviceCtx
*
device_ctx
,
StreamContext
*
stream_ctx
,
const
KernelConf
&
kernel_conf
)
{
auto
ctx
=
new
UserKernelComputeContext
(
device_ctx
,
stream_ctx
,
kernel_conf
);
return
static_cast
<
user_op
::
KernelComputeContext
*>
(
ctx
);
}
}
// namespace ir
}
// namespace one
...
...
oneflow/core/kernel/user_kernel.h
浏览文件 @
d02a218a
...
...
@@ -68,6 +68,9 @@ namespace one {
namespace
ir
{
const
user_op
::
OpKernel
*
GetKernel
(
const
KernelConf
&
kernel_conf
);
user_op
::
KernelComputeContext
*
GetKernelComputeContext
(
DeviceCtx
*
device_ctx
,
StreamContext
*
stream_ctx
,
const
KernelConf
&
kernel_conf
);
}
// namespace ir
...
...
oneflow/ir/oneflow-extension/extension.cpp
浏览文件 @
d02a218a
...
...
@@ -170,6 +170,7 @@ class MlirJitCpuKernel final : public user_op::OpKernel {
private:
void
Compute
(
user_op
::
KernelComputeContext
*
ctx
)
const
override
{
ctx
->
getcallback
()(
ctx
);
WithMlirContext
(
ctx
,
{},
[
&
ctx
](
mlir
::
MLIRContext
*
mlir_ctx
)
{
...
...
oneflow/ir/oneflow-jit/JIT.cpp
浏览文件 @
d02a218a
...
...
@@ -26,6 +26,8 @@ limitations under the License.
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "oneflow/core/kernel/user_kernel.h"
#include "oneflow/core/rpc/include/global_process_ctx.h"
#include "oneflow/core/device/device_context_adapter.h"
namespace
{
...
...
@@ -85,11 +87,39 @@ class ReturnAllLeaveResultPass : public ReturnAllLeaveResultPassBase<ReturnAllLe
struct
JITKernelLaunchContext
{
const
OpKernel
*
kernel
;
KernelComputeContext
*
compute_ctx
;
JITKernelLaunchContext
(
const
OpKernel
*
kernel
,
KernelComputeContext
*
compute_ctx
)
:
kernel
(
kernel
),
compute_ctx
(
compute_ctx
)
{}
};
KernelComputeContext
*
GetKernelComputeContext
(
const
::
oneflow
::
UserOpConf
&
user_op_conf
)
{
static
std
::
vector
<
std
::
shared_ptr
<
const
OpKernel
>>
created_kernels
;
static
std
::
vector
<
std
::
shared_ptr
<
KernelComputeContext
>>
created
;
StreamContext
*
GetStreamCxtFromStreamId
(
const
StreamId
&
stream_id
)
{
StreamContext
*
stream_ctx
=
NewObj
<
int
,
StreamContext
,
const
StreamId
&>
(
stream_id
.
device_id
().
device_type
(),
stream_id
);
return
stream_ctx
;
}
StreamContext
*
GetComputeStreamCxt
()
{
static
int64_t
GPU0
=
0
;
static
DeviceId
device_id
(
GlobalProcessCtx
::
Rank
(),
DeviceType
::
kGPU
,
GPU0
);
static
StreamContext
*
stream_ctx
=
GetStreamCxtFromStreamId
(
StreamId
(
device_id
,
0
));
return
stream_ctx
;
}
DeviceCtx
*
GetComputeDeviceCxt
()
{
static
auto
device_ctx
=
CHECK_NOTNULL
(
NewDeviceCtxAdapter
(
GetComputeStreamCxt
()));
return
device_ctx
;
}
JITKernelLaunchContext
*
GetKernelLaunchContext
(
const
KernelConf
&
kernel_conf
)
{
static
std
::
vector
<
std
::
shared_ptr
<
const
OpKernel
>>
managed_kernels
;
static
std
::
vector
<
std
::
shared_ptr
<
KernelComputeContext
>>
managed_compute_contexts
;
static
std
::
vector
<
std
::
shared_ptr
<
JITKernelLaunchContext
>>
managed_jit_kernel_launch_contexts
;
managed_kernels
.
emplace_back
(
one
::
ir
::
GetKernel
(
kernel_conf
));
managed_compute_contexts
.
emplace_back
(
one
::
ir
::
GetKernelComputeContext
(
GetComputeDeviceCxt
(),
GetComputeStreamCxt
(),
kernel_conf
));
auto
jit_kernel_launch_ctx
=
std
::
make_shared
<
JITKernelLaunchContext
>
(
managed_kernels
.
back
().
get
(),
managed_compute_contexts
.
back
().
get
());
managed_jit_kernel_launch_contexts
.
emplace_back
(
jit_kernel_launch_ctx
);
return
jit_kernel_launch_ctx
.
get
();
}
extern
"C"
void
_mlir_ciface_LaunchOneFlowKernel
(
JITKernelLaunchContext
*
ctx
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录