提交 a66d4b8b 编写于 作者: M Megvii Engine Team

fix(mge/parampacksplit): fix param pack split mem forward

GitOrigin-RevId: 8c001b73ffbd086f0cfff7cac2a4c1037bfcecfb
上级 bd3b9cb6
......@@ -308,6 +308,7 @@ class trace:
def _apply_graph_options(self, graph):
graph.options.seq_opt.enable_seq_comp_node_opt = False
# sublinear
if self._sublinear_memory_config is not None:
graph.options.enable_sublinear_memory_opt = True
......
......@@ -1496,6 +1496,14 @@ void ParamPackSplit::init_output_dtype() {
// already initialized in constructor
}
void ParamPackSplit::init_rt_force_dynamic_mem_alloc_imply_chain() {
for (size_t i = 0; i < output().size(); ++i) {
auto s = input(0), t = output(i);
s->add_rt_force_dynamic_mem_alloc_imply_chain(t);
t->add_rt_force_dynamic_mem_alloc_imply_chain(s);
}
}
void ParamPackSplit::mem_plan_fwd_in2out_readonly() {
mgb_assert(m_offsets.size() == output().size() * 2);
for (size_t i = 0; i < output().size(); i++) {
......@@ -1516,16 +1524,19 @@ void ParamPackSplit::init_output_static_infer_desc() {
using namespace std::placeholders;
auto&& mgr = owner_graph()->static_infer_manager();
DepVal shp_deps{{input(0), DepType::SHAPE}};
for (size_t i = 0; i < output().size(); i++) {
auto ov = output(i);
mgr.register_shape_infer(
ov, {SourceType::DEP, shp_deps,
ov, {SourceType::CONSTANT, {},
std::bind(&ParamPackSplit::infer_shape, this, i, _1, _2)});
}
}
void ParamPackSplit::scn_do_execute() {
int inp_size = input(0)->shape().total_nr_elems();
mgb_assert(inp_size == m_offsets.back(), "input shape should match offsets");
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ParamPackSplit) {
mgb_assert(out_grad.size() == opr.output().size());
......
......@@ -591,7 +591,7 @@ MGB_DEFINE_OPR_CLASS(ParamPackSplit, cg::SingleCNOperatorNodeBase) // {
TensorShapeArray m_shapes;
std::vector<dt_int32> m_offsets;
void scn_do_execute() override{};
void scn_do_execute() override;
void init_output_static_infer_desc() override;
bool infer_shape(size_t index, TensorShape &dest,
const cg::static_infer::InpVal &inp);
......@@ -615,6 +615,8 @@ public:
const TensorShapeArray& get_output_shapes() const {
return m_shapes;
}
void init_rt_force_dynamic_mem_alloc_imply_chain() override;
};
/*!
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册