提交 5b62acfa 编写于 作者: M Megvii Engine Team

feat(dnn/armv7): add new matmul strategy k8x8x4

GitOrigin-RevId: 0c6b7fa1b2ad8724a5c68036d58b3c1e13c3bb42
上级 ad87f78a
......@@ -541,6 +541,74 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x8x8,
armv7::matmul::gemm_s8x8x16_4x8, int8_t,
int16_t, AlgoDataType::INT8X8X16, DEFAULT);
/* ===================== Int8x8x16 Kernel 8x8x4 algo ===================== */
namespace {
void kern_int8x8x16_k8x8x4(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_armv7_matmul_kern,
midout_iv("kern_int8x8x16_k8x8x4"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int16>();
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
auto trA = kern_param.trA, trB = kern_param.trB;
armv7::matmul::gemm_s8x8x16_8x8 strategy(M, N, K, kern_param.A_type,
kern_param.B_type,
kern_param.C_type);
megdnn::matmul::GemmInterleaved<armv7::matmul::gemm_s8x8x16_8x8>(
M, N, K, trA, trB, strategy)
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
kern_param.workspace_ptr);
}
MIDOUT_END();
}
} // anonymous namespace
bool MatrixMulImpl::AlgoInt8x8x16K8x8x4::usable(
const KernSizeParam& kern_size_param) const {
return kern_size_param.A_type == kern_size_param.B_type &&
kern_size_param.A_type == dtype::Int8() &&
kern_size_param.C_type == dtype::Int16() &&
kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT;
}
size_t MatrixMulImpl::AlgoInt8x8x16K8x8x4::get_workspace(
const KernSizeParam& kern_size_param) const {
MIDOUT_BEGIN(megdnn_armv7_matmul_kern,
midout_iv("AlgoInt8x8x16K8x8x4::get_workspace"_hash)) {
auto M = kern_size_param.M, N = kern_size_param.N,
K = kern_size_param.K;
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
C_type = kern_size_param.C_type;
auto trA = kern_size_param.trA, trB = kern_size_param.trB;
matmul::gemm_s8x8x16_8x8 strategy(M, N, K, A_type, B_type, C_type);
return megdnn::matmul::GemmInterleaved<matmul::gemm_s8x8x16_8x8>(
M, N, K, trA, trB, strategy)
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K8x8x4::get_kern(
const KernSizeParam&) const {
return kern_int8x8x16_k8x8x4;
}
bool MatrixMulImpl::AlgoInt8x8x16K8x8x4::preferred(
const KernSizeParam& kern_size_param) const {
return kern_size_param.K >= 8 && kern_size_param.K <= 128;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K8x8x4,
megdnn_armv7_matmul_kern,
"AlgoInt8x8x16K8x8x4"_hash,
armv7::matmul::gemm_s8x8x16_8x8, int8_t,
int16_t, AlgoDataType::INT8X8X16, DEFAULT);
/* =================== Int8x8x16 Kernel MK4 8x8x4 algo ===================*/
namespace {
......
......@@ -181,6 +181,18 @@ public:
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X16_K4X8X8)
};
class MatrixMulImpl::AlgoInt8x8x16K8x8x4 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "ARMV7_INT8X8X16_K8X8X4"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X16_K8X8X4)
};
class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
......
此差异已折叠。
......@@ -10,12 +10,13 @@
* implied.
*/
#include "src/armv7/matrix_mul/int8x8x16/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/armv7/matrix_mul/asm/common.h"
#include "src/armv7/matrix_mul/int8x8x16/kernel_4x2x16.h"
#include "src/armv7/matrix_mul/int8x8x16/kernel_4x8x8.h"
#include "src/armv7/matrix_mul/int8x8x16/kernel_8x8x4.h"
#include "src/armv7/matrix_mul/int8x8x16/kernel_mk4_8x8x4.h"
#include "src/armv7/matrix_mul/int8x8x16/strategy.h"
#include "src/common/utils.h"
#include "src/fallback/matrix_mul/gemm_common.h"
......@@ -181,6 +182,79 @@ void gemm_s8x8x16_4x8::kern(const dt_int8* packA, const dt_int8* packB,
}
}
// ===========================gemm_s8x8x16_8x8==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_8x8);
void gemm_s8x8x16_8x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0,
int ymax, int k0, int kmax,
bool transpose) const {
if (transpose) {
matmul_8x8x4::gemm_s8x8x16_8x8_pack_A_t(out, in, ldin, y0, ymax, k0,
kmax);
} else {
matmul_8x8x4::gemm_s8x8x16_8x8_pack_A_n(out, in, ldin, y0, ymax, k0,
kmax);
}
}
void gemm_s8x8x16_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0,
int xmax, int k0, int kmax,
bool transpose) const {
if (transpose) {
matmul_8x8x4::gemm_s8x8x16_8x8_pack_B_t(out, in, ldin, x0, xmax, k0,
kmax);
} else {
matmul_8x8x4::gemm_s8x8x16_8x8_pack_B_n(out, in, ldin, x0, xmax, k0,
kmax);
}
}
void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB,
size_t M, size_t N, size_t K, dt_int16* C,
size_t LDC, bool is_first_k, const dt_int16*,
dt_int16*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
((A_dtype.enumv() == DTypeEnum::Int8 &&
C_dtype.enumv() == DTypeEnum::Int16)),
"A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(),
C_dtype.name());
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
constexpr size_t A_INTERLEAVE = 8;
constexpr size_t B_INTERLEAVE = 8;
//! K is packed to times of 4
K = round_up<size_t>(K, 4);
size_t m = 0;
for (; m + 7 < M; m += A_INTERLEAVE) {
int16_t* output = C + (m * LDC);
const dt_int8* cur_packB = packB;
size_t n = 0;
for (; n < N; n += B_INTERLEAVE) {
matmul_8x8x4::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 8));
output += B_INTERLEAVE;
cur_packB += K * 8;
}
packA += K * 8;
}
for (; m < M; m += 4) {
int16_t* output = C + (m * LDC);
const dt_int8* cur_packB = packB;
size_t n = 0;
for (; n < N; n += B_INTERLEAVE) {
matmul_8x8x4::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 8));
output += B_INTERLEAVE;
cur_packB += K * 8;
}
packA += K * 4;
}
}
// ===========================gemm_s8x8x16_mk4_8x8==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_8x8);
......
......@@ -22,6 +22,9 @@ MEGDNN_REG_GEMM_STRATEGY(int8_t, int16_t, int16_t, 4, 2, 16, false, true,
MEGDNN_REG_GEMM_STRATEGY(int8_t, int16_t, int16_t, 4, 8, 8, false, true,
gemm_s8x8x16_4x8);
MEGDNN_REG_GEMM_STRATEGY(int8_t, int16_t, int16_t, 8, 8, 4, false, true,
gemm_s8x8x16_8x8);
MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(int8_t, int16_t, int16_t, int16_t, 8,
8, 4, false, false,
gemm_s8x8x16_mk4_8x8);
......
......@@ -39,6 +39,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoQuint8K4x8x8 quint8_k4x8x8;
AlgoInt8x8x16K4x2x16 int8x8x16_k4x2x16;
AlgoInt8x8x16K4x8x8 int8x8x16_k4x8x8;
AlgoInt8x8x16K8x8x4 int8x8x16_k8x8x4;
AlgoInt8x8x16MK4_8x8x4 int8x8x16_mk4_8x8x4;
AlgoInt16x16x32K12x4x1 int16x16x32_k12x4x1;
AlgoInt16x16x32MK8_4x8 int16x16x32_mk8_4x8;
......@@ -47,7 +48,6 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map;
public:
AlgoPack() {
m_all_algos.emplace_back(&f32_gemv);
m_all_algos.emplace_back(&f32);
......@@ -69,6 +69,7 @@ public:
m_all_algos.emplace_back(&int8x8x16_mk4_8x8x4);
m_all_algos.emplace_back(&int8x8x16_k4x2x16);
m_all_algos.emplace_back(&int8x8x16_k4x8x8);
m_all_algos.emplace_back(&int8x8x16_k8x8x4);
m_all_algos.emplace_back(&int16x16x32_k12x4x1);
m_all_algos.emplace_back(&int16x16x32_mk8_4x8);
......
......@@ -41,7 +41,8 @@ private:
class AlgoQuint8K4x8x8; // Armv7 Quint8 Kernel 4x8x8
class AlgoInt8x8x16K4x2x16; // Armv7 Int8x8x16 Kernel 4x2x16
class AlgoInt8x8x16K4x8x8; // Armv7 Int8x8x16 Kernel 4x8x8
class AlgoInt8x8x16MK4_8x8x4; // Armv7 Int8x8x16 Kernel 8x8x8
class AlgoInt8x8x16K8x8x4; // Armv7 Int8x8x16 Kernel 8x8x4
class AlgoInt8x8x16MK4_8x8x4; // Armv7 Int8x8x16 Kernel mk4_8x8x4
class AlgoInt16x16x32K12x4x1; // Armv7 Int16x16x32 Kernel 12x4x1
class AlgoInt16x16x32MK8_4x8; // Armv7 Int16x16x32 MK8 Format block 4x8
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
......
......@@ -174,7 +174,8 @@ public:
ARMV7_INT8X8X16_MK4_K8X8X4,
ARMV7_INT16X16X32_K12X4X1,
ARMV7_INT16X16X32_MK8_4X8,
ARMV7_INT8X8X32_MK4_4X2X16
ARMV7_INT8X8X32_MK4_4X2X16,
ARMV7_INT8X8X16_K8X8X4
#endif
#endif
};
......
......@@ -52,6 +52,12 @@ TEST_F(ARMV7, MATRIX_MUL_INT8x8x16_K4x8x8) {
handle(), "ARMV7_INT8X8X16_K4X8X8");
}
TEST_F(ARMV7, MATRIX_MUL_INT8x8x16_K8x8x4) {
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{},
handle(), "ARMV7_INT8X8X16_K8X8X4");
}
TEST_F(ARMV7, MATRIX_MUL_INT8x8x16_MK4_K8x8x4) {
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{},
handle(), "ARMV7_INT8X8X16_MK4_K8X8X4",
......@@ -183,6 +189,68 @@ void run_8x8x16_benchmark(
}
}
}
void run_8x8x16_contrast(
const char* algo0, const char* algo, Handle* handle,
MatrixMul::Param::Format format = MatrixMul::Param::Format::DEFAULT) {
constexpr size_t RUNS = 100;
param::MatrixMul param;
Benchmarker<MatrixMul> benchmarker_int(handle);
Benchmarker<MatrixMul> benchmarker_int_kern_4x2x16(handle);
benchmarker_int.set_before_exec_callback(AlgoChecker<MatrixMul>(algo0));
benchmarker_int.set_times(RUNS)
.set_dtype(0, dtype::Int8{})
.set_dtype(1, dtype::Int8{})
.set_dtype(2, dtype::Int16{})
.set_param(param)
.set_display(false);
param::MatrixMul target_param;
target_param.format = format;
benchmarker_int_kern_4x2x16.set_before_exec_callback(
AlgoChecker<MatrixMul>(algo));
benchmarker_int_kern_4x2x16.set_times(RUNS)
.set_dtype(0, dtype::Int8{})
.set_dtype(1, dtype::Int8{})
.set_dtype(2, dtype::Int16{})
.set_param(target_param)
.set_display(false);
auto run = [&](size_t M, size_t N, size_t K) {
auto int_used = benchmarker_int.exec({{M, K}, {K, N}, {}}) / RUNS;
auto int_kern_used = 1e10;
double computation = 2.0f * M * N * K * 1e-6;
if (format == MatrixMul::Param::Format::MK4) {
int_kern_used = benchmarker_int_kern_4x2x16.exec(
{{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) /
RUNS;
} else {
int_kern_used =
benchmarker_int_kern_4x2x16.exec({{M, K}, {K, N}, {}}) /
RUNS;
}
printf(" %f(%f)\t %f(%f)\t %f\n", int_used, computation / int_used,
int_kern_used, computation / int_kern_used,
int_used / int_kern_used);
};
printf("\nN\t K\t M\t %s ms(GFlops)\t %s ms(GFlops)\t SPEEDUP\n", algo0,
algo);
for (size_t M : {8}) {
for (size_t K : {72}) {
for (size_t N : {8, 16, 32, 64, 72, 128, 256, 512, 1024, 4096, 8192,
16384, 32768, 65536}) {
printf("%zu\t %zu\t %zu\t", N, K, M);
run(M, N, K);
}
}
}
printf("512\t 512\t 512\t");
run(512, 512, 512);
}
void run_16x16x32_benchmark(const char* algo, Handle* handle) {
constexpr size_t RUNS = 50;
param::MatrixMul param;
......@@ -383,6 +451,10 @@ TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K4x8x8) {
run_8x8x16_benchmark("ARMV7_INT8X8X16_K4X8X8", handle());
}
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K8x8x4) {
run_8x8x16_benchmark("ARMV7_INT8X8X16_K8X8X4", handle());
}
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_MK4_K4x8x8) {
run_8x8x16_benchmark("ARMV7_INT8X8X16_MK4_K8X8X4", handle(),
MatrixMul::Param::Format::MK4);
......@@ -392,6 +464,21 @@ TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT16x16x32_K12x4x1) {
run_16x16x32_benchmark("ARMV7_INT16X16X32_K12X4X1", handle());
}
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K8x8x4_CONTRAST) {
run_8x8x16_contrast("ARM_COMMON_INT8X8X16", "ARMV7_INT8X8X16_K8X8X4",
handle());
}
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K4x8x8_CONTRAST) {
run_8x8x16_contrast("ARM_COMMON_INT8X8X16", "ARMV7_INT8X8X16_K4X8X8",
handle());
}
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K4x8x8_K8x8x4_CONTRAST) {
run_8x8x16_contrast("ARMV7_INT8X8X16_K4X8X8", "ARMV7_INT8X8X16_K8X8X4",
handle());
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_FP16) {
constexpr size_t RUNS = 50;
......
......@@ -517,9 +517,18 @@ void convolution::test_conv_config_combinations(int k_size,
param.compute_mode = Param::ComputeMode::FLOAT32;
}
size_t IC = 6, OC = 9, G = 3, FH = ksize, FW = ksize;
TensorShape ishp = format ?
TensorShape{2, 18, 18, IC} : TensorShape{2, IC, 18, 18},
fshp;
TensorShape ishp = TensorShape{2, 18, 18, IC}, fshp;
if (format) {
ishp.shape[0] = 2;
ishp.shape[1] = 18;
ishp.shape[2] = 18;
ishp.shape[3] = IC;
} else {
ishp.shape[0] = 2;
ishp.shape[1] = IC;
ishp.shape[2] = 18;
ishp.shape[3] = 18;
}
if (padding) {
param.pad_h = 2 + non_square;
param.pad_w = 2 - non_square;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册