未验证 提交 ead81230 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Fix reshape move storage using error (#37765)

* fix reshape move storage error

* remove needless set type

* alloc tensor by shared storage
上级 1bdb8578
......@@ -383,13 +383,13 @@ class ReshapeKernel {
// 3. out tensor is view of input
// We can't MakePtenDenseTensor for case 2, so we solve this case by
// creating a temporary tensor here:
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
ctx.GetPlace());
pten::DenseTensorMeta meta{pten::TransToPtenDataType(in->type()),
in->dims(),
pten::TransToPtenDataLayout(in->layout())};
auto pt_out_tmp =
std::make_shared<pten::DenseTensor>(alloc, std::move(meta));
auto pt_out_tmp = std::make_shared<pten::DenseTensor>(
pten::make_intrusive<paddle::experimental::SharedStorage>(
ctx.GetPlace()),
std::move(meta));
pten::DenseTensor *pt_out = nullptr;
if (in == out) {
pt_out = pt_x.get();
......@@ -484,7 +484,8 @@ class ReshapeKernel {
// non-inplace need move all result from pt_out to out, inplace need set
// result dims.
if (in != out) {
paddle::experimental::MovesStorage(pt_out, static_cast<Tensor *>(out));
paddle::experimental::MovesSharedStorage(pt_out,
static_cast<Tensor *>(out));
} else {
out->Resize(pt_out->dims());
}
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include <utility>
#include <vector>
#include "paddle/pten/core/compat_utils.h"
......@@ -342,6 +343,29 @@ void MovesStorage(pten::DenseTensor* src, paddle::framework::LoDTensor* dst) {
MovesStorage(src, static_cast<paddle::framework::Tensor*>(dst));
}
void MovesSharedStorage(pten::DenseTensor* src,
paddle::framework::Tensor* dst) {
PADDLE_ENFORCE_NOT_NULL(
src,
platform::errors::InvalidArgument(
"The source DenseTensor is nullptr when move allocation."));
PADDLE_ENFORCE_NOT_NULL(
dst,
platform::errors::InvalidArgument(
"The destination Tensor is nullptr when move allocation."));
dst->Resize(src->dims());
auto* storage = static_cast<SharedStorage*>(
pten::CompatibleDenseTensorUtils::UnsafeGetMutableStorage(src));
dst->ResetHolderWithType(storage->GetAllocation(),
pten::TransToProtoVarType(src->dtype()));
}
void MovesSharedStorage(pten::DenseTensor* src,
paddle::framework::LoDTensor* dst) {
MovesSharedStorage(src, static_cast<paddle::framework::Tensor*>(dst));
SetLoD(dst->mutable_lod(), src->lod());
}
void ReMakePtenDenseTensor(const paddle::framework::Tensor& src,
const pten::TensorArgDef& arg_def,
pten::DenseTensor* dst) {
......
......@@ -58,6 +58,11 @@ void MovesStorage(pten::DenseTensor* src, paddle::framework::Tensor* dst);
void MovesStorage(pten::DenseTensor* src, paddle::framework::LoDTensor* dst);
void MovesSharedStorage(pten::DenseTensor* src, paddle::framework::Tensor* dst);
void MovesSharedStorage(pten::DenseTensor* src,
paddle::framework::LoDTensor* dst);
/**
* In order to improve the compatibility state performance, some tricky tool
* functions are added.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册