提交 5ac48b77 编写于 作者: G gaotingquan

fix export_model to support dygraph

上级 fa471435
...@@ -22,11 +22,15 @@ from paddle.jit import to_static ...@@ -22,11 +22,15 @@ from paddle.jit import to_static
def parse_args(): def parse_args():
def str2bool(v):
return v.lower() in ("true", "t", "1")
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str) parser.add_argument("-m", "--model", type=str)
parser.add_argument("-p", "--pretrained_model", type=str) parser.add_argument("-p", "--pretrained_model", type=str)
parser.add_argument("-o", "--output_path", type=str) parser.add_argument("-o", "--output_path", type=str)
parser.add_argument("--class_dim", type=int, default=1000) parser.add_argument("--class_dim", type=int, default=1000)
parser.add_argument("--load_static_weights", type=str2bool, default=True)
# parser.add_argument("--img_size", type=int, default=224) # parser.add_argument("--img_size", type=int, default=224)
return parser.parse_args() return parser.parse_args()
...@@ -57,9 +61,10 @@ def main(): ...@@ -57,9 +61,10 @@ def main():
model = Net(net, to_static, args.class_dim) model = Net(net, to_static, args.class_dim)
# Please set 'load_static_weights' to 'True' or 'False' according to the 'pretrained_model'
load_dygraph_pretrain( load_dygraph_pretrain(
model, path=args.pretrained_model, load_static_weights=True) model.pre_net,
path=args.pretrained_model,
load_static_weights=args.load_static_weights)
paddle.jit.save(model, args.output_path) paddle.jit.save(model, args.output_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册