未验证 提交 2a4eb485 编写于 作者: Z Zhaolong Xing 提交者: GitHub

CHERRY_PICK: from #2070 (#2073)

1. the split op's bug will triger memory optimize pass failed.
test=develop
上级 1b1af3d6
......@@ -49,7 +49,9 @@ void MemoryOptimizePass::CollectLifeCycleByDevice(
"equal",
"lod_reset",
"concat",
"graph_op"};
"graph_op",
"feed",
"fetch"};
for (auto* tmp : node->inlinks) {
CHECK(tmp->IsStmt());
std::string op_type = tmp->AsStmt().op_info()->Type();
......@@ -76,36 +78,23 @@ void MemoryOptimizePass::CollectLifeCycleByDevice(
std::vector<Node*> requires(inputs.begin(), inputs.end());
requires.insert(requires.end(), outputs.begin(), outputs.end());
auto& stmt = op_node->AsStmt();
// The feed and fetch op's inputs and outputs will not be reused.
if (stmt.op_info()->Type() == "feed" ||
stmt.op_info()->Type() == "fetch") {
for (auto* node : op_node->outlinks) {
CHECK(node->IsArg());
std::string var_name = node->AsArg().name;
TargetType target_type = node->AsArg().type->target();
if (is_host(target_type)) target_type = TARGET(kHost);
(*lifecycles)[TargetToStr(target_type)].emplace(
var_name, std::make_pair(0, std::numeric_limits<int>::max()));
}
} else {
for (Node* node : requires) {
CHECK(node->IsArg());
auto& arg = node->AsArg();
if (arg.is_weight || arg.is_persist) continue;
if (!valid_var(node)) continue;
std::string var_name = arg.name;
TargetType target_type = node->AsArg().type->target();
if (is_host(target_type)) target_type = TARGET(kHost);
for (Node* node : requires) {
CHECK(node->IsArg());
auto& arg = node->AsArg();
if (arg.is_weight || arg.is_persist) continue;
if (!valid_var(node)) continue;
std::string var_name = arg.name;
TargetType target_type = node->AsArg().type->target();
if (is_host(target_type)) target_type = TARGET(kHost);
if (!(*lifecycles)[TargetToStr(target_type)].count(var_name)) {
(*lifecycles)[TargetToStr(target_type)].emplace(
var_name, std::make_pair(max_lifecycle_, max_lifecycle_));
} else {
int cur_life =
(*lifecycles)[TargetToStr(target_type)][var_name].second;
(*lifecycles)[TargetToStr(target_type)][var_name].second =
std::max(max_lifecycle_, cur_life);
}
if (!(*lifecycles)[TargetToStr(target_type)].count(var_name)) {
(*lifecycles)[TargetToStr(target_type)].emplace(
var_name, std::make_pair(max_lifecycle_, max_lifecycle_));
} else {
int cur_life =
(*lifecycles)[TargetToStr(target_type)][var_name].second;
(*lifecycles)[TargetToStr(target_type)][var_name].second =
std::max(max_lifecycle_, cur_life);
}
}
++max_lifecycle_;
......@@ -167,6 +156,7 @@ void MemoryOptimizePass::MakeReusePlan(
void MemoryOptimizePass::PerformReusePlan(
SSAGraph* graph,
const std::unordered_map<std::string, std::string>& reuse_table) {
int node_append_idx = 0;
for (auto& op_node : graph->StmtTopologicalOrder()) {
if (!op_node->IsStmt()) continue;
auto& stmt = op_node->AsStmt();
......@@ -190,7 +180,9 @@ void MemoryOptimizePass::PerformReusePlan(
std::string name = input_node->AsArg().name;
if (reuse_table.count(name) && reuse_table.at(name) != name) {
auto replace_name = reuse_table.at(name);
input_node->AsArg().name = replace_name;
input_node->AsArg().name =
replace_name + "(" + std::to_string(node_append_idx) + ")";
node_append_idx++;
}
}
......@@ -212,7 +204,9 @@ void MemoryOptimizePass::PerformReusePlan(
std::string name = out_node->AsArg().name;
if (reuse_table.count(name) && reuse_table.at(name) != name) {
auto replace_name = reuse_table.at(name);
out_node->AsArg().name = replace_name;
out_node->AsArg().name =
replace_name + "(" + std::to_string(node_append_idx) + ")";
node_append_idx++;
}
}
......
......@@ -69,6 +69,7 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
auto input = opdesc.Input("X").front();
auto outs = opdesc.Output("Out");
param_.x = scope->FindVar(input)->GetMutable<lite::Tensor>();
param_.output.clear();
for (auto var : outs) {
param_.output.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册