helper.py 4.2 KB
Newer Older
1 2 3 4 5 6 7 8 9
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools
10
import multiprocessing as mp
11
from collections import defaultdict
12 13
from typing import Callable

14
import numpy as np
15

16 17 18 19 20
from megengine.autodiff.grad_manager import GradManager, get_backwarding_grad_manager
from megengine.device import get_default_device, get_device_count

from ..functional.param_pack import get_offsets, pack_allreduce_split
from ..functional.utils import copy
21
from ..utils.future import Future
22 23
from .functional import all_reduce_sum, broadcast
from .group import WORLD, group_barrier, is_distributed
24 25 26 27 28 29 30 31 32 33 34 35 36 37


class FakeTensor(Future):
    def device(self):
        raise "Sorry, this tensor is not ready"

    def numpy(self):
        raise "Sorry, this tensor is not ready"

    def shape(self):
        raise "Sorry, this tensor is not ready"

    def dtype(self):
        raise "Sorry, this tensor is not ready"
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53


def synchronized(func: Callable):
    """Decorator. Decorated function will synchronize when finished.
    Specifically, we use this to prevent data race during hub.load"""

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        if not is_distributed():
            return func(*args, **kwargs)

        ret = func(*args, **kwargs)
        group_barrier()
        return ret

    return wrapper
54 55 56 57 58 59 60 61 62 63 64 65 66


def get_device_count_by_fork(device_type: str):
    q = mp.Queue()

    def worker(queue):
        num = get_device_count(device_type)
        queue.put(num)

    p = mp.Process(target=worker, args=(q,))
    p.start()
    p.join()
    return q.get()
67 68 69 70 71 72 73 74 75


def bcast_params_(params, group):
    for p in params:
        p._reset(broadcast(p, group))


class AllreduceCallback:
    def __init__(self, reduce_method, group=WORLD):
76 77
        reduce_method = reduce_method.lower()
        assert reduce_method in ["sum", "mean"]
78 79
        self._reduce_method = reduce_method
        self._group = group
80
        self._marked_gm = set()
81 82 83 84 85 86 87 88 89
        self._param_pack_thd = 10 * 1024 * 1024
        self._reset()

    def _reset(self):
        self._params = []
        self._gradients_dict = dict()
        self._futures_dict = dict()
        self._packing_list = defaultdict(list)
        self._packing_size = defaultdict(int)
90
        self._grad_origin_device = dict()
91 92 93 94 95 96 97 98 99 100 101

    def _pack(self, dtype):
        grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]]
        shapes = [p.shape for p in self._packing_list[dtype]]
        reduced_grads = pack_allreduce_split(
            grad_list, shapes, self._group, self._reduce_method
        )
        for param, grad in zip(self._packing_list[dtype], reduced_grads):
            self._gradients_dict[param] = grad
        self._packing_list[dtype] = []
        self._packing_size[dtype] = 0
102 103

    def __call__(self, param, grad):
104 105
        gm = get_backwarding_grad_manager()
        assert isinstance(gm, GradManager)
106
        if gm not in self._marked_gm:
107
            gm.register_after_backward_callback(self._flush)
108
            self._marked_gm.add(gm)
109 110 111
        self._params.append(param)
        self._futures_dict[param] = FakeTensor(ack=False)
        self._gradients_dict[param] = grad
112 113 114 115 116 117 118 119
        self._grad_origin_device[param] = str(grad.device)

        dtype_str = str(np.dtype(param.dtype))
        dtype_size = np.dtype(param.dtype).itemsize
        self._packing_list[dtype_str].append(param)
        self._packing_size[dtype_str] += int(np.prod(param.shape)) * dtype_size
        if self._packing_size[dtype_str] > self._param_pack_thd:
            self._pack(dtype_str)
120 121 122
        return self._futures_dict[param]

    def _flush(self):
123
        for dtype in sorted(self._packing_list.keys()):
124 125 126
            self._pack(dtype)
        for param in self._params:
            grad = self._gradients_dict[param]
127
            grad = copy(grad, self._grad_origin_device[param])
128 129
            self._futures_dict[param].set(grad)
        self._reset()
130 131 132


make_allreduce_cb = AllreduceCallback