提交 56381f80 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(dnn/arm): use vcvtq_f32_s32 for all arm code

GitOrigin-RevId: 27effe7d2402b08559ac59bfab202251d1082efc
上级 11732057
......@@ -19,7 +19,6 @@ namespace arm_common {
template <typename src_ctype, typename dst_ctype = src_ctype>
struct TypeCvtOp;
#if __ARM_ARCH >= 8
template <>
struct TypeCvtOp<dt_qint32, dt_qint8> : UnaryOpBase<dt_qint32, dt_qint8> {
using UnaryOpBase::UnaryOpBase;
......@@ -55,59 +54,6 @@ struct TypeCvtOp<dt_qint32, dt_qint8> : UnaryOpBase<dt_qint32, dt_qint8> {
return QConverter::convert<int8x8_t, float32x4_t>(vitem0);
}
};
#else
template <>
struct TypeCvtOp<dt_qint32, dt_qint8> : UnaryOpBase<dt_qint32, dt_qint8>,
FixupBase {
constexpr static size_t SIMD_WIDTH = 4;
TypeCvtOp(DType src_dtype, DType dst_dtype)
: UnaryOpBase(src_dtype, dst_dtype), FixupBase(scale) {}
TypeCvtOp(float src_scale, float dst_scale)
: UnaryOpBase(src_scale, dst_scale), FixupBase(scale) {}
void operator()(const int32x4x2_t& vsrc, dt_qint8* dst) const {
vst1_s8(reinterpret_cast<int8_t*>(dst), operator()(vsrc));
}
void operator()(const int32x4_t& vsrc, dt_qint8* dst) const {
vst1_lane_s32(reinterpret_cast<int32_t*>(dst),
(int32x2_t)(operator()(vsrc)), 0);
}
dt_qint8 operator()(const dt_qint32& src) const {
float fsrc = src.as_int32() * this->scale;
return QConverter::convert<dt_qint8, float>(fsrc);
}
void operator()(const src_ctype& src, dst_ctype* dst) const {
*dst = operator()(src);
}
int8x8_t operator()(const int32x4x2_t& vsrc) const {
int32x4_t vitem0 = vqrdmulhq_s32(vsrc.val[0], vmultiplier);
int32x4_t vitem1 = vqrdmulhq_s32(vsrc.val[1], vmultiplier);
auto fixup0 = vshrq_n_s32(vitem0, 31);
auto fixup1 = vshrq_n_s32(vitem1, 31);
// FIXME Theoretically, we should check shift != 0 here.
vitem0 = vqaddq_s32(vitem0, fixup0);
vitem1 = vqaddq_s32(vitem1, fixup1);
return vqmovn_s16(vcombine_s16(vqmovn_s32(vrshlq_s32(vitem0, vshift)),
vqmovn_s32(vrshlq_s32(vitem1, vshift))));
}
int8x8_t operator()(const int32x4_t& src) const {
int32x4_t vitem0 = vqrdmulhq_s32(src, vmultiplier);
auto fixup0 = vshrq_n_s32(vitem0, 31);
vitem0 = vqaddq_s32(vitem0, fixup0);
int16x4_t vres0_int16 = vqmovn_s32(vrshlq_s32(vitem0, vshift));
return vqmovn_s16(vcombine_s16(vres0_int16, vres0_int16));
}
int8x8_t operator()(const float32x4_t& src) const {
int32x4_t vitem0 = vqrdmulhq_s32(vcvtq_s32_f32(src), vmultiplier);
auto fixup0 = vshrq_n_s32(vitem0, 31);
vitem0 = vqaddq_s32(vitem0, fixup0);
int16x4_t vres0_int16 = vqmovn_s32(vrshlq_s32(vitem0, vshift));
return vqmovn_s16(vcombine_s16(vres0_int16, vres0_int16));
}
};
#endif
template <>
struct TypeCvtOp<dt_qint32, dt_quint8> : UnaryOpBase<dt_qint32, dt_quint8> {
......
......@@ -144,7 +144,7 @@ inline bool nchw_nchwxx_valid<NCHW88>(
const DTypeEnum dst_dtype,
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
const BiasMode bias_mode,
const param::ConvBias::NonlineMode nonline_mode) {
const param::ConvBias::NonlineMode ) {
bool ok_type = ((src_dtype == DTypeEnum::Float32 &&
filter_dtype == DTypeEnum::Float32 &&
(dst_dtype == DTypeEnum::Float32))) &&
......
......@@ -32,7 +32,6 @@
#include "./helper.h"
#include "megbrain/comp_node_env.h"
#include "cpuinfo.h"
#include "megdnn/tensor_format.h"
#include <random>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册