diff --git a/tools/export_model.py b/tools/export_model.py index 00a1e5803e9fb80f76c5373cda49fe1592a8b658..29a540d68b34ae8641c33163b9703142ccc42d00 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -22,11 +22,15 @@ from paddle.jit import to_static def parse_args(): + def str2bool(v): + return v.lower() in ("true", "t", "1") + parser = argparse.ArgumentParser() parser.add_argument("-m", "--model", type=str) parser.add_argument("-p", "--pretrained_model", type=str) parser.add_argument("-o", "--output_path", type=str) parser.add_argument("--class_dim", type=int, default=1000) + parser.add_argument("--load_static_weights", type=str2bool, default=True) # parser.add_argument("--img_size", type=int, default=224) return parser.parse_args() @@ -57,9 +61,10 @@ def main(): model = Net(net, to_static, args.class_dim) - # Please set 'load_static_weights' to 'True' or 'False' according to the 'pretrained_model' load_dygraph_pretrain( - model, path=args.pretrained_model, load_static_weights=True) + model.pre_net, + path=args.pretrained_model, + load_static_weights=args.load_static_weights) paddle.jit.save(model, args.output_path)