提交 9d928e7f 编写于 作者: M Megvii Engine Team

refactor(mge/distributed): sync interpreter for distribtued launcher

GitOrigin-RevId: 8a88c272a1eae8e633eb68dcbb77a00a1c943f0e
上级 4e9be159
......@@ -6,59 +6,105 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools
import multiprocessing as mp
from .group import init_process_group
from ..core._imperative_rt import sync
from .group import group_barrier, init_process_group
from .helper import get_device_count_by_fork
from .server import Server
from .util import get_free_ports
def _run_wrapped(func, master_ip, port, world_size, rank, dev, args, kwargs):
def _run_wrapped(
func, is_multimachine, master_ip, port, world_size, rank, dev, args, kwargs
):
"""Init distributed process group and run wrapped function."""
init_process_group(
master_ip=master_ip, port=port, world_size=world_size, rank=rank, device=dev
)
if is_multimachine:
group_barrier()
func(*args, **kwargs)
sync()
if is_multimachine:
group_barrier()
def launcher(func):
"""Decorator for launching multiple processes in single-machine multi-gpu training."""
class launcher:
"""Decorator for launching multiple processes in single-machine multi-gpu training.
:param func: the function you want to launch in distributed mode.
:param n_gpus: how many devices each node.
:param world_size: how many devices totally.
:param rank_start: start number for rank.
:param master_ip: ip address for master node (where the rank 0 is).
:param port: server port for distributed server.
"""
n_gpus = get_device_count_by_fork("gpu")
def __new__(cls, *args, **kwargs):
if not args:
return functools.partial(cls, **kwargs)
return super().__new__(cls)
def wrapper(*args, **kwargs):
master_ip = "localhost"
server = Server()
port = server.py_server_port
def __init__(
self,
func,
n_gpus=None,
world_size=None,
rank_start=0,
master_ip="localhost",
port=0,
):
self.func = func
self.n_gpus = n_gpus if n_gpus is not None else get_device_count_by_fork("gpu")
self.world_size = world_size if world_size is not None else self.n_gpus
self.rank_start = rank_start
self.master_ip = master_ip
self.port = port
# master node create server
if self.rank_start == 0:
self.server = Server(self.port)
self.port = self.server.py_server_port
else:
assert self.port != 0, "you have to assign a port for distributed server"
def __call__(self, *args, **kwargs):
procs = []
for rank in range(n_gpus):
for dev in range(self.n_gpus):
p = mp.Process(
target=_run_wrapped,
args=(func, master_ip, port, n_gpus, rank, rank, args, kwargs),
args=(
self.func,
self.world_size > self.n_gpus,
self.master_ip,
self.port,
self.world_size,
dev + self.rank_start,
dev,
args,
kwargs,
),
)
p.start()
procs.append(p)
ranks = [rank for rank in range(n_gpus)]
devs = list(range(self.n_gpus))
while len(ranks) > 0:
while len(devs) > 0:
left = []
# check all processes in one second
time_to_wait = 1.0 / len(ranks)
for rank in ranks:
procs[rank].join(time_to_wait)
code = procs[rank].exitcode
time_to_wait = 1.0 / len(devs)
for dev in devs:
procs[dev].join(time_to_wait)
code = procs[dev].exitcode
# terminate processes if one of them has failed
if code != 0 and code != None:
for i in ranks:
for i in devs:
procs[i].terminate()
assert (
code == 0 or code == None
), "subprocess {} exit with code {}".format(rank, code)
), "subprocess {} exit with code {}".format(dev + self.rank_start, code)
if code == None:
left.append(rank)
ranks = left
return wrapper
left.append(dev)
devs = left
......@@ -6,7 +6,6 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import multiprocessing as mp
import os
import platform
import re
......
......@@ -6,7 +6,6 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import multiprocessing as mp
import platform
import numpy as np
......
......@@ -17,7 +17,6 @@ import megengine.functional as F
import megengine.module as M
import megengine.optimizer as optim
from megengine.autodiff import GradManager
from megengine.core._imperative_rt.imperative import sync
from megengine.distributed.helper import get_device_count_by_fork
from megengine.jit import trace
......@@ -135,6 +134,5 @@ def test_remote_grad():
for func in train_funcs:
for i in range(3):
func(x)
sync()
worker()
......@@ -17,7 +17,6 @@ import megengine as mge
import megengine.distributed as dist
import megengine.functional as F
from megengine.core._imperative_rt import TensorAttr, imperative
from megengine.core._imperative_rt.imperative import sync
from megengine.core.autodiff.grad import Grad
from megengine.core.ops.builtin import Elemwise
from megengine.core.tensor.raw_tensor import as_raw_tensor
......@@ -65,47 +64,31 @@ def save_to(self, name="grad"):
def test_dist_grad():
world_size = 2
x_np = np.random.rand(10).astype("float32")
server = dist.Server()
port = server.py_server_port
def worker0():
dist.init_process_group("localhost", port, world_size, 0, 0)
mge.device.set_default_device("gpu0")
grad = Grad()
x = as_tensor(x_np)
grad.wrt(x, callback=save_to(x))
# need a placeholder to trace operator
send_x = remote_send(x, 1)
recv_x = remote_recv(1, x_np.shape, x_np.dtype, "gpu0")
y = recv_x * recv_x
grad([y], [as_tensor(np.ones_like(x_np))])
np.testing.assert_almost_equal(x.grad.numpy(), x.numpy() * 2)
def worker1():
dist.init_process_group("localhost", port, world_size, 1, 1)
mge.device.set_default_device("gpu1")
grad = Grad()
recv_x = remote_recv(0, x_np.shape, x_np.dtype, "gpu1")
send_x = remote_send(recv_x, 0)
grad([], [])
# sync because grad has a send operator
sync()
send_x.device._cn._sync_all()
import multiprocessing as mp
p0 = mp.Process(target=worker0)
p1 = mp.Process(target=worker1)
p0.start()
p1.start()
p0.join(10)
p1.join(10)
assert p0.exitcode == 0 and p1.exitcode == 0
@dist.launcher
def worker():
rank = dist.get_rank()
if rank == 0:
grad = Grad()
x = as_tensor(x_np)
grad.wrt(x, callback=save_to(x))
# need a placeholder to trace operator
send_x = remote_send(x, 1)
recv_x = remote_recv(1, x_np.shape, x_np.dtype)
y = recv_x * recv_x
grad([y], [as_tensor(np.ones_like(x_np))])
np.testing.assert_almost_equal(x.grad.numpy(), x.numpy() * 2)
elif rank == 1:
grad = Grad()
recv_x = remote_recv(0, x_np.shape, x_np.dtype)
send_x = remote_send(recv_x, 0)
grad([], [])
worker()
def test_grad():
......
......@@ -6,7 +6,6 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import multiprocessing as mp
import platform
import numpy as np
......@@ -16,6 +15,7 @@ import megengine as mge
import megengine.distributed as dist
from megengine import Parameter, Tensor, tensor
from megengine.device import get_default_device, set_default_device
from megengine.distributed.helper import get_device_count_by_fork
from megengine.functional.distributed import (
all_gather,
all_reduce_max,
......@@ -38,20 +38,16 @@ from megengine.functional.distributed import (
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
def test_reduce_sum():
world_size = 2
server = dist.Server()
port = server.py_server_port
def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size:
return
dist.init_process_group("localhost", port, world_size, rank, rank)
inp = tensor(data)
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
inp = tensor(data[rank])
output = reduce_sum(inp)
if rank == 0:
assert np.allclose(output.numpy(), expect)
assert np.allclose(output.numpy(), expect[rank])
else:
assert np.allclose(output.numpy(), 0)
......@@ -59,16 +55,9 @@ def test_reduce_sum():
x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32")
z = x + y
p0 = mp.Process(target=worker, args=(0, x, z, port))
p1 = mp.Process(target=worker, args=(1, y, None, port))
p0.start()
p1.start()
p0.join(10)
p1.join(10)
assert p0.exitcode == 0 and p1.exitcode == 0
data = (x, y)
expect = (z, None)
worker(data, expect)
for shape in [(2, 3), (8, 10), (99, 77)]:
check(shape)
......@@ -80,33 +69,22 @@ def test_reduce_sum():
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
def test_broadcast():
world_size = 2
server = dist.Server()
port = server.py_server_port
def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size:
return
dist.init_process_group("localhost", port, world_size, rank, rank)
inp = tensor(data)
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
inp = tensor(data[rank])
output = broadcast(inp)
assert np.allclose(output.numpy(), expect)
assert np.allclose(output.numpy(), expect[rank])
def check(shape):
x = np.random.rand(*shape).astype("float32")
y = x + 1
p0 = mp.Process(target=worker, args=(0, x, x, port))
p1 = mp.Process(target=worker, args=(1, y, x, port))
p0.start()
p1.start()
p0.join(10)
p1.join(10)
assert p0.exitcode == 0 and p1.exitcode == 0
data = (x, y)
expect = (x, x)
worker(data, expect)
for shape in [(2, 3), (8, 10), (99, 77)]:
check(shape)
......@@ -118,34 +96,23 @@ def test_broadcast():
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
def test_all_gather():
world_size = 2
server = dist.Server()
port = server.py_server_port
def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size:
return
dist.init_process_group("localhost", port, world_size, rank, rank)
inp = tensor(data)
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
inp = tensor(data[rank])
output = all_gather(inp)
assert np.allclose(output.numpy(), expect)
assert np.allclose(output.numpy(), expect[rank])
def check(shape):
x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32")
z = np.concatenate((x, y))
p0 = mp.Process(target=worker, args=(0, x, z, port))
p1 = mp.Process(target=worker, args=(1, y, z, port))
p0.start()
p1.start()
p0.join(10)
p1.join(10)
assert p0.exitcode == 0 and p1.exitcode == 0
data = (x, y)
expect = (z, z)
worker(data, expect)
for shape in [(2, 3), (8, 10), (99, 77)]:
check(shape)
......@@ -157,34 +124,23 @@ def test_all_gather():
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
def test_reduce_scatter_sum():
world_size = 2
server = dist.Server()
port = server.py_server_port
def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size:
return
dist.init_process_group("localhost", port, world_size, rank, rank)
inp = tensor(data)
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
inp = tensor(data[rank])
output = reduce_scatter_sum(inp)
assert np.allclose(output.numpy(), expect)
assert np.allclose(output.numpy(), expect[rank])
def check(shape):
x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32")
z = x + y
p0 = mp.Process(target=worker, args=(0, x, z[: shape[0] // 2], port))
p1 = mp.Process(target=worker, args=(1, y, z[shape[0] // 2 :], port))
p0.start()
p1.start()
p0.join(10)
p1.join(10)
assert p0.exitcode == 0 and p1.exitcode == 0
data = (x, y)
expect = (z[: shape[0] // 2], z[shape[0] // 2 :])
worker(data, expect)
for shape in [(2, 4), (8, 10), (88, 44)]:
check(shape)
......@@ -196,34 +152,23 @@ def test_reduce_scatter_sum():
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
def test_all_reduce_sum():
world_size = 2
server = dist.Server()
port = server.py_server_port
def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size:
return
dist.init_process_group("localhost", port, world_size, rank, rank)
inp = tensor(data)
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
inp = tensor(data[rank])
output = all_reduce_sum(inp)
assert np.allclose(output.numpy(), expect)
assert np.allclose(output.numpy(), expect[rank])
def check(shape):
x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32")
z = x + y
p0 = mp.Process(target=worker, args=(0, x, z, port))
p1 = mp.Process(target=worker, args=(1, y, z, port))
p0.start()
p1.start()
p0.join(10)
p1.join(10)
assert p0.exitcode == 0 and p1.exitcode == 0
data = (x, y)
expect = (z, z)
worker(data, expect)
for shape in [(2, 3), (8, 10), (99, 77)]:
check(shape)
......@@ -235,34 +180,23 @@ def test_all_reduce_sum():
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
def test_all_reduce_max():
world_size = 2
server = dist.Server()
port = server.py_server_port
def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size:
return
dist.init_process_group("localhost", port, world_size, rank, rank)
inp = tensor(data)
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
inp = tensor(data[rank])
output = all_reduce_max(inp)
assert np.allclose(output.numpy(), expect)
assert np.allclose(output.numpy(), expect[rank])
def check(shape):
x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32")
z = np.maximum(x, y)
p0 = mp.Process(target=worker, args=(0, x, z, port))
p1 = mp.Process(target=worker, args=(1, y, z, port))
p0.start()
p1.start()
p0.join(10)
p1.join(10)
assert p0.exitcode == 0 and p1.exitcode == 0
data = (x, y)
expect = (z, z)
worker(data, expect)
for shape in [(2, 3), (8, 10), (99, 77)]:
check(shape)
......@@ -274,34 +208,23 @@ def test_all_reduce_max():
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
def test_all_reduce_min():
world_size = 2
server = dist.Server()
port = server.py_server_port
def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size:
return
dist.init_process_group("localhost", port, world_size, rank, rank)
inp = tensor(data)
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
inp = tensor(data[rank])
output = all_reduce_min(inp)
assert np.allclose(output.numpy(), expect)
assert np.allclose(output.numpy(), expect[rank])
def check(shape):
x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32")
z = np.minimum(x, y)
p0 = mp.Process(target=worker, args=(0, x, z, port))
p1 = mp.Process(target=worker, args=(1, y, z, port))
p0.start()
p1.start()
p0.join(10)
p1.join(10)
assert p0.exitcode == 0 and p1.exitcode == 0
data = (x, y)
expect = (z, z)
worker(data, expect)
for shape in [(2, 3), (8, 10), (99, 77)]:
check(shape)
......@@ -313,20 +236,16 @@ def test_all_reduce_min():
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
def test_gather():
world_size = 2
server = dist.Server()
port = server.py_server_port
def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size:
return
dist.init_process_group("localhost", port, world_size, rank, rank)
inp = tensor(data)
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
inp = tensor(data[rank])
output = gather(inp)
if rank == 0:
assert np.allclose(output.numpy(), expect)
assert np.allclose(output.numpy(), expect[rank])
else:
assert np.allclose(output.numpy(), 0)
......@@ -334,16 +253,9 @@ def test_gather():
x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32")
z = np.concatenate((x, y))
p0 = mp.Process(target=worker, args=(0, x, z, port))
p1 = mp.Process(target=worker, args=(1, y, None, port))
p0.start()
p1.start()
p0.join(10)
p1.join(10)
assert p0.exitcode == 0 and p1.exitcode == 0
data = (x, y)
expect = (z, None)
worker(data, expect)
for shape in [(2, 3), (8, 10), (99, 77)]:
check(shape)
......@@ -355,33 +267,22 @@ def test_gather():
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
def test_scatter():
world_size = 2
server = dist.Server()
port = server.py_server_port
def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size:
return
dist.init_process_group("localhost", port, world_size, rank, rank)
inp = tensor(data)
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
inp = tensor(data[rank])
output = scatter(inp)
assert np.allclose(output.numpy(), expect)
assert np.allclose(output.numpy(), expect[rank])
def check(shape):
x = np.random.rand(*shape).astype("float32")
y = x + 1
p0 = mp.Process(target=worker, args=(0, x, x[: shape[0] // 2], port))
p1 = mp.Process(target=worker, args=(1, y, x[shape[0] // 2 :], port))
p0.start()
p1.start()
p0.join(10)
p1.join(10)
assert p0.exitcode == 0 and p1.exitcode == 0
data = (x, y)
expect = (x[: shape[0] // 2], x[shape[0] // 2 :])
worker(data, expect)
for shape in [(2, 3), (8, 10), (100, 77)]:
check(shape)
......@@ -393,35 +294,24 @@ def test_scatter():
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
def test_all_to_all():
world_size = 2
server = dist.Server()
port = server.py_server_port
def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size:
return
dist.init_process_group("localhost", port, world_size, rank, rank)
inp = tensor(data)
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
inp = tensor(data[rank])
output = all_to_all(inp)
assert np.allclose(output.numpy(), expect)
assert np.allclose(output.numpy(), expect[rank])
def check(shape):
x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32")
a = np.concatenate((x[: shape[0] // 2], y[: shape[0] // 2]))
b = np.concatenate((x[shape[0] // 2 :], y[shape[0] // 2 :]))
p0 = mp.Process(target=worker, args=(0, x, a, port))
p1 = mp.Process(target=worker, args=(1, y, b, port))
p0.start()
p1.start()
p0.join(10)
p1.join(10)
assert p0.exitcode == 0 and p1.exitcode == 0
data = (x, y)
expect = (a, b)
worker(data, expect)
for shape in [(2, 3), (8, 10), (100, 77)]:
check(shape)
......@@ -433,33 +323,21 @@ def test_all_to_all():
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
def test_io_remote():
world_size = 2
server = dist.Server()
port = server.py_server_port
val = np.random.rand(4, 5).astype(np.float32)
def worker(rank):
if mge.get_device_count("gpu") < world_size:
return
@dist.launcher(n_gpus=2)
def worker():
rank = dist.get_rank()
if rank == 0: # remote send
dist.init_process_group("localhost", port, world_size, rank, rank)
x = Tensor(val, device="gpu0")
y = remote_send(x, 1)
assert y.numpy()[0] == 0
else: # remote recv
dist.init_process_group("localhost", port, world_size, rank, rank)
y = remote_recv(0, val.shape, val.dtype)
assert y.device == "gpu1"
np.testing.assert_almost_equal(val, y.numpy())
procs = []
for rank in range(world_size):
p = mp.Process(target=worker, args=(rank,))
p.start()
procs.append(p)
for p in procs:
p.join(10)
assert p.exitcode == 0
worker()
......@@ -7,7 +7,6 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools
import multiprocessing as mp
import platform
import numpy as np
......@@ -17,6 +16,7 @@ import megengine as mge
import megengine.distributed as dist
from megengine import Tensor
from megengine.core._trace_option import use_symbolic_shape
from megengine.distributed.helper import get_device_count_by_fork
from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm
_assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6)
......@@ -28,6 +28,7 @@ _assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
def test_syncbn():
nr_chan = 8
......@@ -41,15 +42,14 @@ def test_syncbn():
server = dist.Server()
port = server.py_server_port
def worker(rank, data, yv_expect, running_mean, running_var):
if mge.get_device_count("gpu") < nr_ranks:
return
dist.init_process_group("localhost", port, nr_ranks, rank, rank)
@dist.launcher(n_gpus=2)
def worker(data, yv_expect, running_mean, running_var):
rank = dist.get_rank()
bn = SyncBatchNorm(nr_chan, momentum=momentum, eps=eps)
for i in range(steps):
yv = bn(Tensor(data[i]))
yv = bn(Tensor(data[rank][i]))
_assert_allclose(yv.numpy(), yv_expect)
_assert_allclose(yv.numpy(), yv_expect[rank])
_assert_allclose(bn.running_mean.numpy(), running_mean)
_assert_allclose(bn.running_var.numpy(), running_var)
......@@ -77,24 +77,9 @@ def test_syncbn():
for j in range(steps):
data[i].append(xv[j][:, :, :, i * 8 : i * 8 + 8])
procs = []
for rank in range(nr_ranks):
p = mp.Process(
target=worker,
args=(
rank,
data[rank],
yv_expect[:, :, :, rank * 8 : rank * 8 + 8],
running_mean,
running_var,
),
)
p.start()
procs.append(p)
yv_expect = [yv_expect[:, :, :, i * 8 : i * 8 + 8] for i in range(nr_ranks)]
for p in procs:
p.join(10)
assert p.exitcode == 0
worker(data, yv_expect, running_mean, running_var)
def test_batchnorm():
......
import multiprocessing as mp
import platform
import numpy as np
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册