未验证 提交 b9a1d954 编写于 作者: Z zhaoyuchen2018 提交者: GitHub

[cherry-pick] Fix softmax cuda bug (#21720) (#22160)

* Fix softmax cuda bug

* Refine multihead log and softmax logic

* Align block to 32
上级 835201bf
......@@ -84,15 +84,39 @@ class MultiHeadMatMulOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(dim_bias_q[0], dim_bias_v[0],
"Multihead input bias should have same batch size");
PADDLE_ENFORCE_EQ(dim_bias_q[1], dim_bias_k[1],
"Multihead input bias should have same size");
PADDLE_ENFORCE_EQ(dim_bias_q[1], dim_bias_v[1],
"Multihead input bias should have same size");
auto dim_bias_qk = context->GetInputDim("BiasQK");
PADDLE_ENFORCE_GT(dim_bias_qk.size(), 3,
"Multihead input bias qk should be at least 4-D tensor.");
int b_indx = dim_bias_q.size() - 1;
int indx = dim_q.size() - 1;
PADDLE_ENFORCE_EQ(
dim_bias_q[b_indx], dim_q[indx],
platform::errors::InvalidArgument(
"bias_q's last dim size should equal to"
" q last dim size, but received bias_q's size is:%d q is:%d",
dim_bias_q[b_indx], dim_q[indx]));
PADDLE_ENFORCE_EQ(
dim_bias_k[b_indx], dim_k[indx],
platform::errors::InvalidArgument(
"bias_k's last dim size should equal to"
" k last dim size, but received bias_k's size is:%d k is:%d",
dim_bias_k[b_indx], dim_k[indx]));
PADDLE_ENFORCE_EQ(
dim_bias_v[b_indx], dim_v[indx],
platform::errors::InvalidArgument(
"bias_v's last dim size should equal to"
" v last dim size, but received bias_v's size is:%d v is:%d",
dim_bias_v[b_indx], dim_v[indx]));
PADDLE_ENFORCE_EQ(dim_q[0], dim_bias_qk[0],
platform::errors::InvalidArgument(
"q should have same batch size"
"with bias_qk, but received q's batch size is:%d "
"bias_qk's batch size is:%d",
dim_q[0], dim_bias_qk[0]));
int head_number = context->Attrs().Get<int>("head_number");
PADDLE_ENFORCE_GT(head_number, 1,
"Multihead input head number should be at least 1.");
......
......@@ -196,15 +196,14 @@ __global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_,
const int head_num,
const int seq_len,
const unsigned mask) {
int seq_id = blockIdx.x % seq_len;
int qk_offset = blockIdx.x * seq_len;
int bias_offset = blockIdx.x % (head_num * seq_len) * seq_len;
assert(blockDim.x % 32 == 0);
__shared__ float s_sum, s_max;
float qk = threadIdx.x < seq_len
? static_cast<float>((qk_buf_[threadIdx.x + qk_offset] +
bias_qk_[threadIdx.x + bias_offset]))
bias_qk_[threadIdx.x + qk_offset]))
: 0.0f;
float tmp = threadIdx.x < seq_len ? static_cast<float>(qk) : -1e20f;
......@@ -259,15 +258,29 @@ void MatMulWithHeadQK(const platform::CUDADeviceContext &context, int head_num,
q_buf_, k_buf_, beta, qk_buf_, batch_size * head_num,
seq_len * size_per_head, seq_len * size_per_head);
int m = batch_size * head_num * seq_len;
int k = seq_len;
int grid = m;
int block = k;
int grid = batch_size * head_num * seq_len;
int block = seq_len;
// Align block to 32, also limit seq_len to max block size.
PADDLE_ENFORCE_LE(seq_len, 1024, platform::errors::InvalidArgument(
"seq_len should <= 1024, "
"but received seq_len is:%d",
seq_len));
if (seq_len <= 32)
block = 32;
else if (seq_len > 32 && seq_len <= 64)
block = 64;
else if (seq_len > 64 && seq_len <= 128)
block = 128;
else if (seq_len > 128 && seq_len <= 256)
block = 256;
else if (seq_len > 256 && seq_len <= 512)
block = 512;
else
block = 1024;
unsigned mask = block < 32 ? (((unsigned)1 << block) - 1) : FINAL_MASK;
softmax_kernel_with_eltadd<T><<<grid, block, 0, stream>>>(
qk_buf_, bias_qk, batch_size, head_num, seq_len, mask);
qk_buf_, bias_qk, batch_size, head_num, seq_len, FINAL_MASK);
}
template <typename T>
......
......@@ -54,7 +54,8 @@ class TestFusedMultiheadMatmulOp(OpTest):
self.BiasK = np.random.random((1, w)).astype("float32")
self.BiasV = np.random.random((1, w)).astype("float32")
self.BiasQK = np.random.random(
(1, self.head_number, self.seq_len, self.seq_len)).astype("float32")
(self.batch_size, self.head_number, self.seq_len,
self.seq_len)).astype("float32")
# Compute Q path
fc_q = self.Q + self.BiasQ
reshape_q = np.reshape(fc_q, (self.batch_size, self.seq_len,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册