test_xla_elemwise.py 6.0 KB
Newer Older
1 2
import platform

3
import numpy as np
4
import pytest
5 6 7 8 9

import megengine as mge
import megengine.functional as F
import megengine.jit as jit
import megengine.tensor as tensor
10
from megengine import is_cuda_available
11 12 13
from megengine.autodiff.grad_manager import GradManager


14 15 16
@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38")
@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now")
@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now")
17 18 19 20
def test_elemwise():
    np.random.seed(123)
    mge.random.seed(123)

21
    def tester(felemwise, *inp_shapes, backward=True, dtype=None, atol=1e-5, **kwargs):
22
        dtype = dtype or np.float32
23 24 25 26 27 28 29 30 31 32 33 34 35
        if dtype in [np.int16, np.int32, np.uint16, np.uint32]:
            inps = [
                tensor(np.random.randint(0, 10, size=inp_shape), dtype=dtype)
                for inp_shape in inp_shapes
            ]
        else:
            inps = [
                tensor(0.1 * np.random.randn(*inp_shape), dtype=dtype)
                for inp_shape in inp_shapes
            ]
        doup = tensor(
            0.1 * np.random.randn(*felemwise(*inps, **kwargs).shape), dtype=dtype
        )
36 37 38

        gm = GradManager()

39
        @jit.xla_trace(without_host=True)
40
        def func(inps, doup):
41 42 43 44
            if backward:
                gm.attach(inps)
                with gm:
                    oup = felemwise(*inps, **kwargs)
45 46
                    gm.backward(oup, doup)
                    return [oup, *[inp.grad for inp in inps]]
47 48 49
            else:
                oup = felemwise(*inps, **kwargs)
                return [oup]
50 51 52

        mge_rsts = func(inps, doup)
        xla_rsts = func(inps, doup)
53
        for _, (mge_rst, xla_rst) in enumerate(zip(mge_rsts, xla_rsts)):
54 55 56 57
            np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=atol)

    tester(F.neg, (4, 16, 12, 12), dtype=np.float32, atol=1e-5)
    tester(F.abs, (2, 32, 16), dtype=np.float32, atol=1e-5)
58 59 60 61 62 63 64 65 66 67 68 69
    tester(F.sin, (1, 16, 3, 1), dtype=np.float32, atol=1e-5)
    tester(F.cos, (4, 16, 3), dtype=np.float32, atol=1e-5)
    tester(F.tan, (4, 16, 1), dtype=np.float32, atol=1e-5)
    tester(F.sinh, (4, 16, 1), dtype=np.float32, atol=1e-5)
    tester(F.cosh, (3, 16, 1), dtype=np.float32, atol=1e-5)
    tester(F.tanh, (4, 6, 3, 1), dtype=np.float32, atol=5e-4)
    tester(F.asin, (4, 1, 3, 1), dtype=np.float32, atol=1e-5)
    # tester(F.acos, (4, 16, 3, 1), dtype=np.float32, atol=1e-5) # xla compute error
    tester(F.atan, (4, 16, 3, 1), dtype=np.float32, atol=1e-5)
    tester(F.asinh, (4, 1, 3, 1), dtype=np.float32, atol=1e-5)
    tester(F.acosh, (4, 1), dtype=np.float32, atol=1e-5)
    tester(F.atanh, (1,), dtype=np.float32, atol=1e-5)
70 71
    tester(F.exp, (2, 8), dtype=np.float32, atol=1e-5)
    tester(F.sqrt, (32,), dtype=np.float32, atol=1e-5)
72
    tester(F.square, (32,), dtype=np.float32, atol=1e-5)
73
    tester(F.log, (8, 8, 16), dtype=np.float32, atol=1e-5)
74 75 76 77 78 79
    tester(F.log1p, (8, 1, 16), dtype=np.float32, atol=1e-5)
    tester(F.expm1, (6, 8, 2), dtype=np.float32, atol=1e-5)
    tester(F.floor, (4, 16, 1, 1), backward=False, dtype=np.float32, atol=1e-5)
    tester(F.ceil, (4, 1, 1), backward=False, dtype=np.float32, atol=1e-5)
    tester(F.round, (1, 4, 1), backward=False, dtype=np.float32, atol=1e-5)
    tester(F.clip, (4, 16, 1), dtype=np.float32, atol=1e-5, lower=-1.0, upper=1.0)
80 81
    tester(F.relu, (1,), dtype=np.float32, atol=1e-5)
    tester(F.gelu, (4, 16, 12, 12), dtype=np.float32, atol=2e-5)
82 83 84 85 86 87 88 89 90
    tester(F.sigmoid, (4, 16, 16, 12), dtype=np.float32, atol=1e-5)
    tester(F.hsigmoid, (4, 16, 16, 12), dtype=np.float32, atol=1e-5)
    tester(F.hswish, (4, 16, 16, 12), dtype=np.float32, atol=1e-5)
    tester(F.relu6, (12, 16, 1), dtype=np.float32, atol=1e-5)
    tester(F.leaky_relu, (1, 16, 1), dtype=np.float32, atol=1e-5)
    tester(F.leaky_relu, (12, 16, 1), dtype=np.float32, atol=1e-5, negative_slope=0.5)
    tester(F.silu, (4, 16, 12, 12), dtype=np.float32, atol=1e-5)
    tester(F.logsigmoid, (4, 16, 12, 12), dtype=np.float32, atol=1e-5)
    tester(F.softplus, (4, 16, 12, 12), dtype=np.float32, atol=1e-5)
91 92 93
    tester(F.add, (4, 16, 12, 12), (4, 16, 12, 12), dtype=np.float32, atol=1e-5)
    tester(F.sub, (4, 16, 12, 12), (4, 16, 1, 1), dtype=np.float32, atol=1e-5)
    tester(F.mul, (4, 16, 12, 12), (1, 1, 12, 12), dtype=np.float32, atol=1e-5)
94 95 96 97 98 99 100 101 102
    tester(F.div, (4, 16, 1, 1), (4, 16, 12, 12), atol=5e-4)
    tester(F.floor_div, (4, 16, 12, 12), (4, 16, 1, 1), backward=False, atol=5e-5)
    # tester(F.mod, (8, 1, 4), (8, 1, 1), backward=False, dtype=np.int32, atol=1e-5) # xla not support
    tester(F.pow, (4, 1, 12, 12), (1, 16, 12, 12), dtype=np.float32, atol=5e-5)
    tester(F.prelu, (4, 16, 12, 12), (1,), dtype=np.float32, atol=1e-5)
    tester(F.prelu, (16, 5, 12), (1, 5, 1), dtype=np.float32, atol=1e-5)
    tester(F.logaddexp, (16, 5, 12), (1, 5, 12), dtype=np.float32, atol=1e-5)
    tester(F.maximum, (1, 5, 1), (1, 5, 12), dtype=np.float32, atol=1e-5)
    tester(F.minimum, (1, 5, 12), (16, 5, 12), dtype=np.float32, atol=1e-5)
103 104

    tester(
105
        F.left_shift, (4, 16, 12, 12), (1, 1, 12, 12), backward=False, dtype=np.int32
106 107
    )
    tester(
108
        F.right_shift, (4, 16, 12, 12), (1, 1, 12, 12), backward=False, dtype=np.int32
109
    )
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124

    tester(F.equal, (4, 16, 12, 12), (1, 1), backward=False)
    tester(F.not_equal, (4, 16, 12, 12), (4, 16, 1, 1), backward=False)
    tester(F.greater, (4, 16, 1, 1), (4, 16, 12, 12), backward=False)
    tester(F.greater_equal, (16, 1, 1), (4, 16, 12, 12), backward=False)
    tester(F.less, (4, 16, 12, 1), (4, 16, 12, 12), backward=False)
    tester(F.less_equal, (1, 1, 12, 12), (4, 16, 12, 12), backward=False)

    # bool is not support in dlpack now
    # tester(F.logical_and, (4, 16, 12, 12), (1, 1), backward=False, dtype=np.bool8)
    # tester(F.logical_or, (4, 16, 12, 12), (4, 16, 1, 1), backward=False, dtype=np.bool8)
    # tester(
    #     F.logical_xor, (4, 16, 1, 1), (4, 16, 12, 12), backward=False, dtype=np.bool8
    # )
    # tester(F.logical_not, (16, 1, 1), backward=False, dtype=np.bool8)