From a149784db0802c302f957d6e8211f2dfc29f2f0a Mon Sep 17 00:00:00 2001 From: zhaoting Date: Wed, 9 Sep 2020 10:08:04 +0800 Subject: [PATCH] fix CPU BatchNorm infer error --- .../cpu/mkldnn/fused_batch_norm_cpu_kernel.cc | 25 +++++++++++++------ 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/fused_batch_norm_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/fused_batch_norm_cpu_kernel.cc index 8654998ec..9cb821409 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/fused_batch_norm_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/fused_batch_norm_cpu_kernel.cc @@ -50,11 +50,13 @@ void FusedBatchNormCPUKernel::InitKernel(const CNodePtr &kernel_node) { dnnl::memory::desc scale_bias_desc = GetDefaultMemDesc({2, channel}); auto epsilon = AnfAlgo::GetNodeAttr(kernel_node, "epsilon"); auto prop_kind = dnnl::prop_kind::forward_inference; + auto normalization_flags = dnnl::normalization_flags::use_scale_shift | dnnl::normalization_flags::use_global_stats; if (is_train) { prop_kind = dnnl::prop_kind::forward_training; + normalization_flags = dnnl::normalization_flags::use_scale_shift; } dnnl::batch_normalization_forward::desc desc = - dnnl::batch_normalization_forward::desc(prop_kind, x_desc, epsilon, dnnl::normalization_flags::use_scale_shift); + dnnl::batch_normalization_forward::desc(prop_kind, x_desc, epsilon, normalization_flags); auto prim_desc = dnnl::batch_normalization_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); primitive_ = std::make_shared(prim_desc); AddArgument(DNNL_ARG_SRC, x_desc); @@ -74,14 +76,14 @@ bool FusedBatchNormCPUKernel::Launch(const std::vector &inpu auto wksp = reinterpret_cast(workspace[0]->addr); memcpy_s(wksp, workspace[0]->size, inputs[1]->addr, inputs[1]->size); memcpy_s(wksp + (inputs[1]->size / sizeof(float)), inputs[2]->size, inputs[2]->addr, inputs[2]->size); - - SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); - SetArgumentHandle(DNNL_ARG_MEAN, outputs[3]->addr); - SetArgumentHandle(DNNL_ARG_VARIANCE, outputs[4]->addr); - SetArgumentHandle(DNNL_ARG_SCALE_SHIFT, workspace[0]->addr); - SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); - ExecutePrimitive(); if (is_train) { + SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_MEAN, outputs[3]->addr); + SetArgumentHandle(DNNL_ARG_VARIANCE, outputs[4]->addr); + SetArgumentHandle(DNNL_ARG_SCALE_SHIFT, workspace[0]->addr); + SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); + ExecutePrimitive(); + auto moving_mean = reinterpret_cast(inputs[3]->addr); auto moving_variance = reinterpret_cast(inputs[4]->addr); auto mean = reinterpret_cast(outputs[3]->addr); @@ -90,6 +92,13 @@ bool FusedBatchNormCPUKernel::Launch(const std::vector &inpu moving_mean[i] = moving_mean[i] * (1 - momentum) + mean[i] * momentum; moving_variance[i] = moving_variance[i] * (1 - momentum) + variance[i] * momentum; } + } else { + SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_MEAN, inputs[3]->addr); + SetArgumentHandle(DNNL_ARG_VARIANCE, inputs[4]->addr); + SetArgumentHandle(DNNL_ARG_SCALE_SHIFT, workspace[0]->addr); + SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); + ExecutePrimitive(); } return true; } -- GitLab