提交 17126424 编写于 作者: J jiachx

fix caffe的pooling层对于kernel size和stride size的转换错误

上级 5098d192
......@@ -19,25 +19,25 @@ from functools import reduce
def get_kernel_parameters(params):
[k_h, k_w] = [1, 1]
if isinstance(params.kernel_size, numbers.Number):
if params.kernel_h > 0 or params.kernel_w > 0:
k_h = params.kernel_h
k_w = params.kernel_w
elif isinstance(params.kernel_size, numbers.Number):
[k_h, k_w] = [params.kernel_size] * 2
elif len(params.kernel_size) > 0:
k_h = params.kernel_h if params.kernel_h > 0 else params.kernel_size[0]
k_w = params.kernel_w if params.kernel_w > 0 else params.kernel_size[
len(params.kernel_size) - 1]
elif params.kernel_h > 0 or params.kernel_w > 0:
k_h = params.kernel_h
k_w = params.kernel_w
[s_h, s_w] = [1, 1]
if isinstance(params.stride, numbers.Number):
if params.stride_h > 0 or params.stride_w > 0:
s_h = params.stride_h
s_w = params.stride_w
elif isinstance(params.stride, numbers.Number):
[s_h, s_w] = [params.stride] * 2
elif len(params.stride) > 0:
s_h = params.stride_h if params.stride_h > 0 else params.stride[0]
s_w = params.stride_w if params.stride_w > 0 else params.stride[len(
params.stride) - 1]
elif params.stride_h > 0 or params.stride_w > 0:
s_h = params.stride_h
s_w = params.stride_w
[p_h, p_w] = [0, 0]
if isinstance(params.pad, numbers.Number):
[p_h, p_w] = [params.pad] * 2
......
......@@ -72,15 +72,15 @@ def _get_kernel_parameters(kind, params):
k_w = params.kernel_w if params.kernel_w > 0 else params.kernel_size[
len(params.kernel_size) - 1]
[s_h, s_w] = [1, 1]
if isinstance(params.stride, numbers.Number):
if params.stride_h > 0 or params.stride_w > 0:
s_h = params.stride_h
s_w = params.stride_w
elif isinstance(params.stride, numbers.Number):
[s_h, s_w] = [params.stride] * 2
elif len(params.stride) > 0:
s_h = params.stride_h if params.stride_h > 0 else params.stride[0]
s_w = params.stride_w if params.stride_w > 0 else params.stride[len(
params.stride) - 1]
elif params.stride_h > 0 or params.stride_w > 0:
s_h = params.stride_h
s_w = params.stride_w
[p_h, p_w] = [0, 0]
if isinstance(params.pad, numbers.Number):
[p_h, p_w] = [params.pad] * 2
......@@ -168,7 +168,8 @@ class CaffeOpMapper():
return False
def directly_map(self, node):
assert len(node.layer.bottom) == 1, 'directly_map error with multi inputs'
assert len(
node.layer.bottom) == 1, 'directly_map error with multi inputs'
op_info = self.directly_map_ops[node.layer_type]
input = self.graph.get_input_node(node, 0)
paddle_op = op_info[0]
......@@ -193,7 +194,8 @@ class CaffeOpMapper():
outputs=layer_outputs)
else:
self.paddle_graph.add_layer(
kernel=paddle_op, inputs={"x": input.name},
kernel=paddle_op,
inputs={"x": input.name},
outputs=[node.name])
def Input(self, node):
......@@ -203,8 +205,7 @@ class CaffeOpMapper():
outputs=[node.layer_name],
data=node.name)
shape = list(node.layer.input_param.shape[0].dim)[1:]
self.inputs_info[node.name] = [[-1] + shape,
"float32"]
self.inputs_info[node.name] = [[-1] + shape, "float32"]
def MemoryData(self, node):
params = node.layer.memory_data_param
......@@ -619,7 +620,8 @@ class CaffeOpMapper():
num_parameters=num_parameters)
def Eltwise(self, node):
if len(node.layer.bottom) == 3 and node.layer.eltwise_param.operation == 1:
if len(node.layer.
bottom) == 3 and node.layer.eltwise_param.operation == 1:
inputs_dict = {}
input0 = self.graph.get_input_node(node, idx=0, copy=True)
input1 = self.graph.get_input_node(node, idx=1, copy=True)
......@@ -630,18 +632,18 @@ class CaffeOpMapper():
inputs_dict['x'] = input0_name
inputs_dict['y'] = input1_name
self.paddle_graph.add_layer(
"paddle.add", inputs=inputs_dict,
outputs=[node.layer_name+"_1"])
"paddle.add",
inputs=inputs_dict,
outputs=[node.layer_name + "_1"])
inputs_dict = {}
inputs_dict['x'] = node.layer_name+"_1"
inputs_dict['x'] = node.layer_name + "_1"
inputs_dict['y'] = input2_name
self.paddle_graph.add_layer(
"paddle.add", inputs=inputs_dict,
outputs=[node.layer_name])
"paddle.add", inputs=inputs_dict, outputs=[node.layer_name])
return
assert len(
node.layer.bottom) == 2, "The count of Eltwise node\'s input is not 2."
assert len(node.layer.
bottom) == 2, "The count of Eltwise node\'s input is not 2."
params = node.layer.eltwise_param
mode = params.operation
inputs = []
......@@ -925,7 +927,7 @@ class CaffeOpMapper():
offset_real = [0] * len(input_shape)
if hasattr(params, "offset") and len(params.offset) > 0:
offset_origin = list(params.offset)
if len(offset_origin)==1 :
if len(offset_origin) == 1:
offset = offset_origin * (len(input_shape) - axis)
assert (len(input_shape) - axis
) == len(offset), "invalid offset[%s] in crop layer" % (
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册