helper.py 8.5 KB
Newer Older
1 2 3
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
4
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
5 6 7 8 9
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools
10
import multiprocessing as mp
11
from collections import defaultdict
12
from typing import Callable
13
from weakref import WeakSet
14

15
import numpy as np
16

17 18 19
from megengine.autodiff.grad_manager import GradManager, get_backwarding_grad_manager
from megengine.device import get_default_device, get_device_count

20
from ..core._imperative_rt.core2 import apply
21
from ..core.ops.builtin import ParamPackConcat, ParamPackSplit
22
from ..functional.tensor import copy
23
from ..tensor import Tensor
24
from ..utils.future import Future
25
from . import group as _group
26
from .functional import _bcast_param, all_reduce_sum, broadcast
27
from .group import WORLD, Group, group_barrier, is_distributed, override_backend
28 29


30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
def param_pack_split(inp: Tensor, offsets: list, shapes: list):
    r"""
    Returns split tensor to tensor list as offsets and shapes described,
            only used for ``parampack``.

    :param inp: input tensor.
    :param offsets: offsets of outputs, length of `2 * n`,
            while n is tensor nums you want to split,
            format `[begin0, end0, begin1, end1]`.
    :param shapes: tensor shapes of outputs.
    :return: splitted tensors.

    Examples:

    .. testcode::

        import numpy as np
        from megengine import tensor
        from megengine.distributed.helper import param_pack_split

        a = tensor(np.ones((10,), np.int32))
        b, c = param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)])
        print(b.numpy())
        print(c.numpy())

    Outputs:

    .. testoutput::

        [1]
        [[1 1 1]
         [1 1 1]
         [1 1 1]]

    """
    op = ParamPackSplit()
    op.offsets = offsets
67 68 69 70
    op.shapes = [s or (1,) for s in shapes]
    outputs = apply(op, inp)
    for s, x in zip(shapes, outputs):
        if not s:
71
            x._setscalar()
72
    return outputs
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121


def param_pack_concat(inps: list, offsets: Tensor, offsets_val: list):
    r"""
    Returns concated tensor, only used for ``parampack``.

    :param inps: input tensors.
    :param offsets: device value of offsets.
    :param offsets_val: offsets of inputs, length of `2 * n`,
            format `[begin0, end0, begin1, end1]`.
    :return: concated tensor.

    Examples:

    .. testcode::

        import numpy as np
        from megengine import tensor
        from megengine.distributed.helper import param_pack_concat

        a = tensor(np.ones((1,), np.int32))
        b = tensor(np.ones((3, 3), np.int32))
        offsets_val = [0, 1, 1, 10]
        offsets = tensor(offsets_val, np.int32)
        c = param_pack_concat([a, b], offsets, offsets_val)
        print(c.numpy())

    Outputs:

    .. testoutput::

        [1 1 1 1 1 1 1 1 1 1]

    """
    op = ParamPackConcat()
    op.offsets = offsets_val
    return apply(op, *inps, offsets)[0]


def get_offsets(shapes):
    offsets = []
    offset = 0
    for shape in shapes:
        offsets.append(offset)
        offset += int(np.prod(shape))
        offsets.append(offset)
    return offsets


122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
_enable_p2p_cache = None


def _check_enable_p2p():
    global _enable_p2p_cache
    if _enable_p2p_cache is not None:
        return _enable_p2p_cache
    cmd = ["nvidia-smi", "topo", "-p2p", "w"]
    import subprocess

    output = subprocess.run(cmd, stdout=subprocess.PIPE).stdout
    if output.count(b"OK") > 1:
        _enable_p2p_cache = True
        return True
    else:
        _enable_p2p_cache = False
        return False


141 142 143 144
def pack_allreduce_split(pack_list, shapes, group, reduce_method):
    offsets_val = get_offsets(shapes)
    offsets = Tensor(offsets_val)
    packed_grads = param_pack_concat(pack_list, offsets, offsets_val)
145

146 147 148 149 150 151 152
    packed_grads = all_reduce_sum(packed_grads, group, group.comp_node)
    if reduce_method == "mean":
        packed_grads /= group.size
    grads = param_pack_split(packed_grads, offsets_val, shapes)
    return grads


153
class TensorFuture(Future):
154 155 156 157 158 159 160 161 162 163 164
    def device(self):
        raise "Sorry, this tensor is not ready"

    def numpy(self):
        raise "Sorry, this tensor is not ready"

    def shape(self):
        raise "Sorry, this tensor is not ready"

    def dtype(self):
        raise "Sorry, this tensor is not ready"
165 166 167


def synchronized(func: Callable):
168 169
    """
    Decorator. Decorated function will synchronize when finished.
170 171 172 173 174 175 176 177 178 179 180 181
    Specifically, we use this to prevent data race during hub.load"""

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        if not is_distributed():
            return func(*args, **kwargs)

        ret = func(*args, **kwargs)
        group_barrier()
        return ret

    return wrapper
182 183


184
def _check_device_initialized(device_type: str, rank: int):
185
    try:
186
        test = Tensor(1, device=(device_type + str(rank)))
187 188 189 190 191 192 193 194 195
        inited = False
        del test
    except:
        inited = True
    errmsg = "The cuda env is set before the forked thread starts. Please do not use any cuda function or variable before forking."
    if inited:
        raise RuntimeError(errmsg)


196
def bcast_list_(inps: list, group: Group = WORLD):
197 198
    """
    Broadcast tensors between given group.
199 200 201 202 203

    :param inps: input tensors.
    :param group: communication group.
    """
    for inp in inps:
204
        inp._reset(_bcast_param(inp, group))
205 206 207


class AllreduceCallback:
208 209
    """
    Allreduce Callback with tensor fusion optimization.
210 211 212

    :param reduce_method: the method to reduce gradiants.
    :param group: communication group.
213
    :param backend: override distributed backend in allreduce
214 215
    """

216
    def __init__(self, reduce_method: str, group: Group = WORLD, backend: str = None):
217
        reduce_method = reduce_method.lower()
218
        assert reduce_method in ["sum", "mean"], "reduce_method should be sum or mean"
219 220
        self._reduce_method = reduce_method
        self._group = group
221
        self._marked_gm = WeakSet()
222 223
        self._param_pack_thd = 10 * 1024 * 1024
        self._reset()
224 225 226 227 228 229 230 231 232
        if backend is None:
            assert _group._sd, "please call init_process_group first"
            backend = _group._sd.backend
        if backend == "auto":
            if group.is_single_machine and not _check_enable_p2p():
                backend = "shm"
            else:
                backend = "nccl"
        self._backend = backend
233 234 235 236 237 238 239

    def _reset(self):
        self._params = []
        self._gradients_dict = dict()
        self._futures_dict = dict()
        self._packing_list = defaultdict(list)
        self._packing_size = defaultdict(int)
240
        self._grad_origin_device = dict()
241 242

    def _pack(self, dtype):
243 244
        if len(self._packing_list[dtype]) == 0:
            return
245
        grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]]
246
        shapes = [p._tuple_shape for p in self._packing_list[dtype]]
247 248 249 250
        with override_backend(self._backend):
            reduced_grads = pack_allreduce_split(
                grad_list, shapes, self._group, self._reduce_method
            )
251 252 253 254
        for param, grad in zip(self._packing_list[dtype], reduced_grads):
            self._gradients_dict[param] = grad
        self._packing_list[dtype] = []
        self._packing_size[dtype] = 0
255 256

    def __call__(self, param, grad):
257 258
        gm = get_backwarding_grad_manager()
        assert isinstance(gm, GradManager)
259
        if gm not in self._marked_gm:
260
            gm._register_after_backward_callback(self._flush)
261
            self._marked_gm.add(gm)
262
        self._params.append(param)
263
        self._futures_dict[param] = TensorFuture(ack=False)
264
        self._gradients_dict[param] = grad
265
        self._grad_origin_device[param] = str(grad.device)
266 267 268 269

        dtype_str = str(np.dtype(param.dtype))
        dtype_size = np.dtype(param.dtype).itemsize
        self._packing_list[dtype_str].append(param)
270
        self._packing_size[dtype_str] += int(np.prod(param._tuple_shape)) * dtype_size
271 272
        if self._packing_size[dtype_str] > self._param_pack_thd:
            self._pack(dtype_str)
273 274 275
        return self._futures_dict[param]

    def _flush(self):
276
        for dtype in sorted(self._packing_list.keys()):
277 278 279
            self._pack(dtype)
        for param in self._params:
            grad = self._gradients_dict[param]
280
            grad = copy(grad, self._grad_origin_device[param])
281 282
            self._futures_dict[param].set(grad)
        self._reset()
283 284 285


make_allreduce_cb = AllreduceCallback