From 4ffd33958edacf9cac8695c8073ff42aa59b2350 Mon Sep 17 00:00:00 2001 From: 0x45f <23097963+0x45f@users.noreply.github.com> Date: Wed, 21 Jul 2021 10:13:48 +0800 Subject: [PATCH] [Cherry-pick][Dy2Stat]Support Nest sequtial container (#34246) #34262 * support Nest sequtial container * rename model path --- .../dygraph_to_static/convert_call_func.py | 2 +- .../dygraph_to_static/test_container.py | 38 +++++++++++++++---- 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py index a621f68c654..b62c16989fb 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py @@ -88,7 +88,7 @@ def is_unsupported(func): for v in m.__dict__.values(): func_in_dict = func == v if isinstance(func_in_dict, (list, numpy.ndarray)): - func_in_dict = any(func_in_dict) + func_in_dict = numpy.array(func_in_dict).any() if func_in_dict: translator_logger.log( 2, diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_container.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_container.py index 647c9e9672c..2c82f5c6990 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_container.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_container.py @@ -47,10 +47,30 @@ class SequentialNet(paddle.nn.Layer): return out +class NestSequentialNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + group1 = paddle.nn.Sequential( + paddle.nn.Linear(10, 10), + paddle.nn.Sigmoid(), ) + group2 = paddle.nn.Sequential( + paddle.nn.Linear(10, 3), + paddle.nn.ReLU(), ) + self.layers = paddle.nn.Sequential(group1, group2) + + def forward(self, x): + return self.layers(x) + + class TestSequential(unittest.TestCase): def setUp(self): paddle.set_device('cpu') self.seed = 2021 + self._init_config() + + def _init_config(self): + self.net = SequentialNet(BufferLayers, 10, 3) + self.model_path = './sequential_net' def _init_seed(self): paddle.seed(self.seed) @@ -58,13 +78,12 @@ class TestSequential(unittest.TestCase): def _run(self, to_static): self._init_seed() - net = SequentialNet(BufferLayers, 10, 3) if to_static: - net = paddle.jit.to_static(net) + self.net = paddle.jit.to_static(self.net) x = paddle.rand([16, 10], 'float32') - out = net(x) + out = self.net(x) if to_static: - load_out = self._test_load(net, x) + load_out = self._test_load(self.net, x) self.assertTrue( np.allclose(load_out, out), msg='load_out is {}\st_out is {}'.format(load_out, out)) @@ -80,12 +99,17 @@ class TestSequential(unittest.TestCase): msg='dygraph_res is {}\nstatic_res is {}'.format(dy_out, st_out)) def _test_load(self, net, x): - model_path = './sequential_net' - paddle.jit.save(net, model_path) - load_net = paddle.jit.load(model_path) + paddle.jit.save(net, self.model_path) + load_net = paddle.jit.load(self.model_path) out = load_net(x) return out +class TestNestSequential(TestSequential): + def _init_config(self): + self.net = NestSequentialNet() + self.model_path = './nested_sequential_net' + + if __name__ == '__main__': unittest.main() -- GitLab