未验证 提交 c342651e 编写于 作者: W wuhuachaocoding 提交者: GitHub

fix concat bug (#34319)

上级 609f8225
......@@ -40,18 +40,18 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
const std::vector<framework::Tensor>& input, int axis,
framework::Tensor* output) {
// TODO(zcd): Add input data validity checking
int num = input.size();
size_t num = input.size();
int rows = 1;
int64_t rows = 1;
auto dim_0 = input[0].dims();
for (int i = 0; i < axis; ++i) {
rows *= dim_0[i];
}
int out_rows = rows, out_cols = 0;
int64_t out_rows = rows, out_cols = 0;
std::vector<int64_t> input_cols(input.size());
for (int i = 0; i < num; ++i) {
int t_cols = input[i].numel() / rows;
for (size_t i = 0; i < num; ++i) {
int64_t t_cols = input[i].numel() / rows;
out_cols += t_cols;
input_cols[i] = t_cols;
}
......@@ -59,11 +59,11 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
// computation
auto output_data = output->data<T>();
int col_idx = 0;
for (int j = 0; j < num; ++j) {
int col_len = input_cols[j];
int64_t col_idx = 0;
for (size_t j = 0; j < num; ++j) {
int64_t col_len = input_cols[j];
auto input_data = input[j].data<T>();
for (int k = 0; k < out_rows; ++k) {
for (int64_t k = 0; k < out_rows; ++k) {
memory::Copy(cpu_place, output_data + k * out_cols + col_idx, cpu_place,
input_data + k * col_len, sizeof(T) * col_len);
}
......
......@@ -26,9 +26,9 @@ namespace operators {
namespace math {
template <typename T>
__global__ void ConcatKernel(const T** inputs, const int* input_cols,
int col_size, const int output_rows,
const int output_cols, T* output) {
__global__ void ConcatKernel(const T** inputs, const int64_t* input_cols,
int col_size, const int64_t output_rows,
const int64_t output_cols, T* output) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
int curr_segment = 0;
int curr_offset = input_cols[0];
......@@ -70,8 +70,8 @@ __device__ void ConcatKernelDetail(const T** inputs_data,
template <typename T>
__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
const int fixed_in_col, const int out_rows,
const int out_cols, T* output_data) {
const int64_t fixed_in_col, const int64_t out_rows,
const int64_t out_cols, T* output_data) {
const T* inputs_data[2];
inputs_data[0] = input_addr0;
inputs_data[1] = input_addr1;
......@@ -81,8 +81,8 @@ __global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
template <typename T>
__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
const T* input_addr2, const int fixed_in_col,
const int out_rows, const int out_cols,
const T* input_addr2, const int64_t fixed_in_col,
const int64_t out_rows, const int64_t out_cols,
T* output_data) {
const T* inputs_data[3];
inputs_data[0] = input_addr0;
......@@ -95,8 +95,8 @@ __global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
template <typename T>
__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
const T* input_addr2, const T* input_addr3,
const int fixed_in_col, const int out_rows,
const int out_cols, T* output_data) {
const int64_t fixed_in_col, const int64_t out_rows,
const int64_t out_cols, T* output_data) {
const T* inputs_data[4];
inputs_data[0] = input_addr0;
inputs_data[1] = input_addr1;
......@@ -108,8 +108,8 @@ __global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
template <typename T>
__global__ void ConcatKernel(const T** inputs_data, const int in_num,
const int fixed_in_col, const int out_rows,
const int out_cols, T* output_data) {
const int64_t fixed_in_col, const int64_t out_rows,
const int64_t out_cols, T* output_data) {
ConcatKernelDetail<T>(inputs_data, fixed_in_col, out_rows, out_cols,
output_data);
}
......@@ -235,19 +235,19 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
framework::Tensor* output) {
// TODO(zcd): Add input data validity checking
int in_num = input.size();
int in_row = 1;
int64_t in_row = 1;
auto dim_0 = input[0].dims();
for (int i = 0; i < axis; ++i) {
in_row *= dim_0[i];
}
int in_col = input[0].numel() / in_row;
int out_row = in_row, out_col = 0;
int64_t in_col = input[0].numel() / in_row;
int64_t out_row = in_row, out_col = 0;
int inputs_col_num = in_num + 1;
std::vector<const T*> inputs_data_vec(in_num);
std::vector<int> inputs_col_vec(inputs_col_num);
std::vector<int64_t> inputs_col_vec(inputs_col_num);
const T** inputs_data = inputs_data_vec.data();
int* inputs_col = inputs_col_vec.data();
int64_t* inputs_col = inputs_col_vec.data();
// There are some differences between hip runtime and NV runtime.
// In NV, when the pageable memory data less than 64K is transferred from
......@@ -263,13 +263,13 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
inputs_data = reinterpret_cast<const T**>(data_alloc->ptr());
col_alloc = memory::Alloc(platform::CUDAPinnedPlace(),
inputs_col_num * sizeof(int));
inputs_col = reinterpret_cast<int*>(col_alloc->ptr());
inputs_col = reinterpret_cast<int64_t*>(col_alloc->ptr());
#endif
inputs_col[0] = 0;
bool has_same_shape = true;
for (int i = 0; i < in_num; ++i) {
int t_cols = input[i].numel() / in_row;
int64_t t_cols = input[i].numel() / in_row;
if (has_same_shape) {
if (t_cols != in_col) has_same_shape = false;
}
......@@ -312,17 +312,19 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
}
} else {
auto tmp_dev_ins_col_data =
memory::Alloc(context, inputs_col_num * sizeof(int));
memory::Alloc(context, inputs_col_num * sizeof(int64_t));
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
static_cast<void*>(inputs_col), inputs_col_num * sizeof(int),
context.stream());
int* dev_ins_col_data = static_cast<int*>(tmp_dev_ins_col_data->ptr());
static_cast<void*>(inputs_col),
inputs_col_num * sizeof(int64_t), context.stream());
int64_t* dev_ins_col_data =
static_cast<int64_t*>(tmp_dev_ins_col_data->ptr());
ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col_num),
out_row, out_col, output->data<T>());
}
#ifdef PADDLE_WITH_HIP
// Prevent the pinned memory value from being covered and release the memory
// after the launch kernel of the stream is executed (reapply pinned memory
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册