未验证 提交 c342651e 编写于 作者: W wuhuachaocoding 提交者: GitHub

fix concat bug (#34319)

上级 609f8225
...@@ -40,18 +40,18 @@ class ConcatFunctor<platform::CPUDeviceContext, T> { ...@@ -40,18 +40,18 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
const std::vector<framework::Tensor>& input, int axis, const std::vector<framework::Tensor>& input, int axis,
framework::Tensor* output) { framework::Tensor* output) {
// TODO(zcd): Add input data validity checking // TODO(zcd): Add input data validity checking
int num = input.size(); size_t num = input.size();
int rows = 1; int64_t rows = 1;
auto dim_0 = input[0].dims(); auto dim_0 = input[0].dims();
for (int i = 0; i < axis; ++i) { for (int i = 0; i < axis; ++i) {
rows *= dim_0[i]; rows *= dim_0[i];
} }
int out_rows = rows, out_cols = 0; int64_t out_rows = rows, out_cols = 0;
std::vector<int64_t> input_cols(input.size()); std::vector<int64_t> input_cols(input.size());
for (int i = 0; i < num; ++i) { for (size_t i = 0; i < num; ++i) {
int t_cols = input[i].numel() / rows; int64_t t_cols = input[i].numel() / rows;
out_cols += t_cols; out_cols += t_cols;
input_cols[i] = t_cols; input_cols[i] = t_cols;
} }
...@@ -59,11 +59,11 @@ class ConcatFunctor<platform::CPUDeviceContext, T> { ...@@ -59,11 +59,11 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
// computation // computation
auto output_data = output->data<T>(); auto output_data = output->data<T>();
int col_idx = 0; int64_t col_idx = 0;
for (int j = 0; j < num; ++j) { for (size_t j = 0; j < num; ++j) {
int col_len = input_cols[j]; int64_t col_len = input_cols[j];
auto input_data = input[j].data<T>(); auto input_data = input[j].data<T>();
for (int k = 0; k < out_rows; ++k) { for (int64_t k = 0; k < out_rows; ++k) {
memory::Copy(cpu_place, output_data + k * out_cols + col_idx, cpu_place, memory::Copy(cpu_place, output_data + k * out_cols + col_idx, cpu_place,
input_data + k * col_len, sizeof(T) * col_len); input_data + k * col_len, sizeof(T) * col_len);
} }
......
...@@ -26,9 +26,9 @@ namespace operators { ...@@ -26,9 +26,9 @@ namespace operators {
namespace math { namespace math {
template <typename T> template <typename T>
__global__ void ConcatKernel(const T** inputs, const int* input_cols, __global__ void ConcatKernel(const T** inputs, const int64_t* input_cols,
int col_size, const int output_rows, int col_size, const int64_t output_rows,
const int output_cols, T* output) { const int64_t output_cols, T* output) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x; int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
int curr_segment = 0; int curr_segment = 0;
int curr_offset = input_cols[0]; int curr_offset = input_cols[0];
...@@ -70,8 +70,8 @@ __device__ void ConcatKernelDetail(const T** inputs_data, ...@@ -70,8 +70,8 @@ __device__ void ConcatKernelDetail(const T** inputs_data,
template <typename T> template <typename T>
__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1, __global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
const int fixed_in_col, const int out_rows, const int64_t fixed_in_col, const int64_t out_rows,
const int out_cols, T* output_data) { const int64_t out_cols, T* output_data) {
const T* inputs_data[2]; const T* inputs_data[2];
inputs_data[0] = input_addr0; inputs_data[0] = input_addr0;
inputs_data[1] = input_addr1; inputs_data[1] = input_addr1;
...@@ -81,8 +81,8 @@ __global__ void ConcatKernel(const T* input_addr0, const T* input_addr1, ...@@ -81,8 +81,8 @@ __global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
template <typename T> template <typename T>
__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1, __global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
const T* input_addr2, const int fixed_in_col, const T* input_addr2, const int64_t fixed_in_col,
const int out_rows, const int out_cols, const int64_t out_rows, const int64_t out_cols,
T* output_data) { T* output_data) {
const T* inputs_data[3]; const T* inputs_data[3];
inputs_data[0] = input_addr0; inputs_data[0] = input_addr0;
...@@ -95,8 +95,8 @@ __global__ void ConcatKernel(const T* input_addr0, const T* input_addr1, ...@@ -95,8 +95,8 @@ __global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
template <typename T> template <typename T>
__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1, __global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
const T* input_addr2, const T* input_addr3, const T* input_addr2, const T* input_addr3,
const int fixed_in_col, const int out_rows, const int64_t fixed_in_col, const int64_t out_rows,
const int out_cols, T* output_data) { const int64_t out_cols, T* output_data) {
const T* inputs_data[4]; const T* inputs_data[4];
inputs_data[0] = input_addr0; inputs_data[0] = input_addr0;
inputs_data[1] = input_addr1; inputs_data[1] = input_addr1;
...@@ -108,8 +108,8 @@ __global__ void ConcatKernel(const T* input_addr0, const T* input_addr1, ...@@ -108,8 +108,8 @@ __global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
template <typename T> template <typename T>
__global__ void ConcatKernel(const T** inputs_data, const int in_num, __global__ void ConcatKernel(const T** inputs_data, const int in_num,
const int fixed_in_col, const int out_rows, const int64_t fixed_in_col, const int64_t out_rows,
const int out_cols, T* output_data) { const int64_t out_cols, T* output_data) {
ConcatKernelDetail<T>(inputs_data, fixed_in_col, out_rows, out_cols, ConcatKernelDetail<T>(inputs_data, fixed_in_col, out_rows, out_cols,
output_data); output_data);
} }
...@@ -235,19 +235,19 @@ class ConcatFunctor<platform::CUDADeviceContext, T> { ...@@ -235,19 +235,19 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
framework::Tensor* output) { framework::Tensor* output) {
// TODO(zcd): Add input data validity checking // TODO(zcd): Add input data validity checking
int in_num = input.size(); int in_num = input.size();
int in_row = 1; int64_t in_row = 1;
auto dim_0 = input[0].dims(); auto dim_0 = input[0].dims();
for (int i = 0; i < axis; ++i) { for (int i = 0; i < axis; ++i) {
in_row *= dim_0[i]; in_row *= dim_0[i];
} }
int in_col = input[0].numel() / in_row; int64_t in_col = input[0].numel() / in_row;
int out_row = in_row, out_col = 0; int64_t out_row = in_row, out_col = 0;
int inputs_col_num = in_num + 1; int inputs_col_num = in_num + 1;
std::vector<const T*> inputs_data_vec(in_num); std::vector<const T*> inputs_data_vec(in_num);
std::vector<int> inputs_col_vec(inputs_col_num); std::vector<int64_t> inputs_col_vec(inputs_col_num);
const T** inputs_data = inputs_data_vec.data(); const T** inputs_data = inputs_data_vec.data();
int* inputs_col = inputs_col_vec.data(); int64_t* inputs_col = inputs_col_vec.data();
// There are some differences between hip runtime and NV runtime. // There are some differences between hip runtime and NV runtime.
// In NV, when the pageable memory data less than 64K is transferred from // In NV, when the pageable memory data less than 64K is transferred from
...@@ -263,13 +263,13 @@ class ConcatFunctor<platform::CUDADeviceContext, T> { ...@@ -263,13 +263,13 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
inputs_data = reinterpret_cast<const T**>(data_alloc->ptr()); inputs_data = reinterpret_cast<const T**>(data_alloc->ptr());
col_alloc = memory::Alloc(platform::CUDAPinnedPlace(), col_alloc = memory::Alloc(platform::CUDAPinnedPlace(),
inputs_col_num * sizeof(int)); inputs_col_num * sizeof(int));
inputs_col = reinterpret_cast<int*>(col_alloc->ptr()); inputs_col = reinterpret_cast<int64_t*>(col_alloc->ptr());
#endif #endif
inputs_col[0] = 0; inputs_col[0] = 0;
bool has_same_shape = true; bool has_same_shape = true;
for (int i = 0; i < in_num; ++i) { for (int i = 0; i < in_num; ++i) {
int t_cols = input[i].numel() / in_row; int64_t t_cols = input[i].numel() / in_row;
if (has_same_shape) { if (has_same_shape) {
if (t_cols != in_col) has_same_shape = false; if (t_cols != in_col) has_same_shape = false;
} }
...@@ -312,17 +312,19 @@ class ConcatFunctor<platform::CUDADeviceContext, T> { ...@@ -312,17 +312,19 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
} }
} else { } else {
auto tmp_dev_ins_col_data = auto tmp_dev_ins_col_data =
memory::Alloc(context, inputs_col_num * sizeof(int)); memory::Alloc(context, inputs_col_num * sizeof(int64_t));
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
static_cast<void*>(inputs_col), inputs_col_num * sizeof(int), static_cast<void*>(inputs_col),
context.stream()); inputs_col_num * sizeof(int64_t), context.stream());
int* dev_ins_col_data = static_cast<int*>(tmp_dev_ins_col_data->ptr()); int64_t* dev_ins_col_data =
static_cast<int64_t*>(tmp_dev_ins_col_data->ptr());
ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>( ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col_num), dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col_num),
out_row, out_col, output->data<T>()); out_row, out_col, output->data<T>());
} }
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
// Prevent the pinned memory value from being covered and release the memory // Prevent the pinned memory value from being covered and release the memory
// after the launch kernel of the stream is executed (reapply pinned memory // after the launch kernel of the stream is executed (reapply pinned memory
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册