未验证 提交 85e5ab2e 编写于 作者: F furnace 提交者: GitHub

[NPU] add int64 support for scatter op (#37440)

* [NPU] add int64 support for scatter op

* [NPU] delete debug codes

* [NPU] optimize codes
上级 ddf38a3f
......@@ -48,18 +48,49 @@ class ScatterNPUKernel : public framework::OpKernel<T> {
index = &tmp_tensor;
}
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
const auto& dev_ctx =
ctx.template device_context<paddle::platform::NPUDeviceContext>();
auto op_func_update = [](const std::vector<Tensor>& inputs,
const std::vector<Tensor>& outputs,
const NPUAttributeMap& attrs,
const platform::NPUDeviceContext& dev_ctx) {
const auto& runner =
NpuOpRunner("TensorScatterUpdate", inputs, outputs, attrs);
runner.Run(dev_ctx.stream());
};
auto op_func_add = [](const std::vector<Tensor>& inputs,
const std::vector<Tensor>& outputs,
const NPUAttributeMap& attrs,
const platform::NPUDeviceContext& dev_ctx) {
const auto& runner =
NpuOpRunner("TensorScatterAdd", inputs, outputs, attrs);
runner.Run(dev_ctx.stream());
};
if (overwrite) {
const auto& runner_update = NpuOpRunner(
"TensorScatterUpdate", {*x, *index, *updates}, {*out}, {});
runner_update.Run(stream);
if (x->type() == framework::proto::VarType::INT64) {
NpuOpRunner::TypeAdapter(
{*x, *index, *updates}, {*out}, {}, dev_ctx, op_func_update,
{framework::proto::VarType::INT32, framework::proto::VarType::INT32,
framework::proto::VarType::INT32},
{framework::proto::VarType::INT32});
} else {
const auto& runner_update = NpuOpRunner(
"TensorScatterUpdate", {*x, *index, *updates}, {*out}, {});
runner_update.Run(dev_ctx.stream());
}
} else {
const auto& runner_add =
NpuOpRunner("TensorScatterAdd", {*x, *index, *updates}, {*out}, {});
runner_add.Run(stream);
if (x->type() == framework::proto::VarType::INT64) {
NpuOpRunner::TypeAdapter(
{*x, *index, *updates}, {*out}, {}, dev_ctx, op_func_add,
{framework::proto::VarType::INT32, framework::proto::VarType::INT32,
framework::proto::VarType::INT32},
{framework::proto::VarType::INT32});
} else {
const auto& runner_add =
NpuOpRunner("TensorScatterAdd", {*x, *index, *updates}, {*out}, {});
runner_add.Run(dev_ctx.stream());
}
}
}
};
......@@ -70,6 +101,10 @@ namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(
scatter, ops::ScatterNPUKernel<paddle::platform::NPUDeviceContext, float>,
#ifdef PADDLE_WITH_ASCEND_INT64
ops::ScatterNPUKernel<paddle::platform::NPUDeviceContext, int64_t>,
#endif
ops::ScatterNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::ScatterNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
#endif
......@@ -27,7 +27,7 @@ paddle.enable_static()
SEED = 2021
class TestCast1(OpTest):
class TestCast1_FP32(OpTest):
def setUp(self):
self.set_npu()
self.op_type = "scatter"
......@@ -50,7 +50,7 @@ class TestCast1(OpTest):
self.check_output_with_place(self.place)
class TestCast2(OpTest):
class TestCast_INT32(OpTest):
def setUp(self):
self.set_npu()
self.op_type = "scatter"
......@@ -73,7 +73,7 @@ class TestCast2(OpTest):
self.check_output_with_place(self.place)
class TestCast3(OpTest):
class TestCast2_FP32(OpTest):
def setUp(self):
self.set_npu()
self.op_type = "scatter"
......@@ -96,7 +96,7 @@ class TestCast3(OpTest):
self.check_output_with_place(self.place)
class TestCast4(OpTest):
class TestCast3_FP32(OpTest):
def setUp(self):
self.set_npu()
self.op_type = "scatter"
......@@ -120,5 +120,28 @@ class TestCast4(OpTest):
self.check_output_with_place(self.place)
class TestCast_INT64(OpTest):
def setUp(self):
self.set_npu()
self.op_type = "scatter"
self.place = paddle.NPUPlace(0)
ref_np = np.ones((3, 2)).astype("int64")
index_np = np.array([1]).astype("int32")
updates_np = np.zeros((1, 2)).astype("int64")
output_np = np.copy(ref_np)
output_np[index_np] = updates_np
self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
self.outputs = {'Out': output_np}
self.attrs = {'overwrite': True}
def set_npu(self):
self.__class__.use_npu = True
def test_check_output(self):
self.check_output_with_place(self.place)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册