未验证 提交 b7fac0f9 编写于 作者: H HydrogenSulfate 提交者: GitHub

fix paddle.summary's bug when outputs contains non-tensor (#34160)

* fix paddle.summary's bug when output contains non-tensor
上级 02cc3c5e
......@@ -262,8 +262,10 @@ def summary_string(model, input_size=None, dtypes=None, input=None):
def _get_output_shape(output):
if isinstance(output, (list, tuple)):
output_shape = [_get_output_shape(o) for o in output]
else:
elif hasattr(output, 'shape'):
output_shape = list(output.shape)
else:
output_shape = []
return output_shape
def register_hook(layer):
......
......@@ -68,6 +68,28 @@ class LeNetDygraph(paddle.nn.Layer):
return x
class ModelInner(paddle.nn.Layer):
def __init__(self):
super(ModelInner, self).__init__()
self.fc = paddle.nn.Linear(3, 4)
def forward(self, x):
y = self.fc(x)
return y, 0
class ModelOutter(paddle.nn.Layer):
def __init__(self):
super(ModelOutter, self).__init__()
self.module1 = ModelInner()
self.module2 = paddle.nn.Linear(4, 5)
def forward(self, x):
y, dummpy = self.module1(x)
y = self.module2(y)
return y, 3
class LeNetListInput(LeNetDygraph):
def forward(self, inputs):
x = inputs[0]
......@@ -607,6 +629,9 @@ class TestModelFunction(unittest.TestCase):
model.summary(input_size=[(20)])
model.summary(input_size=(20), dtype='float32')
def test_summary_non_tensor(self):
paddle.summary(ModelOutter(), input_size=(-1, 3))
def test_summary_nlp(self):
def _get_param_from_state_dict(state_dict):
params = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册