提交 fa471435 编写于 作者: G gaotingquan

fix export_model to support dygraph

上级 c5cf3c15
...@@ -38,7 +38,7 @@ class Net(paddle.nn.Layer): ...@@ -38,7 +38,7 @@ class Net(paddle.nn.Layer):
self.pre_net = net(class_dim=class_dim) self.pre_net = net(class_dim=class_dim)
self.to_static = to_static self.to_static = to_static
# 请根据实际需求修改shape # Please modify the 'shape' according to actual needs
@to_static(input_spec=[ @to_static(input_spec=[
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 3, 224, 224], dtype='float32') shape=[None, 3, 224, 224], dtype='float32')
...@@ -56,8 +56,10 @@ def main(): ...@@ -56,8 +56,10 @@ def main():
net = architectures.__dict__[args.model] net = architectures.__dict__[args.model]
model = Net(net, to_static, args.class_dim) model = Net(net, to_static, args.class_dim)
para_state_dict = paddle.io.load_program_state(args.pretrained_model)
load_dygraph_pretrain(model, args.pretrained_model, True) # Please set 'load_static_weights' to 'True' or 'False' according to the 'pretrained_model'
load_dygraph_pretrain(
model, path=args.pretrained_model, load_static_weights=True)
paddle.jit.save(model, args.output_path) paddle.jit.save(model, args.output_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册