From d3bfb0e983915734d9768f327a48a13edf425de6 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 8 Jan 2021 22:11:20 +0800 Subject: [PATCH] fix(mge): fix trace exit code and reformat GitOrigin-RevId: 145c06b7e7a7f98f40f0e7acc1b555c16f27e2ba --- .../megengine/distributed/functional.py | 2 - imperative/python/megengine/jit/tracing.py | 83 +++++++++++-------- imperative/python/src/tensor.cpp | 26 ++++-- imperative/python/src/trace.cpp | 1 - imperative/python/src/trace_info.h | 4 +- imperative/python/test/unit/test_tracing.py | 47 ++++++++++- 6 files changed, 117 insertions(+), 46 deletions(-) diff --git a/imperative/python/megengine/distributed/functional.py b/imperative/python/megengine/distributed/functional.py index 8b9e77c7f..32a24ff14 100644 --- a/imperative/python/megengine/distributed/functional.py +++ b/imperative/python/megengine/distributed/functional.py @@ -292,8 +292,6 @@ def remote_recv( op = RemoteRecv() op.key = key op.cn = device - if isinstance(shape, Tensor): - shape = shape.numpy() op.shape = shape op.dtype = dtype op.addr, op.port = get_mm_server_addr() diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 85acc0dc1..a9c66c998 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -191,19 +191,20 @@ class trace: if len(ihandles) != len(args): raise TraceMismatchError("op input size different from last time") + # check all inputs of crrent op for h, x in zip(ihandles, args): info = self._tinfo[h] if info.external: if ( - x.__class__ is CompiledTensorProxy - and not self._tinfo[x._CompiledTensorProxy__handle].exported + x._compiled_info is not None + and not self._tinfo[x._mixin_handle].exported ): raise TraceMismatchError( "failed to capture: input was an external tensor " "last time, got an internal tensor this time" ) if info.bound_data: - if x.__class__ is CompiledTensorProxy: + if x._compiled_info is not None: raise TraceMismatchError( "const capture violated: was an external tensor " "last time, got an internal tensor this time" @@ -225,17 +226,17 @@ class trace: ) info.data_setter.set_value(x._dev_tensor()) else: - if x.mixin_handle == -1: + if x._mixin_handle == -1: if x._handle not in self._tensor_remaps: raise TraceMismatchError( "unexpected capture: trying to use an external tensor as " "input, but that input was an internal tensor last time" ) else: - x.mixin_handle = self._tensor_remaps[ + x._mixin_handle = self._tensor_remaps[ x._handle ]._CompiledTensorProxy__handle - if x.mixin_handle != h: + if x._mixin_handle != h: raise TraceMismatchError( "mis-wiring: input edge to an data flow " "graph node is different from last time" @@ -245,9 +246,10 @@ class trace: outputs = [] for h in ohandles: info = self._tinfo[h] + # generate output tensor and create compied info y = RawTensor(info.varnode) y._compiled_info = CompiledTensorProxy(h) - y.mixin_handle = h + y._mixin_handle = h outputs += [y] self._active_tensors[h] = TensorWeakRef(y) self._output_handles.update(ohandles) @@ -260,6 +262,7 @@ class trace: raise TraceMismatchError("trace should end here, but more op observed") record = self._seq[self._pc] op_, ihandles, ohandles = record + # Const op is represented by a str assert isinstance(op_, str) and op_ == "Const" eq = np.all(np.atleast_1d(value) == self._tinfo[ohandles[0]].bound_data.numpy()) @@ -273,17 +276,18 @@ class trace: outputs = [self._tinfo[h].bound_data] return outputs + # run in first step, record information for trace def _record_op(self, op, inputs, outputs): if skip_tracing: for x in inputs: - h = getattr(x, "mixin_handle", -1) + h = getattr(x, "_mixin_handle", -1) if h >= 0: self._tinfo[h].data = True return ihandles = [] for x in inputs: - h = getattr(x, "mixin_handle", -1) + h = getattr(x, "_mixin_handle", -1) if h < 0 or (not self._capture_as_const and self._tinfo[h].exported): h, info = self._new_handle() info.external = True @@ -300,8 +304,8 @@ class trace: h, info = self._new_handle() ohandles.append(h) info.external = False - x.mixin_handle = h - x.recording = True + x._mixin_handle = h + x._recording = True x._trace_mixin_info = info self._active_tensors[h] = TensorWeakRef(x) if self._symbolic: @@ -312,7 +316,7 @@ class trace: def _record_const(self, outputs): if skip_tracing: (x,) = outputs - h = getattr(x, "mixin_handle", -1) + h = getattr(x, "_mixin_handle", -1) if h >= 0: self._tinfo[h].data_read = True return @@ -326,8 +330,8 @@ class trace: info.shape = x.shape info.bound_data = x info.is_const = True - x.mixin_handle = h - x.recording = True + x._mixin_handle = h + x._recording = True x._trace_mixin_info = info if self._symbolic: self._lazy_eval_tensors[h] = TensorWeakRef(x) @@ -371,6 +375,7 @@ class trace: lazy_eval_graph.compile(*lazy_eval_links, *readers) lazy_eval_graph() for r, x in zip(readers, lazy_eval_tensors): + # get values from lazy_eval_graph and assign to lazy_eval tensor x()._handle = RawTensor(r.op.get_value())._handle x()._reset_varnode() @@ -395,14 +400,14 @@ class trace: if self._untraced: for x in escaped_tensors: if x(): - info = self._tinfo[x().mixin_handle] + info = self._tinfo[x()._mixin_handle] info.data_read = True - x().mixin_handle = -1 - x().recording = False + x()._mixin_handle = -1 + x()._recording = False if self._inputs_to_restore: for x in self._inputs_to_restore: - x.mixin_handle = -1 - x.recording = False + x._mixin_handle = -1 + x._recording = False if self._symbolic and ( self._lazy_eval_tensors or self._lazy_eval_links ): @@ -441,12 +446,13 @@ class trace: if not self._untraced and self._pc != len(self._seq): raise TraceMismatchError("premature end") if not self._symbolic or not self._untraced: + # reset output tensors for x in self._active_tensors.values(): if x() is not None: x()._dev_tensor() x()._reset_varnode() - x().mixin_handle = -1 - x().recording = False + x()._mixin_handle = -1 + x()._recording = False x()._trace_mixin_info = None try: @@ -470,10 +476,14 @@ class trace: # conditionally reading a compiled tensor in excluded region # is permitted, so we have to assume every tensor might be read for x in self._active_tensors.values(): - info = self._tinfo[x().mixin_handle] - info.exported = True - info.data_read = True - x()._dev_tensor() + if x(): + info = self._tinfo[x()._mixin_handle] + info.exported = True + info.data_read = True + else: + for x in self._active_tensors.values(): + if x(): + x()._dev_tensor() def _apply_graph_options(self, graph): @@ -528,7 +538,6 @@ class trace: info.varnode = opnode.outputs[0] in_out_links += opnode.outputs[1:] - cnt_data, cnt_value, cnt_shape = 0, 0, 0 for op, ihandles, ohandles in self._seq: if isinstance(op, str) and op == "Const": assert len(ihandles) == 0 @@ -604,16 +613,13 @@ class trace: # Shape can be obtained from data so doesn't need its own # output node. On the other hand, value is read separately # to leverage eager h2d copy - cnt_data += 1 info.shape_read = False opnode = info.data_reader = G.OutputNode(v, *in_out_links) add_reader(opnode) if info.value_read: - cnt_value += 1 opnode = info.value_reader = G.ValueOutputNode(v, *in_out_links) add_reader(opnode) if info.shape_read: - cnt_shape += 1 opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links) add_reader(opnode) @@ -637,15 +643,17 @@ class trace: self._process_inputs(*args, **kwargs) outputs = self.__wrapped__(*args, **kwargs) transform = False + # outputs can be None if outputs is not None: if not isinstance(outputs, collections.abc.Sequence): transform = True outputs = (outputs,) for o in outputs: + # if outputs are copied, then use the newest info in trace data structure if o._copied: - self._active_tensors[o.mixin_handle] = TensorWeakRef(o) + self._active_tensors[o._mixin_handle] = TensorWeakRef(o) if self._untraced and self._symbolic: - self._lazy_eval_tensors[o.mixin_handle] = TensorWeakRef(o) + self._lazy_eval_tensors[o._mixin_handle] = TensorWeakRef(o) if self._capture_as_const: self._process_outputs(outputs) if transform: @@ -819,8 +827,8 @@ class trace: info.device = x.device info.dtype = x.dtype info.shape = x.numpy().shape - x.mixin_handle = h - x.recording = True + x._mixin_handle = h + x._recording = True x._trace_mixin_info = info self._inputs_to_restore.append(x) return h @@ -914,12 +922,12 @@ class trace: if not isinstance(x, RawTensor): raise TypeError("every item of return value should be tensor") if self._untraced: - h = x.mixin_handle + h = x._mixin_handle if h < 0: raise RuntimeError("output is not computed from inputs") self._output_bindings.append(h) else: - h = x.mixin_handle + h = x._mixin_handle if h not in self._output_handles: raise RuntimeError("output is not computed from inputs") if h != self._output_bindings[i]: @@ -938,6 +946,11 @@ class trace: raise RuntimeError("trace is not set with profiling=True") return json.loads(self._profiler.get()) + def __del__(self): + for x in self._tinfo: + if getattr(x, "bound_data", None): + x.bound_data = None + def trace(self, *args, **kwargs): raise NotImplementedError( "trace is deemed unbeneficial with the new " diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 558e6229a..eec79e744 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -291,7 +291,11 @@ PyObject* TensorWrapper::copied() { #define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \ PyObject* TensorWrapper::member() { \ - return m_tensor->m_trace_info.member; \ + if (m_tensor->m_trace_info.member) { \ + return m_tensor->m_trace_info.member; \ + } else { \ + Py_RETURN_NONE; \ + } \ } \ void TensorWrapper::set_##member(PyObject* dest) { \ if (dest == Py_None) { \ @@ -322,6 +326,7 @@ void TensorWrapper::set_handle(PyObject* dest) { PyObject* TensorWrapper::shape() { + // if it's tracing compiled mode, get value from compiled_info if (m_tensor->m_trace_info.compiled_info != nullptr) { if (m_tensor->m_flags & Tensor::Flags::SCALAR) { return PyTuple_New(0); @@ -332,15 +337,18 @@ PyObject* TensorWrapper::shape() { } return shp; } + + // inside trace, if tensor shape is useful for other operations, set shape_read = true if (m_tensor->m_trace_info.recording && !skip_tracing) { PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "shape_read", py::cast(true).release().ptr()); } + if (m_tensor->m_flags & Tensor::Flags::SCALAR) { return PyTuple_New(0); } TensorShape shape; - if (m_tensor->m_var) { + if (m_tensor->m_var) { // get shape from m_var auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager(); auto *tshp = mgr.infer_shape_fallible(m_tensor->m_var); if (!tshp) { @@ -389,9 +397,11 @@ PyObject* TensorWrapper::numpy() { } return np_val; } + if (m_tensor->m_trace_info.recording && !skip_tracing) { PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "value_read", py::cast(true).release().ptr()); } + if (m_tensor->m_handle.get() == nullptr && m_tensor->m_var != nullptr) { auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager(); auto&& type = mgr.get_infer_type(m_tensor->m_var); @@ -411,12 +421,14 @@ PyObject* TensorWrapper::numpy() { } return np_val.release().ptr(); } + auto&& hv = interpreter_for_py->get_value(m_tensor->m_handle.get()); auto arr = py::reinterpret_steal(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE)); if (!arr) { PyErr_SetString(PyExc_ValueError, "tensor invalid"); return nullptr; } + if (m_tensor->m_flags & Tensor::Flags::SCALAR) { mgb_assert(PyArray_Check(arr.ptr())); return PyArray_Squeeze(reinterpret_cast(arr.ptr())); @@ -428,7 +440,7 @@ PyObject* TensorWrapper::varnode() { if (m_tensor->m_var) { return py::cast(m_tensor->m_var).release().ptr(); } - return py::none().release().ptr(); + Py_RETURN_NONE; } void TensorWrapper::reset(PyObject* tensor) { @@ -465,9 +477,13 @@ PyObject* TensorWrapper::_dev_tensor(){ if (dev_tensor == Py_None) { throw TraceReadError("raw data of this tensor is not read in trace"); } + + // set m_handle to make it a real tensor auto py_dev_tensor = py::reinterpret_borrow(dev_tensor); auto sh = interpreter_for_py->put(py_dev_tensor.cast()); m_tensor->m_handle = std::move(SharedHandle(sh)); + + // compiled info is useless after m_handle is set Py_DECREF(m_tensor->m_trace_info.compiled_info); m_tensor->m_trace_info.compiled_info = nullptr; @@ -753,8 +769,8 @@ void init_tensor(py::module m) { .def<&TensorWrapper::reset_varnode>("_reset_varnode") .def_getset<&TensorWrapper::varnode>("_varnode") .def_getset<&TensorWrapper::copied>("_copied") - .def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("mixin_handle") - .def_getset<&TensorWrapper::recording, &TensorWrapper::set_recording>("recording") + .def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("_mixin_handle") + .def_getset<&TensorWrapper::recording, &TensorWrapper::set_recording>("_recording") .def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle") .def_getset<&TensorWrapper::compiled_info, &TensorWrapper::set_compiled_info>("_compiled_info") .def_getset<&TensorWrapper::trace_mixin_info, &TensorWrapper::set_trace_mixin_info>("_trace_mixin_info") diff --git a/imperative/python/src/trace.cpp b/imperative/python/src/trace.cpp index 09d1c1753..e99538633 100644 --- a/imperative/python/src/trace.cpp +++ b/imperative/python/src/trace.cpp @@ -55,7 +55,6 @@ apply_result_t apply_trace(ApplyContext& ctx) { auto args = py::tuple(ctx.nargs + 1); args[0] = py::cast(ctx.op); - py::tuple args(ctx.nargs); for (size_t i = 0; i < ctx.nargs; i++) { args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this()); } diff --git a/imperative/python/src/trace_info.h b/imperative/python/src/trace_info.h index d61d9024c..e41ecc6b5 100644 --- a/imperative/python/src/trace_info.h +++ b/imperative/python/src/trace_info.h @@ -19,7 +19,9 @@ struct TraceInfo { bool recording = false; bool copied = false; + // refer to CompiledTensorProxy in tracing.py, works from second trace step PyObject* compiled_info = nullptr; + // refer to TensorInfo in tracing.py, only works in first trace step PyObject* trace_mixin_info = nullptr; TraceInfo() = default; @@ -37,7 +39,7 @@ struct TraceInfo { return *this; } - ~TraceInfo() { + ~TraceInfo() { Py_XDECREF(trace_mixin_info); Py_XDECREF(compiled_info); } diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index 4eaeb3da2..ca0a5328b 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -14,14 +14,17 @@ import pytest import megengine.core.tensor.megbrain_graph as G import megengine.functional as F +import megengine.optimizer as optim import megengine.utils.comp_graph_tools as cgtools -from megengine import tensor +from megengine import Parameter, tensor +from megengine.autodiff import GradManager from megengine.core._trace_option import set_symbolic_shape from megengine.core.ops import builtin as ops from megengine.core.ops.builtin import Elemwise from megengine.core.tensor.utils import isscalar from megengine.functional import exp, log from megengine.jit import exclude_from_trace, trace +from megengine.module import Module from megengine.random import normal, uniform @@ -39,8 +42,48 @@ def test_trace(): np.testing.assert_equal(f(x).numpy(), y) +def test_output_copy_trace(): + class Simple(Module): + def __init__(self): + super().__init__() + self.a = Parameter([1.0], dtype=np.float32) + + def forward(self, x): + x = x * self.a + # will result into a copy of output in grad + x = F.exp(x) + return x + + net = Simple() + + gm = GradManager().attach(net.parameters()) + opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9) + data = tensor(np.arange(4).reshape(2, 2), dtype="float32") + + @trace(symbolic=False) + def train_f1(d): + with gm: + loss = net(d) + gm.backward(loss) + opt.step().clear_grad() + return loss + + @trace(symbolic=True) + def train_f2(d): + with gm: + loss = net(d) + gm.backward(loss) + opt.step().clear_grad() + return loss + + for i in range(2): + y1 = train_f1(data).numpy() + y2 = train_f2(data).numpy() + np.testing.assert_equal(y1, y2) + + def test_exclude_from_trace(): - for symbolic in [False]: + for symbolic in [False, True]: @trace(symbolic=symbolic) def f(x): -- GitLab