未验证 提交 f9ce1b1a 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] Further ops refactoring of oneDNN cache access (#33515)

* - Draft of implementation of refactoring

- compilation fix

* - Fixes after review

* - Removed unnecessary comment
上级 4ddd595f
......@@ -83,30 +83,11 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
const auto *x = ctx.Input<Tensor>("X");
auto *y = ctx.Output<Tensor>("Out");
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 0;
float beta = ctx.HasAttr("beta") ? ctx.Attr<float>("beta") : 0;
// paddle uses beta but mkldnn uses alpha for swish
if (algorithm == mkldnn::algorithm::eltwise_swish) {
std::swap(alpha, beta);
} else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) {
alpha = ctx.Attr<float>("threshold");
}
PADDLE_ENFORCE(
x->dims().size() >= 1 || x->dims().size() <= 6,
platform::errors::Unimplemented("Input dimension size can be 1, 2, 3, 4, "
"5, or 6, but now the dimension size is",
x->dims().size()));
bool is_inplaced = x->IsSharedBufferWith(*y);
auto src_tz = framework::vectorize<int64_t>(x->dims());
auto src_format = src_tz.size() == 2 ? MKLDNNMemoryFormat::nc : x->format();
platform::ActivationMKLDNNHandler<T> handler(
src_tz, algorithm, alpha, beta, src_format, dev_ctx, ctx.GetPlace(),
ctx.InputName("X"), is_inplaced);
platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, dev_ctx,
ctx.GetPlace(), x,
ctx.InputName("X"), is_inplaced);
auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = is_inplaced ? src_memory_p : handler.AcquireDstMemory(y);
......@@ -130,28 +111,8 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X"));
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 0;
float beta = ctx.HasAttr("beta") ? ctx.Attr<float>("beta") : 0;
// paddle uses beta but mkldnn uses alpha for swish
if (algorithm == mkldnn::algorithm::eltwise_swish) {
std::swap(alpha, beta);
} else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) {
alpha = ctx.Attr<float>("threshold");
}
auto diff_dst_tz = framework::vectorize<int64_t>(diff_y->dims());
// diff_dst and src dims should be the same
auto src_format =
diff_dst_tz.size() == 2 ? MKLDNNMemoryFormat::nc : x->format();
auto diff_y_format =
diff_dst_tz.size() == 2 ? MKLDNNMemoryFormat::nc : diff_y->format();
platform::ActivationMKLDNNHandler<T> handler(
diff_dst_tz, algorithm, alpha, beta, src_format, diff_y_format, dev_ctx,
ctx.GetPlace(), ctx.InputName("X"));
algorithm, ctx, dev_ctx, ctx.GetPlace(), x, diff_y, ctx.InputName("X"));
auto src_memory_p = handler.AcquireBackwardSrcMemory(x);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(diff_y);
......
......@@ -85,24 +85,54 @@ class BatchNormMKLDNNHandler
md, epsilon, flags);
}
}
BatchNormMKLDNNHandler(const std::vector<int64_t> &dims, const float &epsilon,
const mkldnn::normalization_flags &flags,
const MKLDNNMemoryFormat diff_fmt,
const MKLDNNMemoryFormat src_fmt,
BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx,
const platform::MKLDNNDeviceContext &dev_ctx,
platform::Place cpu_place,
const std::string &uniq_name)
platform::Place cpu_place, const Tensor *in_x,
const Tensor *scale, const Tensor *out_grad,
const std::string &unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward,
mkldnn::batch_normalization_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dev_ctx, dims, uniq_name)) {
auto diff_dst_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
auto src_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), src_fmt);
this->AcquireBackwardPrimitiveDescriptor(
mkldnn::prop_kind::backward, diff_dst_md, src_md, epsilon, flags);
platform::CreateKey(dev_ctx, framework::vectorize(in_x->dims()),
unique_name)) {
if (!this->isBwdCached()) {
PADDLE_ENFORCE_EQ(out_grad->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"Wrong layout set for Input out_grad tensor"));
PADDLE_ENFORCE_NE(out_grad->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Wrong format set for Input out_grad tensor"));
auto src_tz = paddle::framework::vectorize<int64_t>(in_x->dims());
auto scale_tz = paddle::framework::vectorize<int64_t>(scale->dims());
PADDLE_ENFORCE_EQ(
scale_tz.size(), 1,
platform::errors::InvalidArgument(
"Dims of scale tensor must be 1, but received scale's size is %d",
scale_tz.size()));
MKLDNNMemoryFormat diff_fmt =
platform::MKLDNNFormatForSize(src_tz.size(), out_grad->format());
MKLDNNMemoryFormat src_fmt =
platform::MKLDNNFormatForSize(src_tz.size(), in_x->format());
auto dims = framework::vectorize(in_x->dims());
auto diff_dst_md = mkldnn::memory::desc(
dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
auto src_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), src_fmt);
const float epsilon = ctx.Attr<float>("epsilon");
this->AcquireForwardPrimitiveDescriptor(
mkldnn::prop_kind::forward_training, src_md, epsilon,
mkldnn::normalization_flags::use_scale_shift);
this->AcquireBackwardPrimitiveDescriptor(
mkldnn::prop_kind::backward, diff_dst_md, src_md, epsilon,
mkldnn::normalization_flags::use_scale_shift);
}
}
std::shared_ptr<mkldnn::memory> AcquireScaleShiftMemory(const Tensor *scale,
......@@ -263,8 +293,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto mkldnn_engine = dev_ctx.GetEngine();
const float epsilon = ctx.Attr<float>("epsilon");
const auto *x = ctx.Input<Tensor>("X");
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *shift = ctx.Input<Tensor>("Bias");
......@@ -275,35 +303,11 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto *diff_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *diff_shift = ctx.Output<Tensor>(framework::GradVarName("Bias"));
PADDLE_ENFORCE_EQ(diff_y->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"Wrong layout set for Input diff_y tensor"));
PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Wrong format set for Input diff_y tensor"));
auto src_tz = paddle::framework::vectorize<int64_t>(x->dims());
auto scale_tz = paddle::framework::vectorize<int64_t>(scale->dims());
PADDLE_ENFORCE_EQ(
scale_tz.size(), 1,
platform::errors::InvalidArgument(
"Dims of scale tensor must be 1, but received scale's size is %d",
scale_tz.size()));
const unsigned int C = scale_tz[0];
MKLDNNMemoryFormat dst_format =
platform::MKLDNNFormatForSize(src_tz.size(), diff_y->format());
MKLDNNMemoryFormat input_format =
platform::MKLDNNFormatForSize(src_tz.size(), x->format());
BatchNormMKLDNNHandler<T> handler(
src_tz, epsilon, mkldnn::normalization_flags::use_scale_shift,
dst_format, input_format, dev_ctx, ctx.GetPlace(),
ctx.InputName("SavedMean"));
BatchNormMKLDNNHandler<T> handler(ctx, dev_ctx, ctx.GetPlace(), x, scale,
diff_y, ctx.InputName("SavedMean"));
// MKLDNN requires a single piece of memory for scale and shift/bias data
const unsigned int C = paddle::framework::vectorize(scale->dims())[0];
const size_t scaleshift_size = 2 * C;
std::vector<T> diff_scaleshift_data;
diff_scaleshift_data.reserve(scaleshift_size);
......@@ -335,7 +339,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
T *diff_scale_data = diff_scale->mutable_data<T>(ctx.GetPlace());
T *diff_shift_data = diff_shift->mutable_data<T>(ctx.GetPlace());
// copy back diff sacle/shift to output tensors (diff scale/shift)
// copy back diff scale/shift to output tensors (diff scale/shift)
diff_scaleshift_data.resize(scaleshift_size);
auto it = std::begin(diff_scaleshift_data);
std::copy(it, std::next(it, C), diff_scale_data);
......
......@@ -90,7 +90,7 @@ class ConvMKLDNNHandlerT
dev_ctx, mkldnn_engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(input->dims()),
unique_name)) {
if (!this->isCachedNonBlocking()) {
if (!this->isCached()) {
PADDLE_ENFORCE_EQ(
input->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
......@@ -228,12 +228,12 @@ class ConvMKLDNNHandlerT
auto bias_md =
platform::MKLDNNMemDesc(bias_tz, data_type, MKLDNNMemoryFormat::x);
this->AcquireForwardPrimitiveDescriptorNonBlocking(
this->AcquireForwardPrimitiveDescriptor(
conv_attr, fwd_prop_kind, dnnl::algorithm::convolution_direct,
src_md, weights_md, bias_md, dst_md, stride_dims, dilations_dims,
mkldnn_paddings[0], mkldnn_paddings[1]);
} else {
this->AcquireForwardPrimitiveDescriptorNonBlocking(
this->AcquireForwardPrimitiveDescriptor(
conv_attr, fwd_prop_kind, dnnl::algorithm::convolution_direct,
src_md, weights_md, dst_md, stride_dims, dilations_dims,
mkldnn_paddings[0], mkldnn_paddings[1]);
......@@ -352,25 +352,25 @@ class ConvMKLDNNHandlerT
auto bias_md = platform::MKLDNNMemDesc(
bias_tz, mkldnn::memory::data_type::f32, MKLDNNMemoryFormat::x);
this->AcquireForwardPrimitiveDescriptorNonBlocking(
this->AcquireForwardPrimitiveDescriptor(
conv_attr, mkldnn::prop_kind::forward_training,
dnnl::algorithm::convolution_direct, src_md, weights_md, bias_md,
dst_md, stride_dims, dilations_dims, mkldnn_paddings[0],
mkldnn_paddings[1]);
} else {
this->AcquireForwardPrimitiveDescriptorNonBlocking(
this->AcquireForwardPrimitiveDescriptor(
conv_attr, mkldnn::prop_kind::forward_training,
dnnl::algorithm::convolution_direct, src_md, weights_md, dst_md,
stride_dims, dilations_dims, mkldnn_paddings[0],
mkldnn_paddings[1]);
}
this->AcquireBackwardPrimitiveDescriptorNonBlocking(
this->AcquireBackwardPrimitiveDescriptor(
mkldnn::algorithm::convolution_direct, diff_src_md, weights_md,
diff_dst_md, strides, dilations_dims, mkldnn_paddings[0],
mkldnn_paddings[1]);
this->AcquireBackwardWeightsPrimitiveDescriptorNonBlocking(
this->AcquireBackwardWeightsPrimitiveDescriptor(
mkldnn::algorithm::convolution_direct, src_md, diff_weights_md,
diff_dst_md, strides, dilations_dims, mkldnn_paddings[0],
mkldnn_paddings[1]);
......
......@@ -34,7 +34,7 @@ class LRNMKLDNNHandler : public platform::MKLDNNHandlerT<T, mkldnn::lrn_forward,
dev_ctx, mkldnn_engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(input->dims()),
unique_name)) {
if (!this->isCachedNonBlocking()) {
if (!this->isCached()) {
const int n = ctx.Attr<int>("n");
// MKL-DNN implements LRN in a caffe way:
// http://caffe.berkeleyvision.org/tutorial/layers/lrn.html
......@@ -52,7 +52,7 @@ class LRNMKLDNNHandler : public platform::MKLDNNHandlerT<T, mkldnn::lrn_forward,
auto src_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(),
input->format());
this->AcquireForwardPrimitiveDescriptorNonBlocking(
this->AcquireForwardPrimitiveDescriptor(
is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training,
mkldnn::algorithm::lrn_across_channels, src_md, n, alpha, beta, k);
......@@ -86,11 +86,11 @@ class LRNMKLDNNHandler : public platform::MKLDNNHandlerT<T, mkldnn::lrn_forward,
auto diff_md = mkldnn::memory::desc(
dims, platform::MKLDNNGetDataType<T>(), out_grad->format());
this->AcquireForwardPrimitiveDescriptorNonBlocking(
this->AcquireForwardPrimitiveDescriptor(
mkldnn::prop_kind::forward_training,
mkldnn::algorithm::lrn_across_channels, src_md, n, alpha, beta, k);
this->AcquireBackwardPrimitiveDescriptorNonBlocking(
this->AcquireBackwardPrimitiveDescriptor(
mkldnn::algorithm::lrn_across_channels, src_md, diff_md, n, alpha,
beta, k);
}
......
......@@ -43,7 +43,7 @@ class PoolingMKLDNNHandler
platform::CreateKey(dev_ctx, framework::vectorize(input->dims()),
framework::ToMKLDNNDataType(input->type()),
unique_name)) {
if (!this->isCachedNonBlocking()) {
if (!this->isCached()) {
PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"Wrong layout set for Input tensor."));
......@@ -123,7 +123,7 @@ class PoolingMKLDNNHandler
ComputeAdaptivePoolParameters(ctx, src_tz, &ksize, &strides);
this->AcquireForwardPrimitiveDescriptorNonBlocking(
this->AcquireForwardPrimitiveDescriptor(
is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training,
pooling_type == "max"
......@@ -220,7 +220,7 @@ class PoolingMKLDNNHandler
const auto exclude_padding = ctx.Attr<bool>("exclusive");
this->AcquireForwardPrimitiveDescriptorNonBlocking(
this->AcquireForwardPrimitiveDescriptor(
mkldnn::prop_kind::forward_training,
pooling_type == "max"
? mkldnn::algorithm::pooling_max
......@@ -230,7 +230,7 @@ class PoolingMKLDNNHandler
src_md, dst_md, strides, ksize, mkldnn_paddings[0],
mkldnn_paddings[1]);
this->AcquireBackwardPrimitiveDescriptorNonBlocking(
this->AcquireBackwardPrimitiveDescriptor(
pooling_type == "max"
? mkldnn::algorithm::pooling_max
: (exclude_padding
......
......@@ -30,28 +30,14 @@ class ScaleMKLDNNKernel : public framework::OpKernel<T> {
const auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
bool bias_after_scale = ctx.Attr<bool>("bias_after_scale");
auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
auto* scale_tensor = ctx.Input<Tensor>("ScaleTensor");
float scale = (scale_tensor == nullptr) ? ctx.Attr<float>("scale")
: (float)*(scale_tensor->data<T>());
float bias = ctx.Attr<float>("bias");
// if bias_after_scale == true
// out = scale*X + bias
// else
// out = scale*(X + bias) = scale*X + scale*bias
if (!bias_after_scale) bias *= scale;
auto x_tz = framework::vectorize<int64_t>(x->dims());
bool is_inplaced = x->IsSharedBufferWith(*out);
platform::ActivationMKLDNNHandler<T> handler(
x_tz, mkldnn::algorithm::eltwise_linear, scale, bias, x->format(),
dev_ctx, ctx.GetPlace(), ctx.InputName("X"), is_inplaced);
mkldnn::algorithm::eltwise_linear, ctx, dev_ctx, ctx.GetPlace(), x,
ctx.InputName("X"), is_inplaced);
auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireDstMemory(out);
......
......@@ -50,7 +50,7 @@ class SoftmaxMKLDNNHandler
: platform::CreateKey(
dev_ctx, framework::vectorize(input->dims()),
uniq_name)) {
if (!this->isCachedNonBlocking()) {
if (!this->isCached()) {
PADDLE_ENFORCE_EQ(
input->dims(), output->dims(),
platform::errors::InvalidArgument(
......@@ -60,8 +60,8 @@ class SoftmaxMKLDNNHandler
auto md = memory::desc(softmax_tz, platform::MKLDNNGetDataType<T>(),
input->format());
this->AcquireForwardPrimitiveDescriptorNonBlocking(
prop_kind::forward_scoring, md, axis);
this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, md,
axis);
}
}
......@@ -90,10 +90,10 @@ class SoftmaxMKLDNNHandler
auto diff_softmax_md = MKLDNNMemDesc(
softmax_tz, platform::MKLDNNGetDataType<T>(), out_grad->format());
this->AcquireForwardPrimitiveDescriptorNonBlocking(
prop_kind::forward_scoring, data_softmax_md, axis);
this->AcquireBackwardPrimitiveDescriptorNonBlocking(
diff_softmax_md, data_softmax_md, axis);
this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring,
data_softmax_md, axis);
this->AcquireBackwardPrimitiveDescriptor(diff_softmax_md, data_softmax_md,
axis);
}
}
};
......
......@@ -118,17 +118,6 @@ class SumMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::sum> {
inline int GetNumInputs(void) { return num_inputs_; }
protected:
// isCached need to be overloaded as base one works on key_common
bool isCached() {
const std::string key_pd = this->key_ + "@fwd_pd";
this->fwd_pd_ = std::static_pointer_cast<dnnl::sum::primitive_desc>(
this->dev_ctx_.GetBlob(key_pd));
const std::string key_p = this->key_ + "@fwd_p";
return (this->dev_ctx_.GetBlob(key_p) != nullptr);
}
private:
int num_inputs_;
std::vector<std::string> srcs_suffix_;
......
......@@ -157,15 +157,6 @@ class MKLDNNHandlerT {
protected:
bool isCached() {
const std::string key_pd = key_common_ + "@fwd_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
const std::string key_p = key_ + "@fwd_p";
return (dev_ctx_.GetBlob(key_p) != nullptr);
}
bool isCachedNonBlocking() {
const std::string key_pd = key_ + "@fwd_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
......@@ -178,7 +169,18 @@ class MKLDNNHandlerT {
bwd_pd_ = std::static_pointer_cast<typename TBackward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
return (bwd_pd_ != nullptr);
if (bwd_pd_ == nullptr) {
return false;
} else {
// When BWD is cached then still we need to Get FWD PD
const std::string key_fpd = key_ + "@fwd_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_fpd));
PADDLE_ENFORCE_NOT_NULL(
fwd_pd_, platform::errors::Unavailable(
"Error: FWD PD should be set when BWD PD is cached."));
return true;
}
}
// If your primitive descriptor requires attributes, pass them as a
......@@ -187,29 +189,6 @@ class MKLDNNHandlerT {
// constructor, including the first one.
template <typename Arg, typename... Args>
void AcquireForwardPrimitiveDescriptor(Arg&& first_arg, Args&&... args) {
// Forward PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
const std::string key_pd = key_common_ + "@fwd_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
if (fwd_pd_ == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
if (fwd_pd_ == nullptr) {
CreateForwardPrimitiveDescriptor(first_arg,
std::forward<Args>(args)...);
dev_ctx_.SetBlob(key_pd, fwd_pd_);
}
}
}
template <typename Arg, typename... Args>
void AcquireForwardPrimitiveDescriptorNonBlocking(Arg&& first_arg,
Args&&... args) {
// This is used when we can recreate FWD PD in BWD so
// we do not need to pass FWD to BWD
const std::string key_pd = key_ + "@fwd_pd";
......@@ -242,31 +221,10 @@ class MKLDNNHandlerT {
std::make_shared<typename TForward::primitive_desc>(fwd_desc, engine_);
}
// TODO(jczaja): After/if all ops can used xxxNonBlocking version
// then remove this one
template <typename... Args>
void AcquireBackwardPrimitiveDescriptor(Args&&... args) {
const std::string key_fwd_pd = key_common_ + "@fwd_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_fwd_pd));
PADDLE_ENFORCE_NOT_NULL(
fwd_pd_, platform::errors::Unavailable(
"Get MKLDNN Forward primitive %s failed.", key_fwd_pd));
const std::string key_pd = key_ + "@bwd_pd";
bwd_pd_ = std::static_pointer_cast<typename TBackward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
if (bwd_pd_ == nullptr) {
auto bwd_desc = typename TBackward::desc(std::forward<Args>(args)...);
bwd_pd_ = std::make_shared<typename TBackward::primitive_desc>(
bwd_desc, engine_, *fwd_pd_);
dev_ctx_.SetBlob(key_pd, bwd_pd_);
}
}
template <typename... Args>
void AcquireBackwardPrimitiveDescriptorNonBlocking(Args&&... args) {
// fwd_pd_ is set during grad by calling
// AcquireForwardPrimitiveDescriptorNonBlocking
// AcquireForwardPrimitiveDescriptor
PADDLE_ENFORCE_NOT_NULL(
fwd_pd_,
platform::errors::Unavailable("Get MKLDNN Forward primitive %s failed.",
......@@ -283,9 +241,9 @@ class MKLDNNHandlerT {
}
template <typename... Args>
void AcquireBackwardWeightsPrimitiveDescriptorNonBlocking(Args&&... args) {
void AcquireBackwardWeightsPrimitiveDescriptor(Args&&... args) {
// fwd_pd_ is set during grad by calling
// AcquireForwardPrimitiveDescriptorNonBlocking
// AcquireForwardPrimitiveDescriptor
PADDLE_ENFORCE_NOT_NULL(
fwd_pd_,
platform::errors::Unavailable("Get MKLDNN Forward primitive %s failed.",
......@@ -834,45 +792,100 @@ class ActivationMKLDNNHandler
: public MKLDNNHandlerT<T, mkldnn::eltwise_forward,
mkldnn::eltwise_backward> {
public:
ActivationMKLDNNHandler(const std::vector<int64_t>& dims,
mkldnn::algorithm algorithm, float alpha, float beta,
const MKLDNNMemoryFormat fmt,
const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place,
ActivationMKLDNNHandler(mkldnn::algorithm algorithm,
const framework::ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx, Place cpu_place,
const framework::Tensor* in_x,
const std::string& unique_name, bool is_inplaced)
: platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward,
mkldnn::eltwise_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
is_inplaced
? platform::CreateKey(dev_ctx, dims, "a", algorithm,
unique_name)
: platform::CreateKey(dev_ctx, dims, "a", unique_name)) {
auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training,
algorithm, md, alpha, beta);
}
ActivationMKLDNNHandler(const std::vector<int64_t>& dims,
mkldnn::algorithm algorithm, float alpha, float beta,
const MKLDNNMemoryFormat fmt,
const MKLDNNMemoryFormat diff_fmt,
const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place,
const std::string& unique_name)
is_inplaced ? platform::CreateKey(
dev_ctx, framework::vectorize(in_x->dims()), "a",
algorithm, unique_name)
: platform::CreateKey(
dev_ctx, framework::vectorize(in_x->dims()), "a",
unique_name)) {
if (!this->isCached()) {
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 0;
float beta = ctx.HasAttr("beta") ? ctx.Attr<float>("beta") : 0;
// eltwise_linear means we are in scale op
if (algorithm == mkldnn::algorithm::eltwise_linear) {
bool bias_after_scale = ctx.Attr<bool>("bias_after_scale");
auto* scale_tensor = ctx.Input<Tensor>("ScaleTensor");
alpha = (scale_tensor == nullptr) ? ctx.Attr<float>("scale")
: (float)*(scale_tensor->data<T>());
beta = ctx.Attr<float>("bias");
// if bias_after_scale == true
// out = scale*X + bias
// else
// out = scale*(X + bias) = scale*X + scale*bias
if (!bias_after_scale) beta *= alpha;
} else {
// paddle uses beta but mkldnn uses alpha for swish
if (algorithm == mkldnn::algorithm::eltwise_swish) {
std::swap(alpha, beta);
} else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) {
alpha = ctx.Attr<float>("threshold");
}
}
PADDLE_ENFORCE(in_x->dims().size() >= 1 || in_x->dims().size() <= 6,
platform::errors::Unimplemented(
"Input dimension size can be 1, 2, 3, 4, "
"5, or 6, but now the dimension size is",
in_x->dims().size()));
auto src_tz = framework::vectorize<int64_t>(in_x->dims());
auto src_fmt =
src_tz.size() == 2 ? MKLDNNMemoryFormat::nc : in_x->format();
auto md = mkldnn::memory::desc(src_tz, platform::MKLDNNGetDataType<T>(),
src_fmt);
this->AcquireForwardPrimitiveDescriptor(
mkldnn::prop_kind::forward_training, algorithm, md, alpha, beta);
}
}
ActivationMKLDNNHandler(mkldnn::algorithm algorithm,
const framework::ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx, Place cpu_place,
const framework::Tensor* in_x, const Tensor* out_grad,
const std::string& unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward,
mkldnn::eltwise_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dev_ctx, dims, "a", unique_name)) {
auto diff_dst_md = platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
auto src_md =
platform::MKLDNNMemDesc(dims, platform::MKLDNNGetDataType<T>(), fmt);
this->AcquireBackwardPrimitiveDescriptor(algorithm, diff_dst_md, src_md,
alpha, beta);
platform::CreateKey(dev_ctx, framework::vectorize(in_x->dims()),
"a", unique_name)) {
if (!this->isBwdCached()) {
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 0;
float beta = ctx.HasAttr("beta") ? ctx.Attr<float>("beta") : 0;
// paddle uses beta but mkldnn uses alpha for swish
if (algorithm == mkldnn::algorithm::eltwise_swish) {
std::swap(alpha, beta);
} else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) {
alpha = ctx.Attr<float>("threshold");
}
auto diff_dst_tz = framework::vectorize<int64_t>(out_grad->dims());
auto src_fmt =
diff_dst_tz.size() == 2 ? MKLDNNMemoryFormat::nc : in_x->format();
auto diff_fmt =
diff_dst_tz.size() == 2 ? MKLDNNMemoryFormat::nc : out_grad->format();
auto dims = framework::vectorize(in_x->dims());
auto diff_dst_md = platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
auto src_md = platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), src_fmt);
this->AcquireForwardPrimitiveDescriptor(
mkldnn::prop_kind::forward_training, algorithm, src_md, alpha, beta);
this->AcquireBackwardPrimitiveDescriptor(algorithm, diff_dst_md, src_md,
alpha, beta);
}
}
std::shared_ptr<mkldnn::memory> AcquireBackwardSrcMemory(
......
......@@ -115,4 +115,6 @@ class TestMKLDNNBatchNormOpWithReluInference(TestBatchNormOpInference):
if __name__ == '__main__':
from paddle import enable_static
enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册