diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index a7739d5882b93fea4de2351813a57bf97179f76a..9e40def9439e18ec6ad58a70fd5f6ba2246d20d8 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -1592,7 +1592,8 @@ void ConvertBatchNormToElemwisePass::apply(OptState& state) const { SymbolVar bias = {rewriter.get_var(bn->input(2))}; SymbolVar mean = {rewriter.get_var(bn->input(3))}; SymbolVar variance = {rewriter.get_var(bn->input(4))}; - SymbolVar invsqrt_variance = opr::PowC::make(variance, {-0.5}); + SymbolVar invsqrt_variance = opr::PowC::make(variance + + variance.make_scalar_dt(float(bn->param().epsilon)), {-0.5}); auto res = scale * (x - mean) * invsqrt_variance + bias; rewriter.replace_var( opr->output(4), res.node(), diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index c6a7d7950bebce4224acec38c67f3db8e9abba21..44989430f8c1c3f0943c43bab13ca91fc28bb3a8 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -1404,7 +1404,7 @@ TEST(TestGoptInference, ConvertBatchNormPass) { auto func = graph->compile({make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)}); func->execute(); - MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-2); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5); } TEST(TestGoptInference, ConvBiasNonlinearityFusePass) {