diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc index 360c9270782083169e4d4dc5ffdb29116f5bd893..7fc8eff3d31c9e8cceabb7f29f63373711f498d2 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc @@ -193,6 +193,8 @@ std::unique_ptr CinnCompiler::CompileGraph( CinnGraphSymbolization symbol{compiled_num, graph, target, input_tensors}; auto frontend_program = symbol(); ProgramPass::Apply(&frontend_program, target, {"Decomposer"}); + auto fetch_ids = symbol.GetFetchIds(); + ::cinn::frontend::ApplyPass(&frontend_program, fetch_ids, "RemoveIdentity"); auto cinn_graph = std::make_shared<::cinn::hlir::framework::Graph>( frontend_program, target); VLOG(1) << "-- The " << compiled_num << "-th compilation (" @@ -201,7 +203,6 @@ std::unique_ptr CinnCompiler::CompileGraph( ApplyPass(cinn_graph.get(), "OpFusion"); auto scope = BuildScope(target, cinn_graph); - auto fetch_ids = symbol.GetFetchIds(); VLOG(4) << "All fetch var ids in CINN: " << string::join_strings(fetch_ids, ',');