提交 408e638c 编写于 作者: Y Yihua Xu 提交者: Tao Luo

Fix the crash issue when scale or bias was null-pointer. (#21284) (#21444)

* Fix the crash issue when scale or bias was null-pointer.

* Add the error message for passing CI.

test=release/1.6
上级 77268831
...@@ -224,17 +224,35 @@ class LayerNormKernel : public framework::OpKernel<T> { ...@@ -224,17 +224,35 @@ class LayerNormKernel : public framework::OpKernel<T> {
ctx, &out, bias, /*axis*/ 1, AddFunctor<T>(), &out); ctx, &out, bias, /*axis*/ 1, AddFunctor<T>(), &out);
} }
#else #else
PADDLE_ENFORCE_EQ(mean->numel(), left); PADDLE_ENFORCE_EQ(mean->numel(), left,
PADDLE_ENFORCE_EQ(var->numel(), left); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(scale->numel(), right); "mean's length (%d) is not equal with expected (%d).",
PADDLE_ENFORCE_EQ(bias->numel(), right); mean->numel(), left));
PADDLE_ENFORCE_EQ(var->numel(), left,
platform::errors::InvalidArgument(
"var's length (%d) is not equal with expected (%d).",
var->numel(), left));
if (scale) {
PADDLE_ENFORCE_EQ(
scale->numel(), right,
platform::errors::InvalidArgument(
"scale's length (%d) is not equal with expected (%d).",
scale->numel(), right));
}
if (bias) {
PADDLE_ENFORCE_EQ(
bias->numel(), right,
platform::errors::InvalidArgument(
"bias's length (%d) is not equal with expected (%d).",
bias->numel(), right));
}
auto ker = auto ker =
jit::KernelFuncs<jit::LayerNormTuple<T>, platform::CPUPlace>::Cache() jit::KernelFuncs<jit::LayerNormTuple<T>, platform::CPUPlace>::Cache()
.At(right); .At(right);
ker(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(), ker(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(),
scale->data<T>(), bias->data<T>(), static_cast<int>(left), scale ? scale->data<T>() : nullptr, bias ? bias->data<T>() : nullptr,
static_cast<const float>(epsilon), right); static_cast<int>(left), static_cast<const float>(epsilon), right);
#endif #endif
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册