提交 cdb692d2 编写于 作者: M Megvii Engine Team

refactor(imperative): add TODO tag for some functions

GitOrigin-RevId: e295a1fa5537f13bc65f9e82b44a3f9cd56992a6
上级 90dd0716
......@@ -7,13 +7,20 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .._imperative_rt.ops._custom import _install, _uninstall, _get_custom_op_list, _make_custom_op
from .._imperative_rt.ops._custom import (
_get_custom_op_list,
_install,
_make_custom_op,
_uninstall,
)
__all__ = ["load"]
def _gen_custom_op_maker(custom_op_name):
def op_maker(**kwargs):
return _make_custom_op(custom_op_name, kwargs)
return op_maker
......
......@@ -95,6 +95,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> CustomOpDef::infer_output_attrs
for (auto i_shape: i_shapes) {
if (i_shape.ndim == 0) {
success = false;
break;
}
}
......@@ -187,14 +188,11 @@ void apply_on_device_tensornd(const OpDef& def,
auto cn = output.comp_node();
cn.activate();
}
// [TODO] sync should be modified
CompNode::sync_all();
auto&& op = static_cast<const CustomOpDef&>(def);
op.compute(inputs, outputs);
// for (auto &&output: (*outputs)) {
// auto cn = output.comp_node();
// cn.sync(); // cannot sync ??????????
// }
CompNode::sync_all();
}
......@@ -224,19 +222,11 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
}
VarNodeArray apply_on_var_node(const OpDef &def, const cg::VarNodeArray &inputs) {
SymbolVarArray input_syms;
for (auto &input_var: inputs)
input_syms.emplace_back(input_var);
auto&& op = static_cast<const CustomOpDef&>(def);
OperatorNodeConfig config;
SymbolVarArray output_syms = opr::CustomOpNode::make(
op.impl(), input_syms, op.param(), config
VarNodeArray outputs = opr::CustomOpNode::make(
op.impl(), inputs, op.param(), config
);
VarNodeArray outputs;
for (auto &output_sym: output_syms)
outputs.push_back(output_sym.node());
return outputs;
}
......@@ -273,6 +263,7 @@ bool is_same_st(const OpDef& lhs, const OpDef& rhs) {
return a.param() == b.param() && a.runtime_id() == b.runtime_id();
}
// [TODO] to be implemented
std::vector<std::pair<const char*, std::string>> props(const OpDef& def) {
mgb_assert(false, "Custom OpDef Props Function is not IMPLEMENTED now");
// can be implement with param schema
......
......@@ -140,7 +140,8 @@ void CustomOpNode::do_execute(ExecEnv &env) {
std::vector<custom::Tensor> custom_inputs = custom::to_custom<DeviceTensorND, custom::Tensor>(inputs);
std::vector<custom::Tensor> custom_outputs = custom::to_custom<DeviceTensorND, custom::Tensor>(outputs);
m_op->compute(custom_inputs, m_param, custom_outputs);
CompNode::sync_all(); // whether reasonable
// [TODO] sync should be modified
CompNode::sync_all();
this->owner_graph()->event().signal_inplace<cg::event::AfterKernel>(
this, m_comp_node
......@@ -157,7 +158,8 @@ void CustomOpNode::init_output_static_infer_desc() {
auto &&mgr = owner_graph()->static_infer_manager();
DepVal dep;
if (true) { // need design a function to allow user to decide it
// [TODO] need design a interface to allow user to decide it
if (true) {
for (auto input_var: input())
dep.push_back({input_var, DepType::SHAPE});
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册