tracing.py 40.4 KB
Newer Older
1 2 3 4 5 6 7 8
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
M
Megvii Engine Team 已提交
9
import collections
M
Megvii Engine Team 已提交
10 11
import contextlib
import functools
M
Megvii Engine Team 已提交
12
import itertools
13
import json
14
import os
M
Megvii Engine Team 已提交
15
import typing
M
Megvii Engine Team 已提交
16
import warnings
M
Megvii Engine Team 已提交
17 18
import weakref

M
Megvii Engine Team 已提交
19 20
import numpy as np

21
from ..core._imperative_rt import GraphProfiler, common
22 23
from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import (
24
    TensorWeakRef,
25 26 27 28 29 30 31
    apply,
    set_compiled,
    set_tracing,
    skip_tracing,
    unset_compiled,
    unset_tracing,
)
32
from ..core._imperative_rt.ops import CollectiveComm, RemoteRecv, RemoteSend
33
from ..core._trace_option import set_symbolic_shape
34
from ..core._wrap import device as as_device
35
from ..core.ops.builtin import BackwardGraph, OpDef
M
Megvii Engine Team 已提交
36 37
from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G
38
from ..core.tensor.utils import setscalar
39
from .sublinear_memory_config import SublinearMemoryConfig
M
Megvii Engine Team 已提交
40 41


42 43 44 45
def _input_node_use_static_shape():
    return os.environ.get("MEGENGINE_INPUT_NODE_USE_STATIC_SHAPE") is not None


M
Megvii Engine Team 已提交
46 47 48 49 50 51 52
class TraceMismatchError(RuntimeError):
    pass


active_trace = None


53 54 55 56 57 58 59
def is_tracing():
    if active_trace is None:
        return False
    else:
        return not skip_tracing


M
Megvii Engine Team 已提交
60 61 62 63 64 65 66 67
@contextlib.contextmanager
def exclude_from_trace():
    global skip_tracing
    if skip_tracing:
        yield
        return
    try:
        skip_tracing = True
68
        unset_tracing()
M
Megvii Engine Team 已提交
69 70 71 72 73
        if active_trace is not None:
            active_trace._begin_excluded_region()
        yield
    finally:
        skip_tracing = False
74
        set_tracing()
M
Megvii Engine Team 已提交
75 76 77 78 79 80


class TensorInfo:
    __slots__ = (
        # collected attributes
        "external",
81 82 83
        "data_read",
        "shape_read",
        "value_read",
M
Megvii Engine Team 已提交
84 85 86
        "exported",
        "device",
        "dtype",
87
        "shape",
88
        "is_const",
M
Megvii Engine Team 已提交
89 90 91 92 93 94 95 96 97 98 99
        "bound_data",
        # resources for execution
        "varnode",
        "data_setter",
        "shape_reader",
        "value_reader",
        "data_reader",
    )

    def __init__(self):
        self.exported = None
100 101 102
        self.data_read = None
        self.shape_read = None
        self.value_read = None
M
Megvii Engine Team 已提交
103 104 105 106 107 108 109 110
        self.bound_data = None

        self.data_setter = None
        self.shape_reader = None
        self.value_reader = None
        self.data_reader = None


111 112 113
_io_op_types = {CollectiveComm, RemoteSend, RemoteRecv}


M
Megvii Engine Team 已提交
114
class trace:
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
    """
    Wraps a callable and provide:

    * tracing via :meth:`.trace` and :meth:`.dump`
    * accelerated evalutaion via :meth:`.__call__`

    :param function: the function will be traced.
    :param symbolic: whether to apply symbolic execution for tracing. Default: False
    :param capture_as_const: capture global vars or closures as const value. Default: False
    :param sublinear_memory_config: configuration for sublinear memory optimization.
        If not None, it enables sublinear memory optimization with given setting.
    :param profiling: whether to profile compiled trace. Default: False
    :param opt_level: optimization level for compiling trace.
    :param symbolic_shape: whether to use symbolic shape for tracing. Default: True
    """

M
Megvii Engine Team 已提交
131 132 133
    def __new__(cls, *args, **kwargs):
        if not args:
            return functools.partial(cls, **kwargs)
134
        return super().__new__(cls)
M
Megvii Engine Team 已提交
135

136 137 138 139 140 141
    def __init__(
        self,
        function,
        symbolic=False,
        capture_as_const=False,
        sublinear_memory_config: SublinearMemoryConfig = None,
142
        profiling: bool = False,
143
        opt_level: int = None,
144
        symbolic_shape: bool = True,
145
    ):
M
Megvii Engine Team 已提交
146 147 148
        self.__wrapped__ = function
        self._symbolic = symbolic
        self._capture_as_const = capture_as_const
149
        self._sublinear_memory_config = sublinear_memory_config
150 151
        self._profiling = profiling
        self._profiler = None
152
        self._graph_opt_level = opt_level
153
        self._symbolic_shape = symbolic_shape
154
        self._output_handles = set()
M
Megvii Engine Team 已提交
155

156 157 158
        self._reset()

    def _reset(self):
M
Megvii Engine Team 已提交
159 160 161 162 163 164 165
        self._untraced = True
        self._tinfo = []  # handle -> TensorInfo
        self._seq = []
        self._pc = 0
        self._graph = None
        self._need_reset_nodes = None
        self._lazy_eval_graph = None
166
        self._lazy_eval_tensors = {}
167
        self._lazy_eval_links = None
168
        self._active_tensors = {}
M
Megvii Engine Team 已提交
169 170
        self._tensor_remaps = None
        self._inputs_to_restore = None
171 172
        self._arg_bindings = None
        self._kwarg_bindings = None
M
Megvii Engine Team 已提交
173 174
        self._output_bindings = None
        self._output_names = None
M
Megvii Engine Team 已提交
175 176 177 178 179 180 181 182 183 184 185 186 187 188

    def _new_handle(self):
        handle = len(self._tinfo)
        info = TensorInfo()
        self._tinfo.append(info)
        return handle, info

    def _apply_op(self, op, args):
        assert not self._untraced
        # check against trace
        if self._pc >= len(self._seq):
            raise TraceMismatchError("trace should end here, but more op observed")
        record = self._seq[self._pc]
        op_, ihandles, ohandles = record
189
        if (isinstance(op_, str) and op_ == "Const") or (op != op_):
190
            raise TraceMismatchError("op different from last time")
M
Megvii Engine Team 已提交
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
        if len(ihandles) != len(args):
            raise TraceMismatchError("op input size different from last time")

        for h, x in zip(ihandles, args):
            info = self._tinfo[h]
            if info.external:
                if (
                    x.__class__ is CompiledTensorProxy
                    and not self._tinfo[x._CompiledTensorProxy__handle].exported
                ):
                    raise TraceMismatchError(
                        "failed to capture: input was an external tensor "
                        "last time, got an internal tensor this time"
                    )
                if info.bound_data:
                    if x.__class__ is CompiledTensorProxy:
                        raise TraceMismatchError(
                            "const capture violated: was an external tensor "
                            "last time, got an internal tensor this time"
                        )
                    if x._handle != info.bound_data._handle:
212
                        if not np.array_equal(x.numpy(), info.bound_data.numpy()):
M
Megvii Engine Team 已提交
213 214 215 216
                            raise TraceMismatchError(
                                "const capture violated: got "
                                "a different tensor this time"
                            )
M
Megvii Engine Team 已提交
217 218 219 220 221 222 223 224 225 226 227
                else:
                    if info.dtype != x.dtype:
                        raise TraceMismatchError(
                            "failed to capture: different dtype from last time"
                        )
                    if info.device != x.device:
                        raise TraceMismatchError(
                            "failed to capture: different device from last time"
                        )
                    info.data_setter.set_value(x._dev_tensor())
            else:
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
                if x.mixin_handle == -1:
                    if x._handle not in self._tensor_remaps:
                        raise TraceMismatchError(
                            "unexpected capture: trying to use an external tensor as "
                            "input, but that input was an internal tensor last time"
                        )
                    else:
                        x.mixin_handle = self._tensor_remaps[
                            x._handle
                        ]._CompiledTensorProxy__handle
                if x.mixin_handle != h:
                    raise TraceMismatchError(
                        "mis-wiring: input edge to an data flow "
                        "graph node is different from last time"
                    )
M
Megvii Engine Team 已提交
243 244

        self._pc += 1
245
        outputs = []
246
        for h in ohandles:
247 248 249 250 251
            info = self._tinfo[h]
            y = RawTensor(info.varnode)
            y._compiled_info = CompiledTensorProxy(h)
            y.mixin_handle = h
            outputs += [y]
252
            self._active_tensors[h] = TensorWeakRef(y)
253
        self._output_handles.update(ohandles)
M
Megvii Engine Team 已提交
254 255
        return outputs

256
    def _apply_const(self, value, dtype, device):
257 258 259 260 261 262
        assert not self._untraced
        # check against trace
        if self._pc >= len(self._seq):
            raise TraceMismatchError("trace should end here, but more op observed")
        record = self._seq[self._pc]
        op_, ihandles, ohandles = record
263 264
        assert isinstance(op_, str) and op_ == "Const"

265 266 267 268 269
        eq = np.all(np.atleast_1d(value) == self._tinfo[ohandles[0]].bound_data.numpy())
        if not eq:
            raise TraceMismatchError(
                "const tensor violated: got a different tensor this time"
            )
270 271 272

        self._pc += 1
        (h,) = ohandles
273
        outputs = [self._tinfo[h].bound_data]
274 275
        return outputs

M
Megvii Engine Team 已提交
276 277 278
    def _record_op(self, op, inputs, outputs):
        if skip_tracing:
            for x in inputs:
279 280
                h = getattr(x, "mixin_handle", -1)
                if h >= 0:
281
                    self._tinfo[h].data = True
M
Megvii Engine Team 已提交
282 283 284 285
            return

        ihandles = []
        for x in inputs:
286 287
            h = getattr(x, "mixin_handle", -1)
            if h < 0 or (not self._capture_as_const and self._tinfo[h].exported):
M
Megvii Engine Team 已提交
288 289 290 291
                h, info = self._new_handle()
                info.external = True
                info.device = x.device
                info.dtype = x.dtype
292
                info.shape = x.shape
M
Megvii Engine Team 已提交
293
                if self._capture_as_const:
294
                    info.bound_data = RawTensor(x.numpy(), x.dtype, x.device, False)
M
Megvii Engine Team 已提交
295 296 297 298 299 300 301 302

            ihandles.append(h)

        ohandles = []
        for x in outputs:
            h, info = self._new_handle()
            ohandles.append(h)
            info.external = False
303
            x.mixin_handle = h
304 305
            x.recording = True
            x._trace_mixin_info = info
306 307 308
            self._active_tensors[h] = TensorWeakRef(x)
            if self._symbolic:
                self._lazy_eval_tensors[h] = TensorWeakRef(x)
M
Megvii Engine Team 已提交
309 310 311

        self._seq.append((op, tuple(ihandles), tuple(ohandles)))

312
    def _record_const(self, outputs):
313 314
        if skip_tracing:
            (x,) = outputs
315 316
            h = getattr(x, "mixin_handle", -1)
            if h >= 0:
317
                self._tinfo[h].data_read = True
318 319 320 321 322 323 324 325 326 327 328
            return

        (x,) = outputs
        h, info = self._new_handle()
        ohandles = [h]
        info.external = True
        info.device = x.device
        info.dtype = x.dtype
        info.shape = x.shape
        info.bound_data = x
        info.is_const = True
329
        x.mixin_handle = h
330 331
        x.recording = True
        x._trace_mixin_info = info
332 333
        if self._symbolic:
            self._lazy_eval_tensors[h] = TensorWeakRef(x)
334
        self._seq.append(("Const", tuple(), tuple(ohandles)))
335

336
    def _set_active(self, active: bool):
M
Megvii Engine Team 已提交
337
        global active_trace
338 339 340 341
        if active:
            if active_trace:
                raise NotImplementedError("sorry, not implemented: nested trace")
            active_trace = self
M
Megvii Engine Team 已提交
342
        else:
343 344 345 346 347 348
            assert active_trace is self
            active_trace = None

    def _init_trace(self, symbolic: bool):
        if symbolic:
            self._lazy_eval_graph = G.Graph()
349
            self._apply_graph_options(self._lazy_eval_graph)
350
            self._lazy_eval_links = ()
351 352

    def _take_escaped_tensors(self):
353 354 355
        escaped_tensors = tuple(
            filter(lambda x: x() is not None, self._active_tensors.values())
        )
M
Megvii Engine Team 已提交
356
        self._active_tensors.clear()
357 358
        return escaped_tensors

359
    def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links):
360 361 362
        lazy_eval_tensors = list(
            filter(lambda x: x() is not None, lazy_eval_tensors.values())
        )
363
        readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors]
364
        self._apply_graph_options(lazy_eval_graph)
365 366 367 368 369
        # FIXME
        if self._graph_opt_level is not None:
            lazy_eval_graph.options.graph_opt_level = self._graph_opt_level
        else:
            lazy_eval_graph.options.graph_opt_level = 2
370
        lazy_eval_graph._set_priority_to_id([*lazy_eval_links, *readers])
371
        lazy_eval_graph.compile(*lazy_eval_links, *readers)
372
        lazy_eval_graph()
373
        for r, x in zip(readers, lazy_eval_tensors):
374
            x()._handle = RawTensor(r.op.get_value())._handle
375
            x()._reset_varnode()
376 377 378 379

    @contextlib.contextmanager
    def _setup(self):
        interrupted = False
M
Megvii Engine Team 已提交
380

381
        def do_enter():
382
            set_tracing()
383
            self._save_symbolic_shape = set_symbolic_shape(self._symbolic_shape)
384 385 386 387
            self._set_active(True)
            if self._untraced:
                self._init_trace(self._symbolic)
            else:
388
                set_compiled()
389 390 391 392 393 394 395 396
                if self._graph is None:
                    self._compile()
                self._graph.execute()

        def do_finalize():
            escaped_tensors = self._take_escaped_tensors()
            if self._untraced:
                for x in escaped_tensors:
397 398 399 400 401
                    if x():
                        info = self._tinfo[x().mixin_handle]
                        info.data_read = True
                        x().mixin_handle = -1
                        x().recording = False
402 403
                if self._inputs_to_restore:
                    for x in self._inputs_to_restore:
404
                        x.mixin_handle = -1
405
                        x.recording = False
406 407 408
                if self._symbolic and (
                    self._lazy_eval_tensors or self._lazy_eval_links
                ):
409
                    # eval lazy eval tensors
410 411
                    self._lazy_eval(
                        self._lazy_eval_graph,
412
                        self._lazy_eval_tensors,
413 414
                        self._lazy_eval_links,
                    )
M
Megvii Engine Team 已提交
415 416
                    self._lazy_eval_graph = None
                    self._lazy_eval_tensors = None
417
                    self._lazy_eval_links = None
418 419 420 421 422 423
                self._untraced = False
            else:
                # compiled_tensor leaks
                if self._pc == len(self._seq):
                    for x in escaped_tensors:
                        try:
424
                            assign_raw_tensor(x(), RawTensor(x()._dev_tensor()))
425
                        except RuntimeError:
426 427 428 429 430 431
                            # TraceMismatchError thrown in do_exit
                            pass
                    self._graph.wait()
                    self._reset_exec_env()

            # reset status
M
Megvii Engine Team 已提交
432
            self._pc = 0
433 434
            self._tensor_remaps = None
            self._set_active(False)
435
            set_symbolic_shape(self._save_symbolic_shape)
436 437
            unset_compiled()
            unset_tracing()
438 439

        def do_exit():
440
            unset_tracing()
441 442 443
            if not self._untraced and self._pc != len(self._seq):
                raise TraceMismatchError("premature end")
            if not self._symbolic or not self._untraced:
444
                for x in self._active_tensors.values():
445 446
                    if x() is not None:
                        x()._dev_tensor()
447
                        x()._reset_varnode()
448
                        x().mixin_handle = -1
449
                        x().recording = False
450
                        x()._trace_mixin_info = None
451 452 453 454 455 456 457 458 459 460 461 462

        try:
            do_enter()
            yield
            do_exit()
        except:
            interrupted = True
            raise
        finally:
            do_finalize()
            if interrupted:
                self._reset()
M
Megvii Engine Team 已提交
463 464

    def _begin_excluded_region(self):
M
Megvii Engine Team 已提交
465 466 467 468
        if self._capture_as_const:
            raise RuntimeError(
                "exclude_from_trace cannot be used with capture_as_const"
            )
M
Megvii Engine Team 已提交
469 470 471
        if self._untraced:
            # conditionally reading a compiled tensor in excluded region
            # is permitted, so we have to assume every tensor might be read
472
            for x in self._active_tensors.values():
473
                info = self._tinfo[x().mixin_handle]
M
Megvii Engine Team 已提交
474
                info.exported = True
475 476
                info.data_read = True
                x()._dev_tensor()
M
Megvii Engine Team 已提交
477

478 479
    def _apply_graph_options(self, graph):

480
        graph.options.no_force_inplace = True
481
        graph.options.seq_opt.enable_seq_comp_node_opt = False
482
        # graph opt level
483 484 485 486
        # if self._graph_opt_level is not None:
        #     graph.options.graph_opt_level = self._graph_opt_level
        # FIXME
        graph.options.graph_opt_level = 0
487 488 489 490 491 492 493 494 495 496 497 498 499
        # sublinear
        if self._sublinear_memory_config is not None:
            graph.options.enable_sublinear_memory_opt = True
            sublinear_config = graph.options.sublinear_mem_config
            sublinear_config.lb_memory = self._sublinear_memory_config.lb_memory
            sublinear_config.genetic_nr_iter = (
                self._sublinear_memory_config.genetic_nr_iter
            )
            sublinear_config.genetic_pool_size = (
                self._sublinear_memory_config.genetic_pool_size
            )
            sublinear_config.thresh_nr_try = self._sublinear_memory_config.thresh_nr_try
            sublinear_config.num_worker = self._sublinear_memory_config.num_worker
500
        # profile
501 502
        if self._profiling:
            self._profiler = GraphProfiler(graph)
503 504
        if int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")):
            graph.options.var_sanity_check_first_run = False
505

M
Megvii Engine Team 已提交
506 507
    def _compile(self):
        graph = self._graph = G.Graph()
508
        graph.options.async_exec_level = 0b100
509
        self._apply_graph_options(graph)
M
Megvii Engine Team 已提交
510 511 512
        # graph.options.graph_opt_level = 0
        need_reset_nodes = self._need_reset_nodes = []
        # links enforce ordering of I/O nodes
513 514
        in_out_links = ()
        io_links = ()
515
        readers = []
M
Megvii Engine Team 已提交
516 517

        if self._capture_as_const:
518
            for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()):
M
Megvii Engine Team 已提交
519 520
                info = self._tinfo[h]
                opnode = info.data_setter = G.InputNode(
521 522
                    device=info.device,
                    dtype=info.dtype,
523
                    shape=info.shape or (1,),
524 525
                    graph=graph,
                    use_static_shape=_input_node_use_static_shape(),
M
Megvii Engine Team 已提交
526 527 528
                )
                need_reset_nodes.append(opnode)
                info.varnode = opnode.outputs[0]
529
                in_out_links += opnode.outputs[1:]
M
Megvii Engine Team 已提交
530

531
        cnt_data, cnt_value, cnt_shape = 0, 0, 0
M
Megvii Engine Team 已提交
532
        for op, ihandles, ohandles in self._seq:
533
            if isinstance(op, str) and op == "Const":
534 535 536 537 538 539 540 541 542 543 544 545
                assert len(ihandles) == 0
                (h,) = ohandles
                info = self._tinfo[h]
                if not hasattr(info, "varnode"):
                    assert info.external
                    assert info.bound_data
                    info.varnode = graph.make_const(
                        info.bound_data.numpy(),
                        info.bound_data.dtype,
                        info.bound_data.device,
                    )
                continue
546

547
            require_links = type(op) in _io_op_types
M
Megvii Engine Team 已提交
548
            ivars = []
549
            for i, h in enumerate(ihandles):
M
Megvii Engine Team 已提交
550 551 552 553
                info = self._tinfo[h]
                if not hasattr(info, "varnode"):
                    assert info.external
                    if info.bound_data:
554 555 556 557 558 559 560 561 562 563 564
                        if hasattr(info, "is_const") and info.is_const:
                            info.varnode = graph.make_const(
                                info.bound_data.numpy(),
                                info.bound_data.dtype,
                                info.bound_data.device,
                            )
                        else:
                            info.varnode = graph.make_const(
                                info.bound_data._dev_tensor()
                                # info.bound_data.numpy()
                            )
M
Megvii Engine Team 已提交
565 566
                    else:
                        opnode = info.data_setter = G.InputNode(
567
                            *in_out_links,
568 569
                            device=info.device,
                            dtype=info.dtype,
570
                            shape=info.shape or (1,),
571
                            graph=graph,
572
                            use_static_shape=_input_node_use_static_shape(),
M
Megvii Engine Team 已提交
573 574
                        )
                        need_reset_nodes.append(opnode)
575 576
                        info.varnode, *in_out_links = opnode.outputs
                if require_links and i == 0 and len(io_links) > 0:
577 578 579 580
                    opnode = G.VirtualDepNode(
                        [info.varnode, *io_links], str(io_links[0].device)
                    )
                    info.varnode = opnode.outputs[0]
581
                    io_links = (info.varnode,)
M
Megvii Engine Team 已提交
582 583

                ivars.append(info.varnode)
584

585 586 587 588 589
            if isinstance(op, BackwardGraph):
                ovars = G.apply_backward_varnode(op, *ivars)
            else:
                ovars = G.apply_normal_varnode(op, *ivars)

590
            if require_links and len(ovars) > 0:
591
                io_links = (ovars[0],)
M
Megvii Engine Team 已提交
592 593 594 595 596 597
            assert len(ovars) == len(ohandles)
            for h, v in zip(ohandles, ovars):
                info = self._tinfo[h]
                info.varnode = v

                def add_reader(opnode):
598
                    nonlocal in_out_links
M
Megvii Engine Team 已提交
599 600
                    need_reset_nodes.append(opnode)
                    readers.append(opnode.outputs[0])
601
                    in_out_links = opnode.outputs
M
Megvii Engine Team 已提交
602

603
                if info.data_read:
M
Megvii Engine Team 已提交
604 605 606
                    # Shape can be obtained from data so doesn't need its own
                    # output node. On the other hand, value is read separately
                    # to leverage eager h2d copy
607
                    cnt_data += 1
M
Megvii Engine Team 已提交
608
                    info.shape_read = False
609
                    opnode = info.data_reader = G.OutputNode(v, *in_out_links)
M
Megvii Engine Team 已提交
610 611
                    add_reader(opnode)
                if info.value_read:
612
                    cnt_value += 1
613
                    opnode = info.value_reader = G.ValueOutputNode(v, *in_out_links)
M
Megvii Engine Team 已提交
614 615
                    add_reader(opnode)
                if info.shape_read:
616
                    cnt_shape += 1
617
                    opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links)
M
Megvii Engine Team 已提交
618
                    add_reader(opnode)
619

620 621 622 623 624
        # FIXME
        if self._graph_opt_level is not None:
            graph.options.graph_opt_level = self._graph_opt_level
        else:
            graph.options.graph_opt_level = 2
625
        graph._set_priority_to_id([*readers, *in_out_links, *io_links])
626
        graph.compile(*readers, *in_out_links, *io_links)
M
Megvii Engine Team 已提交
627 628 629 630 631 632

    def _reset_exec_env(self):
        for opnode in self._need_reset_nodes:
            opnode.reset()

    def __call__(self, *args, **kwargs):
633 634
        if is_tracing():
            return self.__wrapped__(*args, **kwargs)
M
Megvii Engine Team 已提交
635
        with self._setup():
M
Megvii Engine Team 已提交
636 637 638
            if self._capture_as_const:
                self._process_inputs(*args, **kwargs)
            outputs = self.__wrapped__(*args, **kwargs)
639 640 641 642 643 644 645 646 647 648
            transform = False
            if outputs is not None:
                if not isinstance(outputs, collections.abc.Sequence):
                    transform = True
                    outputs = (outputs,)
                for o in outputs:
                    if o._copied:
                        self._active_tensors[o.mixin_handle] = TensorWeakRef(o)
                        if self._untraced and self._symbolic:
                            self._lazy_eval_tensors[o.mixin_handle] = TensorWeakRef(o)
M
Megvii Engine Team 已提交
649 650
            if self._capture_as_const:
                self._process_outputs(outputs)
651 652
            if transform:
                outputs = outputs[0]
M
Megvii Engine Team 已提交
653 654
            return outputs

655 656 657 658 659 660 661 662 663 664
    def dump(
        self,
        file,
        *,
        arg_names=None,
        output_names=None,
        append=False,
        optimize_for_inference=True,
        **kwargs
    ):
665 666
        r"""
        Serializes trace to file system.
667 668 669 670 671 672 673

        :param file: output file, could be file object or filename.
        :param arg_names: names of the input tensors in the traced function.
        :param output_names: names of the output tensors in the traced function,
            use the default name if not specified.
        :param append: whether output is appended to ``file``.
            Only works when ``file`` is str.
674 675
        :param optimize_for_inference: enbale optmizations,
            will skip all optimize options if this is False. Default: True
676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715

        :Keyword Arguments:

            * enable_io16xc32 --
                whether to use float16 for I/O between oprs and use
                float32 as internal computation precision. Note the output var would be
                changed to float16.
            * enable_ioc16 --
                whether to use float16 for both I/O and computation
                precision.

            * enable_hwcd4 --
                whether to use NHWCD4 data layout. This is faster on some
                OpenCL backend.
            * enable_nchw88 --
                whether to use NCHW88 data layout, currently
                used in X86 AVX backend.
            * enable_nchw44 --
                whether to use NCHW44 data layout, currently
                used in arm backend.
            * enable_nchw44_dot --
                whether to use NCHW44_dot data layout, currently
                used in armv8.2+dotprod backend.
            * enable_nchw4 --
                whether to use NCHW4 data layout, currently
                used in nvidia backend(based on cudnn).
            * enable_nchw32 --
                whether to use NCHW32 data layout, currently
                used in nvidia backend with tensorcore(based on cudnn).
            * enable_chwn4 --
                whether to use CHWN4 data layout, currently
                used in nvidia backend with tensorcore.

            * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
                into one opr.
            * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
                input for inference on nvidia backend(this optimization pass will
                result in mismatch of the precision of output of training and
                inference)
        """
M
Megvii Engine Team 已提交
716 717 718 719 720 721 722 723 724 725
        if not self._capture_as_const:
            raise ValueError(
                "you must specify capture_as_const=True at __init__ to use dump"
            )
        if self._untraced:
            raise RuntimeError("should run at least once before calling dump")
        if self._output_names and output_names:
            raise TypeError(
                "cannot specify output_names when output is already in dict format"
            )
M
Megvii Engine Team 已提交
726
        if output_names and not isinstance(output_names, collections.abc.Sequence):
M
Megvii Engine Team 已提交
727 728
            output_names = (output_names,)
        if output_names and len(output_names) != len(self._output_bindings):
729 730 731 732 733
            raise ValueError(
                "wrong number of output_names, should be {} values".format(
                    len(self._output_bindings)
                )
            )
734 735
        if arg_names is None:
            arg_names = ["arg_%d" % i for i in range(len(self._arg_bindings))]
M
Megvii Engine Team 已提交
736
        if arg_names and not isinstance(arg_names, collections.abc.Sequence):
M
Megvii Engine Team 已提交
737 738
            arg_names = (arg_names,)
        if arg_names and len(arg_names) != len(self._arg_bindings):
739 740 741 742 743
            raise ValueError(
                "wrong number of arg_names, should be {} values".format(
                    len(self._arg_bindings)
                )
            )
M
Megvii Engine Team 已提交
744 745
        output_names = output_names or self._output_names

746 747
        dumped_device = as_device("xpux")

M
Megvii Engine Team 已提交
748 749
        h2v = {}
        graph = G.Graph()
750 751
        # only graph_opt_level takes effect in dump
        self._apply_graph_options(graph)
M
Megvii Engine Team 已提交
752

753
        for i, h in enumerate(self._arg_bindings):
M
Megvii Engine Team 已提交
754
            info = self._tinfo[h]
755 756
            h2v[h] = graph.make_h2d(
                dtype=info.dtype,
757
                device=dumped_device,
758
                shape=info.shape or (1,),
759 760 761
                name=arg_names[i] if arg_names else None,
            )
        for k, h in self._kwarg_bindings.items():
M
Megvii Engine Team 已提交
762
            info = self._tinfo[h]
763
            h2v[h] = graph.make_h2d(
764
                dtype=info.dtype, device=dumped_device, shape=info.shape or (1,), name=k
765
            )
M
Megvii Engine Team 已提交
766 767

        for op, ihandles, ohandles in self._seq:
768
            if isinstance(op, str) and op == "Const":
769 770 771 772 773 774 775 776 777 778
                assert len(ihandles) == 0
                (h,) = ohandles
                info = self._tinfo[h]
                if h not in h2v:
                    assert info.external
                    assert info.bound_data
                    h2v[h] = graph.make_const(
                        info.bound_data.numpy(), dtype=info.dtype, device=info.device,
                    )
                continue
M
Megvii Engine Team 已提交
779 780 781 782 783 784
            ivars = []
            for h in ihandles:
                info = self._tinfo[h]
                if h not in h2v:
                    assert info.external
                    assert info.bound_data
785
                    h2v[h] = graph.make_const(
786
                        info.bound_data.numpy(), dtype=info.dtype, device=dumped_device
787
                    )
M
Megvii Engine Team 已提交
788
                ivars.append(h2v[h])
789
            ovars = G.apply_normal_varnode(op, *ivars)
M
Megvii Engine Team 已提交
790 791 792 793 794 795 796 797 798 799
            assert len(ovars) == len(ohandles)
            h2v.update(zip(ohandles, ovars))

        dest_vars = []
        for i, h in enumerate(self._output_bindings):
            v = h2v[h]
            if output_names:
                v.name = output_names[i]
            dest_vars.append(v)

800 801
        if optimize_for_inference:
            dest_vars = G.optimize_for_inference(dest_vars, **kwargs)
802

M
Megvii Engine Team 已提交
803
        if isinstance(file, str):
804 805
            permission = "wb" if append == False else "ab"
            file = open(file, permission)
806 807 808
        dump_content, dump_info = G.dump_graph(dest_vars)
        file.write(dump_content)
        return dump_info
M
Megvii Engine Team 已提交
809 810 811 812 813 814 815 816 817 818 819 820

    def _process_inputs(self, *args, **kwargs):
        if self._untraced:
            self._inputs_to_restore = []

            def record_input(x):
                if x is None:
                    return
                h, info = self._new_handle()
                info.external = False
                info.device = x.device
                info.dtype = x.dtype
821 822
                info.shape = x.numpy().shape
                x.mixin_handle = h
823 824
                x.recording = True
                x._trace_mixin_info = info
M
Megvii Engine Team 已提交
825 826 827
                self._inputs_to_restore.append(x)
                return h

828
            self._arg_bindings = []
M
Megvii Engine Team 已提交
829
            for i, x in enumerate(args):
830
                if not isinstance(x, RawTensor):
M
Megvii Engine Team 已提交
831 832 833 834
                    raise TypeError(
                        "positional arguments should all be tensor "
                        "but args[%d] cannot be recognized as one" % i
                    )
835
                self._arg_bindings.append(record_input(x))
M
Megvii Engine Team 已提交
836

837
            self._kwarg_bindings = {}
M
Megvii Engine Team 已提交
838
            for k, x in kwargs.items():
839
                if isinstance(x, RawTensor):
840
                    self._kwarg_bindings[k] = record_input(x)
M
Megvii Engine Team 已提交
841
        else:
842
            if len(args) != len(self._arg_bindings):
M
Megvii Engine Team 已提交
843 844 845 846
                raise TraceMismatchError("positional argument length mismatch")

            self._tensor_remaps = {}

847
            for i, (h, x) in enumerate(zip(self._arg_bindings, args)):
848
                if not isinstance(x, RawTensor):
M
Megvii Engine Team 已提交
849 850 851 852 853 854 855 856 857 858
                    raise TypeError(
                        "positional arguments should all be tensor "
                        "but args[%d] cannot be recognized as one" % i
                    )
                info = self._tinfo[h]
                if x.dtype != info.dtype:
                    raise TypeError("args[%d].dtype different from last time" % i)
                if x.device != info.device:
                    raise TypeError("args[%d].device different from last time" % i)
                info.data_setter.set_value(x._dev_tensor())
859
                self._tensor_remaps[x._handle] = CompiledTensorProxy(h)
M
Megvii Engine Team 已提交
860 861 862

            kwargs_tensors = {}
            for k, x in kwargs.items():
863
                if isinstance(x, RawTensor):
M
Megvii Engine Team 已提交
864
                    kwargs_tensors[k] = x
865 866 867
            if set(kwargs_tensors) != set(self._kwarg_bindings):
                too_many = set(kwargs_tensors) - set(self._kwarg_bindings)
                too_few = set(self._kwarg_bindings) - set(kwargs_tensors)
M
Megvii Engine Team 已提交
868 869 870 871 872 873 874 875 876 877
                if too_many:
                    raise TraceMismatchError(
                        "keyword arguments found to be tensor this time "
                        "but were non-tensor previously: %s" % " ".join(too_many)
                    )
                if too_few:
                    raise TraceMismatchError(
                        "keyword arguments found to be non-tensor this time "
                        "but were tensor previously: %s" % " ".join(too_few)
                    )
878
            for k, h in self._kwarg_bindings.items():
M
Megvii Engine Team 已提交
879 880 881 882 883 884 885
                x = kwargs_tensors[k]
                info = self._tinfo[h]
                if x.dtype != info.dtype:
                    raise TypeError("kwargs[%s].dtype different from last time" % k)
                if x.device != info.device:
                    raise TypeError("kwargs[%s].device different from last time" % k)
                info.data_setter.set_value(x._dev_tensor())
886
                self._tensor_remaps[x._handle] = CompiledTensorProxy(h)
M
Megvii Engine Team 已提交
887 888 889

    def _process_outputs(self, outputs):
        output_names = None
M
Megvii Engine Team 已提交
890
        if isinstance(outputs, collections.abc.Mapping):
M
Megvii Engine Team 已提交
891
            output_names, outputs = zip(*sorted(outputs.items()))
M
Megvii Engine Team 已提交
892
        elif not isinstance(outputs, collections.abc.Sequence):
M
Megvii Engine Team 已提交
893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913
            outputs = (outputs,)

        if not self._untraced:
            if output_names != self._output_names:
                too_many = set(output_names) - set(self._output_names)
                too_few = set(self._output_names) - set(output_names)
                if too_many:
                    raise TraceMismatchError(
                        "output has more keys than last time: %s" % " ".join(too_many)
                    )
                if too_few:
                    raise TraceMismatchError(
                        "output has less keys than last time: %s" % " ".join(too_few)
                    )
            if len(outputs) != len(self._output_bindings):
                raise TraceMismatchError("output size differs from last time")
        else:
            self._output_names = output_names
            self._output_bindings = []

        for i, x in enumerate(outputs):
914
            if not isinstance(x, RawTensor):
M
Megvii Engine Team 已提交
915 916
                raise TypeError("every item of return value should be tensor")
            if self._untraced:
917 918
                h = x.mixin_handle
                if h < 0:
M
Megvii Engine Team 已提交
919 920 921
                    raise RuntimeError("output is not computed from inputs")
                self._output_bindings.append(h)
            else:
922
                h = x.mixin_handle
923
                if h not in self._output_handles:
M
Megvii Engine Team 已提交
924 925 926 927 928 929
                    raise RuntimeError("output is not computed from inputs")
                if h != self._output_bindings[i]:
                    raise TraceMismatchError(
                        "retval[%s] is a different tensor than last time"
                        % (output_names and output_names[i] or i)
                    )
M
Megvii Engine Team 已提交
930

931 932 933 934 935 936 937 938 939 940
    def get_profile(self):
        """
        Get profiling result for compiled trace.

        :return: a json compatible object.
        """
        if not self._profiler:
            raise RuntimeError("trace is not set with profiling=True")
        return json.loads(self._profiler.get())

941 942 943 944 945 946
    def trace(self, *args, **kwargs):
        raise NotImplementedError(
            "trace is deemed unbeneficial with the new "
            "tracing mechanism. You should alwasy use __call__."
        )

M
Megvii Engine Team 已提交
947

948
class CompiledTensorProxy:
M
Megvii Engine Team 已提交
949 950 951 952 953 954
    """
    Duck-typed RawTensor
    """

    def __init__(self, handle):
        self.__handle = handle
955
        self._isscalar = False
M
Megvii Engine Team 已提交
956 957 958 959 960 961 962 963 964 965 966 967 968 969 970
        self.__info = active_trace._tinfo[handle]
        self.__shape = None
        self.__data = None
        self.__value = None

    @property
    def dtype(self):
        return self.__info.varnode.dtype

    @property
    def device(self):
        return self.__info.varnode.device

    @property
    def shape(self):
971 972
        if self._isscalar:
            return ()
M
Megvii Engine Team 已提交
973
        if self.__shape is None:
974
            if self.__info.shape_read:
M
Megvii Engine Team 已提交
975
                self.__shape = self.__info.shape_reader.get_value().shape
976
            elif self.__info.data_read:
977
                self.__shape = self._dev_tensor().shape
M
Megvii Engine Team 已提交
978
            else:
979 980
                # c++ will throw TraceReadError
                return None
M
Megvii Engine Team 已提交
981 982 983 984
        return self.__shape

    def numpy(self):
        if self.__value is None:
985
            if self.__info.value_read:
M
Megvii Engine Team 已提交
986
                self.__value = self.__info.value_reader.get_value()
987
            elif self.__info.data_read:
M
Megvii Engine Team 已提交
988 989
                self.__value = self._dev_tensor().numpy()
            else:
990 991
                # c++ will throw TraceReadError
                return None
992 993
            if self._isscalar:
                self.__value = self.__value.squeeze()
M
Megvii Engine Team 已提交
994 995 996 997
        return self.__value

    def _dev_tensor(self):
        if self.__data is None:
998
            if not self.__info.data_read:
999 1000
                # c++ will throw TraceReadError
                return None
M
Megvii Engine Team 已提交
1001 1002 1003 1004
            self.__data = self.__info.data_reader.get_value()
        return self.__data

    def __del__(self):
1005
        if self.__info.shape_read and self.__shape is not None:
M
Megvii Engine Team 已提交
1006
            self.__info.shape_reader.drop_value()
1007
        if self.__info.value_read and self.__value is not None:
1008
            self.__info.value_reader.drop_value()
1009
        if self.__info.data_read and self.__data is not None:
M
Megvii Engine Team 已提交
1010 1011 1012 1013
            self.__info.data_reader.drop_value()


def assign_raw_tensor(lhs, rhs):
1014
    lhs.__init__(rhs)
M
Megvii Engine Team 已提交
1015 1016 1017 1018


def apply_symbolic_mode(op: OpDef, *args: RawTensor):
    graph = active_trace._lazy_eval_graph
1019 1020
    ivars = []
    for x in args:
1021
        var = getattr(x, "_varnode", None)
1022 1023 1024 1025 1026 1027
        if var:
            ivars.append(var)
        else:
            data_setter = G.InputNode(
                device=x.device,
                dtype=x.dtype,
1028
                shape=x.numpy().shape or (1,),
1029 1030 1031 1032 1033 1034
                graph=graph,
                use_static_shape=True,
            )
            var = data_setter.outputs[0]
            ivars.append(var)
            data_setter.set_value(x._dev_tensor())
1035 1036 1037 1038 1039

    require_links = type(op) in _io_op_types

    if require_links and active_trace._lazy_eval_links:
        assert len(ivars) > 0, "op should has at least one input"
1040 1041 1042 1043 1044
        opnode = G.VirtualDepNode(
            [ivars[0], *active_trace._lazy_eval_links],
            str(active_trace._lazy_eval_links[0].device),
        )
        ivars[0] = opnode.outputs[0]
1045 1046
        active_trace._lazy_eval_links = (ivars[0],)

1047 1048 1049 1050 1051
    if isinstance(op, BackwardGraph):
        ovars = G.apply_backward_varnode(op, *ivars)
    else:
        ovars = G.apply_normal_varnode(op, *ivars)
    outputs = [RawTensor(o) for o in ovars]
1052 1053

    if require_links:
1054
        active_trace._lazy_eval_links = (G.VarNode(outputs[0]._varnode),)
1055

M
Megvii Engine Team 已提交
1056 1057 1058
    return outputs


1059
def apply_const_symbolic_mode(value, dtype, device):
1060
    graph = active_trace._lazy_eval_graph
1061 1062 1063
    # don't need to unset tracing
    # because varnode construction will ignore tracing flag
    ret = RawTensor(graph.make_const(value, dtype=dtype, device=device))
1064 1065
    if np.array(value).ndim == 0:
        setscalar(ret)
1066 1067 1068
    return (ret,)


M
Megvii Engine Team 已提交
1069 1070 1071
def apply_compiled_mode(op: OpDef, *args: RawTensor):
    if skip_tracing:
        args = [
1072
            RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
M
Megvii Engine Team 已提交
1073 1074
            for x in args
        ]
1075 1076 1077 1078
        unset_tracing()
        ret = apply(op, *args)
        set_tracing()
        return ret
M
Megvii Engine Team 已提交
1079 1080 1081
    return active_trace._apply_op(op, args)


1082
def apply_const_compiled_mode(value, dtype, device, is_const, no_cache):
1083 1084
    if skip_tracing:
        args = [
1085
            RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
1086 1087
            for x in args
        ]
1088 1089 1090 1091 1092
        unset_tracing()
        ret = RawTensor(value, dtype, device, False)
        set_tracing()
        return ret
    return active_trace._apply_const(value, dtype, device)
1093 1094


M
Megvii Engine Team 已提交
1095
def apply_with_tracing(op: OpDef, *args: RawTensor):
1096 1097 1098 1099 1100 1101
    if active_trace._symbolic:
        outputs = apply_symbolic_mode(op, *args)
    else:
        unset_tracing()
        outputs = apply(op, *args)
        set_tracing()
M
Megvii Engine Team 已提交
1102

1103 1104
    active_trace._record_op(op, args, outputs)
    return list(outputs)
M
Megvii Engine Team 已提交
1105 1106


1107
def apply_const_with_tracing(value, dtype, device, is_const, no_cache):
1108 1109 1110 1111 1112 1113 1114 1115
    if active_trace._symbolic:
        outputs = apply_const_symbolic_mode(value, dtype, device)
    else:
        unset_tracing()
        outputs = (RawTensor(value, dtype, device, False),)
        set_tracing()
    active_trace._record_const(outputs)
    return list(outputs)