From 5f08b82f2cebe360d3d2788fe3a60486580695e7 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 17 Aug 2022 16:11:49 +0800 Subject: [PATCH] fix(dnn/cuda): fix ptx mma algo compute bugs GitOrigin-RevId: 19628d0c94e93ff1072db2eb04547e6f8db5f809 --- .../fuse_z_imma8832_ldg16_128x256_relu.cu | 26 ++--- .../fuse_z_imma8832_ldg16_256x64_relu.cu | 57 +++++----- .../fuse_z_imma8832_ldgsts16_128x128_relu.cu | 57 +++++----- .../uint4_int4/imma8832_ldg16_128x256_relu.cu | 25 +++-- .../uint4_int4/imma8832_ldg16_256x64_relu.cu | 25 +++-- .../imma8832_ldgsts16_128x128_relu.cu | 62 +++++------ dnn/src/cuda/ptx/uint4_int4/macro.cuh | 100 ++++-------------- 7 files changed, 157 insertions(+), 195 deletions(-) diff --git a/dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldg16_128x256_relu.cu b/dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldg16_128x256_relu.cu index 14285e6f..fc8659c8 100644 --- a/dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldg16_128x256_relu.cu +++ b/dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldg16_128x256_relu.cu @@ -476,6 +476,20 @@ extern "C" __global__ void __launch_bounds__(256) __syncthreads(); } + size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad; + const float* bias_ptr = bias + oc; + + int4 load_bias0 = make_int4(0, 0, 0, 0); + int4 load_bias1 = make_int4(0, 0, 0, 0); + int4 load_bias2 = make_int4(0, 0, 0, 0); + int4 load_bias3 = make_int4(0, 0, 0, 0); + if (oc < param.oc) { + load_bias0 = *(reinterpret_cast(bias_ptr)); + load_bias1 = *(reinterpret_cast(bias_ptr + 4)); + load_bias2 = *(reinterpret_cast(bias_ptr + 8)); + load_bias3 = *(reinterpret_cast(bias_ptr + 12)); + } + // read fuse_z int2 reg_fuse_z[reg_m] = {make_int2(z_zero_point, z_zero_point), make_int2(z_zero_point, z_zero_point), @@ -595,18 +609,7 @@ extern "C" __global__ void __launch_bounds__(256) __syncthreads(); /// output - size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad; - const float* bias_ptr = bias + oc; - - int4 load_bias0 = make_int4(0, 0, 0, 0); - int4 load_bias1 = make_int4(0, 0, 0, 0); - int4 load_bias2 = make_int4(0, 0, 0, 0); - int4 load_bias3 = make_int4(0, 0, 0, 0); if (oc < param.oc) { - load_bias0 = *(reinterpret_cast(bias_ptr)); - load_bias1 = *(reinterpret_cast(bias_ptr + 4)); - load_bias2 = *(reinterpret_cast(bias_ptr + 8)); - load_bias3 = *(reinterpret_cast(bias_ptr + 12)); mul_v4(load_bias0, load_bias0, beta); mul_v4(load_bias1, load_bias1, beta); mul_v4(load_bias2, load_bias2, beta); @@ -617,7 +620,6 @@ extern "C" __global__ void __launch_bounds__(256) #pragma unroll for (int y = 0; y < reg_m; y += 4) { - I2F_4x8(reg_acc, y, 0); FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); FUSE_Z_4x8(reg_acc, y, 0, reg_fuse_z, gamma, z_zero_point); PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point); diff --git a/dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldg16_256x64_relu.cu b/dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldg16_256x64_relu.cu index 87200dbf..3faa9bed 100644 --- a/dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldg16_256x64_relu.cu +++ b/dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldg16_256x64_relu.cu @@ -657,6 +657,20 @@ extern "C" __global__ void __launch_bounds__(256) __syncthreads(); } + size_t oc = bidy * BM + 16 * idx_in_quad; + const float* bias_ptr = bias + oc; + + int4 load_bias0 = make_int4(0, 0, 0, 0); + int4 load_bias1 = make_int4(0, 0, 0, 0); + int4 load_bias2 = make_int4(0, 0, 0, 0); + int4 load_bias3 = make_int4(0, 0, 0, 0); + if (oc < param.oc) { + load_bias0 = *(reinterpret_cast(bias_ptr)); + load_bias1 = *(reinterpret_cast(bias_ptr + 4)); + load_bias2 = *(reinterpret_cast(bias_ptr + 8)); + load_bias3 = *(reinterpret_cast(bias_ptr + 12)); + } + // read fuse_z int2 reg_fuse_z[reg_m] = {make_int2(z_zero_point, z_zero_point), make_int2(z_zero_point, z_zero_point), @@ -712,6 +726,14 @@ extern "C" __global__ void __launch_bounds__(256) reg_flt[0][j] = make_int4(x, y, z, w); } + /// output + if (oc < param.oc) { + mul_v4(load_bias0, load_bias0, beta); + mul_v4(load_bias1, load_bias1, beta); + mul_v4(load_bias2, load_bias2, beta); + mul_v4(load_bias3, load_bias3, beta); + } + // compute #pragma unroll for (int k_inner = 0; k_inner < BKd32; k_inner++) { @@ -773,35 +795,20 @@ extern "C" __global__ void __launch_bounds__(256) __syncthreads(); - /// output - size_t oc = bidy * BM + 16 * idx_in_quad; - const float* bias_ptr = bias + oc; - - int4 load_bias0 = make_int4(0, 0, 0, 0); - int4 load_bias1 = make_int4(0, 0, 0, 0); - int4 load_bias2 = make_int4(0, 0, 0, 0); - int4 load_bias3 = make_int4(0, 0, 0, 0); - if (oc < param.oc) { - load_bias0 = *(reinterpret_cast(bias_ptr)); - load_bias1 = *(reinterpret_cast(bias_ptr + 4)); - load_bias2 = *(reinterpret_cast(bias_ptr + 8)); - load_bias3 = *(reinterpret_cast(bias_ptr + 12)); - mul_v4(load_bias0, load_bias0, beta); - mul_v4(load_bias1, load_bias1, beta); - mul_v4(load_bias2, load_bias2, beta); - mul_v4(load_bias3, load_bias3, beta); - } - int8_t* __restrict__ g_dst_ptr = dst + d_offset; + FMA_1x8(reg_acc, 0, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); + fuse_z_1x8(reg_acc[0], 0, reg_fuse_z[0], gamma, z_zero_point); + PACK_F2I_WITH_RELU_1x8(reg_acc, 0, 0, relu, dst_zero_point); + #pragma unroll - for (int y = 0; y < reg_m; y += 4) { - I2F_4x8(reg_acc, y, 0); - FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); - FUSE_Z_4x8(reg_acc, y, 0, reg_fuse_z, gamma, z_zero_point); - PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point); - STG_AFTER_LDG_4x1(g_offset, reg_acc, y, 0); + for (int y = 1; y < reg_m; y += 1) { + FMA_1x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); + fuse_z_1x8(reg_acc[y], 0, reg_fuse_z[y], gamma, z_zero_point); + PACK_F2I_WITH_RELU_1x8(reg_acc, y, 0, relu, dst_zero_point); + STG_AFTER_LDG(g_offset[y - 1], reg_acc[y - 1][0], stg_guard[y - 1]); } + STG_AFTER_LDG(g_offset[7], reg_acc[7][0], stg_guard[7]); #endif } } // namespace diff --git a/dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldgsts16_128x128_relu.cu b/dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldgsts16_128x128_relu.cu index 9c7f4262..c3101c1e 100644 --- a/dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldgsts16_128x128_relu.cu +++ b/dnn/src/cuda/ptx/uint4_int4/fuse_z_imma8832_ldgsts16_128x128_relu.cu @@ -437,7 +437,7 @@ extern "C" __global__ void __launch_bounds__(256) cp_async_fence(); } - bool only_one_stage = (stage == 1) ? true : false; + bool only_one_stage = (stage == 1); if (stage >= 2) { cp_async_wait(stages - 2); } else { @@ -844,6 +844,20 @@ extern "C" __global__ void __launch_bounds__(256) cp_async_wait(stages - 2); } + size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad; + const float* bias_ptr = bias + oc; + + int4 load_bias0 = make_int4(0, 0, 0, 0); + int4 load_bias1 = make_int4(0, 0, 0, 0); + int4 load_bias2 = make_int4(0, 0, 0, 0); + int4 load_bias3 = make_int4(0, 0, 0, 0); + if (oc < param.oc) { + load_bias0 = *(reinterpret_cast(bias_ptr)); + load_bias1 = *(reinterpret_cast(bias_ptr + 4)); + load_bias2 = *(reinterpret_cast(bias_ptr + 8)); + load_bias3 = *(reinterpret_cast(bias_ptr + 12)); + } + if (!only_one_stage) { #pragma unroll // low for (int i = 0; i < reg_nd4; ++i) { @@ -975,6 +989,13 @@ extern "C" __global__ void __launch_bounds__(256) reg_flt[0][j] = make_int4(x, y, z, w); } + if (oc < param.oc) { + mul_v4(load_bias0, load_bias0, beta); + mul_v4(load_bias1, load_bias1, beta); + mul_v4(load_bias2, load_bias2, beta); + mul_v4(load_bias3, load_bias3, beta); + } + // compute #pragma unroll for (int k_inner = 0; k_inner < BKd32; k_inner++) { @@ -1038,34 +1059,20 @@ extern "C" __global__ void __launch_bounds__(256) __syncthreads(); /// output - size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad; - const float* bias_ptr = bias + oc; - - int4 load_bias0 = make_int4(0, 0, 0, 0); - int4 load_bias1 = make_int4(0, 0, 0, 0); - int4 load_bias2 = make_int4(0, 0, 0, 0); - int4 load_bias3 = make_int4(0, 0, 0, 0); - if (oc < param.oc) { - load_bias0 = *(reinterpret_cast(bias_ptr)); - load_bias1 = *(reinterpret_cast(bias_ptr + 4)); - load_bias2 = *(reinterpret_cast(bias_ptr + 8)); - load_bias3 = *(reinterpret_cast(bias_ptr + 12)); - mul_v4(load_bias0, load_bias0, beta); - mul_v4(load_bias1, load_bias1, beta); - mul_v4(load_bias2, load_bias2, beta); - mul_v4(load_bias3, load_bias3, beta); - } - int8_t* __restrict__ g_dst_ptr = dst + d_offset; + FMA_1x8(reg_acc, 0, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); + fuse_z_1x8(reg_acc[0], 0, reg_fuse_z[0], gamma, z_zero_point); + PACK_F2I_WITH_RELU_1x8(reg_acc, 0, 0, relu, dst_zero_point); + #pragma unroll - for (int y = 0; y < reg_m; y += 4) { - I2F_4x8(reg_acc, y, 0); - FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); - FUSE_Z_4x8(reg_acc, y, 0, reg_fuse_z, gamma, z_zero_point); - PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point); - STG_AFTER_LDG_4x1(g_offset, reg_acc, y, 0); + for (int y = 1; y < reg_m; y += 1) { + FMA_1x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); + fuse_z_1x8(reg_acc[y], 0, reg_fuse_z[y], gamma, z_zero_point); + PACK_F2I_WITH_RELU_1x8(reg_acc, y, 0, relu, dst_zero_point); + STG_AFTER_LDG(g_offset[y - 1], reg_acc[y - 1][0], stg_guard[y - 1]); } + STG_AFTER_LDG(g_offset[7], reg_acc[7][0], stg_guard[7]); #endif } } // namespace diff --git a/dnn/src/cuda/ptx/uint4_int4/imma8832_ldg16_128x256_relu.cu b/dnn/src/cuda/ptx/uint4_int4/imma8832_ldg16_128x256_relu.cu index 52558669..18658ae1 100644 --- a/dnn/src/cuda/ptx/uint4_int4/imma8832_ldg16_128x256_relu.cu +++ b/dnn/src/cuda/ptx/uint4_int4/imma8832_ldg16_128x256_relu.cu @@ -475,6 +475,20 @@ extern "C" __global__ void __launch_bounds__(256) __syncthreads(); } + size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad; + const float* bias_ptr = bias + oc; + + int4 load_bias0 = make_int4(0, 0, 0, 0); + int4 load_bias1 = make_int4(0, 0, 0, 0); + int4 load_bias2 = make_int4(0, 0, 0, 0); + int4 load_bias3 = make_int4(0, 0, 0, 0); + if (oc < param.oc) { + load_bias0 = *(reinterpret_cast(bias_ptr)); + load_bias1 = *(reinterpret_cast(bias_ptr + 4)); + load_bias2 = *(reinterpret_cast(bias_ptr + 8)); + load_bias3 = *(reinterpret_cast(bias_ptr + 12)); + } + guard = iter < 0; #pragma unroll for (int i = 0; i < reg_nd4; ++i) { @@ -574,18 +588,8 @@ extern "C" __global__ void __launch_bounds__(256) size_t nhw_post3 = nhw_post0 + 24; size_t stg_oc = bidy * BM + (warp_y << 6); - size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad; - const float* bias_ptr = bias + oc; - int4 load_bias0 = make_int4(0, 0, 0, 0); - int4 load_bias1 = make_int4(0, 0, 0, 0); - int4 load_bias2 = make_int4(0, 0, 0, 0); - int4 load_bias3 = make_int4(0, 0, 0, 0); if (oc < param.oc) { - load_bias0 = *(reinterpret_cast(bias_ptr)); - load_bias1 = *(reinterpret_cast(bias_ptr + 4)); - load_bias2 = *(reinterpret_cast(bias_ptr + 8)); - load_bias3 = *(reinterpret_cast(bias_ptr + 12)); mul_v4(load_bias0, load_bias0, beta); mul_v4(load_bias1, load_bias1, beta); mul_v4(load_bias2, load_bias2, beta); @@ -599,7 +603,6 @@ extern "C" __global__ void __launch_bounds__(256) #pragma unroll for (int y = 0; y < reg_m; y += 4) { - I2F_4x8(reg_acc, y, 0); FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point); STG_4x1(stg_ptr, reg_acc, y, 0); diff --git a/dnn/src/cuda/ptx/uint4_int4/imma8832_ldg16_256x64_relu.cu b/dnn/src/cuda/ptx/uint4_int4/imma8832_ldg16_256x64_relu.cu index ba6b016b..d2be4f26 100644 --- a/dnn/src/cuda/ptx/uint4_int4/imma8832_ldg16_256x64_relu.cu +++ b/dnn/src/cuda/ptx/uint4_int4/imma8832_ldg16_256x64_relu.cu @@ -659,6 +659,20 @@ extern "C" __global__ void __launch_bounds__(256) __syncthreads(); } + size_t oc = bidy * BM + 16 * idx_in_quad; + const float* bias_ptr = bias + oc; + + int4 load_bias0 = make_int4(0, 0, 0, 0); + int4 load_bias1 = make_int4(0, 0, 0, 0); + int4 load_bias2 = make_int4(0, 0, 0, 0); + int4 load_bias3 = make_int4(0, 0, 0, 0); + if (oc < param.oc) { + load_bias0 = *(reinterpret_cast(bias_ptr)); + load_bias1 = *(reinterpret_cast(bias_ptr + 4)); + load_bias2 = *(reinterpret_cast(bias_ptr + 8)); + load_bias3 = *(reinterpret_cast(bias_ptr + 12)); + } + guard = iter < 0; #pragma unroll // low for (int i = 0; i < reg_nd4; ++i) { @@ -755,18 +769,8 @@ extern "C" __global__ void __launch_bounds__(256) size_t nhw_post3 = nhw_post0 + 24; size_t stg_oc = bidy * BM; - size_t oc = bidy * BM + 16 * idx_in_quad; - const float* bias_ptr = bias + oc; - int4 load_bias0 = make_int4(0, 0, 0, 0); - int4 load_bias1 = make_int4(0, 0, 0, 0); - int4 load_bias2 = make_int4(0, 0, 0, 0); - int4 load_bias3 = make_int4(0, 0, 0, 0); if (oc < param.oc) { - load_bias0 = *(reinterpret_cast(bias_ptr)); - load_bias1 = *(reinterpret_cast(bias_ptr + 4)); - load_bias2 = *(reinterpret_cast(bias_ptr + 8)); - load_bias3 = *(reinterpret_cast(bias_ptr + 12)); mul_v4(load_bias0, load_bias0, beta); mul_v4(load_bias1, load_bias1, beta); mul_v4(load_bias2, load_bias2, beta); @@ -779,7 +783,6 @@ extern "C" __global__ void __launch_bounds__(256) #pragma unroll for (int y = 0; y < reg_m; y += 4) { - I2F_4x8(reg_acc, y, 0); FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point); STG_4x1(stg_ptr, reg_acc, y, 0); diff --git a/dnn/src/cuda/ptx/uint4_int4/imma8832_ldgsts16_128x128_relu.cu b/dnn/src/cuda/ptx/uint4_int4/imma8832_ldgsts16_128x128_relu.cu index eeacceda..f4da3884 100644 --- a/dnn/src/cuda/ptx/uint4_int4/imma8832_ldgsts16_128x128_relu.cu +++ b/dnn/src/cuda/ptx/uint4_int4/imma8832_ldgsts16_128x128_relu.cu @@ -449,15 +449,15 @@ extern "C" __global__ void __launch_bounds__(256) bool stg_guard[8]; #pragma unroll for (int y = 0; y < reg_m; y += 4) { - COMPUTE_OFFSET_4x1(reg_fuse_z, g_offset, y) + COMPUTE_OFFSET_4x1(g_offset, y); - nhw_post0 += 32; + nhw_post0 += 32; nhw_post1 += 32; nhw_post2 += 32; nhw_post3 += 32; } - bool only_one_stage = (stage == 1) ? true : false; + bool only_one_stage = (stage == 1); if (stage >= 2) { cp_async_wait(stages - 2); } else { @@ -835,6 +835,20 @@ extern "C" __global__ void __launch_bounds__(256) cp_async_wait(stages - 2); } + size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad; + const float* bias_ptr = bias + oc; + + int4 load_bias0 = make_int4(0, 0, 0, 0); + int4 load_bias1 = make_int4(0, 0, 0, 0); + int4 load_bias2 = make_int4(0, 0, 0, 0); + int4 load_bias3 = make_int4(0, 0, 0, 0); + if (oc < param.oc) { + load_bias0 = *(reinterpret_cast(bias_ptr)); + load_bias1 = *(reinterpret_cast(bias_ptr + 4)); + load_bias2 = *(reinterpret_cast(bias_ptr + 8)); + load_bias3 = *(reinterpret_cast(bias_ptr + 12)); + } + if (!only_one_stage) { #pragma unroll // low for (int i = 0; i < reg_nd4; ++i) { @@ -965,6 +979,13 @@ extern "C" __global__ void __launch_bounds__(256) reg_flt[0][j] = make_int4(x, y, z, w); } + if (oc < param.oc) { + mul_v4(load_bias0, load_bias0, beta); + mul_v4(load_bias1, load_bias1, beta); + mul_v4(load_bias2, load_bias2, beta); + mul_v4(load_bias3, load_bias3, beta); + } + // compute #pragma unroll for (int k_inner = 0; k_inner < BKd32; k_inner++) { @@ -1028,38 +1049,19 @@ extern "C" __global__ void __launch_bounds__(256) __syncthreads(); /// output - size_t oc = bidy * BM + (warp_y << 6) + 16 * idx_in_quad; - const float* bias_ptr = bias + oc; - - int4 load_bias0 = make_int4(0, 0, 0, 0); - int4 load_bias1 = make_int4(0, 0, 0, 0); - int4 load_bias2 = make_int4(0, 0, 0, 0); - int4 load_bias3 = make_int4(0, 0, 0, 0); - if (oc < param.oc) { - load_bias0 = *(reinterpret_cast(bias_ptr)); - load_bias1 = *(reinterpret_cast(bias_ptr + 4)); - load_bias2 = *(reinterpret_cast(bias_ptr + 8)); - load_bias3 = *(reinterpret_cast(bias_ptr + 12)); - mul_v4(load_bias0, load_bias0, beta); - mul_v4(load_bias1, load_bias1, beta); - mul_v4(load_bias2, load_bias2, beta); - mul_v4(load_bias3, load_bias3, beta); - } int8_t* __restrict__ g_dst_ptr = dst + d_offset; -#pragma unroll - for (int y = 0; y < reg_m; y += 4) { - I2F_4x8(reg_acc, y, 0); - FMA_4x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); - PACK_F2I_WITH_RELU_4x8(reg_acc, y, 0, relu, dst_zero_point); - STG_AFTER_LDG_4x1(g_offset, reg_acc, y, 0); + FMA_1x8(reg_acc, 0, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); + PACK_F2I_WITH_RELU_1x8(reg_acc, 0, 0, relu, dst_zero_point); - nhw_post0 += 32; - nhw_post1 += 32; - nhw_post2 += 32; - nhw_post3 += 32; +#pragma unroll + for (int y = 1; y < reg_m; y += 1) { + FMA_1x8(reg_acc, y, 0, alpha, load_bias0, load_bias1, load_bias2, load_bias3); + PACK_F2I_WITH_RELU_1x8(reg_acc, y, 0, relu, dst_zero_point); + STG_AFTER_LDG(g_offset[y - 1], reg_acc[y - 1][0], stg_guard[y - 1]); } + STG_AFTER_LDG(g_offset[7], reg_acc[7][0], stg_guard[7]); #endif } } // namespace diff --git a/dnn/src/cuda/ptx/uint4_int4/macro.cuh b/dnn/src/cuda/ptx/uint4_int4/macro.cuh index 812a0b15..defb5f26 100644 --- a/dnn/src/cuda/ptx/uint4_int4/macro.cuh +++ b/dnn/src/cuda/ptx/uint4_int4/macro.cuh @@ -23,78 +23,26 @@ __device__ __forceinline__ void mul_v4( __device__ __forceinline__ void fma2( int2& c0, const int2 a0, int2& c1, const int2 a1, const float alpha, const int4 b) { - asm("fma.rz.f32 %0, %1, %2, %3;" - : "=f"(((float*)&c0)[0]) - : "f"(((float*)&a0)[0]), "f"(alpha), "f"(((float*)&b)[0])); - asm("fma.rz.f32 %0, %1, %2, %3;" - : "=f"(((float*)&c0)[1]) - : "f"(((float*)&a0)[1]), "f"(alpha), "f"(((float*)&b)[1])); - asm("fma.rz.f32 %0, %1, %2, %3;" - : "=f"(((float*)&c1)[0]) - : "f"(((float*)&a1)[0]), "f"(alpha), "f"(((float*)&b)[2])); - asm("fma.rz.f32 %0, %1, %2, %3;" - : "=f"(((float*)&c1)[1]) - : "f"(((float*)&a1)[1]), "f"(alpha), "f"(((float*)&b)[3])); -} - -__device__ __forceinline__ void fuse_z_1x8( - int4* a, const int& j, const int4& fuse_z, const float& gamma, - const int32_t& zero_point) { - const int2 z[2] = { - *reinterpret_cast(&fuse_z), - *(reinterpret_cast(&fuse_z) + 1)}; - for (int k = 0; k < 4; k++) { - int f = ((z[0].x >> (k * 8)) & 15); - f = (f << 28) >> 28; - ((float*)&(a[j + k]))[0] += (f - zero_point) * gamma; - f = ((z[0].x >> (k * 8 + 4)) & 15); - f = (f << 28) >> 28; - ((float*)&(a[j + k]))[1] += (f - zero_point) * gamma; - - f = ((z[1].x >> (k * 8)) & 15); - f = (f << 28) >> 28; - ((float*)&(a[j + k]))[2] += (f - zero_point) * gamma; - f = ((z[1].x >> (k * 8 + 4)) & 15); - f = (f << 28) >> 28; - ((float*)&(a[j + k]))[3] += (f - zero_point) * gamma; - } - for (int k = 0; k < 4; k++) { - int f = ((z[0].y >> (k * 8)) & 15); - f = (f << 28) >> 28; - ((float*)&(a[j + k + 4]))[0] += (f - zero_point) * gamma; - f = ((z[0].y >> (k * 8 + 4)) & 15); - f = (f << 28) >> 28; - ((float*)&(a[j + k + 4]))[1] += (f - zero_point) * gamma; - - f = ((z[1].y >> (k * 8)) & 15); - f = (f << 28) >> 28; - ((float*)&(a[j + k + 4]))[2] += (f - zero_point) * gamma; - f = ((z[1].y >> (k * 8 + 4)) & 15); - f = (f << 28) >> 28; - ((float*)&(a[j + k + 4]))[3] += (f - zero_point) * gamma; - } + ((float*)&c0)[0] = a0.x * alpha + ((float*)&b)[0]; + ((float*)&c0)[1] = a0.y * alpha + ((float*)&b)[1]; + ((float*)&c1)[0] = a1.x * alpha + ((float*)&b)[2]; + ((float*)&c1)[1] = a1.y * alpha + ((float*)&b)[3]; } __device__ __forceinline__ void fuse_z_1x8( int2* a, const int& j, const int2& fuse_z, const float& gamma, const int32_t& zero_point) { + float x = zero_point * gamma; #pragma unroll for (int k = 0; k < 4; k++) { int f = ((fuse_z.x >> (k * 8)) & 15); - f = (f << 28) >> 28; - ((float*)&(a[j + k]))[0] += (f - zero_point) * gamma; + ((float*)&(a[j + k]))[0] += f * gamma - x; f = ((fuse_z.x >> (k * 8 + 4)) & 15); - f = (f << 28) >> 28; - ((float*)&(a[j + k]))[1] += (f - zero_point) * gamma; - } -#pragma unroll - for (int k = 0; k < 4; k++) { - int f = ((fuse_z.y >> (k * 8)) & 15); - f = (f << 28) >> 28; - ((float*)&(a[j + k + 4]))[0] += (f - zero_point) * gamma; + ((float*)&(a[j + k]))[1] += f * gamma - x; + f = ((fuse_z.y >> (k * 8)) & 15); + ((float*)&(a[j + k + 4]))[0] += f * gamma - x; f = ((fuse_z.y >> (k * 8 + 4)) & 15); - f = (f << 28) >> 28; - ((float*)&(a[j + k + 4]))[1] += (f - zero_point) * gamma; + ((float*)&(a[j + k + 4]))[1] += f * gamma - x; } } @@ -282,12 +230,6 @@ __device__ __forceinline__ void pack_f2i_with_relu( fuse_z_1x8(a[i + 2], j, fuse_z[i + 2], gamma, zero_point); \ fuse_z_1x8(a[i + 3], j, fuse_z[i + 3], gamma, zero_point); -#define FUSE_Z_4x8(a, i, j, fuse_z, gamma, zero_point) \ - fuse_z_1x8(a[i], j, fuse_z[i], gamma, zero_point); \ - fuse_z_1x8(a[i + 1], j, fuse_z[i + 1], gamma, zero_point); \ - fuse_z_1x8(a[i + 2], j, fuse_z[i + 2], gamma, zero_point); \ - fuse_z_1x8(a[i + 3], j, fuse_z[i + 3], gamma, zero_point); - // 1x8 1x(2x8 int2) to 2 int2 #define PACK_F2I_1x8(a, i, j) \ pack_f2i(a[i][j].x, a[i][j].z, a[i][j], a[i][j + 1], a[i][j + 2], a[i][j + 3]); \ @@ -316,24 +258,20 @@ __device__ __forceinline__ void pack_f2i_with_relu( stg_guard[i + 2]) \ LDG(d[i + 3], s[i + 3], 3, reg_src_cache[0].w, reg_src_cache[1].w, stg_guard[i + 3]) -#define COMPUTE_OFFSET(d, s, idx, n_reuse, hw_reuse, g) \ +#define COMPUTE_OFFSET(s, idx, n_reuse, hw_reuse, g) \ n_reuse = nhw_post##idx / param.div_ohow; \ hw_reuse = nhw_post##idx % param.div_ohow; \ s = n_reuse * param.obs + hw_reuse * (packed_channel >> 1); \ g = nhw_post##idx < param.nhw; -#define COMPUTE_OFFSET_4x1(d, s, i) \ - COMPUTE_OFFSET( \ - d[i], s[i], 0, reg_src_cache[0].x, reg_src_cache[1].x, stg_guard[i]) \ - COMPUTE_OFFSET( \ - d[i + 1], s[i + 1], 1, reg_src_cache[0].y, reg_src_cache[1].y, \ - stg_guard[i + 1]) \ - COMPUTE_OFFSET( \ - d[i + 2], s[i + 2], 2, reg_src_cache[0].z, reg_src_cache[1].z, \ - stg_guard[i + 2]) \ - COMPUTE_OFFSET( \ - d[i + 3], s[i + 3], 3, reg_src_cache[0].w, reg_src_cache[1].w, \ - stg_guard[i + 3]) +#define COMPUTE_OFFSET_4x1(s, i) \ + COMPUTE_OFFSET(s[i], 0, reg_src_cache[0].x, reg_src_cache[1].x, stg_guard[i]) \ + COMPUTE_OFFSET( \ + s[i + 1], 1, reg_src_cache[0].y, reg_src_cache[1].y, stg_guard[i + 1]) \ + COMPUTE_OFFSET( \ + s[i + 2], 2, reg_src_cache[0].z, reg_src_cache[1].z, stg_guard[i + 2]) \ + COMPUTE_OFFSET( \ + s[i + 3], 3, reg_src_cache[0].w, reg_src_cache[1].w, stg_guard[i + 3]) #define STG_AFTER_LDG(d, s, g) \ if (stg_oc < param.oc && g) { \ -- GitLab