test_observer.py 3.4 KB
Newer Older
1 2 3 4 5 6 7 8
import platform

import numpy as np
import pytest

import megengine as mge
import megengine.distributed as dist
from megengine.distributed.helper import get_device_count_by_fork
9
from megengine.quantization import QuantMode, create_qparams
10 11
from megengine.quantization.observer import (
    ExponentialMovingAverageObserver,
12
    HistogramObserver,
13 14 15 16 17 18 19 20 21 22 23
    MinMaxObserver,
    Observer,
    PassiveObserver,
    SyncExponentialMovingAverageObserver,
    SyncMinMaxObserver,
)


def test_observer():
    with pytest.raises(TypeError):
        Observer("qint8")
24 25 26 27 28 29


def test_min_max_observer():
    x = np.random.rand(3, 3, 3, 3).astype("float32")
    np_min, np_max = x.min(), x.max()
    x = mge.tensor(x)
30
    m = MinMaxObserver()
31
    m(x)
32 33 34 35 36 37 38 39 40 41 42 43 44
    np.testing.assert_allclose(m.min_val.numpy(), np_min)
    np.testing.assert_allclose(m.max_val.numpy(), np_max)


def test_exponential_moving_average_observer():
    t = np.random.rand()
    x1 = np.random.rand(3, 3, 3, 3).astype("float32")
    x2 = np.random.rand(3, 3, 3, 3).astype("float32")
    expected_min = x1.min() * t + x2.min() * (1 - t)
    expected_max = x1.max() * t + x2.max() * (1 - t)
    m = ExponentialMovingAverageObserver(momentum=t)
    m(mge.tensor(x1, dtype=np.float32))
    m(mge.tensor(x2, dtype=np.float32))
45 46
    np.testing.assert_allclose(m.min_val.numpy(), expected_min, atol=1e-5)
    np.testing.assert_allclose(m.max_val.numpy(), expected_max, atol=1e-5)
47 48


49 50 51 52 53 54 55 56 57 58
def test_histogram_observer():
    x = np.random.rand(3, 3, 3, 3).astype("float32")
    np_min, np_max = x.min(), x.max()
    x = mge.tensor(x)
    m = HistogramObserver()
    m(x)
    np.testing.assert_allclose(m.min_val.numpy(), np_min)
    np.testing.assert_allclose(m.max_val.numpy(), np_max)


59
def test_passive_observer():
60
    qparams = create_qparams(QuantMode.SYMMERTIC, "qint8", mge.tensor(1.0))
61
    m = PassiveObserver("qint8")
62
    m.set_qparams(qparams)
63
    assert m.orig_scale == 1.0
64 65 66 67
    assert m.scale.numpy() == 1.0
    assert m.get_qparams().dtype_meta == qparams.dtype_meta
    assert m.get_qparams().scale == qparams.scale
    assert m.get_qparams() == qparams
68 69


70
@pytest.mark.require_ngpu(2)
71 72
@pytest.mark.isolated_distributed
def test_sync_min_max_observer():
73 74
    word_size = get_device_count_by_fork("gpu")
    x = np.random.rand(3 * word_size, 3, 3, 3).astype("float32")
75 76
    np_min, np_max = x.min(), x.max()

77 78 79
    @dist.launcher
    def worker():
        rank = dist.get_rank()
80
        m = SyncMinMaxObserver()
81
        y = mge.tensor(x[rank * 3 : (rank + 1) * 3])
82 83 84
        m(y)
        assert m.min_val == np_min and m.max_val == np_max

85
    worker()
86 87


88
@pytest.mark.require_ngpu(2)
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
@pytest.mark.isolated_distributed
def test_sync_exponential_moving_average_observer():
    word_size = get_device_count_by_fork("gpu")
    t = np.random.rand()
    x1 = np.random.rand(3 * word_size, 3, 3, 3).astype("float32")
    x2 = np.random.rand(3 * word_size, 3, 3, 3).astype("float32")
    expected_min = x1.min() * t + x2.min() * (1 - t)
    expected_max = x1.max() * t + x2.max() * (1 - t)

    @dist.launcher
    def worker():
        rank = dist.get_rank()
        m = SyncExponentialMovingAverageObserver(momentum=t)
        y1 = mge.tensor(x1[rank * 3 : (rank + 1) * 3])
        y2 = mge.tensor(x2[rank * 3 : (rank + 1) * 3])
        m(y1)
        m(y2)
106 107
        np.testing.assert_allclose(m.min_val.numpy(), expected_min, atol=1e-6)
        np.testing.assert_allclose(m.max_val.numpy(), expected_max, atol=1e-6)
108 109

    worker()