未验证 提交 88f1a98e 编写于 作者: W wanghuancoder 提交者: GitHub

[IR] fix getfaketensorlist bug (#57063)

* fix getfaketensorlist bug

* refine
上级 a2d61455
......@@ -359,12 +359,12 @@ std::vector<std::shared_ptr<phi::TensorBase>> GetFakeTensorList(
} else if (input_type.isa<ir::VectorType>()) {
auto vec_inner_types = input_type.dyn_cast<ir::VectorType>().data();
for (size_t i = 0; i < vec_inner_types.size(); ++i) {
if (vec_inner_types[0].isa<dialect::AllocatedDenseTensorType>()) {
if (vec_inner_types[i].isa<dialect::AllocatedDenseTensorType>()) {
vec_res.push_back(build_fake_dense_tensor(
vec_inner_types[0].dyn_cast<dialect::AllocatedDenseTensorType>()));
} else if (vec_inner_types[0].isa<dialect::AllocatedSelectedRowsType>()) {
vec_inner_types[i].dyn_cast<dialect::AllocatedDenseTensorType>()));
} else if (vec_inner_types[i].isa<dialect::AllocatedSelectedRowsType>()) {
vec_res.push_back(build_fake_selected_rows(
vec_inner_types[0].dyn_cast<dialect::AllocatedSelectedRowsType>()));
vec_inner_types[i].dyn_cast<dialect::AllocatedSelectedRowsType>()));
}
}
}
......
......@@ -635,7 +635,9 @@ class PartialProgramLayer:
filter(_need_aggregation, self._outputs.tolist())
)
for _var in to_processed_vars:
_insert_aggregation_ops_for_var(target_program, _var)
target_program: paddle.static.Program
target_var = target_program.global_block().var(_var.name)
_insert_aggregation_ops_for_var(target_program, target_var)
@switch_to_static_graph
def _append_backward_desc(self, main_program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册