/** * \file dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ #pragma once #include "megdnn/opr_param_defs.h" #include "src/fallback/conv_bias/opr_impl.h" #if MEGDNN_X86 #include "src/x86/conv_bias/postprocess_helper.h" #elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) #include "src/arm_common/conv_bias/postprocess_helper.h" #endif namespace megdnn { namespace fallback { namespace conv1x1 { #if MEGDNN_X86 using namespace x86; #endif namespace { //! get_matmul_kern_param MatrixMulImpl::KernSizeParam get_matmul_kern_param( const ConvBiasImpl::NCBKernSizeParam& param, size_t n, size_t m) { size_t M = m; size_t N = n; size_t K = param.filter_meta.icpg; //! K = IC size_t LDA = K, LDB = N, LDC = N; bool is_dst_8bit = (param.src_type.enumv() == DTypeEnum::QuantizedS8 && param.dst_type.enumv() == DTypeEnum::QuantizedS8) || (param.src_type.enumv() == DTypeEnum::Quantized8Asymm && param.dst_type.enumv() == DTypeEnum::Quantized8Asymm); return {param.filter_type, param.src_type, is_dst_8bit ? param.bias_type : param.dst_type, M, N, K, LDA, LDB, LDC, false, false, param::MatrixMul::ComputeMode::DEFAULT, param::MatrixMul::Format::DEFAULT}; } } // namespace class Conv1x1StrategyBase { public: virtual void packA(WorkspaceBundle& whole_bundle, WorkspaceBundle& matmul_bundle, size_t oc_tile_size, const MatrixMulImpl::AlgoBase* matmul_algo, const ConvBiasImpl::NCBKernSizeParam& param, const ConvBiasImpl::NCBKernParam& ncb_param, const ConvBiasImpl::NCBKernIndex& ncb_index) = 0; virtual void packB(WorkspaceBundle& whole_bundle, WorkspaceBundle& matmul_bundle, const MatrixMulImpl::AlgoBase* matmul_algo, const ConvBiasImpl::NCBKernSizeParam& param, const ConvBiasImpl::NCBKernParam& ncb_param, const ConvBiasImpl::NCBKernIndex& ncb_index) = 0; virtual void exec(WorkspaceBundle& whole_bundle, WorkspaceBundle& matmul_bundle, WorkspaceBundle& thread_bundle, size_t oc_tile_size, const MatrixMulImpl::AlgoBase* matmul_algo, const ConvBiasImpl::NCBKernSizeParam& param, const ConvBiasImpl::NCBKernParam& ncb_param, const ConvBiasImpl::NCBKernIndex& ncb_index) = 0; virtual ~Conv1x1StrategyBase() = default; }; template class Conv1x1Strategy : public Conv1x1StrategyBase { public: explicit Conv1x1Strategy(size_t pack_size = 1) : m_pack_size(pack_size) {} void packA(WorkspaceBundle& whole_bundle, WorkspaceBundle& matmul_bundle, size_t oc_tile_size, const MatrixMulImpl::AlgoBase* matmul_algo, const ConvBiasImpl::NCBKernSizeParam& param, const ConvBiasImpl::NCBKernParam& ncb_param, const ConvBiasImpl::NCBKernIndex& ncb_index) override { if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::NO_PACK) { megdnn_log_error("NoPack mode has no packA kernel"); return; } whole_bundle.set(ncb_param.workspace_ptr); //! packa size per group size_t OC = param.filter_meta.ocpg; size_t oc_tiles_per_group = div_ceil(OC, oc_tile_size); size_t packa_bytes_per_oc_tile = matmul_bundle.get_size(0); size_t packa_bytes_per_group = oc_tiles_per_group * packa_bytes_per_oc_tile; size_t group_id = ncb_index.ndrange_id[0]; size_t oc_tile_id_in_group = ncb_index.ndrange_id[1]; size_t oc_start = oc_tile_id_in_group * oc_tile_size; size_t oc_end = oc_start + oc_tile_size; oc_end = (oc_end <= OC ? oc_end : OC); size_t OH = param.osz[0]; size_t OW = param.osz[1]; size_t IC = param.filter_meta.icpg; MatrixMulImpl::KernParam matmul_kern_param; static_cast(matmul_kern_param) = get_matmul_kern_param(param, OH * OW, oc_end - oc_start); size_t bytes_offset_of_a_panel = group_id * packa_bytes_per_group + oc_tile_id_in_group * packa_bytes_per_oc_tile; size_t numbers_offset_of_filter = oc_tile_size * IC * oc_tile_id_in_group; src_ctype* a_panel = reinterpret_cast( reinterpret_cast(whole_bundle.get(0)) + bytes_offset_of_a_panel); matmul_kern_param.LDA *= m_pack_size; matmul_kern_param.A_ptr = const_cast( ncb_param.filter(group_id) + numbers_offset_of_filter); matmul_algo->pack_A(matmul_kern_param, a_panel, 0, oc_end - oc_start); } void packB(WorkspaceBundle& whole_bundle, WorkspaceBundle& matmul_bundle, const MatrixMulImpl::AlgoBase* matmul_algo, const ConvBiasImpl::NCBKernSizeParam& param, const ConvBiasImpl::NCBKernParam& ncb_param, const ConvBiasImpl::NCBKernIndex& ncb_index) override { MEGDNN_MARK_USED_VAR(ncb_index); if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::DEFAULT) { whole_bundle.set(ncb_param.workspace_ptr); //! packb size per group size_t packb_bytes_per_group = matmul_bundle.get_size(1); size_t GROUP = param.filter_meta.group; size_t BATCH = param.n; size_t SH = param.filter_meta.stride[0]; size_t SW = param.filter_meta.stride[1]; size_t OH = param.osz[0]; size_t OW = param.osz[1]; size_t OC = param.filter_meta.ocpg; MatrixMulImpl::KernParam matmul_kern_param; static_cast(matmul_kern_param) = get_matmul_kern_param(param, OH * OW, OC); matmul_kern_param.LDB *= m_pack_size; rep(batch, BATCH) { rep(g, GROUP) { if (SH == 2 && SW == 2) megdnn_throw("no support for stride = 2"); size_t bytes_offset_of_b_panel = batch * packb_bytes_per_group * GROUP + g * packb_bytes_per_group; src_ctype* b_panel = reinterpret_cast( reinterpret_cast(whole_bundle.get(1)) + bytes_offset_of_b_panel); matmul_kern_param.B_ptr = const_cast( ncb_param.src(batch, g)); matmul_algo->pack_B(matmul_kern_param, b_panel, 0, OH * OW); } } } else { megdnn_log_error("OnlyPackA mode and NoPack mode has no packB kernel"); } } void exec(WorkspaceBundle& whole_bundle, WorkspaceBundle& matmul_bundle, WorkspaceBundle& thread_bundle, size_t oc_tile_size, const MatrixMulImpl::AlgoBase* matmul_algo, const ConvBiasImpl::NCBKernSizeParam& param, const ConvBiasImpl::NCBKernParam& ncb_param, const ConvBiasImpl::NCBKernIndex& ncb_index) override { whole_bundle.set(ncb_param.workspace_ptr); size_t OC = param.filter_meta.ocpg; size_t IC = param.filter_meta.icpg; //! packa bytes per group size_t oc_tiles_per_group = div_ceil(OC, oc_tile_size); size_t packa_bytes_per_oc_tile = matmul_bundle.get_size(0); size_t packa_bytes_per_group = packa_bytes_per_oc_tile * oc_tiles_per_group; //! packb bytes per group size_t packb_bytes_per_group = matmul_bundle.get_size(1); //! matmul bytes per thread size_t matmul_bytes_per_thread = thread_bundle.get_size(0); size_t batch_id = ncb_index.ndrange_id[0]; size_t group_id = ncb_index.ndrange_id[1]; size_t oc_tile_id_in_group = ncb_index.ndrange_id[2]; size_t thread_id = ncb_index.thread_id; size_t GROUP = param.filter_meta.group; size_t OH = param.osz[0]; size_t OW = param.osz[1]; size_t oc_start = oc_tile_size * oc_tile_id_in_group; size_t oc_end = oc_start + oc_tile_size; oc_end = (oc_end <= OC ? oc_end : OC); MatrixMulImpl::KernParam matmul_kern_param; static_cast(matmul_kern_param) = get_matmul_kern_param(param, OH * OW, oc_end - oc_start); size_t bytes_offset_of_a_panel = group_id * packa_bytes_per_group + oc_tile_id_in_group * packa_bytes_per_oc_tile; int8_t* a_panel = reinterpret_cast(whole_bundle.get(0)) + bytes_offset_of_a_panel; size_t bytes_offset_of_b_panel = batch_id * packb_bytes_per_group * GROUP + group_id * packb_bytes_per_group; int8_t* b_panel = reinterpret_cast(whole_bundle.get(1)) + bytes_offset_of_b_panel; size_t thread_offset = thread_bundle.total_size_in_bytes() * thread_id; size_t bytes_offset_of_matmul_dst_this_thread = thread_offset + thread_bundle.get_size(0); int8_t* matmul_temp_dst = reinterpret_cast(whole_bundle.get(2)) + bytes_offset_of_matmul_dst_this_thread; size_t numbers_of_ncb_dst_offset = oc_tile_size * OH * OW * oc_tile_id_in_group; void* conv_bias_dst = static_cast( ncb_param.dst(batch_id, group_id) + numbers_of_ncb_dst_offset); size_t numbers_of_ncb_filter_offset = oc_tile_size * IC * oc_tile_id_in_group; matmul_kern_param.A_ptr = const_cast( ncb_param.filter(group_id) + numbers_of_ncb_filter_offset); matmul_kern_param.B_ptr = const_cast( ncb_param.src(batch_id, group_id)); matmul_kern_param.workspace_ptr = reinterpret_cast(whole_bundle.get(2)) + thread_offset; matmul_kern_param.workspace_size = matmul_bytes_per_thread; bool is_dst_8bit = (param.src_type.enumv() == DTypeEnum::QuantizedS8 && param.dst_type.enumv() == DTypeEnum::QuantizedS8) || (param.src_type.enumv() == DTypeEnum::Quantized8Asymm && param.dst_type.enumv() == DTypeEnum::Quantized8Asymm); void* matmul_dst = is_dst_8bit ? matmul_temp_dst : conv_bias_dst; matmul_kern_param.C_ptr = matmul_dst; matmul_kern_param.LDC *= m_pack_size; if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::NO_PACK) { auto matmul_kern = matmul_algo->get_kern(matmul_kern_param); matmul_kern(matmul_kern_param); } else { auto matmul_kern_naked = matmul_algo->get_kern_naked(matmul_kern_param); matmul_kern_naked(matmul_kern_param, a_panel, b_panel); } //! do postprocess void* bias_ptr = nullptr; if (param.bias_mode == megdnn::BiasMode::BIAS) bias_ptr = static_cast(const_cast( ncb_param.bias(batch_id, group_id) + numbers_of_ncb_dst_offset)); else bias_ptr = static_cast(const_cast( ncb_param.bias(batch_id, group_id) + oc_start)); PostProcess::run( matmul_dst, bias_ptr, conv_bias_dst, param.bias_mode, param.nonlineMode, param.bias_type, param.dst_type, 1_z, (oc_end - oc_start) / m_pack_size, OH, OW, m_pack_size); } private: size_t m_pack_size = 1; }; class Conv1x1Factory { public: static Conv1x1StrategyBase* make_conv1x1_strategy( const ConvBiasImpl::NCBKernSizeParam& param, MatrixMulImpl::AlgoBase::PackMode pack_mode, param::ConvBias::Format format); }; } // namespace conv1x1 } // namespace fallback } // namespace megdnn // vim: syntax=cpp.doxygen