提交 17126424 编写于 作者: J jiachx

fix caffe的pooling层对于kernel size和stride size的转换错误

上级 5098d192
...@@ -19,25 +19,25 @@ from functools import reduce ...@@ -19,25 +19,25 @@ from functools import reduce
def get_kernel_parameters(params): def get_kernel_parameters(params):
[k_h, k_w] = [1, 1] [k_h, k_w] = [1, 1]
if isinstance(params.kernel_size, numbers.Number): if params.kernel_h > 0 or params.kernel_w > 0:
k_h = params.kernel_h
k_w = params.kernel_w
elif isinstance(params.kernel_size, numbers.Number):
[k_h, k_w] = [params.kernel_size] * 2 [k_h, k_w] = [params.kernel_size] * 2
elif len(params.kernel_size) > 0: elif len(params.kernel_size) > 0:
k_h = params.kernel_h if params.kernel_h > 0 else params.kernel_size[0] k_h = params.kernel_h if params.kernel_h > 0 else params.kernel_size[0]
k_w = params.kernel_w if params.kernel_w > 0 else params.kernel_size[ k_w = params.kernel_w if params.kernel_w > 0 else params.kernel_size[
len(params.kernel_size) - 1] len(params.kernel_size) - 1]
elif params.kernel_h > 0 or params.kernel_w > 0:
k_h = params.kernel_h
k_w = params.kernel_w
[s_h, s_w] = [1, 1] [s_h, s_w] = [1, 1]
if isinstance(params.stride, numbers.Number): if params.stride_h > 0 or params.stride_w > 0:
s_h = params.stride_h
s_w = params.stride_w
elif isinstance(params.stride, numbers.Number):
[s_h, s_w] = [params.stride] * 2 [s_h, s_w] = [params.stride] * 2
elif len(params.stride) > 0: elif len(params.stride) > 0:
s_h = params.stride_h if params.stride_h > 0 else params.stride[0] s_h = params.stride_h if params.stride_h > 0 else params.stride[0]
s_w = params.stride_w if params.stride_w > 0 else params.stride[len( s_w = params.stride_w if params.stride_w > 0 else params.stride[len(
params.stride) - 1] params.stride) - 1]
elif params.stride_h > 0 or params.stride_w > 0:
s_h = params.stride_h
s_w = params.stride_w
[p_h, p_w] = [0, 0] [p_h, p_w] = [0, 0]
if isinstance(params.pad, numbers.Number): if isinstance(params.pad, numbers.Number):
[p_h, p_w] = [params.pad] * 2 [p_h, p_w] = [params.pad] * 2
......
...@@ -72,15 +72,15 @@ def _get_kernel_parameters(kind, params): ...@@ -72,15 +72,15 @@ def _get_kernel_parameters(kind, params):
k_w = params.kernel_w if params.kernel_w > 0 else params.kernel_size[ k_w = params.kernel_w if params.kernel_w > 0 else params.kernel_size[
len(params.kernel_size) - 1] len(params.kernel_size) - 1]
[s_h, s_w] = [1, 1] [s_h, s_w] = [1, 1]
if isinstance(params.stride, numbers.Number): if params.stride_h > 0 or params.stride_w > 0:
s_h = params.stride_h
s_w = params.stride_w
elif isinstance(params.stride, numbers.Number):
[s_h, s_w] = [params.stride] * 2 [s_h, s_w] = [params.stride] * 2
elif len(params.stride) > 0: elif len(params.stride) > 0:
s_h = params.stride_h if params.stride_h > 0 else params.stride[0] s_h = params.stride_h if params.stride_h > 0 else params.stride[0]
s_w = params.stride_w if params.stride_w > 0 else params.stride[len( s_w = params.stride_w if params.stride_w > 0 else params.stride[len(
params.stride) - 1] params.stride) - 1]
elif params.stride_h > 0 or params.stride_w > 0:
s_h = params.stride_h
s_w = params.stride_w
[p_h, p_w] = [0, 0] [p_h, p_w] = [0, 0]
if isinstance(params.pad, numbers.Number): if isinstance(params.pad, numbers.Number):
[p_h, p_w] = [params.pad] * 2 [p_h, p_w] = [params.pad] * 2
...@@ -168,7 +168,8 @@ class CaffeOpMapper(): ...@@ -168,7 +168,8 @@ class CaffeOpMapper():
return False return False
def directly_map(self, node): def directly_map(self, node):
assert len(node.layer.bottom) == 1, 'directly_map error with multi inputs' assert len(
node.layer.bottom) == 1, 'directly_map error with multi inputs'
op_info = self.directly_map_ops[node.layer_type] op_info = self.directly_map_ops[node.layer_type]
input = self.graph.get_input_node(node, 0) input = self.graph.get_input_node(node, 0)
paddle_op = op_info[0] paddle_op = op_info[0]
...@@ -193,7 +194,8 @@ class CaffeOpMapper(): ...@@ -193,7 +194,8 @@ class CaffeOpMapper():
outputs=layer_outputs) outputs=layer_outputs)
else: else:
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
kernel=paddle_op, inputs={"x": input.name}, kernel=paddle_op,
inputs={"x": input.name},
outputs=[node.name]) outputs=[node.name])
def Input(self, node): def Input(self, node):
...@@ -203,8 +205,7 @@ class CaffeOpMapper(): ...@@ -203,8 +205,7 @@ class CaffeOpMapper():
outputs=[node.layer_name], outputs=[node.layer_name],
data=node.name) data=node.name)
shape = list(node.layer.input_param.shape[0].dim)[1:] shape = list(node.layer.input_param.shape[0].dim)[1:]
self.inputs_info[node.name] = [[-1] + shape, self.inputs_info[node.name] = [[-1] + shape, "float32"]
"float32"]
def MemoryData(self, node): def MemoryData(self, node):
params = node.layer.memory_data_param params = node.layer.memory_data_param
...@@ -619,7 +620,8 @@ class CaffeOpMapper(): ...@@ -619,7 +620,8 @@ class CaffeOpMapper():
num_parameters=num_parameters) num_parameters=num_parameters)
def Eltwise(self, node): def Eltwise(self, node):
if len(node.layer.bottom) == 3 and node.layer.eltwise_param.operation == 1: if len(node.layer.
bottom) == 3 and node.layer.eltwise_param.operation == 1:
inputs_dict = {} inputs_dict = {}
input0 = self.graph.get_input_node(node, idx=0, copy=True) input0 = self.graph.get_input_node(node, idx=0, copy=True)
input1 = self.graph.get_input_node(node, idx=1, copy=True) input1 = self.graph.get_input_node(node, idx=1, copy=True)
...@@ -630,18 +632,18 @@ class CaffeOpMapper(): ...@@ -630,18 +632,18 @@ class CaffeOpMapper():
inputs_dict['x'] = input0_name inputs_dict['x'] = input0_name
inputs_dict['y'] = input1_name inputs_dict['y'] = input1_name
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.add", inputs=inputs_dict, "paddle.add",
outputs=[node.layer_name+"_1"]) inputs=inputs_dict,
outputs=[node.layer_name + "_1"])
inputs_dict = {} inputs_dict = {}
inputs_dict['x'] = node.layer_name+"_1" inputs_dict['x'] = node.layer_name + "_1"
inputs_dict['y'] = input2_name inputs_dict['y'] = input2_name
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.add", inputs=inputs_dict, "paddle.add", inputs=inputs_dict, outputs=[node.layer_name])
outputs=[node.layer_name])
return return
assert len( assert len(node.layer.
node.layer.bottom) == 2, "The count of Eltwise node\'s input is not 2." bottom) == 2, "The count of Eltwise node\'s input is not 2."
params = node.layer.eltwise_param params = node.layer.eltwise_param
mode = params.operation mode = params.operation
inputs = [] inputs = []
...@@ -925,7 +927,7 @@ class CaffeOpMapper(): ...@@ -925,7 +927,7 @@ class CaffeOpMapper():
offset_real = [0] * len(input_shape) offset_real = [0] * len(input_shape)
if hasattr(params, "offset") and len(params.offset) > 0: if hasattr(params, "offset") and len(params.offset) > 0:
offset_origin = list(params.offset) offset_origin = list(params.offset)
if len(offset_origin)==1 : if len(offset_origin) == 1:
offset = offset_origin * (len(input_shape) - axis) offset = offset_origin * (len(input_shape) - axis)
assert (len(input_shape) - axis assert (len(input_shape) - axis
) == len(offset), "invalid offset[%s] in crop layer" % ( ) == len(offset), "invalid offset[%s] in crop layer" % (
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册