helper.py 4.2 KB
Newer Older
1 2 3 4 5 6 7 8 9
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools
10
import multiprocessing as mp
11
from collections import defaultdict
12
from typing import Callable
13
from weakref import WeakSet
14

15
import numpy as np
16

17 18 19 20 21
from megengine.autodiff.grad_manager import GradManager, get_backwarding_grad_manager
from megengine.device import get_default_device, get_device_count

from ..functional.param_pack import get_offsets, pack_allreduce_split
from ..functional.utils import copy
22
from ..utils.future import Future
23 24
from .functional import all_reduce_sum, broadcast
from .group import WORLD, group_barrier, is_distributed
25 26


27
class TensorFuture(Future):
28 29 30 31 32 33 34 35 36 37 38
    def device(self):
        raise "Sorry, this tensor is not ready"

    def numpy(self):
        raise "Sorry, this tensor is not ready"

    def shape(self):
        raise "Sorry, this tensor is not ready"

    def dtype(self):
        raise "Sorry, this tensor is not ready"
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54


def synchronized(func: Callable):
    """Decorator. Decorated function will synchronize when finished.
    Specifically, we use this to prevent data race during hub.load"""

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        if not is_distributed():
            return func(*args, **kwargs)

        ret = func(*args, **kwargs)
        group_barrier()
        return ret

    return wrapper
55 56 57 58 59 60 61 62 63 64 65 66 67


def get_device_count_by_fork(device_type: str):
    q = mp.Queue()

    def worker(queue):
        num = get_device_count(device_type)
        queue.put(num)

    p = mp.Process(target=worker, args=(q,))
    p.start()
    p.join()
    return q.get()
68 69


70
def bcast_list_(params, group):
71 72 73 74 75 76
    for p in params:
        p._reset(broadcast(p, group))


class AllreduceCallback:
    def __init__(self, reduce_method, group=WORLD):
77 78
        reduce_method = reduce_method.lower()
        assert reduce_method in ["sum", "mean"]
79 80
        self._reduce_method = reduce_method
        self._group = group
81
        self._marked_gm = WeakSet()
82 83 84 85 86 87 88 89 90
        self._param_pack_thd = 10 * 1024 * 1024
        self._reset()

    def _reset(self):
        self._params = []
        self._gradients_dict = dict()
        self._futures_dict = dict()
        self._packing_list = defaultdict(list)
        self._packing_size = defaultdict(int)
91
        self._grad_origin_device = dict()
92 93 94 95 96 97 98 99 100 101 102

    def _pack(self, dtype):
        grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]]
        shapes = [p.shape for p in self._packing_list[dtype]]
        reduced_grads = pack_allreduce_split(
            grad_list, shapes, self._group, self._reduce_method
        )
        for param, grad in zip(self._packing_list[dtype], reduced_grads):
            self._gradients_dict[param] = grad
        self._packing_list[dtype] = []
        self._packing_size[dtype] = 0
103 104

    def __call__(self, param, grad):
105 106
        gm = get_backwarding_grad_manager()
        assert isinstance(gm, GradManager)
107
        if gm not in self._marked_gm:
108
            gm._register_after_backward_callback(self._flush)
109
            self._marked_gm.add(gm)
110
        self._params.append(param)
111
        self._futures_dict[param] = TensorFuture(ack=False)
112
        self._gradients_dict[param] = grad
113
        self._grad_origin_device[param] = str(grad.device)
114 115 116 117 118 119 120

        dtype_str = str(np.dtype(param.dtype))
        dtype_size = np.dtype(param.dtype).itemsize
        self._packing_list[dtype_str].append(param)
        self._packing_size[dtype_str] += int(np.prod(param.shape)) * dtype_size
        if self._packing_size[dtype_str] > self._param_pack_thd:
            self._pack(dtype_str)
121 122 123
        return self._futures_dict[param]

    def _flush(self):
124
        for dtype in sorted(self._packing_list.keys()):
125 126 127
            self._pack(dtype)
        for param in self._params:
            grad = self._gradients_dict[param]
128
            grad = copy(grad, self._grad_origin_device[param])
129 130
            self._futures_dict[param].set(grad)
        self._reset()
131 132 133


make_allreduce_cb = AllreduceCallback