未验证 提交 e896567e 编写于 作者: S sneaxiy 提交者: GitHub

Fix some operators when the tensor.numel() > INT32_MAX (#46767)

* fix some ops for int64 range

* update error message
上级 05c2b9ba
......@@ -127,26 +127,27 @@ __device__ __forceinline__ void warp_reduce_upper_tri(T* sum) {
template <typename T, int pow2_index>
__global__ void SoftmaxMaskFuseUpperTriangleGPUKernel(const T* src,
T* dst,
int batch_count,
int key_seq_len) {
int64_t batch_count,
int64_t key_seq_len) {
constexpr int next_pow2 = 1 << pow2_index;
constexpr int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE;
constexpr int kLocalIterations = std::max(next_pow2 / warp_size, 4);
constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1;
constexpr int kOneLoadingCounts = 4;
int key_seq_len_pow_2 = key_seq_len * key_seq_len;
int64_t key_seq_len_pow_2 = key_seq_len * key_seq_len;
int first_idx =
(blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * kLocalBatchSize +
int64_t first_idx =
(static_cast<int64_t>(blockDim.y) * blockIdx.y + threadIdx.y) *
gridDim.x * kLocalBatchSize +
blockIdx.x;
int local_block_idx = blockIdx.x + 1;
int warp_iter_upper_bound =
int64_t local_block_idx = blockIdx.x + 1;
int64_t warp_iter_upper_bound =
(local_block_idx + kOneLoadingCounts * warp_size - 1) / warp_size;
int local_batches = batch_count - first_idx;
int64_t local_batches = batch_count - first_idx;
if (local_batches > kLocalBatchSize) local_batches = kLocalBatchSize;
int local_idx = threadIdx.x;
int64_t local_idx = threadIdx.x;
src += first_idx * key_seq_len + kOneLoadingCounts * local_idx;
dst += first_idx * key_seq_len + kOneLoadingCounts * local_idx;
......@@ -156,11 +157,11 @@ __global__ void SoftmaxMaskFuseUpperTriangleGPUKernel(const T* src,
#pragma unroll
for (int i = 0; i < kLocalBatchSize; ++i) {
int batch_total_number = (i >= local_batches) ? 0 : local_block_idx;
auto batch_total_number = (i >= local_batches) ? 0 : local_block_idx;
#pragma unroll
for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) {
int element_index = kOneLoadingCounts * local_idx + ii * warp_size;
auto element_index = kOneLoadingCounts * local_idx + ii * warp_size;
if (element_index < batch_total_number) {
load_data_upper_tri(temp_in,
......@@ -215,7 +216,7 @@ __global__ void SoftmaxMaskFuseUpperTriangleGPUKernel(const T* src,
if (i >= local_batches) break;
#pragma unroll
for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) {
int element_index = kOneLoadingCounts * local_idx + ii * warp_size;
auto element_index = kOneLoadingCounts * local_idx + ii * warp_size;
if (element_index < local_block_idx) {
#pragma unroll
......@@ -241,31 +242,32 @@ template <typename T, int pow2_index>
__global__ void SoftmaxMaskFuseUpperTriangleGradGPUKernel(const T* grad_input,
T* grad_output,
const T* softmax_rst,
int batch_count,
int key_seq_len) {
int64_t batch_count,
int64_t key_seq_len) {
constexpr int next_pow2 = 1 << pow2_index;
constexpr int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE;
constexpr int kLocalIterations = std::max(next_pow2 / warp_size, 4);
constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1;
constexpr int kOneLoadingCounts = 4;
int key_seq_len_pow_2 = key_seq_len * key_seq_len;
int64_t key_seq_len_pow_2 = key_seq_len * key_seq_len;
int first_idx =
(blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * kLocalBatchSize +
int64_t first_idx =
(static_cast<int64_t>(blockDim.y) * blockIdx.y + threadIdx.y) *
gridDim.x * kLocalBatchSize +
blockIdx.x;
int local_block_idx = blockIdx.x + 1;
int64_t local_block_idx = blockIdx.x + 1;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = batch_count - first_idx;
int64_t local_batches = batch_count - first_idx;
if (local_batches > kLocalBatchSize) local_batches = kLocalBatchSize;
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x;
int64_t local_idx = threadIdx.x;
// the first element to process by the current thread
int offset = first_idx * key_seq_len + kOneLoadingCounts * local_idx;
int64_t offset = first_idx * key_seq_len + kOneLoadingCounts * local_idx;
grad_input += offset;
grad_output += offset;
softmax_rst += offset;
......@@ -278,11 +280,11 @@ __global__ void SoftmaxMaskFuseUpperTriangleGradGPUKernel(const T* grad_input,
#pragma unroll
for (int i = 0; i < kLocalBatchSize; ++i) {
int batch_total_number = (i >= local_batches) ? 0 : local_block_idx;
auto batch_total_number = (i >= local_batches) ? 0 : local_block_idx;
#pragma unroll
for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) {
int element_index = kOneLoadingCounts * local_idx + ii * warp_size;
auto element_index = kOneLoadingCounts * local_idx + ii * warp_size;
if (element_index < batch_total_number) {
load_data_upper_tri(
temp_grad_input,
......@@ -327,7 +329,7 @@ __global__ void SoftmaxMaskFuseUpperTriangleGradGPUKernel(const T* grad_input,
if (i >= local_batches) break;
#pragma unroll
for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) {
int element_index = kOneLoadingCounts * local_idx + ii * warp_size;
auto element_index = kOneLoadingCounts * local_idx + ii * warp_size;
if (element_index < key_seq_len) {
// compute gradients
T samples_out[kOneLoadingCounts];
......@@ -368,10 +370,10 @@ class SoftmaxMaskFuseUpperTriangleKernel : public framework::OpKernel<T> {
key_seq_len,
query_seq_len));
PADDLE_ENFORCE_EQ(key_seq_len >= 32 && key_seq_len < 8192,
PADDLE_ENFORCE_EQ(key_seq_len >= 32 && key_seq_len <= 16384,
true,
platform::errors::InvalidArgument(
"Input x's last dim must be between [32, 8192) "
"Input x's last dim must be between [32, 16384] "
"received the last dimension of x is %d",
key_seq_len));
......@@ -380,7 +382,7 @@ class SoftmaxMaskFuseUpperTriangleKernel : public framework::OpKernel<T> {
int pow2_index = get_pow2_index_value(key_seq_len);
const int next_pow2 = 1 << pow2_index;
int batch_count = attn_mul_batch * query_seq_len;
int64_t batch_count = attn_mul_batch * query_seq_len;
int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE;
int batches_per_warp = (next_pow2 <= 128) ? 2 : 1;
constexpr int threads_per_block = 128;
......@@ -447,7 +449,13 @@ class SoftmaxMaskFuseUpperTriangleKernel : public framework::OpKernel<T> {
<<<blocks, threads, 0, stream>>>(
x_data, y_data, batch_count, key_seq_len);
break;
case 14: // 16384
SoftmaxMaskFuseUpperTriangleGPUKernel<T, 14>
<<<blocks, threads, 0, stream>>>(
x_data, y_data, batch_count, key_seq_len);
break;
default:
PADDLE_THROW(phi::errors::Unimplemented("Too large sequence length."));
break;
}
}
......@@ -479,7 +487,7 @@ class SoftmaxMaskFuseUpperTriangleGradKernel : public framework::OpKernel<T> {
int pow2_index = get_pow2_index_value(key_seq_len);
const int next_pow2 = 1 << pow2_index;
int batch_count = attn_mul_batch * query_seq_len;
int64_t batch_count = attn_mul_batch * query_seq_len;
int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE;
int batches_per_warp = (next_pow2 <= 128) ? 2 : 1;
// use 128 threads per block to maximum gpu utilization
......@@ -565,7 +573,16 @@ class SoftmaxMaskFuseUpperTriangleGradKernel : public framework::OpKernel<T> {
batch_count,
key_seq_len);
break;
case 14:
SoftmaxMaskFuseUpperTriangleGradGPUKernel<T, 14>
<<<blocks, threads, 0, stream>>>(grad_y_data,
grad_x_data,
softmax_rst_data,
batch_count,
key_seq_len);
break;
default:
PADDLE_THROW(phi::errors::Unimplemented("Too large sequence length."));
break;
}
}
......
......@@ -760,8 +760,10 @@ __global__ void VectorizedElementwiseKernel(
kps::IndexType main_offset,
int read_lens,
Functor func) {
kps::IndexType data_offset = BLOCK_ID_X * BLOCK_NUM_X * read_lens;
kps::IndexType stride = BLOCK_NUM_X * GRID_NUM_X * read_lens;
kps::IndexType data_offset =
static_cast<kps::IndexType>(BLOCK_ID_X) * BLOCK_NUM_X * read_lens;
kps::IndexType stride =
static_cast<kps::IndexType>(BLOCK_NUM_X) * GRID_NUM_X * read_lens;
for (; data_offset < main_offset; data_offset += stride) {
VectorizedElementwiseKernelImpl<OutT,
Functor,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册