未验证 提交 32e3353f 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat] Fix always copy by paddle.to_tensor from PR #33335(#33590)

上级 78260ff3
......@@ -247,28 +247,27 @@ class PartialProgramLayer(layers.Layer):
flatten_inputs = flatten(inputs)
# Convert variable into VarBase and feed in training data.
input_vars = []
expected_place = framework._current_expected_place()
for i, value in enumerate(flatten_inputs):
if isinstance(value, np.ndarray):
var = core.VarBase(
value=value,
name=self._inputs[i].desc.name(),
persistable=False,
place=framework._current_expected_place(),
place=expected_place,
zero_copy=True)
elif isinstance(value, core.VarBase):
value.name = self._inputs[i].desc.name()
if value.stop_gradient:
# NOTE(Aurelius84): If var is on CPUPlace, it will be transformed multi times
# into CUDAPlace when it's as input of multi Ops. so we move it in advance
# to avoid this problem.
var = paddle.to_tensor(
value,
dtype=value.dtype,
place=framework._current_expected_place(),
stop_gradient=True)
# NOTE(Aurelius84): If var is on CPUPlace, it will be transformed multi times
# into CUDAPlace when it's as input of multi Ops. so we move it in advance
# to avoid this problem.
if value.stop_gradient and not value.place._equals(
expected_place):
var = value._copy_to(expected_place, False)
var.stop_gradient = True
var.name = value.name
else:
var = value
var.name = self._inputs[i].desc.name()
else:
continue
input_vars.append(var)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册