/** * \file dnn/src/arm_common/elemwise_helper/kimpl/none.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once #include "src/arm_common/elemwise_helper/kimpl/op_base.h" namespace megdnn { namespace arm_common { template struct NoneOpBase : UnaryOpBase { using UnaryOpBase::UnaryOpBase; dst_ctype operator()(const src_ctype& src) const { return src; } }; template struct NoneOp; #define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ template <> \ struct NoneOp<_ctype> : NoneOpBase<_ctype> { \ using NoneOpBase::NoneOpBase; \ using NoneOpBase::operator(); \ constexpr static size_t SIMD_WIDTH = _simd_width; \ _neon_type2 operator()(const _neon_type2& src) const { return src; } \ _neon_type operator()(const _neon_type& src) const { return src; } \ }; OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC OP(__fp16, float16x8_t, float16x8x2_t, f16, 8) #endif OP(dt_int32, int32x4_t, int32x4x2_t, s32, 4) OP(dt_int16, int16x8_t, int16x8x2_t, s16, 8) OP(dt_int8, int8x16_t, int8x16x2_t, s8, 16) #undef OP template <> struct NoneOpBase : UnaryOpBase { using UnaryOpBase::UnaryOpBase; void operator()(const dt_qint8& src, dt_qint8* dst) const { *dst = src; } }; template <> struct NoneOpBase : UnaryOpBase { using UnaryOpBase::UnaryOpBase; void operator()(const dt_quint8& src, dt_quint8* dst) const { *dst = src; } }; template <> struct NoneOpBase : UnaryOpBase { using UnaryOpBase::UnaryOpBase; void operator()(const dt_qint32& src, dt_qint8* dst) const { *(reinterpret_cast(dst)) = src; } }; template <> struct NoneOp : NoneOpBase { using NoneOpBase::NoneOpBase; using NoneOpBase::operator(); constexpr static size_t SIMD_WIDTH = 4; void operator()(const int32x4x2_t& vsrc, dt_qint8* dst) const { vst1q_s32(reinterpret_cast(dst), vsrc.val[0]); vst1q_s32(reinterpret_cast(dst + 16), vsrc.val[1]); } void operator()(const int32x4_t& src, dt_qint8* dst) const { vst1q_s32(reinterpret_cast(dst), src); } }; } // namespace arm_common } // namespace megdnn // vim: syntax=cpp.doxygen