未验证 提交 90dad8b2 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Add cast method for Tensor and rename to method to copy_to (#37423)

* rename to api to copy_to

* support cast method for tensor

* fix compile failed
上级 73f4601d
......@@ -86,16 +86,13 @@ class AbstractAutogradMeta {
class PD_DLL_DECL Tensor final {
public:
/* Part 1: Construction and destruction methods */
/**
* @brief Construct a new Tensor object
*/
Tensor() = default;
/**
* @brief Construct a new Tensor object with name
* */
explicit Tensor(const std::string& name) { name_ = name; }
/**
* @brief Construct a new Tensor object by copy
*/
......@@ -132,18 +129,14 @@ class PD_DLL_DECL Tensor final {
Tensor(const PlaceType& place, const std::vector<int64_t>& shape);
/**
* @brief Return the name of Tensor.
* @brief Construct a new Tensor object with name
*
* @return const std::string&
*/
const std::string& name() const { return name_; }
* @note Used to adapt original execution mechanism and debug analysis
* in the development of new dygraph. It may be removed in the future.
* */
explicit Tensor(const std::string& name) : name_(name) {}
/**
* @brief Set name of Tensor.
*
* @param const std::string& name
*/
void set_name(const std::string& name) { name_ = name; }
/* Part 2: Dimension, DataType and DataLayout methods */
/**
* @brief Return the number of elements of Tensor.
......@@ -179,7 +172,7 @@ class PD_DLL_DECL Tensor final {
/**
* @brief Reset the shape of the tensor.
* Note: This method means Reset the shape of the tensor,
* @note: This method means Reset the shape of the tensor,
* and must be called before calling mutable_data() or
* copy_to(const PlaceType& place), this is not a standard definition of
* reshape behavior, so we will deprecated this feature in the future.
......@@ -329,14 +322,33 @@ class PD_DLL_DECL Tensor final {
gpuStream_t stream() const;
#endif
/**
* @brief Return the name of Tensor.
* @note Used to adapt original execution mechanism and debug analysis
* in the development of new dygraph. It may be removed in the future.
*
* @return const std::string&
*/
const std::string& name() const { return name_; }
/**
* @brief Set name of Tensor.
* @note Used to adapt original execution mechanism and debug analysis
* in the development of new dygraph. It may be removed in the future.
*
* @param const std::string& name
*/
void set_name(const std::string& name) { name_ = name; }
/* Part 5: Data Transform methods */
/**
* @brief Copy the current Tensor data to the specified device
* and return the new Tensor. It's usually used to set the input tensor data.
* Note: The Tensor's `copy_to` method is deprecated since version 2.3, and
* will be removed in version 2.4, please use `to` method instead. reason:
* copying a Tensor to another device does not need to specify the
* @note The Tensor's `copy_to` method is deprecated since version 2.3, and
* will be removed in version 2.4, please use `copy_to` method without
* template argument instead.
* reason: copying a Tensor to another device does not need to specify the
* data type template argument
*
* @tparam T
......@@ -352,9 +364,8 @@ class PD_DLL_DECL Tensor final {
* @param place, the target place of which the tensor will copy to.
* @return Tensor
*/
// TODO(chenweihang): replace Backend by new Place, may be append dtype and
// layout arguments in the future
Tensor to(Backend backend, bool blocking) const;
// TODO(chenweihang): replace Backend by new Place
Tensor copy_to(Backend backend, bool blocking) const;
/**
* @brief Cast datatype from one to another
......@@ -362,7 +373,7 @@ class PD_DLL_DECL Tensor final {
* @param target_type
* @return Tensor
*/
Tensor cast(const DataType& target_type) const;
Tensor cast(DataType target_type) const;
/* Part 6: Status utils methods */
......@@ -470,7 +481,7 @@ class PD_DLL_DECL Tensor final {
std::shared_ptr<AbstractAutogradMeta> autograd_meta_{nullptr};
/**
* Tensor name: used for adapt original execution mechanism and debug analysis
* Tensor name: used to adapt original execution mechanism and debug analysis
* in the development of new dygraph. It may be removed in the future.
*/
std::string name_;
......
......@@ -21,8 +21,7 @@ namespace paddle {
namespace experimental {
// TODO(chenweihang): Replace backend by place when place is ready
// TODO(chenweihang): Add layout and dtype argument if needed
PD_DLL_DECL Tensor to(const Tensor& x, Backend backend, bool blocking);
PD_DLL_DECL Tensor copy_to(const Tensor& x, Backend backend, bool blocking);
} // namespace experimental
} // namespace paddle
......@@ -19,6 +19,7 @@ limitations under the License. */
#include <vector>
#include "glog/logging.h"
#include "paddle/pten/api/include/manipulation.h"
#include "paddle/pten/api/include/utils.h"
#include "paddle/pten/api/lib/ext_compat_utils.h"
#include "paddle/pten/api/lib/utils/allocator.h"
......@@ -281,11 +282,11 @@ gpuStream_t Tensor::stream() const {
template <typename T>
Tensor Tensor::copy_to(const PlaceType &target_place) const {
LOG(WARNING) << "The Tensor's `copy_to` method is deprecated since version "
"2.3, and will be removed in version 2.4, please use `to` "
"method instead. "
"2.3, and will be removed in version 2.4, please use "
"`copy_to` method without template argumentinstead. "
"reason: copying a Tensor to another device does not need "
"to specify the data type template argument.";
return to(ConvertExtPlaceToBackend(target_place), /*blocking=*/false);
return copy_to(ConvertExtPlaceToBackend(target_place), /*blocking=*/false);
}
template PD_DLL_DECL Tensor
......@@ -311,15 +312,12 @@ template PD_DLL_DECL Tensor Tensor::copy_to<paddle::platform::complex<double>>(
template PD_DLL_DECL Tensor
Tensor::copy_to<paddle::platform::float16>(const PlaceType &target_place) const;
Tensor Tensor::to(Backend backend, bool blocking) const {
return experimental::to(*this, backend, blocking);
Tensor Tensor::copy_to(Backend backend, bool blocking) const {
return experimental::copy_to(*this, backend, blocking);
}
Tensor Tensor::cast(const DataType &target_type) const {
PADDLE_THROW(platform::errors::Unimplemented(
"The cast operation is not supported now, "
"and it will be implemented by calling the cast kernel later."));
return Tensor();
Tensor Tensor::cast(DataType target_type) const {
return experimental::cast(*this, target_type);
}
/* Part 6: Status utils methods */
......
......@@ -34,7 +34,7 @@ PT_DECLARE_MODULE(UtilsCUDA);
namespace paddle {
namespace experimental {
PD_DLL_DECL Tensor to(const Tensor& x, Backend backend, bool blocking) {
PD_DLL_DECL Tensor copy_to(const Tensor& x, Backend backend, bool blocking) {
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(backend);
......
if(WITH_ROCM)
hip_test(test_pten_tensor SRCS test_pten_tensor.cc DEPS pten_tensor utils_api glog)
hip_test(test_pten_tensor SRCS test_pten_tensor.cc DEPS pten_tensor utils_api manipulation_api glog)
else()
cc_test(test_pten_tensor SRCS test_pten_tensor.cc DEPS pten_tensor utils_api glog)
cc_test(test_pten_tensor SRCS test_pten_tensor.cc DEPS pten_tensor utils_api manipulation_api glog)
endif()
cc_test(test_pten_exception SRCS test_pten_exception.cc DEPS gtest)
......
......@@ -15,17 +15,15 @@ limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/api/include/creation.h"
#include "paddle/pten/api/include/manipulation.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
PT_DECLARE_MODULE(ManipulationCPU);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_MODULE(ManipulationCUDA);
#endif
namespace pten {
namespace tests {
namespace framework = paddle::framework;
using DDim = paddle::framework::DDim;
......@@ -67,3 +65,24 @@ TEST(API, cast) {
ASSERT_NEAR(dense_out_data[i], static_cast<double>(dense_x_data[i]), 1e-6f);
}
}
TEST(Tensor, cast) {
auto x = paddle::experimental::full({3, 4}, 1.0, pten::DataType::FLOAT32);
auto y = x.cast(pten::DataType::INT32);
// check slice result
ASSERT_EQ(y.dims().size(), 2);
ASSERT_EQ(y.dims()[0], 3);
ASSERT_EQ(y.dims()[1], 4);
ASSERT_EQ(y.numel(), 12);
ASSERT_EQ(y.is_cpu(), true);
ASSERT_EQ(y.type(), pten::DataType::INT32);
ASSERT_EQ(y.layout(), pten::DataLayout::NCHW);
ASSERT_EQ(y.initialized(), true);
for (int64_t i = 0; i < y.numel(); ++i) {
ASSERT_EQ(y.mutable_data<int>()[i], 1);
}
}
} // namespace tests
} // namespace pten
......@@ -58,39 +58,39 @@ void CheckOutputResult(const paddle::experimental::Tensor& out) {
}
}
TEST(API, to) {
TEST(API, copy_to) {
// 1. create tensor
auto x = CreateInputTensor();
// 2. test API
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto tmp = paddle::experimental::to(x, pten::Backend::CUDA, false);
auto out = paddle::experimental::to(tmp, pten::Backend::CPU, true);
auto tmp = paddle::experimental::copy_to(x, pten::Backend::CUDA, false);
auto out = paddle::experimental::copy_to(tmp, pten::Backend::CPU, true);
#else
auto out = paddle::experimental::to(x, pten::Backend::CPU, false);
auto out = paddle::experimental::copy_to(x, pten::Backend::CPU, false);
#endif
// 3. check result
CheckOutputResult(out);
}
TEST(Tensor, to) {
TEST(Tensor, copy_to) {
// 1. create tensor
auto x = CreateInputTensor();
// 2. test API
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto tmp = x.to(pten::Backend::CUDA, false);
auto out = tmp.to(pten::Backend::CPU, true);
auto tmp = x.copy_to(pten::Backend::CUDA, false);
auto out = tmp.copy_to(pten::Backend::CPU, true);
#else
auto out = x.to(pten::Backend::CPU, false);
auto out = x.copy_to(pten::Backend::CPU, false);
#endif
// 3. check result
CheckOutputResult(out);
}
TEST(Tensor, copy_to) {
TEST(Tensor, old_copy_to) {
// 1. create tensor
auto x = CreateInputTensor();
......
......@@ -23,12 +23,6 @@ limitations under the License. */
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
PT_DECLARE_MODULE(ManipulationCPU);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_MODULE(ManipulationCUDA);
#endif
namespace framework = paddle::framework;
using DDim = paddle::framework::DDim;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册