提交 e6caa9ff 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(opr): add bn backward for inference mode

GitOrigin-RevId: bb643cb62fbba90ca8846a3550f88bf6763ddd58
上级 c90fa087
......@@ -100,16 +100,6 @@ class _BatchNorm(Module):
if _bias is not None:
_bias = _bias.detach()
# Need to expand to elementwise operations here
# see MGB_IMPL_OPR_GRAD(BatchNormForward) in src/opr/impl/dnn/batch_norm.cpp
scale = (self.running_var + self.eps) ** (-0.5)
if _weight is not None:
scale *= _weight
bias = -self.running_mean * scale
if _bias is not None:
bias += _bias
return inp * scale + bias
if self.training and self.track_running_stats:
exponential_average_factor = self.momentum
else:
......
......@@ -19,9 +19,13 @@ from megengine.jit import trace
from megengine.module import BatchNorm2d, Conv2d, Module, Sequential, SyncBatchNorm
def run_frozen_bn(BNModule, use_trace=False, use_symbolic=False):
def run_frozen_bn(BNModule, is_training, use_trace, use_symbolic):
nchannel = 3
m = BNModule(nchannel, freeze=True)
if is_training:
m.train()
else:
m.eval()
var = 4.0
bias = 1.0
shape = (1, nchannel, 1, 1)
......@@ -51,30 +55,33 @@ def run_frozen_bn(BNModule, use_trace=False, use_symbolic=False):
train_fn = trace(train_fn, symbolic=use_symbolic)
for _ in range(3):
loss = train_fn(megengine.Tensor(data))
np.testing.assert_equal(m.running_var.numpy(), saved_var)
np.testing.assert_equal(m.running_mean.numpy(), saved_mean)
loss = train_fn(megengine.tensor(data))
if not is_training:
np.testing.assert_equal(m.running_var.numpy(), saved_var)
np.testing.assert_equal(m.running_mean.numpy(), saved_mean)
np.testing.assert_almost_equal(
loss.numpy(), ((data - bias) / np.sqrt(var)).mean(), 5
)
np.testing.assert_equal(m.weight.numpy(), saved_wt)
np.testing.assert_equal(m.bias.numpy(), saved_bias)
np.testing.assert_almost_equal(
loss.numpy(), ((data - bias) / np.sqrt(var)).mean(), 5
)
def test_frozen_bn():
run_frozen_bn(BatchNorm2d)
run_frozen_bn(BatchNorm2d, True, False)
run_frozen_bn(BatchNorm2d, True, True)
@pytest.mark.parametrize("is_training", [False, True])
@pytest.mark.parametrize("use_trace", [False, True])
@pytest.mark.parametrize("use_symbolic", [False, True])
def test_frozen_bn(is_training, use_trace, use_symbolic):
run_frozen_bn(BatchNorm2d, is_training, use_trace, use_symbolic)
@pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed
def test_frozen_synced_bn():
@pytest.mark.parametrize("is_training", [False, True])
@pytest.mark.parametrize("use_trace", [False, True])
@pytest.mark.parametrize("use_symbolic", [False, True])
def test_frozen_synced_bn(is_training, use_trace, use_symbolic):
@dist.launcher(n_gpus=2)
def worker():
run_frozen_bn(SyncBatchNorm)
run_frozen_bn(SyncBatchNorm, True, False)
run_frozen_bn(SyncBatchNorm, True, True)
run_frozen_bn(SyncBatchNorm, is_training, use_trace, use_symbolic)
worker()
......@@ -190,8 +197,13 @@ def test_trace_several_syncbn(trace_mode):
# https://github.com/MegEngine/MegEngine/issues/145
def test_frozen_bn_no_affine():
@pytest.mark.parametrize("is_training", [False, True])
def test_frozen_bn_no_affine(is_training):
nchannel = 3
m = BatchNorm2d(nchannel, freeze=True, affine=False)
data = tensor(np.random.random((6, nchannel, 2, 2)).astype("float32"))
if is_training:
m.train()
else:
m.eval()
data = megengine.tensor(np.random.random((6, nchannel, 2, 2)).astype("float32"))
m(data).numpy()
......@@ -12,6 +12,8 @@
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/io.h"
#include "megbrain/graph/grad_impl.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/tensor_manip.h"
#include "../internal/megdnn_opr_wrapper.inl"
......@@ -243,16 +245,34 @@ void BatchNormForward::mem_plan_fwd_in2out_writable() {
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(BatchNormForward) {
mgb_assert(opr.param().fwd_mode == BatchNorm::Param::FwdMode::TRAINING,
"batch norm could only take grad in training mode");
mgb_assert(wrt_idx < 5, "wrt_idx %zu is out of range", wrt_idx);
VarNodeArray ret(opr.input().size(), nullptr);
SymbolVarArray grad = BatchNormBackward::make(
opr.input(0), out_grad[4],
opr.output(2), opr.output(3),
opr.input(1), opr.param());
for (size_t i = 0; i < 3; ++ i) {
ret[i] = grad[(i + 2) % 3].node();
SymbolVarArray grad;
switch (opr.param().fwd_mode) {
case BatchNorm::Param::FwdMode::TRAINING:
grad = BatchNormBackward::make(
opr.input(0), out_grad[4],
opr.output(2), opr.output(3),
opr.input(1), opr.param());
for (size_t i = 0; i < 3; ++ i) {
ret[i] = grad[(i + 2) % 3].node();
}
return ret;
case BatchNorm::Param::FwdMode::INFERENCE:
auto sqrt_var = PowC::make((SymbolVar{opr.input(4)}
+ static_cast<dt_float32>(opr.param().epsilon)), 0.5, opr.config());
auto d_bn_scale_unreduced = SymbolVar{out_grad[4]} *
(SymbolVar{opr.input(0)} - SymbolVar{opr.input(3)}) / sqrt_var;
auto d_bn_scale = Reduce::make(d_bn_scale_unreduced,
Reduce::Param::Mode::SUM, GetVarShape::make(opr.input(1)));
auto d_bn_bias = Reduce::make(out_grad[4],
Reduce::Param::Mode::SUM, GetVarShape::make(opr.input(2)));
auto dx = SymbolVar{out_grad[4]} * SymbolVar{opr.input(1)} / sqrt_var;
ret[0] = dx.node();
ret[1] = d_bn_scale.node();
ret[2] = d_bn_bias.node();
return ret;
}
return ret;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册