diff --git a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc index 429a8b8456821f148804ac77ed8b388b2b2c45e9..177e539c4b6c294b23dfd10127b9606262d59f71 100644 --- a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc @@ -83,30 +83,11 @@ void eltwise_forward(const framework::ExecutionContext &ctx, const auto *x = ctx.Input("X"); auto *y = ctx.Output("Out"); - float alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 0; - float beta = ctx.HasAttr("beta") ? ctx.Attr("beta") : 0; - - // paddle uses beta but mkldnn uses alpha for swish - if (algorithm == mkldnn::algorithm::eltwise_swish) { - std::swap(alpha, beta); - } else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) { - alpha = ctx.Attr("threshold"); - } - - PADDLE_ENFORCE( - x->dims().size() >= 1 || x->dims().size() <= 6, - platform::errors::Unimplemented("Input dimension size can be 1, 2, 3, 4, " - "5, or 6, but now the dimension size is", - x->dims().size())); - bool is_inplaced = x->IsSharedBufferWith(*y); - auto src_tz = framework::vectorize(x->dims()); - auto src_format = src_tz.size() == 2 ? MKLDNNMemoryFormat::nc : x->format(); - - platform::ActivationMKLDNNHandler handler( - src_tz, algorithm, alpha, beta, src_format, dev_ctx, ctx.GetPlace(), - ctx.InputName("X"), is_inplaced); + platform::ActivationMKLDNNHandler handler(algorithm, ctx, dev_ctx, + ctx.GetPlace(), x, + ctx.InputName("X"), is_inplaced); auto src_memory_p = handler.AcquireSrcMemory(x); auto dst_memory_p = is_inplaced ? src_memory_p : handler.AcquireDstMemory(y); @@ -130,28 +111,8 @@ void eltwise_grad(const framework::ExecutionContext &ctx, const auto *diff_y = ctx.Input(framework::GradVarName("Out")); auto *diff_x = ctx.Output(framework::GradVarName("X")); - float alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 0; - float beta = ctx.HasAttr("beta") ? ctx.Attr("beta") : 0; - - // paddle uses beta but mkldnn uses alpha for swish - if (algorithm == mkldnn::algorithm::eltwise_swish) { - std::swap(alpha, beta); - } else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) { - alpha = ctx.Attr("threshold"); - } - - auto diff_dst_tz = framework::vectorize(diff_y->dims()); - - // diff_dst and src dims should be the same - auto src_format = - diff_dst_tz.size() == 2 ? MKLDNNMemoryFormat::nc : x->format(); - - auto diff_y_format = - diff_dst_tz.size() == 2 ? MKLDNNMemoryFormat::nc : diff_y->format(); - platform::ActivationMKLDNNHandler handler( - diff_dst_tz, algorithm, alpha, beta, src_format, diff_y_format, dev_ctx, - ctx.GetPlace(), ctx.InputName("X")); + algorithm, ctx, dev_ctx, ctx.GetPlace(), x, diff_y, ctx.InputName("X")); auto src_memory_p = handler.AcquireBackwardSrcMemory(x); auto diff_dst_memory_p = handler.AcquireDiffDstMemory(diff_y); diff --git a/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc index 75367ba0573209338b3ba85ab2ac7240f07d58d3..99b8d020436fc1418bd8877dd1fd640ae0bb3994 100644 --- a/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc @@ -85,24 +85,54 @@ class BatchNormMKLDNNHandler md, epsilon, flags); } } - BatchNormMKLDNNHandler(const std::vector &dims, const float &epsilon, - const mkldnn::normalization_flags &flags, - const MKLDNNMemoryFormat diff_fmt, - const MKLDNNMemoryFormat src_fmt, + + BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx, const platform::MKLDNNDeviceContext &dev_ctx, - platform::Place cpu_place, - const std::string &uniq_name) + platform::Place cpu_place, const Tensor *in_x, + const Tensor *scale, const Tensor *out_grad, + const std::string &unique_name) : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(dev_ctx, dims, uniq_name)) { - auto diff_dst_md = - mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), diff_fmt); - auto src_md = - mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), src_fmt); - - this->AcquireBackwardPrimitiveDescriptor( - mkldnn::prop_kind::backward, diff_dst_md, src_md, epsilon, flags); + platform::CreateKey(dev_ctx, framework::vectorize(in_x->dims()), + unique_name)) { + if (!this->isBwdCached()) { + PADDLE_ENFORCE_EQ(out_grad->layout(), DataLayout::kMKLDNN, + platform::errors::InvalidArgument( + "Wrong layout set for Input out_grad tensor")); + PADDLE_ENFORCE_NE(out_grad->format(), MKLDNNMemoryFormat::undef, + platform::errors::InvalidArgument( + "Wrong format set for Input out_grad tensor")); + + auto src_tz = paddle::framework::vectorize(in_x->dims()); + auto scale_tz = paddle::framework::vectorize(scale->dims()); + PADDLE_ENFORCE_EQ( + scale_tz.size(), 1, + platform::errors::InvalidArgument( + "Dims of scale tensor must be 1, but received scale's size is %d", + scale_tz.size())); + + MKLDNNMemoryFormat diff_fmt = + platform::MKLDNNFormatForSize(src_tz.size(), out_grad->format()); + + MKLDNNMemoryFormat src_fmt = + platform::MKLDNNFormatForSize(src_tz.size(), in_x->format()); + + auto dims = framework::vectorize(in_x->dims()); + auto diff_dst_md = mkldnn::memory::desc( + dims, platform::MKLDNNGetDataType(), diff_fmt); + auto src_md = + mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), src_fmt); + + const float epsilon = ctx.Attr("epsilon"); + + this->AcquireForwardPrimitiveDescriptor( + mkldnn::prop_kind::forward_training, src_md, epsilon, + mkldnn::normalization_flags::use_scale_shift); + this->AcquireBackwardPrimitiveDescriptor( + mkldnn::prop_kind::backward, diff_dst_md, src_md, epsilon, + mkldnn::normalization_flags::use_scale_shift); + } } std::shared_ptr AcquireScaleShiftMemory(const Tensor *scale, @@ -263,8 +293,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel { auto &dev_ctx = ctx.template device_context(); auto mkldnn_engine = dev_ctx.GetEngine(); - const float epsilon = ctx.Attr("epsilon"); - const auto *x = ctx.Input("X"); const auto *scale = ctx.Input("Scale"); const auto *shift = ctx.Input("Bias"); @@ -275,35 +303,11 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel { auto *diff_scale = ctx.Output(framework::GradVarName("Scale")); auto *diff_shift = ctx.Output(framework::GradVarName("Bias")); - PADDLE_ENFORCE_EQ(diff_y->layout(), DataLayout::kMKLDNN, - platform::errors::InvalidArgument( - "Wrong layout set for Input diff_y tensor")); - PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::undef, - platform::errors::InvalidArgument( - "Wrong format set for Input diff_y tensor")); - - auto src_tz = paddle::framework::vectorize(x->dims()); - auto scale_tz = paddle::framework::vectorize(scale->dims()); - PADDLE_ENFORCE_EQ( - scale_tz.size(), 1, - platform::errors::InvalidArgument( - "Dims of scale tensor must be 1, but received scale's size is %d", - scale_tz.size())); - - const unsigned int C = scale_tz[0]; - - MKLDNNMemoryFormat dst_format = - platform::MKLDNNFormatForSize(src_tz.size(), diff_y->format()); - - MKLDNNMemoryFormat input_format = - platform::MKLDNNFormatForSize(src_tz.size(), x->format()); - - BatchNormMKLDNNHandler handler( - src_tz, epsilon, mkldnn::normalization_flags::use_scale_shift, - dst_format, input_format, dev_ctx, ctx.GetPlace(), - ctx.InputName("SavedMean")); + BatchNormMKLDNNHandler handler(ctx, dev_ctx, ctx.GetPlace(), x, scale, + diff_y, ctx.InputName("SavedMean")); // MKLDNN requires a single piece of memory for scale and shift/bias data + const unsigned int C = paddle::framework::vectorize(scale->dims())[0]; const size_t scaleshift_size = 2 * C; std::vector diff_scaleshift_data; diff_scaleshift_data.reserve(scaleshift_size); @@ -335,7 +339,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel { T *diff_scale_data = diff_scale->mutable_data(ctx.GetPlace()); T *diff_shift_data = diff_shift->mutable_data(ctx.GetPlace()); - // copy back diff sacle/shift to output tensors (diff scale/shift) + // copy back diff scale/shift to output tensors (diff scale/shift) diff_scaleshift_data.resize(scaleshift_size); auto it = std::begin(diff_scaleshift_data); std::copy(it, std::next(it, C), diff_scale_data); diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index fed6a7dfa5e1ce408d954ce3576bedc7e96b0d35..0065f3ae39483236622fb13b95ab8b6a14ca4095 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -90,7 +90,7 @@ class ConvMKLDNNHandlerT dev_ctx, mkldnn_engine, cpu_place, platform::CreateKey(dev_ctx, framework::vectorize(input->dims()), unique_name)) { - if (!this->isCachedNonBlocking()) { + if (!this->isCached()) { PADDLE_ENFORCE_EQ( input->layout(), DataLayout::kMKLDNN, platform::errors::InvalidArgument( @@ -228,12 +228,12 @@ class ConvMKLDNNHandlerT auto bias_md = platform::MKLDNNMemDesc(bias_tz, data_type, MKLDNNMemoryFormat::x); - this->AcquireForwardPrimitiveDescriptorNonBlocking( + this->AcquireForwardPrimitiveDescriptor( conv_attr, fwd_prop_kind, dnnl::algorithm::convolution_direct, src_md, weights_md, bias_md, dst_md, stride_dims, dilations_dims, mkldnn_paddings[0], mkldnn_paddings[1]); } else { - this->AcquireForwardPrimitiveDescriptorNonBlocking( + this->AcquireForwardPrimitiveDescriptor( conv_attr, fwd_prop_kind, dnnl::algorithm::convolution_direct, src_md, weights_md, dst_md, stride_dims, dilations_dims, mkldnn_paddings[0], mkldnn_paddings[1]); @@ -352,25 +352,25 @@ class ConvMKLDNNHandlerT auto bias_md = platform::MKLDNNMemDesc( bias_tz, mkldnn::memory::data_type::f32, MKLDNNMemoryFormat::x); - this->AcquireForwardPrimitiveDescriptorNonBlocking( + this->AcquireForwardPrimitiveDescriptor( conv_attr, mkldnn::prop_kind::forward_training, dnnl::algorithm::convolution_direct, src_md, weights_md, bias_md, dst_md, stride_dims, dilations_dims, mkldnn_paddings[0], mkldnn_paddings[1]); } else { - this->AcquireForwardPrimitiveDescriptorNonBlocking( + this->AcquireForwardPrimitiveDescriptor( conv_attr, mkldnn::prop_kind::forward_training, dnnl::algorithm::convolution_direct, src_md, weights_md, dst_md, stride_dims, dilations_dims, mkldnn_paddings[0], mkldnn_paddings[1]); } - this->AcquireBackwardPrimitiveDescriptorNonBlocking( + this->AcquireBackwardPrimitiveDescriptor( mkldnn::algorithm::convolution_direct, diff_src_md, weights_md, diff_dst_md, strides, dilations_dims, mkldnn_paddings[0], mkldnn_paddings[1]); - this->AcquireBackwardWeightsPrimitiveDescriptorNonBlocking( + this->AcquireBackwardWeightsPrimitiveDescriptor( mkldnn::algorithm::convolution_direct, src_md, diff_weights_md, diff_dst_md, strides, dilations_dims, mkldnn_paddings[0], mkldnn_paddings[1]); diff --git a/paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc index b6b0b486bf060b9f163759540dfa3b17bbe68cb2..5b563e666af0aaa7034594de18fbb69813a93195 100644 --- a/paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc @@ -34,7 +34,7 @@ class LRNMKLDNNHandler : public platform::MKLDNNHandlerTdims()), unique_name)) { - if (!this->isCachedNonBlocking()) { + if (!this->isCached()) { const int n = ctx.Attr("n"); // MKL-DNN implements LRN in a caffe way: // http://caffe.berkeleyvision.org/tutorial/layers/lrn.html @@ -52,7 +52,7 @@ class LRNMKLDNNHandler : public platform::MKLDNNHandlerT(), input->format()); - this->AcquireForwardPrimitiveDescriptorNonBlocking( + this->AcquireForwardPrimitiveDescriptor( is_test ? mkldnn::prop_kind::forward_inference : mkldnn::prop_kind::forward_training, mkldnn::algorithm::lrn_across_channels, src_md, n, alpha, beta, k); @@ -86,11 +86,11 @@ class LRNMKLDNNHandler : public platform::MKLDNNHandlerT(), out_grad->format()); - this->AcquireForwardPrimitiveDescriptorNonBlocking( + this->AcquireForwardPrimitiveDescriptor( mkldnn::prop_kind::forward_training, mkldnn::algorithm::lrn_across_channels, src_md, n, alpha, beta, k); - this->AcquireBackwardPrimitiveDescriptorNonBlocking( + this->AcquireBackwardPrimitiveDescriptor( mkldnn::algorithm::lrn_across_channels, src_md, diff_md, n, alpha, beta, k); } diff --git a/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc index 04e0bcbfc7ce3550c6a874d1949094cea09661a1..920ec97a769b6d12bdcc28606813003b353f0aef 100644 --- a/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc @@ -43,7 +43,7 @@ class PoolingMKLDNNHandler platform::CreateKey(dev_ctx, framework::vectorize(input->dims()), framework::ToMKLDNNDataType(input->type()), unique_name)) { - if (!this->isCachedNonBlocking()) { + if (!this->isCached()) { PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN, platform::errors::InvalidArgument( "Wrong layout set for Input tensor.")); @@ -123,7 +123,7 @@ class PoolingMKLDNNHandler ComputeAdaptivePoolParameters(ctx, src_tz, &ksize, &strides); - this->AcquireForwardPrimitiveDescriptorNonBlocking( + this->AcquireForwardPrimitiveDescriptor( is_test ? mkldnn::prop_kind::forward_inference : mkldnn::prop_kind::forward_training, pooling_type == "max" @@ -220,7 +220,7 @@ class PoolingMKLDNNHandler const auto exclude_padding = ctx.Attr("exclusive"); - this->AcquireForwardPrimitiveDescriptorNonBlocking( + this->AcquireForwardPrimitiveDescriptor( mkldnn::prop_kind::forward_training, pooling_type == "max" ? mkldnn::algorithm::pooling_max @@ -230,7 +230,7 @@ class PoolingMKLDNNHandler src_md, dst_md, strides, ksize, mkldnn_paddings[0], mkldnn_paddings[1]); - this->AcquireBackwardPrimitiveDescriptorNonBlocking( + this->AcquireBackwardPrimitiveDescriptor( pooling_type == "max" ? mkldnn::algorithm::pooling_max : (exclude_padding diff --git a/paddle/fluid/operators/mkldnn/scale_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/scale_mkldnn_op.cc index e91bbd15cfb7c612719ab3b6cdf2ac439b616ab5..ae17048b5d568baf4722e63299c9ef2ca3fb6bae 100644 --- a/paddle/fluid/operators/mkldnn/scale_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/scale_mkldnn_op.cc @@ -30,28 +30,14 @@ class ScaleMKLDNNKernel : public framework::OpKernel { const auto& dev_ctx = ctx.template device_context(); - bool bias_after_scale = ctx.Attr("bias_after_scale"); auto* x = ctx.Input("X"); auto* out = ctx.Output("Out"); - auto* scale_tensor = ctx.Input("ScaleTensor"); - float scale = (scale_tensor == nullptr) ? ctx.Attr("scale") - : (float)*(scale_tensor->data()); - float bias = ctx.Attr("bias"); - - // if bias_after_scale == true - // out = scale*X + bias - // else - // out = scale*(X + bias) = scale*X + scale*bias - - if (!bias_after_scale) bias *= scale; - - auto x_tz = framework::vectorize(x->dims()); bool is_inplaced = x->IsSharedBufferWith(*out); platform::ActivationMKLDNNHandler handler( - x_tz, mkldnn::algorithm::eltwise_linear, scale, bias, x->format(), - dev_ctx, ctx.GetPlace(), ctx.InputName("X"), is_inplaced); + mkldnn::algorithm::eltwise_linear, ctx, dev_ctx, ctx.GetPlace(), x, + ctx.InputName("X"), is_inplaced); auto src_memory_p = handler.AcquireSrcMemory(x); auto dst_memory_p = handler.AcquireDstMemory(out); diff --git a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc index 1d177e120b59f869d6f9ba96197a973d0ad62d5b..e065800e4d1c71ee4bc47fe09b26ed1ea0b9d2c9 100644 --- a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc @@ -50,7 +50,7 @@ class SoftmaxMKLDNNHandler : platform::CreateKey( dev_ctx, framework::vectorize(input->dims()), uniq_name)) { - if (!this->isCachedNonBlocking()) { + if (!this->isCached()) { PADDLE_ENFORCE_EQ( input->dims(), output->dims(), platform::errors::InvalidArgument( @@ -60,8 +60,8 @@ class SoftmaxMKLDNNHandler auto md = memory::desc(softmax_tz, platform::MKLDNNGetDataType(), input->format()); - this->AcquireForwardPrimitiveDescriptorNonBlocking( - prop_kind::forward_scoring, md, axis); + this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, md, + axis); } } @@ -90,10 +90,10 @@ class SoftmaxMKLDNNHandler auto diff_softmax_md = MKLDNNMemDesc( softmax_tz, platform::MKLDNNGetDataType(), out_grad->format()); - this->AcquireForwardPrimitiveDescriptorNonBlocking( - prop_kind::forward_scoring, data_softmax_md, axis); - this->AcquireBackwardPrimitiveDescriptorNonBlocking( - diff_softmax_md, data_softmax_md, axis); + this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, + data_softmax_md, axis); + this->AcquireBackwardPrimitiveDescriptor(diff_softmax_md, data_softmax_md, + axis); } } }; diff --git a/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc index 7618b1d9c31218bf6e15b048801a3bb196a94fce..1813aabf1d8548453932d5850dd48facc980b0ab 100644 --- a/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc @@ -118,17 +118,6 @@ class SumMKLDNNHandler : public platform::MKLDNNHandlerT { inline int GetNumInputs(void) { return num_inputs_; } - protected: - // isCached need to be overloaded as base one works on key_common - bool isCached() { - const std::string key_pd = this->key_ + "@fwd_pd"; - this->fwd_pd_ = std::static_pointer_cast( - this->dev_ctx_.GetBlob(key_pd)); - - const std::string key_p = this->key_ + "@fwd_p"; - return (this->dev_ctx_.GetBlob(key_p) != nullptr); - } - private: int num_inputs_; std::vector srcs_suffix_; diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 2981e5502ce6ac2d5cf55e8bf60a30035f032a3a..514c0b3d3ce7f892d5ad5adc397b44a7bf184171 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -157,15 +157,6 @@ class MKLDNNHandlerT { protected: bool isCached() { - const std::string key_pd = key_common_ + "@fwd_pd"; - fwd_pd_ = std::static_pointer_cast( - dev_ctx_.GetBlob(key_pd)); - - const std::string key_p = key_ + "@fwd_p"; - return (dev_ctx_.GetBlob(key_p) != nullptr); - } - - bool isCachedNonBlocking() { const std::string key_pd = key_ + "@fwd_pd"; fwd_pd_ = std::static_pointer_cast( dev_ctx_.GetBlob(key_pd)); @@ -178,7 +169,18 @@ class MKLDNNHandlerT { bwd_pd_ = std::static_pointer_cast( dev_ctx_.GetBlob(key_pd)); - return (bwd_pd_ != nullptr); + if (bwd_pd_ == nullptr) { + return false; + } else { + // When BWD is cached then still we need to Get FWD PD + const std::string key_fpd = key_ + "@fwd_pd"; + fwd_pd_ = std::static_pointer_cast( + dev_ctx_.GetBlob(key_fpd)); + PADDLE_ENFORCE_NOT_NULL( + fwd_pd_, platform::errors::Unavailable( + "Error: FWD PD should be set when BWD PD is cached.")); + return true; + } } // If your primitive descriptor requires attributes, pass them as a @@ -187,29 +189,6 @@ class MKLDNNHandlerT { // constructor, including the first one. template void AcquireForwardPrimitiveDescriptor(Arg&& first_arg, Args&&... args) { - // Forward PD has to be passed to Grad op that - // may be executed by diffrent thread, hence - // for that one we use key that does not contain TID - const std::string key_pd = key_common_ + "@fwd_pd"; - fwd_pd_ = std::static_pointer_cast( - dev_ctx_.GetBlob(key_pd)); - if (fwd_pd_ == nullptr) { - static std::mutex acquire_barrier; - std::lock_guard block_threads_until_finish_this_job( - acquire_barrier); - fwd_pd_ = std::static_pointer_cast( - dev_ctx_.GetBlob(key_pd)); - if (fwd_pd_ == nullptr) { - CreateForwardPrimitiveDescriptor(first_arg, - std::forward(args)...); - dev_ctx_.SetBlob(key_pd, fwd_pd_); - } - } - } - - template - void AcquireForwardPrimitiveDescriptorNonBlocking(Arg&& first_arg, - Args&&... args) { // This is used when we can recreate FWD PD in BWD so // we do not need to pass FWD to BWD const std::string key_pd = key_ + "@fwd_pd"; @@ -242,31 +221,10 @@ class MKLDNNHandlerT { std::make_shared(fwd_desc, engine_); } - // TODO(jczaja): After/if all ops can used xxxNonBlocking version - // then remove this one template void AcquireBackwardPrimitiveDescriptor(Args&&... args) { - const std::string key_fwd_pd = key_common_ + "@fwd_pd"; - fwd_pd_ = std::static_pointer_cast( - dev_ctx_.GetBlob(key_fwd_pd)); - PADDLE_ENFORCE_NOT_NULL( - fwd_pd_, platform::errors::Unavailable( - "Get MKLDNN Forward primitive %s failed.", key_fwd_pd)); - const std::string key_pd = key_ + "@bwd_pd"; - bwd_pd_ = std::static_pointer_cast( - dev_ctx_.GetBlob(key_pd)); - if (bwd_pd_ == nullptr) { - auto bwd_desc = typename TBackward::desc(std::forward(args)...); - bwd_pd_ = std::make_shared( - bwd_desc, engine_, *fwd_pd_); - dev_ctx_.SetBlob(key_pd, bwd_pd_); - } - } - - template - void AcquireBackwardPrimitiveDescriptorNonBlocking(Args&&... args) { // fwd_pd_ is set during grad by calling - // AcquireForwardPrimitiveDescriptorNonBlocking + // AcquireForwardPrimitiveDescriptor PADDLE_ENFORCE_NOT_NULL( fwd_pd_, platform::errors::Unavailable("Get MKLDNN Forward primitive %s failed.", @@ -283,9 +241,9 @@ class MKLDNNHandlerT { } template - void AcquireBackwardWeightsPrimitiveDescriptorNonBlocking(Args&&... args) { + void AcquireBackwardWeightsPrimitiveDescriptor(Args&&... args) { // fwd_pd_ is set during grad by calling - // AcquireForwardPrimitiveDescriptorNonBlocking + // AcquireForwardPrimitiveDescriptor PADDLE_ENFORCE_NOT_NULL( fwd_pd_, platform::errors::Unavailable("Get MKLDNN Forward primitive %s failed.", @@ -834,45 +792,100 @@ class ActivationMKLDNNHandler : public MKLDNNHandlerT { public: - ActivationMKLDNNHandler(const std::vector& dims, - mkldnn::algorithm algorithm, float alpha, float beta, - const MKLDNNMemoryFormat fmt, - const platform::MKLDNNDeviceContext& dev_ctx, - platform::Place cpu_place, + ActivationMKLDNNHandler(mkldnn::algorithm algorithm, + const framework::ExecutionContext& ctx, + const MKLDNNDeviceContext& dev_ctx, Place cpu_place, + const framework::Tensor* in_x, const std::string& unique_name, bool is_inplaced) - : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - is_inplaced - ? platform::CreateKey(dev_ctx, dims, "a", algorithm, - unique_name) - : platform::CreateKey(dev_ctx, dims, "a", unique_name)) { - auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); - - this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training, - algorithm, md, alpha, beta); - } - - ActivationMKLDNNHandler(const std::vector& dims, - mkldnn::algorithm algorithm, float alpha, float beta, - const MKLDNNMemoryFormat fmt, - const MKLDNNMemoryFormat diff_fmt, - const platform::MKLDNNDeviceContext& dev_ctx, - platform::Place cpu_place, - const std::string& unique_name) + is_inplaced ? platform::CreateKey( + dev_ctx, framework::vectorize(in_x->dims()), "a", + algorithm, unique_name) + : platform::CreateKey( + dev_ctx, framework::vectorize(in_x->dims()), "a", + unique_name)) { + if (!this->isCached()) { + float alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 0; + float beta = ctx.HasAttr("beta") ? ctx.Attr("beta") : 0; + // eltwise_linear means we are in scale op + if (algorithm == mkldnn::algorithm::eltwise_linear) { + bool bias_after_scale = ctx.Attr("bias_after_scale"); + auto* scale_tensor = ctx.Input("ScaleTensor"); + alpha = (scale_tensor == nullptr) ? ctx.Attr("scale") + : (float)*(scale_tensor->data()); + beta = ctx.Attr("bias"); + // if bias_after_scale == true + // out = scale*X + bias + // else + // out = scale*(X + bias) = scale*X + scale*bias + if (!bias_after_scale) beta *= alpha; + } else { + // paddle uses beta but mkldnn uses alpha for swish + if (algorithm == mkldnn::algorithm::eltwise_swish) { + std::swap(alpha, beta); + } else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) { + alpha = ctx.Attr("threshold"); + } + } + + PADDLE_ENFORCE(in_x->dims().size() >= 1 || in_x->dims().size() <= 6, + platform::errors::Unimplemented( + "Input dimension size can be 1, 2, 3, 4, " + "5, or 6, but now the dimension size is", + in_x->dims().size())); + auto src_tz = framework::vectorize(in_x->dims()); + auto src_fmt = + src_tz.size() == 2 ? MKLDNNMemoryFormat::nc : in_x->format(); + auto md = mkldnn::memory::desc(src_tz, platform::MKLDNNGetDataType(), + src_fmt); + + this->AcquireForwardPrimitiveDescriptor( + mkldnn::prop_kind::forward_training, algorithm, md, alpha, beta); + } + } + + ActivationMKLDNNHandler(mkldnn::algorithm algorithm, + const framework::ExecutionContext& ctx, + const MKLDNNDeviceContext& dev_ctx, Place cpu_place, + const framework::Tensor* in_x, const Tensor* out_grad, + const std::string& unique_name) : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(dev_ctx, dims, "a", unique_name)) { - auto diff_dst_md = platform::MKLDNNMemDesc( - dims, platform::MKLDNNGetDataType(), diff_fmt); - auto src_md = - platform::MKLDNNMemDesc(dims, platform::MKLDNNGetDataType(), fmt); - - this->AcquireBackwardPrimitiveDescriptor(algorithm, diff_dst_md, src_md, - alpha, beta); + platform::CreateKey(dev_ctx, framework::vectorize(in_x->dims()), + "a", unique_name)) { + if (!this->isBwdCached()) { + float alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 0; + float beta = ctx.HasAttr("beta") ? ctx.Attr("beta") : 0; + + // paddle uses beta but mkldnn uses alpha for swish + if (algorithm == mkldnn::algorithm::eltwise_swish) { + std::swap(alpha, beta); + } else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) { + alpha = ctx.Attr("threshold"); + } + + auto diff_dst_tz = framework::vectorize(out_grad->dims()); + + auto src_fmt = + diff_dst_tz.size() == 2 ? MKLDNNMemoryFormat::nc : in_x->format(); + auto diff_fmt = + diff_dst_tz.size() == 2 ? MKLDNNMemoryFormat::nc : out_grad->format(); + + auto dims = framework::vectorize(in_x->dims()); + auto diff_dst_md = platform::MKLDNNMemDesc( + dims, platform::MKLDNNGetDataType(), diff_fmt); + auto src_md = platform::MKLDNNMemDesc( + dims, platform::MKLDNNGetDataType(), src_fmt); + + this->AcquireForwardPrimitiveDescriptor( + mkldnn::prop_kind::forward_training, algorithm, src_md, alpha, beta); + this->AcquireBackwardPrimitiveDescriptor(algorithm, diff_dst_md, src_md, + alpha, beta); + } } std::shared_ptr AcquireBackwardSrcMemory( diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_batch_norm_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_batch_norm_mkldnn_op.py index 1f34bebe949df3be505364429113d66e46ca48da..85b398f684237ceb70e6a43cd710d0b00d989106 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_batch_norm_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_batch_norm_mkldnn_op.py @@ -115,4 +115,6 @@ class TestMKLDNNBatchNormOpWithReluInference(TestBatchNormOpInference): if __name__ == '__main__': + from paddle import enable_static + enable_static() unittest.main()