提交 39ba3021 编写于 作者: M Megvii Engine Team

feat(src): add adaptor between custom op and graph runtime

GitOrigin-RevId: 834c0ac72bd47559ed4873ea3800c64aa732e2cf
上级 c9134717
/**
* \file src/opr/impl/custom_opnode.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/opr/custom_opnode.h"
namespace mgb {
namespace opr {
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CustomOpNode);
void CustomOpNode::infer_output_comp_node(void) {
SmallVector<CompNode> input_comp_nodes(input_num());
for (int i=0; i<input_num(); ++i) {
input_comp_nodes[i] = input(i)->comp_node();
}
SmallVector<CompNode> output_comp_nodes = custom::to_builtin<CompNode, custom::Device>(
m_op->infer_output_device(
custom::to_custom<CompNode, custom::Device>(input_comp_nodes), m_param
)
);
for (int i=0; i<output_num(); ++i) {
mgb_assert(output_comp_nodes[i] == output_comp_nodes[0],
"only single comp node operator is supported");
output(i)->comp_node(output_comp_nodes[i]);
}
m_comp_node = output_comp_nodes[0];
}
void CustomOpNode::infer_output_dtype(void) {
SmallVector<DType> input_dtypes(input_num());
for (int i=0; i<input_num(); ++i) {
input_dtypes[i] = input(i)->dtype();
}
SmallVector<DType> output_dtypes = custom::to_builtin<megdnn::DType, custom::DType>(
m_op->infer_output_dtype(
custom::to_custom<megdnn::DType, custom::DType>(input_dtypes), m_param
)
);
for (int i=0; i<output_num(); ++i) {
output(i)->dtype(output_dtypes[i]);
}
}
void CustomOpNode::infer_output_format(void) {
SmallVector<TensorFormat> input_formats(input_num());
for (int i=0; i<input_num(); ++i) {
input_formats[i] = input(i)->format();
}
SmallVector<TensorFormat> output_formats = custom::to_builtin<TensorFormat, custom::Format>(
m_op->infer_output_format(
custom::to_custom<TensorFormat, custom::Format>(input_formats), m_param
)
);
for (int i=0; i<output_num(); ++i) {
output(i)->format(output_formats[i]);
}
}
void CustomOpNode::infer_output_shape(void) {
SmallVector<TensorShape> input_shapes(input_num());
for (int i=0; i<input_num(); ++i) {
input_shapes[i] = input(i)->shape();
}
SmallVector<TensorShape> output_shapes = custom::to_builtin<TensorShape, custom::Shape>(
m_op->infer_output_shape(
custom::to_custom<TensorShape, custom::Shape>(input_shapes), m_param
)
);
for (int i=0; i<output_num(); ++i) {
output(i)->shape(output_shapes[i]);
}
}
void CustomOpNode::infer_output_shape(const TensorShapeArray &input_shapes,
TensorShapeArray &output_shapes) {
output_shapes = custom::to_builtin<TensorShape, custom::Shape>(
m_op->infer_output_shape(
custom::to_custom<TensorShape, custom::Shape>(input_shapes), m_param
)
);
}
// called by computing_graph for each output varnode
bool CustomOpNode::infer_desc(size_t out_idx, TensorShape &output_shape,
const StaticInferInpVal &input_vals) {
TensorShapeArray input_shapes(input_vals.val.size());
TensorShapeArray output_shapes(output_num());
for (size_t i = 0; i < input_shapes.size(); ++ i) {
input_shapes[i] = input_vals.val[i].shape();
}
infer_output_shape(input_shapes, output_shapes);
output_shape = output_shapes.at(out_idx);
return true;
}
void CustomOpNode::init_output_dtype() {
infer_output_dtype();
}
void CustomOpNode::init_output_format() {
infer_output_format();
}
void CustomOpNode::init_output_comp_node() {
infer_output_comp_node();
}
void CustomOpNode::do_execute(ExecEnv &env) {
auto runner = [this]() {
this->owner_graph()->event().signal_inplace<cg::event::BeforeKernel>(
this, m_comp_node
);
m_comp_node.activate();
SmallVector<DeviceTensorND> inputs, outputs;
for(size_t i=0; i<input_num(); i++)
inputs.push_back(input(i)->dev_tensor());
for(size_t i=0; i<output_num(); i++)
outputs.push_back(output(i)->dev_tensor());
std::vector<custom::Tensor> custom_inputs = custom::to_custom<DeviceTensorND, custom::Tensor>(inputs);
std::vector<custom::Tensor> custom_outputs = custom::to_custom<DeviceTensorND, custom::Tensor>(outputs);
m_op->compute(custom_inputs, m_param, custom_outputs);
CompNode::sync_all(); // whether reasonable
this->owner_graph()->event().signal_inplace<cg::event::AfterKernel>(
this, m_comp_node
);
};
env.dispatch_on_comp_node(m_comp_node, runner);
}
void CustomOpNode::init_output_static_infer_desc() {
using namespace std::placeholders;
using namespace cg::static_infer;
m_out_shape.resize(output_num());
auto &&mgr = owner_graph()->static_infer_manager();
DepVal dep;
if (true) { // need design a function to allow user to decide it
for (auto input_var: input())
dep.push_back({input_var, DepType::SHAPE});
}
else {
for (auto input_var: input())
dep.push_back({input_var, DepType::VALUE});
}
for (size_t i = 0; i < output_num(); ++ i) {
mgr.register_shape_infer(output(i), {
dep.empty() ? SourceType::CONSTANT : SourceType::DEP, dep,
std::bind(&CustomOpNode::infer_desc, this, i, _1, _2)
});
}
}
void CustomOpNode::init_output_mem_plan(bool dynamic) {
for (auto output_var: output()) {
if (cg::is_static_var_storage(output_var) == !dynamic
&& !output_var->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC))
output_var->init_mem_plan();
}
}
void CustomOpNode::init_rt_force_dynamic_mem_alloc_imply_chain() {
}
void CustomOpNode::add_input_layout_constraint() {
for (auto &&input_var: input()) {
input_var->add_layout_constraint_contiguous();
}
}
void CustomOpNode::mem_plan_fwd_in2out_readonly() {
}
void CustomOpNode::mem_plan_fwd_in2out_writable() {
}
cg::OperatorNodeBase::OprEventCallback CustomOpNode::get_opr_event_callback() {
return {};
}
void CustomOpNode::on_output_comp_node_stream_changed() {
for (auto output_var: output()) {
if (output_var->comp_node() != m_comp_node) {
mgb_assert(output_var->contain_flag(VarNode::Flag::VOLATILE_CONTENT));
output_var->comp_node(m_comp_node);
}
}
}
cg::OperatorNodeBase::NodeProp* CustomOpNode::do_make_node_prop() const {
// auto ret = &const_cast<OperatorNodeBase::NodeProp&>(node_prop());
// for (auto &&inp_var: input())
// ret->add_dep_type(inp_var, NodeProp::DepType::DEV_VALUE);
// ret->add_flag(NodeProp::Flag::SINGLE_COMP_NODE);
// return ret;
return OperatorNodeBase::do_make_node_prop();
}
bool CustomOpNode::update_priority() const {
if (output_num() == 1
&& output()[0]->contain_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE)) {
node_prop().attribute().priority
= std::numeric_limits<decltype(NodeProp::Attribute::priority)>::min();
return true;
}
return false;
}
CustomOpNode::CustomOpNode(const std::shared_ptr<const custom::CustomOp> &op,
VarNodeArray inputs,
const custom::Param &param,
const OperatorNodeConfig &config):
OperatorNodeBase(inputs[0]->owner_graph(), config, op->op_type(), inputs), m_op(op), m_param(param) {
mgb_assert(input_num() == inputs.size(), "wrong input tensors list length");
for (int i=0; i < input_num(); ++i)
add_input({inputs[i]});
for (int i=0; i<output_num(); ++i)
add_output(output_info(i).name());
if (!std::is_empty<custom::Param>::value) {
using step = unsigned long;
size_t STEP_SIZE = sizeof(step);
std::string hash_str = std::to_string(op->runtime_id());
for (auto &&val: param.raw()) {
hash_str += val.first;
hash_str += val.second.str();
}
if (hash_str.size() % STEP_SIZE != 0)
hash_str += std::string(STEP_SIZE - (hash_str.size() % STEP_SIZE), ' ');
for (size_t pos=0; pos <hash_str.size(); pos+=STEP_SIZE)
add_equivalence_component<PODHash<step>>(reinterpret_cast<const step*>(hash_str.c_str()+pos));
}
}
VarNodeArray CustomOpNode::make(const std::shared_ptr<const custom::CustomOp> &op,
VarNodeArray inputs,
const custom::Param &param,
const OperatorNodeConfig &config) {
auto &&outputs = inputs[0]->owner_graph()->insert_opr(
std::make_unique<CustomOpNode>(op, inputs, param, config))->output();
return outputs;
}
SymbolVarArray CustomOpNode::make(const std::shared_ptr<const custom::CustomOp> &op,
SymbolVarArray inputs,
const custom::Param &param,
const OperatorNodeConfig &config) {
VarNodeArray input_vars(inputs.size());
for (size_t i=0; i<input_vars.size(); ++i)
input_vars[i] = inputs[i].node();
auto &&outputs = inputs[0].node()->owner_graph()->insert_opr(
std::make_unique<CustomOpNode>(op, input_vars, param, config))->output();
SymbolVarArray ret(outputs.size());
for (size_t i=0; i<ret.size(); ++i)
ret[i] = outputs[i];
return ret;
}
custom::RunTimeId CustomOpNode::runtime_id() const {
return m_op->runtime_id();
}
uint32_t CustomOpNode::param_tag(void) const {
return m_op->param_info().tag();
}
custom::Param& CustomOpNode::param(void) {
return m_param;
}
custom::Param CustomOpNode::param(void) const {
return m_param;
}
// a series of functions with the same names as CustomOpImpl
std::string CustomOpNode::op_type(void) const {
return m_op->op_type();
}
std::string CustomOpNode::op_desc(void) const {
return m_op->op_desc();
}
int CustomOpNode::input_num(void) const {
return m_op->input_num();
}
int CustomOpNode::output_num(void) const {
return m_op->output_num();
}
custom::ArgInfo CustomOpNode::input_info(size_t idx) const {
return m_op->input_info(idx);
}
custom::ArgInfo CustomOpNode::output_info(size_t idx) const {
return m_op->output_info(idx);
}
}
}
/**
* \file src/opr/impl/custom_opnode.sereg.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/opr/custom_opnode.h"
#include "megbrain/serialization/sereg.h"
namespace mgb {
namespace serialization {
void custom_dumper(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) {
auto &&custom_op = opr.cast_final_safe<opr::CustomOpNode>();
std::string op_type = custom_op.op_type();
ctx.dump_buf_with_len(op_type.c_str(), op_type.size());
uint32_t tag = custom_op.param_tag();
ctx.dump_buf_with_len(&tag, sizeof(tag));
std::string bytes = custom_op.param().to_bytes();
ctx.dump_buf_with_len(bytes.c_str(), bytes.size());
}
mgb::cg::OperatorNodeBase *custom_loader(OprLoadContext& ctx,
const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
std::string op_type = ctx.load_buf_with_len();
auto *op_manager = custom::CustomOpManager::inst();
auto op = op_manager->find(op_type);
std::string tag_str = ctx.load_buf_with_len();
uint32_t tag = *reinterpret_cast<const uint32_t*>(tag_str.c_str());
mgb_assert(
tag == op->param_info().tag(),
"Wrong Param TAG of Op %s, should be %u, but load %u\n",
op_type.c_str(), op->param_info().tag(), tag
);
custom::Param param(op->param_info());
std::string bytes = ctx.load_buf_with_len();
param.from_bytes(bytes);
return opr::CustomOpNode::make(op, inputs, param, config)[0]->owner_opr();
}
}
}
#define CUSTOM_OP_SEREG_REG(cls) \
namespace { \
struct _OprReg##cls { \
static void entry() { \
MGB_SEREG_OPR_INTL_CALL_ADD( \
cls, \
::mgb::serialization::custom_dumper, \
::mgb::serialization::custom_loader); \
} \
}; \
} \
MGB_SEREG_OPR_INTL_CALL_ENTRY(cls, _OprReg##cls)
using namespace mgb;
using CustomOpNode = opr::CustomOpNode;
CUSTOM_OP_SEREG_REG(CustomOpNode);
/**
* \file src/opr/include/megbrain/opr/custom_opnode.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/custom/custom.h"
#include "megbrain/custom/manager.h"
#include "megbrain/custom/data_adaptor.h"
#include "megbrain/graph/operator_node.h"
#include "megbrain/graph/symbol_var.h"
#include "megbrain/graph/helper.h"
#include "megbrain/graph/event.h"
#include "megbrain/serialization/sereg.h"
namespace mgb {
namespace opr {
using VarNode = cg::VarNode;
using VarNodeArray = cg::VarNodeArray;
using SymbolVar = cg::SymbolVar;
using SymbolVarArray = cg::SymbolVarArray;
using StaticInferInpVal = cg::StaticInferInpVal;
using OperatorNodeConfig = cg::OperatorNodeConfig;
MGB_DEFINE_OPR_CLASS(CustomOpNode, cg::OperatorNodeBase) // {
const std::shared_ptr<const custom::CustomOp> m_op;
custom::Param m_param;
CompNode m_comp_node;
TensorShapeArray m_out_shape;
void infer_output_comp_node(void);
void infer_output_dtype(void);
void infer_output_format(void);
void infer_output_shape(void);
void infer_output_shape(const TensorShapeArray &input_shapes, TensorShapeArray &output_shapes);
// called by computing_graph for each output varnode
bool infer_desc(size_t out_idx, TensorShape &output_shape, const StaticInferInpVal &input_vals);
void init_output_dtype() override final;
void init_output_format() override final;
void init_output_comp_node() override final;
void do_execute(ExecEnv &env) override final;
void init_output_static_infer_desc() override final;
void init_output_mem_plan(bool dynamic) override final;
// [TODO] if some dynamic mem alloc flag in m_opimpl, ignore it for now
void init_rt_force_dynamic_mem_alloc_imply_chain() override final;
// [TODO] only contiguous input is supported
void add_input_layout_constraint() override final;
// [TODO] ignore it for now
void mem_plan_fwd_in2out_readonly() override final;
// [TODO] ignore it for now
void mem_plan_fwd_in2out_writable() override final;
// [TODO] return default ctor obj
OprEventCallback get_opr_event_callback() override final;
// [TODO]
void on_output_comp_node_stream_changed() override final;
// [TODO]
NodeProp* do_make_node_prop() const override final;
// [TODO] default implementation
bool update_priority() const override final;
public:
CustomOpNode(const std::shared_ptr<const custom::CustomOp> &op,
VarNodeArray inputs, const custom::Param &param,
const OperatorNodeConfig &config);
static VarNodeArray make(const std::shared_ptr<const custom::CustomOp> &op,
VarNodeArray inputs, const custom::Param &param,
const OperatorNodeConfig &config);
static SymbolVarArray make(const std::shared_ptr<const custom::CustomOp> &op,
SymbolVarArray inputs, const custom::Param &param,
const OperatorNodeConfig &config);
custom::RunTimeId runtime_id(void) const;
uint32_t param_tag(void) const;
custom::Param& param(void);
custom::Param param(void) const;
std::string op_type(void) const;
std::string op_desc(void) const;
int input_num(void) const;
int output_num(void) const;
custom::ArgInfo input_info(size_t idx) const;
custom::ArgInfo output_info(size_t idx) const;
};
} // namespace opr
}
...@@ -29,6 +29,7 @@ namespace mgb{void call_sereg(){}} ...@@ -29,6 +29,7 @@ namespace mgb{void call_sereg(){}}
#include "../../opr/impl/tensor_gen.sereg.h" #include "../../opr/impl/tensor_gen.sereg.h"
#include "../../opr/impl/tensor_manip.sereg.h" #include "../../opr/impl/tensor_manip.sereg.h"
#include "../../opr/impl/utility.sereg.h" #include "../../opr/impl/utility.sereg.h"
#include "../../opr/impl/custom_opnode.sereg.h"
#if MGB_ENABLE_TENSOR_RT #if MGB_ENABLE_TENSOR_RT
#include "../../tensorrt/impl/tensorrt_opr.sereg.h" #include "../../tensorrt/impl/tensorrt_opr.sereg.h"
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册