提交 3afa3893 编写于 作者: M Megvii Engine Team

perf(arm_common): optimize arm common pooling 9x9 and 13x13

GitOrigin-RevId: 33d5a624784a5dde61b6c9cfe461297a0f2950fe
上级 d16c5caf
......@@ -124,4 +124,5 @@ __ai void load_helper_x(T& weight, T2 ptr, int oc_offset, XT... args) {
} // namespace
} // namespace megdnn
#undef __ai
// vim: syntax=cpp.doxygen
\ No newline at end of file
// vim: syntax=cpp.doxygen
......@@ -30,10 +30,12 @@ bool PoolingImpl::AlgoFp32ModexStridexNCHW44::usable(
bool avaible = param.src_type.enumv() == DTypeEnum::Float32 &&
param.format == Param::Format::NCHW44 &&
(param.mode == Mode::MAX || param.mode == Mode::AVERAGE) &&
fh == fw && sh == sw &&
(fh == 2 || fh == 3 || fh == 4 || fh == 5) &&
(sh == 1 || sh == 2);
return avaible;
fh == fw && sh == sw;
bool size_ok = ((fh == 2 || fh == 3 || fh == 4 || fh == 5) &&
(sh == 1 || sh == 2));
size_ok |= ((fh == 9 || fh == 13) && (sh == 1));
return avaible && size_ok;
}
void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec(
......@@ -94,6 +96,15 @@ void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec(
megdnn_assert(0, "invalid stride %d", sh); \
}
#define DISPATCH_STRIDE_1(filter) \
switch (sh) { \
case 1: \
DISPATCH_MODE(filter, 1); \
break; \
default: \
megdnn_assert(0, "invalid stride %d", sh); \
}
#define DISPATCH_FILTER() \
switch (fh) { \
case 2: \
......@@ -108,6 +119,12 @@ void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec(
case 5: \
DISPATCH_STRIDE(5); \
break; \
case 9: \
DISPATCH_STRIDE_1(9); \
break; \
case 13: \
DISPATCH_STRIDE_1(13); \
break; \
default: \
megdnn_assert(0, "invalid filter %d", fh); \
}
......@@ -123,4 +140,4 @@ void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec(
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
\ No newline at end of file
// vim: syntax=cpp.doxygen
......@@ -64,6 +64,8 @@ INSTANCE_CAL(2)
INSTANCE_CAL(3)
INSTANCE_CAL(4)
INSTANCE_CAL(5)
INSTANCE_CAL(9)
INSTANCE_CAL(13)
#undef INSTANCE_CAL
#undef CALCULATE_AVG_CB
......@@ -305,4 +307,4 @@ static inline void pooling_fp32_nchw44(const float32_t* src, float32_t* dst,
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
\ No newline at end of file
// vim: syntax=cpp.doxygen
......@@ -116,6 +116,31 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_NCHW44_FP32) {
}
}
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W9_w13_NCHW44)
{
UniformIntRNG rng{-10, 10};
Checker<Pooling> checker(handle());
checker.set_rng(0, &rng);
// clang-format off
for (size_t ih: {20, 15})
for (size_t iw: {15, 20})
for (size_t kernel: {9, 13})
for (size_t pad: {4, 6})
for(auto mode: {param::Pooling::Mode::MAX, param::Pooling::Mode::AVERAGE})
if (kernel > pad)
{
param::Pooling param;
param.mode = mode;
param.format = param::Pooling::Format::NCHW44;
param.pad_h = pad;
param.pad_w = pad;
param.stride_h = param.stride_w = 1;
param.window_h = param.window_w = kernel ;
checker.set_param(param).exec(TensorShapeArray{{2, 8, ih, iw, 4}, {}});
}
// clang-format on
}
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W3x3_NCHW44)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
......
......@@ -2,7 +2,7 @@
set -e
ARCHS=("arm64-v8a" "armeabi-v7a")
BUILD_TYPE=Release
BUILD_TYPE=RelWithDebInfo
MGE_ARMV8_2_FEATURE_FP16=OFF
MGE_DISABLE_FLOAT16=OFF
ARCH=arm64-v8a
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册