From bf7fd504372396a8265d1a8ae1981a1832433821 Mon Sep 17 00:00:00 2001 From: yeliang2258 <30516196+yeliang2258@users.noreply.github.com> Date: Mon, 18 Oct 2021 10:43:21 +0800 Subject: [PATCH] fix split op in onnx (#645) * fix expand op in onnx * remove useless info * fix split and add GatherND * fix * test revert * update * update * reverse * reverse expand --- .../op_mapper/onnx2paddle/opset9/opset.py | 68 +++++++++++++------ 1 file changed, 49 insertions(+), 19 deletions(-) diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index e984e10..9ffe233 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -793,6 +793,14 @@ class OpSet9(): self.paddle_graph.add_layer( 'paddle.multiply', inputs=inputs_dict, outputs=[node.name]) + @print_mapping_info + def GatherND(self, node): + x = self.graph.get_input_node(node, idx=0, copy=True) + index = self.graph.get_input_node(node, idx=1, copy=True) + inputs = {'x': x.name, 'index': index.name} + self.paddle_graph.add_layer( + "paddle.gather_nd", inputs=inputs, outputs=[node.name]) + @print_mapping_info def Gather(self, node): val_x = self.graph.get_input_node(node, idx=0, copy=True) @@ -1345,28 +1353,50 @@ class OpSet9(): if split is None: split = len(node.outputs) axis = node.get_attr('axis', 0) - layer_attrs = { - 'num_or_sections': split, - 'axis': axis, - } - outputs_list = list() - if isinstance(split, list) or isinstance(split, tuple): - if len(split) == 1: - outputs_list.append(node.name) - else: - for i in range(len(split)): + if split is None: + split_num = len(node.layer.output) + layer_attrs = { + 'num_or_sections': split_num, + 'axis': axis, + } + outputs_list = list() + for i in range(len(node.layer.output)): + if hasattr(node, 'index'): outputs_list.append("{}_p{}".format(node.layer_name, i)) + else: + outputs_list.append("{}".format(node.layer_name)) + if split_num > 1: + self.paddle_graph.add_layer( + 'paddle.split', + inputs={"x": val_x.name}, + outputs=outputs_list, + **layer_attrs) + else: + self.paddle_graph.add_layer( + "paddle.cast", + inputs={"x": val_x.name}, + outputs=outputs_list, + dtype=string(val_x.dtype)) + else: - if len(node.outputs) == 1: - outputs_list.append(node.name) + layer_attrs = { + 'num_or_sections': split, + 'axis': axis, + } + outputs_list = list() + if isinstance(split, list) or isinstance(split, tuple): + if len(split) == 1: + outputs_list.append(node.name) + else: + for i in range(len(split)): + outputs_list.append("{}_p{}".format(node.layer_name, i)) else: - for i in range(len(node.outputs)): - outputs_list.append("{}_p{}".format(node.layer_name, i)) - self.paddle_graph.add_layer( - 'paddle.split', - inputs={"x": val_x.name}, - outputs=outputs_list, - **layer_attrs) + outputs_list.append(node.name) + self.paddle_graph.add_layer( + 'paddle.split', + inputs={"x": val_x.name}, + outputs=outputs_list, + **layer_attrs) @print_mapping_info def Reshape(self, node): -- GitLab