提交 354b478c 编写于 作者: M Megvii Engine Team

fix(opr): remove constant value infer for const SharedDeviceTensor

GitOrigin-RevId: 8fa47b35adbeb8935038b494f6bb2b55df7dd152
上级 78d7d400
...@@ -963,11 +963,21 @@ void mixin::WeightPreprocessExecutor::record_preprocessed_weight( ...@@ -963,11 +963,21 @@ void mixin::WeightPreprocessExecutor::record_preprocessed_weight(
bool mixin::WeightPreprocessExecutor::mixin_allow_weight_preprocess( bool mixin::WeightPreprocessExecutor::mixin_allow_weight_preprocess(
const cg::OperatorNodeBase& opr) const { const cg::OperatorNodeBase& opr) const {
bool param_merged = opr.input(1) if (!opr.input(1)->contain_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE))
->owner_opr() return false;
->same_type<opr::MultipleDeviceTensorHolder>(); if (cg::is_const_var_value(opr.input(1)))
return opr.input(1)->contain_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE) && return true;
(cg::is_const_var_value(opr.input(1)) || param_merged); auto* input_opr = opr.input(1)->owner_opr();
if (input_opr->same_type<opr::MultipleDeviceTensorHolder>() ||
input_opr->same_type<opr::MultipleDeviceTensorWithFormatHolder>())
return true;
auto* sdt = input_opr->try_cast_final<opr::SharedDeviceTensor>();
if (sdt && sdt->const_value())
return true;
auto* sdtf = input_opr->try_cast_final<opr::SharedDeviceTensorWithFormat>();
if (sdtf && sdtf->const_value())
return true;
return false;
} }
/* ==================== ConvolutionForward ==================== */ /* ==================== ConvolutionForward ==================== */
......
...@@ -307,20 +307,6 @@ void intl::SharedDeviceTensorBase::init_output_comp_node() { ...@@ -307,20 +307,6 @@ void intl::SharedDeviceTensorBase::init_output_comp_node() {
comp_node(m_dev_data->comp_node()); comp_node(m_dev_data->comp_node());
} }
bool intl::SharedDeviceTensorBase::fill_in_static_infer(DeviceTensorND* dest) {
if (m_const_value) {
if (dest) {
if (m_static_infer.empty()) {
m_static_infer.comp_node(CompNode::default_cpu())
.copy_from(*m_dev_data);
}
*dest = m_static_infer;
}
return true;
}
return false;
}
cg::static_infer::SourceType SharedDeviceTensor::static_infer_src_type() const { cg::static_infer::SourceType SharedDeviceTensor::static_infer_src_type() const {
return cg::static_infer::SourceType::CONSTANT; return cg::static_infer::SourceType::CONSTANT;
} }
...@@ -886,24 +872,6 @@ void intl::MultipleDeviceTensorHolderBase::init_output_static_infer_desc() { ...@@ -886,24 +872,6 @@ void intl::MultipleDeviceTensorHolderBase::init_output_static_infer_desc() {
}; };
mgr.register_shape_infer(output(i), mgr.register_shape_infer(output(i),
{SourceType::CONSTANT, {}, infer_shp}); {SourceType::CONSTANT, {}, infer_shp});
auto infer_val = [this, i](DeviceTensorND& dest, const InpVal&) {
if (m_host_values.empty()) {
m_host_values.resize(m_values.size());
}
if (m_host_values[i].empty()) {
m_host_values[i]
.comp_node(CompNode::default_cpu())
.copy_from(*m_values[i]);
}
if (!m_host_values[i].empty()) {
dest = m_host_values[i];
return true;
}
return false;
};
mgr.register_value_infer(output(i),
{SourceType::CONSTANT, {}, infer_val});
} }
} }
......
...@@ -75,12 +75,14 @@ class DeviceTensorHolder: public HostIONodeBase { ...@@ -75,12 +75,14 @@ class DeviceTensorHolder: public HostIONodeBase {
*/ */
MGB_DEFINE_CLS_WITH_SUPER(SharedDeviceTensorBase, DeviceTensorHolder) // { MGB_DEFINE_CLS_WITH_SUPER(SharedDeviceTensorBase, DeviceTensorHolder) // {
std::shared_ptr<DeviceTensorND> m_dev_data; std::shared_ptr<DeviceTensorND> m_dev_data;
DeviceTensorND m_static_infer;
bool m_const_value; bool m_const_value;
const TensorShape& get_output_shape() override; const TensorShape& get_output_shape() override;
bool fill_in_static_infer(DeviceTensorND* dest) override; bool fill_in_static_infer(DeviceTensorND* dest) override {
MGB_MARK_USED_VAR(dest);
return false;
}
void init_output_comp_node() override; void init_output_comp_node() override;
...@@ -131,8 +133,6 @@ private: ...@@ -131,8 +133,6 @@ private:
void init_output_comp_node() override; void init_output_comp_node() override;
void init_output_static_infer_desc() override; void init_output_static_infer_desc() override;
NodeProp* do_make_node_prop() const override; NodeProp* do_make_node_prop() const override;
SmallVector<DeviceTensorND> m_host_values;
}; };
} // namespace intl } // namespace intl
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册