diff --git a/src/dev/cpu/op/conv/cortex_a/conv_dw_kernel_int8_arm.c b/src/dev/cpu/op/conv/cortex_a/conv_dw_kernel_int8_arm.c index 348f7da935b52a5851688b78724aa552cb7c859e..54c9285865b9e8e740cb176a12acb23e968324a0 100644 --- a/src/dev/cpu/op/conv/cortex_a/conv_dw_kernel_int8_arm.c +++ b/src/dev/cpu/op/conv/cortex_a/conv_dw_kernel_int8_arm.c @@ -1,6 +1,7 @@ -// -// Created by Hebing Shi on 2020/11/4. -// +/* + * Author: 1091545398@qq.com + */ + #include "conv_dw_kernel_int8_arm.h" #include "tengine_ir.h" #include "sys_port.h" @@ -117,6 +118,7 @@ static inline void conv_dw_int8_3x3s1(const int8_t* input, int8_t* kernel, const float32x4_t f1 = vdupq_n_f32(1); float32x4_t f6 = vdupq_n_f32(6); float32x4_t f_1 = vdupq_n_f32(-1); + uint32x4_t u_1 = vmovq_n_u32(1); for (; i + 1 < outh; i += 2) { @@ -238,8 +240,13 @@ static inline void conv_dw_int8_3x3s1(const int8_t* input, int8_t* kernel, const sum0_f = vmulq_f32(sum0_f, outscale); sum0_1_f = vmulq_f32(sum0_1_f, outscale); - sum0_f = vaddq_f32(sum0_f, f_0_5); - sum0_1_f = vaddq_f32(sum0_1_f, f_0_5); + + sum0_f = vaddq_f32(sum0_f, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcgtq_f32(f0,sum0_f),u_1)),0.5)); + sum0_1_f = vaddq_f32(sum0_1_f, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcgtq_f32(f0,sum0_1_f),u_1)),0.5)); + sum0_f = vaddq_f32(sum0_f, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcltq_f32(f0,sum0_f),u_1)),-0.5)); + sum0_1_f = vaddq_f32(sum0_1_f, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcltq_f32(f0,sum0_1_f),u_1)),-0.5)); + //sum0_f = vaddq_f32(sum0_f, f_0_5); + //sum0_1_f = vaddq_f32(sum0_1_f, f_0_5); float32x4_t sum2_f = vcvtq_f32_s32(sum2); float32x4_t sum2_1_f = vcvtq_f32_s32(sum2_1); @@ -277,9 +284,11 @@ static inline void conv_dw_int8_3x3s1(const int8_t* input, int8_t* kernel, const } sum2_f = vmulq_f32(sum2_f, outscale); sum2_1_f = vmulq_f32(sum2_1_f, outscale); - - sum2_f = vaddq_f32(sum2_f, f_0_5); - sum2_1_f = vaddq_f32(sum2_1_f, f_0_5); + /* round */ + sum2_f = vaddq_f32(sum2_f, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcgtq_f32(f0,sum2_f),u_1)),0.5)); + sum2_1_f = vaddq_f32(sum2_1_f, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcgtq_f32(f0,sum2_1_f),u_1)),0.5)); + sum2_f = vaddq_f32(sum2_f, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcltq_f32(f0,sum2_f),u_1)),-0.5)); + sum2_1_f = vaddq_f32(sum2_1_f, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcltq_f32(f0,sum2_1_f),u_1)),-0.5)); sum2 = vcvtq_s32_f32(sum2_f); sum2_1 = vcvtq_s32_f32(sum2_1_f); @@ -425,10 +434,13 @@ static inline void conv_dw_int8_3x3s1(const int8_t* input, int8_t* kernel, const sum0_1_f = vminq_f32(sum0_1_f, f6); } } + /* round */ sum0_f = vmulq_f32(sum0_f, outscale); sum0_1_f = vmulq_f32(sum0_1_f, outscale); - sum0_f = vaddq_f32(sum0_f, f_0_5); - sum0_1_f = vaddq_f32(sum0_1_f, f_0_5); + sum0_f = vaddq_f32(sum0_f, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcgtq_f32(f0,sum0_f),u_1)),0.5)); + sum0_1_f = vaddq_f32(sum0_1_f, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcgtq_f32(f0,sum0_1_f),u_1)),0.5)); + sum0_f = vaddq_f32(sum0_f, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcltq_f32(f0,sum0_f),u_1)),-0.5)); + sum0_1_f = vaddq_f32(sum0_1_f, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcltq_f32(f0,sum0_1_f),u_1)),-0.5)); sum0 = vcvtq_s32_f32(sum0_f); sum0_1 = vcvtq_s32_f32(sum0_1_f); @@ -507,6 +519,7 @@ static inline void conv_dw_int8_3x3s2(const int8_t* input, const int8_t* kernel, float32x4_t f1 = vdupq_n_f32(1); float32x4_t f6 = vdupq_n_f32(6); float32x4_t f_1 = vdupq_n_f32(-1); + uint32x4_t u_1 = vmovq_n_u32(1); for (; i < outh; ++i) { @@ -593,8 +606,11 @@ static inline void conv_dw_int8_3x3s2(const int8_t* input, const int8_t* kernel, sum0_f = vmulq_f32(sum0_f, outscale); sum0_1_f = vmulq_f32(sum0_1_f, outscale); - sum0_f = vaddq_f32(sum0_f, f_0_5); - sum0_1_f = vaddq_f32(sum0_1_f, f_0_5); + /* round */ + sum0_f = vaddq_f32(sum0_f, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcgtq_f32(f0,sum0_f),u_1)),0.5)); + sum0_1_f = vaddq_f32(sum0_1_f, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcgtq_f32(f0,sum0_1_f),u_1)),0.5)); + sum0_f = vaddq_f32(sum0_f, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcltq_f32(f0,sum0_f),u_1)),-0.5)); + sum0_1_f = vaddq_f32(sum0_1_f, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcltq_f32(f0,sum0_1_f),u_1)),-0.5)); sum0 = vcvtq_s32_f32(sum0_f); sum0_1 = vcvtq_s32_f32(sum0_1_f); diff --git a/src/dev/cpu/op/conv/cortex_a/conv_dw_kernel_int8_arm.h b/src/dev/cpu/op/conv/cortex_a/conv_dw_kernel_int8_arm.h index fc4ac308a80a8861e662489652858a1f1b38ee48..430bb38b0392bb4a61c3f2d80b2587bbc6e90e9a 100644 --- a/src/dev/cpu/op/conv/cortex_a/conv_dw_kernel_int8_arm.h +++ b/src/dev/cpu/op/conv/cortex_a/conv_dw_kernel_int8_arm.h @@ -1,6 +1,6 @@ -// -// Created by Hebing Shi on 2020/11/4. -// +/* + * Author: 1091545398@qq.com + */ #ifndef TENGINLITE_CONV_DW_KERNEL_INT8_ARM_H #define TENGINLITE_CONV_DW_KERNEL_INT8_ARM_H diff --git a/src/dev/cpu/op/conv/cortex_a/conv_kernel_int8_arm.c b/src/dev/cpu/op/conv/cortex_a/conv_kernel_int8_arm.c index 64270af4481787013dcefb8b596df80d3dcce197..34088f5a403a9f26ca5176e7bb876be56f9614b4 100644 --- a/src/dev/cpu/op/conv/cortex_a/conv_kernel_int8_arm.c +++ b/src/dev/cpu/op/conv/cortex_a/conv_kernel_int8_arm.c @@ -1,7 +1,6 @@ -// -// Created by Hebing Shi on 2020/11/9. -// - +/* + * Author: 1091545398@qq.com + */ #include "conv_kernel_arm.h" #include #define PER_OUT_CHAN 8 @@ -264,22 +263,77 @@ static inline void sgemm_8x8(int32_t* biases, int8_t* input, int8_t* kernel, lon float32x4_t f_0_5 = vdupq_n_f32(0.5); int32x4_t d127 = vdupq_n_s32(127); int32x4_t d_127 = vdupq_n_s32(-127); - out0 = vcvtq_s32_f32(vaddq_f32(vmulq_n_f32(out0_f, outputscales), f_0_5)); - out8 = vcvtq_s32_f32(vaddq_f32(vmulq_n_f32(out8_f, outputscales), f_0_5)); - out1 = vcvtq_s32_f32(vaddq_f32(vmulq_n_f32(out1_f, outputscales), f_0_5)); - out9 = vcvtq_s32_f32(vaddq_f32(vmulq_n_f32(out9_f, outputscales), f_0_5)); - out2 = vcvtq_s32_f32(vaddq_f32(vmulq_n_f32(out2_f, outputscales), f_0_5)); - out10 = vcvtq_s32_f32(vaddq_f32(vmulq_n_f32(out10_f, outputscales), f_0_5)); - out3 = vcvtq_s32_f32(vaddq_f32(vmulq_n_f32(out3_f, outputscales), f_0_5)); - out11 = vcvtq_s32_f32(vaddq_f32(vmulq_n_f32(out11_f, outputscales), f_0_5)); - out4 = vcvtq_s32_f32(vaddq_f32(vmulq_n_f32(out4_f, outputscales), f_0_5)); - out12 = vcvtq_s32_f32(vaddq_f32(vmulq_n_f32(out12_f, outputscales), f_0_5)); - out5 = vcvtq_s32_f32(vaddq_f32(vmulq_n_f32(out5_f, outputscales), f_0_5)); - out13 = vcvtq_s32_f32(vaddq_f32(vmulq_n_f32(out13_f, outputscales), f_0_5)); - out6 = vcvtq_s32_f32(vaddq_f32(vmulq_n_f32(out6_f, outputscales), f_0_5)); - out14 = vcvtq_s32_f32(vaddq_f32(vmulq_n_f32(out14_f, outputscales), f_0_5)); - out7 = vcvtq_s32_f32(vaddq_f32(vmulq_n_f32(out7_f, outputscales), f_0_5)); - out15 = vcvtq_s32_f32(vaddq_f32(vmulq_n_f32(out15_f, outputscales), f_0_5)); + uint32x4_t u_1 = vmovq_n_u32(1); + + out0 = vmulq_n_f32(out0_f, outputscales); + out1 = vmulq_n_f32(out1_f, outputscales); + out2 = vmulq_n_f32(out2_f, outputscales); + out3 = vmulq_n_f32(out3_f, outputscales); + out4 = vmulq_n_f32(out4_f, outputscales); + out5 = vmulq_n_f32(out5_f, outputscales); + out6 = vmulq_n_f32(out6_f, outputscales); + out7 = vmulq_n_f32(out7_f, outputscales); + out8 = vmulq_n_f32(out8_f, outputscales); + out9 = vmulq_n_f32(out9_f, outputscales); + out10 = vmulq_n_f32(out10_f, outputscales); + out11 = vmulq_n_f32(out11_f, outputscales); + out12 = vmulq_n_f32(out12_f, outputscales); + out13 = vmulq_n_f32(out13_f, outputscales); + out14 = vmulq_n_f32(out14_f, outputscales); + out15 = vmulq_n_f32(out15_f, outputscales); + + /* round pos */ + out0 = vaddq_f32(out0, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcgtq_f32(f_0,out0),u_1)),0.5)); + out1 = vaddq_f32(out1, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcgtq_f32(f_0,out1),u_1)),0.5)); + out2 = vaddq_f32(out2, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcgtq_f32(f_0,out2),u_1)),0.5)); + out3 = vaddq_f32(out3, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcgtq_f32(f_0,out3),u_1)),0.5)); + out4 = vaddq_f32(out4, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcgtq_f32(f_0,out4),u_1)),0.5)); + out5 = vaddq_f32(out5, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcgtq_f32(f_0,out5),u_1)),0.5)); + out6 = vaddq_f32(out6, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcgtq_f32(f_0,out6),u_1)),0.5)); + out7 = vaddq_f32(out7, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcgtq_f32(f_0,out7),u_1)),0.5)); + out8 = vaddq_f32(out8, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcgtq_f32(f_0,out8),u_1)),0.5)); + out9 = vaddq_f32(out9, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcgtq_f32(f_0,out9),u_1)),0.5)); + out10 = vaddq_f32(out10, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcgtq_f32(f_0,out10),u_1)),0.5)); + out11 = vaddq_f32(out11, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcgtq_f32(f_0,out11),u_1)),0.5)); + out12 = vaddq_f32(out12, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcgtq_f32(f_0,out12),u_1)),0.5)); + out13 = vaddq_f32(out13, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcgtq_f32(f_0,out13),u_1)),0.5)); + out14 = vaddq_f32(out14, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcgtq_f32(f_0,out14),u_1)),0.5)); + out15 = vaddq_f32(out15, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcgtq_f32(f_0,out15),u_1)),0.5)); + + /* round neg */ + out0 = vaddq_f32(out0, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcltq_f32(f_0,out0),u_1)),-0.5)); + out1 = vaddq_f32(out1, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcltq_f32(f_0,out1),u_1)),-0.5)); + out2 = vaddq_f32(out2, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcltq_f32(f_0,out2),u_1)),-0.5)); + out3 = vaddq_f32(out3, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcltq_f32(f_0,out3),u_1)),-0.5)); + out4 = vaddq_f32(out4, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcltq_f32(f_0,out4),u_1)),-0.5)); + out5 = vaddq_f32(out5, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcltq_f32(f_0,out5),u_1)),-0.5)); + out6 = vaddq_f32(out6, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcltq_f32(f_0,out6),u_1)),-0.5)); + out7 = vaddq_f32(out7, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcltq_f32(f_0,out7),u_1)),-0.5)); + out8 = vaddq_f32(out8, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcltq_f32(f_0,out8),u_1)),-0.5)); + out9 = vaddq_f32(out9, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcltq_f32(f_0,out9),u_1)),-0.5)); + out10 = vaddq_f32(out10, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcltq_f32(f_0,out10),u_1)),-0.5)); + out11 = vaddq_f32(out11, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcltq_f32(f_0,out11),u_1)),-0.5)); + out12 = vaddq_f32(out12, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcltq_f32(f_0,out12),u_1)),-0.5)); + out13 = vaddq_f32(out13, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcltq_f32(f_0,out13),u_1)),-0.5)); + out14 = vaddq_f32(out14, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcltq_f32(f_0,out14),u_1)),-0.5)); + out15 = vaddq_f32(out15, vmulq_n_f32(vcvtq_f32_u32(vaddq_u32(vcltq_f32(f_0,out15),u_1)),-0.5)); + + out0 = vcvtq_s32_f32(out0); + out1 = vcvtq_s32_f32(out1); + out2 = vcvtq_s32_f32(out2); + out3 = vcvtq_s32_f32(out3); + out4 = vcvtq_s32_f32(out4); + out5 = vcvtq_s32_f32(out5); + out6 = vcvtq_s32_f32(out6); + out7 = vcvtq_s32_f32(out7); + out8 = vcvtq_s32_f32(out8); + out9 = vcvtq_s32_f32(out9); + out10 = vcvtq_s32_f32(out10); + out11 = vcvtq_s32_f32(out11); + out12 = vcvtq_s32_f32(out12); + out13 = vcvtq_s32_f32(out13); + out14 = vcvtq_s32_f32(out14); + out15 = vcvtq_s32_f32(out15); out0 = vminq_s32(d127, vmaxq_s32(out0, d_127)); out1 = vminq_s32(d127, vmaxq_s32(out1, d_127)); diff --git a/src/dev/cpu/op/fc/cortex_a/fc_kernel_int8_arm.c b/src/dev/cpu/op/fc/cortex_a/fc_kernel_int8_arm.c index d95b333d63780d093f6c37e7c76f0215bc939409..2fc6d9a86a129781df9ac8a29b45930971784a09 100644 --- a/src/dev/cpu/op/fc/cortex_a/fc_kernel_int8_arm.c +++ b/src/dev/cpu/op/fc/cortex_a/fc_kernel_int8_arm.c @@ -17,6 +17,10 @@ * under the License. */ +/* + * Author: 1091545398@qq.com + */ + #include #include #include