未验证 提交 f9ce1b1a 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] Further ops refactoring of oneDNN cache access (#33515)

* - Draft of implementation of refactoring

- compilation fix

* - Fixes after review

* - Removed unnecessary comment
上级 4ddd595f
...@@ -83,30 +83,11 @@ void eltwise_forward(const framework::ExecutionContext &ctx, ...@@ -83,30 +83,11 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
const auto *x = ctx.Input<Tensor>("X"); const auto *x = ctx.Input<Tensor>("X");
auto *y = ctx.Output<Tensor>("Out"); auto *y = ctx.Output<Tensor>("Out");
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 0;
float beta = ctx.HasAttr("beta") ? ctx.Attr<float>("beta") : 0;
// paddle uses beta but mkldnn uses alpha for swish
if (algorithm == mkldnn::algorithm::eltwise_swish) {
std::swap(alpha, beta);
} else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) {
alpha = ctx.Attr<float>("threshold");
}
PADDLE_ENFORCE(
x->dims().size() >= 1 || x->dims().size() <= 6,
platform::errors::Unimplemented("Input dimension size can be 1, 2, 3, 4, "
"5, or 6, but now the dimension size is",
x->dims().size()));
bool is_inplaced = x->IsSharedBufferWith(*y); bool is_inplaced = x->IsSharedBufferWith(*y);
auto src_tz = framework::vectorize<int64_t>(x->dims());
auto src_format = src_tz.size() == 2 ? MKLDNNMemoryFormat::nc : x->format(); platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, dev_ctx,
ctx.GetPlace(), x,
platform::ActivationMKLDNNHandler<T> handler( ctx.InputName("X"), is_inplaced);
src_tz, algorithm, alpha, beta, src_format, dev_ctx, ctx.GetPlace(),
ctx.InputName("X"), is_inplaced);
auto src_memory_p = handler.AcquireSrcMemory(x); auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = is_inplaced ? src_memory_p : handler.AcquireDstMemory(y); auto dst_memory_p = is_inplaced ? src_memory_p : handler.AcquireDstMemory(y);
...@@ -130,28 +111,8 @@ void eltwise_grad(const framework::ExecutionContext &ctx, ...@@ -130,28 +111,8 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out")); const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X")); auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X"));
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 0;
float beta = ctx.HasAttr("beta") ? ctx.Attr<float>("beta") : 0;
// paddle uses beta but mkldnn uses alpha for swish
if (algorithm == mkldnn::algorithm::eltwise_swish) {
std::swap(alpha, beta);
} else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) {
alpha = ctx.Attr<float>("threshold");
}
auto diff_dst_tz = framework::vectorize<int64_t>(diff_y->dims());
// diff_dst and src dims should be the same
auto src_format =
diff_dst_tz.size() == 2 ? MKLDNNMemoryFormat::nc : x->format();
auto diff_y_format =
diff_dst_tz.size() == 2 ? MKLDNNMemoryFormat::nc : diff_y->format();
platform::ActivationMKLDNNHandler<T> handler( platform::ActivationMKLDNNHandler<T> handler(
diff_dst_tz, algorithm, alpha, beta, src_format, diff_y_format, dev_ctx, algorithm, ctx, dev_ctx, ctx.GetPlace(), x, diff_y, ctx.InputName("X"));
ctx.GetPlace(), ctx.InputName("X"));
auto src_memory_p = handler.AcquireBackwardSrcMemory(x); auto src_memory_p = handler.AcquireBackwardSrcMemory(x);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(diff_y); auto diff_dst_memory_p = handler.AcquireDiffDstMemory(diff_y);
......
...@@ -85,24 +85,54 @@ class BatchNormMKLDNNHandler ...@@ -85,24 +85,54 @@ class BatchNormMKLDNNHandler
md, epsilon, flags); md, epsilon, flags);
} }
} }
BatchNormMKLDNNHandler(const std::vector<int64_t> &dims, const float &epsilon,
const mkldnn::normalization_flags &flags, BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx,
const MKLDNNMemoryFormat diff_fmt,
const MKLDNNMemoryFormat src_fmt,
const platform::MKLDNNDeviceContext &dev_ctx, const platform::MKLDNNDeviceContext &dev_ctx,
platform::Place cpu_place, platform::Place cpu_place, const Tensor *in_x,
const std::string &uniq_name) const Tensor *scale, const Tensor *out_grad,
const std::string &unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward, : platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward,
mkldnn::batch_normalization_backward>( mkldnn::batch_normalization_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dev_ctx, dims, uniq_name)) { platform::CreateKey(dev_ctx, framework::vectorize(in_x->dims()),
auto diff_dst_md = unique_name)) {
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), diff_fmt); if (!this->isBwdCached()) {
auto src_md = PADDLE_ENFORCE_EQ(out_grad->layout(), DataLayout::kMKLDNN,
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), src_fmt); platform::errors::InvalidArgument(
"Wrong layout set for Input out_grad tensor"));
this->AcquireBackwardPrimitiveDescriptor( PADDLE_ENFORCE_NE(out_grad->format(), MKLDNNMemoryFormat::undef,
mkldnn::prop_kind::backward, diff_dst_md, src_md, epsilon, flags); platform::errors::InvalidArgument(
"Wrong format set for Input out_grad tensor"));
auto src_tz = paddle::framework::vectorize<int64_t>(in_x->dims());
auto scale_tz = paddle::framework::vectorize<int64_t>(scale->dims());
PADDLE_ENFORCE_EQ(
scale_tz.size(), 1,
platform::errors::InvalidArgument(
"Dims of scale tensor must be 1, but received scale's size is %d",
scale_tz.size()));
MKLDNNMemoryFormat diff_fmt =
platform::MKLDNNFormatForSize(src_tz.size(), out_grad->format());
MKLDNNMemoryFormat src_fmt =
platform::MKLDNNFormatForSize(src_tz.size(), in_x->format());
auto dims = framework::vectorize(in_x->dims());
auto diff_dst_md = mkldnn::memory::desc(
dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
auto src_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), src_fmt);
const float epsilon = ctx.Attr<float>("epsilon");
this->AcquireForwardPrimitiveDescriptor(
mkldnn::prop_kind::forward_training, src_md, epsilon,
mkldnn::normalization_flags::use_scale_shift);
this->AcquireBackwardPrimitiveDescriptor(
mkldnn::prop_kind::backward, diff_dst_md, src_md, epsilon,
mkldnn::normalization_flags::use_scale_shift);
}
} }
std::shared_ptr<mkldnn::memory> AcquireScaleShiftMemory(const Tensor *scale, std::shared_ptr<mkldnn::memory> AcquireScaleShiftMemory(const Tensor *scale,
...@@ -263,8 +293,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -263,8 +293,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto mkldnn_engine = dev_ctx.GetEngine(); auto mkldnn_engine = dev_ctx.GetEngine();
const float epsilon = ctx.Attr<float>("epsilon");
const auto *x = ctx.Input<Tensor>("X"); const auto *x = ctx.Input<Tensor>("X");
const auto *scale = ctx.Input<Tensor>("Scale"); const auto *scale = ctx.Input<Tensor>("Scale");
const auto *shift = ctx.Input<Tensor>("Bias"); const auto *shift = ctx.Input<Tensor>("Bias");
...@@ -275,35 +303,11 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -275,35 +303,11 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto *diff_scale = ctx.Output<Tensor>(framework::GradVarName("Scale")); auto *diff_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *diff_shift = ctx.Output<Tensor>(framework::GradVarName("Bias")); auto *diff_shift = ctx.Output<Tensor>(framework::GradVarName("Bias"));
PADDLE_ENFORCE_EQ(diff_y->layout(), DataLayout::kMKLDNN, BatchNormMKLDNNHandler<T> handler(ctx, dev_ctx, ctx.GetPlace(), x, scale,
platform::errors::InvalidArgument( diff_y, ctx.InputName("SavedMean"));
"Wrong layout set for Input diff_y tensor"));
PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Wrong format set for Input diff_y tensor"));
auto src_tz = paddle::framework::vectorize<int64_t>(x->dims());
auto scale_tz = paddle::framework::vectorize<int64_t>(scale->dims());
PADDLE_ENFORCE_EQ(
scale_tz.size(), 1,
platform::errors::InvalidArgument(
"Dims of scale tensor must be 1, but received scale's size is %d",
scale_tz.size()));
const unsigned int C = scale_tz[0];
MKLDNNMemoryFormat dst_format =
platform::MKLDNNFormatForSize(src_tz.size(), diff_y->format());
MKLDNNMemoryFormat input_format =
platform::MKLDNNFormatForSize(src_tz.size(), x->format());
BatchNormMKLDNNHandler<T> handler(
src_tz, epsilon, mkldnn::normalization_flags::use_scale_shift,
dst_format, input_format, dev_ctx, ctx.GetPlace(),
ctx.InputName("SavedMean"));
// MKLDNN requires a single piece of memory for scale and shift/bias data // MKLDNN requires a single piece of memory for scale and shift/bias data
const unsigned int C = paddle::framework::vectorize(scale->dims())[0];
const size_t scaleshift_size = 2 * C; const size_t scaleshift_size = 2 * C;
std::vector<T> diff_scaleshift_data; std::vector<T> diff_scaleshift_data;
diff_scaleshift_data.reserve(scaleshift_size); diff_scaleshift_data.reserve(scaleshift_size);
...@@ -335,7 +339,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -335,7 +339,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
T *diff_scale_data = diff_scale->mutable_data<T>(ctx.GetPlace()); T *diff_scale_data = diff_scale->mutable_data<T>(ctx.GetPlace());
T *diff_shift_data = diff_shift->mutable_data<T>(ctx.GetPlace()); T *diff_shift_data = diff_shift->mutable_data<T>(ctx.GetPlace());
// copy back diff sacle/shift to output tensors (diff scale/shift) // copy back diff scale/shift to output tensors (diff scale/shift)
diff_scaleshift_data.resize(scaleshift_size); diff_scaleshift_data.resize(scaleshift_size);
auto it = std::begin(diff_scaleshift_data); auto it = std::begin(diff_scaleshift_data);
std::copy(it, std::next(it, C), diff_scale_data); std::copy(it, std::next(it, C), diff_scale_data);
......
...@@ -90,7 +90,7 @@ class ConvMKLDNNHandlerT ...@@ -90,7 +90,7 @@ class ConvMKLDNNHandlerT
dev_ctx, mkldnn_engine, cpu_place, dev_ctx, mkldnn_engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(input->dims()), platform::CreateKey(dev_ctx, framework::vectorize(input->dims()),
unique_name)) { unique_name)) {
if (!this->isCachedNonBlocking()) { if (!this->isCached()) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input->layout(), DataLayout::kMKLDNN, input->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -228,12 +228,12 @@ class ConvMKLDNNHandlerT ...@@ -228,12 +228,12 @@ class ConvMKLDNNHandlerT
auto bias_md = auto bias_md =
platform::MKLDNNMemDesc(bias_tz, data_type, MKLDNNMemoryFormat::x); platform::MKLDNNMemDesc(bias_tz, data_type, MKLDNNMemoryFormat::x);
this->AcquireForwardPrimitiveDescriptorNonBlocking( this->AcquireForwardPrimitiveDescriptor(
conv_attr, fwd_prop_kind, dnnl::algorithm::convolution_direct, conv_attr, fwd_prop_kind, dnnl::algorithm::convolution_direct,
src_md, weights_md, bias_md, dst_md, stride_dims, dilations_dims, src_md, weights_md, bias_md, dst_md, stride_dims, dilations_dims,
mkldnn_paddings[0], mkldnn_paddings[1]); mkldnn_paddings[0], mkldnn_paddings[1]);
} else { } else {
this->AcquireForwardPrimitiveDescriptorNonBlocking( this->AcquireForwardPrimitiveDescriptor(
conv_attr, fwd_prop_kind, dnnl::algorithm::convolution_direct, conv_attr, fwd_prop_kind, dnnl::algorithm::convolution_direct,
src_md, weights_md, dst_md, stride_dims, dilations_dims, src_md, weights_md, dst_md, stride_dims, dilations_dims,
mkldnn_paddings[0], mkldnn_paddings[1]); mkldnn_paddings[0], mkldnn_paddings[1]);
...@@ -352,25 +352,25 @@ class ConvMKLDNNHandlerT ...@@ -352,25 +352,25 @@ class ConvMKLDNNHandlerT
auto bias_md = platform::MKLDNNMemDesc( auto bias_md = platform::MKLDNNMemDesc(
bias_tz, mkldnn::memory::data_type::f32, MKLDNNMemoryFormat::x); bias_tz, mkldnn::memory::data_type::f32, MKLDNNMemoryFormat::x);
this->AcquireForwardPrimitiveDescriptorNonBlocking( this->AcquireForwardPrimitiveDescriptor(
conv_attr, mkldnn::prop_kind::forward_training, conv_attr, mkldnn::prop_kind::forward_training,
dnnl::algorithm::convolution_direct, src_md, weights_md, bias_md, dnnl::algorithm::convolution_direct, src_md, weights_md, bias_md,
dst_md, stride_dims, dilations_dims, mkldnn_paddings[0], dst_md, stride_dims, dilations_dims, mkldnn_paddings[0],
mkldnn_paddings[1]); mkldnn_paddings[1]);
} else { } else {
this->AcquireForwardPrimitiveDescriptorNonBlocking( this->AcquireForwardPrimitiveDescriptor(
conv_attr, mkldnn::prop_kind::forward_training, conv_attr, mkldnn::prop_kind::forward_training,
dnnl::algorithm::convolution_direct, src_md, weights_md, dst_md, dnnl::algorithm::convolution_direct, src_md, weights_md, dst_md,
stride_dims, dilations_dims, mkldnn_paddings[0], stride_dims, dilations_dims, mkldnn_paddings[0],
mkldnn_paddings[1]); mkldnn_paddings[1]);
} }
this->AcquireBackwardPrimitiveDescriptorNonBlocking( this->AcquireBackwardPrimitiveDescriptor(
mkldnn::algorithm::convolution_direct, diff_src_md, weights_md, mkldnn::algorithm::convolution_direct, diff_src_md, weights_md,
diff_dst_md, strides, dilations_dims, mkldnn_paddings[0], diff_dst_md, strides, dilations_dims, mkldnn_paddings[0],
mkldnn_paddings[1]); mkldnn_paddings[1]);
this->AcquireBackwardWeightsPrimitiveDescriptorNonBlocking( this->AcquireBackwardWeightsPrimitiveDescriptor(
mkldnn::algorithm::convolution_direct, src_md, diff_weights_md, mkldnn::algorithm::convolution_direct, src_md, diff_weights_md,
diff_dst_md, strides, dilations_dims, mkldnn_paddings[0], diff_dst_md, strides, dilations_dims, mkldnn_paddings[0],
mkldnn_paddings[1]); mkldnn_paddings[1]);
......
...@@ -34,7 +34,7 @@ class LRNMKLDNNHandler : public platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, ...@@ -34,7 +34,7 @@ class LRNMKLDNNHandler : public platform::MKLDNNHandlerT<T, mkldnn::lrn_forward,
dev_ctx, mkldnn_engine, cpu_place, dev_ctx, mkldnn_engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(input->dims()), platform::CreateKey(dev_ctx, framework::vectorize(input->dims()),
unique_name)) { unique_name)) {
if (!this->isCachedNonBlocking()) { if (!this->isCached()) {
const int n = ctx.Attr<int>("n"); const int n = ctx.Attr<int>("n");
// MKL-DNN implements LRN in a caffe way: // MKL-DNN implements LRN in a caffe way:
// http://caffe.berkeleyvision.org/tutorial/layers/lrn.html // http://caffe.berkeleyvision.org/tutorial/layers/lrn.html
...@@ -52,7 +52,7 @@ class LRNMKLDNNHandler : public platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, ...@@ -52,7 +52,7 @@ class LRNMKLDNNHandler : public platform::MKLDNNHandlerT<T, mkldnn::lrn_forward,
auto src_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), auto src_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(),
input->format()); input->format());
this->AcquireForwardPrimitiveDescriptorNonBlocking( this->AcquireForwardPrimitiveDescriptor(
is_test ? mkldnn::prop_kind::forward_inference is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training, : mkldnn::prop_kind::forward_training,
mkldnn::algorithm::lrn_across_channels, src_md, n, alpha, beta, k); mkldnn::algorithm::lrn_across_channels, src_md, n, alpha, beta, k);
...@@ -86,11 +86,11 @@ class LRNMKLDNNHandler : public platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, ...@@ -86,11 +86,11 @@ class LRNMKLDNNHandler : public platform::MKLDNNHandlerT<T, mkldnn::lrn_forward,
auto diff_md = mkldnn::memory::desc( auto diff_md = mkldnn::memory::desc(
dims, platform::MKLDNNGetDataType<T>(), out_grad->format()); dims, platform::MKLDNNGetDataType<T>(), out_grad->format());
this->AcquireForwardPrimitiveDescriptorNonBlocking( this->AcquireForwardPrimitiveDescriptor(
mkldnn::prop_kind::forward_training, mkldnn::prop_kind::forward_training,
mkldnn::algorithm::lrn_across_channels, src_md, n, alpha, beta, k); mkldnn::algorithm::lrn_across_channels, src_md, n, alpha, beta, k);
this->AcquireBackwardPrimitiveDescriptorNonBlocking( this->AcquireBackwardPrimitiveDescriptor(
mkldnn::algorithm::lrn_across_channels, src_md, diff_md, n, alpha, mkldnn::algorithm::lrn_across_channels, src_md, diff_md, n, alpha,
beta, k); beta, k);
} }
......
...@@ -43,7 +43,7 @@ class PoolingMKLDNNHandler ...@@ -43,7 +43,7 @@ class PoolingMKLDNNHandler
platform::CreateKey(dev_ctx, framework::vectorize(input->dims()), platform::CreateKey(dev_ctx, framework::vectorize(input->dims()),
framework::ToMKLDNNDataType(input->type()), framework::ToMKLDNNDataType(input->type()),
unique_name)) { unique_name)) {
if (!this->isCachedNonBlocking()) { if (!this->isCached()) {
PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN, PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Wrong layout set for Input tensor.")); "Wrong layout set for Input tensor."));
...@@ -123,7 +123,7 @@ class PoolingMKLDNNHandler ...@@ -123,7 +123,7 @@ class PoolingMKLDNNHandler
ComputeAdaptivePoolParameters(ctx, src_tz, &ksize, &strides); ComputeAdaptivePoolParameters(ctx, src_tz, &ksize, &strides);
this->AcquireForwardPrimitiveDescriptorNonBlocking( this->AcquireForwardPrimitiveDescriptor(
is_test ? mkldnn::prop_kind::forward_inference is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training, : mkldnn::prop_kind::forward_training,
pooling_type == "max" pooling_type == "max"
...@@ -220,7 +220,7 @@ class PoolingMKLDNNHandler ...@@ -220,7 +220,7 @@ class PoolingMKLDNNHandler
const auto exclude_padding = ctx.Attr<bool>("exclusive"); const auto exclude_padding = ctx.Attr<bool>("exclusive");
this->AcquireForwardPrimitiveDescriptorNonBlocking( this->AcquireForwardPrimitiveDescriptor(
mkldnn::prop_kind::forward_training, mkldnn::prop_kind::forward_training,
pooling_type == "max" pooling_type == "max"
? mkldnn::algorithm::pooling_max ? mkldnn::algorithm::pooling_max
...@@ -230,7 +230,7 @@ class PoolingMKLDNNHandler ...@@ -230,7 +230,7 @@ class PoolingMKLDNNHandler
src_md, dst_md, strides, ksize, mkldnn_paddings[0], src_md, dst_md, strides, ksize, mkldnn_paddings[0],
mkldnn_paddings[1]); mkldnn_paddings[1]);
this->AcquireBackwardPrimitiveDescriptorNonBlocking( this->AcquireBackwardPrimitiveDescriptor(
pooling_type == "max" pooling_type == "max"
? mkldnn::algorithm::pooling_max ? mkldnn::algorithm::pooling_max
: (exclude_padding : (exclude_padding
......
...@@ -30,28 +30,14 @@ class ScaleMKLDNNKernel : public framework::OpKernel<T> { ...@@ -30,28 +30,14 @@ class ScaleMKLDNNKernel : public framework::OpKernel<T> {
const auto& dev_ctx = const auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
bool bias_after_scale = ctx.Attr<bool>("bias_after_scale");
auto* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out"); auto* out = ctx.Output<Tensor>("Out");
auto* scale_tensor = ctx.Input<Tensor>("ScaleTensor");
float scale = (scale_tensor == nullptr) ? ctx.Attr<float>("scale")
: (float)*(scale_tensor->data<T>());
float bias = ctx.Attr<float>("bias");
// if bias_after_scale == true
// out = scale*X + bias
// else
// out = scale*(X + bias) = scale*X + scale*bias
if (!bias_after_scale) bias *= scale;
auto x_tz = framework::vectorize<int64_t>(x->dims());
bool is_inplaced = x->IsSharedBufferWith(*out); bool is_inplaced = x->IsSharedBufferWith(*out);
platform::ActivationMKLDNNHandler<T> handler( platform::ActivationMKLDNNHandler<T> handler(
x_tz, mkldnn::algorithm::eltwise_linear, scale, bias, x->format(), mkldnn::algorithm::eltwise_linear, ctx, dev_ctx, ctx.GetPlace(), x,
dev_ctx, ctx.GetPlace(), ctx.InputName("X"), is_inplaced); ctx.InputName("X"), is_inplaced);
auto src_memory_p = handler.AcquireSrcMemory(x); auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireDstMemory(out); auto dst_memory_p = handler.AcquireDstMemory(out);
......
...@@ -50,7 +50,7 @@ class SoftmaxMKLDNNHandler ...@@ -50,7 +50,7 @@ class SoftmaxMKLDNNHandler
: platform::CreateKey( : platform::CreateKey(
dev_ctx, framework::vectorize(input->dims()), dev_ctx, framework::vectorize(input->dims()),
uniq_name)) { uniq_name)) {
if (!this->isCachedNonBlocking()) { if (!this->isCached()) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input->dims(), output->dims(), input->dims(), output->dims(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -60,8 +60,8 @@ class SoftmaxMKLDNNHandler ...@@ -60,8 +60,8 @@ class SoftmaxMKLDNNHandler
auto md = memory::desc(softmax_tz, platform::MKLDNNGetDataType<T>(), auto md = memory::desc(softmax_tz, platform::MKLDNNGetDataType<T>(),
input->format()); input->format());
this->AcquireForwardPrimitiveDescriptorNonBlocking( this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, md,
prop_kind::forward_scoring, md, axis); axis);
} }
} }
...@@ -90,10 +90,10 @@ class SoftmaxMKLDNNHandler ...@@ -90,10 +90,10 @@ class SoftmaxMKLDNNHandler
auto diff_softmax_md = MKLDNNMemDesc( auto diff_softmax_md = MKLDNNMemDesc(
softmax_tz, platform::MKLDNNGetDataType<T>(), out_grad->format()); softmax_tz, platform::MKLDNNGetDataType<T>(), out_grad->format());
this->AcquireForwardPrimitiveDescriptorNonBlocking( this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring,
prop_kind::forward_scoring, data_softmax_md, axis); data_softmax_md, axis);
this->AcquireBackwardPrimitiveDescriptorNonBlocking( this->AcquireBackwardPrimitiveDescriptor(diff_softmax_md, data_softmax_md,
diff_softmax_md, data_softmax_md, axis); axis);
} }
} }
}; };
......
...@@ -118,17 +118,6 @@ class SumMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::sum> { ...@@ -118,17 +118,6 @@ class SumMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::sum> {
inline int GetNumInputs(void) { return num_inputs_; } inline int GetNumInputs(void) { return num_inputs_; }
protected:
// isCached need to be overloaded as base one works on key_common
bool isCached() {
const std::string key_pd = this->key_ + "@fwd_pd";
this->fwd_pd_ = std::static_pointer_cast<dnnl::sum::primitive_desc>(
this->dev_ctx_.GetBlob(key_pd));
const std::string key_p = this->key_ + "@fwd_p";
return (this->dev_ctx_.GetBlob(key_p) != nullptr);
}
private: private:
int num_inputs_; int num_inputs_;
std::vector<std::string> srcs_suffix_; std::vector<std::string> srcs_suffix_;
......
...@@ -157,15 +157,6 @@ class MKLDNNHandlerT { ...@@ -157,15 +157,6 @@ class MKLDNNHandlerT {
protected: protected:
bool isCached() { bool isCached() {
const std::string key_pd = key_common_ + "@fwd_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
const std::string key_p = key_ + "@fwd_p";
return (dev_ctx_.GetBlob(key_p) != nullptr);
}
bool isCachedNonBlocking() {
const std::string key_pd = key_ + "@fwd_pd"; const std::string key_pd = key_ + "@fwd_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>( fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_pd)); dev_ctx_.GetBlob(key_pd));
...@@ -178,7 +169,18 @@ class MKLDNNHandlerT { ...@@ -178,7 +169,18 @@ class MKLDNNHandlerT {
bwd_pd_ = std::static_pointer_cast<typename TBackward::primitive_desc>( bwd_pd_ = std::static_pointer_cast<typename TBackward::primitive_desc>(
dev_ctx_.GetBlob(key_pd)); dev_ctx_.GetBlob(key_pd));
return (bwd_pd_ != nullptr); if (bwd_pd_ == nullptr) {
return false;
} else {
// When BWD is cached then still we need to Get FWD PD
const std::string key_fpd = key_ + "@fwd_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_fpd));
PADDLE_ENFORCE_NOT_NULL(
fwd_pd_, platform::errors::Unavailable(
"Error: FWD PD should be set when BWD PD is cached."));
return true;
}
} }
// If your primitive descriptor requires attributes, pass them as a // If your primitive descriptor requires attributes, pass them as a
...@@ -187,29 +189,6 @@ class MKLDNNHandlerT { ...@@ -187,29 +189,6 @@ class MKLDNNHandlerT {
// constructor, including the first one. // constructor, including the first one.
template <typename Arg, typename... Args> template <typename Arg, typename... Args>
void AcquireForwardPrimitiveDescriptor(Arg&& first_arg, Args&&... args) { void AcquireForwardPrimitiveDescriptor(Arg&& first_arg, Args&&... args) {
// Forward PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
const std::string key_pd = key_common_ + "@fwd_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
if (fwd_pd_ == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
if (fwd_pd_ == nullptr) {
CreateForwardPrimitiveDescriptor(first_arg,
std::forward<Args>(args)...);
dev_ctx_.SetBlob(key_pd, fwd_pd_);
}
}
}
template <typename Arg, typename... Args>
void AcquireForwardPrimitiveDescriptorNonBlocking(Arg&& first_arg,
Args&&... args) {
// This is used when we can recreate FWD PD in BWD so // This is used when we can recreate FWD PD in BWD so
// we do not need to pass FWD to BWD // we do not need to pass FWD to BWD
const std::string key_pd = key_ + "@fwd_pd"; const std::string key_pd = key_ + "@fwd_pd";
...@@ -242,31 +221,10 @@ class MKLDNNHandlerT { ...@@ -242,31 +221,10 @@ class MKLDNNHandlerT {
std::make_shared<typename TForward::primitive_desc>(fwd_desc, engine_); std::make_shared<typename TForward::primitive_desc>(fwd_desc, engine_);
} }
// TODO(jczaja): After/if all ops can used xxxNonBlocking version
// then remove this one
template <typename... Args> template <typename... Args>
void AcquireBackwardPrimitiveDescriptor(Args&&... args) { void AcquireBackwardPrimitiveDescriptor(Args&&... args) {
const std::string key_fwd_pd = key_common_ + "@fwd_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_fwd_pd));
PADDLE_ENFORCE_NOT_NULL(
fwd_pd_, platform::errors::Unavailable(
"Get MKLDNN Forward primitive %s failed.", key_fwd_pd));
const std::string key_pd = key_ + "@bwd_pd";
bwd_pd_ = std::static_pointer_cast<typename TBackward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
if (bwd_pd_ == nullptr) {
auto bwd_desc = typename TBackward::desc(std::forward<Args>(args)...);
bwd_pd_ = std::make_shared<typename TBackward::primitive_desc>(
bwd_desc, engine_, *fwd_pd_);
dev_ctx_.SetBlob(key_pd, bwd_pd_);
}
}
template <typename... Args>
void AcquireBackwardPrimitiveDescriptorNonBlocking(Args&&... args) {
// fwd_pd_ is set during grad by calling // fwd_pd_ is set during grad by calling
// AcquireForwardPrimitiveDescriptorNonBlocking // AcquireForwardPrimitiveDescriptor
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
fwd_pd_, fwd_pd_,
platform::errors::Unavailable("Get MKLDNN Forward primitive %s failed.", platform::errors::Unavailable("Get MKLDNN Forward primitive %s failed.",
...@@ -283,9 +241,9 @@ class MKLDNNHandlerT { ...@@ -283,9 +241,9 @@ class MKLDNNHandlerT {
} }
template <typename... Args> template <typename... Args>
void AcquireBackwardWeightsPrimitiveDescriptorNonBlocking(Args&&... args) { void AcquireBackwardWeightsPrimitiveDescriptor(Args&&... args) {
// fwd_pd_ is set during grad by calling // fwd_pd_ is set during grad by calling
// AcquireForwardPrimitiveDescriptorNonBlocking // AcquireForwardPrimitiveDescriptor
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
fwd_pd_, fwd_pd_,
platform::errors::Unavailable("Get MKLDNN Forward primitive %s failed.", platform::errors::Unavailable("Get MKLDNN Forward primitive %s failed.",
...@@ -834,45 +792,100 @@ class ActivationMKLDNNHandler ...@@ -834,45 +792,100 @@ class ActivationMKLDNNHandler
: public MKLDNNHandlerT<T, mkldnn::eltwise_forward, : public MKLDNNHandlerT<T, mkldnn::eltwise_forward,
mkldnn::eltwise_backward> { mkldnn::eltwise_backward> {
public: public:
ActivationMKLDNNHandler(const std::vector<int64_t>& dims, ActivationMKLDNNHandler(mkldnn::algorithm algorithm,
mkldnn::algorithm algorithm, float alpha, float beta, const framework::ExecutionContext& ctx,
const MKLDNNMemoryFormat fmt, const MKLDNNDeviceContext& dev_ctx, Place cpu_place,
const platform::MKLDNNDeviceContext& dev_ctx, const framework::Tensor* in_x,
platform::Place cpu_place,
const std::string& unique_name, bool is_inplaced) const std::string& unique_name, bool is_inplaced)
: platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward, : platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward,
mkldnn::eltwise_backward>( mkldnn::eltwise_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, dev_ctx, dev_ctx.GetEngine(), cpu_place,
is_inplaced is_inplaced ? platform::CreateKey(
? platform::CreateKey(dev_ctx, dims, "a", algorithm, dev_ctx, framework::vectorize(in_x->dims()), "a",
unique_name) algorithm, unique_name)
: platform::CreateKey(dev_ctx, dims, "a", unique_name)) { : platform::CreateKey(
auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt); dev_ctx, framework::vectorize(in_x->dims()), "a",
unique_name)) {
this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training, if (!this->isCached()) {
algorithm, md, alpha, beta); float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 0;
} float beta = ctx.HasAttr("beta") ? ctx.Attr<float>("beta") : 0;
// eltwise_linear means we are in scale op
ActivationMKLDNNHandler(const std::vector<int64_t>& dims, if (algorithm == mkldnn::algorithm::eltwise_linear) {
mkldnn::algorithm algorithm, float alpha, float beta, bool bias_after_scale = ctx.Attr<bool>("bias_after_scale");
const MKLDNNMemoryFormat fmt, auto* scale_tensor = ctx.Input<Tensor>("ScaleTensor");
const MKLDNNMemoryFormat diff_fmt, alpha = (scale_tensor == nullptr) ? ctx.Attr<float>("scale")
const platform::MKLDNNDeviceContext& dev_ctx, : (float)*(scale_tensor->data<T>());
platform::Place cpu_place, beta = ctx.Attr<float>("bias");
const std::string& unique_name) // if bias_after_scale == true
// out = scale*X + bias
// else
// out = scale*(X + bias) = scale*X + scale*bias
if (!bias_after_scale) beta *= alpha;
} else {
// paddle uses beta but mkldnn uses alpha for swish
if (algorithm == mkldnn::algorithm::eltwise_swish) {
std::swap(alpha, beta);
} else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) {
alpha = ctx.Attr<float>("threshold");
}
}
PADDLE_ENFORCE(in_x->dims().size() >= 1 || in_x->dims().size() <= 6,
platform::errors::Unimplemented(
"Input dimension size can be 1, 2, 3, 4, "
"5, or 6, but now the dimension size is",
in_x->dims().size()));
auto src_tz = framework::vectorize<int64_t>(in_x->dims());
auto src_fmt =
src_tz.size() == 2 ? MKLDNNMemoryFormat::nc : in_x->format();
auto md = mkldnn::memory::desc(src_tz, platform::MKLDNNGetDataType<T>(),
src_fmt);
this->AcquireForwardPrimitiveDescriptor(
mkldnn::prop_kind::forward_training, algorithm, md, alpha, beta);
}
}
ActivationMKLDNNHandler(mkldnn::algorithm algorithm,
const framework::ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx, Place cpu_place,
const framework::Tensor* in_x, const Tensor* out_grad,
const std::string& unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward, : platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward,
mkldnn::eltwise_backward>( mkldnn::eltwise_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dev_ctx, dims, "a", unique_name)) { platform::CreateKey(dev_ctx, framework::vectorize(in_x->dims()),
auto diff_dst_md = platform::MKLDNNMemDesc( "a", unique_name)) {
dims, platform::MKLDNNGetDataType<T>(), diff_fmt); if (!this->isBwdCached()) {
auto src_md = float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 0;
platform::MKLDNNMemDesc(dims, platform::MKLDNNGetDataType<T>(), fmt); float beta = ctx.HasAttr("beta") ? ctx.Attr<float>("beta") : 0;
this->AcquireBackwardPrimitiveDescriptor(algorithm, diff_dst_md, src_md, // paddle uses beta but mkldnn uses alpha for swish
alpha, beta); if (algorithm == mkldnn::algorithm::eltwise_swish) {
std::swap(alpha, beta);
} else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) {
alpha = ctx.Attr<float>("threshold");
}
auto diff_dst_tz = framework::vectorize<int64_t>(out_grad->dims());
auto src_fmt =
diff_dst_tz.size() == 2 ? MKLDNNMemoryFormat::nc : in_x->format();
auto diff_fmt =
diff_dst_tz.size() == 2 ? MKLDNNMemoryFormat::nc : out_grad->format();
auto dims = framework::vectorize(in_x->dims());
auto diff_dst_md = platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
auto src_md = platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), src_fmt);
this->AcquireForwardPrimitiveDescriptor(
mkldnn::prop_kind::forward_training, algorithm, src_md, alpha, beta);
this->AcquireBackwardPrimitiveDescriptor(algorithm, diff_dst_md, src_md,
alpha, beta);
}
} }
std::shared_ptr<mkldnn::memory> AcquireBackwardSrcMemory( std::shared_ptr<mkldnn::memory> AcquireBackwardSrcMemory(
......
...@@ -115,4 +115,6 @@ class TestMKLDNNBatchNormOpWithReluInference(TestBatchNormOpInference): ...@@ -115,4 +115,6 @@ class TestMKLDNNBatchNormOpWithReluInference(TestBatchNormOpInference):
if __name__ == '__main__': if __name__ == '__main__':
from paddle import enable_static
enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册