tf_op_mapper_nhwc.py 36.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from x2paddle.decoder.tf_decoder import TFGraph
from x2paddle.core.op_mapper import OpMapper
from x2paddle.core.util import *
import inspect
import numpy
import sys


# compute padding size for SAME mode
def get_same_padding(in_size, kernel_size, stride):
    new_size = int(math.ceil(in_size * 1.0 / stride))
    pad_size = (new_size - 1) * stride + kernel_size - in_size
J
jiangjiajun 已提交
27 28
    if pad_size < 0:
        pad_size = 0
29 30 31 32 33 34 35 36 37 38 39 40 41 42
    pad0 = int(pad_size / 2)
    pad1 = pad_size - pad0
    return [pad0, pad1]


class TFOpMapperNHWC(OpMapper):
    directly_map_ops = {
        'Relu': ['relu'],
        'Relu6': ['relu6'],
        'Shape': ['shape'],
        'Abs': ['abs'],
        'Sigmoid': ['sigmoid'],
        'Exp': ['exp'],
        'Rsqrt': ['rsqrt'],
J
jiangjiajun@baidu.com 已提交
43
        'Sqrt': ['sqrt'],
44
        'swish_f32': ['swish'],
45
        'Tanh': ['tanh'],
J
jiangjiajun 已提交
46
        'Softplus': ['softplus'],
47 48 49 50 51 52
        'LeakyRelu': ['leaky_relu', {
            'alpha': 'alpha'
        }]
    }
    elementwise_ops = {
        'Add': 'elementwise_add',
J
jiangjiajun@baidu.com 已提交
53
        'AddV2': 'elementwise_add',
54 55 56
        'RealDiv': 'elementwise_div',
        'Sub': 'elementwise_sub',
        'Maximum': 'elementwise_max',
J
jiangjiajun 已提交
57 58
        'Mul': 'elementwise_mul',
        'FloorDiv': 'elementwise_floordiv'
59 60 61 62 63 64 65
    }

    def __init__(self, decoder):
        super(TFOpMapperNHWC, self).__init__()
        self.decoder = decoder
        self.graph = decoder.tf_graph
        self.weights = dict()
66
        self.batch_node = None
67 68 69 70 71 72 73 74 75 76 77 78
        self.omit_nodes = list()
        self.used_custom_layers = dict()

        not_placeholder = list()
        for name in self.graph.input_nodes:
            if self.graph.get_node(name).layer_type != "Placeholder":
                not_placeholder.append(name)
        for name in not_placeholder:
            idx = self.graph.input_nodes.index(name)
            del self.graph.input_nodes[idx]

        unsupported_ops = set()
79 80
        sys.stderr.write("Total nodes: {}\n".format(len(self.graph.topo_sort)))
        for i, node_name in enumerate(self.graph.topo_sort):
M
mamingjie-China 已提交
81
            sys.stderr.write("\rConverting node {} ...     ".format(i + 1))
82 83 84 85 86 87 88 89 90 91 92 93 94 95
            node = self.graph.get_node(node_name)
            op = node.layer_type
            if op in self.directly_map_ops:
                if len(unsupported_ops) > 0:
                    continue
                self.directly_map(node)
            elif op in self.elementwise_ops:
                if len(unsupported_ops) > 0:
                    continue
                self.elementwise_map(node)
            elif hasattr(self, op):
                if len(unsupported_ops) > 0:
                    continue
                func = getattr(self, op)
J
jiangjiajun@baidu.com 已提交
96 97
                try:
                    func(node)
98
                except Exception as e:
J
jiangjiajun@baidu.com 已提交
99
                    unsupported_ops.add(op)
100
                    print(e)
101 102 103 104 105 106 107 108
            else:
                unsupported_ops.add(op)
        if len(unsupported_ops) > 0:
            print("========= {} OPs are not supported yet ===========".format(
                len(unsupported_ops)))
            for op in unsupported_ops:
                print("========== {} ============".format(op))
            sys.exit(-1)
M
mamingjie-China 已提交
109
        sys.stderr.write("\nDone!\n")
110

J
jiangjiajun 已提交
111 112 113 114 115 116 117 118 119
    def add_omit_nodes(self, in_node_name, out_node_name):
        in_node = self.graph.get_node(in_node_name)
        out_node = self.graph.get_node(out_node_name)
        index = in_node.outputs.index(out_node_name)
        del in_node.outputs[index]
        index = out_node.inputs.index(in_node_name)
        del out_node.inputs[index]
        self.omit_nodes.append(in_node.layer_name)

120 121 122 123 124 125 126 127 128 129
    def directly_map(self, node):
        assert node.layer_type in self.directly_map_ops
        op_info = self.directly_map_ops[node.layer_type]
        input = self.graph.get_node(node.layer.input[0], copy=True)
        attr = dict()
        for param in op_info[1:]:
            tf_param_name = list(param.keys())[0]
            pd_param_name = list(param.values())[0]
            tf_param = node.get_attr(tf_param_name)
            attr[pd_param_name] = tf_param
M
modify  
mamingjie-China 已提交
130 131 132

        if len(input.out_shapes[0]) == 4 and op_info[0] != 'shape':
            attr1 = {"perm": [0, 3, 1, 2]}
J
jiangjiajun 已提交
133 134
            node.fluid_code.add_layer(
                'transpose', inputs=input, output=node, param_attr=attr1)
M
modify  
mamingjie-China 已提交
135
            input = node
J
jiangjiajun 已提交
136 137
            node.fluid_code.add_layer(
                op_info[0], inputs=input, output=node, param_attr=attr)
M
modify  
mamingjie-China 已提交
138 139
            input = node
            attr2 = {"perm": [0, 2, 3, 1]}
J
jiangjiajun 已提交
140 141
            node.fluid_code.add_layer(
                'transpose', inputs=input, output=node, param_attr=attr2)
M
modify  
mamingjie-China 已提交
142
        else:
J
jiangjiajun 已提交
143 144
            node.fluid_code.add_layer(
                op_info[0], inputs=input, output=node, param_attr=attr)
145 146 147 148 149 150

    def elementwise_map(self, node):
        assert node.layer_type in self.elementwise_ops
        op_type = self.elementwise_ops[node.layer_type]
        x = self.graph.get_node(node.layer.input[0], copy=True)
        y = self.graph.get_node(node.layer.input[1], copy=True)
151 152 153
        inputs = {"x": x, "y": y}
        node.fluid_code.add_layer(
            op_type, inputs=inputs, output=node, param_attr=None)
154 155 156 157 158 159

    def Placeholder(self, node):
        shape = node.out_shapes[0]
        assert len(shape) != 0, "Unknown shape of input nodes[{}].".format(
            node.layer_name)
        dtype = node.dtype
J
jiangjiajun 已提交
160 161
        if shape[0] < 0:
            self.batch_node = node
162 163 164 165 166 167
        attr = {
            'dtype': string(dtype),
            'shape': shape,
            'name': string(node.layer_name),
            'append_batch_size': False
        }
J
jiangjiajun 已提交
168

J
jiangjiajun 已提交
169 170
        node.fluid_code.add_layer(
            "data", inputs=None, output=node, param_attr=attr)
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189

    def Const(self, node):
        shape = node.out_shapes[0]
        dtype = node.dtype
        value = node.value
        initializer = "Constant(0.0)"
        if len(shape) == 0:
            assert value.size == 1, "Unexpected situation happend"
            shape = [1]
            initializer = "Constant({})".format(value)

        self.weights[node.layer_name] = node.value

        attr = {
            'dtype': string(dtype),
            'shape': shape,
            'name': string(node.layer_name),
            'default_initializer': initializer
        }
J
jiangjiajun 已提交
190 191
        node.fluid_code.add_layer(
            "create_parameter", inputs=None, output=node, param_attr=attr)
192 193 194 195 196 197 198 199 200 201

    def Transpose(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        perm = self.graph.get_node(node.layer.input[1], copy=True)
        assert perm.layer_type == "Const", "Perm of transpose OP should be Const"
        del self.weights[perm.layer_name.replace('/', '_')]
        perm.fluid_code.clear()
        perm = perm.value.tolist()

        attr = {'perm': perm}
J
jiangjiajun 已提交
202 203
        node.fluid_code.add_layer(
            "transpose", inputs=input, output=node, param_attr=attr)
204 205 206 207 208 209 210 211 212 213 214 215

    def MaxPool(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)

        k_size = node.get_attr("ksize")
        strides = node.get_attr("strides")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()
        channel_first = data_format == "NCHW"

        if not channel_first:
            attr = {"perm": [0, 3, 1, 2]}
J
jiangjiajun 已提交
216 217
            node.fluid_code.add_layer(
                "transpose", inputs=input, output=node, param_attr=attr)
218 219 220 221 222 223 224
            strides = [strides[i] for i in [0, 3, 1, 2]]
            k_size = [k_size[i] for i in [0, 3, 1, 2]]
            input = node

        attr = {
            "pool_size": k_size[2:4],
            "pool_type": string("max"),
M
mamingjie-China 已提交
225 226
            "pool_stride": strides[2:4],
            "pool_padding": string(pad_mode)
227
        }
J
jiangjiajun 已提交
228 229
        node.fluid_code.add_layer(
            "pool2d", inputs=input, output=node, param_attr=attr)
230 231 232

        if not channel_first:
            attr = {"perm": [0, 2, 3, 1]}
J
jiangjiajun 已提交
233 234
            node.fluid_code.add_layer(
                "transpose", inputs=node, output=node, param_attr=attr)
235 236 237 238 239 240 241 242 243 244 245 246

    def Conv2D(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        kernel = self.graph.get_node(node.layer.input[1], copy=True)

        k_size = kernel.out_shapes[0]
        strides = node.get_attr("strides")
        dilations = node.get_attr("dilations")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()
        channel_first = data_format == "NCHW"

J
jiangjiajun@baidu.com 已提交
247
        if kernel.layer_type == 'Const':
248
            self.add_omit_nodes(kernel.layer_name, node.layer_name)
J
jiangjiajun@baidu.com 已提交
249
            kernel_value = kernel.value
250
        self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose(
J
jiangjiajun@baidu.com 已提交
251
            kernel_value, (3, 2, 0, 1))
252 253 254 255 256

        if not channel_first:
            strides = [strides[i] for i in [0, 3, 1, 2]]
            dilations = [dilations[i] for i in [0, 3, 1, 2]]
            attr = {"perm": [0, 3, 1, 2]}
J
jiangjiajun 已提交
257 258
            node.fluid_code.add_layer(
                "transpose", inputs=input, output=node, param_attr=attr)
259 260 261 262 263 264 265 266 267
            input = node

        attr = {
            "bias_attr": False,
            "param_attr": string(kernel.layer_name),
            "num_filters": k_size[3],
            "filter_size": k_size[0:2],
            "stride": strides[2:4],
            "dilation": dilations[2:4],
M
mamingjie-China 已提交
268
            "padding": string(pad_mode)
269
        }
J
jiangjiajun@baidu.com 已提交
270 271 272 273

        if hasattr(node, 'dilation') and attr['dilation'] == [1, 1]:
            if len(node.dilation) == 1:
                attr['dilation'] = [1, node.dilation[0]]
J
jiangjiajun 已提交
274 275
        node.fluid_code.add_layer(
            "conv2d", inputs=input, output=node, param_attr=attr)
276 277
        if not channel_first:
            attr = {"perm": [0, 2, 3, 1]}
J
jiangjiajun 已提交
278 279
            node.fluid_code.add_layer(
                "transpose", inputs=node, output=node, param_attr=attr)
280 281 282 283 284

    def BiasAdd(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        bias = self.graph.get_node(node.layer.input[1], copy=True)
        inputs = {"x": input, "y": bias}
J
jiangjiajun 已提交
285 286
        node.fluid_code.add_layer(
            "elementwise_add", inputs=inputs, output=node, param_attr=None)
287 288 289 290 291 292 293 294 295 296 297 298 299 300

    def FusedBatchNorm(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        gamma = self.graph.get_node(node.layer.input[1], copy=True)
        beta = self.graph.get_node(node.layer.input[2], copy=True)
        moving_mean = self.graph.get_node(node.layer.input[3], copy=True)
        moving_var = self.graph.get_node(node.layer.input[4], copy=True)
        data_format = node.get_attr("data_format").decode()
        channel_first = data_format == "NCHW"

        assert gamma.layer_type == "Const"
        assert beta.layer_type == "Const"
        assert moving_mean.layer_type == "Const"
        assert moving_var.layer_type == "Const"
J
jiangjiajun 已提交
301 302 303 304
        self.add_omit_nodes(gamma.layer_name, node.layer_name)
        self.add_omit_nodes(beta.layer_name, node.layer_name)
        self.add_omit_nodes(moving_mean.layer_name, node.layer_name)
        self.add_omit_nodes(moving_var.layer_name, node.layer_name)
305 306 307

        if not channel_first:
            attr = {"perm": [0, 3, 1, 2]}
J
jiangjiajun 已提交
308 309
            node.fluid_code.add_layer(
                "transpose", inputs=input, output=node, param_attr=attr)
310 311 312 313 314 315 316 317 318 319 320
            input = node

        attr = {
            "epsilon": node.get_attr("epsilon"),
            "param_attr": string(gamma.layer_name),
            "bias_attr": string(beta.layer_name),
            "moving_mean_name": string(moving_mean.layer_name),
            "moving_variance_name": string(moving_var.layer_name),
            "is_test": True
        }

J
jiangjiajun 已提交
321 322
        node.fluid_code.add_layer(
            "batch_norm", inputs=input, output=node, param_attr=attr)
323 324 325

        if not channel_first:
            attr = {"perm": [0, 2, 3, 1]}
J
jiangjiajun 已提交
326 327
            node.fluid_code.add_layer(
                "transpose", inputs=node, output=node, param_attr=attr)
328 329 330 331 332

    def DepthwiseConv2dNative(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        kernel = self.graph.get_node(node.layer.input[1], copy=True)
        assert kernel.layer_type == "Const", "Kernel of DepthwiseConv2DNative should be Const"
J
jiangjiajun 已提交
333
        self.add_omit_nodes(kernel.layer_name, node.layer_name)
334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350

        in_shape = input.out_shapes[0]
        k_size = kernel.out_shapes[0]
        strides = node.get_attr("strides")
        dilations = node.get_attr("dilations")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()
        channel_first = data_format == "NCHW"

        self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose(
            kernel.value, (2, 3, 0, 1))

        if not channel_first:
            in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
            strides = [strides[i] for i in [0, 3, 1, 2]]
            dilations = [dilations[i] for i in [0, 3, 1, 2]]
            attr = {"perm": [0, 3, 1, 2]}
J
jiangjiajun 已提交
351 352
            node.fluid_code.add_layer(
                "transpose", inputs=input, output=node, param_attr=attr)
353 354 355 356 357 358 359 360 361 362 363
            input = node

        attr = {
            "bias_attr": False,
            "param_attr": string(kernel.layer_name),
            "num_filters": in_shape[1],
            "filter_size": k_size[0:2],
            "stride": strides[2:4],
            "dilation": dilations[2:4],
            "groups": k_size[3] * in_shape[1],
            "use_cudnn": False,
M
mamingjie-China 已提交
364
            "padding": string(pad_mode)
365
        }
J
jiangjiajun 已提交
366 367
        node.fluid_code.add_layer(
            "conv2d", inputs=input, output=node, param_attr=attr)
368 369 370

        if not channel_first:
            attr = {"perm": [0, 2, 3, 1]}
J
jiangjiajun 已提交
371 372
            node.fluid_code.add_layer(
                "transpose", inputs=node, output=node, param_attr=attr)
373 374 375 376 377

    def Reshape(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        param = self.graph.get_node(node.layer.input[1], copy=True)
        if param.layer_type == "Const":
J
jiangjiajun 已提交
378
            self.add_omit_nodes(param.layer_name, node.layer_name)
379
            shape = param.value.tolist()
380
        else:
381 382 383 384 385 386 387 388 389
            shape = param
        inputs = {"x": input, "shape": shape}
        node.fluid_code.add_layer(
            "reshape", inputs=inputs, output=node, param_attr=None)
        if param.layer_type != "Const":
            out_shape = numpy.array(node.out_shapes[0])
            if (out_shape > 0).any():
                out_shape[out_shape < 0] = 0
                attr = {'shape': out_shape.tolist()}
J
jiangjiajun 已提交
390
                node.fluid_code.add_layer(
391
                    "reshape", inputs=node, output=node, param_attr=attr)
392 393 394 395 396 397 398 399 400 401 402 403 404 405

    def AvgPool(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)

        k_size = node.get_attr("ksize")
        strides = node.get_attr("strides")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()
        channel_first = data_format == "NCHW"

        if not channel_first:
            strides = [strides[i] for i in [0, 3, 1, 2]]
            k_size = [k_size[i] for i in [0, 3, 1, 2]]
            attr = {"perm": [0, 3, 1, 2]}
J
jiangjiajun 已提交
406 407
            node.fluid_code.add_layer(
                "transpose", inputs=input, output=node, param_attr=attr)
408 409 410 411 412
            input = node

        attr = {
            "pool_size": k_size[2:4],
            "pool_type": string("avg"),
M
mamingjie-China 已提交
413 414
            "pool_stride": strides[2:4],
            "pool_padding": string(pad_mode)
415
        }
J
jiangjiajun 已提交
416 417
        node.fluid_code.add_layer(
            "pool2d", inputs=input, output=node, param_attr=attr)
418 419 420

        if not channel_first:
            attr = {"perm": [0, 2, 3, 1]}
J
jiangjiajun 已提交
421 422
            node.fluid_code.add_layer(
                "transpose", inputs=node, output=node, param_attr=attr)
423 424 425 426 427 428 429

    def SplitV(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        num_sections = self.graph.get_node(node.layer.input[1], copy=True)
        dim = self.graph.get_node(node.layer.input[2], copy=True)
        assert num_sections.layer_type == "Const"
        assert dim.layer_type == "Const"
J
jiangjiajun 已提交
430 431
        self.add_omit_nodes(num_sections.layer_name, node.layer_name)
        self.add_omit_nodes(dim.layer_name, node.layer_name)
432 433 434 435 436
        dim = dim.value
        attr = {
            "num_or_sections": num_sections.value.tolist(),
            "dim": dim.value
        }
J
jiangjiajun 已提交
437 438
        node.fluid_code.add_layer(
            "split", inputs=input, output=node, param_attr=attr)
439 440 441

    def ConcatV2(self, node):
        inputs = [
J
jiangjiajun 已提交
442 443
            self.graph.get_node(
                name, copy=True) for name in node.layer.input[:-1]
444 445 446
        ]
        axis = self.graph.get_node(node.layer.input[-1], copy=True)
        assert axis.layer_type == "Const"
J
jiangjiajun 已提交
447
        self.add_omit_nodes(axis.layer_name, node.layer_name)
448 449 450 451
        axis = axis.value
        if axis < 0:
            axis += len(inputs[0].out_shapes[0])
        attr = {"axis": axis}
J
jiangjiajun 已提交
452 453
        node.fluid_code.add_layer(
            "concat", inputs=inputs, output=node, param_attr=attr)
454 455 456 457 458

    def Tile(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        expand_times = self.graph.get_node(node.layer.input[1], copy=True)
        if expand_times.layer_type == "Const":
459
            self.add_omit_nodes(expand_times.layer_name, node.layer_name)
460 461
            expand_times = expand_times.value.tolist()
        else:
462 463
            expand_times = expand_times
        inputs = {"x": input, "expand_times": expand_times}
J
jiangjiajun 已提交
464
        node.fluid_code.add_layer(
465
            "expand", inputs=inputs, output=node, param_attr=None)
466 467 468

    def Pack(self, node):
        inputs = [
J
jiangjiajun 已提交
469 470
            self.graph.get_node(
                name, copy=True) for name in node.layer.input
471
        ]
472 473 474 475 476 477 478 479 480 481 482 483 484 485
        reshape_shape = list()
        for input_node in inputs:
            k_size = input_node.out_shapes[0]
            if len(k_size) and k_size[-1] != -1:
                reshape_shape = [0] * len(k_size)
                reshape_shape[-1] = k_size[-1]
                break
        if len(reshape_shape):
            for i, input_node in enumerate(inputs):
                node.fluid_code.add_layer(
                    "reshape",
                    inputs=input_node,
                    output='tmp_{}'.format(i),
                    param_attr={"shape": reshape_shape})
486 487
        axis = node.get_attr("axis")
        attr = {"axis": axis}
488 489
        if len(reshape_shape):
            inputs = ['tmp_{}'.format(i) for i in range(len(inputs))]
J
jiangjiajun 已提交
490 491
        node.fluid_code.add_layer(
            "stack", inputs=inputs, output=node, param_attr=attr)
492 493 494 495 496

    def Pad(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        paddings = self.graph.get_node(node.layer.input[1], copy=True)
        assert paddings.layer_type == "Const", "Padding should be Const"
J
jiangjiajun 已提交
497
        self.add_omit_nodes(paddings.layer_name, node.layer_name)
498 499 500 501 502 503 504 505 506 507 508 509 510 511
        paddings = paddings.value.flatten().tolist()
        data_format = input.tf_data_format

        if len(input.out_shapes[0]) == 4:
            new_padding = None
            if input.tf_data_format == "NHWC":
                if paddings[0] + paddings[1] + paddings[6] + paddings[7] == 0:
                    new_padding = paddings[2:6]
            else:
                if paddings[0] + paddings[1] + paddings[2] + paddings[3] == 0:
                    new_padding = paddings[4:]
            if new_padding is not None:
                if input.tf_data_format == "NHWC":
                    attr = {"perm": [0, 3, 1, 2]}
J
jiangjiajun 已提交
512 513
                    node.fluid_code.add_layer(
                        "transpose", inputs=input, output=node, param_attr=attr)
514 515
                    input = node
                attr = {"paddings": new_padding}
J
jiangjiajun 已提交
516 517
                node.fluid_code.add_layer(
                    "pad2d", inputs=input, output=node, param_attr=attr)
518 519
                if input.tf_data_format == "NHWC":
                    attr = {"perm": [0, 2, 3, 1]}
J
jiangjiajun 已提交
520 521
                    node.fluid_code.add_layer(
                        "transpose", inputs=node, output=node, param_attr=attr)
522 523 524 525

                return

        attr = {"paddings": paddings}
J
jiangjiajun 已提交
526 527
        node.fluid_code.add_layer(
            "pad", inputs=input, output=node, param_attr=attr)
528 529 530 531 532

    def Range(self, node):
        start = self.graph.get_node(node.layer.input[0], copy=True)
        limit = self.graph.get_node(node.layer.input[1], copy=True)
        delta = self.graph.get_node(node.layer.input[2], copy=True)
533

534
        if start.layer_type == "Const":
535
            self.add_omit_nodes(start.layer_name, node.layer_name)
536 537
            start = start.value
        if limit.layer_type == "Const":
538
            self.add_omit_nodes(limit.layer_name, node.layer_name)
539 540
            limit = limit.value
        if delta.layer_type == "Const":
541
            self.add_omit_nodes(delta.layer_name, node.layer_name)
542
            delta = delta.value
543

544 545 546 547 548 549 550
        dtype = node.dtype
        inputs = {
            "start": start,
            "end": limit,
            "step": delta,
        }
        attr = {"dtype": string(node.dtype)}
J
jiangjiajun 已提交
551 552
        node.fluid_code.add_layer(
            "range", inputs=inputs, output=node, param_attr=attr)
553 554 555 556 557 558 559 560 561

    def Mean(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        reduce_idx = self.graph.get_node(node.layer.input[1], copy=True)
        assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
        dims = reduce_idx.value.tolist()
        keep_dims = node.get_attr("keep_dims")

        attr = {"dim": dims, "keep_dim": keep_dims}
J
jiangjiajun 已提交
562 563
        node.fluid_code.add_layer(
            "reduce_mean", inputs=input, output=node, param_attr=attr)
564 565 566 567 568 569 570 571 572 573 574 575 576

    def MatMul(self, node):
        x = self.graph.get_node(node.layer.input[0], copy=True)
        y = self.graph.get_node(node.layer.input[1], copy=True)
        transpose_a = node.get_attr('transpose_a')
        transpose_b = node.get_attr('transpose_b')
        inputs = {"x": x, "y": y}
        # fix paddle shape infer problem
        # should be removed after paddle 1.6
        if x.out_shapes[0][-1] < 0 and y.out_shapes[0][0] > 0:
            shape = x.out_shapes[0]
            shape[-1] = y.out_shapes[0][0]
            attr = {"shape": shape}
J
jiangjiajun 已提交
577 578
            node.fluid_code.add_layer(
                "reshape", inputs=x, output=x, param_attr=attr)
579
        attr = {"transpose_x": transpose_a, "transpose_y": transpose_b}
J
jiangjiajun 已提交
580 581
        node.fluid_code.add_layer(
            "matmul", inputs=inputs, output=node, param_attr=attr)
582 583 584 585 586

    def ArgMax(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        axis = self.graph.get_node(node.layer.input[1], copy=True)
        assert axis.layer_type == "Const", "ArgMax only support Const parameter"
J
jiangjiajun 已提交
587
        self.add_omit_nodes(axis.layer_name, node.layer_name)
588 589
        axis = axis.value
        attr = {"axis": axis}
J
jiangjiajun 已提交
590 591
        node.fluid_code.add_layer(
            "argmax", inputs=input, output=node, param_attr=attr)
592 593 594 595 596 597 598 599 600

    def StridedSlice(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        begin = self.graph.get_node(node.layer.input[1], copy=True)
        end = self.graph.get_node(node.layer.input[2], copy=True)
        strides = self.graph.get_node(node.layer.input[3], copy=True)
        assert begin.layer_type == "Const"
        assert end.layer_type == "Const"
        assert strides.layer_type == "Const"
J
jiangjiajun 已提交
601 602 603
        self.add_omit_nodes(begin.layer_name, node.layer_name)
        self.add_omit_nodes(end.layer_name, node.layer_name)
        self.add_omit_nodes(strides.layer_name, node.layer_name)
604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656
        strides = strides.value.tolist()
        assert len(set(strides)) == 1 and strides[
            0] == 1, "Only support strides be 1 in StridedSlice OP"

        begin = begin.value.tolist()
        end = end.value.tolist()

        for i in range(len(end)):
            if end[i] == 0:
                end[i] = 999999

        begin_mask = node.get_attr('begin_mask')
        end_mask = node.get_attr('end_mask')
        ellipsis_mask = node.get_attr('ellipsis_mask')
        new_axis_mask = node.get_attr('new_axis_mask')
        shrink_axis_mask = node.get_attr('shrink_axis_mask')

        assert ellipsis_mask == 0, "(OP:{} Name:{})Only support ellipsis_mask be 0[now: {}] n StridedSlice OP".format(
            node.layer_type, node.layer.name, ellipsis_mask)

        # TODO codes without validation
        # Use it carefully
        new_begin = list()
        new_end = list()
        new_axes = list()
        shrink_axes = list()
        for i, item in enumerate(begin):
            mask = (new_axis_mask >> i) & 1
            if mask != 0:
                new_axes.append(i)
                continue

            mask = (shrink_axis_mask >> i) & 1
            if mask != 0:
                shrink_axes.append(i)

            mask = (begin_mask >> i) & 1
            if mask != 0:
                new_begin.append(0)
            else:
                new_begin.append(item)

            mask = (end_mask >> i) & 1
            if mask != 0:
                new_end.append(999999)
            else:
                new_end.append(end[i])

        attr = {
            "axes": [i for i in range(len(new_begin))],
            "starts": new_begin,
            "ends": new_end
        }
J
jiangjiajun 已提交
657 658
        node.fluid_code.add_layer(
            "slice", inputs=input, output=node, param_attr=attr)
659 660
        if len(new_axes) > 0:
            attr = {"axes": new_axes}
J
jiangjiajun 已提交
661 662
            node.fluid_code.add_layer(
                "unsqueeze", inputs=node, output=node, param_attr=attr)
663 664 665 666 667
        if len(shrink_axes) > 0:
            if len(input.out_shapes[0]) + len(new_axes) <= 1:
                pass
            else:
                attr = {"axes": shrink_axes}
J
jiangjiajun 已提交
668 669
                node.fluid_code.add_layer(
                    "squeeze", inputs=node, output=node, param_attr=attr)
670 671 672 673 674 675

    def Slice(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        begin = self.graph.get_node(node.layer.input[1], copy=True)
        size = self.graph.get_node(node.layer.input[2], copy=True)
        if begin.layer_type == "Const":
676
            self.add_omit_nodes(begin.layer_name, node.layer_name)
677 678
            begin = begin.value.tolist()
        else:
679 680 681 682 683 684 685
            begin = begin
            shape = begin.out_shapes[0]
            attr = {"shape": shape}
            node.fluid_code.add_layer(
                "reshape", inputs=begin, output=begin, param_attr=attr)
        if size.layer_type == "Const":
            self.add_omit_nodes(size.layer_name, node.layer_name)
686 687
            size = size.value.tolist()
        else:
688 689 690 691 692 693
            size = size
            shape = size.out_shapes[0]
            attr = {"shape": shape}
            node.fluid_code.add_layer(
                "reshape", inputs=size, output=size, param_attr=attr)
        inputs = {"x": input, "offsets": begin, "shape": size}
J
jiangjiajun 已提交
694
        node.fluid_code.add_layer(
695
            "crop_tensor", inputs=inputs, output=node, param_attr=None)
696 697

    def Conv2DBackpropInput(self, node):
698
        out_shape = self.graph.get_node(node.layer.input[0], copy=True)
699
        kernel = self.graph.get_node(node.layer.input[1], copy=True)
700 701
        input = self.graph.get_node(node.layer.input[2], copy=True)

702
        assert kernel.layer_type == "Const", "Kernel of Conv2DBackpropInput should be Const"
703
        assert out_shape.layer_type == "Const", "Out_shape of Conv2DBackpropInput should be Const"
704

J
jiangjiajun 已提交
705
        self.add_omit_nodes(kernel.layer_name, node.layer_name)
706

707 708
        out_shape = out_shape.value.tolist()
        self.add_omit_nodes(out_shape.layer_name, node.layer_name)
709

710 711 712 713 714 715 716
        in_shape = input.out_shapes[0]
        if in_shape.count(-1) > 2:
            in_shape = self.decoder.infer_tensor(input).shape
        k_size = kernel.out_shapes[0]
        if k_size.count(-1) > 2:
            k_size = self.decoder.infer_tensor(kernel).shape

717
        pad_mode = node.get_attr("padding").decode()
718 719 720 721
        strides = node.get_attr("strides")
        dilations = node.get_attr("dilations")
        data_format = node.get_attr("data_format").decode()
        channel_first = data_format == "NCHW"
722

723 724 725 726 727 728 729
        self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose(
            kernel.value, (3, 2, 0, 1))
        if not channel_first:
            in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
            strides = [strides[i] for i in [0, 3, 1, 2]]
            dilations = [dilations[i] for i in [0, 3, 1, 2]]
            attr = {"perm": [0, 3, 1, 2]}
J
jiangjiajun 已提交
730 731
            node.fluid_code.add_layer(
                "transpose", inputs=input, output=node, param_attr=attr)
732
            input = node
733 734 735
        else:
            self.data_format_propagation(node)

736 737 738
        attr = {
            "bias_attr": False,
            "param_attr": string(kernel.layer_name),
M
mamingjie-China 已提交
739
            "num_filters": k_size[2],
740 741 742
            "filter_size": k_size[0:2],
            "stride": strides[2:4],
            "dilation": dilations[2:4],
M
mamingjie-China 已提交
743 744
            "padding": string(pad_mode),
            "output_size": out_shape[1:3]
745
        }
J
jiangjiajun 已提交
746 747
        node.fluid_code.add_layer(
            "conv2d_transpose", inputs=input, output=node, param_attr=attr)
748

749 750
        if not channel_first:
            attr = {"perm": [0, 2, 3, 1]}
J
jiangjiajun 已提交
751 752
            node.fluid_code.add_layer(
                "transpose", inputs=node, output=node, param_attr=attr)
753 754 755 756 757 758 759 760 761

    def Max(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        reduce_idx = self.graph.get_node(node.layer.input[1], copy=True)
        assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
        keep_dims = node.get_attr("keep_dims")
        dim = reduce_idx.value.tolist()

        attr = {"dim": dim, "keep_dim": keep_dims}
J
jiangjiajun 已提交
762 763
        node.fluid_code.add_layer(
            "reduce_max", inputs=input, output=node, param_attr=attr)
764 765 766 767 768 769 770 771 772

    def Sum(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        reduce_idx = self.graph.get_node(node.layer.input[1], copy=True)
        assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
        keep_dims = node.get_attr("keep_dims")
        dim = reduce_idx.value.tolist()

        attr = {"dim": dim, "keep_dim": keep_dims}
J
jiangjiajun 已提交
773 774
        node.fluid_code.add_layer(
            "reduce_sum", inputs=input, output=node, param_attr=attr)
775 776 777 778 779

    def Cast(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        dtype = node.dtype_map[node.get_attr('DstT')]
        attr = {"dtype": string(dtype)}
J
jiangjiajun 已提交
780 781
        node.fluid_code.add_layer(
            "cast", inputs=input, output=node, param_attr=attr)
782 783 784 785 786

    def Split(self, node):
        dim = self.graph.get_node(node.layer.input[0], copy=True)
        input = self.graph.get_node(node.layer.input[1], copy=True)
        assert dim.layer_type == "Const"
J
jiangjiajun 已提交
787
        self.add_omit_nodes(dim.layer_name, node.layer_name)
788 789 790 791
        num_split = node.get_attr('num_split')
        dim = dim.value

        attr = {"num_or_sections": num_split, "dim": dim}
J
jiangjiajun 已提交
792 793
        node.fluid_code.add_layer(
            "split", inputs=input, output=node, param_attr=attr)
794 795 796 797 798

    def Squeeze(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        squeeze_dims = node.get_attr('squeeze_dims')
        attr = {"axes": squeeze_dims}
J
jiangjiajun 已提交
799 800
        node.fluid_code.add_layer(
            "squeeze", inputs=input, output=node, param_attr=attr)
801 802 803 804 805

    def Softmax(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        axis = node.get_attr("axis")
        attr = {"axis": axis}
J
jiangjiajun 已提交
806 807
        node.fluid_code.add_layer(
            "softmax", inputs=input, output=node, param_attr=attr)
808 809 810 811 812

    def ResizeNearestNeighbor(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
        if resize_shape.layer_type == "Const":
813
            self.add_omit_nodes(resize_shape.layer_name, node.layer_name)
814 815
            resize_shape = resize_shape.value.tolist()
        else:
816 817 818 819 820 821 822 823 824
            resize_shape = resize_shape
            shape = resize_shape.out_shapes[0]
            attr = {"shape": shape}
            node.fluid_code.add_layer(
                "reshape",
                inputs=resize_shape,
                output=resize_shape,
                param_attr=attr)

825 826
        align_corners = node.get_attr("align_corners")
        attr = {"perm": [0, 3, 1, 2]}
J
jiangjiajun 已提交
827 828
        node.fluid_code.add_layer(
            "transpose", inputs=input, output=node, param_attr=attr)
829 830
        inputs = {"input": node, "out_shape": resize_shape}
        attr = {"align_corners": align_corners}
J
jiangjiajun 已提交
831
        node.fluid_code.add_layer(
832
            "resize_nearest", inputs=inputs, output=node, param_attr=attr)
833
        attr = {"perm": [0, 2, 3, 1]}
J
jiangjiajun 已提交
834 835
        node.fluid_code.add_layer(
            "transpose", inputs=node, output=node, param_attr=attr)
836 837 838 839 840

    def ResizeBilinear(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
        if resize_shape.layer_type == "Const":
841
            self.add_omit_nodes(resize_shape.layer_name, node.layer_name)
842 843
            resize_shape = resize_shape.value.tolist()
        else:
844 845 846 847 848 849 850
            shape = resize_shape.out_shapes[0]
            attr = {"shape": shape}
            node.fluid_code.add_layer(
                "reshape",
                inputs=resize_shape,
                output=resize_shape,
                param_attr=attr)
851 852
        align_corners = node.get_attr("align_corners")
        attr = {"perm": [0, 3, 1, 2]}
J
jiangjiajun 已提交
853 854
        node.fluid_code.add_layer(
            "transpose", inputs=input, output=node, param_attr=attr)
855
        inputs = {"input": node, "out_shape": resize_shape}
856
        attr = {
857
            #"out_shape": resize_shape,
858 859 860
            "align_corners": align_corners,
            "align_mode": 1
        }
J
jiangjiajun 已提交
861
        node.fluid_code.add_layer(
862
            "resize_bilinear", inputs=inputs, output=node, param_attr=attr)
863
        attr = {"perm": [0, 2, 3, 1]}
J
jiangjiajun 已提交
864 865
        node.fluid_code.add_layer(
            "transpose", inputs=node, output=node, param_attr=attr)
866 867 868 869 870

    def GreaterEqual(self, node):
        x = self.graph.get_node(node.layer.input[0], copy=True)
        y = self.graph.get_node(node.layer.input[1], copy=True)
        inputs = {"x": x, "y": y}
J
jiangjiajun 已提交
871 872
        node.fluid_code.add_layer(
            "greater_equal", inputs=inputs, output=node, param_attr=None)
873 874 875 876

    def RandomUniform(self, node):
        shape = self.graph.get_node(node.layer.input[0], copy=True)
        if shape.layer_type == "Const":
877
            self.add_omit_nodes(shape.layer_name, node.layer_name)
878 879
            shape = shape.value.tolist()
        else:
880 881
            shape = shape
        attr = {"min": 0.0, "max": 0.9999}
M
mamingjie-China 已提交
882

883 884
        node.fluid_code.add_layer(
            "uniform_random", inputs=shape, output=node, param_attr=attr)
885 886 887 888 889

    def SquaredDifference(self, node):
        x = self.graph.get_node(node.layer.input[0], copy=True)
        y = self.graph.get_node(node.layer.input[1], copy=True)
        inputs = {"x": x, "y": y}
J
jiangjiajun 已提交
890 891
        node.fluid_code.add_layer(
            "elementwise_sub", inputs=inputs, output=node, param_attr=None)
892
        inputs = {"x": node, "y": node}
J
jiangjiajun 已提交
893 894
        node.fluid_code.add_layer(
            "elementwise_mul", inputs=inputs, output=node, param_attr=None)
J
jiangjiajun@baidu.com 已提交
895 896 897 898 899

    def ExpandDims(self, node):
        x = self.graph.get_node(node.layer.input[0], copy=True)
        y = self.graph.get_node(node.layer.input[1], copy=True)
        if y.layer_type == 'Const':
900
            self.add_omit_nodes(y.layer_name, node.layer_name)
J
jiangjiajun@baidu.com 已提交
901
            dim = y.value.tolist()
902
            attr = {'axes': [dim]}
J
jiangjiajun@baidu.com 已提交
903
        else:
904
            attr = {'axes': y}
J
jiangjiajun 已提交
905 906
        node.fluid_code.add_layer(
            "unsqueeze", inputs=x, output=node, param_attr=attr)
J
jiangjiajun@baidu.com 已提交
907 908 909 910 911

    def BatchToSpaceND(self, node):
        x = self.graph.get_node(node.layer.input[0], copy=True)
        y = self.graph.get_node(node.layer.input[1], copy=True)
        if hasattr(node, 'skip') and node.skip:
J
jiangjiajun 已提交
912 913
            node.fluid_code.add_layer(
                "=", inputs=x, output=node, param_attr=None)
J
jiangjiajun@baidu.com 已提交
914 915 916 917 918 919 920
        else:
            raise Exception("BatchToSpaceND is not supported")

    def SpaceToBatchND(self, node):
        x = self.graph.get_node(node.layer.input[0], copy=True)
        y = self.graph.get_node(node.layer.input[1], copy=True)
        if hasattr(node, 'skip') and node.skip:
J
jiangjiajun 已提交
921 922
            node.fluid_code.add_layer(
                "=", inputs=x, output=node, param_attr=None)
J
jiangjiajun@baidu.com 已提交
923 924
        else:
            raise Exception("SpaceToBatchND is not supported")