未验证 提交 905b0765 编写于 作者: X xiaoting 提交者: GitHub

rm max_input in conv2d for kunlun, test=kunlun (#28063)

上级 8600f474
......@@ -27,10 +27,10 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
// that avoids modifying the variable in the Scope.
Tensor filter = *context.Input<Tensor>("Filter");
Tensor* output = context.Output<Tensor>("Output");
Tensor* max_input = context.Output<Tensor>("MaxInput");
Tensor* max_filter = context.Output<Tensor>("MaxFilter");
max_input->mutable_data<T>(context.GetPlace());
max_filter->mutable_data<T>(context.GetPlace());
// Tensor* max_input = context.Output<Tensor>("MaxInput");
// Tensor* max_filter = context.Output<Tensor>("MaxFilter");
// max_input->mutable_data<T>(context.GetPlace());
// max_filter->mutable_data<T>(context.GetPlace());
output->mutable_data<T>(context.GetPlace());
int groups = context.Attr<int>("groups");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
......@@ -47,28 +47,28 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
dilations[0] == 1 && dilations[1] == 1, true,
platform::errors::InvalidArgument("XPU only support dilation == 1."));
auto& dev_ctx = context.template device_context<DeviceContext>();
PADDLE_ENFORCE_EQ(
xpu::findmax(dev_ctx.x_context(), input->data<T>(), input->numel(),
max_input->data<T>()) == xpu::Error_t::SUCCESS,
true, platform::errors::InvalidArgument(
"XPU conv kernel error,can not finde max_input,please "
"check whether Baidu Kunlun "
"Card is properly installed."));
PADDLE_ENFORCE_EQ(
xpu::findmax(dev_ctx.x_context(), filter.data<T>(), filter.numel(),
max_filter->data<T>()) == xpu::Error_t::SUCCESS,
true, platform::errors::InvalidArgument(
"XPU conv kernel error,can not find max_filter,please "
"check whether Baidu Kunlun "
"Card is properly installed."));
// PADDLE_ENFORCE_EQ(
// xpu::findmax(dev_ctx.x_context(), input->data<T>(), input->numel(),
// max_input->data<T>()) == xpu::Error_t::SUCCESS,
// true, platform::errors::InvalidArgument(
// "XPU conv kernel error,can not finde max_input,please "
// "check whether Baidu Kunlun "
// "Card is properly installed."));
// PADDLE_ENFORCE_EQ(
// xpu::findmax(dev_ctx.x_context(), filter.data<T>(), filter.numel(),
// max_filter->data<T>()) == xpu::Error_t::SUCCESS,
// true, platform::errors::InvalidArgument(
// "XPU conv kernel error,can not find max_filter,please "
// "check whether Baidu Kunlun "
// "Card is properly installed."));
if (groups == 1) {
int r = xpu::conv2d_forward_int16<float, float, float, float>(
dev_ctx.x_context(), batch_size, img_c, img_h, img_w, f, win_h, win_w,
strides[0], strides[1], paddings[0], paddings[1], dilations[0],
dilations[1], groups, input->data<float>(), filter.data<float>(),
output->data<float>(), nullptr, nullptr, xpu::Activation_t::LINEAR,
// nullptr, nullptr);
max_input->data<float>(), max_filter->data<float>());
nullptr, nullptr);
// max_input->data<float>(), max_filter->data<float>());
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d], "
......@@ -80,8 +80,8 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
dev_ctx.x_context(), input->data<float>(), filter.data<float>(),
output->data<float>(), batch_size, img_c, img_h, img_w, f, win_h,
win_w, groups, strides[0], strides[1], paddings[0], paddings[1],
// nullptr, nullptr);
max_input->data<float>(), max_filter->data<float>());
nullptr, nullptr);
// max_input->data<float>(), max_filter->data<float>());
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d], "
......@@ -96,9 +96,9 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input");
const Tensor* max_input = context.Input<Tensor>("MaxInput");
const Tensor* max_filter = context.Input<Tensor>("MaxFilter");
Tensor* max_output_grad = context.Output<Tensor>("MaxOutputGrad");
// const Tensor* max_input = context.Input<Tensor>("MaxInput");
// const Tensor* max_filter = context.Input<Tensor>("MaxFilter");
// Tensor* max_output_grad = context.Output<Tensor>("MaxOutputGrad");
const Tensor* output_grad =
context.Input<Tensor>(framework::GradVarName("Output"));
Tensor* input_grad =
......@@ -133,25 +133,25 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
filter_grad->mutable_data<T>(context.GetPlace());
}
auto& dev_ctx = context.template device_context<DeviceContext>();
max_output_grad->Resize({4});
max_output_grad->mutable_data<T>(context.GetPlace());
PADDLE_ENFORCE_EQ(
xpu::findmax(dev_ctx.x_context(), output_grad->data<T>(),
output_grad->numel(),
max_output_grad->data<T>()) == xpu::Error_t::SUCCESS,
true,
platform::errors::External(
"XPU conv kernel error, can not find max_output_grad, please check "
"whether Baidu Kunlun Card is "
"properly installed."));
// max_output_grad->Resize({4});
// max_output_grad->mutable_data<T>(context.GetPlace());
// PADDLE_ENFORCE_EQ(
// xpu::findmax(dev_ctx.x_context(), output_grad->data<T>(),
// output_grad->numel(),
// max_output_grad->data<T>()) == xpu::Error_t::SUCCESS,
// true,
// platform::errors::External(
// "XPU conv kernel error, can not find max_output_grad, please
// check "
// "whether Baidu Kunlun Card is "
// "properly installed."));
if (input_grad) {
int r = xpu::conv2d_backward_int16(
dev_ctx.x_context(), batch_size, img_c, img_h, img_w, f, win_h, win_w,
strides[0], strides[1], paddings[0], paddings[1], dilations[0],
dilations[1], groups, output_grad->data<float>(),
filter.data<float>(), input_grad->data<float>(),
// nullptr, nullptr,
max_output_grad->data<float>(), max_filter->data<float>());
filter.data<float>(), input_grad->data<float>(), nullptr, nullptr);
// max_output_grad->data<float>(), max_filter->data<float>());
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d], "
......@@ -164,9 +164,8 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
dev_ctx.x_context(), batch_size, img_c, img_h, img_w, f, win_h, win_w,
strides[0], strides[1], paddings[0], paddings[1], dilations[0],
dilations[1], groups, output_grad->data<float>(),
input->data<float>(), filter_grad->data<float>(),
// nullptr, nullptr,
max_output_grad->data<float>(), max_input->data<float>());
input->data<float>(), filter_grad->data<float>(), nullptr, nullptr);
// max_output_grad->data<float>(), max_input->data<float>());
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d], "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册