提交 3ef308e7 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

ci(copybara): fix copybara of arm

GitOrigin-RevId: 2aa113ef476aebad69db7d28a79e2da53eb1c360
上级 86598767
......@@ -19,11 +19,14 @@ CHECK_CXX_COMPILER_FLAG(-Wclass-memaccess CXX_SUPPORT_WCLASS_MEMACCESS)
set(MGE_ARCH AUTO CACHE STRING "Architecture on which MegEngine to be built.")
set_property(CACHE MGE_ARCH PROPERTY STRINGS AUTO
x86_64 i386
armv7 aarch64
naive fallback
)
option(MGE_WITH_JIT "Build MegEngine with JIT." ON)
option(MGE_WITH_HALIDE "Build MegEngine with Halide JIT" ON)
option(MGE_ARMV8_2_FEATURE_FP16 "Enable armv8.2-a+fp16 support" OFF)
option(MGE_ARMV8_2_FEATURE_DOTPROD "enable armv8.2-a+dotprod support" OFF)
option(MGE_DISABLE_FLOAT16 "Disable MegEngine float16 support." OFF)
option(MGE_WITH_CUDA "Enable MegEngine CUDA support." ON)
option(MGE_CUDA_USE_STATIC "Enable MegEngine CUDA static linking." ON)
......@@ -31,12 +34,52 @@ option(MGE_WITH_TRT "Build MegEngine with TensorRT." ON)
option(MGE_USE_SYSTEM_LIB "Build MegEngine with system libraries." OFF)
option(MGB_WITH_FLATBUFFERS "Build MegBrain with FlatBuffers serialization support." ON)
if(CMAKE_TOOLCHAIN_FILE)
message("We are cross compiling.")
message("config FLATBUFFERS_FLATC_EXECUTABLE to: ${PROJECT_SOURCE_DIR}/build_dir/host_flatc/install/bin/flatc")
set(FLATBUFFERS_FLATC_EXECUTABLE "${PROJECT_SOURCE_DIR}/build_dir/host_flatc/install/bin/flatc")
if(ANDROID_TOOLCHAIN_ROOT)
if(NOT "${ANDROID_ARCH_NAME}" STREQUAL "")
set(ANDROID_ARCH ${ANDROID_ARCH_NAME})
endif()
if(${ANDROID_ARCH} STREQUAL "arm")
set(MGE_ARCH "armv7")
elseif(${ANDROID_ARCH} STREQUAL "arm64")
set(MGE_ARCH "aarch64")
else()
message(FATAL_ERROR "DO NOT SUPPORT ANDROID ARCH NOW")
endif()
elseif(IOS_TOOLCHAIN_ROOT)
if(${IOS_ARCH} STREQUAL "armv7")
set(MGE_ARCH "armv7")
elseif(${IOS_ARCH} STREQUAL "arm64")
set(MGE_ARCH "aarch64")
elseif(${IOS_ARCH} STREQUAL "armv7k")
set(MGE_ARCH "armv7")
elseif(${IOS_ARCH} STREQUAL "arm64e")
set(MGE_ARCH "aarch64")
elseif(${IOS_ARCH} STREQUAL "armv7s")
set(MGE_ARCH "armv7")
else()
message(FATAL_ERROR "Unsupported IOS_ARCH.")
endif()
elseif(NOT "${ARM_CROSS_BUILD_ARCH}" STREQUAL "")
set(MGE_ARCH ${ARM_CROSS_BUILD_ARCH})
else()
message(FATAL_ERROR "Unknown cross-compiling settings.")
endif()
message("CONFIG MGE_ARCH TO ${MGE_ARCH}")
endif()
if(${MGE_ARCH} STREQUAL "AUTO")
if(${CMAKE_SYSTEM_PROCESSOR} STREQUAL "x86_64")
set(MGE_ARCH "x86_64")
elseif(${CMAKE_SYSTEM_PROCESSOR} STREQUAL "i386" OR ${CMAKE_SYSTEM_PROCESSOR} STREQUAL "i686")
set(MGE_ARCH "i386")
elseif(${CMAKE_SYSTEM_PROCESSOR} STREQUAL "aarch64" OR ${CMAKE_SYSTEM_PROCESSOR} STREQUAL "arm64")
set(MGE_ARCH "aarch64")
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "^arm")
set(MGE_ARCH "armv7")
else()
message(FATAL "Unknown machine architecture for MegEngine.")
endif()
......@@ -399,6 +442,38 @@ if(MGE_ARCH STREQUAL "x86_64" OR MGE_ARCH STREQUAL "i386")
endif()
endif()
if(MGE_ARCH STREQUAL "armv7")
# -funsafe-math-optimizations to enable neon auto-vectorization (since neon is not fully IEEE 754 compatible, GCC does not turn on neon auto-vectorization by default.
if(ANDROID)
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfloat-abi=softfp -mfpu=neon")
endif()
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -funsafe-math-optimizations")
set (MARCH "-march=armv7-a")
set (MEGDNN_ARMV7 1)
endif()
if(MGE_ARCH STREQUAL "aarch64")
set(MEGDNN_AARCH64 1)
set(MEGDNN_64_BIT 1)
set(MARCH "-march=armv8-a")
if(MGE_ARMV8_2_FEATURE_FP16)
message("Enable fp16 feature support in armv8.2")
if(NOT ${MGE_DISABLE_FLOAT16})
set(MEGDNN_ENABLE_FP16_NEON 1)
endif()
set(MARCH "-march=armv8.2-a+fp16")
endif()
if(MGE_ARMV8_2_FEATURE_DOTPROD)
message("Enable dotprod feature support in armv8.2")
if(MGE_ARMV8_2_FEATURE_FP16)
set(MARCH "-march=armv8.2-a+fp16+dotprod")
else()
set(MARCH "-march=armv8.2-a+dotprod")
endif()
endif()
endif()
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${MARCH}")
......
......@@ -29,6 +29,9 @@ class Handle {
NAIVE = 0,
FALLBACK = 1,
X86 = 2,
ARM_COMMON = 3,
ARMV7 = 4,
AARCH64 = 5,
CUDA = 6,
};
......
......@@ -17,6 +17,22 @@ if(NOT ${MGE_ARCH} STREQUAL "naive")
set_source_files_properties(${SOURCES_} PROPERTIES LANGUAGE C)
list(APPEND SOURCES ${SOURCES_})
endif()
elseif(${MGE_ARCH} STREQUAL "armv7")
file(GLOB_RECURSE SOURCES_ armv7/*.cpp)
list(APPEND SOURCES ${SOURCES_})
file(GLOB_RECURSE SOURCES_ arm_common/*.cpp)
list(APPEND SOURCES ${SOURCES_})
file(GLOB_RECURSE SOURCES_ armv7/*.S)
set_source_files_properties(${SOURCES_} PROPERTIES LANGUAGE C)
list(APPEND SOURCES ${SOURCES_})
elseif(${MGE_ARCH} STREQUAL "aarch64")
file(GLOB_RECURSE SOURCES_ aarch64/*.cpp)
list(APPEND SOURCES ${SOURCES_})
file(GLOB_RECURSE SOURCES_ arm_common/*.cpp)
list(APPEND SOURCES ${SOURCES_})
file(GLOB_RECURSE SOURCES_ aarch64/*.S)
set_source_files_properties(${SOURCES_} PROPERTIES LANGUAGE C)
list(APPEND SOURCES ${SOURCES_})
endif()
endif()
......
/**
* \file dnn/src/aarch64/conv_bias/fp16/algos.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/aarch64/conv_bias/fp16/algos.h"
#include "src/aarch64/conv_bias/fp16/stride2_kern.h"
#include "src/arm_common/conv_bias/direct/multi_thread_common.h"
#include "src/arm_common/conv_bias/postprocess_helper.h"
using namespace megdnn;
using namespace aarch64;
#include "midout.h"
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
/* ===================== stride-2 algo ===================== */
MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp16)
bool ConvBiasImpl::AlgoF16DirectStride2::usable(
FallbackConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 0) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
bool aviliable =
param.filter_meta.format == param::Convolution::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float16 &&
param.filter_type.enumv() == DTypeEnum::Float16 &&
param.dst_type.enumv() == DTypeEnum::Float16 &&
!fm.should_flip && fm.spatial_ndim == 2 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7);
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) {
bool large_group = param.filter_meta.group >= param.nr_threads;
aviliable &= (large_group == m_large_group);
}
return aviliable;
}
MIDOUT_END();
return false;
}
size_t ConvBiasImpl::AlgoF16DirectStride2::get_workspace(
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 1) {
auto wbundle = arm_common::MultithreadDirectConvCommon<
dt_float16, __fp16>::get_bundle_stride(param, m_large_group);
return wbundle.total_size_in_bytes();
}
MIDOUT_END();
return false;
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns(
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 2) {
return get_kimpls(param);
}
MIDOUT_END();
return {};
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF16DirectStride2::get_kimpls(
const NCBKernSizeParam& param) const {
auto fm = param.filter_meta;
auto FH = fm.spatial[0];
size_t N = param.n;
size_t IC = param.filter_meta.icpg;
size_t OC = param.filter_meta.ocpg;
size_t group = fm.group;
using Func = std::function<void(const __fp16*, const __fp16*, __fp16*,
size_t, size_t, size_t, size_t, size_t)>;
Func conv = nullptr;
if (FH == 2) {
conv = fp16::conv_stride2::do_conv_2x2_stride2;
} else if (FH == 3) {
conv = fp16::conv_stride2::do_conv_3x3_stride2;
} else if (FH == 5) {
conv = fp16::conv_stride2::do_conv_5x5_stride2;
} else if (FH == 7) {
conv = fp16::conv_stride2::do_conv_7x7_stride2;
}
WorkspaceBundle wbundle = arm_common::MultithreadDirectConvCommon<
dt_float16, __fp16>::get_bundle_stride(param, m_large_group);
SmallVector<NCBKern> ret_kerns;
//! Dense conv and small group
if (m_large_group) {
//! Channel wise conv and big groups
auto exec_one_group = [wbundle, conv](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
auto fm = kern_param.filter_meta;
size_t IC = fm.icpg;
size_t OC = fm.ocpg;
WorkspaceBundle bundle = wbundle;
for (size_t ic = 0; ic < IC; ic++) {
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>::
copy_padding_kern_stride(bundle, kern_param, ncb_index,
{ncb_index.thread_id, 0, ic});
}
for (size_t oc = 0; oc < OC; oc++) {
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>::
do_conv_kern_stride(bundle, kern_param, ncb_index, conv,
{ncb_index.thread_id, 0, oc});
}
};
ret_kerns.push_back({exec_one_group, {group, N, 1_z}});
} else {
WorkspaceBundle bundle = wbundle;
auto copy_padding = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>::
copy_padding_kern_stride(bundle, kern_param, ncb_index,
ncb_index.ndrange_id);
};
ret_kerns.push_back({copy_padding, {group, N, IC}});
auto do_conv = [bundle, conv](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>::
do_conv_kern_stride(bundle, kern_param, ncb_index, conv,
ncb_index.ndrange_id);
};
ret_kerns.push_back({do_conv, {group, N, OC}});
}
return ret_kerns;
}
#endif
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/aarch64/conv_bias/fp16/algos.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/aarch64/conv_bias/opr_impl.h"
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
namespace megdnn {
namespace aarch64 {
/* ===================== stride-2 algo ===================== */
class ConvBiasImpl::AlgoF16DirectStride2 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
bool m_large_group;
public:
AlgoF16DirectStride2(bool large_group) : m_large_group(large_group) {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "ARMV8F16STRD2_LARGE_GROUP"
: "ARMV8F16STRD2_SMALL_GROUP";
}
bool usable(FallbackConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(FallbackConvBiasImpl*,
const NCBKernSizeParam& param) const override;
SmallVector<NCBKern> dispatch_kerns(FallbackConvBiasImpl*,
const NCBKernSizeParam&) const override;
};
} // namespace aarch64
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen
此差异已折叠。
/**
* \file dnn/src/aarch64/conv_bias/fp32/algos.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/aarch64/conv_bias/fp32/algos.h"
#include "src/aarch64/conv_bias/fp32/stride2_kern.h"
#include "src/arm_common/conv_bias/direct/multi_thread_common.h"
#include "src/arm_common/conv_bias/postprocess_helper.h"
#include "src/fallback/conv_bias/common.h"
#include "midout.h"
using namespace megdnn;
using namespace aarch64;
MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp32)
bool ConvBiasImpl::AlgoF32DirectStride2::usable(
FallbackConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 0) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
bool aviliable =
param.filter_meta.format == param::ConvBias::Format::NCHW &&
param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 &&
param.dst_type.enumv() == DTypeEnum::Float32 &&
!fm.should_flip && fm.spatial_ndim == 2 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7);
if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) {
bool large_group = param.filter_meta.group >= param.nr_threads;
aviliable &= (large_group == m_large_group);
}
return aviliable;
}
MIDOUT_END();
return false;
}
size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace(
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 1) {
auto wbundle = arm_common::MultithreadDirectConvCommon<
float, float>::get_bundle_stride(param, m_large_group);
return wbundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns(
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 2) {
return get_kimpls(param);
}
MIDOUT_END();
return {};
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF32DirectStride2::get_kimpls(
const NCBKernSizeParam& param) const {
auto fm = param.filter_meta;
auto FH = fm.spatial[0];
size_t N = param.n;
size_t IC = param.filter_meta.icpg;
size_t OC = param.filter_meta.ocpg;
size_t group = fm.group;
using Func = std::function<void(const float*, const float*, float*, size_t,
size_t, size_t, size_t, size_t)>;
Func conv = nullptr;
if (FH == 2) {
conv = fp32::conv_stride2::do_conv_2x2_stride2;
} else if (FH == 3) {
conv = fp32::conv_stride2::do_conv_3x3_stride2;
} else if (FH == 5) {
conv = fp32::conv_stride2::do_conv_5x5_stride2;
} else if (FH == 7) {
conv = fp32::conv_stride2::do_conv_7x7_stride2;
}
WorkspaceBundle wbundle = arm_common::MultithreadDirectConvCommon<
float, float>::get_bundle_stride(param, m_large_group);
SmallVector<NCBKern> ret_kerns;
//! Dense conv and small group
if (m_large_group) {
//! Channel wise conv and big groups
auto exec_one_group = [wbundle, conv](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
auto fm = kern_param.filter_meta;
size_t IC = fm.icpg;
size_t OC = fm.ocpg;
WorkspaceBundle bundle = wbundle;
for (size_t ic = 0; ic < IC; ic++) {
arm_common::MultithreadDirectConvCommon<float, float>::
copy_padding_kern_stride(bundle, kern_param, ncb_index,
{ncb_index.thread_id, 0, ic});
}
for (size_t oc = 0; oc < OC; oc++) {
arm_common::MultithreadDirectConvCommon<
float, float>::do_conv_kern_stride(bundle, kern_param,
ncb_index, conv,
{ncb_index.thread_id,
0, oc});
}
};
ret_kerns.push_back({exec_one_group, {group, N, 1_z}});
} else {
WorkspaceBundle bundle = wbundle;
auto copy_padding = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
arm_common::MultithreadDirectConvCommon<float, float>::
copy_padding_kern_stride(bundle, kern_param, ncb_index,
ncb_index.ndrange_id);
};
ret_kerns.push_back({copy_padding, {group, N, IC}});
auto do_conv = [bundle, conv](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
arm_common::MultithreadDirectConvCommon<
float, float>::do_conv_kern_stride(bundle, kern_param,
ncb_index, conv,
ncb_index.ndrange_id);
};
ret_kerns.push_back({do_conv, {group, N, OC}});
}
return ret_kerns;
}
/**
* \file dnn/src/aarch64/conv_bias/fp32/algos.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/aarch64/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/opr_impl.h"
namespace megdnn {
namespace aarch64 {
using FallbackConvBiasImpl = fallback::ConvBiasImpl;
/* ===================== stride-2 algo ===================== */
class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase {
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
bool m_large_group;
public:
AlgoF32DirectStride2(bool large_group) : m_large_group(large_group) {}
bool is_reproducible() const override { return true; }
const char* name() const override {
return m_large_group ? "ARMV8F32STRD2_LARGE_GROUP"
: "ARMV8F32STRD2_SMALL_GROUP";
}
bool usable(FallbackConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(FallbackConvBiasImpl*,
const NCBKernSizeParam& param) const override;
SmallVector<NCBKern> dispatch_kerns(FallbackConvBiasImpl*,
const NCBKernSizeParam&) const override;
};
} // namespace aarch64
} // namespace megdnn
// vim: syntax=cpp.doxygen
此差异已折叠。
/**
* \file dnn/src/aarch64/conv_bias/int8/algos.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/aarch64/conv_bias/int8/algos.h"
#include "src/aarch64/conv_bias/int8/strategy.h"
#include "src/arm_common/convolution/img2col_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/common/opr_delegate.h"
#include "src/fallback/conv_bias/common.h"
#include "src/fallback/matrix_mul/gemm_impl.h"
#include "midout.h"
MIDOUT_DECL(megdnn_aarch64_conv_bias_int8_gemm)
using namespace megdnn;
using namespace aarch64;
using megdnn::arm_common::HSwishOp;
using megdnn::arm_common::ReluOp;
using megdnn::arm_common::TypeCvtOp;
/* ===================== matrix mul algo ===================== */
bool ConvBiasImpl::AlgoS8MatrixMul::usable(
FallbackConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
MEGDNN_MARK_USED_VAR(opr);
auto&& fm = param.filter_meta;
return param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS8 &&
fm.format == param::ConvBias::Format::NCHW && fm.spatial_ndim == 2 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
//! As postprocess, the bias is not contigous read, make the
//! performance bad, so we do not process it in fused kernel
param.bias_mode != BiasMode::BIAS &&
//! This algo is only support single thread
param.nr_threads == 1_z;
}
WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle(
const NCBKernSizeParam& param) {
UNPACK_CONV_NCB_KERN_SIZES(param);
MEGDNN_MARK_USED_VAR(N);
auto IW2 = IH + 2 * PH;
auto IH2 = IW + 2 * PW;
bool can_matrix_mul_direct =
(FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0);
// temp space to store padding-free src (with 16 extra int8)
// temp space to store unrolled matrix (with 16 extra int8)
// workspace for matrix mul opr
size_t part0, part1, part2;
if (can_matrix_mul_direct) {
part0 = part1 = 0;
} else {
part0 = (IC * IH2 * IW2 + 16) * sizeof(int8_t);
part1 = (IC * FH * FW * OH * OW + 16) * sizeof(int8_t);
}
{
size_t M = OC;
size_t K = IC * FH * FW;
size_t N = OH * OW;
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_int8_gemm, 0, _gemm_midout_enum, \
_bias_midout_enum, _nonline_midout_enum) { \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
part2 = megdnn::matmul::GemmInterleaved< \
matmul::gemm_##_gemm##_##_bias##_##_nonline>( \
M, N, K, false, false, strategy) \
.get_workspace_size(); \
} \
MIDOUT_END()
#if !(__ARM_FEATURE_DOTPROD)
DISPATCH_GEMM_BIAS(s8_4x4, 0)
#else
DISPATCH_GEMM_BIAS(s8_8x12, 1)
#endif
#undef DISPATCH_GEMM_STRATEGY
}
return {nullptr, {part0, part1, part2}};
}
void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param,
const NCBKernIndex& ncb_index) {
auto is_xcorr = !param.filter_meta.should_flip;
UNPACK_CONV_NCB_KERN_SIZES(param);
auto bundle = get_bundle(param);
bundle.set(param.workspace_ptr);
auto IH2 = IH + 2 * PH;
auto IW2 = IW + 2 * PW;
size_t group_id = ncb_index.ndrange_id[0];
// workspace = tmp..src2
for (size_t n = 0; n < N; ++n) {
dt_int8* src = const_cast<dt_int8*>(param.src<dt_int8>(n, group_id));
dt_int8* filter = const_cast<dt_int8*>(param.filter<dt_int8>(group_id));
dt_int8* dst = static_cast<dt_int8*>(param.dst<dt_int8>(n, group_id));
dt_int32* bias = const_cast<dt_int32*>(param.bias<dt_int32>(n, group_id));
dt_int8 *B, *src2;
if (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0) {
// special case: 1x1
B = const_cast<dt_int8*>(src);
} else {
src2 = static_cast<dt_int8*>(bundle.get(0));
// copy src to src2;
dt_int8* src2_ptr = src2;
const dt_int8* src_ptr = src;
rep(ic, IC) {
if (PH != 0) {
std::memset(src2_ptr, 0, sizeof(dt_int8) * PH * IW2);
src2_ptr += PH * IW2;
}
rep(ih, IH) {
if (PW != 0)
rep(pw, PW) { *(src2_ptr++) = 0.0f; }
std::memcpy(src2_ptr, src_ptr, sizeof(dt_int8) * IW);
src2_ptr += IW;
src_ptr += IW;
if (PW != 0)
rep(pw, PW) { *(src2_ptr++) = 0.0f; }
}
if (PH != 0) {
std::memset(src2_ptr, 0, sizeof(dt_int8) * PH * IW2);
src2_ptr += PH * IW2;
}
}
B = static_cast<dt_int8*>(bundle.get(1));
if (SH == 1 && SW == 1) {
if (is_xcorr)
img2col<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW);
else
img2col<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW);
} else {
if (is_xcorr)
img2col_stride<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH,
FW, SH, SW);
else
img2col_stride<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH,
FW, SH, SW);
}
}
{
Workspace workspace(static_cast<dt_byte*>(bundle.get(2)),
bundle.get_size(2));
size_t M = OC;
size_t K = IC * FH * FW;
size_t N = OH * OW;
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_int8_gemm, 1, _gemm_midout_enum, \
_bias_midout_enum, _nonline_midout_enum) { \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
megdnn::matmul::GemmInterleaved< \
matmul::gemm_##_gemm##_##_bias##_##_nonline> \
gemm_interleaved(M, N, K, false, false, strategy); \
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, \
bias); \
} \
MIDOUT_END()
#if !(__ARM_FEATURE_DOTPROD)
DISPATCH_GEMM_BIAS(s8_4x4, 0)
#else
DISPATCH_GEMM_BIAS(s8_8x12, 1)
#endif
#undef DISPATCH_GEMM_STRATEGY
}
}
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/aarch64/conv_bias/int8/algos.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/aarch64/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/opr_impl.h"
namespace megdnn {
namespace aarch64 {
using FallbackConvBiasImpl = fallback::ConvBiasImpl;
class ConvBiasImpl::AlgoS8MatrixMul final : public AlgoBase {
static WorkspaceBundle get_bundle(const NCBKernSizeParam& param);
static void kimpl(const NCBKernParam& param, const NCBKernIndex& ncb_index);
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "S8MATMUL"; }
bool usable(FallbackConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(FallbackConvBiasImpl*,
const NCBKernSizeParam& param) const override {
return get_bundle(param).total_size_in_bytes();
}
SmallVector<NCBKern> dispatch_kerns(
FallbackConvBiasImpl*, const NCBKernSizeParam& param) const override {
size_t group = param.filter_meta.group;
return {{kimpl, {group, 1_z, 1_z}}};
}
//! select matmul to the highest preference
bool is_preferred(FallbackConvBiasImpl* opr,
const NCBKernSizeParam& param) const override {
return static_cast<arm_common::ConvBiasImpl*>(opr)
->is_matmul_quantized_prefer(param);
}
};
} // namespace aarch64
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/aarch64/conv_bias/int8/strategy.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/aarch64/conv_bias/int8/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
#include "src/aarch64/matrix_mul/int8/kernel_4x4x16.h"
#include "src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h"
#include "src/arm_common/conv_bias/matmul_postprocess.h"
using namespace megdnn;
using namespace aarch64;
using namespace aarch64::matmul;
namespace impl {
template <BiasMode bmode, typename Op, int block_m, int block_n>
struct KernCaller;
#if __ARM_FEATURE_DOTPROD
template <BiasMode bmode, typename Op>
struct KernCaller<bmode, Op, 8, 12> {
static void run(const dt_int8* packA, const dt_int8* packB, size_t M,
size_t N, size_t K, dt_int8* C, size_t LDC, bool is_first_k,
Op op, const dt_int32* bias, dt_int32* workspace) {
megdnn_assert(is_first_k);
constexpr size_t A_INTERLEAVE = 8;
constexpr size_t B_INTERLEAVE = 12;
//! K is packed to times of 4
K = round_up<size_t>(K, 4);
const int K8 = (K << 3);
const int K12 = K * 12;
const int K4 = K * 4;
size_t m = 0;
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) {
int8_t* output = C + (m * LDC);
size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x12x4::kern_8x12(packA, cur_packB, K, workspace, 12,
is_first_k);
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 8, 12, 8,
12>::postprocess(bias, workspace,
output, LDC, op);
output += B_INTERLEAVE;
cur_packB += K12;
}
for (; n < N; n += 4) {
matmul_8x12x4::kern_8x4(packA, cur_packB, K, workspace, 4,
is_first_k, std::min<size_t>(N - n, 4));
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 8, 4, 8, n>::postprocess( \
bias, workspace, output, LDC, op);
DISPATCH_N(cb, 8, std::min<size_t>(N - n, 4));
#undef cb
output += 4;
cur_packB += K4;
}
packA += K8;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
bias += A_INTERLEAVE;
}
}
for (; m < M; m += 4) {
int8_t* output = C + (m * LDC);
const dt_int8* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x12x4::kern_4x12(packA, cur_packB, K, workspace, 12,
is_first_k,
std::min<size_t>(M - m, 4));
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 12, m, n>::postprocess( \
bias, workspace, output, LDC, op);
DISPATCH_M_N(cb, std::min<size_t>(M - m, 4), 12);
#undef cb
output += B_INTERLEAVE;
cur_packB += K12;
}
for (; n < N; n += 4) {
matmul_8x12x4::kern_4x4(packA, cur_packB, K, workspace, 4,
is_first_k, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, m, n>::postprocess( \
bias, workspace, output, LDC, op);
DISPATCH_M(cb, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
#undef cb
output += 4;
cur_packB += K4;
}
packA += K4;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
bias += 4;
}
}
}
};
#else
template <BiasMode bmode, typename Op>
struct KernCaller<bmode, Op, 4, 4> {
static void run(const dt_int8* packA, const dt_int8* packB, size_t M,
size_t N, size_t K, dt_int8* C, size_t LDC, bool is_first_k,
Op op, const dt_int32* bias, dt_int32* workspace) {
megdnn_assert(is_first_k);
constexpr size_t A_INTERLEAVE = 4;
constexpr size_t B_INTERLEAVE = 4;
//! K is packed to times of 4
K = round_up<size_t>(K, 16);
const int K4 = K * 4;
size_t m = 0;
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) {
int8_t* output = C + (m * LDC);
size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_4x4x16::kern_4x4(packA, cur_packB, K, workspace, 4,
is_first_k);
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, 4,
4>::postprocess(bias, workspace,
output, LDC, op);
output += B_INTERLEAVE;
cur_packB += K4;
}
for (; n < N; n += B_INTERLEAVE) {
matmul_4x4x16::kern_4x4_remain(packA, cur_packB, K, workspace,
4, is_first_k, 4,
std::min<size_t>(N - n, 4));
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, 4, n>::postprocess( \
bias, workspace, output, LDC, op);
DISPATCH_N(cb, 4, std::min<size_t>(N - n, 4));
#undef cb
output += B_INTERLEAVE;
cur_packB += K4;
}
packA += K4;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
bias += A_INTERLEAVE;
}
}
for (; m < M; m += A_INTERLEAVE) {
int8_t* output = C + (m * LDC);
size_t n = 0;
const dt_int8* cur_packB = packB;
for (; n < N; n += B_INTERLEAVE) {
matmul_4x4x16::kern_4x4_remain(
packA, cur_packB, K, workspace, 4, is_first_k,
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, m, n>::postprocess( \
bias, workspace, output, LDC, op);
DISPATCH_M(cb, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
#undef cb
output += B_INTERLEAVE;
cur_packB += K4;
}
packA += K4;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
bias += A_INTERLEAVE;
}
}
}
};
#endif
} // namespace impl
#if !(__ARM_FEATURE_DOTPROD)
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4_nobias_identity)
void gemm_s8_4x4_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
if (transpose) {
matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0,
kmax);
} else {
matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0,
kmax);
}
}
void gemm_s8_4x4_nobias_identity::pack_B(dt_int8* out, const dt_int8* in,
int ldin, int x0, int xmax, int k0,
int kmax, bool transpose) const {
if (transpose) {
matmul_4x4x16::gemm_s8_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax);
} else {
matmul_4x4x16::gemm_s8_4x4_pack_B_n(out, in, ldin, x0, xmax, k0, kmax);
}
}
size_t gemm_s8_4x4_nobias_identity::get_workspace_size() const {
return 4 * 4 * sizeof(dt_int32);
}
#else
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12_nobias_identity)
void gemm_s8_8x12_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr,
int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_A_t);
MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_B_t);
if (transpose) {
matmul_8x12x4::gemm_s8_8x12_pack_B_n(outptr, inptr, ldin, y0, ymax, k0,
kmax);
} else {
matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0,
kmax);
}
}
void gemm_s8_8x12_nobias_identity::pack_B(dt_int8* out, const dt_int8* in,
int ldin, int x0, int xmax, int k0,
int kmax, bool transpose) const {
if (transpose) {
matmul_8x12x4::gemm_s8_8x12_pack_A_n(out, in, ldin, x0, xmax, k0, kmax);
} else {
matmul_8x12x4::gemm_s8_8x12_pack_B_n(out, in, ldin, x0, xmax, k0, kmax);
}
}
size_t gemm_s8_8x12_nobias_identity::get_workspace_size() const {
return 8 * 12 * sizeof(dt_int32);
}
#endif
#define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \
void gemm_s8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \
const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, \
size_t K, dt_int8* C, size_t LDC, bool is_first_k, \
const dt_int32* bias, dt_int32* workspace) const { \
float scale_A = A_dtype.param<dtype::QuantizedS8>().scale; \
float scale_B = B_dtype.param<dtype::QuantizedS8>().scale; \
float scale_C = C_dtype.param<dtype::QuantizedS8>().scale; \
DEFINE_OP(_OP); \
impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \
packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \
workspace); \
}
#define DEFINE_OP(_Op) \
arm_common::_Op<dt_qint32, dt_qint8> op(scale_A* scale_B, scale_C);
#if !(__ARM_FEATURE_DOTPROD)
KERN(4, 4, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp)
KERN(4, 4, nobias, BiasMode::NO_BIAS, relu, ReluOp)
KERN(4, 4, nobias, BiasMode::NO_BIAS, hswish, HSwishOp)
#else
KERN(8, 12, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp)
KERN(8, 12, nobias, BiasMode::NO_BIAS, relu, ReluOp)
KERN(8, 12, nobias, BiasMode::NO_BIAS, hswish, HSwishOp)
#endif
#undef DEFINE_OP
#define DEFINE_OP(_Op) \
arm_common::_Op<dt_qint32, dt_qint8> op(scale_A* scale_B, \
scale_A* scale_B, scale_C);
#if !(__ARM_FEATURE_DOTPROD)
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp)
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp)
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish,
FuseAddHSwishOp)
#else
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp)
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp)
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish,
FuseAddHSwishOp)
#endif
#undef DEFINE_OP
#undef KERN
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/aarch64/conv_bias/int8/strategy.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/fallback/matrix_mul/gemm_common.h"
namespace megdnn {
namespace aarch64 {
namespace matmul {
#if !(__ARM_FEATURE_DOTPROD)
/**
* \brief base strategy of gemm.
*
* \name gemm_<type>_<block>_biasmode_nolinemode
*/
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 4, 4, 16,
false, true,
gemm_s8_4x4_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_nobias_relu,
gemm_s8_4x4_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_nobias_hswish,
gemm_s8_4x4_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_identity,
gemm_s8_4x4_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_relu,
gemm_s8_4x4_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_hswish,
gemm_s8_4x4_nobias_identity);
#else
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 8, 12, 4,
false, true,
gemm_s8_8x12_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_nobias_relu,
gemm_s8_8x12_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_nobias_hswish,
gemm_s8_8x12_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_identity,
gemm_s8_8x12_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_relu,
gemm_s8_8x12_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_hswish,
gemm_s8_8x12_nobias_identity);
#endif
} // namespace matmul
} // namespace aarch64
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/aarch64/conv_bias/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/aarch64/conv_bias/opr_impl.h"
#include "src/aarch64/conv_bias/int8/algos.h"
#include "src/aarch64/conv_bias/quint8/algos.h"
#include "src/naive/handle.h"
#include "src/common/utils.h"
#include "src/common/metahelper.h"
#include "src/fallback/convolution/opr_impl.h"
#include "src/aarch64/conv_bias/fp32/algos.h"
#include "src/aarch64/conv_bias/fp16/algos.h"
using namespace megdnn;
using namespace aarch64;
class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoF32DirectStride2 f32_direct_stride2_large_group{true};
AlgoF32DirectStride2 f32_direct_stride2_small_group{false};
AlgoS8MatrixMul s8_matrix_mul;
AlgoQU8MatrixMul qu8_matrix_mul;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
AlgoF16DirectStride2 f16_direct_stride2_large_group{true};
AlgoF16DirectStride2 f16_direct_stride2_small_group{false};
#endif
public:
AlgoPack() {
matmul_algos.emplace_back(&qu8_matrix_mul);
matmul_algos.emplace_back(&s8_matrix_mul);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
direct_algos.emplace_back(&f16_direct_stride2_large_group);
direct_algos.emplace_back(&f16_direct_stride2_small_group);
#endif
direct_algos.emplace_back(&f32_direct_stride2_large_group);
direct_algos.emplace_back(&f32_direct_stride2_small_group);
}
SmallVector<AlgoBase*> direct_algos;
SmallVector<AlgoBase*> matmul_algos;
};
SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() {
static AlgoPack sl_algo_pack;
auto&& algos = arm_common::ConvBiasImpl::algo_pack();
algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(),
sl_algo_pack.direct_algos.end());
//! We put matmul algos at the end. Because matmul will get privilege when
//! prefer return true. See
//! fallback::ConvolutionImpl::ncb_1g_get_all_algorithms for more details.
algos.insert(algos.end(), sl_algo_pack.matmul_algos.begin(),
sl_algo_pack.matmul_algos.end());
return std::move(algos);
}
const char* ConvBiasImpl::get_algorithm_set_name() const {
return "AARCH64";
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/aarch64/conv_bias/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/common/utils.h"
#include "src/arm_common/conv_bias/opr_impl.h"
namespace megdnn {
namespace aarch64 {
class ConvBiasImpl : public arm_common::ConvBiasImpl {
public:
using arm_common::ConvBiasImpl::ConvBiasImpl;
SmallVector<AlgoBase*> algo_pack() override;
protected:
const char* get_algorithm_set_name() const override;
private:
class AlgoF32DirectStride2;
class AlgoS8MatrixMul;
class AlgoQU8MatrixMul;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class AlgoF16DirectStride2;
#endif
class AlgoPack;
};
} // namespace aarch64
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/aarch64/conv_bias/quint8/algos.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/aarch64/conv_bias/quint8/algos.h"
#include "src/aarch64/conv_bias/quint8/strategy.h"
#include "src/aarch64/matrix_mul/quint8_dot/gemv.h"
#include "src/aarch64/matrix_mul/quint8_dot/strategy.h"
#include "src/arm_common/convolution/img2col_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/common/opr_delegate.h"
#include "src/fallback/conv_bias/common.h"
#include "src/fallback/matrix_mul/gemm_impl.h"
#include "midout.h"
MIDOUT_DECL(megdnn_aarch64_conv_bias_quint8_gemm)
using namespace megdnn;
using namespace aarch64;
using megdnn::arm_common::HSwishOp;
using megdnn::arm_common::ReluOp;
using megdnn::arm_common::TypeCvtOp;
/* ===================== matrix mul algo ===================== */
bool ConvBiasImpl::AlgoQU8MatrixMul::usable(
FallbackConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
MEGDNN_MARK_USED_VAR(opr);
auto&& fm = param.filter_meta;
return param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
param.dst_type.enumv() == DTypeEnum::Quantized8Asymm &&
fm.format == param::ConvBias::Format::NCHW && fm.spatial_ndim == 2 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
//! As postprocess, the bias is not contigous read, make the
//! performance bad, so we do not process it in fused kernel
param.bias_mode != BiasMode::BIAS &&
//! This algo is only support single thread
param.nr_threads == 1_z;
}
WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle(
const NCBKernSizeParam& param) {
UNPACK_CONV_NCB_KERN_SIZES(param);
MEGDNN_MARK_USED_VAR(N);
auto IW2 = IH + 2 * PH;
auto IH2 = IW + 2 * PW;
bool can_matrix_mul_direct =
(FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0);
// temp space to store padding-free src (with 16 extra int8)
// temp space to store unrolled matrix (with 16 extra int8)
// workspace for matrix mul opr
size_t part0, part1, part2;
if (can_matrix_mul_direct) {
part0 = part1 = 0;
} else {
part0 = (IC * IH2 * IW2 + 16) * sizeof(uint8_t);
part1 = (IC * FH * FW * OH * OW + 16) * sizeof(uint8_t);
}
{
size_t M = OC;
size_t K = IC * FH * FW;
size_t N = OH * OW;
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_quint8_gemm, 0, _gemm_midout_enum, \
_bias_midout_enum, _nonline_midout_enum) { \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
part2 = megdnn::matmul::GemmInterleaved< \
matmul::gemm_##_gemm##_##_bias##_##_nonline>( \
M, N, K, false, false, strategy) \
.get_workspace_size(); \
} \
MIDOUT_END()
DISPATCH_GEMM_BIAS(u8_8x8, 0)
#undef DISPATCH_GEMM_STRATEGY
}
return {nullptr, {part0, part1, part2}};
}
void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param,
const NCBKernIndex& ncb_index) {
auto is_xcorr = !param.filter_meta.should_flip;
UNPACK_CONV_NCB_KERN_SIZES(param);
auto bundle = get_bundle(param);
bundle.set(param.workspace_ptr);
auto IH2 = IH + 2 * PH;
auto IW2 = IW + 2 * PW;
size_t group_id = ncb_index.ndrange_id[0];
uint8_t src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point;
// workspace = tmp..src2
for (size_t n = 0; n < N; ++n) {
uint8_t* src = const_cast<uint8_t*>(param.src<uint8_t>(n, group_id));
uint8_t* filter = const_cast<uint8_t*>(param.filter<uint8_t>(group_id));
uint8_t* dst = static_cast<uint8_t*>(param.dst<uint8_t>(n, group_id));
int32_t* bias = const_cast<int32_t*>(param.bias<int32_t>(n, group_id));
uint8_t *B, *src2;
if (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0) {
// special case: 1x1
B = const_cast<uint8_t*>(src);
} else {
src2 = static_cast<uint8_t*>(bundle.get(0));
// copy src to src2;
uint8_t* src2_ptr = src2;
const uint8_t* src_ptr = src;
rep(ic, IC) {
if (PH != 0) {
std::memset(src2_ptr, src_zp, sizeof(uint8_t) * PH * IW2);
src2_ptr += PH * IW2;
}
rep(ih, IH) {
if (PW != 0)
rep(pw, PW) { *(src2_ptr++) = src_zp; }
std::memcpy(src2_ptr, src_ptr, sizeof(uint8_t) * IW);
src2_ptr += IW;
src_ptr += IW;
if (PW != 0)
rep(pw, PW) { *(src2_ptr++) = src_zp; }
}
if (PH != 0) {
std::memset(src2_ptr, src_zp, sizeof(uint8_t) * PH * IW2);
src2_ptr += PH * IW2;
}
}
B = static_cast<uint8_t*>(bundle.get(1));
if (SH == 1 && SW == 1) {
if (is_xcorr)
img2col<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW);
else
img2col<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW);
} else {
if (is_xcorr)
img2col_stride<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH,
FW, SH, SW);
else
img2col_stride<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH,
FW, SH, SW);
}
}
{
Workspace workspace(static_cast<dt_byte*>(bundle.get(2)),
bundle.get_size(2));
size_t M = OC;
size_t K = IC * FH * FW;
size_t N = OH * OW;
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_quint8_gemm, 1, _gemm_midout_enum, \
_bias_midout_enum, _nonline_midout_enum) { \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
megdnn::matmul::GemmInterleaved< \
matmul::gemm_##_gemm##_##_bias##_##_nonline> \
gemm_interleaved(M, N, K, false, false, strategy); \
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, \
bias); \
} \
MIDOUT_END()
DISPATCH_GEMM_BIAS(u8_8x8, 0)
#undef DISPATCH_GEMM_STRATEGY
}
}
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/aarch64/conv_bias/quint8/algos.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/aarch64/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/opr_impl.h"
namespace megdnn {
namespace aarch64 {
using FallbackConvBiasImpl = fallback::ConvBiasImpl;
class ConvBiasImpl::AlgoQU8MatrixMul final : public AlgoBase {
static WorkspaceBundle get_bundle(const NCBKernSizeParam& param);
static void kimpl(const NCBKernParam& param, const NCBKernIndex&);
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "QU8MATMUL"; }
bool usable(FallbackConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(FallbackConvBiasImpl*,
const NCBKernSizeParam& param) const override {
return get_bundle(param).total_size_in_bytes();
}
SmallVector<NCBKern> dispatch_kerns(
FallbackConvBiasImpl*,
const NCBKernSizeParam& param) const override {
size_t group = param.filter_meta.group;
return {{kimpl, {group, 1_z, 1_z}}};
}
//! select matmul to the highest preference
bool is_preferred(FallbackConvBiasImpl* opr,
const NCBKernSizeParam& param) const override {
return static_cast<arm_common::ConvBiasImpl*>(opr)
->is_matmul_quantized_prefer(param);
}
};
} // namespace aarch64
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/aarch64/conv_bias/quint8/strategy.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/aarch64/conv_bias/quint8/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"
#include "src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h"
#include "src/aarch64/matrix_mul/quint8/kernel_8x8x8.h"
#include "src/arm_common/conv_bias/matmul_postprocess.h"
using namespace megdnn;
using namespace aarch64;
using namespace aarch64::matmul;
namespace impl {
template <BiasMode bmode, typename Op, int block_m, int block_n>
struct KernCaller;
#if __ARM_FEATURE_DOTPROD
template <BiasMode bmode, typename Op>
struct KernCaller<bmode, Op, 8, 8> {
static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M,
size_t N, size_t K, dt_uint8* C, size_t LDC,
bool is_first_k, Op op, const dt_int32* bias,
dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) {
megdnn_assert(is_first_k);
constexpr size_t A_INTERLEAVE = 8;
constexpr size_t B_INTERLEAVE = 8;
const uint32_t zAB =
static_cast<uint32_t>(zp_A) * static_cast<uint32_t>(zp_B) * K;
//! K is packed to times of 4
K = round_up<size_t>(K, 4);
const int K8 = (K << 3);
const int K4 = K * 4;
size_t m = 0;
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) {
uint8_t* output = C + (m * LDC);
size_t n = 0;
const dt_uint8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x4::kern_8x8(packA, cur_packB, K, workspace, 8,
is_first_k, zp_A, zp_B, zAB);
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8,
8>::postprocess(bias, workspace,
output, LDC, op);
output += B_INTERLEAVE;
cur_packB += K8;
}
for (; n < N; n += 4) {
matmul_8x8x4::kern_8x4(packA, cur_packB, K, workspace, 4,
is_first_k, std::min<size_t>(N - n, 4),
zp_A, zp_B, zAB);
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 4, 8, n>::postprocess( \
bias, workspace, output, LDC, op);
DISPATCH_N(cb, 8, std::min<size_t>(N - n, 4));
#undef cb
output += 4;
cur_packB += K4;
}
packA += K8;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
bias += A_INTERLEAVE;
}
}
for (; m < M; m += 4) {
uint8_t* output = C + (m * LDC);
const dt_uint8* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x4::kern_4x8(packA, cur_packB, K, workspace, 8,
is_first_k, std::min<size_t>(M - m, 4),
zp_A, zp_B, zAB);
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 8, m, n>::postprocess( \
bias, workspace, output, LDC, op);
DISPATCH_M_N(cb, std::min<size_t>(M - m, 4), 8);
#undef cb
output += B_INTERLEAVE;
cur_packB += K8;
}
for (; n < N; n += 4) {
matmul_8x8x4::kern_4x4(packA, cur_packB, K, workspace, 4,
is_first_k, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4), zp_A, zp_B,
zAB);
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 4, m, n>::postprocess( \
bias, workspace, output, LDC, op);
DISPATCH_M(cb, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
#undef cb
output += 4;
cur_packB += K4;
}
packA += K4;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
bias += 4;
}
}
}
};
#else
template <BiasMode bmode, typename Op>
struct KernCaller<bmode, Op, 8, 8> {
static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M,
size_t N, size_t K, dt_uint8* C, size_t LDC,
bool is_first_k, Op op, const dt_int32* bias,
dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) {
megdnn_assert(is_first_k);
constexpr size_t A_INTERLEAVE = 8;
constexpr size_t B_INTERLEAVE = 8;
//! K is packed to times of 8
K = round_up<size_t>(K, 8);
const int K8 = K * 8;
const int K4 = K * 4;
size_t m = 0;
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) {
uint8_t* output = C + (m * LDC);
size_t n = 0;
const dt_uint8* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x8::kern_8x8(packA, cur_packB, K, workspace, 8,
is_first_k, zp_A, zp_B);
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8,
8>::postprocess(bias, workspace,
output, LDC, op);
output += B_INTERLEAVE;
cur_packB += K8;
}
for (; n < N; n += 4) {
matmul_8x8x8::kern_8x4(packA, cur_packB, K, workspace, 4,
is_first_k, std::min<size_t>(N - n, 4),
zp_A, zp_B);
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 4, 8, n>::postprocess( \
bias, workspace, output, LDC, op);
DISPATCH_N(cb, 8, std::min<size_t>(N - n, 4));
#undef cb
output += 4;
cur_packB += K4;
}
packA += K8;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
bias += A_INTERLEAVE;
}
}
for (; m < M; m += 4) {
uint8_t* output = C + (m * LDC);
const dt_uint8* cur_packB = packB;
size_t n = 0;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_8x8x8::kern_4x8(packA, cur_packB, K, workspace, 8,
is_first_k, std::min<size_t>(M - m, 4),
zp_A, zp_B);
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 8, m, n>::postprocess( \
bias, workspace, output, LDC, op);
DISPATCH_M_N(cb, std::min<size_t>(M - m, 4), 8);
#undef cb
output += B_INTERLEAVE;
cur_packB += K8;
}
for (; n < N; n += 4) {
matmul_8x8x8::kern_4x4(packA, cur_packB, K, workspace, 4,
is_first_k, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4), zp_A, zp_B);
#define cb(m, n) \
arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 4, m, n>::postprocess( \
bias, workspace, output, LDC, op);
DISPATCH_M(cb, std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
#undef cb
output += 4;
cur_packB += K4;
}
packA += K4;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
bias += 4;
}
}
}
};
#endif
} // namespace impl
#if __ARM_FEATURE_DOTPROD
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_nobias_identity)
void gemm_u8_8x8_nobias_identity::pack_A(uint8_t* outptr, const uint8_t* inptr,
int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const {
if (transpose) {
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(outptr, inptr, ldin, y0,
ymax, k0, kmax);
} else {
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(outptr, inptr, ldin,
y0, ymax, k0, kmax);
}
}
void gemm_u8_8x8_nobias_identity::pack_B(uint8_t* out, const uint8_t* in,
int ldin, int x0, int xmax, int k0,
int kmax, bool transpose) const {
if (transpose) {
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(out, in, ldin, x0,
xmax, k0, kmax);
} else {
matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(out, in, ldin, x0, xmax,
k0, kmax);
}
}
#else
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_nobias_identity)
void gemm_u8_8x8_nobias_identity::pack_A(dt_uint8* outptr,
const dt_uint8* inptr, int ldin,
int y0, int ymax, int k0, int kmax,
bool transpose) const {
uint8_t zA = A_dtype.param<dtype::Quantized8Asymm>().zero_point;
if (transpose) {
matmul_8x8x8::gemm_u8_8x8_transpose_pack_A_n(outptr, inptr, ldin, y0,
ymax, k0, kmax, zA);
} else {
matmul_8x8x8::gemm_u8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0,
kmax, zA);
}
}
void gemm_u8_8x8_nobias_identity::pack_B(dt_uint8* out, const dt_uint8* in,
int ldin, int x0, int xmax, int k0,
int kmax, bool transpose) const {
uint8_t zB = B_dtype.param<dtype::Quantized8Asymm>().zero_point;
if (transpose) {
matmul_8x8x8::gemm_u8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax,
k0, kmax, zB);
} else {
matmul_8x8x8::gemm_u8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax,
zB);
}
}
#endif
size_t gemm_u8_8x8_nobias_identity::get_workspace_size() const {
return 8 * 8 * sizeof(dt_int32);
}
#define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \
void gemm_u8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \
const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, \
size_t K, dt_uint8* C, size_t LDC, bool is_first_k, \
const dt_int32* bias, dt_int32* workspace) const { \
float scale_A = A_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_A = A_dtype.param<dtype::Quantized8Asymm>().zero_point; \
float scale_B = B_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_B = B_dtype.param<dtype::Quantized8Asymm>().zero_point; \
float scale_C = C_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_C = C_dtype.param<dtype::Quantized8Asymm>().zero_point; \
DEFINE_OP(_OP); \
impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \
packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \
workspace, zp_A, zp_B); \
}
#define DEFINE_OP(_Op) \
arm_common::_Op<dt_qint32, dt_quint8> op(scale_A* scale_B, scale_C, zp_C);
KERN(8, 8, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp)
KERN(8, 8, nobias, BiasMode::NO_BIAS, relu, ReluOp)
KERN(8, 8, nobias, BiasMode::NO_BIAS, hswish, HSwishOp)
#undef DEFINE_OP
#define DEFINE_OP(_Op) \
arm_common::_Op<dt_qint32, dt_quint8> op(scale_A* scale_B, \
scale_A* scale_B, scale_C, zp_C);
KERN(8, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp)
KERN(8, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp)
KERN(8, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish,
FuseAddHSwishOp)
#undef DEFINE_OP
#undef KERN
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/aarch64/conv_bias/quint8/strategy.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/fallback/matrix_mul/gemm_common.h"
namespace megdnn {
namespace aarch64 {
namespace matmul {
#if __ARM_FEATURE_DOTPROD
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 4,
false, true,
gemm_u8_8x8_nobias_identity);
#else
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 8,
false, true,
gemm_u8_8x8_nobias_identity);
#endif
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nobias_relu,
gemm_u8_8x8_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nobias_hswish,
gemm_u8_8x8_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_bias_channel_identity,
gemm_u8_8x8_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_bias_channel_relu,
gemm_u8_8x8_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_bias_channel_hswish,
gemm_u8_8x8_nobias_identity);
} // namespace matmul
} // namespace aarch64
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/aarch64/handle.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/common/handle_impl.h"
#include "src/aarch64/handle.h"
#include "src/aarch64/matrix_mul/opr_impl.h"
#include "src/aarch64/rotate/opr_impl.h"
#include "src/aarch64/relayout/opr_impl.h"
#include "src/aarch64/conv_bias/opr_impl.h"
#include "src/aarch64/warp_perspective/opr_impl.h"
namespace megdnn {
namespace aarch64 {
template <typename Opr>
std::unique_ptr<Opr> HandleImpl::create_operator() {
return arm_common::HandleImpl::create_operator<Opr>();
}
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MatrixMul)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Rotate)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(RelayoutForward)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvBias)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(WarpPerspective)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wpragmas"
#pragma GCC diagnostic ignored "-Winstantiation-after-specialization"
MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR)
#pragma GCC diagnostic pop
} // namespace aarch64
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/aarch64/handle.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/arm_common/handle.h"
namespace megdnn {
namespace aarch64 {
class HandleImpl: public arm_common::HandleImpl {
public:
HandleImpl(megcoreComputingHandle_t computing_handle,
HandleType type = HandleType::AARCH64):
arm_common::HandleImpl::HandleImpl(computing_handle, type)
{}
template <typename Opr>
std::unique_ptr<Opr> create_operator();
};
} // namespace aarch64
} // namespace megdnn
// vim: syntax=cpp.doxygen
此差异已折叠。
/**
* \file dnn/src/aarch64/matrix_mul/algos.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/aarch64/matrix_mul/opr_impl.h"
#include "src/arm_common/matrix_mul/algos.h"
#include "src/fallback/matrix_mul/gemm_common.h"
namespace megdnn {
namespace aarch64 {
class MatrixMulImpl::AlgoF32K8x12x1 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_F32K8X12X1"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_F32K4X16X1"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoF32MK4_4x16 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_F32_MK4_4x16"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; }
};
class MatrixMulImpl::AlgoF32Gemv final
: public arm_common::MatrixMulImpl::AlgoF32Gemv {};
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class MatrixMulImpl::AlgoF16K8x24x1 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_F16_K8X24X1"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoF16MK8_8x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_F16_MK8_8X8"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; }
};
#endif
#if __ARM_FEATURE_DOTPROD
class MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override {
return "AARCH64_INT8X8X32_K8X12X4_DOTPROD";
}
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoInt8x8x32GemvDotProd final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override {
return "AARCH64_INT8X8X32_GEMV_DOTPROD";
}
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override { return 0; }
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
};
#else
class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override {
return "AARCH64_INT8X8X32_MK4_4X4X16";
}
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::DEFAULT; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoInt8x8x32K4x4x16 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_INT8X8X32_K4X4X16"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoInt8x8x32K8x8x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_INT8X8X32_K8X8X8"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoInt8x8x32Gemv final
: public arm_common::MatrixMulImpl::AlgoInt8x8x32Gemv {};
#endif
class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_INT8X8X16_K8X8X8"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoInt8x8x16K4x4x16 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_INT8X8X16_K4X4X16"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_INT16X16X32_K12X8X1"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoInt16x16x32MK8_8x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_INT16X16X32_MK8_8X8"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; }
};
#if __ARM_FEATURE_DOTPROD
class MatrixMulImpl::AlgoQuint8K8x8x4DotProd final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override {
return "AARCH64_QUINT8_K8X8X4_DOTPROD";
}
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_QUINT8_GEMV_DOTPROD"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override { return 0; }
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
};
#else
class MatrixMulImpl::AlgoQuint8K8x8x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_QUINT8_K8X8X8"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};
#endif
} // namespace aarch64
} // namespace megdnn
// vim: syntax=cpp.doxygen
此差异已折叠。
此差异已折叠。
/**
* \file dnn/src/aarch64/matrix_mul/fp16/strategy.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/fallback/matrix_mul/gemm_common.h"
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
namespace megdnn {
namespace aarch64 {
namespace matmul {
MEGDNN_REG_GEMM_STRATEGY(dt_float16, dt_float16, dt_float16, 8, 24, 1, false,
true, hgemm_8x24);
MEGDNN_REG_GEMM_STRATEGY_NOPACK(dt_float16, dt_float16, dt_float16, 8, 8, 1,
false, true, gemm_nopack_f16_8x8);
} // namespace matmul
} // namespace aarch64
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/aarch64/matrix_mul/fp16/strategy_mk8_8x8.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/aarch64/matrix_mul/fp16/strategy.h"
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
using namespace megdnn;
using namespace aarch64;
using namespace aarch64::matmul;
namespace {
// Overview of register layout:
//
// A 8x1 cell of Rhs is stored in 16bit in v0-v3
// A 8x1 cell of Lhs is stored in 16bit in v16-v23
// A 8x1 block of accumulators is stored in 16bit in v24-v27.
//
// Rhs +-------+
// |v0[0-7]|
// |v1[0-7]|
// |v2[0-7]|
// |v3[0-7]|
// +-------+
// Lhs
// +--------+
// |v16[0-7]|
// |v17[0-7]|
// |v18[0-7]|
// |v19[0-7]| +--------+
// |v20[0-7]| |v24[0-7]|
// |v21[0-7]| |v25[0-7]|
// |v22[0-7]| |v26[0-7]|
// |v23[0-7]| |v27[0-7]|
// +--------+ +--------+
// Accumulator
void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
dt_float16* output) {
//! LDB means number of elements in one block in B. we will read 24 numbers
//! first. so minus 24 * 2 bytes here.
LDB = (LDB - 24) * sizeof(dt_float16);
asm volatile(
".arch armv8.2-a+fp16\n"
"ld1 {v16.4s, v17.4s}, [%[a_ptr]], 32\n"
"subs %w[K], %w[K], #8\n"
"ld1 {v0.4s}, [%[b_ptr]], 16\n"
"ld1 {v1.4s}, [%[b_ptr]], 16\n"
"fmul v24.8h, v16.8h, v0.h[0]\n"
"ld1 {v2.4s}, [%[b_ptr]], 16\n"
"fmul v25.8h, v16.8h, v1.h[0]\n"
"ld1 {v3.4s}, [%[b_ptr]], %x[LDB]\n"
"fmul v26.8h, v16.8h, v2.h[0]\n"
"ld1 {v18.4s}, [%[a_ptr]], 16\n"
"fmul v27.8h, v16.8h, v3.h[0]\n"
"fmla v24.8h, v17.8h, v0.h[1]\n"
"fmla v25.8h, v17.8h, v1.h[1]\n"
"fmla v26.8h, v17.8h, v2.h[1]\n"
"fmla v27.8h, v17.8h, v3.h[1]\n"
"ld1 {v19.4s}, [%[a_ptr]], 16\n"
"fmla v24.8h, v18.8h, v0.h[2]\n"
"fmla v25.8h, v18.8h, v1.h[2]\n"
"fmla v26.8h, v18.8h, v2.h[2]\n"
"fmla v27.8h, v18.8h, v3.h[2]\n"
"ld1 {v20.4s}, [%[a_ptr]], 16\n"
"fmla v24.8h, v19.8h, v0.h[3]\n"
"fmla v25.8h, v19.8h, v1.h[3]\n"
"fmla v26.8h, v19.8h, v2.h[3]\n"
"fmla v27.8h, v19.8h, v3.h[3]\n"
"ld1 {v21.4s}, [%[a_ptr]], 16\n"
"fmla v24.8h, v20.8h, v0.h[4]\n"
"fmla v25.8h, v20.8h, v1.h[4]\n"
"fmla v26.8h, v20.8h, v2.h[4]\n"
"fmla v27.8h, v20.8h, v3.h[4]\n"
"ld1 {v22.4s}, [%[a_ptr]], 16\n"
"fmla v24.8h, v21.8h, v0.h[5]\n"
"fmla v25.8h, v21.8h, v1.h[5]\n"
"fmla v26.8h, v21.8h, v2.h[5]\n"
"fmla v27.8h, v21.8h, v3.h[5]\n"
"ld1 {v23.4s}, [%[a_ptr]], 16\n"
"fmla v24.8h, v22.8h, v0.h[6]\n"
"fmla v25.8h, v22.8h, v1.h[6]\n"
"fmla v26.8h, v22.8h, v2.h[6]\n"
"fmla v27.8h, v22.8h, v3.h[6]\n"
"beq 2f\n"
"1:\n"
"ld1 {v16.4s}, [%[a_ptr]], 16\n"
"fmla v24.8h, v23.8h, v0.h[7]\n"
"ld1 {v0.4s}, [%[b_ptr]], 16\n"
"fmla v25.8h, v23.8h, v1.h[7]\n"
"ld1 {v1.4s}, [%[b_ptr]], 16\n"
"fmla v26.8h, v23.8h, v2.h[7]\n"
"ld1 {v2.4s}, [%[b_ptr]], 16\n"
"fmla v27.8h, v23.8h, v3.h[7]\n"
"ld1 {v3.4s}, [%[b_ptr]], %x[LDB]\n"
"ld1 {v17.4s}, [%[a_ptr]], 16\n"
"fmla v24.8h, v16.8h, v0.h[0]\n"
"fmla v25.8h, v16.8h, v1.h[0]\n"
"fmla v26.8h, v16.8h, v2.h[0]\n"
"fmla v27.8h, v16.8h, v3.h[0]\n"
"ld1 {v18.4s}, [%[a_ptr]], 16\n"
"fmla v24.8h, v17.8h, v0.h[1]\n"
"fmla v25.8h, v17.8h, v1.h[1]\n"
"fmla v26.8h, v17.8h, v2.h[1]\n"
"fmla v27.8h, v17.8h, v3.h[1]\n"
"ld1 {v19.4s}, [%[a_ptr]], 16\n"
"fmla v24.8h, v18.8h, v0.h[2]\n"
"fmla v25.8h, v18.8h, v1.h[2]\n"
"fmla v26.8h, v18.8h, v2.h[2]\n"
"fmla v27.8h, v18.8h, v3.h[2]\n"
"ld1 {v20.4s}, [%[a_ptr]], 16\n"
"fmla v24.8h, v19.8h, v0.h[3]\n"
"fmla v25.8h, v19.8h, v1.h[3]\n"
"fmla v26.8h, v19.8h, v2.h[3]\n"
"fmla v27.8h, v19.8h, v3.h[3]\n"
"ld1 {v21.4s}, [%[a_ptr]], 16\n"
"fmla v24.8h, v20.8h, v0.h[4]\n"
"fmla v25.8h, v20.8h, v1.h[4]\n"
"fmla v26.8h, v20.8h, v2.h[4]\n"
"fmla v27.8h, v20.8h, v3.h[4]\n"
"ld1 {v22.4s}, [%[a_ptr]], 16\n"
"fmla v24.8h, v21.8h, v0.h[5]\n"
"fmla v25.8h, v21.8h, v1.h[5]\n"
"fmla v26.8h, v21.8h, v2.h[5]\n"
"fmla v27.8h, v21.8h, v3.h[5]\n"
"ld1 {v23.4s}, [%[a_ptr]], 16\n"
"fmla v24.8h, v22.8h, v0.h[6]\n"
"fmla v25.8h, v22.8h, v1.h[6]\n"
"fmla v26.8h, v22.8h, v2.h[6]\n"
"fmla v27.8h, v22.8h, v3.h[6]\n"
"subs %w[K], %w[K], #8\n"
"bne 1b\n"
"2:\n"
"fmla v24.8h, v23.8h, v0.h[7]\n"
"fmla v25.8h, v23.8h, v1.h[7]\n"
"fmla v26.8h, v23.8h, v2.h[7]\n"
"fmla v27.8h, v23.8h, v3.h[7]\n"
"st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output]], 64\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB)
:
: "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21",
"v22", "v23", "v24", "v25", "v26", "v27", "cc", "memory");
}
// Overview of register layout:
//
// A 8x1 cell of Rhs is stored in 16bit in v8-v15
// A 8x1 cell of Lhs is stored in 16bit in v0-v7
// A 8x1 block of accumulators is stored in 16bit in v24-v31.
//
// Rhs +--------+
// | v8[0-7]|
// | v9[0-7]|
// |v10[0-7]|
// |v11[0-7]|
// |v12[0-7]|
// |v13[0-7]|
// |v14[0-7]|
// |v15[0-7]|
// +--------+
// Lhs
// +--------+ - - - - -+--------+
// | v0[0-7]| |v24[0-7]|
// | v1[0-7]| |v25[0-7]|
// | v2[0-7]| |v26[0-7]|
// | v3[0-7]| |v27[0-7]|
// | v4[0-7]| |v28[0-7]|
// | v5[0-7]| |v29[0-7]|
// | v6[0-7]| |v30[0-7]|
// | v7[0-7]| |v31[0-7]|
// +--------+ +--------+
// Accumulator
void kern_8x8(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K,
dt_float16* output) {
//! As each load 128 number from B, but the pos add 112 * 2, so we minus 112
//! here.
LDB = (LDB - 32) * sizeof(dt_float16);
asm volatile(
".arch armv8.2-a+fp16\n"
"ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[a_ptr]], 64\n"
"subs %w[K], %w[K], #8\n"
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n"
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[b_ptr]], %x[LDB]\n"
"fmul v24.8h, v8.8h, v0.h[0]\n"
"fmul v25.8h, v8.8h, v1.h[0]\n"
"fmul v26.8h, v8.8h, v2.h[0]\n"
"fmul v27.8h, v8.8h, v3.h[0]\n"
"fmul v28.8h, v8.8h, v4.h[0]\n"
"fmul v29.8h, v8.8h, v5.h[0]\n"
"fmul v30.8h, v8.8h, v6.h[0]\n"
"fmul v31.8h, v8.8h, v7.h[0]\n"
"fmla v24.8h, v9.8h, v0.h[1]\n"
"fmla v25.8h, v9.8h, v1.h[1]\n"
"fmla v26.8h, v9.8h, v2.h[1]\n"
"fmla v27.8h, v9.8h, v3.h[1]\n"
"fmla v28.8h, v9.8h, v4.h[1]\n"
"fmla v29.8h, v9.8h, v5.h[1]\n"
"fmla v30.8h, v9.8h, v6.h[1]\n"
"fmla v31.8h, v9.8h, v7.h[1]\n"
"ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[a_ptr]], 64\n"
"fmla v24.8h, v10.8h, v0.h[2]\n"
"fmla v25.8h, v10.8h, v1.h[2]\n"
"fmla v26.8h, v10.8h, v2.h[2]\n"
"fmla v27.8h, v10.8h, v3.h[2]\n"
"fmla v28.8h, v10.8h, v4.h[2]\n"
"fmla v29.8h, v10.8h, v5.h[2]\n"
"fmla v30.8h, v10.8h, v6.h[2]\n"
"fmla v31.8h, v10.8h, v7.h[2]\n"
"fmla v24.8h, v11.8h, v0.h[3]\n"
"fmla v25.8h, v11.8h, v1.h[3]\n"
"fmla v26.8h, v11.8h, v2.h[3]\n"
"fmla v27.8h, v11.8h, v3.h[3]\n"
"fmla v28.8h, v11.8h, v4.h[3]\n"
"fmla v29.8h, v11.8h, v5.h[3]\n"
"fmla v30.8h, v11.8h, v6.h[3]\n"
"fmla v31.8h, v11.8h, v7.h[3]\n"
"fmla v24.8h, v12.8h, v0.h[4]\n"
"fmla v25.8h, v12.8h, v1.h[4]\n"
"fmla v26.8h, v12.8h, v2.h[4]\n"
"fmla v27.8h, v12.8h, v3.h[4]\n"
"fmla v24.8h, v13.8h, v0.h[5]\n"
"fmla v25.8h, v13.8h, v1.h[5]\n"
"fmla v26.8h, v13.8h, v2.h[5]\n"
"fmla v27.8h, v13.8h, v3.h[5]\n"
"beq 2f\n"
"1:\n"
"ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[a_ptr]], 64\n"
"fmla v24.8h, v15.8h, v0.h[7]\n"
"fmla v25.8h, v15.8h, v1.h[7]\n"
"fmla v26.8h, v15.8h, v2.h[7]\n"
"fmla v27.8h, v15.8h, v3.h[7]\n"
"fmla v24.8h, v14.8h, v0.h[6]\n"
"fmla v25.8h, v14.8h, v1.h[6]\n"
"fmla v26.8h, v14.8h, v2.h[6]\n"
"fmla v27.8h, v14.8h, v3.h[6]\n"
"fmla v28.8h, v12.8h, v4.h[4]\n"
"fmla v29.8h, v12.8h, v5.h[4]\n"
"fmla v30.8h, v12.8h, v6.h[4]\n"
"fmla v31.8h, v12.8h, v7.h[4]\n"
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b_ptr]], 64\n"
"fmla v28.8h, v13.8h, v4.h[5]\n"
"fmla v29.8h, v13.8h, v5.h[5]\n"
"fmla v30.8h, v13.8h, v6.h[5]\n"
"fmla v31.8h, v13.8h, v7.h[5]\n"
"fmla v28.8h, v14.8h, v4.h[6]\n"
"fmla v29.8h, v14.8h, v5.h[6]\n"
"fmla v30.8h, v14.8h, v6.h[6]\n"
"fmla v31.8h, v14.8h, v7.h[6]\n"
"fmla v28.8h, v15.8h, v4.h[7]\n"
"fmla v29.8h, v15.8h, v5.h[7]\n"
"fmla v30.8h, v15.8h, v6.h[7]\n"
"fmla v31.8h, v15.8h, v7.h[7]\n"
"fmla v24.8h, v8.8h, v0.h[0]\n"
"fmla v25.8h, v8.8h, v1.h[0]\n"
"fmla v26.8h, v8.8h, v2.h[0]\n"
"fmla v27.8h, v8.8h, v3.h[0]\n"
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[b_ptr]], %x[LDB]\n"
"fmla v24.8h, v9.8h, v0.h[1]\n"
"fmla v25.8h, v9.8h, v1.h[1]\n"
"fmla v26.8h, v9.8h, v2.h[1]\n"
"fmla v27.8h, v9.8h, v3.h[1]\n"
"fmla v24.8h, v10.8h, v0.h[2]\n"
"fmla v25.8h, v10.8h, v1.h[2]\n"
"fmla v26.8h, v10.8h, v2.h[2]\n"
"fmla v27.8h, v10.8h, v3.h[2]\n"
"fmla v24.8h, v11.8h, v0.h[3]\n"
"fmla v25.8h, v11.8h, v1.h[3]\n"
"fmla v26.8h, v11.8h, v2.h[3]\n"
"fmla v27.8h, v11.8h, v3.h[3]\n"
"ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[a_ptr]], 64\n"
"fmla v28.8h, v10.8h, v4.h[2]\n"
"fmla v29.8h, v10.8h, v5.h[2]\n"
"fmla v30.8h, v10.8h, v6.h[2]\n"
"fmla v31.8h, v10.8h, v7.h[2]\n"
"fmla v28.8h, v8.8h, v4.h[0]\n"
"fmla v29.8h, v8.8h, v5.h[0]\n"
"fmla v30.8h, v8.8h, v6.h[0]\n"
"fmla v31.8h, v8.8h, v7.h[0]\n"
"fmla v28.8h, v9.8h, v4.h[1]\n"
"fmla v29.8h, v9.8h, v5.h[1]\n"
"fmla v30.8h, v9.8h, v6.h[1]\n"
"fmla v31.8h, v9.8h, v7.h[1]\n"
"fmla v28.8h, v11.8h, v4.h[3]\n"
"fmla v29.8h, v11.8h, v5.h[3]\n"
"fmla v30.8h, v11.8h, v6.h[3]\n"
"fmla v31.8h, v11.8h, v7.h[3]\n"
"fmla v24.8h, v12.8h, v0.h[4]\n"
"fmla v25.8h, v12.8h, v1.h[4]\n"
"fmla v26.8h, v12.8h, v2.h[4]\n"
"fmla v27.8h, v12.8h, v3.h[4]\n"
"fmla v24.8h, v13.8h, v0.h[5]\n"
"fmla v25.8h, v13.8h, v1.h[5]\n"
"fmla v26.8h, v13.8h, v2.h[5]\n"
"fmla v27.8h, v13.8h, v3.h[5]\n"
"subs %w[K], %w[K], #8\n"
"bne 1b\n"
"2:\n"
"fmla v24.8h, v14.8h, v0.h[6]\n"
"fmla v25.8h, v14.8h, v1.h[6]\n"
"fmla v26.8h, v14.8h, v2.h[6]\n"
"fmla v27.8h, v14.8h, v3.h[6]\n"
"fmla v24.8h, v15.8h, v0.h[7]\n"
"fmla v25.8h, v15.8h, v1.h[7]\n"
"fmla v26.8h, v15.8h, v2.h[7]\n"
"fmla v27.8h, v15.8h, v3.h[7]\n"
"fmla v28.8h, v12.8h, v4.h[4]\n"
"fmla v29.8h, v12.8h, v5.h[4]\n"
"fmla v28.8h, v13.8h, v4.h[5]\n"
"fmla v29.8h, v13.8h, v5.h[5]\n"
"st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[output]], 64\n"
"fmla v28.8h, v14.8h, v4.h[6]\n"
"fmla v29.8h, v14.8h, v5.h[6]\n"
"fmla v28.8h, v15.8h, v4.h[7]\n"
"fmla v29.8h, v15.8h, v5.h[7]\n"
"fmla v30.8h, v12.8h, v6.h[4]\n"
"fmla v31.8h, v12.8h, v7.h[4]\n"
"fmla v30.8h, v13.8h, v6.h[5]\n"
"fmla v31.8h, v13.8h, v7.h[5]\n"
"fmla v30.8h, v14.8h, v6.h[6]\n"
"fmla v31.8h, v14.8h, v7.h[6]\n"
"fmla v30.8h, v15.8h, v6.h[7]\n"
"fmla v31.8h, v15.8h, v7.h[7]\n"
"st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[output]], 64\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[output] "+r"(output), [LDB] "+r"(LDB)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v24", "v25", "v26", "v27",
"v28", "v29", "v30", "v31", "cc", "memory");
}
} // anonymous namespace
MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_f16_8x8);
void gemm_nopack_f16_8x8::kern(const dt_float16* A, size_t LDA,
const dt_float16* B, size_t LDB, dt_float16* C,
size_t LDC, size_t M, size_t K, size_t N,
const dt_float16*, void*, bool trA,
bool trB) const {
constexpr static size_t MB = 8;
constexpr static size_t KB = 8;
constexpr static size_t NB = 8;
constexpr static size_t CALCBLK = 4;
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0);
//! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8)
for (size_t m = 0; m < M; m += MB) {
dt_float16* output = C + (m / MB) * LDC;
const dt_float16* cur_B = B;
size_t n = 0;
for (; n + NB - 1 < N; n += NB) {
kern_8x8(A, cur_B, LDB, K, output);
cur_B += KB * NB;
output += MB * NB;
}
if (n < N) {
kern_8x4(A, cur_B, LDB, K, output);
}
A += LDA;
}
}
#endif
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/aarch64/matrix_mul/fp32/common.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <cstddef>
#include "megdnn/arch.h"
#include "src/common/utils.h"
namespace megdnn {
namespace aarch64 {
MEGDNN_NOINLINE void sgemm_packA_n(const float* A, float* Apacked, size_t M,
size_t K, size_t LDA, const float* alpha);
MEGDNN_NOINLINE void sgemm_packA_t(const float* A, float* Apacked, size_t M,
size_t K, size_t LDA, const float* alpha);
MEGDNN_NOINLINE void sgemm_packB_n(const float* B, float* Bpacked, size_t K,
size_t N, size_t LDB);
MEGDNN_NOINLINE void sgemm_packB_t(const float* B, float* Bpacked, size_t K,
size_t N, size_t LDB);
MEGDNN_NOINLINE void sgemm_kernel12x8(const float* A, const float* B, float* C,
size_t LDC, size_t M, size_t N, size_t K,
int type, const float* beta);
} // namespace aarch64
} // namespace megdnn
// vim: syntax=cpp.doxygen
此差异已折叠。
此差异已折叠。
/**
* \file dnn/src/aarch64/matrix_mul/fp32/strategy.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/aarch64/matrix_mul/fp32/strategy.h"
#include "src/aarch64/matrix_mul/fp32/kernel_general_4x16.h"
#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12.h"
#include "src/common/utils.h"
using namespace megdnn;
using namespace aarch64;
using namespace aarch64::matmul;
MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_4x16);
void sgemm_4x16::pack_A(float* out, const float* in, int ldin, int y0,
int ymax, int k0, int kmax, bool transpose_A) const {
if (transpose_A) {
matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0, kmax);
} else {
matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0, kmax);
}
}
void sgemm_4x16::pack_B(float* out, const float* in, int ldin, int x0, int xmax,
int k0, int kmax, bool transpose_B) const {
if (transpose_B) {
matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0, kmax);
} else {
matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0, kmax);
}
}
void sgemm_4x16::kern(const float* packA, const float* packB,
size_t M, size_t N, size_t K, float* C, size_t LDC,
bool is_first_k, const float*, float*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
constexpr size_t A_INTERLEAVE = 4;
constexpr size_t B_INTERLEAVE = 16;
const int K16 = K * 16;
const int K4 = K * 4;
size_t m = 0;
for (; m < M; m += A_INTERLEAVE) {
float* output = C + (m * LDC);
size_t n = 0;
const float* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_general_4x16::kern_4x16(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4));
output += B_INTERLEAVE;
cur_packB += K16;
}
for (; n < N; n += 4) {
matmul_general_4x16::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}
packA += K4;
}
}
MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_8x12);
void sgemm_8x12::pack_A(float* out, const float* in, int ldin, int y0,
int ymax, int k0, int kmax, bool transpose_A) const {
if (transpose_A) {
matmul_general_8x12::sgemm_8x12_pack_A_t(out, in, ldin, y0, ymax, k0,
kmax);
} else {
matmul_general_8x12::sgemm_8x12_pack_A_n(out, in, ldin, y0, ymax, k0,
kmax);
}
}
void sgemm_8x12::pack_B(float* out, const float* in, int ldin, int x0, int xmax,
int k0, int kmax, bool transpose_B) const {
if (transpose_B) {
matmul_general_8x12::sgemm_8x12_pack_B_t(out, in, ldin, x0, xmax, k0,
kmax);
} else {
matmul_general_8x12::sgemm_8x12_pack_B_n(out, in, ldin, x0, xmax, k0,
kmax);
}
}
void sgemm_8x12::kern(const float* packA, const float* packB,
size_t M, size_t N, size_t K, float* C, size_t LDC,
bool is_first_k, const float*, float*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
constexpr size_t A_INTERLEAVE = 8;
constexpr size_t A_INTERLEAVE4 = 4;
constexpr size_t B_INTERLEAVE = 12;
const int K12 = K * 12;
const int K8 = K * 8;
const int K4 = K * 4;
size_t m = 0;
for (; m + A_INTERLEAVE <= M; m += A_INTERLEAVE) {
float* output = C + (m * LDC);
size_t n = 0;
const float* cur_packB = packB;
for (; n + B_INTERLEAVE <= N; n += B_INTERLEAVE) {
matmul_general_8x12::kern_8x12(packA, cur_packB, K, output, LDC,
is_first_k);
output += B_INTERLEAVE;
cur_packB += K12;
}
for (; n < N; n += 4) {
matmul_general_8x12::kern_8x4(packA, cur_packB, K, output, LDC,
is_first_k,
std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}
packA += K8;
}
for (; m < M; m += A_INTERLEAVE4) {
float* output = C + (m * LDC);
size_t n = 0;
const float* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_general_8x12::kern_4x12(packA, cur_packB, K, output, LDC,
is_first_k,
std::min<size_t>(M - m, 4));
output += B_INTERLEAVE;
cur_packB += K12;
}
for (; n < N; n += 4) {
matmul_general_8x12::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}
packA += K4;
}
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/aarch64/matrix_mul/fp32/strategy.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/fallback/matrix_mul/gemm_common.h"
namespace megdnn {
namespace aarch64 {
namespace matmul {
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true,
sgemm_8x12);
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 16, 1, false, true,
sgemm_4x16);
MEGDNN_REG_GEMM_STRATEGY_NOPACK(float, float, float, 4, 16, 1, false, true,
sgemm_nopack_4x16);
} // namespace matmul
} // namespace aarch64
} // namespace megdnn
// vim: syntax=cpp.doxygen
此差异已折叠。
此差异已折叠。
此差异已折叠。
/**
* \file dnn/src/aarch64/matrix_mul/int16/strategy.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/fallback/matrix_mul/gemm_common.h"
namespace megdnn {
namespace aarch64 {
namespace matmul {
MEGDNN_REG_GEMM_STRATEGY(dt_int16, dt_int32, dt_int32, 12, 8, 1, false, true,
gemm_s16_12x8x1);
MEGDNN_REG_GEMM_STRATEGY_NOPACK(dt_int16, dt_int32, dt_int32, 8, 8, 1, false,
true, gemm_nopack_s16_8x8);
} // namespace matmul
} // namespace aarch64
} // namespace megdnn
// vim: syntax=cpp.doxygen
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
/**
* \file dnn/src/aarch64/matrix_mul/int8/strategy.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#if !(__ARM_FEATURE_DOTPROD)
#include "src/fallback/matrix_mul/gemm_common.h"
namespace megdnn {
namespace aarch64 {
namespace matmul {
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 4, 16, false, true,
gemm_s8_4x4);
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 4, 16, false, false,
gemm_mk4_s8_4x4);
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 8, 8, false, true,
gemm_s8_8x8);
} // namespace matmul
} // namespace aarch64
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册