未验证 提交 bf7fd504 编写于 作者: Y yeliang2258 提交者: GitHub

fix split op in onnx (#645)

* fix expand op in onnx

* remove useless info

* fix split and add GatherND

* fix

* test revert

* update

* update

* reverse

* reverse expand
上级 b229cbd0
......@@ -793,6 +793,14 @@ class OpSet9():
self.paddle_graph.add_layer(
'paddle.multiply', inputs=inputs_dict, outputs=[node.name])
@print_mapping_info
def GatherND(self, node):
x = self.graph.get_input_node(node, idx=0, copy=True)
index = self.graph.get_input_node(node, idx=1, copy=True)
inputs = {'x': x.name, 'index': index.name}
self.paddle_graph.add_layer(
"paddle.gather_nd", inputs=inputs, outputs=[node.name])
@print_mapping_info
def Gather(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
......@@ -1345,28 +1353,50 @@ class OpSet9():
if split is None:
split = len(node.outputs)
axis = node.get_attr('axis', 0)
layer_attrs = {
'num_or_sections': split,
'axis': axis,
}
outputs_list = list()
if isinstance(split, list) or isinstance(split, tuple):
if len(split) == 1:
outputs_list.append(node.name)
else:
for i in range(len(split)):
if split is None:
split_num = len(node.layer.output)
layer_attrs = {
'num_or_sections': split_num,
'axis': axis,
}
outputs_list = list()
for i in range(len(node.layer.output)):
if hasattr(node, 'index'):
outputs_list.append("{}_p{}".format(node.layer_name, i))
else:
outputs_list.append("{}".format(node.layer_name))
if split_num > 1:
self.paddle_graph.add_layer(
'paddle.split',
inputs={"x": val_x.name},
outputs=outputs_list,
**layer_attrs)
else:
self.paddle_graph.add_layer(
"paddle.cast",
inputs={"x": val_x.name},
outputs=outputs_list,
dtype=string(val_x.dtype))
else:
if len(node.outputs) == 1:
outputs_list.append(node.name)
layer_attrs = {
'num_or_sections': split,
'axis': axis,
}
outputs_list = list()
if isinstance(split, list) or isinstance(split, tuple):
if len(split) == 1:
outputs_list.append(node.name)
else:
for i in range(len(split)):
outputs_list.append("{}_p{}".format(node.layer_name, i))
else:
for i in range(len(node.outputs)):
outputs_list.append("{}_p{}".format(node.layer_name, i))
self.paddle_graph.add_layer(
'paddle.split',
inputs={"x": val_x.name},
outputs=outputs_list,
**layer_attrs)
outputs_list.append(node.name)
self.paddle_graph.add_layer(
'paddle.split',
inputs={"x": val_x.name},
outputs=outputs_list,
**layer_attrs)
@print_mapping_info
def Reshape(self, node):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册