From bedcf0dd98e30accd32969a70d5729ef8a8d2f15 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Thu, 1 Jul 2021 19:09:04 +0800 Subject: [PATCH] [cherry-pick] fix bug when the cuda kernel config exceeds dims max (#33748) (#33893) fix bug when the cuda kernel config exceeds dims max --- paddle/fluid/operators/layer_norm_op.cu | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) mode change 100755 => 100644 paddle/fluid/operators/layer_norm_op.cu diff --git a/paddle/fluid/operators/layer_norm_op.cu b/paddle/fluid/operators/layer_norm_op.cu old mode 100755 new mode 100644 index 25c722358c4..0410e051158 --- a/paddle/fluid/operators/layer_norm_op.cu +++ b/paddle/fluid/operators/layer_norm_op.cu @@ -399,9 +399,9 @@ __global__ void LayerNormBackwardComputeGradInput( const U *__restrict__ mean, const U *__restrict__ var, const float epsilon, const U *gamma, T *grad_input) { #ifdef __HIPCC__ - for (auto i1 = hipBlockIdx_y; i1 < n1; i1 += hipGridDim_y) { + for (auto i1 = hipBlockIdx_x; i1 < n1; i1 += hipGridDim_x) { #else - for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { + for (auto i1 = blockIdx.x; i1 < n1; i1 += gridDim.x) { #endif U sum_loss1 = U(0); U sum_loss2 = U(0); @@ -867,9 +867,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale, constexpr int BDIMX1 = 32; constexpr int BDIMY1 = 4; dim3 threads1(BDIMX1, BDIMY1, 1); - const dim3 blocks1(1, batch_size, 1); LayerNormBackwardComputeGradInput< - T, U, BDIMX1, BDIMY1><<>>( + T, U, BDIMX1, BDIMY1><<>>( d_y, x, batch_size, feature_size, mean, var, epsilon, scale, d_x); break; } -- GitLab