diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index c4b70d6cdefd988acd24dc5dbfa15825e58622f8..c95f5402ff73aebff38472f065e66943ba145a07 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -36,7 +36,7 @@ from ..core._imperative_rt.ops import ( ) from ..core._trace_option import set_symbolic_shape from ..core._wrap import device as as_device -from ..core.ops.builtin import BackwardGraph, OpDef +from ..core.ops.builtin import BackwardGraph, BatchNorm, OpDef from ..core.ops.special import Const from ..core.tensor import megbrain_graph as G from ..core.tensor.utils import setscalar @@ -833,6 +833,10 @@ class trace: if isinstance(op, BackwardGraph): ovars = G.apply_backward_varnode(op, *ivars) else: + if isinstance(op, BatchNorm): + assert ( + op.fwd_mode == BatchNorm.FwdMode.INFERENCE + ), "can not dump BatchNorm in training mode, maybe you forget to do model.eval()?" ovars = G.apply_normal_varnode(op, *ivars) AutoNaming.record_opnode(ovars[0].op) diff --git a/imperative/python/test/integration/test_trace_dump.py b/imperative/python/test/integration/test_trace_dump.py index e0b876c516d545ec2a5f3dc837e027e5daa0892c..c719ee94b24f93966ec71031fb3848d4444c1ffd 100644 --- a/imperative/python/test/integration/test_trace_dump.py +++ b/imperative/python/test/integration/test_trace_dump.py @@ -11,6 +11,7 @@ import os import tempfile import numpy as np +import pytest import megengine as mge import megengine.functional as F @@ -140,3 +141,15 @@ def test_xornet_trace_dump(): with mkstemp() as out: pred_fun.dump(out, arg_names=["data"], output_names=["label"]) + + +def test_dump_bn_train_mode(): + @trace(symbolic=True, capture_as_const=True) + def bn_train(data): + pred = M.BatchNorm2d(10)(data).sum() + return pred + + data = mge.tensor(np.random.random((10, 10, 10, 10))) + bn_train(data) + with pytest.raises(AssertionError): + bn_train.dump("test.mge")