test_multi_out_jit.py 8.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import unittest
17

18
import numpy as np
19
from utils import check_output, extra_cc_args, paddle_includes
20 21

import paddle
22
from paddle import static
23 24
from paddle.utils.cpp_extension import get_build_directory, load
from paddle.utils.cpp_extension.extension_utils import run_cmd
25

26
# Because Windows don't use docker, the shared lib already exists in the
Z
Zhou Wei 已提交
27
# cache dir, it will not be compiled again unless the shared lib is removed.
28
file = f'{get_build_directory()}\\multi_out_jit\\multi_out_jit.pyd'
29
if os.name == 'nt' and os.path.isfile(file):
30
    cmd = f'del {file}'
Z
Zhou Wei 已提交
31
    run_cmd(cmd, True)
32

33
# Compile and load custom op Just-In-Time.
34 35 36
multi_out_module = load(
    name='multi_out_jit',
    sources=['multi_out_test_op.cc'],
37
    extra_include_paths=paddle_includes,  # add for Coverage CI
38
    extra_cxx_cflags=extra_cc_args,  # test for cflags
39 40
    verbose=True,
)
41 42


43
def discrete_out_dynamic(use_custom, device, dtype, np_w, np_x, np_y, np_z):
44 45 46 47 48
    paddle.set_device(device)
    w = paddle.to_tensor(np_w, dtype=dtype, stop_gradient=False)
    x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False)
    y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False)
    z = paddle.to_tensor(np_z, dtype=dtype, stop_gradient=False)
49
    if use_custom:
50 51 52 53 54 55 56 57
        out = multi_out_module.discrete_out(w, x, y, z)
    else:
        out = w * 1 + x * 2 + y * 3 + z * 4

    out.backward()
    return out.numpy(), w.grad.numpy(), y.grad.numpy()


58
def discrete_out_static(use_custom, device, dtype, np_w, np_x, np_y, np_z):
59 60 61 62 63 64 65 66 67 68 69 70
    paddle.enable_static()
    paddle.set_device(device)
    with static.scope_guard(static.Scope()):
        with static.program_guard(static.Program()):
            w = static.data(name="w", shape=[None, np_x.shape[1]], dtype=dtype)
            x = static.data(name="x", shape=[None, np_x.shape[1]], dtype=dtype)
            y = static.data(name="y", shape=[None, np_y.shape[1]], dtype=dtype)
            z = static.data(name="z", shape=[None, np_z.shape[1]], dtype=dtype)
            w.stop_gradient = False
            x.stop_gradient = False
            y.stop_gradient = False
            z.stop_gradient = False
71
            if use_custom:
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
                out = multi_out_module.discrete_out(w, x, y, z)
            else:
                out = w * 1 + x * 2 + y * 3 + z * 4
            static.append_backward(out)

            exe = static.Executor()
            exe.run(static.default_startup_program())

            out_v, w_grad_v, y_grad_v = exe.run(
                static.default_main_program(),
                feed={
                    "w": np_w.astype(dtype),
                    "x": np_x.astype(dtype),
                    "y": np_y.astype(dtype),
                    "z": np_z.astype(dtype),
                },
                fetch_list=[
                    out.name,
                    w.name + "@GRAD",
                    y.name + "@GRAD",
                ],
            )
    paddle.disable_static()
    return out_v, w_grad_v, y_grad_v


98 99
class TestMultiOutputDtypes(unittest.TestCase):
    def setUp(self):
100
        self.custom_op = multi_out_module.multi_out
101
        self.dtypes = ['float32', 'float64']
102
        self.devices = ['cpu']
103 104 105 106 107
        self.np_w = np.random.uniform(-1, 1, [4, 8]).astype("float32")
        self.np_x = np.random.uniform(-1, 1, [4, 8]).astype("float32")
        self.np_y = np.random.uniform(-1, 1, [4, 8]).astype("float32")
        self.np_z = np.random.uniform(-1, 1, [4, 8]).astype("float32")

108 109 110
    def run_static(self, device, dtype):
        paddle.set_device(device)
        x_data = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
111

112 113 114
        with paddle.static.scope_guard(paddle.static.Scope()):
            with paddle.static.program_guard(paddle.static.Program()):
                x = paddle.static.data(name='X', shape=[None, 8], dtype=dtype)
115 116
                outs = self.custom_op(x)

117 118
                exe = paddle.static.Executor()
                exe.run(paddle.static.default_startup_program())
119 120 121 122 123
                res = exe.run(
                    paddle.static.default_main_program(),
                    feed={'X': x_data},
                    fetch_list=outs,
                )
124 125

                return res
126 127 128 129 130 131 132 133

    def check_multi_outputs(self, outs, is_dynamic=False):
        out, zero_float64, one_int32 = outs
        if is_dynamic:
            zero_float64 = zero_float64.numpy()
            one_int32 = one_int32.numpy()
        # Fake_float64
        self.assertTrue('float64' in str(zero_float64.dtype))
134 135
        check_output(
            zero_float64, np.zeros([4, 8]).astype('float64'), "zero_float64"
136
        )
137 138
        # ZFake_int32
        self.assertTrue('int32' in str(one_int32.dtype))
139
        check_output(one_int32, np.ones([4, 8]).astype('int32'), "one_int32")
140

141
    def test_multi_out_static(self):
142 143 144 145 146 147
        paddle.enable_static()
        for device in self.devices:
            for dtype in self.dtypes:
                res = self.run_static(device, dtype)
                self.check_multi_outputs(res)
        paddle.disable_static()
148

149
    def test_multi_out_dynamic(self):
150 151 152 153 154
        for device in self.devices:
            for dtype in self.dtypes:
                paddle.set_device(device)
                x_data = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
                x = paddle.to_tensor(x_data)
155 156
                outs = self.custom_op(x)

157 158
                self.assertTrue(len(outs) == 3)
                self.check_multi_outputs(outs, True)
159

160 161 162 163 164 165 166 167 168 169 170 171
    def test_discrete_out_static(self):
        for device in self.devices:
            for dtype in self.dtypes:
                (pd_out, pd_w_grad, pd_y_grad,) = discrete_out_static(
                    False,
                    device,
                    dtype,
                    self.np_w,
                    self.np_x,
                    self.np_y,
                    self.np_z,
                )
172 173 174 175 176
                (
                    custom_out,
                    custom_w_grad,
                    custom_y_grad,
                ) = discrete_out_static(
177 178 179 180 181 182 183 184
                    True,
                    device,
                    dtype,
                    self.np_w,
                    self.np_x,
                    self.np_y,
                    self.np_z,
                )
185
                check_output(custom_out, pd_out, "out")
186
                # NOTE: In static mode, the output gradient of custom operator has been optimized to shape=[1]. However, native paddle op's output shape = [4, 8], hence we need to fetch pd_w_grad[0][0] (By the way, something wrong with native paddle's gradient, the outputs with other indexes instead of pd_w_grad[0][0] is undefined in this unittest.)
187 188
                check_output(custom_w_grad, pd_w_grad[0][0], "w_grad")
                check_output(custom_y_grad, pd_y_grad[0][0], "y_grad")
189 190 191 192 193 194 195 196 197 198 199 200 201

    def test_discrete_out_dynamic(self):
        for device in self.devices:
            for dtype in self.dtypes:
                (pd_out, pd_w_grad, pd_y_grad,) = discrete_out_dynamic(
                    False,
                    device,
                    dtype,
                    self.np_w,
                    self.np_x,
                    self.np_y,
                    self.np_z,
                )
202 203 204 205 206
                (
                    custom_out,
                    custom_w_grad,
                    custom_y_grad,
                ) = discrete_out_dynamic(
207 208 209 210 211 212 213 214
                    True,
                    device,
                    dtype,
                    self.np_w,
                    self.np_x,
                    self.np_y,
                    self.np_z,
                )
215 216 217
                check_output(custom_out, pd_out, "out")
                check_output(custom_w_grad, pd_w_grad, "w_grad")
                check_output(custom_y_grad, pd_y_grad, "y_grad")
218

219

220 221
if __name__ == '__main__':
    unittest.main()