helper.py 4.8 KB
Newer Older
1 2 3 4 5 6 7 8 9
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools
10
import multiprocessing as mp
11
from collections import defaultdict
12
from typing import Callable
13
from weakref import WeakSet
14

15
import numpy as np
16

17 18 19 20 21
from megengine.autodiff.grad_manager import GradManager, get_backwarding_grad_manager
from megengine.device import get_default_device, get_device_count

from ..functional.param_pack import get_offsets, pack_allreduce_split
from ..functional.utils import copy
22
from ..utils.future import Future
23
from .functional import all_reduce_sum, broadcast
24
from .group import WORLD, Group, group_barrier, is_distributed
25 26


27
class TensorFuture(Future):
28 29 30 31 32 33 34 35 36 37 38
    def device(self):
        raise "Sorry, this tensor is not ready"

    def numpy(self):
        raise "Sorry, this tensor is not ready"

    def shape(self):
        raise "Sorry, this tensor is not ready"

    def dtype(self):
        raise "Sorry, this tensor is not ready"
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54


def synchronized(func: Callable):
    """Decorator. Decorated function will synchronize when finished.
    Specifically, we use this to prevent data race during hub.load"""

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        if not is_distributed():
            return func(*args, **kwargs)

        ret = func(*args, **kwargs)
        group_barrier()
        return ret

    return wrapper
55 56


57
def _get_device_count_worker(queue, device_type):
58 59
    num = get_device_count(device_type)
    queue.put(num)
60 61


62
def get_device_count_by_fork(device_type: str):
63 64 65 66
    """Get device count in fork thread.
    See https://stackoverflow.com/questions/22950047/cuda-initialization-error-after-fork
    for more information.
    """
67
    q = mp.Queue()
68
    p = mp.Process(target=_get_device_count_worker, args=(q, device_type))
69 70 71
    p.start()
    p.join()
    return q.get()
72 73


74 75 76 77 78 79 80 81
def bcast_list_(inps: list, group: Group = WORLD):
    """Broadcast tensors between given group.

    :param inps: input tensors.
    :param group: communication group.
    """
    for inp in inps:
        inp._reset(broadcast(inp, group))
82 83 84


class AllreduceCallback:
85 86 87 88 89 90 91
    """Allreduce Callback with tensor fusion optimization.

    :param reduce_method: the method to reduce gradiants.
    :param group: communication group.
    """

    def __init__(self, reduce_method: str, group: Group = WORLD):
92
        reduce_method = reduce_method.lower()
93
        assert reduce_method in ["sum", "mean"], "reduce_method should be sum or mean"
94 95
        self._reduce_method = reduce_method
        self._group = group
96
        self._marked_gm = WeakSet()
97 98 99 100 101 102 103 104 105
        self._param_pack_thd = 10 * 1024 * 1024
        self._reset()

    def _reset(self):
        self._params = []
        self._gradients_dict = dict()
        self._futures_dict = dict()
        self._packing_list = defaultdict(list)
        self._packing_size = defaultdict(int)
106
        self._grad_origin_device = dict()
107 108 109 110 111 112 113 114 115 116 117

    def _pack(self, dtype):
        grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]]
        shapes = [p.shape for p in self._packing_list[dtype]]
        reduced_grads = pack_allreduce_split(
            grad_list, shapes, self._group, self._reduce_method
        )
        for param, grad in zip(self._packing_list[dtype], reduced_grads):
            self._gradients_dict[param] = grad
        self._packing_list[dtype] = []
        self._packing_size[dtype] = 0
118 119

    def __call__(self, param, grad):
120 121
        gm = get_backwarding_grad_manager()
        assert isinstance(gm, GradManager)
122
        if gm not in self._marked_gm:
123
            gm._register_after_backward_callback(self._flush)
124
            self._marked_gm.add(gm)
125
        self._params.append(param)
126
        self._futures_dict[param] = TensorFuture(ack=False)
127
        self._gradients_dict[param] = grad
128
        self._grad_origin_device[param] = str(grad.device)
129 130 131 132 133 134 135

        dtype_str = str(np.dtype(param.dtype))
        dtype_size = np.dtype(param.dtype).itemsize
        self._packing_list[dtype_str].append(param)
        self._packing_size[dtype_str] += int(np.prod(param.shape)) * dtype_size
        if self._packing_size[dtype_str] > self._param_pack_thd:
            self._pack(dtype_str)
136 137 138
        return self._futures_dict[param]

    def _flush(self):
139
        for dtype in sorted(self._packing_list.keys()):
140 141 142
            self._pack(dtype)
        for param in self._params:
            grad = self._gradients_dict[param]
143
            grad = copy(grad, self._grad_origin_device[param])
144 145
            self._futures_dict[param].set(grad)
        self._reset()
146 147 148


make_allreduce_cb = AllreduceCallback