helper.py 7.6 KB
Newer Older
1 2 3 4 5 6 7 8 9
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools
10
import multiprocessing as mp
11
from collections import defaultdict
12
from typing import Callable
13
from weakref import WeakSet
14

15
import numpy as np
16

17 18 19
from megengine.autodiff.grad_manager import GradManager, get_backwarding_grad_manager
from megengine.device import get_default_device, get_device_count

20 21
from ..core.ops.builtin import ParamPackConcat, ParamPackSplit
from ..core.tensor.core import apply
22
from ..functional.utils import copy
23
from ..tensor import Tensor
24
from ..utils.future import Future
25
from .functional import all_reduce_sum, broadcast
26
from .group import WORLD, Group, group_barrier, is_distributed
27 28


29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
def param_pack_split(inp: Tensor, offsets: list, shapes: list):
    r"""
    Returns split tensor to tensor list as offsets and shapes described,
            only used for ``parampack``.

    :param inp: input tensor.
    :param offsets: offsets of outputs, length of `2 * n`,
            while n is tensor nums you want to split,
            format `[begin0, end0, begin1, end1]`.
    :param shapes: tensor shapes of outputs.
    :return: splitted tensors.

    Examples:

    .. testcode::

        import numpy as np
        from megengine import tensor
        from megengine.distributed.helper import param_pack_split

        a = tensor(np.ones((10,), np.int32))
        b, c = param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)])
        print(b.numpy())
        print(c.numpy())

    Outputs:

    .. testoutput::

        [1]
        [[1 1 1]
         [1 1 1]
         [1 1 1]]

    """
    op = ParamPackSplit()
    op.offsets = offsets
66 67 68 69 70 71
    op.shapes = [s or (1,) for s in shapes]
    outputs = apply(op, inp)
    for s, x in zip(shapes, outputs):
        if not s:
            x._isscalar = True
    return outputs
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131


def param_pack_concat(inps: list, offsets: Tensor, offsets_val: list):
    r"""
    Returns concated tensor, only used for ``parampack``.

    :param inps: input tensors.
    :param offsets: device value of offsets.
    :param offsets_val: offsets of inputs, length of `2 * n`,
            format `[begin0, end0, begin1, end1]`.
    :return: concated tensor.

    Examples:

    .. testcode::

        import numpy as np
        from megengine import tensor
        from megengine.distributed.helper import param_pack_concat

        a = tensor(np.ones((1,), np.int32))
        b = tensor(np.ones((3, 3), np.int32))
        offsets_val = [0, 1, 1, 10]
        offsets = tensor(offsets_val, np.int32)
        c = param_pack_concat([a, b], offsets, offsets_val)
        print(c.numpy())

    Outputs:

    .. testoutput::

        [1 1 1 1 1 1 1 1 1 1]

    """
    op = ParamPackConcat()
    op.offsets = offsets_val
    return apply(op, *inps, offsets)[0]


def get_offsets(shapes):
    offsets = []
    offset = 0
    for shape in shapes:
        offsets.append(offset)
        offset += int(np.prod(shape))
        offsets.append(offset)
    return offsets


def pack_allreduce_split(pack_list, shapes, group, reduce_method):
    offsets_val = get_offsets(shapes)
    offsets = Tensor(offsets_val)
    packed_grads = param_pack_concat(pack_list, offsets, offsets_val)
    packed_grads = all_reduce_sum(packed_grads, group, group.comp_node)
    if reduce_method == "mean":
        packed_grads /= group.size
    grads = param_pack_split(packed_grads, offsets_val, shapes)
    return grads


132
class TensorFuture(Future):
133 134 135 136 137 138 139 140 141 142 143
    def device(self):
        raise "Sorry, this tensor is not ready"

    def numpy(self):
        raise "Sorry, this tensor is not ready"

    def shape(self):
        raise "Sorry, this tensor is not ready"

    def dtype(self):
        raise "Sorry, this tensor is not ready"
144 145 146


def synchronized(func: Callable):
147 148
    """
    Decorator. Decorated function will synchronize when finished.
149 150 151 152 153 154 155 156 157 158 159 160
    Specifically, we use this to prevent data race during hub.load"""

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        if not is_distributed():
            return func(*args, **kwargs)

        ret = func(*args, **kwargs)
        group_barrier()
        return ret

    return wrapper
161 162


163
def _get_device_count_worker(queue, device_type):
164 165
    num = get_device_count(device_type)
    queue.put(num)
166 167


168
def get_device_count_by_fork(device_type: str):
169 170
    """
    Get device count in fork thread.
171 172 173
    See https://stackoverflow.com/questions/22950047/cuda-initialization-error-after-fork
    for more information.
    """
174
    q = mp.Queue()
175
    p = mp.Process(target=_get_device_count_worker, args=(q, device_type))
176 177 178
    p.start()
    p.join()
    return q.get()
179 180


181
def bcast_list_(inps: list, group: Group = WORLD):
182 183
    """
    Broadcast tensors between given group.
184 185 186 187 188 189

    :param inps: input tensors.
    :param group: communication group.
    """
    for inp in inps:
        inp._reset(broadcast(inp, group))
190 191 192


class AllreduceCallback:
193 194
    """
    Allreduce Callback with tensor fusion optimization.
195 196 197 198 199 200

    :param reduce_method: the method to reduce gradiants.
    :param group: communication group.
    """

    def __init__(self, reduce_method: str, group: Group = WORLD):
201
        reduce_method = reduce_method.lower()
202
        assert reduce_method in ["sum", "mean"], "reduce_method should be sum or mean"
203 204
        self._reduce_method = reduce_method
        self._group = group
205
        self._marked_gm = WeakSet()
206 207 208 209 210 211 212 213 214
        self._param_pack_thd = 10 * 1024 * 1024
        self._reset()

    def _reset(self):
        self._params = []
        self._gradients_dict = dict()
        self._futures_dict = dict()
        self._packing_list = defaultdict(list)
        self._packing_size = defaultdict(int)
215
        self._grad_origin_device = dict()
216 217

    def _pack(self, dtype):
218 219
        if len(self._packing_list[dtype]) == 0:
            return
220 221 222 223 224 225 226 227 228
        grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]]
        shapes = [p.shape for p in self._packing_list[dtype]]
        reduced_grads = pack_allreduce_split(
            grad_list, shapes, self._group, self._reduce_method
        )
        for param, grad in zip(self._packing_list[dtype], reduced_grads):
            self._gradients_dict[param] = grad
        self._packing_list[dtype] = []
        self._packing_size[dtype] = 0
229 230

    def __call__(self, param, grad):
231
        param = param.__wrapped__
232 233
        gm = get_backwarding_grad_manager()
        assert isinstance(gm, GradManager)
234
        if gm not in self._marked_gm:
235
            gm._register_after_backward_callback(self._flush)
236
            self._marked_gm.add(gm)
237
        self._params.append(param)
238
        self._futures_dict[param] = TensorFuture(ack=False)
239
        self._gradients_dict[param] = grad
240
        self._grad_origin_device[param] = str(grad.device)
241 242 243 244 245 246 247

        dtype_str = str(np.dtype(param.dtype))
        dtype_size = np.dtype(param.dtype).itemsize
        self._packing_list[dtype_str].append(param)
        self._packing_size[dtype_str] += int(np.prod(param.shape)) * dtype_size
        if self._packing_size[dtype_str] > self._param_pack_thd:
            self._pack(dtype_str)
248 249 250
        return self._futures_dict[param]

    def _flush(self):
251
        for dtype in sorted(self._packing_list.keys()):
252 253 254
            self._pack(dtype)
        for param in self._params:
            grad = self._gradients_dict[param]
255
            grad = copy(grad, self._grad_origin_device[param])
256 257
            self._futures_dict[param].set(grad)
        self._reset()
258 259 260


make_allreduce_cb = AllreduceCallback