未验证 提交 d71b9ba7 编写于 作者: L Leo Chen 提交者: GitHub

[NPU] Avoid cpu tensor freed before copying to npu completed (#34475)

上级 76710e5f
......@@ -25,7 +25,7 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_MKLDNN
#include "dnnl_debug.h"
#include "dnnl_debug.h" // NOLINT
#endif
namespace paddle {
......@@ -112,11 +112,32 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
}
else if (platform::is_cpu_place(src_place) && // NOLINT
platform::is_npu_place(dst_place)) {
auto stream =
reinterpret_cast<const platform::NPUDeviceContext&>(ctx).stream();
memory::Copy(BOOST_GET_CONST(platform::NPUPlace, dst_place), dst_ptr,
BOOST_GET_CONST(platform::CPUPlace, src_place), src_ptr, size,
stream);
// 1. cpu tensor -> npu pinned tensor
platform::NPUPinnedPlace npu_pinned_place;
Tensor npu_pinned_tensor;
npu_pinned_tensor.Resize(src.dims());
auto npu_pinned_ptr =
npu_pinned_tensor.mutable_data(npu_pinned_place, src.type());
memory::Copy(npu_pinned_place, npu_pinned_ptr,
BOOST_GET_CONST(platform::CPUPlace, src_place), src_ptr, size);
// 2. async copy npu pinned tensor -> npu tensor
memory::Copy(
BOOST_GET_CONST(platform::NPUPlace, dst_place), dst_ptr,
npu_pinned_place, npu_pinned_ptr, size,
reinterpret_cast<const platform::NPUDeviceContext&>(ctx).stream());
// 3. record event
auto npu_pinned_allocator =
static_cast<paddle::memory::allocation::NPUPinnedAllocator*>(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(npu_pinned_place)
.get());
paddle::memory::allocation::Allocation* allocation =
npu_pinned_tensor.Holder().get();
npu_pinned_allocator->RecordEvent(
allocation,
reinterpret_cast<const platform::NPUDeviceContext&>(ctx).stream());
}
else if (platform::is_npu_place(src_place) && // NOLINT
platform::is_npu_place(dst_place)) {
......
......@@ -40,10 +40,6 @@ class LookupTableV2NPUKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument("npu only accept LoDTensor"));
output_t->mutable_data<T>(ctx.GetPlace());
// add copy ids to ensure ids_t is prepared.
std::vector<int> ids;
TensorToVector(*ids_t, ctx.device_context(), &ids);
NpuOpRunner runner;
runner.SetType("GatherV2")
.AddInput(*table_t)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册