未验证 提交 4b1d28fc 编写于 作者: X xiongkun 提交者: GitHub

【SetItem】Fix setitem for function call. (#56810)

* fix error

* fix setitem

* fix bgs

* fix
上级 1b8619c7
......@@ -152,8 +152,8 @@ class NameloadJstTransformer(BaseTransformer):
"""
Can't convert name of function call, bacause this will affect CallTransformer.
"""
node.args = [self.generic_visit(arg) for arg in node.args]
node.func = self.generic_visit(node.func)
node.args = [self.visit(arg) for arg in node.args]
node.func = self.visit(node.func)
return node
def visit_Attribute(self, node):
......
......@@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.nn.functional as F
class TestSetItemBase(unittest.TestCase):
......@@ -231,5 +233,31 @@ class TestCase14(TestSetItemBase):
return (y,)
class TestCase15(TestSetItemBase):
# Test gradient of value tensor
def init_func(self):
def foo(x, H, W):
B, _, _, C = x.shape
pad_list = paddle.zeros([4], dtype="int32")
pad_list[3] = H // 2
pad_list[1] = W // 2
# 问题在这里,进去F.pad以后,pad_list是初始变量而非赋值后的变量
# 在修改前,赋值前后的变量是同一个,没有问题
# 修改后,期望接收赋值后的变量,接收赋值前变量结果是不对的
x = F.pad(x, pad_list, data_format="NHWC")
return x
return foo
def run_dygraph(self, func):
# 注释这句看结果diff
x = paddle.ones((1, 6, 6, 3))
H = paddle.full([1], 6, dtype='int32')
W = paddle.full([1], 6, dtype='int32')
y = func(x, H, W)
return (y,)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册