提交 1c090461 编写于 作者: M Megvii Engine Team

fix(mge/dtr): filter bad ops

GitOrigin-RevId: 380262dca0ea11bfa28fca50f4643963bd9ad7d9
上级 300b048a
......@@ -763,7 +763,14 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
break;
}
}
if (!is_inplace && !cross_cn) {
// FIXME: do not use opname as identifier
auto get_name = [](const OpDef& opdef) {
if (auto attr = opdef.try_cast_final<OprAttr>()) {
return attr->type.c_str();
}
return opdef.dyn_typeinfo()->name;
};
if (!is_inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) {
TensorInfo::ComputePath::make(cmd.op, cmd.inputs, cmd.outputs);
size_t detach_cnt = 0;
for (auto output : cmd.outputs) {
......
......@@ -308,6 +308,13 @@ private:
//! whether the warning message has been printed
bool warn_printed = false;
bool is_bad_op(std::string op_name) {
return std::find(op_blacklist.begin(), op_blacklist.end(), op_name) != op_blacklist.end();
}
std::vector<std::string> op_blacklist = {"CollectiveComm", "InplaceAdd",
"ParamPackSplit", "ParamPackConcat", "GaussianRNG"};
} m_dtr;
//! automatically evict an optimal tensor
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册