未验证 提交 bedcf0dd 编写于 作者: L Leo Chen 提交者: GitHub

[cherry-pick] fix bug when the cuda kernel config exceeds dims max (#33748) (#33893)

fix bug when the cuda kernel config exceeds dims max
上级 702610ef
......@@ -399,9 +399,9 @@ __global__ void LayerNormBackwardComputeGradInput(
const U *__restrict__ mean, const U *__restrict__ var, const float epsilon,
const U *gamma, T *grad_input) {
#ifdef __HIPCC__
for (auto i1 = hipBlockIdx_y; i1 < n1; i1 += hipGridDim_y) {
for (auto i1 = hipBlockIdx_x; i1 < n1; i1 += hipGridDim_x) {
#else
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
for (auto i1 = blockIdx.x; i1 < n1; i1 += gridDim.x) {
#endif
U sum_loss1 = U(0);
U sum_loss2 = U(0);
......@@ -867,9 +867,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
constexpr int BDIMX1 = 32;
constexpr int BDIMY1 = 4;
dim3 threads1(BDIMX1, BDIMY1, 1);
const dim3 blocks1(1, batch_size, 1);
LayerNormBackwardComputeGradInput<
T, U, BDIMX1, BDIMY1><<<blocks1, threads1, 0, stream>>>(
T, U, BDIMX1, BDIMY1><<<batch_size, threads1, 0, stream>>>(
d_y, x, batch_size, feature_size, mean, var, epsilon, scale, d_x);
break;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册