提交 fb13a068 编写于 作者: M Megvii Engine Team

feat(jit/opencl): add OpenCL tiny compiler

algo some misc fix:
* fix jit backends env logic issue: fix jit backends env logic issue
* fix OpenCL prop support image detect logic
* disable OpenCL jit on device do not support image
* if opr is not CD4, OpenCL jit will not fuse it
* fix jit test build with clang and without rtti

GitOrigin-RevId: 9311b270d10b13bd8d0ed7831780ae76cac00af6
上级 ed64f0f6
......@@ -47,10 +47,10 @@ LITE_API inline LiteAlgoSelectStrategy operator|(
* @param no_profiling_on_shape_change do not re-profile to select best implement
* algo when input shape changes (use previous algo)
*
* @param jit_level Execute supported operators with JIT (support MLIR,
* NVRTC). Can only be used on Nvidia GPUs and X86 CPU, this value indicates JIT level:
* level 1: for JIT execute with basic elemwise operator
* level 2: for JIT execute elemwise and reduce operators
* @param jit_level Execute supported operators with JIT, please check with
* MGB_JIT_BACKEND for more details, this value indicates JIT level.
* 1: for JIT execute with basic elemwise operator
* 2: for JIT execute elemwise and reduce operators
*
* @param record_level flags to optimize the inference performance with record the
* kernel tasks in first run, hereafter the inference all need is to execute the
......
......@@ -36,10 +36,10 @@ extern "C" {
* \param no_profiling_on_shape_change do not re-profile to select best impl
* algo when input shape changes (use previous algo)
*
* \param jit_level Execute supported operators with JIT (support MLIR,
* NVRTC). Can only be used on Nvidia GPUs, this value indicates JIT level:
* 1 for basic elemwise opr;
* 2 for including reduce operator
* \param jit_level Execute supported operators with JIT, please check with
* MGB_JIT_BACKEND for more details, this value indicates JIT level.
* 1: for basic elemwise opr
* 2: for including reduce operator
*
* \param record_level flag optimize the inference performace with record the
* kernel tasks in first run, hereafter the inference all need to execute the
......
......@@ -744,8 +744,8 @@ DEFINE_uint64(workspace_limit, SIZE_MAX, "set workspace upbound limit");
///////////////////////// other options for optimization /////////////////
DEFINE_bool(
enable_jit, false,
" Execute supported operators with JIT(now only support NVRTC). "
"Can only be used on Nvidia GPUs");
" Execute supported operators with JIT, please check with MGB_JIT_BACKEND for "
"more details");
#if MGB_ENABLE_TENSOR_RT
DEFINE_bool(
tensorrt, false,
......@@ -788,4 +788,4 @@ REGIST_OPTION_CREATOR(memory_optimize, lar::MemoryOptimizeOption::create_option)
REGIST_OPTION_CREATOR(JIT, lar::JITOption::create_option);
#if MGB_ENABLE_TENSOR_RT
REGIST_OPTION_CREATOR(tensorRT, lar::TensorRTOption::create_option);
#endif
\ No newline at end of file
#endif
......@@ -41,8 +41,8 @@ class LiteOptions(Structure):
no_profiling_on_shape_change: do not re-profile to select best implement
algo when input shape changes (use previous algo)
jit_level: Execute supported operators with JIT (support MLIR,
NVRTC). Can only be used on Nvidia GPUs and X86 CPU, this value indicates JIT level:
jit_level: Execute supported operators with JIT, please check with MGB_JIT_BACKEND
for more details, this value indicates JIT level:
level 1: for JIT execute with basic elemwise operator
......
......@@ -97,7 +97,7 @@ void DeviceMemoryAllocator::alloc_dynamic(
}
void DeviceMemoryAllocator::defrag_prealloc_contig(
ComputingGraph* graph, CompNode comp_node,
ComputingGraph* /*graph*/, CompNode comp_node,
size_t size){MGB_TRY{comp_node.free_device(comp_node.alloc_device(size));
}
MGB_CATCH(MemAllocError&, {})
......@@ -574,10 +574,13 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
(options().graph_opt.jit || options().graph_opt.jit_config.enabled())) {
// Deprecated usage added previously. It allows NVRTC JIT optimization
// when graph_opt_level is 0. This usage is not recommanded any more.
mgb_log_warn(
"It is not recommanded to enable JIT optimization when "
"graph_opt_level is 0.");
setenv("MGB_JIT_BACKEND", "NVRTC", 1);
unsigned int max_warm = 9;
do {
mgb_log_warn(
"It is not recommanded to enable JIT optimization when "
"graph_opt_level is 0, try config graph_opt_level more than 0");
} while (max_warm-- > 0);
gopt::GraphOptimizer optimizer;
optimizer.add_pass<gopt::JITFusionPass>(
sopr_stat.has_virtual_grad, options().graph_opt.jit,
......
......@@ -859,11 +859,12 @@ const SeqModifierForSublinearMemory::SeqModifyAction& SeqModifierForSublinearMem
msg.push_back('\n');
msg.append(ssprintf("m_min_bottleneck: %-10.2f\n", m_min_bottleneck * SIZE2MB));
if (!m_par_modifier->m_config->genetic_nr_iter) {
msg.append(
ssprintf("\nGenetic algorithm is currently DISABLED, "
"set MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER [default = 0]"
" to a positive integer to set the number of iterations"
" in genetic algorithm.\n"));
msg.append(ssprintf(
"\nGenetic algorithm is currently DISABLED, "
"set %c%cB_SUBLINEAR_MEMORY_GENETIC_NR_ITER [default = 0]"
" to a positive integer to set the number of iterations"
" in genetic algorithm.\n",
'M', 'G'));
}
mgb_log_debug("%s", msg.c_str());
#else
......@@ -934,10 +935,11 @@ SeqModifierForSublinearMemory::SeqModifyAction SeqModifierForSublinearMemory::
planner_concur = m_config->num_worker;
}
mgb_log_debug(
"use %zu threads to search for sublinear memory plan; "
"this can be changed via MGB_SUBLINEAR_MEMORY_WORKERS env var",
planner_concur);
std::string msg = ssprintf(
"use %zu threads to search for sublinear memory plan; this can be changed "
"via %c%cB_SUBLINEAR_MEMORY_WORKERS env var",
planner_concur, 'M', 'G');
mgb_log_debug("%s", msg.c_str());
for (auto&& i : m_planner_thread_pool.start(planner_concur))
m_thread2planner[i].reset(new ModifyActionPlanner{this});
......
......@@ -41,8 +41,8 @@ The detection is implemented in [impl/fusion_pass.cpp](impl/fusion_pass.cpp),
the main detection logic is in function *Fusion::Impl::on_opr*. Compared to nnvm
fusion, our fusion logic can fuse more operators into one fusion kernel.
For now , JIT just support CUDA, but it has reserved interface to extend other
platforms.
For now , JIT support CUDA by HALIDE or NVRTC, CPU by MLIR, OpenCL by TINYOPENCL,
also it has reserved interface to extend more platforms.
## How to enable JIT
You can set `graph_opt_level` to 3 to enable JIT.
......@@ -57,10 +57,12 @@ cg.set_option('graph_opt_level', 3)
You can set environment variable `MGB_JIT_BACKEND` to select the JIT backend.
| Backend | Platforms | Reduction support | Kernel Binary Cache | Kernel Reuse | Noncontig Input |
|---------|-----------|-------------------|---------------------|--------------|-----------------|
| HALIDE | CUDA | Y | No | Shape | No |
| NVRTC | CUDA | N | Via PersistentCache | Bcast type | Monotone |
| Backend | Platforms | Reduction support | Kernel Binary Cache | Kernel Reuse | Noncontig Input |
|------------|-----------|-------------------|---------------------|--------------|-----------------|
| HALIDE | CUDA | Y | No | Shape | No |
| NVRTC | CUDA | N | Via PersistentCache | Bcast type | Monotone |
| MLIR | CPU | N | NO | Kernel hash | Monotone |
| TINYOPENCL | OpenCL | N | Via OpenCL cache | Kernel hash | Monotone |
To enable fusion of Reduce oprs, set `graph_opt.jit = 2` in graph options.
......
......@@ -53,16 +53,22 @@ ASTPtr gen_powc(ASTPtr inp, float exp) {
return make_call("powf", {inp, exp});
}
} // anonymous namespace
const ElemGeneratorMap& ast_c::elem_opr_generator() {
#define ENTRY(_mode, _impl) \
{ \
ElemMode::_mode, { \
[](const ASTPtrArray& inps) -> ASTPtrArray { return {_impl}; } \
} \
const ElemGeneratorMap& ast_c::elem_opr_generator(CompNode::DeviceType device_type) {
#define ENTRY(_mode, _impl) \
{ \
ElemMode::_mode, { \
[=](const ASTPtrArray& inps, bool is_half) -> ASTPtrArray { \
MGB_MARK_USED_VAR(is_half); \
return {_impl}; \
} \
} \
}
static ElemGeneratorMap map = {
//! other backends map
static ElemGeneratorMap other_map = {
// unary
ENTRY(RELU, make_call("fmaxf", {inps[0], 0.f})),
ENTRY(ABS, make_call("fabsf", inps)),
......@@ -102,7 +108,7 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() {
ENTRY(SWITCH_GT0, ASTPtr::make<Cond3AST>(inps[0] > 0, inps[1], 0)),
ENTRY(TANH_GRAD, (1 - inps[0] * inps[0]) * inps[1]),
ENTRY(TRUE_DIV, inps[0] / inps[1]),
ENTRY(LOG_SUM_EXP, make_call("mgb_log_sum_exp", {inps[0], inps[1]})),
ENTRY(LOG_SUM_EXP, make_call("jit_log_sum_exp", {inps[0], inps[1]})),
ENTRY(LT, ASTPtr::make<BinaryAST>("<", inps[0], inps[1])),
ENTRY(LEQ, ASTPtr::make<BinaryAST>("<=", inps[0], inps[1])),
ENTRY(EQ, ASTPtr::make<BinaryAST>("==", inps[0], inps[1])),
......@@ -133,22 +139,28 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() {
0.f}) /
6.f),
};
mgb_assert(map.size() + 41 == opr::Elemwise::Param::MODE_NR_MEMBER);
mgb_assert(other_map.size() + 41 == opr::Elemwise::Param::MODE_NR_MEMBER);
// unimplemented modes: SHL, SHR, FAST_TANH, FAST_TANH_GRAD, ROUND, RMULH,
// ERFINV, ERFCINV, NOT, AND, OR, XOR, NEQ, ISNAN, ISINF
return map;
return other_map;
#undef ADD_OPR
}
ASTPtrArray ast_c::opr2AST(cg::OperatorNodeBase* opr, const ASTPtrArray& inputs) {
ASTPtrArray ast_c::opr2AST(
cg::OperatorNodeBase* opr, const ASTPtrArray& inputs,
CompNode::DeviceType device_type) {
using namespace opr;
if (auto elem = gopt::try_cast_as_op<Elemwise>(opr)) {
if (check_elem_mode(elem->param().mode)) {
return elem_opr_generator().find(elem->param().mode)->second(inputs);
if (check_elem_mode(elem->param().mode, device_type)) {
return elem_opr_generator(device_type)
.find(elem->param().mode)
->second(inputs, false);
}
}
if (auto powc = gopt::try_cast_as_op<PowC>(opr)) {
mgb_assert(inputs.size() == 1);
return {gen_powc(inputs[0], powc->param().exp)};
}
......@@ -157,6 +169,7 @@ ASTPtrArray ast_c::opr2AST(cg::OperatorNodeBase* opr, const ASTPtrArray& inputs)
if (imm.valid()) {
auto dtype = imm->dtype();
if (dtype == dtype::Int32{}) {
return {ASTPtr::make<IntAST>(imm->get<int>())};
}
float scalar_value;
......@@ -169,10 +182,12 @@ ASTPtrArray ast_c::opr2AST(cg::OperatorNodeBase* opr, const ASTPtrArray& inputs)
InternalError, "dtype(%s) is not any of [Float16, Float32, Int32]",
dtype.name());
}
return {ASTPtr::make<FloatAST>(scalar_value)};
return {ASTPtr::make<FloatAST>(scalar_value, device_type, false)};
}
if (opr->same_type<opr::TypeCvt>()) {
// simply ignore TypeCvt oprs.
mgb_assert(inputs.size() == 1);
return inputs;
......
......@@ -67,40 +67,50 @@ Compiler* Compiler::get(ComputingGraph& graph, CompNode comp_node) {
}
MGB_LOCK_GUARD(holder->mtx);
auto&& compiler = holder->dev2compiler[comp_node.device_type()];
auto backend = MGB_GETENV("MGB_JIT_BACKEND");
auto backend = ::std::getenv(ssprintf("%c%cB_JIT_BACKEND", 'M', 'G').c_str());
mgb_assert(
backend,
"code issue happened, need call config_jit_backends before get compiler");
//! please keep logic with JITFusionPass::Impl::config_jit_backends
if (!compiler) {
switch (comp_node.device_type()) {
#if MGB_CUDA
case CompNode::DeviceType::CUDA:
#if MGB_JIT_HALIDE
if (!backend || !strcmp(backend, "HALIDE")) {
if (!strcmp(backend, "HALIDE")) {
compiler = std::make_unique<HalideCudaCompiler>();
break;
}
#endif
#if MGB_JIT_MLIR
if (!backend || !strcmp(backend, "MLIR")) {
if (!strcmp(backend, "MLIR")) {
compiler =
std::make_unique<MLIRCompiler>(CompNode::DeviceType::CUDA);
break;
}
#endif
if (!backend || !strcmp(backend, "NVRTC")) {
if (!strcmp(backend, "NVRTC")) {
compiler = std::make_unique<CudaCompiler>();
break;
}
mgb_throw(InternalError, "No compiler support for cuda");
mgb_throw(
InternalError,
"No compiler support for cuda, may caused by build not enable "
"MLIR/HALIDE module or error config jit backend env");
break;
#endif
case CompNode::DeviceType::CPU:
#if MGB_JIT_MLIR
if (!backend || !strcmp(backend, "MLIR")) {
if (!strcmp(backend, "MLIR")) {
compiler =
std::make_unique<MLIRCompiler>(CompNode::DeviceType::CPU);
break;
}
#endif
mgb_throw(InternalError, "No compiler support for cpu");
mgb_throw(
InternalError,
"No compiler support for cpu, may caused by build not enable "
"MLIR module or error config jit backend env");
break;
default:
mgb_throw(
......
#include "megbrain/jit/fusion_pass.h"
#include "megbrain/common.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/gopt/gtrans.h"
#include "megbrain/jit/ast_c.h"
#include "megbrain/jit/compiler.h"
#include "megbrain/jit/internal_graph.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/serialization/serializer.h"
#include "megdnn/tensor_format.h"
#if MGB_JIT
......@@ -66,6 +68,9 @@ class JITFusionPass::Impl final {
return num;
}
//! config jit backends
void config_jit_backends(CompNode comp_node) const;
public:
Impl(bool after_grad, JITFeatureBits feature_bits, OptState& opt_state)
: m_after_grad{after_grad},
......@@ -77,6 +82,57 @@ public:
}
};
void JITFusionPass::Impl::config_jit_backends(CompNode comp_node) const {
#define ENV_CB(VALUE) \
if (!backend || !strcmp(backend, VALUE)) { \
if (!backend) { \
mgb_log_debug("config jit default backend to %s", VALUE); \
setenv(ssprintf("%c%cB_JIT_BACKEND", 'M', 'G').c_str(), VALUE, 1); \
} \
break; \
}
auto backend = ::std::getenv(ssprintf("%c%cB_JIT_BACKEND", 'M', 'G').c_str());
if (backend) {
mgb_log_debug("use user config jit backend with: %s", backend);
}
switch (comp_node.device_type()) {
#if MGB_CUDA
// CUDA jit default property: HALIDE > MLIR > NVRTC
case CompNode::DeviceType::CUDA:
#if MGB_JIT_HALIDE
ENV_CB("HALIDE");
#endif
#if MGB_JIT_MLIR
ENV_CB("MLIR");
#endif
ENV_CB("NVRTC");
mgb_throw(
InternalError,
"No compiler support for cuda, may caused by build not enable "
"MLIR/HALIDE module or error config jit backend env");
break;
#endif
// CPU jit only support MLIR now
case CompNode::DeviceType::CPU:
#if MGB_JIT_MLIR
ENV_CB("MLIR");
#endif
mgb_throw(
InternalError,
"No compiler support for cpu, may caused by build not enable "
"MLIR module or error config jit backend env");
break;
default:
mgb_throw(
InternalError,
"unsupported JIT config: "
"comp_node=%s backend_setting=%s",
comp_node.to_string().c_str(), backend);
}
#undef ENV_CB
}
void JITFusionPass::Impl::detect_fusion() {
std::vector<OperatorNodeBase*> topo_order;
m_opt_state.graph().iter([this, &topo_order](OperatorNodeBase* opr) {
......@@ -86,8 +142,19 @@ void JITFusionPass::Impl::detect_fusion() {
}
});
//! call config_jit_backends as soon as possible
for (auto opr : reverse_adaptor(topo_order)) {
auto&& cn = opr->output(0)->comp_node();
if (cn == CompNode::default_cpu()) {
continue;
}
config_jit_backends(cn);
break;
}
for (auto opr : reverse_adaptor(topo_order)) {
if (can_be_fused(opr)) {
mgb_log_debug("%s: try process : %s", __FUNCTION__, opr->cname());
process_opr(opr);
}
}
......@@ -317,11 +384,11 @@ bool JITFusionPass::Impl::can_be_fused(cg::OperatorNodeBase* opr) const {
return false;
}
//! As MLIR backend has some contraints
const char* backend = MGB_GETENV("MGB_JIT_BACKEND");
if (!backend) {
backend = "DEFAULT";
}
auto backend = ::std::getenv(ssprintf("%c%cB_JIT_BACKEND", 'M', 'G').c_str());
mgb_assert(
backend,
"code issue happened, need call config_jit_backends before check opr can "
"be fused");
// float elemwise
if (auto elem = gopt::try_cast_as_op<opr::Elemwise>(opr)) {
bool ret = true;
......@@ -361,11 +428,15 @@ bool JITFusionPass::Impl::can_be_fused(cg::OperatorNodeBase* opr) const {
#undef FOREACH_ELEMWISE_SKIP_MODE
}
#endif // MGB_JIT_MLIR
return ret && ast_c::check_elem_mode(elem->param().mode) &&
return ret &&
ast_c::check_elem_mode(
elem->param().mode, opr->output(0)->comp_node().device_type()) &&
elem->output(0)->dtype().category() == DTypeCategory::FLOAT;
}
if (strcmp(backend, "MLIR")) {
//! TINYOPENCL and MLIR only support elemwise now
if (strcmp(backend, "MLIR") && strcmp(backend, "TINYOPENCL")) {
if (opr->same_type<opr::PowC>()) {
return true;
}
......
......@@ -82,7 +82,8 @@ void gen_input_code(
for (size_t i = 0; i < args.inputs.size(); i++) {
ASTPtr elem_var = ASTPtr::make<VariableAST>("x" + std::to_string(i));
ASTPtr elem_val = gen_data_ast(i, args.inputs[i]);
ASTPtr elem_decl = ASTPtr::make<DeclFloatAST>(elem_var);
ASTPtr elem_decl =
ASTPtr::make<DeclFloatAST>(elem_var, CompNode::DeviceType::CUDA);
ASTPtr elem_assign = ASTPtr::make<AssignAST>(elem_var, elem_val);
var2ast[placeholders[args.inputs[i].idx]->output(0)] = elem_var;
decl_exps_str += elem_decl->code_gen();
......@@ -109,7 +110,7 @@ ASTPtr gen_opr_ast(cg::OperatorNodeBase* opr, const VarNode2AST& var2ast) {
return {cur_inputs[0]};
}
return opr2AST(opr, cur_inputs).at(0);
return opr2AST(opr, cur_inputs, CompNode::DeviceType::CUDA).at(0);
}
} // anonymous namespace
......@@ -145,7 +146,7 @@ struct PEVisitors {
};
template<typename T>
static __forceinline__ __device__ T mgb_log_sum_exp(T x, T y) {
static __forceinline__ __device__ T jit_log_sum_exp(T x, T y) {
T a, b;
a = x < y ? x : y;
b = x < y ? y : x;
......@@ -213,7 +214,8 @@ extern "C" __global__ void {{KERNEL_NAME}} (Data data, size_t num_elements,
}
ASTPtr elem_var = ASTPtr::make<VariableAST>("y" + std::to_string(cur_opr_cnt));
ASTPtr elem_val = gen_opr_ast(opr, var2ast);
ASTPtr elem_decl = ASTPtr::make<DeclFloatAST>(elem_var);
ASTPtr elem_decl =
ASTPtr::make<DeclFloatAST>(elem_var, CompNode::DeviceType::CUDA);
ASTPtr elem_assign = ASTPtr::make<AssignAST>(elem_var, elem_val);
var2ast[opr->output(0)] = elem_var;
internal_decl_exps_str += elem_decl->code_gen();
......
......@@ -11,7 +11,6 @@
#if MGB_JIT && MGB_CUDA
#include <dlfcn.h>
#include <nvrtc.h>
using namespace mgb;
......
#include "./codegen_opencl.h"
#include "./utils.h"
#include "megbrain/common.h"
#include "megbrain/jit/ast_c.h"
#include "megbrain/jit/executor_opr.h"
#include "megbrain/jit/placeholder_opr.h"
#include "megbrain/jit/utils.h"
#include "megbrain/opr/tensor_manip.h"
#include <cinttypes>
#if MGB_JIT && MGB_OPENCL
using namespace mgb;
using namespace jit;
using namespace ast_c;
namespace {
using VarNode2AST = ThinHashMap<VarNode*, ASTPtr>;
//! generate code to access input values in the kernel
void gen_input_code_and_gen_input_data_update(
str_util::StrReplaceMap& replace_map, VarNode2AST& var2ast,
const JITExecutor::Args& args, const PlaceholderArray& placeholders,
bool is_half) {
std::string decl_exps_str, input_data_read_str;
std::string read_image_func = is_half ? "read_imageh" : "read_imagef";
std::string scaler_dec_prefix = is_half ? "__global half* x" : "__global float* x";
auto&& b_info = get_channel_broadcast_info(args);
for (size_t i = 0; i < args.inputs.size(); i++) {
//! gen input args
ASTPtr elem_var_raw =
ASTPtr::make<VariableAST>("x_after_read" + std::to_string(i));
ASTPtr elem_var = ASTPtr::make<VariableAST>(
"__read_only image2d_t x" + std::to_string(i));
ASTPtr elem_var_scalar_offset;
if (LayoutType::SCALAR == b_info[i]) {
elem_var = ASTPtr::make<VariableAST>(scaler_dec_prefix + std::to_string(i));
elem_var_scalar_offset = ASTPtr::make<VariableAST>(
"const uint x_offset" + std::to_string(i));
}
var2ast[placeholders[args.inputs[i].idx]->output(0)] = elem_var_raw;
decl_exps_str += elem_var->code_gen() + ", ";
if (LayoutType::SCALAR == b_info[i]) {
decl_exps_str += elem_var_scalar_offset->code_gen() + ", ";
}
//! gen input data update
ASTPtr elem_var_raw_input = ASTPtr::make<VariableAST>("x" + std::to_string(i));
elem_var_raw = ASTPtr::make<VariableAST>(
(is_half ? "half4 x_after_read" : "float4 x_after_read") +
std::to_string(i));
std::string coord = "coord";
if (LayoutType::BROADCAST == b_info[i]) {
coord = "coord_b";
}
std::string read_method = read_image_func + "(" +
elem_var_raw_input->code_gen() + ", " + coord + ")";
if (LayoutType::SCALAR == b_info[i]) {
if (is_half) {
read_method = "(half4)(vload_half(x_offset" + std::to_string(i) +
", x" + std::to_string(i) + "))";
} else {
read_method = "(float4)(vload(x_offset" + std::to_string(i) + ", x" +
std::to_string(i) + "))";
}
}
ASTPtr elem_assign = ASTPtr::make<AssignAST>(
elem_var_raw, ASTPtr::make<VariableAST>(read_method));
input_data_read_str += elem_assign->code_gen();
}
str_util::append_replace_map(
replace_map, {
{"{{KERNEL_SRC_ARGS}}", decl_exps_str},
{"{{ASSIGN_EXPRS}}", input_data_read_str},
});
}
ASTPtr gen_opr_ast(cg::OperatorNodeBase* opr, const VarNode2AST& var2ast) {
mgb_assert(
!opr->same_type<opr::Reduce>() && !opr->same_type<opr::GetVarShape>() &&
!opr->same_type<opr::Dimshuffle>() && !opr->same_type<opr::PowC>(),
"OpenCL jit not support Reduce/GetVarShape/Dimshuffle/PowC type now");
ASTPtrArray cur_inputs;
for (auto inp_node : opr->input()) {
cur_inputs.push_back(var2ast.at(inp_node));
}
return opr2AST(opr, cur_inputs, CompNode::DeviceType::OPENCL).at(0);
}
} // anonymous namespace
std::pair<std::string, std::string> mgb::jit::codegen_opencl(
const InternalGraph& internal_graph, const JITExecutor::Args& args) {
std::string opencl_kernel = R"(
__kernel void {{KERNEL_NAME}} (
{{KERNEL_SRC_ARGS}}
__write_only image2d_t dst,
__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int wc_size,
__private const int hb_size,
__private const uint w_size
) {
#if OPENCL_ENABLE_FP16
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#endif
const sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP;
int wc = get_global_id(0);
int hb = get_global_id(1);
#ifndef NON_UNIFORM_WORK_GROUP
if (wc >= global_size_dim0 || hb >= global_size_dim1)
return;
#endif
for (; hb < hb_size; hb += global_size_dim1) {
for (; wc < wc_size; wc += global_size_dim0) {
int2 coord = (int2)(wc, hb);
int2 coord_b = (int2)(wc / w_size, 0);
{{INTERNAL_DECL_EXPRS}}
{{ASSIGN_EXPRS}}
{{INTERNAL_ASSIGN_EXPRS}}
{{WRITE_IMAGE}}(dst, coord, {{EXP}});
}
wc = get_global_id(0);
}
}
)";
auto input_dtype = args.inputs[0].layout.dtype;
for (size_t i = 0; i < args.inputs.size(); i++) {
mgb_assert(
args.inputs[i].layout.dtype == input_dtype,
"OpenCL jit all oprs should have same dtype");
}
mgb_assert(
args.outputs.size() == 1 && args.outputs[0].layout.dtype == input_dtype,
"output size should be 1 and output dtype should be same with input");
mgb_assert(
dtype::Float16() == input_dtype || dtype::Float32() == input_dtype,
"OpenCL jit dtype only support float32 or float16, %s not support",
input_dtype.name());
auto is_half = dtype::Float16() == input_dtype;
VarNode2AST var2ast;
str_util::StrReplaceMap source_replace_map;
// add inputs to the replace map
gen_input_code_and_gen_input_data_update(
source_replace_map, var2ast, args, internal_graph.placeholders(), is_half);
// add other oprs
std::string internal_decl_exps_str, internal_assign_exps_str;
std::string write_image_func = is_half ? "write_imageh" : "write_imagef";
size_t cur_opr_cnt = 0;
cg::DepOprIter{[&](cg::OperatorNodeBase* opr) {
++cur_opr_cnt;
if (opr->same_type<JITPlaceholder>()) {
return;
}
ASTPtr elem_var = ASTPtr::make<VariableAST>("y" + std::to_string(cur_opr_cnt));
ASTPtr elem_val = gen_opr_ast(opr, var2ast);
ASTPtr elem_decl = ASTPtr::make<DeclFloatAST>(
elem_var, CompNode::DeviceType::OPENCL, is_half);
ASTPtr elem_assign = ASTPtr::make<AssignAST>(elem_var, elem_val);
var2ast[opr->output(0)] = elem_var;
internal_decl_exps_str += elem_decl->code_gen();
internal_assign_exps_str += elem_assign->code_gen();
}}.add(internal_graph.output());
str_util::append_replace_map(
source_replace_map,
{{"{{INTERNAL_DECL_EXPRS}}", internal_decl_exps_str},
{"{{INTERNAL_ASSIGN_EXPRS}}", internal_assign_exps_str},
{"{{WRITE_IMAGE}}", write_image_func},
{"{{EXP}}", var2ast.at(internal_graph.output())->code_gen()}});
str_util::replace_all_pairs_inplace(opencl_kernel, source_replace_map);
// str_util::replace_all_pairs_inplace(opencl_kernel, source_replace_map);
auto kernel_name = ssprintf(
"jit_opencl_%" PRIx64,
XXHash{}.update(opencl_kernel.data(), opencl_kernel.size()).digest());
str_util::replace_all_pairs_inplace(
opencl_kernel, {{"{{KERNEL_NAME}}", kernel_name}});
return {kernel_name, opencl_kernel};
}
#endif // MGB_JIT && MGB_OPENCL
#pragma once
#include "megbrain_build_config.h"
#if MGB_JIT && MGB_OPENCL
#include "megbrain/jit/executor_opr.h"
namespace mgb {
namespace jit {
/*!
* \brief generate opencl kernel source code
* \return (kernel name, kernel source)
*/
std::pair<std::string, std::string> codegen_opencl(
const InternalGraph& internal_graph, const JITExecutor::Args& args);
} // namespace jit
} // namespace mgb
#endif // MGB_JIT && MGB_OPENCL
#include "megbrain_build_config.h"
#include "megdnn/tensor_format.h"
#if MGB_JIT && MGB_OPENCL
#include "./codegen_opencl.h"
#include "./compiler.h"
#include "./utils.h"
#include "megbrain/common.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/rdnn/management.h"
#include "megbrain/utils/timer.h"
using namespace mgb;
using namespace jit;
/* ==================== OpenCLTinyCompiler ===================== */
OpenCLTinyCompiler::OpenCLTinyCompiler(CompNode::DeviceType device_type) {
m_is_debug = ::std::getenv("OPENCL_JIT_DEBUG") ? true : false;
mgb_assert(
CompNode::DeviceType::OPENCL == device_type,
"error init OpenCLTinyCompiler");
}
std::unique_ptr<Executable> OpenCLTinyCompiler::do_compile(
const InternalGraph& graph, const JITExecutor::Args& args) {
std::string source, kernel_name;
std::tie(kernel_name, source) = codegen_opencl(graph, args);
if (m_is_debug) {
mgb_log_debug("kernel: name: %s\n%s", kernel_name.c_str(), source.c_str());
}
auto ret = std::make_unique<OpenCLExecutable>(
std::move(source), std::move(kernel_name), m_is_debug);
return ret;
}
size_t OpenCLTinyCompiler::get_nr_workspace_outputs(JITExecutor* opr) const {
MGB_MARK_USED_VAR(opr);
return 0;
}
void OpenCLTinyCompiler::init_workspace_size_infer(JITExecutor* opr) {
MGB_MARK_USED_VAR(opr);
}
/* =================== OpenCLExecutable ==================== */
OpenCLExecutable::OpenCLExecutable(std::string source, std::string name, bool is_debug)
: m_source{std::move(source)}, m_name{std::move(name)}, m_is_debug{is_debug} {}
void OpenCLExecutable::execute(JITExecutor* fusion_opr) {
auto&& cn = fusion_opr->comp_node();
auto& env = CompNodeEnv::from_comp_node(cn).opencl_env();
auto handle = mgb::opr::intl::get_megdnn_handle(cn);
auto mgr = env.opencl_mgr;
auto&& ctx = mgr->context();
auto& queue = mgr->command_queue();
auto&& kernel = megdnn::opencl::OpenCLKernel(handle);
auto& args = fusion_opr->args();
static auto&& prop = megcore::opencl::OpenCLProp(mgr->device());
bool is_adreno = prop.is_adreno();
bool is_mali = prop.is_mali();
auto max_work_group = static_cast<uint32_t>(prop.max_work_group_size());
mgb_assert(
prop.is_support_image(),
"code issue happened, OpenCL jit only support device with support image");
//! for debug
MGB_MARK_USED_VAR(ctx);
MGB_MARK_USED_VAR(queue);
size_t WGSX = 0;
size_t WGSY = 0;
//! create cl args
for (size_t i = 0; i < args.inputs.size(); i++) {
if (TensorFormat::Type::IMAGE2D_PACK4 == args.inputs[i].layout.format.type()) {
WGSX = std::max(
WGSX,
args.inputs[i]
.layout.format.as_impl<megdnn::Image2DPack4TensorFormat>()
.image_width(args.inputs[i].layout));
WGSY = std::max(
WGSY,
args.inputs[i]
.layout.format.as_impl<megdnn::Image2DPack4TensorFormat>()
.image_height(args.inputs[i].layout));
}
}
mgb_assert(WGSX > 0 && WGSY > 0, "invalid tensor for OpenCL jit");
if (m_is_debug) {
mgb_log_debug(
"OpenCLExecutable init input tensor array with size: %zu, init output "
"tensor array with size: %zu",
args.inputs.size(), args.outputs.size());
for (size_t i = 0; i < args.inputs.size(); i++) {
mgb_log_debug(
"input(%zu) dim: %zu %s", i, args.inputs[i].layout.ndim,
args.inputs[i].layout.to_string().c_str());
}
for (size_t i = 0; i < args.outputs.size(); i++) {
mgb_log_debug(
"output(%zu) dim: %zu %s", i, args.outputs[i].layout.ndim,
args.outputs[i].layout.to_string().c_str());
}
}
mgb_assert(
args.outputs.size() == 1, "OpenCL elemwise jit output size should be one");
//! create kernel
std::string compile_options;
kernel.set_meta_data({compile_options, m_source});
kernel.set_kernel_name(m_name);
kernel.build_kernel();
//! set tensor args
for (size_t i = 0; i < args.inputs.size(); i++) {
if (TensorFormat::Type::IMAGE2D_PACK4 == args.inputs[i].layout.format.type()) {
kernel.add_tensor_image_args(
{{args.inputs[i].from->dev_tensor().raw_ptr(),
args.inputs[i].layout}});
} else {
//! scalar default format case
kernel.add_tensor_arg(
{args.inputs[i].from->dev_tensor().raw_ptr(),
args.inputs[i].layout});
}
}
kernel.add_tensor_image_args(
{{args.outputs[0].from->dev_tensor().raw_ptr(), args.outputs[0].layout}});
uint32_t block_w = 1, block_h = 1, dimx = 1, dimy = 1;
auto config_super_parameter = [&] {
if (is_adreno) {
block_w = 1;
dimx = 64;
dimy = 1;
} else if (is_mali) {
block_w = 1;
dimx = 96;
dimy = 1;
} else {
//! unknown gpu case
block_w = 1;
dimx = 64;
dimy = 1;
}
//! float16 case
if (dtype::Float16() == args.inputs[0].layout.dtype) {
dimx *= 2;
}
//! scaling dimx less than gws0, dimy less than gws1
dimx = std::min(dimx, static_cast<uint32_t>((WGSX + block_w - 1) / block_w));
dimy = std::min(dimy, static_cast<uint32_t>((WGSY + block_h - 1) / block_h));
//! scaling dimx * dimy less than device max_work_group
dimx = std::min(
dimx, std::max(static_cast<uint32_t>(1), max_work_group / dimy));
};
config_super_parameter();
//! set other args and config lws and gws
int wc_size = WGSX;
int hb_size = WGSY;
WGSX = (WGSX + block_w - 1) / block_w;
WGSY = (WGSY + block_h - 1) / block_h;
int i_WGSX = safe_int<size_t>(WGSX);
int i_WGSY = safe_int<size_t>(WGSY);
kernel.add_args(
{{&i_WGSX, sizeof(int)},
{&i_WGSY, sizeof(int)},
{&wc_size, sizeof(int)},
{&hb_size, sizeof(int)}});
//! have broadcasted_channel_like_input case
int may_w_size = args.outputs[0].layout[3];
kernel.add_arg({&may_w_size, sizeof(cl_uint)});
mgb_log_debug(
"config OpenCL jit kernel args: lws: (%d %d), i_WGSX: %d, i_WGSY: %d "
"wc_size: %d, hb_size: %d, w_size: %d",
dimx, dimy, i_WGSX, i_WGSY, wc_size, hb_size, may_w_size);
kernel.set_local_size({dimx, dimy});
kernel.set_global_size_divup_consider_uniform_gws({WGSX, WGSY});
//! enqueue kernel
kernel.run();
}
#endif // MGB_OPENCL
#pragma once
#include "megbrain_build_config.h"
#if MGB_OPENCL
#include "megbrain/jit/compiler.h"
namespace mgb {
namespace jit {
/*!
* \brief Executable class for OPENCL
*/
class OpenCLExecutable final : public Executable {
public:
OpenCLExecutable(std::string source, std::string name, bool is_debug);
~OpenCLExecutable() = default;
/*!
* \brief execute
* A Executable instance can be executed by one or more fusion_opr
*/
void execute(JITExecutor* fusion_opr) override final;
private:
const std::string m_source;
const std::string m_name;
bool m_is_debug;
};
/*!
* \brief OpenCL tiny compiler, now only handle elemwise opr and just call DNN CL runtime
*/
class OpenCLTinyCompiler final : public Compiler {
std::unique_ptr<Executable> do_compile(
const InternalGraph& graph, const JITExecutor::Args& args) override;
bool m_is_debug;
public:
OpenCLTinyCompiler(CompNode::DeviceType device_type = CompNode::DeviceType::OPENCL);
Property property() const override {
using F = Property::Flag;
return Property{F::BIND_NDIM | F::BIND_SHAPE, JITFeatureBits::NONE, 64};
}
size_t get_nr_workspace_outputs(JITExecutor* opr) const override;
void init_workspace_size_infer(JITExecutor* opr) override;
};
} // namespace jit
} // namespace mgb
#endif // MGB_OPENCL
#include "./utils.h"
#include <vector>
#if MGB_JIT && MGB_OPENCL
std::vector<LayoutType> get_channel_broadcast_info(
const mgb::jit::JITExecutor::Args& args) {
auto output_dim = args.outputs[0].layout.ndim;
auto& out_layout = args.outputs[0].layout;
mgb_assert(
out_layout.ndim == 5,
"code issue happened, OpenCL jit only support image now");
auto n = out_layout[0];
auto c = out_layout[2] * 4;
auto h = out_layout[1];
auto w = out_layout[3];
std::vector<LayoutType> ret;
for (size_t i = 0; i < args.inputs.size(); i++) {
if (args.inputs[i].layout.is_scalar()) {
ret.push_back(LayoutType::SCALAR);
} else {
auto& in_layout = args.inputs[i].layout;
auto in = in_layout[0];
auto ic = in_layout[2] * 4;
auto ih = in_layout[1];
auto iw = in_layout[3];
mgb_assert(
in_layout.ndim == output_dim && in == n && ic == c,
"invalid args for OpenCL jit");
if (ih == h && iw == w) {
ret.push_back(LayoutType::VEC);
} else {
ret.push_back(LayoutType::BROADCAST);
mgb_assert(ih == 1 && iw == 1, "invalid args for OpenCL jit");
}
}
}
return ret;
}
#endif
#include "megbrain_build_config.h"
#if MGB_JIT && MGB_OPENCL
#include "megbrain/jit/compiler.h"
template <typename T, typename S>
T safe_icast(S val) {
static_assert(
std::is_integral<S>::value && std::is_integral<T>::value, "must be int");
mgb_assert(
val <= static_cast<S>(std::numeric_limits<T>::max()) &&
val >= static_cast<S>(0));
return static_cast<T>(val);
}
template <typename S>
int safe_int(S val) {
return safe_icast<int>(val);
}
enum class LayoutType {
SCALAR = 0,
BROADCAST = 1,
VEC = 2,
};
/*!
* \brief get inputs channel broadcast info
* \param args of mgb::jit::JITExecutor::Args
* \return input idx is channel boardcast
*/
std::vector<LayoutType> get_channel_broadcast_info(
const mgb::jit::JITExecutor::Args& args);
#endif
......@@ -42,11 +42,12 @@ public:
inline ASTPtr(int imm);
inline ASTPtr(float imm);
inline ASTPtr(float imm, CompNode::DeviceType cn_type, bool is_half);
};
using ASTPtrArray = SmallVector<ASTPtr>;
//! function type for generating AST nodes
using AstGenerator = thin_function<ASTPtrArray(const ASTPtrArray&)>;
using AstGenerator = thin_function<ASTPtrArray(const ASTPtrArray&, bool is_half)>;
class IntAST : public AST {
public:
......@@ -59,11 +60,20 @@ private:
class FloatAST : public AST {
public:
FloatAST(float val) : m_val(val) {}
inline std::string code_gen() override { return ssprintf("float(%.12e)", m_val); }
FloatAST(
float val, CompNode::DeviceType cn_type = CompNode::DeviceType::CPU,
bool is_half = false)
: m_val(val), m_cn_type(cn_type), m_is_half(is_half) {}
inline std::string code_gen() override {
mgb_assert(!m_is_half, "code issue, only OpenCL support as half now");
return ssprintf("float(%.12e)", m_val);
}
private:
float m_val;
CompNode::DeviceType m_cn_type;
bool m_is_half;
};
class VariableAST : public AST {
......@@ -139,13 +149,20 @@ public:
class DeclFloatAST : public AST {
public:
DeclFloatAST(const ASTPtr& var) : m_var(var) {}
DeclFloatAST(
const ASTPtr& var, CompNode::DeviceType cn_type = CompNode::DeviceType::CPU,
bool is_half = false)
: m_var(var), m_cn_type(cn_type), m_is_half(is_half) {}
inline std::string code_gen() override {
mgb_assert(!m_is_half, "code issue, only OpenCL support as half now");
return "float " + m_var->code_gen() + ";";
}
private:
ASTPtr m_var;
CompNode::DeviceType m_cn_type;
bool m_is_half;
};
class DeclIntAST : public AST {
......@@ -205,23 +222,29 @@ ASTPtr::ASTPtr(int imm) : m_ptr(std::make_shared<IntAST>(imm)) {}
ASTPtr::ASTPtr(float imm) : m_ptr(std::make_shared<FloatAST>(imm)) {}
ASTPtr::ASTPtr(float imm, CompNode::DeviceType cn_type, bool is_half)
: m_ptr(std::make_shared<FloatAST>(imm, cn_type, is_half)) {}
using ElemMode = opr::Elemwise::Mode;
using ElemGeneratorMap = ThinHashMap<ElemMode, AstGenerator>;
//! mapping from elemwise mode to ast node generator
const ElemGeneratorMap& elem_opr_generator();
const ElemGeneratorMap& elem_opr_generator(CompNode::DeviceType type);
static inline bool check_elem_mode(ElemMode mode) {
return elem_opr_generator().count(mode);
static inline bool check_elem_mode(ElemMode mode, CompNode::DeviceType type) {
return elem_opr_generator(type).count(mode);
}
/*!
* \brief Generate a AST node from the opr and the given ast inputs
* \param opr the opr
* \param inputs the AST inputs of the ASTs to be generate
* \param device_type jit backend cn device type
* \return AST nodes corresponding to opr value outputs
*/
ASTPtrArray opr2AST(cg::OperatorNodeBase* opr, const ASTPtrArray& inputs);
ASTPtrArray opr2AST(
cg::OperatorNodeBase* opr, const ASTPtrArray& inputs,
CompNode::DeviceType device_type);
} // namespace ast_c
} // namespace jit
......
......@@ -40,7 +40,7 @@ public:
const InternalGraphPtr& internal_graph, const VarNodeArray& inputs,
const OperatorNodeConfig& config);
static SymbolVar make(
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
const InternalGraphPtr& internal_graph, const VarNodeArray& inputs,
const OperatorNodeConfig& config = {});
......
......@@ -82,10 +82,10 @@ class InternalGraphGenerator {
void find_oprs_depended_by_dimshuffle(cg::OperatorNodeBase* opr);
public:
explicit InternalGraphGenerator(cg::OperatorNodeBase* opr);
MGE_WIN_DECLSPEC_FUC explicit InternalGraphGenerator(cg::OperatorNodeBase* opr);
//! generate the graph; this method can be called multiple times
InternalGraphPtr generate();
MGE_WIN_DECLSPEC_FUC InternalGraphPtr generate();
/*!
* \brief needed input vars in the original (i.e. outer) graph
......@@ -120,7 +120,7 @@ public:
size_t get_cnt_input_if_add(cg::OperatorNodeBase* opr) const;
//! add an operator into this graph; its outputs must have been added
void add_opr(cg::OperatorNodeBase* opr);
MGE_WIN_DECLSPEC_FUC void add_opr(cg::OperatorNodeBase* opr);
//! output var in the outer graph (i.e. the root node)
VarNode* output() const { return m_output; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册