test_distributed.py 6.4 KB
Newer Older
1 2 3
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
4
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
5 6 7 8 9 10 11 12
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import multiprocessing as mp
import platform
import queue

13
import numpy as np
14 15 16 17
import pytest

import megengine as mge
import megengine.distributed as dist
18
from megengine.core.ops.builtin import CollectiveComm, ParamPackConcat, ParamPackSplit
19
from megengine.device import get_default_device
20
from megengine.distributed.helper import param_pack_concat, param_pack_split
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36


def _assert_q_empty(q):
    try:
        res = q.get(timeout=1)
    except Exception as e:
        assert isinstance(e, queue.Empty)
    else:
        assert False, "queue is not empty"


def _assert_q_val(q, val):
    ret = q.get()
    assert ret == val


37 38
@pytest.mark.require_ngpu(2)
@pytest.mark.parametrize("backend", ["nccl"])
39
@pytest.mark.isolated_distributed
40
def test_init_process_group(backend):
41
    world_size = 2
42 43
    server = dist.Server()
    port = server.py_server_port
44

45
    def worker(rank):
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
        dist.init_process_group("localhost", port, world_size, rank, rank, backend)
        assert dist.is_distributed() == True
        assert dist.get_rank() == rank
        assert dist.get_world_size() == world_size
        assert dist.get_backend() == backend

        py_server_addr = dist.get_py_server_addr()
        assert py_server_addr[0] == "localhost"
        assert py_server_addr[1] == port

        mm_server_addr = dist.get_mm_server_addr()
        assert mm_server_addr[0] == "localhost"
        assert mm_server_addr[1] > 0

        assert isinstance(dist.get_client(), dist.Client)

62 63 64 65 66
    procs = []
    for rank in range(world_size):
        p = mp.Process(target=worker, args=(rank,))
        p.start()
        procs.append(p)
67

68 69 70
    for p in procs:
        p.join(20)
        assert p.exitcode == 0
71 72


73
@pytest.mark.require_ngpu(3)
74 75 76 77 78
@pytest.mark.isolated_distributed
def test_new_group():
    world_size = 3
    ranks = [2, 0]

79 80 81
    @dist.launcher
    def worker():
        rank = dist.get_rank()
82 83 84 85 86
        if rank in ranks:
            group = dist.new_group(ranks)
            assert group.size == 2
            assert group.key == "2,0"
            assert group.rank == ranks.index(rank)
87 88
            dt = get_default_device()[:-1]
            assert group.comp_node == "{}{}:2".format(dt, rank)
89

90
    worker()
91 92


93
@pytest.mark.require_ngpu(2)
94 95 96
@pytest.mark.isolated_distributed
def test_group_barrier():
    world_size = 2
97 98
    server = dist.Server()
    port = server.py_server_port
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122

    def worker(rank, q):
        dist.init_process_group("localhost", port, world_size, rank, rank)
        dist.group_barrier()
        if rank == 0:
            dist.group_barrier()
            q.put(0)  # to be observed in rank 1
        else:
            _assert_q_empty(q)  # q.put(0) is not executed in rank 0
            dist.group_barrier()
            _assert_q_val(q, 0)  # q.put(0) executed in rank 0

    Q = mp.Queue()
    procs = []
    for rank in range(world_size):
        p = mp.Process(target=worker, args=(rank, Q))
        p.start()
        procs.append(p)

    for p in procs:
        p.join(20)
        assert p.exitcode == 0


123
@pytest.mark.require_ngpu(2)
124 125 126
@pytest.mark.isolated_distributed
def test_synchronized():
    world_size = 2
127 128
    server = dist.Server()
    port = server.py_server_port
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158

    @dist.synchronized
    def func(rank, q):
        q.put(rank)

    def worker(rank, q):
        dist.init_process_group("localhost", port, world_size, rank, rank)
        dist.group_barrier()
        if rank == 0:
            func(0, q)  # q.put(0)
            q.put(2)
        else:
            _assert_q_val(q, 0)  # func executed in rank 0
            _assert_q_empty(q)  # q.put(2) is not executed
            func(1, q)
            _assert_q_val(
                q, 1
            )  # func in rank 1 executed earlier than q.put(2) in rank 0
            _assert_q_val(q, 2)  # q.put(2) executed in rank 0

    Q = mp.Queue()
    procs = []
    for rank in range(world_size):
        p = mp.Process(target=worker, args=(rank, Q))
        p.start()
        procs.append(p)

    for p in procs:
        p.join(20)
        assert p.exitcode == 0
159 160


161
@pytest.mark.require_ngpu(2)
162 163
@pytest.mark.isolated_distributed
def test_user_set_get():
164 165
    @dist.launcher
    def worker():
166 167 168 169 170 171
        # set in race condition
        dist.get_client().user_set("foo", 1)
        # get in race condition
        ret = dist.get_client().user_get("foo")
        assert ret == 1

172
    worker()
173 174


175 176 177 178 179
def test_oprmm_hashable():
    lhs = (CollectiveComm(), ParamPackConcat(), ParamPackSplit())
    rhs = (CollectiveComm(), ParamPackConcat(), ParamPackSplit())
    assert lhs == rhs
    assert hash(lhs) == hash(rhs)
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195


def test_param_pack_split():
    a = mge.Tensor(np.ones((10,), np.int32))
    b, c = param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)])
    assert np.allclose(b.numpy(), a.numpy()[1])
    assert np.allclose(c.numpy(), a.numpy()[1:].reshape(3, 3))


def test_param_pack_concat():
    a = mge.Tensor(np.ones((1,), np.int32))
    b = mge.Tensor(np.ones((3, 3), np.int32))
    offsets_val = [0, 1, 1, 10]
    offsets = mge.Tensor(offsets_val, np.int32)
    c = param_pack_concat([a, b], offsets, offsets_val)
    assert np.allclose(np.concatenate([a.numpy(), b.numpy().flatten()]), c.numpy())
196 197 198 199


@pytest.mark.require_ngpu(2)
@pytest.mark.parametrize("early_return", [False, True], ids=["common", "early_return"])
200
@pytest.mark.parametrize("output_size", [10, 10000], ids=["small_size", "large_size"])
201
@pytest.mark.isolated_distributed
202
def test_collect_results(early_return, output_size):
203 204 205 206
    @dist.launcher
    def worker():
        if early_return:
            exit(0)
207
        return [dist.get_rank()] * output_size
208 209 210 211 212 213 214

    results = worker()
    world_size = len(results)
    assert world_size > 0
    expects = (
        [None] * world_size
        if early_return
215
        else [[dev] * output_size for dev in range(world_size)]
216 217
    )
    assert results == expects
218 219 220 221 222 223 224 225 226 227 228 229 230 231


@pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed
def test_user_set_pop():
    @dist.launcher
    def worker():
        # set in race condition
        dist.get_client().user_set("foo", 1)
        if dist.get_rank() == 1:
            ret = dist.get_client().user_pop("foo")
            assert ret == 1

    worker()