diff --git a/dnn/src/cuda/batched_matrix_mul/algo.cpp b/dnn/src/cuda/batched_matrix_mul/algo.cpp index bbe562bda3473d555108a20111622f4b7aed6cd5..cd59ed846e1dc7a080ef1230b83acc71c68584fe 100644 --- a/dnn/src/cuda/batched_matrix_mul/algo.cpp +++ b/dnn/src/cuda/batched_matrix_mul/algo.cpp @@ -54,12 +54,7 @@ BatchedMatrixMulForwardImpl::AlgoPack::AlgoPack() { all_algos.push_back(&cublasLt); #endif all_algos.push_back(&int8x8x32); - for (auto& algo : mm_pack.all_algos) { - brute_force_algos.emplace_back(AlgoBruteForce(algo)); - } - for (auto& algo : brute_force_algos) { - all_algos.push_back(&algo); - } + all_algos.push_back(&brute_force); for (auto&& algo : all_algos) { m_all_algos_map.emplace(algo->info().desc, algo); diff --git a/dnn/src/cuda/batched_matrix_mul/algo.h b/dnn/src/cuda/batched_matrix_mul/algo.h index 8fd34fdf01c5bba2c734c6332148793996e29bc2..40ff620063e45c52ab8dd52e5c0bc4367d26b978 100644 --- a/dnn/src/cuda/batched_matrix_mul/algo.h +++ b/dnn/src/cuda/batched_matrix_mul/algo.h @@ -87,26 +87,20 @@ public: class BatchedMatrixMulForwardImpl::AlgoBruteForce final : public BatchedMatrixMulForwardImpl::AlgoBase { using Param = MatrixMulForward::Param; - private: - std::string m_name; - MatrixMulForwardImpl::AlgoBase* m_algorithm = nullptr; WorkspaceBundle get_workspace_bundle(); public: - AlgoBruteForce(MatrixMulForwardImpl::AlgoBase* algo); bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override; void exec(const ExecArgs& args) const final; bool is_reproducible() const override { return true; } - const char* name() const override { return m_name.c_str(); } + const char* name() const override { return "BRUTE_FORCE"; } MEGDNN_DECL_ALGO_TYPE(CUDA_BRUTE_FORCE) - std::string param() const override { - std::string ret; - serialize_write_pod(m_algorithm, ret); - return ret; - } + std::vector get_subopr_list( + const TensorLayoutArray& layouts, + const OperatorBase* opr) const override; }; class BatchedMatrixMulForwardImpl::AlgoCublas final : public BatchedMatrixMulForwardImpl::AlgoBase { @@ -157,7 +151,7 @@ public: #endif AlgoInt8x8x32 int8x8x32; std::vector all_algos; - std::vector brute_force_algos; + AlgoBruteForce brute_force; const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; diff --git a/dnn/src/cuda/batched_matrix_mul/brute_force.cpp b/dnn/src/cuda/batched_matrix_mul/brute_force.cpp index f3c093ef5b06059c3aa9774f5e8d690d25e85edd..62da5cc75aed977b00148cf3b5704fe56e762b33 100644 --- a/dnn/src/cuda/batched_matrix_mul/brute_force.cpp +++ b/dnn/src/cuda/batched_matrix_mul/brute_force.cpp @@ -9,48 +9,86 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "./algo.h" +#include "megdnn/opr_param_defs.h" +#include "src/common/algo_chooser.h" #include "src/cuda/handle.h" #include "src/cuda/utils.h" using namespace megdnn; using namespace cuda; -BatchedMatrixMulForwardImpl::AlgoBruteForce::AlgoBruteForce( - MatrixMulForwardImpl::AlgoBase* algo) - : m_algorithm(algo) { - m_name = ssprintf("BRUTE_FORCE-%s", algo->name()); +namespace { +std::pair sub_opr_config( + const TensorLayout& layout_a, const TensorLayout& layout_b, + const TensorLayout& layout_c, const BatchedMatrixMulForward* opr) { + auto mm_layout_a = layout_a.remove_axis(0); + auto mm_layout_b = layout_b.remove_axis(0); + auto mm_layout_c = layout_c.remove_axis(0); + + return {{mm_layout_a, mm_layout_b, mm_layout_c}, opr->param()}; +} +} // namespace + +std::vector +BatchedMatrixMulForwardImpl::AlgoBruteForce::get_subopr_list( + const TensorLayoutArray& layouts, const OperatorBase* opr) const { + const BatchedMatrixMulForwardImpl* bmm_opr = + static_cast(opr); + auto&& config = sub_opr_config(layouts[0], layouts[1], layouts[2], bmm_opr); + + std::string param_str; + Algorithm::serialize_write_pod(config.second, param_str); + return {{Algorithm::OprType::MATRIX_MUL_FORWARD, param_str, config.first}}; } + bool BatchedMatrixMulForwardImpl::AlgoBruteForce::is_available( const SizeArgs& args) const { - MatrixMulForwardImpl mm{args.opr->handle()}; - mm.param() = {args.opr->param().transposeA, args.opr->param().transposeB}; - mm.execution_policy() = {m_algorithm->desc(), {}}; + auto matmul_opr = args.opr->handle()->create_operator(); + if (args.opr->execution_policy().algo.valid() && + !args.opr->execution_policy().sub_policy.empty()) { + megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1); + matmul_opr->execution_policy() = + args.opr->execution_policy().sub_policy[0]; + } - auto mm_layout_a = args.layout_a.remove_axis(0); - auto mm_layout_b = args.layout_b.remove_axis(0); - auto mm_layout_c = args.layout_c.remove_axis(0); + auto&& config = sub_opr_config(args.layout_a, args.layout_b, args.layout_c, + args.opr); + matmul_opr->param() = config.second; - MatrixMulForwardImpl::AlgoBase::SizeArgs mm_args{&mm, mm_layout_a, - mm_layout_b, mm_layout_c}; - return m_algorithm->is_available(mm_args); + return get_algorithm(static_cast(matmul_opr.get()), + config.first[0], config.first[1], config.first[2]); } size_t BatchedMatrixMulForwardImpl::AlgoBruteForce::get_workspace_in_bytes( const SizeArgs& args) const { - auto mm_opr = args.opr->handle()->create_operator(); - mm_opr->param() = {args.opr->param().transposeA, - args.opr->param().transposeB}; - mm_opr->execution_policy() = {m_algorithm->desc(), {}}; + auto matmul_opr = args.opr->handle()->create_operator(); + if (args.opr->execution_policy().algo.valid() && + !args.opr->execution_policy().sub_policy.empty()) { + megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1); + matmul_opr->execution_policy() = + args.opr->execution_policy().sub_policy[0]; + } + + auto&& config = sub_opr_config(args.layout_a, args.layout_b, args.layout_c, + args.opr); + matmul_opr->param() = config.second; - return mm_opr->get_workspace_in_bytes(args.layout_a, args.layout_b, - args.layout_c); + return matmul_opr->get_workspace_in_bytes(config.first[0], config.first[1], + config.first[2]); } void BatchedMatrixMulForwardImpl::AlgoBruteForce::exec( const ExecArgs& args) const { auto N = args.layout_a.shape[0]; - auto&& mm_opr = args.opr->handle()->create_operator(); - mm_opr->param() = {args.opr->param().transposeA, - args.opr->param().transposeB}; - mm_opr->execution_policy() = {m_algorithm->desc(), {}}; + auto matmul_opr = args.opr->handle()->create_operator(); + if (args.opr->execution_policy().algo.valid()) { + megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1); + matmul_opr->execution_policy() = + args.opr->execution_policy().sub_policy[0]; + } + + auto&& config = sub_opr_config(args.layout_a, args.layout_b, args.layout_c, + args.opr); + matmul_opr->param() = config.second; + rep(n, N) { TensorND A_, B_, C_; auto tensor_n_from_batch = [n](const TensorND& in, TensorND& out) { @@ -62,6 +100,6 @@ void BatchedMatrixMulForwardImpl::AlgoBruteForce::exec( tensor_n_from_batch(args.tensor_a, A_); tensor_n_from_batch(args.tensor_b, B_); tensor_n_from_batch(args.tensor_c, C_); - mm_opr->exec(A_, B_, C_, args.workspace); + matmul_opr->exec(A_, B_, C_, args.workspace); } } diff --git a/dnn/src/cuda/batched_matrix_mul/opr_impl.cpp b/dnn/src/cuda/batched_matrix_mul/opr_impl.cpp index 2f309c3bb339a6b68114412f628c8351743fc066..d1366a8e6b0624ff833a01f9c81936e59dad9637 100644 --- a/dnn/src/cuda/batched_matrix_mul/opr_impl.cpp +++ b/dnn/src/cuda/batched_matrix_mul/opr_impl.cpp @@ -56,9 +56,8 @@ std::vector BatchedMatrixMulForwardImpl::get_all_algorithms( Algorithm* BatchedMatrixMulForwardImpl::get_algorithm_heuristic( const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, size_t workspace_limit_in_bytes, bool reproducible) { + MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes); AlgoBase::SizeArgs args(this, A, B, C); - std::vector brute_force_algos; - if (sm_algo_pack.cublas.is_available_reproducible(args, reproducible)) { return &sm_algo_pack.cublas; } @@ -72,25 +71,14 @@ Algorithm* BatchedMatrixMulForwardImpl::get_algorithm_heuristic( reproducible)) { return &sm_algo_pack.int8x8x32; } else { - for (auto& algo : sm_algo_pack.brute_force_algos) { - if (algo.is_available_reproducible(args, reproducible)) { - return &algo; - } + if (sm_algo_pack.brute_force.is_available_reproducible(args, + reproducible)) { + return &sm_algo_pack.brute_force; } } - for (auto& algo : sm_algo_pack.brute_force_algos) - brute_force_algos.push_back(&algo); - - if (reproducible) { - return megdnn::get_reproducible_algo( - brute_force_algos, args, workspace_limit_in_bytes, - "batched matrix mul"); - } else { - return megdnn::get_usable_algo( - brute_force_algos, args, workspace_limit_in_bytes, - "batched matrix mul"); - } + megdnn_throw("No usable algo for batched_matrix_mul"); + return nullptr; }; // vim: syntax=cpp.doxygen diff --git a/dnn/test/common/matrix_mul.cpp b/dnn/test/common/matrix_mul.cpp index 7cb553746c22abb3d552b46ef70c979036a58f57..f1f96dc2e6862958c315669470789ead10af00ab 100644 --- a/dnn/test/common/matrix_mul.cpp +++ b/dnn/test/common/matrix_mul.cpp @@ -138,12 +138,13 @@ std::vector matrix_mul::get_batched_matmul_args() { template void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, - Handle* handle, const char* algo, + Handle* handle, + const ExecutionPolicyAlgoName& algo, param::MatrixMul::Format format, size_t nbase, float eps, std::vector&& user_args) { megdnn_assert(A_dtype.enumv() == B_dtype.enumv()); Checker checker(handle); - if (algo) { + if (!algo.name.empty()) { checker.set_before_exec_callback(AlgoChecker(algo)); } std::unique_ptr rng; @@ -267,7 +268,8 @@ void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, void matrix_mul::check_batched_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, - const char* algo, float eps, + const ExecutionPolicyAlgoName& algo, + float eps, std::vector&& args) { check_matrix_mul( A_dtype, B_dtype, C_dtype, handle, algo, @@ -276,7 +278,8 @@ void matrix_mul::check_batched_matrix_mul(DType A_dtype, DType B_dtype, } void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, - Handle* handle, const char* algo, + Handle* handle, + const ExecutionPolicyAlgoName& algo, param::MatrixMul::Format format, size_t nbase, float eps) { check_matrix_mul(A_dtype, B_dtype, C_dtype, handle, algo, diff --git a/dnn/test/common/matrix_mul.h b/dnn/test/common/matrix_mul.h index 7853c09e06375aad3da612195f8302ec08b6499d..5da5bcec0330aa60cc2e373c3f176361f2f2fb2b 100644 --- a/dnn/test/common/matrix_mul.h +++ b/dnn/test/common/matrix_mul.h @@ -16,6 +16,7 @@ #include "megdnn/handle.h" #include "megdnn/opr_param_defs.h" #include "megdnn/oprs.h" +#include "test/common/checker.h" namespace megdnn { namespace test { @@ -58,18 +59,19 @@ using TestArgFilterFunc = std::function; template void check_matrix_mul( DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, - const char* algo = nullptr, + const ExecutionPolicyAlgoName& algo = {"", {}}, param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT, size_t nbase = 8, float eps = 1e-3, std::vector&& args = {}); void check_matrix_mul( DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, - const char* algo = nullptr, + const ExecutionPolicyAlgoName& algo = {"", {}}, param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT, size_t nbase = 8, float eps = 1e-3); void check_batched_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype, - Handle* handle, const char* algo = nullptr, + Handle* handle, + const ExecutionPolicyAlgoName& algo = {"", {}}, float eps = 1e-3, std::vector&& args = {}); diff --git a/dnn/test/cpu/batched_matrix_mul.cpp b/dnn/test/cpu/batched_matrix_mul.cpp index 8c2137e6a0c77792e6fed448b890ae1fa1e7ec76..387e1de3ab5b3050a426e74ef5af86413587b787 100644 --- a/dnn/test/cpu/batched_matrix_mul.cpp +++ b/dnn/test/cpu/batched_matrix_mul.cpp @@ -20,8 +20,8 @@ using namespace test; //! check batch=1 and batch_stride is arbitrarily TEST_F(CPU, BATCHED_MATRIX_MUL_BATCH_1) { matrix_mul::check_batched_matrix_mul( - dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), - nullptr, 1e-3, + dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "", + 1e-3, std::vector{ {5, 5, 5, 0, 5, 5, 5, 1, 5, 5, 5}}); } diff --git a/dnn/test/cuda/batched_matrix_mul.cpp b/dnn/test/cuda/batched_matrix_mul.cpp index 9b84595069e020eaa1e8da3e1abfeb80e943aca6..d237416de52b44906bef7f12e344a7042d0a584f 100644 --- a/dnn/test/cuda/batched_matrix_mul.cpp +++ b/dnn/test/cuda/batched_matrix_mul.cpp @@ -62,6 +62,34 @@ TEST_F(CUDA, BATCHED_MATRIX_MUL_LT_F32_PART4) { #undef F32_TEST_PART +TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BRUTE_FORCE_PART1) { + matrix_mul::check_batched_matrix_mul( + dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(), + ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}, 1e-3, + matrix_mul::get_batched_matmul_args_mask(0)); +} + +TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BRUTE_FORCE_PART2) { + matrix_mul::check_batched_matrix_mul( + dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(), + ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}, 1e-3, + matrix_mul::get_batched_matmul_args_mask(1)); +} + +TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BRUTE_FORCE_PART3) { + matrix_mul::check_batched_matrix_mul( + dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(), + ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}, 1e-3, + matrix_mul::get_batched_matmul_args_mask(2)); +} + +TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BRUTE_FORCE_PART4) { + matrix_mul::check_batched_matrix_mul( + dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(), + ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}, 1e-3, + matrix_mul::get_batched_matmul_args_mask(3)); +} + TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_PART1) { require_compute_capability(6, 0); matrix_mul::check_batched_matrix_mul( @@ -150,7 +178,8 @@ TEST_F(CUDA, BATCHED_MATRIX_MUL_INT8x8x32) { TEST_F(CUDA, BATCHED_MATMUL_8x8x32_BENCHMARK) { require_compute_capability(6, 1); auto run = [&](bool transA, bool transB, size_t m, size_t n, size_t k, - const char* algo1, const char* algo2, size_t b = 128) { + const ExecutionPolicyAlgoName& algo1, + const ExecutionPolicyAlgoName& algo2, size_t b = 128) { size_t RUNS = 10; CUBenchmarker bencher1(handle_cuda()); bencher1.set_display(false).set_times(RUNS); @@ -196,19 +225,20 @@ TEST_F(CUDA, BATCHED_MATMUL_8x8x32_BENCHMARK) { printf("trA: %d, trB: %d, m: %ld, n: %ld, k: %ld, b: %ld, speedup: %s " "/ " "%s %.3f\n", - transA, transB, m, n, k, b, algo1, algo2, flops1 / flops2); + transA, transB, m, n, k, b, algo1.name.c_str(), + algo2.name.c_str(), flops1 / flops2); }; for (bool transA : {0, 1}) for (bool transB : {0, 1}) { run(transA, transB, 128, 576, 128, "INT8x8x32", - "BRUTE_FORCE-CUBLAS"); + ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}); run(transA, transB, 256, 144, 256, "INT8x8x32", - "BRUTE_FORCE-CUBLAS"); + ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}); run(transA, transB, 512, 36, 512, "INT8x8x32", - "BRUTE_FORCE-CUBLAS"); + ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}); run(transA, transB, 1024, 8, 1024, "INT8x8x32", - "BRUTE_FORCE-CUBLAS"); + ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}); } } #endif