提交 d7f0fc8f 编写于 作者: M Megvii Engine Team

test(mge/distributed): add gather scatter reduce broadcast grad test

GitOrigin-RevId: 8245e11f1dd1c9c8f43d553924972f20a028d39d
上级 b5ec9dfe
......@@ -5,6 +5,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import os
import platform
import weakref
......@@ -151,9 +152,7 @@ def test_remote_grad(trace_mode):
def train_func(x):
with gm:
if rank != 0:
x = dist.functional.remote_recv(
rank - 1, shape=(1, rank * 2 + 2), dtype=np.float32
)
x = dist.functional.remote_recv(rank - 1)
y = m(x)
if rank != size - 1:
dist.functional.remote_send(y, dest_rank=rank + 1)
......@@ -170,3 +169,113 @@ def test_remote_grad(trace_mode):
train_func(x)
worker()
@pytest.mark.require_ngpu(3)
@pytest.mark.isolated_distributed
@pytest.mark.parametrize(
"trace_mode", [True, False, None], ids=["symbolic", "trace", "no_trace"]
)
def test_gather_grad(trace_mode):
@dist.launcher(n_gpus=3)
def worker():
m = M.Linear(10, 10)
x = F.ones([3, 10], dtype="float32")
def func():
with GradManager().attach(m.parameters()) as gm:
y = m(x)
y = F.distributed.gather(y)
if dist.get_rank() == 0:
loss = (2 * y + 1).mean()
gm.backward(loss)
else:
gm.backward()
if trace_mode is not None:
func = trace(symbolic=trace_mode)(func)
func()
worker()
@pytest.mark.require_ngpu(3)
@pytest.mark.isolated_distributed
@pytest.mark.parametrize(
"trace_mode", [True, False, None], ids=["symbolic", "trace", "no_trace"]
)
def test_scatter_grad(trace_mode):
@dist.launcher(n_gpus=3)
def worker():
x = F.ones([3, 10], dtype="float32")
m = M.Linear(10, 10)
def func():
with GradManager().attach(m.parameters()) as gm:
if dist.get_rank() == 0:
y = m(x)
else:
y = x
y = F.distributed.scatter(y)
gm.backward(y)
if trace_mode is not None:
func = trace(symbolic=trace_mode)(func)
func()
worker()
@pytest.mark.require_ngpu(3)
@pytest.mark.isolated_distributed
@pytest.mark.parametrize(
"trace_mode", [True, False, None], ids=["symbolic", "trace", "no_trace"]
)
def test_reduce_grad(trace_mode):
@dist.launcher(n_gpus=3)
def worker():
m = M.Linear(10, 10)
x = F.ones([3, 10], dtype="float32")
def func():
with GradManager().attach(m.parameters()) as gm:
y = m(x)
y = F.distributed.reduce_sum(y)
if dist.get_rank() == 0:
loss = (2 * y + 1).mean()
gm.backward(loss)
else:
gm.backward()
if trace_mode is not None:
func = trace(symbolic=trace_mode)(func)
func()
worker()
@pytest.mark.require_ngpu(3)
@pytest.mark.isolated_distributed
@pytest.mark.parametrize(
"trace_mode", [True, False, None], ids=["symbolic", "trace", "no_trace"]
)
def test_broadcast_grad(trace_mode):
@dist.launcher(n_gpus=3)
def worker():
x = F.ones([3, 10], dtype="float32")
m = M.Linear(10, 10)
def func():
with GradManager().attach(m.parameters()) as gm:
if dist.get_rank() == 0:
y = m(x)
else:
y = x
y = F.distributed.broadcast(y)
gm.backward(y)
if trace_mode is not None:
func = trace(symbolic=trace_mode)(func)
func()
worker()
......@@ -67,7 +67,7 @@ def test_dist_grad():
grad.wrt(x, callback=save_to(x))
# need a placeholder to trace operator
remote_send(x, 1)
recv_x = remote_recv(1, x_np.shape, x_np.dtype)
recv_x = remote_recv(1)
y = recv_x * recv_x
grad([y], [as_tensor(np.ones_like(x_np))])
......@@ -75,7 +75,7 @@ def test_dist_grad():
elif rank == 1:
grad = Grad()
recv_x = remote_recv(0, x_np.shape, x_np.dtype)
recv_x = remote_recv(0)
remote_send(recv_x, 0)
grad([], [])
......
......@@ -44,6 +44,8 @@ def test_reduce_sum(shape):
output = reduce_sum(inp)
if rank == 0:
assert np.allclose(output.numpy(), expect[rank])
else:
assert output is None
x = np.random.random_sample(shape).astype("float32")
y = np.random.random_sample(shape).astype("float32")
......@@ -177,6 +179,8 @@ def test_gather(shape):
output = gather(inp)
if rank == 0:
assert np.allclose(output.numpy(), expect[rank])
else:
assert output is None
x = np.random.random_sample(shape).astype("float32")
y = np.random.random_sample(shape).astype("float32")
......@@ -236,7 +240,7 @@ def test_io_remote(shape):
remote_send(x, 1)
sync()
else: # remote recv
y = remote_recv(0, shape, np.float32)
y = remote_recv(0)
assert y.device == get_default_device()
np.testing.assert_almost_equal(val, y.numpy())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册