提交 442b4f6c 编写于 作者: M Megvii Engine Team

test(traced_module): add some testcases for traced module

GitOrigin-RevId: 0d6bb20b2b5110b5ecd280ec055bb14aed74ebfc
上级 f2691566
......@@ -201,7 +201,8 @@ class Apply(Expr):
NodeMixin.wrap_safe(i, Constant.make(i))
apply_node = cls.make(opdef)
for i in inputs:
apply_node.add_input(NodeMixin.get(i))
assert isinstance(i, RawTensor)
apply_node.inputs.append(NodeMixin.get(i))
unset_module_tracing()
outputs = apply(opdef, *inputs)
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
from typing import Callable, NamedTuple
SUPPORTED_TYPE = {}
......@@ -9,11 +19,22 @@ def register_supported_type(type, flatten, unflatten):
SUPPORTED_TYPE[type] = NodeType(flatten, unflatten)
def _dict_flatten(inp):
aux_data = []
results = []
for key, value in sorted(inp.items()):
results.append(value)
aux_data.append(key)
return results, aux_data
def _dict_unflatten(inps, aux_data):
return dict(zip(aux_data, inps))
register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x))
register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: list(x))
register_supported_type(
dict, lambda x: (list(x.values()), list(x.keys())), lambda x, y: dict(zip(y, x))
)
register_supported_type(dict, _dict_flatten, _dict_unflatten)
register_supported_type(
slice,
lambda x: ([x.start, x.stop, x.step], None),
......@@ -68,6 +89,8 @@ class TreeDef:
class LeafDef(TreeDef):
def __init__(self, type):
if not isinstance(type, collections.abc.Sequence):
type = (type,)
super().__init__(type, None, [])
self.num_leaves = 1
......@@ -77,4 +100,4 @@ class LeafDef(TreeDef):
return leaves[0]
def __repr__(self):
return "Leaf({})".format(self.type.__name__)
return "Leaf({})".format(", ".join(t.__name__ for t in self.type))
......@@ -14,6 +14,7 @@ import megengine as mge
import megengine.autodiff as ad
import megengine.functional as F
from megengine import Tensor
from megengine.experimental.traced_module import trace_module
from megengine.module import Linear, Module
from megengine.optimizer import SGD
......@@ -71,8 +72,13 @@ class XORNet(Module):
return x
def test_training_converge():
@pytest.mark.parametrize("test_traced_module", [True, False])
def test_training_converge(test_traced_module):
net = XORNet()
if test_training_converge:
inp = Tensor(np.random.random((14, 2)))
net = trace_module(net, inp)
opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
gm = ad.GradManager().attach(net.parameters())
......@@ -105,9 +111,8 @@ def test_training_converge():
xx = xx.reshape((ngrid * ngrid, 1))
yy = yy.reshape((ngrid * ngrid, 1))
data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32))
pred = infer(data).numpy()
precision = calculate_precision(data.numpy(), pred)
pred = infer(data)
precision = calculate_precision(data.numpy(), pred.numpy())
assert precision == 1.0, "Test precision must be high enough, get {}".format(
precision
)
......@@ -15,6 +15,7 @@ import megengine.autodiff as ad
import megengine.functional as F
import megengine.optimizer as optim
from megengine import Tensor
from megengine.experimental.traced_module import trace_module
from megengine.jit import trace
from megengine.module import Linear, Module
from megengine.optimizer import SGD
......@@ -73,8 +74,12 @@ class XORNet(Module):
return x
def test_training_converge():
@pytest.mark.parametrize("test_traced_module", [True, False])
def test_training_converge(test_traced_module):
net = XORNet()
if test_traced_module:
inp = Tensor(np.random.random((14, 2)))
net = trace_module(net, inp)
opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
gm = ad.GradManager().attach(net.parameters())
......@@ -110,9 +115,8 @@ def test_training_converge():
xx = xx.reshape((ngrid * ngrid, 1))
yy = yy.reshape((ngrid * ngrid, 1))
data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32))
pred = infer(data).numpy()
precision = calculate_precision(data.numpy(), pred)
pred = infer(data)
precision = calculate_precision(data.numpy(), pred.numpy())
print("precision=", precision)
assert precision == 1.0, "Test precision must be high enough, get {}".format(
precision
......
......@@ -19,6 +19,7 @@ import megengine.module as M
import megengine.optimizer as optim
from megengine import tensor
from megengine.autodiff import GradManager
from megengine.experimental.traced_module import trace_module
from megengine.jit import trace
......
......@@ -15,6 +15,7 @@ import pytest
import megengine as mge
import megengine.functional as F
from megengine import Parameter, Tensor, tensor
from megengine.experimental.traced_module import TracedModule, trace_module
from megengine.module import (
BatchNorm1d,
BatchNorm2d,
......@@ -67,8 +68,18 @@ class MyModule(Module):
return x
def test_module_api():
@pytest.mark.parametrize("test_traced_module", [True, False])
def test_module_api(test_traced_module):
m = MyModule()
if test_traced_module:
buff = m.buff
param = m.param
m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16))))
assert "buff" not in m.__dict__
assert "param" not in m.__dict__
m.buff = buff
m.param = param
assert list(m.children()) == [m.bn, m.i]
assert list(m.named_children()) == [("bn", m.bn), ("i", m.i)]
assert list(m.modules()) == [m, m.bn, m.i, m.i.bn]
......@@ -141,8 +152,11 @@ def test_module_api():
assert m.bn.training == False and m.i.bn.training == False
def test_module_api_reuse_submodule():
@pytest.mark.parametrize("test_traced_module", [True, False])
def test_module_api_reuse_submodule(test_traced_module):
m = MyModule()
if test_traced_module:
m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16))))
m.h = m.i # pylint: disable=attribute-defined-outside-init
assert list(m.modules()) == [m, m.bn, m.i, m.i.bn]
assert list(m.named_modules()) == [
......@@ -153,15 +167,21 @@ def test_module_api_reuse_submodule():
]
def test_module_api_iterable_stability():
@pytest.mark.parametrize("test_traced_module", [True, False])
def test_module_api_iterable_stability(test_traced_module):
m = MyModule()
if test_traced_module:
m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16))))
l = list(m.modules())
for _ in range(100):
assert list(m.modules()) == l
def test_module_api_hooks():
@pytest.mark.parametrize("test_traced_module", [True, False])
def test_module_api_hooks(test_traced_module):
net = MyModule()
if test_traced_module:
net = trace_module(net, Tensor(np.zeros((1, 4, 1, 1))))
pre_hook_num = 0
post_hook_num = 0
hooks = []
......@@ -383,11 +403,16 @@ class Simple(Module):
self.conv1.weight = self.conv0.weight
def forward(self, inputs):
pass
x = self.conv0(inputs)
y = self.conv1(inputs)
return x + y
def test_shared_param():
@pytest.mark.parametrize("test_traced_module", [True, False])
def test_shared_param(test_traced_module):
net = Simple()
if test_traced_module:
net = trace_module(net, tensor(np.random.random((1, 1, 8, 8))))
assert net.conv0.weight is net.conv1.weight
data = tensor(np.random.random((1, 1, 8, 8)).astype(np.float32))
np.testing.assert_allclose(net.conv0(data).numpy(), net.conv1(data).numpy())
......@@ -449,15 +474,21 @@ def test_shared_param_1d():
np.testing.assert_allclose(conv0(data).numpy(), conv1(data).numpy())
def test_pickle_module():
@pytest.mark.parametrize("test_traced_module", [True, False])
def test_pickle_module(test_traced_module):
data_shape = (2, 28)
data = tensor(np.random.random(data_shape))
mlp = MLP()
pred_gt = mlp(data)
if test_traced_module:
mlp = trace_module(mlp, data)
# pickle before forward
with BytesIO() as fout:
mge.save(mlp, fout)
fout.seek(0)
mlp1 = mge.load(fout)
if test_traced_module:
assert type(mlp1) == TracedModule
pred0 = mlp1(data)
pred1 = mlp(data)
......@@ -467,8 +498,11 @@ def test_pickle_module():
mge.save(mlp, fout)
fout.seek(0)
mlp1 = mge.load(fout)
if test_traced_module:
assert type(mlp1) == TracedModule
pred2 = mlp1(data)
np.testing.assert_allclose(pred_gt.numpy(), pred1.numpy(), atol=5e-6)
np.testing.assert_allclose(pred0.numpy(), pred1.numpy(), atol=5e-6)
np.testing.assert_allclose(pred0.numpy(), pred2.numpy(), atol=5e-6)
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import io
import numpy as np
import megengine.functional as F
import megengine.module as M
import megengine.utils.comp_graph_tools as cgtools
from megengine.experimental.traced_module import trace_module
from megengine.jit import trace
from megengine.module import Module
class MyBlock(Module):
def __init__(self, in_channels, channels):
super(MyBlock, self).__init__()
self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False)
self.bn1 = M.BatchNorm2d(channels)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x) + 1
return x
class MyModule(Module):
def __init__(self):
super(MyModule, self).__init__()
self.block0 = MyBlock(8, 4)
self.block1 = MyBlock(4, 2)
def forward(self, x):
x = self.block0(x)
x = self.block1(x)
return x
def test_jit_trace():
module = MyModule()
module.eval()
x = F.ones((1, 8, 14, 14))
expect = module(x)
traced_module = trace_module(module, x)
func = trace(traced_module, capture_as_const=True)
np.testing.assert_array_equal(func(x), expect)
model = io.BytesIO()
func.dump(model)
model.seek(0)
infer_cg = cgtools.GraphInference(model)
np.testing.assert_allclose(
list(infer_cg.run(x.numpy()).values())[0], expect, atol=1e-6
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册