diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index fbd29a96a54b39b8234f1acf3fc08a335961bc70..1063a7678acc4ffbc401c647b3da043e9c2a0240 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -450,6 +450,9 @@ def _unwrap(x): def apply_normal_varnode(op: OpDef, *args: VarNode): + # for PyOp like RemoteSend/Recv + if getattr(op, "op", None): + op = op.op outputs = _imperative_rt.invoke_op(op, _unwrap(args)) return _wrap(outputs) diff --git a/imperative/python/megengine/distributed/functional.py b/imperative/python/megengine/distributed/functional.py index 32a24ff147e66f13da4e2023bcca1784722ba102..8b9e77c7fc53db37bb55542ec51b73b52eecbd71 100644 --- a/imperative/python/megengine/distributed/functional.py +++ b/imperative/python/megengine/distributed/functional.py @@ -292,6 +292,8 @@ def remote_recv( op = RemoteRecv() op.key = key op.cn = device + if isinstance(shape, Tensor): + shape = shape.numpy() op.shape = shape op.dtype = dtype op.addr, op.port = get_mm_server_addr() diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 67b18f3ffc0bc541f52a73124f96c56e85f63891..3c969d829209b18034667261bb8bb1436503d48b 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -234,20 +234,21 @@ class trace: ) info.data_setter.set_value(x._dev_tensor()) else: - pass - # if x.__class__ is not CompiledTensorProxy: - # if x not in self._tensor_remaps: - # raise TraceMismatchError( - # "unexpected capture: trying to use an external tensor as " - # "input, but that input was an internal tensor last time" - # ) - # else: - # x = self._tensor_remaps[x] - # if x._CompiledTensorProxy__handle != h: - # raise TraceMismatchError( - # "mis-wiring: input edge to an data flow " - # "graph node is different from last time" - # ) + if x.mixin_handle == -1: + if x._handle not in self._tensor_remaps: + raise TraceMismatchError( + "unexpected capture: trying to use an external tensor as " + "input, but that input was an internal tensor last time" + ) + else: + x.mixin_handle = self._tensor_remaps[ + x._handle + ]._CompiledTensorProxy__handle + if x.mixin_handle != h: + raise TraceMismatchError( + "mis-wiring: input edge to an data flow " + "graph node is different from last time" + ) self._pc += 1 outputs = [] @@ -268,14 +269,11 @@ class trace: op_, ihandles, ohandles = record assert isinstance(op_, str) and op_ == "Const" - # TODO : assert on const value - # eq = value == self._tinfo[ohandles[0]].bound_data.numpy() - # if not isinstance(eq, bool): - # eq = all(eq) - # if not eq: - # raise TraceMismatchError( - # "const tensor violated: got a different tensor this time" - # ) + eq = np.all(np.atleast_1d(value) == self._tinfo[ohandles[0]].bound_data.numpy()) + if not eq: + raise TraceMismatchError( + "const tensor violated: got a different tensor this time" + ) self._pc += 1 (h,) = ohandles @@ -750,7 +748,6 @@ class trace: dtype=info.dtype, device=dumped_device, shape=info.shape or (1,), name=k ) - set_tracing() for op, ihandles, ohandles in self._seq: if isinstance(op, str) and op == "Const": assert len(ihandles) == 0 @@ -776,7 +773,6 @@ class trace: ovars = G.apply_normal_varnode(op, *ivars) assert len(ovars) == len(ohandles) h2v.update(zip(ohandles, ovars)) - unset_tracing() dest_vars = [] for i, h in enumerate(self._output_bindings): @@ -843,7 +839,7 @@ class trace: if x.device != info.device: raise TypeError("args[%d].device different from last time" % i) info.data_setter.set_value(x._dev_tensor()) - self._tensor_remaps[x] = CompiledTensorProxy(h) + self._tensor_remaps[x._handle] = CompiledTensorProxy(h) kwargs_tensors = {} for k, x in kwargs.items(): @@ -870,7 +866,7 @@ class trace: if x.device != info.device: raise TypeError("kwargs[%s].device different from last time" % k) info.data_setter.set_value(x._dev_tensor()) - self._tensor_remaps[x] = CompiledTensorProxy(h) + self._tensor_remaps[x._handle] = CompiledTensorProxy(h) def _process_outputs(self, outputs): output_names = None @@ -1000,8 +996,8 @@ class CompiledTensorProxy: def __del__(self): if self.__tensor.shape_read and self.__shape is not None: self.__info.shape_reader.drop_value() - # if self.__tensor.value_read and self.__value is not None: - # self.__info.value_reader.drop_value() + if self.__tensor.value_read and self.__value is not None: + self.__info.value_reader.drop_value() if self.__tensor.data_read and self.__data is not None: self.__info.data_reader.drop_value() @@ -1047,7 +1043,7 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): outputs = [RawTensor(o) for o in ovars] if require_links: - active_trace._lazy_eval_links = (outputs[0]._varnode,) + active_trace._lazy_eval_links = (G.VarNode(outputs[0]._varnode),) active_trace._lazy_eval_tensors.update([TensorWeakRef(o) for o in outputs]) return outputs diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index fffa80fce02c6fd829a93115a57f748bd7bd6bc9..ab85a1ebc422f5b2f204d6e27fb4809f5ac0e599 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -760,7 +760,14 @@ void init_tensor(py::module m) { m.attr("skip_tracing") = &skip_tracing; py::class_(m, "SharedHandle") - .def(py::init()); + .def(py::init()) + .def("__eq__", [](SharedHandle &thish, SharedHandle &thath) { + return (thish.get() == thath.get()); + }) + .def("__hash__", [](SharedHandle &sh) { + return reinterpret_cast(sh.get()); + }) + ; m.def("set_tracing", &set_tracing); m.def("unset_tracing", &unset_tracing); diff --git a/imperative/python/test/unit/autodiff/test_grad_manger.py b/imperative/python/test/unit/autodiff/test_grad_manger.py index 6a1498ad0141db552ddd9396ac717a6fe92ed1a0..0761d973d421caa03d1835d505ee81e537a92790 100644 --- a/imperative/python/test/unit/autodiff/test_grad_manger.py +++ b/imperative/python/test/unit/autodiff/test_grad_manger.py @@ -141,7 +141,6 @@ def test_regression_1762(): ) @pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") @pytest.mark.isolated_distributed -@pytest.mark.skip(reason="FIXME: remote_send/recv") def test_remote_grad(): @dist.launcher def worker():