提交 821656aa 编写于 作者: M Megvii Engine Team

refactor(megdnn): refactor brute force algo in batched matmul

GitOrigin-RevId: 5c143ab3acadf70f94fd3f6395c5916192da6103
上级 08ff62de
......@@ -54,12 +54,7 @@ BatchedMatrixMulForwardImpl::AlgoPack::AlgoPack() {
all_algos.push_back(&cublasLt);
#endif
all_algos.push_back(&int8x8x32);
for (auto& algo : mm_pack.all_algos) {
brute_force_algos.emplace_back(AlgoBruteForce(algo));
}
for (auto& algo : brute_force_algos) {
all_algos.push_back(&algo);
}
all_algos.push_back(&brute_force);
for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
......
......@@ -87,26 +87,20 @@ public:
class BatchedMatrixMulForwardImpl::AlgoBruteForce final
: public BatchedMatrixMulForwardImpl::AlgoBase {
using Param = MatrixMulForward::Param;
private:
std::string m_name;
MatrixMulForwardImpl::AlgoBase* m_algorithm = nullptr;
WorkspaceBundle get_workspace_bundle();
public:
AlgoBruteForce(MatrixMulForwardImpl::AlgoBase* algo);
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override;
void exec(const ExecArgs& args) const final;
bool is_reproducible() const override { return true; }
const char* name() const override { return m_name.c_str(); }
const char* name() const override { return "BRUTE_FORCE"; }
MEGDNN_DECL_ALGO_TYPE(CUDA_BRUTE_FORCE)
std::string param() const override {
std::string ret;
serialize_write_pod(m_algorithm, ret);
return ret;
}
std::vector<SearchItem> get_subopr_list(
const TensorLayoutArray& layouts,
const OperatorBase* opr) const override;
};
class BatchedMatrixMulForwardImpl::AlgoCublas final
: public BatchedMatrixMulForwardImpl::AlgoBase {
......@@ -157,7 +151,7 @@ public:
#endif
AlgoInt8x8x32 int8x8x32;
std::vector<AlgoBase*> all_algos;
std::vector<AlgoBruteForce> brute_force_algos;
AlgoBruteForce brute_force;
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
};
......
......@@ -9,48 +9,86 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./algo.h"
#include "megdnn/opr_param_defs.h"
#include "src/common/algo_chooser.h"
#include "src/cuda/handle.h"
#include "src/cuda/utils.h"
using namespace megdnn;
using namespace cuda;
BatchedMatrixMulForwardImpl::AlgoBruteForce::AlgoBruteForce(
MatrixMulForwardImpl::AlgoBase* algo)
: m_algorithm(algo) {
m_name = ssprintf("BRUTE_FORCE-%s", algo->name());
namespace {
std::pair<TensorLayoutArray, MatrixMulForward::Param> sub_opr_config(
const TensorLayout& layout_a, const TensorLayout& layout_b,
const TensorLayout& layout_c, const BatchedMatrixMulForward* opr) {
auto mm_layout_a = layout_a.remove_axis(0);
auto mm_layout_b = layout_b.remove_axis(0);
auto mm_layout_c = layout_c.remove_axis(0);
return {{mm_layout_a, mm_layout_b, mm_layout_c}, opr->param()};
}
} // namespace
std::vector<Algorithm::SearchItem>
BatchedMatrixMulForwardImpl::AlgoBruteForce::get_subopr_list(
const TensorLayoutArray& layouts, const OperatorBase* opr) const {
const BatchedMatrixMulForwardImpl* bmm_opr =
static_cast<const BatchedMatrixMulForwardImpl*>(opr);
auto&& config = sub_opr_config(layouts[0], layouts[1], layouts[2], bmm_opr);
std::string param_str;
Algorithm::serialize_write_pod(config.second, param_str);
return {{Algorithm::OprType::MATRIX_MUL_FORWARD, param_str, config.first}};
}
bool BatchedMatrixMulForwardImpl::AlgoBruteForce::is_available(
const SizeArgs& args) const {
MatrixMulForwardImpl mm{args.opr->handle()};
mm.param() = {args.opr->param().transposeA, args.opr->param().transposeB};
mm.execution_policy() = {m_algorithm->desc(), {}};
auto matmul_opr = args.opr->handle()->create_operator<MatrixMulForward>();
if (args.opr->execution_policy().algo.valid() &&
!args.opr->execution_policy().sub_policy.empty()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
matmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto mm_layout_a = args.layout_a.remove_axis(0);
auto mm_layout_b = args.layout_b.remove_axis(0);
auto mm_layout_c = args.layout_c.remove_axis(0);
auto&& config = sub_opr_config(args.layout_a, args.layout_b, args.layout_c,
args.opr);
matmul_opr->param() = config.second;
MatrixMulForwardImpl::AlgoBase::SizeArgs mm_args{&mm, mm_layout_a,
mm_layout_b, mm_layout_c};
return m_algorithm->is_available(mm_args);
return get_algorithm(static_cast<MatrixMulForwardImpl*>(matmul_opr.get()),
config.first[0], config.first[1], config.first[2]);
}
size_t BatchedMatrixMulForwardImpl::AlgoBruteForce::get_workspace_in_bytes(
const SizeArgs& args) const {
auto mm_opr = args.opr->handle()->create_operator<MatrixMulForward>();
mm_opr->param() = {args.opr->param().transposeA,
args.opr->param().transposeB};
mm_opr->execution_policy() = {m_algorithm->desc(), {}};
auto matmul_opr = args.opr->handle()->create_operator<MatrixMulForward>();
if (args.opr->execution_policy().algo.valid() &&
!args.opr->execution_policy().sub_policy.empty()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
matmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto&& config = sub_opr_config(args.layout_a, args.layout_b, args.layout_c,
args.opr);
matmul_opr->param() = config.second;
return mm_opr->get_workspace_in_bytes(args.layout_a, args.layout_b,
args.layout_c);
return matmul_opr->get_workspace_in_bytes(config.first[0], config.first[1],
config.first[2]);
}
void BatchedMatrixMulForwardImpl::AlgoBruteForce::exec(
const ExecArgs& args) const {
auto N = args.layout_a.shape[0];
auto&& mm_opr = args.opr->handle()->create_operator<MatrixMulForward>();
mm_opr->param() = {args.opr->param().transposeA,
args.opr->param().transposeB};
mm_opr->execution_policy() = {m_algorithm->desc(), {}};
auto matmul_opr = args.opr->handle()->create_operator<MatrixMulForward>();
if (args.opr->execution_policy().algo.valid()) {
megdnn_assert(args.opr->execution_policy().sub_policy.size() == 1);
matmul_opr->execution_policy() =
args.opr->execution_policy().sub_policy[0];
}
auto&& config = sub_opr_config(args.layout_a, args.layout_b, args.layout_c,
args.opr);
matmul_opr->param() = config.second;
rep(n, N) {
TensorND A_, B_, C_;
auto tensor_n_from_batch = [n](const TensorND& in, TensorND& out) {
......@@ -62,6 +100,6 @@ void BatchedMatrixMulForwardImpl::AlgoBruteForce::exec(
tensor_n_from_batch(args.tensor_a, A_);
tensor_n_from_batch(args.tensor_b, B_);
tensor_n_from_batch(args.tensor_c, C_);
mm_opr->exec(A_, B_, C_, args.workspace);
matmul_opr->exec(A_, B_, C_, args.workspace);
}
}
......@@ -56,9 +56,8 @@ std::vector<Algorithm*> BatchedMatrixMulForwardImpl::get_all_algorithms(
Algorithm* BatchedMatrixMulForwardImpl::get_algorithm_heuristic(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_limit_in_bytes, bool reproducible) {
MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes);
AlgoBase::SizeArgs args(this, A, B, C);
std::vector<AlgoBase*> brute_force_algos;
if (sm_algo_pack.cublas.is_available_reproducible(args, reproducible)) {
return &sm_algo_pack.cublas;
}
......@@ -72,25 +71,14 @@ Algorithm* BatchedMatrixMulForwardImpl::get_algorithm_heuristic(
reproducible)) {
return &sm_algo_pack.int8x8x32;
} else {
for (auto& algo : sm_algo_pack.brute_force_algos) {
if (algo.is_available_reproducible(args, reproducible)) {
return &algo;
}
if (sm_algo_pack.brute_force.is_available_reproducible(args,
reproducible)) {
return &sm_algo_pack.brute_force;
}
}
for (auto& algo : sm_algo_pack.brute_force_algos)
brute_force_algos.push_back(&algo);
if (reproducible) {
return megdnn::get_reproducible_algo<BatchedMatrixMulForwardImpl>(
brute_force_algos, args, workspace_limit_in_bytes,
"batched matrix mul");
} else {
return megdnn::get_usable_algo<BatchedMatrixMulForwardImpl>(
brute_force_algos, args, workspace_limit_in_bytes,
"batched matrix mul");
}
megdnn_throw("No usable algo for batched_matrix_mul");
return nullptr;
};
// vim: syntax=cpp.doxygen
......@@ -138,12 +138,13 @@ std::vector<matrix_mul::TestArg> matrix_mul::get_batched_matmul_args() {
template <typename Opr>
void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype,
Handle* handle, const char* algo,
Handle* handle,
const ExecutionPolicyAlgoName& algo,
param::MatrixMul::Format format, size_t nbase,
float eps, std::vector<TestArg>&& user_args) {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv());
Checker<Opr> checker(handle);
if (algo) {
if (!algo.name.empty()) {
checker.set_before_exec_callback(AlgoChecker<Opr>(algo));
}
std::unique_ptr<RNG> rng;
......@@ -267,7 +268,8 @@ void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype,
void matrix_mul::check_batched_matrix_mul(DType A_dtype, DType B_dtype,
DType C_dtype, Handle* handle,
const char* algo, float eps,
const ExecutionPolicyAlgoName& algo,
float eps,
std::vector<TestArg>&& args) {
check_matrix_mul<megdnn::BatchedMatrixMul>(
A_dtype, B_dtype, C_dtype, handle, algo,
......@@ -276,7 +278,8 @@ void matrix_mul::check_batched_matrix_mul(DType A_dtype, DType B_dtype,
}
void matrix_mul::check_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype,
Handle* handle, const char* algo,
Handle* handle,
const ExecutionPolicyAlgoName& algo,
param::MatrixMul::Format format, size_t nbase,
float eps) {
check_matrix_mul<megdnn::MatrixMul>(A_dtype, B_dtype, C_dtype, handle, algo,
......
......@@ -16,6 +16,7 @@
#include "megdnn/handle.h"
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs.h"
#include "test/common/checker.h"
namespace megdnn {
namespace test {
......@@ -58,18 +59,19 @@ using TestArgFilterFunc = std::function<bool(const TestArg&)>;
template <typename Opr = megdnn::MatrixMul>
void check_matrix_mul(
DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle,
const char* algo = nullptr,
const ExecutionPolicyAlgoName& algo = {"", {}},
param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT,
size_t nbase = 8, float eps = 1e-3, std::vector<TestArg>&& args = {});
void check_matrix_mul(
DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle,
const char* algo = nullptr,
const ExecutionPolicyAlgoName& algo = {"", {}},
param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT,
size_t nbase = 8, float eps = 1e-3);
void check_batched_matrix_mul(DType A_dtype, DType B_dtype, DType C_dtype,
Handle* handle, const char* algo = nullptr,
Handle* handle,
const ExecutionPolicyAlgoName& algo = {"", {}},
float eps = 1e-3,
std::vector<TestArg>&& args = {});
......
......@@ -20,8 +20,8 @@ using namespace test;
//! check batch=1 and batch_stride is arbitrarily
TEST_F(CPU, BATCHED_MATRIX_MUL_BATCH_1) {
matrix_mul::check_batched_matrix_mul(
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
nullptr, 1e-3,
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "",
1e-3,
std::vector<matrix_mul::TestArg>{
{5, 5, 5, 0, 5, 5, 5, 1, 5, 5, 5}});
}
......
......@@ -62,6 +62,34 @@ TEST_F(CUDA, BATCHED_MATRIX_MUL_LT_F32_PART4) {
#undef F32_TEST_PART
TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BRUTE_FORCE_PART1) {
matrix_mul::check_batched_matrix_mul(
dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(),
ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}, 1e-3,
matrix_mul::get_batched_matmul_args_mask(0));
}
TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BRUTE_FORCE_PART2) {
matrix_mul::check_batched_matrix_mul(
dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(),
ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}, 1e-3,
matrix_mul::get_batched_matmul_args_mask(1));
}
TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BRUTE_FORCE_PART3) {
matrix_mul::check_batched_matrix_mul(
dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(),
ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}, 1e-3,
matrix_mul::get_batched_matmul_args_mask(2));
}
TEST_F(CUDA, BATCHED_MATRIX_MUL_F32_BRUTE_FORCE_PART4) {
matrix_mul::check_batched_matrix_mul(
dtype::Float32{}, dtype::Float32{}, {}, handle_cuda(),
ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}, 1e-3,
matrix_mul::get_batched_matmul_args_mask(3));
}
TEST_F(CUDA, BATCHED_MATRIX_MUL_F16_PART1) {
require_compute_capability(6, 0);
matrix_mul::check_batched_matrix_mul(
......@@ -150,7 +178,8 @@ TEST_F(CUDA, BATCHED_MATRIX_MUL_INT8x8x32) {
TEST_F(CUDA, BATCHED_MATMUL_8x8x32_BENCHMARK) {
require_compute_capability(6, 1);
auto run = [&](bool transA, bool transB, size_t m, size_t n, size_t k,
const char* algo1, const char* algo2, size_t b = 128) {
const ExecutionPolicyAlgoName& algo1,
const ExecutionPolicyAlgoName& algo2, size_t b = 128) {
size_t RUNS = 10;
CUBenchmarker<BatchedMatrixMul> bencher1(handle_cuda());
bencher1.set_display(false).set_times(RUNS);
......@@ -196,19 +225,20 @@ TEST_F(CUDA, BATCHED_MATMUL_8x8x32_BENCHMARK) {
printf("trA: %d, trB: %d, m: %ld, n: %ld, k: %ld, b: %ld, speedup: %s "
"/ "
"%s %.3f\n",
transA, transB, m, n, k, b, algo1, algo2, flops1 / flops2);
transA, transB, m, n, k, b, algo1.name.c_str(),
algo2.name.c_str(), flops1 / flops2);
};
for (bool transA : {0, 1})
for (bool transB : {0, 1}) {
run(transA, transB, 128, 576, 128, "INT8x8x32",
"BRUTE_FORCE-CUBLAS");
ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}});
run(transA, transB, 256, 144, 256, "INT8x8x32",
"BRUTE_FORCE-CUBLAS");
ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}});
run(transA, transB, 512, 36, 512, "INT8x8x32",
"BRUTE_FORCE-CUBLAS");
ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}});
run(transA, transB, 1024, 8, 1024, "INT8x8x32",
"BRUTE_FORCE-CUBLAS");
ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}});
}
}
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册