diff --git a/imperative/python/megengine/module/batchnorm.py b/imperative/python/megengine/module/batchnorm.py index ad188d1a5f186fe36e38be9de6584f4081ec9b89..2e95f25a67d5bbb9537bb30cfdc3c44bbd50d390 100644 --- a/imperative/python/megengine/module/batchnorm.py +++ b/imperative/python/megengine/module/batchnorm.py @@ -100,16 +100,6 @@ class _BatchNorm(Module): if _bias is not None: _bias = _bias.detach() - # Need to expand to elementwise operations here - # see MGB_IMPL_OPR_GRAD(BatchNormForward) in src/opr/impl/dnn/batch_norm.cpp - scale = (self.running_var + self.eps) ** (-0.5) - if _weight is not None: - scale *= _weight - bias = -self.running_mean * scale - if _bias is not None: - bias += _bias - return inp * scale + bias - if self.training and self.track_running_stats: exponential_average_factor = self.momentum else: diff --git a/imperative/python/test/integration/test_bn.py b/imperative/python/test/integration/test_bn.py index a1cd9628084c848576cb4bf09f5df2335a058251..7ce2c2ddcf352bec8d96cf4c70d8552415f1f22e 100644 --- a/imperative/python/test/integration/test_bn.py +++ b/imperative/python/test/integration/test_bn.py @@ -19,9 +19,13 @@ from megengine.jit import trace from megengine.module import BatchNorm2d, Conv2d, Module, Sequential, SyncBatchNorm -def run_frozen_bn(BNModule, use_trace=False, use_symbolic=False): +def run_frozen_bn(BNModule, is_training, use_trace, use_symbolic): nchannel = 3 m = BNModule(nchannel, freeze=True) + if is_training: + m.train() + else: + m.eval() var = 4.0 bias = 1.0 shape = (1, nchannel, 1, 1) @@ -51,30 +55,33 @@ def run_frozen_bn(BNModule, use_trace=False, use_symbolic=False): train_fn = trace(train_fn, symbolic=use_symbolic) for _ in range(3): - loss = train_fn(megengine.Tensor(data)) - np.testing.assert_equal(m.running_var.numpy(), saved_var) - np.testing.assert_equal(m.running_mean.numpy(), saved_mean) + loss = train_fn(megengine.tensor(data)) + if not is_training: + np.testing.assert_equal(m.running_var.numpy(), saved_var) + np.testing.assert_equal(m.running_mean.numpy(), saved_mean) + np.testing.assert_almost_equal( + loss.numpy(), ((data - bias) / np.sqrt(var)).mean(), 5 + ) np.testing.assert_equal(m.weight.numpy(), saved_wt) np.testing.assert_equal(m.bias.numpy(), saved_bias) - np.testing.assert_almost_equal( - loss.numpy(), ((data - bias) / np.sqrt(var)).mean(), 5 - ) -def test_frozen_bn(): - run_frozen_bn(BatchNorm2d) - run_frozen_bn(BatchNorm2d, True, False) - run_frozen_bn(BatchNorm2d, True, True) +@pytest.mark.parametrize("is_training", [False, True]) +@pytest.mark.parametrize("use_trace", [False, True]) +@pytest.mark.parametrize("use_symbolic", [False, True]) +def test_frozen_bn(is_training, use_trace, use_symbolic): + run_frozen_bn(BatchNorm2d, is_training, use_trace, use_symbolic) @pytest.mark.require_ngpu(2) @pytest.mark.isolated_distributed -def test_frozen_synced_bn(): +@pytest.mark.parametrize("is_training", [False, True]) +@pytest.mark.parametrize("use_trace", [False, True]) +@pytest.mark.parametrize("use_symbolic", [False, True]) +def test_frozen_synced_bn(is_training, use_trace, use_symbolic): @dist.launcher(n_gpus=2) def worker(): - run_frozen_bn(SyncBatchNorm) - run_frozen_bn(SyncBatchNorm, True, False) - run_frozen_bn(SyncBatchNorm, True, True) + run_frozen_bn(SyncBatchNorm, is_training, use_trace, use_symbolic) worker() @@ -190,8 +197,13 @@ def test_trace_several_syncbn(trace_mode): # https://github.com/MegEngine/MegEngine/issues/145 -def test_frozen_bn_no_affine(): +@pytest.mark.parametrize("is_training", [False, True]) +def test_frozen_bn_no_affine(is_training): nchannel = 3 m = BatchNorm2d(nchannel, freeze=True, affine=False) - data = tensor(np.random.random((6, nchannel, 2, 2)).astype("float32")) + if is_training: + m.train() + else: + m.eval() + data = megengine.tensor(np.random.random((6, nchannel, 2, 2)).astype("float32")) m(data).numpy() diff --git a/src/opr/impl/dnn/batch_norm.cpp b/src/opr/impl/dnn/batch_norm.cpp index 989755303d602ec65eaaa36b5a4d9ef58ba7380c..b3e2419980b542ac51d435b67ed3f7da4c681cc3 100644 --- a/src/opr/impl/dnn/batch_norm.cpp +++ b/src/opr/impl/dnn/batch_norm.cpp @@ -12,6 +12,8 @@ #include "megbrain/opr/dnn/batch_norm.h" #include "megbrain/opr/io.h" #include "megbrain/graph/grad_impl.h" +#include "megbrain/opr/basic_arith.h" +#include "megbrain/opr/tensor_manip.h" #include "../internal/megdnn_opr_wrapper.inl" @@ -243,16 +245,34 @@ void BatchNormForward::mem_plan_fwd_in2out_writable() { #if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(BatchNormForward) { - mgb_assert(opr.param().fwd_mode == BatchNorm::Param::FwdMode::TRAINING, - "batch norm could only take grad in training mode"); mgb_assert(wrt_idx < 5, "wrt_idx %zu is out of range", wrt_idx); VarNodeArray ret(opr.input().size(), nullptr); - SymbolVarArray grad = BatchNormBackward::make( - opr.input(0), out_grad[4], - opr.output(2), opr.output(3), - opr.input(1), opr.param()); - for (size_t i = 0; i < 3; ++ i) { - ret[i] = grad[(i + 2) % 3].node(); + SymbolVarArray grad; + switch (opr.param().fwd_mode) { + case BatchNorm::Param::FwdMode::TRAINING: + grad = BatchNormBackward::make( + opr.input(0), out_grad[4], + opr.output(2), opr.output(3), + opr.input(1), opr.param()); + for (size_t i = 0; i < 3; ++ i) { + ret[i] = grad[(i + 2) % 3].node(); + } + return ret; + case BatchNorm::Param::FwdMode::INFERENCE: + auto sqrt_var = PowC::make((SymbolVar{opr.input(4)} + + static_cast(opr.param().epsilon)), 0.5, opr.config()); + auto d_bn_scale_unreduced = SymbolVar{out_grad[4]} * + (SymbolVar{opr.input(0)} - SymbolVar{opr.input(3)}) / sqrt_var; + auto d_bn_scale = Reduce::make(d_bn_scale_unreduced, + Reduce::Param::Mode::SUM, GetVarShape::make(opr.input(1))); + auto d_bn_bias = Reduce::make(out_grad[4], + Reduce::Param::Mode::SUM, GetVarShape::make(opr.input(2))); + auto dx = SymbolVar{out_grad[4]} * SymbolVar{opr.input(1)} / sqrt_var; + + ret[0] = dx.node(); + ret[1] = d_bn_scale.node(); + ret[2] = d_bn_bias.node(); + return ret; } return ret; }