提交 2ad8c5e1 编写于 作者: M Megvii Engine Team

fix(mge/io_remote): fix remote send/recv gradient at trace

GitOrigin-RevId: 7886efd0c124b1a6f60046c9f876e457eb683b1d
上级 f470df4f
...@@ -16,7 +16,7 @@ import numpy as np ...@@ -16,7 +16,7 @@ import numpy as np
import megengine as mge import megengine as mge
from ..ops.builtin import Elemwise, OpDef from ..ops.builtin import Elemwise, OpDef, RemoteSend
from ..ops.special import Const from ..ops.special import Const
from ..tensor.core import TensorBase, TensorWrapperBase, apply from ..tensor.core import TensorBase, TensorWrapperBase, apply
from ..tensor.function import Function from ..tensor.function import Function
...@@ -84,6 +84,9 @@ class Grad: ...@@ -84,6 +84,9 @@ class Grad:
# ops forms the computational graph # ops forms the computational graph
self.ops = [] self.ops = []
# save remote_send output for backward
self.remote_send_cache = []
self._attached_tensors = weakref.WeakSet() self._attached_tensors = weakref.WeakSet()
self._enabled = True self._enabled = True
...@@ -144,6 +147,7 @@ class Grad: ...@@ -144,6 +147,7 @@ class Grad:
o.clear() o.clear()
for i in self._attached_tensors: for i in self._attached_tensors:
i._extra_data.pop(self, None) i._extra_data.pop(self, None)
self.remote_send_cache = []
def __exit__(self, *_): def __exit__(self, *_):
self._exit() self._exit()
...@@ -398,6 +402,8 @@ def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]): ...@@ -398,6 +402,8 @@ def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]):
return return
opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs) opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs)
if isinstance(op, RemoteSend):
manager.remote_send_cache.append(opnode)
opnode.backward = backward opnode.backward = backward
outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)] outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)]
......
...@@ -588,7 +588,7 @@ class trace: ...@@ -588,7 +588,7 @@ class trace:
graph.options.graph_opt_level = self._graph_opt_level graph.options.graph_opt_level = self._graph_opt_level
else: else:
graph.options.graph_opt_level = 2 graph.options.graph_opt_level = 2
graph.compile(*readers) graph.compile(*readers, *links)
def _reset_exec_env(self): def _reset_exec_env(self):
for opnode in self._need_reset_nodes: for opnode in self._need_reset_nodes:
......
...@@ -111,7 +111,6 @@ def test_remote_grad(): ...@@ -111,7 +111,6 @@ def test_remote_grad():
gm = GradManager().attach(m.parameters()) gm = GradManager().attach(m.parameters())
opt = optim.SGD(m.parameters(), 1e-3, momentum=0.9) opt = optim.SGD(m.parameters(), 1e-3, momentum=0.9)
@trace(symbolic=True)
def train_func(x): def train_func(x):
with gm: with gm:
if rank != 0: if rank != 0:
...@@ -120,18 +119,22 @@ def test_remote_grad(): ...@@ -120,18 +119,22 @@ def test_remote_grad():
) )
y = m(x) y = m(x)
if rank != size - 1: if rank != size - 1:
y = dist.functional.remote_send(y, dest_rank=rank + 1) dist.functional.remote_send(y, dest_rank=rank + 1)
if rank == size - 1: gm.backward()
else:
y = y.mean() y = y.mean()
gm.backward(y) gm.backward(y)
else:
gm.backward()
opt.step().clear_grad() opt.step().clear_grad()
for i in range(3): train_funcs = [
train_func(x) train_func,
trace(symbolic=False)(train_func),
trace(symbolic=True)(train_func),
]
for param in m.parameters(): for func in train_funcs:
param.numpy() for i in range(3):
func(x)
sync()
worker() worker()
...@@ -266,11 +266,20 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_recv( ...@@ -266,11 +266,20 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_recv(
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
const OperatorNodeConfig& config) { const OperatorNodeConfig& config) {
auto&& opr = opr_.cast_final_safe<RemoteRecv>(); auto&& opr = opr_.cast_final_safe<RemoteRecv>();
return RemoteRecv::make(opr.key(), *opr.owner_graph(), if (inputs.size() == 1) {
opr.group_client(), config, inputs[0]->shape(), return RemoteRecv::make(opr.key(), inputs[0], *opr.owner_graph(),
inputs[0]->dtype()) opr.group_client(), config, opr.shape(),
.node() opr.dtype())
->owner_opr(); .node()
->owner_opr();
} else {
mgb_assert(inputs.size() == 0, "recv should have 1 or 0 input");
return RemoteRecv::make(opr.key(), *opr.owner_graph(),
opr.group_client(), config, opr.shape(),
opr.dtype())
.node()
->owner_opr();
}
} }
MGB_REG_OPR_SHALLOW_COPY(RemoteRecv, opr_shallow_copy_remote_recv); MGB_REG_OPR_SHALLOW_COPY(RemoteRecv, opr_shallow_copy_remote_recv);
......
...@@ -94,6 +94,9 @@ MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // { ...@@ -94,6 +94,9 @@ MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // {
const OperatorNodeConfig& config, const TensorShape& shape, const OperatorNodeConfig& config, const TensorShape& shape,
DType dtype); DType dtype);
const TensorShape& shape() const { return m_shape; }
const DType& dtype() const { return m_dtype; }
private: private:
const TensorShape m_shape; const TensorShape m_shape;
const DType m_dtype; const DType m_dtype;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册