提交 de0742be 编写于 作者: M Megvii Engine Team

refactor(mge): reopen passed assertions

GitOrigin-RevId: e0276e73e31ddba1e35a1561abe3f178eedd509a
上级 a90c937d
......@@ -450,6 +450,9 @@ def _unwrap(x):
def apply_normal_varnode(op: OpDef, *args: VarNode):
# for PyOp like RemoteSend/Recv
if getattr(op, "op", None):
op = op.op
outputs = _imperative_rt.invoke_op(op, _unwrap(args))
return _wrap(outputs)
......
......@@ -292,6 +292,8 @@ def remote_recv(
op = RemoteRecv()
op.key = key
op.cn = device
if isinstance(shape, Tensor):
shape = shape.numpy()
op.shape = shape
op.dtype = dtype
op.addr, op.port = get_mm_server_addr()
......
......@@ -234,20 +234,21 @@ class trace:
)
info.data_setter.set_value(x._dev_tensor())
else:
pass
# if x.__class__ is not CompiledTensorProxy:
# if x not in self._tensor_remaps:
# raise TraceMismatchError(
# "unexpected capture: trying to use an external tensor as "
# "input, but that input was an internal tensor last time"
# )
# else:
# x = self._tensor_remaps[x]
# if x._CompiledTensorProxy__handle != h:
# raise TraceMismatchError(
# "mis-wiring: input edge to an data flow "
# "graph node is different from last time"
# )
if x.mixin_handle == -1:
if x._handle not in self._tensor_remaps:
raise TraceMismatchError(
"unexpected capture: trying to use an external tensor as "
"input, but that input was an internal tensor last time"
)
else:
x.mixin_handle = self._tensor_remaps[
x._handle
]._CompiledTensorProxy__handle
if x.mixin_handle != h:
raise TraceMismatchError(
"mis-wiring: input edge to an data flow "
"graph node is different from last time"
)
self._pc += 1
outputs = []
......@@ -268,14 +269,11 @@ class trace:
op_, ihandles, ohandles = record
assert isinstance(op_, str) and op_ == "Const"
# TODO : assert on const value
# eq = value == self._tinfo[ohandles[0]].bound_data.numpy()
# if not isinstance(eq, bool):
# eq = all(eq)
# if not eq:
# raise TraceMismatchError(
# "const tensor violated: got a different tensor this time"
# )
eq = np.all(np.atleast_1d(value) == self._tinfo[ohandles[0]].bound_data.numpy())
if not eq:
raise TraceMismatchError(
"const tensor violated: got a different tensor this time"
)
self._pc += 1
(h,) = ohandles
......@@ -750,7 +748,6 @@ class trace:
dtype=info.dtype, device=dumped_device, shape=info.shape or (1,), name=k
)
set_tracing()
for op, ihandles, ohandles in self._seq:
if isinstance(op, str) and op == "Const":
assert len(ihandles) == 0
......@@ -776,7 +773,6 @@ class trace:
ovars = G.apply_normal_varnode(op, *ivars)
assert len(ovars) == len(ohandles)
h2v.update(zip(ohandles, ovars))
unset_tracing()
dest_vars = []
for i, h in enumerate(self._output_bindings):
......@@ -843,7 +839,7 @@ class trace:
if x.device != info.device:
raise TypeError("args[%d].device different from last time" % i)
info.data_setter.set_value(x._dev_tensor())
self._tensor_remaps[x] = CompiledTensorProxy(h)
self._tensor_remaps[x._handle] = CompiledTensorProxy(h)
kwargs_tensors = {}
for k, x in kwargs.items():
......@@ -870,7 +866,7 @@ class trace:
if x.device != info.device:
raise TypeError("kwargs[%s].device different from last time" % k)
info.data_setter.set_value(x._dev_tensor())
self._tensor_remaps[x] = CompiledTensorProxy(h)
self._tensor_remaps[x._handle] = CompiledTensorProxy(h)
def _process_outputs(self, outputs):
output_names = None
......@@ -1000,8 +996,8 @@ class CompiledTensorProxy:
def __del__(self):
if self.__tensor.shape_read and self.__shape is not None:
self.__info.shape_reader.drop_value()
# if self.__tensor.value_read and self.__value is not None:
# self.__info.value_reader.drop_value()
if self.__tensor.value_read and self.__value is not None:
self.__info.value_reader.drop_value()
if self.__tensor.data_read and self.__data is not None:
self.__info.data_reader.drop_value()
......@@ -1047,7 +1043,7 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor):
outputs = [RawTensor(o) for o in ovars]
if require_links:
active_trace._lazy_eval_links = (outputs[0]._varnode,)
active_trace._lazy_eval_links = (G.VarNode(outputs[0]._varnode),)
active_trace._lazy_eval_tensors.update([TensorWeakRef(o) for o in outputs])
return outputs
......
......@@ -760,7 +760,14 @@ void init_tensor(py::module m) {
m.attr("skip_tracing") = &skip_tracing;
py::class_<SharedHandle>(m, "SharedHandle")
.def(py::init<const SharedHandle&>());
.def(py::init<const SharedHandle&>())
.def("__eq__", [](SharedHandle &thish, SharedHandle &thath) {
return (thish.get() == thath.get());
})
.def("__hash__", [](SharedHandle &sh) {
return reinterpret_cast<int64_t>(sh.get());
})
;
m.def("set_tracing", &set_tracing);
m.def("unset_tracing", &unset_tracing);
......
......@@ -141,7 +141,6 @@ def test_regression_1762():
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
@pytest.mark.skip(reason="FIXME: remote_send/recv")
def test_remote_grad():
@dist.launcher
def worker():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册