提交 ead611e1 编写于 作者: M Megvii Engine Team

perf(dnn): slightly improve arm neon transcendental function performance

GitOrigin-RevId: 210d88f81e23efd104ff32ddb57c06b39d0e3e03
上级 0d169524
...@@ -86,11 +86,11 @@ v4sf log_ps_f32(v4sf x) { ...@@ -86,11 +86,11 @@ v4sf log_ps_f32(v4sf x) {
e = vaddq_f32(e, one); e = vaddq_f32(e, one);
/* part2: /* part2:
if( x < SQRTHF ) { * if( x < SQRTHF ) {
e -= 1; * e -= 1;
x = x + x - 1.0; * x = x + x - 1.0;
} else { x = x - 1.0; } * } else { x = x - 1.0; }
*/ */
v4su mask = vcltq_f32(x, vdupq_n_f32(c_cephes_SQRTHF)); v4su mask = vcltq_f32(x, vdupq_n_f32(c_cephes_SQRTHF));
v4sf tmp = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(x), mask)); v4sf tmp = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(x), mask));
x = vsubq_f32(x, one); x = vsubq_f32(x, one);
...@@ -101,38 +101,26 @@ v4sf log_ps_f32(v4sf x) { ...@@ -101,38 +101,26 @@ v4sf log_ps_f32(v4sf x) {
v4sf z = vmulq_f32(x, x); v4sf z = vmulq_f32(x, x);
v4sf y = vdupq_n_f32(c_cephes_log_p0); v4sf y = vdupq_n_f32(c_cephes_log_p0);
y = vmulq_f32(y, x); y = fma_ps_f32(vdupq_n_f32(c_cephes_log_p1), y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p1)); y = fma_ps_f32(vdupq_n_f32(c_cephes_log_p2), y, x);
y = vmulq_f32(y, x); y = fma_ps_f32(vdupq_n_f32(c_cephes_log_p3), y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p2)); y = fma_ps_f32(vdupq_n_f32(c_cephes_log_p4), y, x);
y = vmulq_f32(y, x); y = fma_ps_f32(vdupq_n_f32(c_cephes_log_p5), y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p3)); y = fma_ps_f32(vdupq_n_f32(c_cephes_log_p6), y, x);
y = vmulq_f32(y, x); y = fma_ps_f32(vdupq_n_f32(c_cephes_log_p7), y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p4)); y = fma_ps_f32(vdupq_n_f32(c_cephes_log_p8), y, x);
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p5));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p6));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p7));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p8));
y = vmulq_f32(y, x); y = vmulq_f32(y, x);
y = vmulq_f32(y, z); y = vmulq_f32(y, z);
tmp = vmulq_f32(e, vdupq_n_f32(c_cephes_log_q1)); y = fma_ps_f32(y, e, vdupq_n_f32(c_cephes_log_q1));
y = vaddq_f32(y, tmp);
tmp = vmulq_f32(z, vdupq_n_f32(0.5f)); y = vmlsq_f32(y, z, vdupq_n_f32(0.5f));
y = vsubq_f32(y, tmp);
tmp = vmulq_f32(e, vdupq_n_f32(c_cephes_log_q2));
x = vaddq_f32(x, y); x = vaddq_f32(x, y);
x = vaddq_f32(x, tmp); x = fma_ps_f32(x, e, vdupq_n_f32(c_cephes_log_q2));
x = vreinterpretq_f32_u32(vorrq_u32( x = vreinterpretq_f32_u32(vorrq_u32(
vreinterpretq_u32_f32(x), vreinterpretq_u32_f32(x), invalid_mask)); // negative arg will be NAN
invalid_mask)); // negative arg will be NAN
return x; return x;
} }
...@@ -159,7 +147,7 @@ v4sf exp_ps_f32(v4sf x) { ...@@ -159,7 +147,7 @@ v4sf exp_ps_f32(v4sf x) {
x = vmaxq_f32(x, vdupq_n_f32(c_exp_lo)); x = vmaxq_f32(x, vdupq_n_f32(c_exp_lo));
/* express exp(x) as exp(g + n*log(2)) */ /* express exp(x) as exp(g + n*log(2)) */
fx = vmlaq_f32(vdupq_n_f32(0.5f), x, vdupq_n_f32(c_cephes_LOG2EF)); fx = fma_ps_f32(vdupq_n_f32(0.5f), x, vdupq_n_f32(c_cephes_LOG2EF));
/* perform a floorf */ /* perform a floorf */
tmp = vcvtq_f32_s32(vcvtq_s32_f32(fx)); tmp = vcvtq_f32_s32(vcvtq_s32_f32(fx));
...@@ -175,34 +163,20 @@ v4sf exp_ps_f32(v4sf x) { ...@@ -175,34 +163,20 @@ v4sf exp_ps_f32(v4sf x) {
x = vsubq_f32(x, tmp); x = vsubq_f32(x, tmp);
x = vsubq_f32(x, z); x = vsubq_f32(x, z);
static const float cephes_exp_p[6] = {c_cephes_exp_p0, c_cephes_exp_p1,
c_cephes_exp_p2, c_cephes_exp_p3,
c_cephes_exp_p4, c_cephes_exp_p5};
v4sf y = vld1q_dup_f32(cephes_exp_p + 0);
v4sf c1 = vld1q_dup_f32(cephes_exp_p + 1);
v4sf c2 = vld1q_dup_f32(cephes_exp_p + 2);
v4sf c3 = vld1q_dup_f32(cephes_exp_p + 3);
v4sf c4 = vld1q_dup_f32(cephes_exp_p + 4);
v4sf c5 = vld1q_dup_f32(cephes_exp_p + 5);
y = vmulq_f32(y, x);
z = vmulq_f32(x, x); z = vmulq_f32(x, x);
y = vaddq_f32(y, c1);
y = vmulq_f32(y, x);
y = vaddq_f32(y, c2);
y = vmulq_f32(y, x);
y = vaddq_f32(y, c3);
y = vmulq_f32(y, x);
y = vaddq_f32(y, c4);
y = vmulq_f32(y, x);
y = vaddq_f32(y, c5);
y = vmulq_f32(y, z); v4sf y = vdupq_n_f32(c_cephes_exp_p0);
y = vaddq_f32(y, x); y = fma_ps_f32(vdupq_n_f32(c_cephes_exp_p1), y, x);
y = fma_ps_f32(vdupq_n_f32(c_cephes_exp_p2), y, x);
y = fma_ps_f32(vdupq_n_f32(c_cephes_exp_p3), y, x);
y = fma_ps_f32(vdupq_n_f32(c_cephes_exp_p4), y, x);
y = fma_ps_f32(vdupq_n_f32(c_cephes_exp_p5), y, x);
y = fma_ps_f32(x, y, z);
y = vaddq_f32(y, one); y = vaddq_f32(y, one);
/* build 2^n */ /* build 2^n */
int32x4_t mm; v4si mm;
mm = vcvtq_s32_f32(fx); mm = vcvtq_s32_f32(fx);
mm = vaddq_s32(mm, vdupq_n_s32(0x7f)); mm = vaddq_s32(mm, vdupq_n_s32(0x7f));
mm = vshlq_n_s32(mm, 23); mm = vshlq_n_s32(mm, 23);
...@@ -249,8 +223,9 @@ float16x8_t exp_ps_f16(float16x8_t x) { ...@@ -249,8 +223,9 @@ float16x8_t exp_ps_f16(float16x8_t x) {
almost no extra price so both sin_ps_f32 and cos_ps_f32 make use of almost no extra price so both sin_ps_f32 and cos_ps_f32 make use of
sincos_ps_f32.. sincos_ps_f32..
*/ */
void sincos_ps_f32(v4sf x, v4sf* ysin, v4sf* ycos) { // any x void sincos_ps_f32(v4sf x, v4sf* ysin, v4sf* ycos) {
v4sf xmm1, xmm2, xmm3, y; // any x
v4sf y;
v4su emm2; v4su emm2;
...@@ -269,44 +244,36 @@ void sincos_ps_f32(v4sf x, v4sf* ysin, v4sf* ycos) { // any x ...@@ -269,44 +244,36 @@ void sincos_ps_f32(v4sf x, v4sf* ysin, v4sf* ycos) { // any x
y = vcvtq_f32_u32(emm2); y = vcvtq_f32_u32(emm2);
/* get the polynom selection mask /* get the polynom selection mask
there is one polynom for 0 <= x <= Pi/4 * there is one polynom for 0 <= x <= Pi/4
and another one for Pi/4<x<=Pi/2 * and another one for Pi/4<x<=Pi/2
*
Both branches will be computed. * Both branches will be computed.
*/ */
v4su poly_mask = vtstq_u32(emm2, vdupq_n_u32(2)); v4su poly_mask = vtstq_u32(emm2, vdupq_n_u32(2));
/* The magic pass: "Extended precision modular arithmetic" /* The magic pass: "Extended precision modular arithmetic"
x = ((x - y * DP1) - y * DP2) - y * DP3; */ * x = ((x - y * DP1) - y * DP2) - y * DP3; */
xmm1 = vmulq_n_f32(y, c_minus_cephes_DP1); x = fma_ps_f32(x, y, vdupq_n_f32(c_minus_cephes_DP1));
xmm2 = vmulq_n_f32(y, c_minus_cephes_DP2); x = fma_ps_f32(x, y, vdupq_n_f32(c_minus_cephes_DP2));
xmm3 = vmulq_n_f32(y, c_minus_cephes_DP3); x = fma_ps_f32(x, y, vdupq_n_f32(c_minus_cephes_DP3));
x = vaddq_f32(x, xmm1);
x = vaddq_f32(x, xmm2);
x = vaddq_f32(x, xmm3);
sign_mask_sin = veorq_u32(sign_mask_sin, vtstq_u32(emm2, vdupq_n_u32(4))); sign_mask_sin = veorq_u32(sign_mask_sin, vtstq_u32(emm2, vdupq_n_u32(4)));
sign_mask_cos = vtstq_u32(vsubq_u32(emm2, vdupq_n_u32(2)), vdupq_n_u32(4)); sign_mask_cos = vtstq_u32(vsubq_u32(emm2, vdupq_n_u32(2)), vdupq_n_u32(4));
/* Evaluate the first polynom (0 <= x <= Pi/4) in y1, /* Evaluate the first polynom (0 <= x <= Pi/4) in y1,
and the second polynom (Pi/4 <= x <= 0) in y2 */ * and the second polynom (Pi/4 <= x <= 0) in y2 */
v4sf z = vmulq_f32(x, x); v4sf z = vmulq_f32(x, x);
v4sf y1, y2; v4sf y1, y2;
y1 = vmulq_n_f32(z, c_coscof_p0); y1 = fma_ps_f32(vdupq_n_f32(c_coscof_p1), z, vdupq_n_f32(c_coscof_p0));
y2 = vmulq_n_f32(z, c_sincof_p0); y2 = fma_ps_f32(vdupq_n_f32(c_sincof_p1), z, vdupq_n_f32(c_sincof_p0));
y1 = vaddq_f32(y1, vdupq_n_f32(c_coscof_p1)); y1 = fma_ps_f32(vdupq_n_f32(c_coscof_p2), y1, z);
y2 = vaddq_f32(y2, vdupq_n_f32(c_sincof_p1)); y2 = fma_ps_f32(vdupq_n_f32(c_sincof_p2), y2, z);
y1 = vmulq_f32(y1, z);
y2 = vmulq_f32(y2, z);
y1 = vaddq_f32(y1, vdupq_n_f32(c_coscof_p2));
y2 = vaddq_f32(y2, vdupq_n_f32(c_sincof_p2));
y1 = vmulq_f32(y1, z); y1 = vmulq_f32(y1, z);
y2 = vmulq_f32(y2, z); y2 = vmulq_f32(y2, z);
y1 = vmulq_f32(y1, z); y1 = vmulq_f32(y1, z);
y2 = vmulq_f32(y2, x); y1 = vmlsq_f32(y1, z, vdupq_n_f32(0.5f));
y1 = vsubq_f32(y1, vmulq_f32(z, vdupq_n_f32(0.5f))); y2 = fma_ps_f32(x, y2, x);
y2 = vaddq_f32(y2, x);
y1 = vaddq_f32(y1, vdupq_n_f32(1)); y1 = vaddq_f32(y1, vdupq_n_f32(1));
/* select the correct result from the two polynoms */ /* select the correct result from the two polynoms */
...@@ -407,20 +374,20 @@ v4sf sigmoid_ps_f32(v4sf src) { ...@@ -407,20 +374,20 @@ v4sf sigmoid_ps_f32(v4sf src) {
auto val = vmaxq_f32(vdupq_n_f32(sigmoid_constants.lower_range), src); auto val = vmaxq_f32(vdupq_n_f32(sigmoid_constants.lower_range), src);
val = vminq_f32(vdupq_n_f32(sigmoid_constants.upper_range), val); val = vminq_f32(vdupq_n_f32(sigmoid_constants.upper_range), val);
auto squared = vmulq_f32(val, val); auto squared = vmulq_f32(val, val);
auto p = vmlaq_f32( auto p = fma_ps_f32(
vdupq_n_f32(sigmoid_constants.alpha_7), squared, vdupq_n_f32(sigmoid_constants.alpha_7), squared,
vdupq_n_f32(sigmoid_constants.alpha_9)); vdupq_n_f32(sigmoid_constants.alpha_9));
p = vmlaq_f32(vdupq_n_f32(sigmoid_constants.alpha_5), p, squared); p = fma_ps_f32(vdupq_n_f32(sigmoid_constants.alpha_5), p, squared);
p = vmlaq_f32(vdupq_n_f32(sigmoid_constants.alpha_3), p, squared); p = fma_ps_f32(vdupq_n_f32(sigmoid_constants.alpha_3), p, squared);
p = vmlaq_f32(vdupq_n_f32(sigmoid_constants.alpha_1), p, squared); p = fma_ps_f32(vdupq_n_f32(sigmoid_constants.alpha_1), p, squared);
p = vmulq_f32(p, val); p = vmulq_f32(p, val);
auto q = vmlaq_f32( auto q = fma_ps_f32(
vdupq_n_f32(sigmoid_constants.beta_8), squared, vdupq_n_f32(sigmoid_constants.beta_8), squared,
vdupq_n_f32(sigmoid_constants.beta_10)); vdupq_n_f32(sigmoid_constants.beta_10));
q = vmlaq_f32(vdupq_n_f32(sigmoid_constants.beta_6), q, squared); q = fma_ps_f32(vdupq_n_f32(sigmoid_constants.beta_6), q, squared);
q = vmlaq_f32(vdupq_n_f32(sigmoid_constants.beta_4), q, squared); q = fma_ps_f32(vdupq_n_f32(sigmoid_constants.beta_4), q, squared);
q = vmlaq_f32(vdupq_n_f32(sigmoid_constants.beta_2), q, squared); q = fma_ps_f32(vdupq_n_f32(sigmoid_constants.beta_2), q, squared);
q = vmlaq_f32(vdupq_n_f32(sigmoid_constants.beta_0), q, squared); q = fma_ps_f32(vdupq_n_f32(sigmoid_constants.beta_0), q, squared);
return vaddq_f32(div_ps_f32(p, q), vdupq_n_f32(sigmoid_constants.one_half)); return vaddq_f32(div_ps_f32(p, q), vdupq_n_f32(sigmoid_constants.one_half));
} }
......
...@@ -54,7 +54,7 @@ v4sf cos_ps_f32(v4sf x); ...@@ -54,7 +54,7 @@ v4sf cos_ps_f32(v4sf x);
v4sf tan_ps_f32(v4sf x); v4sf tan_ps_f32(v4sf x);
static inline v4sf div_ps_f32(v4sf x, v4sf y) { static inline v4sf div_ps_f32(v4sf& x, v4sf& y) {
#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
return vdivq_f32(x, y); return vdivq_f32(x, y);
#else #else
...@@ -65,6 +65,12 @@ static inline v4sf div_ps_f32(v4sf x, v4sf y) { ...@@ -65,6 +65,12 @@ static inline v4sf div_ps_f32(v4sf x, v4sf y) {
#endif #endif
} }
#if defined(__ARM_FEATURE_FMA)
#define fma_ps_f32(c, b, a) vfmaq_f32((c), (a), (b))
#else
#define fma_ps_f32(c, b, a) vmlaq_f32((c), (a), (b))
#endif
v4sf sigmoid_ps_f32(v4sf x); v4sf sigmoid_ps_f32(v4sf x);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
...@@ -73,7 +79,7 @@ v4sf sigmoid_ps_f32(v4sf x); ...@@ -73,7 +79,7 @@ v4sf sigmoid_ps_f32(v4sf x);
*/ */
float16x8_t exp_ps_f16(float16x8_t x); float16x8_t exp_ps_f16(float16x8_t x);
static inline float16x8_t div_ps_f16(float16x8_t x, float16x8_t y) { static inline float16x8_t div_ps_f16(float16x8_t& x, float16x8_t& y) {
#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
return vdivq_f16(x, y); return vdivq_f16(x, y);
#else #else
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册