提交 8f9bd9b6 编写于 作者: C Channingss

update Reshape&elementwise_map

上级 9fe56c11
......@@ -350,7 +350,6 @@ class ONNXGraph(Graph):
node.out_shapes.append(value_info['shape'])
else:
node.out_shapes.append([])
print(layer.name, node.out_shapes)
class ONNXDecoder(object):
......
......@@ -40,6 +40,21 @@ def _const_weight_or_none(node):
return None
def _is_static_shape(shape):
negtive_dims = 0
error_dims = 0
for dim in shape:
if dim < 0:
negtive_dims += 1
if dim != -1:
error_dims += 1
if negtive_dims > 1:
return False
if error_dims > 0:
return False
return True
def _get_same_padding(in_size, kernel_size, stride):
new_size = int(math.ceil(in_size * 1.0 / stride))
pad_size = (new_size - 1) * stride + kernel_size - in_size
......@@ -230,11 +245,35 @@ class OpSet9():
val_y = self.graph.get_input_node(node, idx=1, copy=True)
val_y_shape = val_y.out_shapes[0]
val_x_shape = val_x.out_shapes[0]
inputs = {}
if len(val_x_shape) < len(val_y_shape):
val_x, val_y = val_y, val_x
val_y_shape, val_x_shape = val_x_shape, val_y_shape
if node.layer_type in ['Mul', 'Add']:
val_x, val_y = val_y, val_x
val_y_shape, val_x_shape = val_x_shape, val_y_shape
inputs = {'x': val_x, 'y': val_y}
elif node.layer_type in ['Sub', 'Div', 'Pow']:
val_x_expand = val_x.layer_name + '_expand'
x_value = _const_weight_or_none(val_x)
if (val_x_shape == [1] or len(val_x_shape) == 0) and x_value:
attr = {
'shape': val_y_shape,
'dtype': string(val_x.dtype),
'value': x_value
if len(val_x_shape) == 0 else x_value[0]
}
node.fluid_code.add_layer(
'fill_constant',
inputs=None,
output=val_x_expand,
param_attr=attr)
val_x_shape = val_y_shape
inputs = {'x': val_x_expand, 'y': val_y}
else:
assert 'Unsupported situation happened.'
else:
inputs = {'x': val_x, 'y': val_y}
print(node.layer_name)
print(val_x_shape, val_y_shape)
str_y_shape = ','.join(str(e) for e in val_y_shape)
str_x_shape = ','.join(str(e) for e in val_x_shape)
slice_idx = 0
......@@ -244,7 +283,6 @@ class OpSet9():
slice_idx += 1
else:
break
attr = {"name": string(node.layer_name)}
if slice_idx < len(val_y_shape) and slice_idx > 0:
val_y_reshaped = val_y_shape[slice_idx:]
var_y_reshaped = val_y.layer_name + '_reshaped'
......@@ -257,13 +295,12 @@ class OpSet9():
inputs=val_y,
output=var_y_reshaped,
param_attr=attr_reshaped)
inputs = {'x': val_x, 'y': var_y_reshaped}
inputs['y'] = var_y_reshaped
node.fluid_code.add_layer(
op_type, inputs=inputs, output=node, param_attr=attr)
op_type, inputs=inputs, output=node, param_attr=None)
else:
inputs = {'x': val_x, 'y': val_y}
node.fluid_code.add_layer(
op_type, inputs=inputs, output=node, param_attr=attr)
op_type, inputs=inputs, output=node, param_attr=None)
@print_mapping_info
def place_holder(self, node):
......@@ -941,7 +978,8 @@ class OpSet9():
inputs={'x': val_x},
output=node,
param_attr={'shape': shape_value.tolist()})
elif len(node.out_shapes[0]) > 0:
elif len(node.out_shapes[0]) > 0 and _is_static_shape(node.out_shapes[
0]):
node.fluid_code.add_layer(
'reshape',
inputs={'x': val_x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册