diff --git a/x2paddle/convert.py b/x2paddle/convert.py index cadd9fd9d6a0b246c3efced681031e9596587566..792d8dc7f2ca746902591950ca7d84f84aebcb2c 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -300,7 +300,7 @@ def onnx2paddle(model_path, from x2paddle.decoder.onnx_decoder import ONNXDecoder from x2paddle.op_mapper.onnx2paddle.onnx_op_mapper import ONNXOpMapper - model = ONNXDecoder(model_path, input_shape_dict, enable_onnx_checker) + model = ONNXDecoder(model_path, enable_onnx_checker, input_shape_dict) mapper = ONNXOpMapper(model) mapper.paddle_graph.build() logging.info("Model optimizing ...") diff --git a/x2paddle/decoder/onnx_decoder.py b/x2paddle/decoder/onnx_decoder.py index f51be83573ec17e3c6cf11be41c6a600752b9430..7fe4696d49685ce0c5d8a8b8fe7ff5de051edea2 100755 --- a/x2paddle/decoder/onnx_decoder.py +++ b/x2paddle/decoder/onnx_decoder.py @@ -173,7 +173,7 @@ class ONNXGraphDataNode(GraphNode): class ONNXGraph(Graph): - def __init__(self, onnx_model, input_shape_dict): + def __init__(self, onnx_model, input_shape_dict=None): super(ONNXGraph, self).__init__(onnx_model) self.fixed_input_shape = {} if input_shape_dict is not None: @@ -395,7 +395,7 @@ class ONNXGraph(Graph): class ONNXDecoder(object): - def __init__(self, onnx_model, input_shape_dict, enable_onnx_checker): + def __init__(self, onnx_model, enable_onnx_checker, input_shape_dict=None): onnx_model = onnx.load(onnx_model) print('model ir_version: {}, op version: {}'.format( onnx_model.ir_version, onnx_model.opset_import[0].version))