From 171264247436bcf0cf3670530ba7255754ddbb28 Mon Sep 17 00:00:00 2001 From: jiachx Date: Sat, 4 Sep 2021 22:05:16 +0800 Subject: [PATCH] =?UTF-8?q?fix=20caffe=E7=9A=84pooling=E5=B1=82=E5=AF=B9?= =?UTF-8?q?=E4=BA=8Ekernel=20size=E5=92=8Cstride=20size=E7=9A=84=E8=BD=AC?= =?UTF-8?q?=E6=8D=A2=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- x2paddle/decoder/caffe_shape_inference.py | 16 ++++----- .../op_mapper/caffe2paddle/caffe_op_mapper.py | 36 ++++++++++--------- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/x2paddle/decoder/caffe_shape_inference.py b/x2paddle/decoder/caffe_shape_inference.py index 213a8ff..9c0c48a 100644 --- a/x2paddle/decoder/caffe_shape_inference.py +++ b/x2paddle/decoder/caffe_shape_inference.py @@ -19,25 +19,25 @@ from functools import reduce def get_kernel_parameters(params): [k_h, k_w] = [1, 1] - if isinstance(params.kernel_size, numbers.Number): + if params.kernel_h > 0 or params.kernel_w > 0: + k_h = params.kernel_h + k_w = params.kernel_w + elif isinstance(params.kernel_size, numbers.Number): [k_h, k_w] = [params.kernel_size] * 2 elif len(params.kernel_size) > 0: k_h = params.kernel_h if params.kernel_h > 0 else params.kernel_size[0] k_w = params.kernel_w if params.kernel_w > 0 else params.kernel_size[ len(params.kernel_size) - 1] - elif params.kernel_h > 0 or params.kernel_w > 0: - k_h = params.kernel_h - k_w = params.kernel_w [s_h, s_w] = [1, 1] - if isinstance(params.stride, numbers.Number): + if params.stride_h > 0 or params.stride_w > 0: + s_h = params.stride_h + s_w = params.stride_w + elif isinstance(params.stride, numbers.Number): [s_h, s_w] = [params.stride] * 2 elif len(params.stride) > 0: s_h = params.stride_h if params.stride_h > 0 else params.stride[0] s_w = params.stride_w if params.stride_w > 0 else params.stride[len( params.stride) - 1] - elif params.stride_h > 0 or params.stride_w > 0: - s_h = params.stride_h - s_w = params.stride_w [p_h, p_w] = [0, 0] if isinstance(params.pad, numbers.Number): [p_h, p_w] = [params.pad] * 2 diff --git a/x2paddle/op_mapper/caffe2paddle/caffe_op_mapper.py b/x2paddle/op_mapper/caffe2paddle/caffe_op_mapper.py index c43f059..e05f60c 100644 --- a/x2paddle/op_mapper/caffe2paddle/caffe_op_mapper.py +++ b/x2paddle/op_mapper/caffe2paddle/caffe_op_mapper.py @@ -72,15 +72,15 @@ def _get_kernel_parameters(kind, params): k_w = params.kernel_w if params.kernel_w > 0 else params.kernel_size[ len(params.kernel_size) - 1] [s_h, s_w] = [1, 1] - if isinstance(params.stride, numbers.Number): + if params.stride_h > 0 or params.stride_w > 0: + s_h = params.stride_h + s_w = params.stride_w + elif isinstance(params.stride, numbers.Number): [s_h, s_w] = [params.stride] * 2 elif len(params.stride) > 0: s_h = params.stride_h if params.stride_h > 0 else params.stride[0] s_w = params.stride_w if params.stride_w > 0 else params.stride[len( params.stride) - 1] - elif params.stride_h > 0 or params.stride_w > 0: - s_h = params.stride_h - s_w = params.stride_w [p_h, p_w] = [0, 0] if isinstance(params.pad, numbers.Number): [p_h, p_w] = [params.pad] * 2 @@ -168,7 +168,8 @@ class CaffeOpMapper(): return False def directly_map(self, node): - assert len(node.layer.bottom) == 1, 'directly_map error with multi inputs' + assert len( + node.layer.bottom) == 1, 'directly_map error with multi inputs' op_info = self.directly_map_ops[node.layer_type] input = self.graph.get_input_node(node, 0) paddle_op = op_info[0] @@ -193,7 +194,8 @@ class CaffeOpMapper(): outputs=layer_outputs) else: self.paddle_graph.add_layer( - kernel=paddle_op, inputs={"x": input.name}, + kernel=paddle_op, + inputs={"x": input.name}, outputs=[node.name]) def Input(self, node): @@ -203,8 +205,7 @@ class CaffeOpMapper(): outputs=[node.layer_name], data=node.name) shape = list(node.layer.input_param.shape[0].dim)[1:] - self.inputs_info[node.name] = [[-1] + shape, - "float32"] + self.inputs_info[node.name] = [[-1] + shape, "float32"] def MemoryData(self, node): params = node.layer.memory_data_param @@ -619,7 +620,8 @@ class CaffeOpMapper(): num_parameters=num_parameters) def Eltwise(self, node): - if len(node.layer.bottom) == 3 and node.layer.eltwise_param.operation == 1: + if len(node.layer. + bottom) == 3 and node.layer.eltwise_param.operation == 1: inputs_dict = {} input0 = self.graph.get_input_node(node, idx=0, copy=True) input1 = self.graph.get_input_node(node, idx=1, copy=True) @@ -630,18 +632,18 @@ class CaffeOpMapper(): inputs_dict['x'] = input0_name inputs_dict['y'] = input1_name self.paddle_graph.add_layer( - "paddle.add", inputs=inputs_dict, - outputs=[node.layer_name+"_1"]) + "paddle.add", + inputs=inputs_dict, + outputs=[node.layer_name + "_1"]) inputs_dict = {} - inputs_dict['x'] = node.layer_name+"_1" + inputs_dict['x'] = node.layer_name + "_1" inputs_dict['y'] = input2_name self.paddle_graph.add_layer( - "paddle.add", inputs=inputs_dict, - outputs=[node.layer_name]) + "paddle.add", inputs=inputs_dict, outputs=[node.layer_name]) return - assert len( - node.layer.bottom) == 2, "The count of Eltwise node\'s input is not 2." + assert len(node.layer. + bottom) == 2, "The count of Eltwise node\'s input is not 2." params = node.layer.eltwise_param mode = params.operation inputs = [] @@ -925,7 +927,7 @@ class CaffeOpMapper(): offset_real = [0] * len(input_shape) if hasattr(params, "offset") and len(params.offset) > 0: offset_origin = list(params.offset) - if len(offset_origin)==1 : + if len(offset_origin) == 1: offset = offset_origin * (len(input_shape) - axis) assert (len(input_shape) - axis ) == len(offset), "invalid offset[%s] in crop layer" % ( -- GitLab