test_observer.py 3.7 KB
Newer Older
1 2 3 4 5 6 7 8
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

9 10 11 12 13 14 15
import platform

import numpy as np
import pytest

import megengine as mge
import megengine.distributed as dist
16
from megengine.device import get_device_count
17
from megengine.quantization import QuantMode, create_qparams
18 19
from megengine.quantization.observer import (
    ExponentialMovingAverageObserver,
20
    HistogramObserver,
21 22 23 24 25 26 27 28 29 30 31
    MinMaxObserver,
    Observer,
    PassiveObserver,
    SyncExponentialMovingAverageObserver,
    SyncMinMaxObserver,
)


def test_observer():
    with pytest.raises(TypeError):
        Observer("qint8")
32 33 34 35 36 37


def test_min_max_observer():
    x = np.random.rand(3, 3, 3, 3).astype("float32")
    np_min, np_max = x.min(), x.max()
    x = mge.tensor(x)
38
    m = MinMaxObserver()
39
    m(x)
40 41 42 43 44 45 46 47 48 49 50 51 52
    np.testing.assert_allclose(m.min_val.numpy(), np_min)
    np.testing.assert_allclose(m.max_val.numpy(), np_max)


def test_exponential_moving_average_observer():
    t = np.random.rand()
    x1 = np.random.rand(3, 3, 3, 3).astype("float32")
    x2 = np.random.rand(3, 3, 3, 3).astype("float32")
    expected_min = x1.min() * t + x2.min() * (1 - t)
    expected_max = x1.max() * t + x2.max() * (1 - t)
    m = ExponentialMovingAverageObserver(momentum=t)
    m(mge.tensor(x1, dtype=np.float32))
    m(mge.tensor(x2, dtype=np.float32))
53 54
    np.testing.assert_allclose(m.min_val.numpy(), expected_min, atol=1e-5)
    np.testing.assert_allclose(m.max_val.numpy(), expected_max, atol=1e-5)
55 56


57 58 59 60 61 62 63 64 65 66
def test_histogram_observer():
    x = np.random.rand(3, 3, 3, 3).astype("float32")
    np_min, np_max = x.min(), x.max()
    x = mge.tensor(x)
    m = HistogramObserver()
    m(x)
    np.testing.assert_allclose(m.min_val.numpy(), np_min)
    np.testing.assert_allclose(m.max_val.numpy(), np_max)


67
def test_passive_observer():
68
    qparams = create_qparams(QuantMode.SYMMERTIC, "qint8", mge.tensor(1.0))
69
    m = PassiveObserver("qint8")
70
    m.set_qparams(qparams)
71
    assert m.orig_scale == 1.0
72 73 74 75
    assert m.scale.numpy() == 1.0
    assert m.get_qparams().dtype_meta == qparams.dtype_meta
    assert m.get_qparams().scale == qparams.scale
    assert m.get_qparams() == qparams
76 77


78
@pytest.mark.require_ngpu(2)
79 80
@pytest.mark.isolated_distributed
def test_sync_min_max_observer():
81
    word_size = get_device_count("gpu")
82
    x = np.random.rand(3 * word_size, 3, 3, 3).astype("float32")
83 84
    np_min, np_max = x.min(), x.max()

85 86 87
    @dist.launcher
    def worker():
        rank = dist.get_rank()
88
        m = SyncMinMaxObserver()
89
        y = mge.tensor(x[rank * 3 : (rank + 1) * 3])
90 91 92
        m(y)
        assert m.min_val == np_min and m.max_val == np_max

93
    worker()
94 95


96
@pytest.mark.require_ngpu(2)
97 98
@pytest.mark.isolated_distributed
def test_sync_exponential_moving_average_observer():
99
    word_size = get_device_count("gpu")
100 101 102 103 104 105 106 107 108 109 110 111 112 113
    t = np.random.rand()
    x1 = np.random.rand(3 * word_size, 3, 3, 3).astype("float32")
    x2 = np.random.rand(3 * word_size, 3, 3, 3).astype("float32")
    expected_min = x1.min() * t + x2.min() * (1 - t)
    expected_max = x1.max() * t + x2.max() * (1 - t)

    @dist.launcher
    def worker():
        rank = dist.get_rank()
        m = SyncExponentialMovingAverageObserver(momentum=t)
        y1 = mge.tensor(x1[rank * 3 : (rank + 1) * 3])
        y2 = mge.tensor(x2[rank * 3 : (rank + 1) * 3])
        m(y1)
        m(y2)
114 115
        np.testing.assert_allclose(m.min_val.numpy(), expected_min, atol=1e-6)
        np.testing.assert_allclose(m.max_val.numpy(), expected_max, atol=1e-6)
116 117

    worker()