提交 354b478c 编写于 作者: M Megvii Engine Team

fix(opr): remove constant value infer for const SharedDeviceTensor

GitOrigin-RevId: 8fa47b35adbeb8935038b494f6bb2b55df7dd152
上级 78d7d400
......@@ -963,11 +963,21 @@ void mixin::WeightPreprocessExecutor::record_preprocessed_weight(
bool mixin::WeightPreprocessExecutor::mixin_allow_weight_preprocess(
const cg::OperatorNodeBase& opr) const {
bool param_merged = opr.input(1)
->owner_opr()
->same_type<opr::MultipleDeviceTensorHolder>();
return opr.input(1)->contain_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE) &&
(cg::is_const_var_value(opr.input(1)) || param_merged);
if (!opr.input(1)->contain_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE))
return false;
if (cg::is_const_var_value(opr.input(1)))
return true;
auto* input_opr = opr.input(1)->owner_opr();
if (input_opr->same_type<opr::MultipleDeviceTensorHolder>() ||
input_opr->same_type<opr::MultipleDeviceTensorWithFormatHolder>())
return true;
auto* sdt = input_opr->try_cast_final<opr::SharedDeviceTensor>();
if (sdt && sdt->const_value())
return true;
auto* sdtf = input_opr->try_cast_final<opr::SharedDeviceTensorWithFormat>();
if (sdtf && sdtf->const_value())
return true;
return false;
}
/* ==================== ConvolutionForward ==================== */
......
......@@ -307,20 +307,6 @@ void intl::SharedDeviceTensorBase::init_output_comp_node() {
comp_node(m_dev_data->comp_node());
}
bool intl::SharedDeviceTensorBase::fill_in_static_infer(DeviceTensorND* dest) {
if (m_const_value) {
if (dest) {
if (m_static_infer.empty()) {
m_static_infer.comp_node(CompNode::default_cpu())
.copy_from(*m_dev_data);
}
*dest = m_static_infer;
}
return true;
}
return false;
}
cg::static_infer::SourceType SharedDeviceTensor::static_infer_src_type() const {
return cg::static_infer::SourceType::CONSTANT;
}
......@@ -886,24 +872,6 @@ void intl::MultipleDeviceTensorHolderBase::init_output_static_infer_desc() {
};
mgr.register_shape_infer(output(i),
{SourceType::CONSTANT, {}, infer_shp});
auto infer_val = [this, i](DeviceTensorND& dest, const InpVal&) {
if (m_host_values.empty()) {
m_host_values.resize(m_values.size());
}
if (m_host_values[i].empty()) {
m_host_values[i]
.comp_node(CompNode::default_cpu())
.copy_from(*m_values[i]);
}
if (!m_host_values[i].empty()) {
dest = m_host_values[i];
return true;
}
return false;
};
mgr.register_value_infer(output(i),
{SourceType::CONSTANT, {}, infer_val});
}
}
......
......@@ -75,12 +75,14 @@ class DeviceTensorHolder: public HostIONodeBase {
*/
MGB_DEFINE_CLS_WITH_SUPER(SharedDeviceTensorBase, DeviceTensorHolder) // {
std::shared_ptr<DeviceTensorND> m_dev_data;
DeviceTensorND m_static_infer;
bool m_const_value;
const TensorShape& get_output_shape() override;
bool fill_in_static_infer(DeviceTensorND* dest) override;
bool fill_in_static_infer(DeviceTensorND* dest) override {
MGB_MARK_USED_VAR(dest);
return false;
}
void init_output_comp_node() override;
......@@ -131,8 +133,6 @@ private:
void init_output_comp_node() override;
void init_output_static_infer_desc() override;
NodeProp* do_make_node_prop() const override;
SmallVector<DeviceTensorND> m_host_values;
};
} // namespace intl
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册