提交 67f11788 编写于 作者: M Megvii Engine Team

perf(arm_common): add elemwise unary multithread support

GitOrigin-RevId: 8eac123f67224e283b368c515bf0b8e7ef565158
上级 3afa3893
......@@ -71,12 +71,19 @@ void ElemwiseImpl::AlgoUnary::exec(const KernParam& kern_param) const {
thin_function<void(const _type*, _type*, DType, DType, size_t)> \
run = OpCallerUnary<_op<_type, _type>, \
BcastType::VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
auto kernel = [nr_elems, nr_elems_per_thread, src0, dst_tensor, \
run](size_t task_id, size_t) { \
size_t offset = task_id * nr_elems_per_thread; \
size_t nr_elems_thread = \
std::min(nr_elems - offset, nr_elems_per_thread); \
run(static_cast<const _type*>(src0.raw_ptr) + offset, \
static_cast<_type*>(dst_tensor.raw_ptr) + offset, \
src0.layout.dtype, dst_tensor.layout.dtype, \
nr_elems_thread); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr), \
static_cast<_type*>(dst_tensor.raw_ptr), \
src0.layout.dtype, dst_tensor.layout.dtype, \
nr_elems)); \
nr_threads, kernel); \
} \
MIDOUT_END(); \
return
......@@ -86,7 +93,12 @@ void ElemwiseImpl::AlgoUnary::exec(const KernParam& kern_param) const {
auto& src0 = elparam[0];
auto& dst_tensor = *(kern_param.m_dst);
size_t nr_threads = static_cast<naive::HandleImpl*>(kern_param.handle)
->megcore_dispatcher()
->nr_threads();
size_t nr_elems = src0.layout.total_nr_elems();
size_t nr_elems_per_thread = (nr_elems + nr_threads - 1) / nr_threads;
#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \
switch (kern_param.mode) { \
......
......@@ -26,6 +26,13 @@ TYPED_TEST(ARM_ELEMWISE, run) {
elemwise::run_test<TypeParam>(this->handle());
}
template <typename tag>
class ARM_ELEMWISE_MULTI_THREADS : public ARM_COMMON_MULTI_THREADS {};
TYPED_TEST_CASE(ARM_ELEMWISE_MULTI_THREADS, elemwise::test_types);
TYPED_TEST(ARM_ELEMWISE_MULTI_THREADS, run) {
elemwise::run_test<TypeParam>(this->handle());
}
TEST_F(ARM_COMMON, ELEMWISE_FORWARD_TERNARY) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle());
......
......@@ -2,7 +2,7 @@
set -e
ARCHS=("arm64-v8a" "armeabi-v7a")
BUILD_TYPE=RelWithDebInfo
BUILD_TYPE=Release
MGE_ARMV8_2_FEATURE_FP16=OFF
MGE_DISABLE_FLOAT16=OFF
ARCH=arm64-v8a
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册