提交 ff755451 编写于 作者: M Megvii Engine Team

refactor(mgb): move algo's name from info to desc and delete some algo's unnecessary param() method

GitOrigin-RevId: 144ff547d1aaed633838d35524a0d548823a4f27
上级 a437ec8e
......@@ -136,16 +136,16 @@ public:
uint32_t type = INVALID_ALGO_TYPE;
//! serialized param of the algo type
std::string param;
//! algorithm name
std::string name;
bool valid() const { return type != INVALID_ALGO_TYPE; }
void reset() { type = INVALID_ALGO_TYPE; }
bool operator==(const Desc& rhs) const {
return handle_type == rhs.handle_type && type == rhs.type &&
param == rhs.param;
param == rhs.param && name == rhs.name;
}
} desc;
//! algorithm name
std::string name;
Attribute attribute;
bool valid() const { return desc.valid(); }
void reset() { desc.reset(); }
......@@ -178,12 +178,12 @@ public:
static std::string attribute_str(const Attribute& attr);
Handle::HandleType handle_type() const { return m_handle_type; }
Info::Desc desc() const { return {handle_type(), type(), param(), name()}; }
Info info() const {
return {{handle_type(), type(), param()}, name(), attribute()};
return {desc(), attribute()};
}
Info::Desc desc() const { return {handle_type(), type(), param()}; }
template <typename T>
static void serialize_write_pod(const T& val, std::string& result) {
static_assert(std::is_trivially_copyable<T>::value,
......
......@@ -116,8 +116,10 @@ struct hash<megdnn::detail::Algorithm::Info::Desc> {
const megdnn::detail::Algorithm::Info::Desc& desc) const {
return megdnn::hash_combine<size_t>(
megdnn::hash_combine<size_t>(
std::hash<std::string>()(desc.param),
std::hash<uint32_t>()(desc.type)),
std::hash<std::string>()(desc.name),
megdnn::hash_combine<size_t>(
std::hash<std::string>()(desc.param),
std::hash<uint32_t>()(desc.type))),
std::hash<uint32_t>()(static_cast<uint32_t>(desc.handle_type)));
}
};
......
......@@ -439,12 +439,6 @@ public:
TensorLayout& dst_pg, TensorLayout& bias_pg);
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)
std::string param() const override {
std::string ret;
serialize_write_pod(m_impl->name(), ret);
return ret;
}
private:
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
AlgoBase* m_impl;
......
......@@ -237,12 +237,6 @@ public:
}
return ret;
}
std::string param() const override {
std::string ret;
serialize_write_pod(m_impl->name(), ret);
return ret;
}
};
class ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm final
......
......@@ -222,12 +222,6 @@ public:
}
return ret;
}
std::string param() const override {
std::string ret;
serialize_write_pod(m_impl->name(), ret);
return ret;
}
};
class ConvolutionBackwardFilterImpl::AlgoPack : NonCopyableObj {
......
......@@ -174,14 +174,8 @@ public:
}
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)
std::string param() const override {
std::string ret;
serialize_write_pod(m_impl->name(), ret);
return ret;
}
};
class Convolution3DBackwardDataImpl::AlgoPack : NonCopyableObj {
// defined in cudnn.cpp
void fill_cudnn_algos();
......
......@@ -183,11 +183,6 @@ public:
TensorLayout& diff_pg);
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)
std::string param() const override {
std::string ret;
serialize_write_pod(m_impl->name(), ret);
return ret;
}
};
class Convolution3DBackwardFilterImpl::AlgoPack : NonCopyableObj {
......
......@@ -135,11 +135,6 @@ public:
static void modify_size_args(SizeArgs& args, TensorLayout& src_pg,
TensorLayout& dst_pg);
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)
std::string param() const override {
std::string ret;
serialize_write_pod(m_impl->name(), ret);
return ret;
}
};
class Convolution3DForwardImpl::AlgoCUDNN final : public AlgoBase {
......
......@@ -65,11 +65,6 @@ public:
return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD};
}
MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_F32)
std::string param() const override {
std::string ret;
serialize_write_pod(m_matmul_algo->name(), ret);
return ret;
}
private:
MatrixMulImpl::AlgoBase* m_matmul_algo;
......@@ -101,11 +96,6 @@ public:
return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD};
}
MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_4X4_F32)
std::string param() const override {
std::string ret;
serialize_write_pod(m_matmul_algo->name(), ret);
return ret;
}
private:
MatrixMulImpl::AlgoBase* m_matmul_algo;
......@@ -137,11 +127,6 @@ public:
return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD};
}
MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_QS8)
std::string param() const override {
std::string ret;
serialize_write_pod(m_matmul_algo->name(), ret);
return ret;
}
private:
MatrixMulImpl::AlgoBase* m_matmul_algo;
......@@ -173,11 +158,6 @@ public:
return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD};
}
MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_8X8_QS8)
std::string param() const override {
std::string ret;
serialize_write_pod(m_matmul_algo->name(), ret);
return ret;
}
private:
MatrixMulImpl::AlgoBase* m_matmul_algo;
......
......@@ -157,7 +157,6 @@ using BiasMode = ConvBiasForward::BiasMode;
} \
std::string param() const override { \
std::string ret; \
serialize_write_pod(m_matmul_algo->name(), ret); \
serialize_write_pod(m_tile_size, ret); \
return ret; \
} \
......
......@@ -62,10 +62,9 @@ public:
return {m_matmul_algo->matmul_description().algo_type.data_type,
AlgoCategory::IM2COL};
}
MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_8X8_QS8)
MEGDNN_DECL_ALGO_TYPE(FB_CONV1x1)
std::string param() const override {
std::string ret;
serialize_write_pod(m_matmul_algo->name(), ret);
serialize_write_pod(m_oc_block_size, ret);
return ret;
}
......
......@@ -74,7 +74,6 @@ public:
std::string param() const override {
std::string ret;
serialize_write_pod(m_matmul_algo->name(), ret);
serialize_write_pod(m_ohw_tile_size, ret);
return ret;
}
......
......@@ -155,12 +155,6 @@ public:
//! select matmul to the highest preference
bool is_preferred(const NCBKernSizeParam& param) const override;
std::string param() const override {
std::string ret;
serialize_write_pod(m_algorithm->name(), ret);
return ret;
}
static ConvBiasImpl::NCBKernSizeParam init_conv_bias_param(
const NCBKernSizeParam& param);
......
......@@ -380,13 +380,13 @@ float algo_benchmark(Benchmarker<Opr, T>& benchmark, TensorLayoutArray layouts,
float min_used = std::numeric_limits<float>::max();
bool execed = false;
for (auto i : algos) {
if (std::regex_match(i.name,
if (std::regex_match(i.desc.name,
std::regex("(" + algo_base + ")(.*)"))) {
opr->execution_policy().algo = i.desc;
auto used = benchmark.exec(layouts);
min_used = std::min(min_used, used);
printf("run algo: %s used: %f ms min_used: %f ms\n", i.name.c_str(),
used, min_used);
printf("run algo: %s used: %f ms min_used: %f ms\n",
i.desc.name.c_str(), used, min_used);
execed = true;
}
}
......
......@@ -482,7 +482,7 @@ public:
AlgoProxy<Opr, OprTrait<Opr>::arity>::get_all_algorithms_info(
opr.get(), layouts)) {
if (std::regex_match(
algo_info.name,
algo_info.desc.name,
std::regex("(" + policy_name.name + ")(.*)"))) {
ret.algo = algo_info.desc;
} else {
......@@ -495,7 +495,7 @@ public:
if (sub_items.size() != policy_name.sub_policy_names.size()) {
printf("Invalid sub_policy_names in %s, expected %zu but got "
"%zu\n",
algo_info.name.c_str(), sub_items.size(),
algo_info.desc.name.c_str(), sub_items.size(),
policy_name.sub_policy_names.size());
return {};
}
......@@ -528,7 +528,7 @@ public:
auto algo =
OprAlgoProxy::get_algorithm_info_heuristic(opr, layouts);
ASSERT_STREQ(opr->get_algorithm_from_desc(m_policy.algo)->name(),
algo.name.c_str());
algo.desc.name.c_str());
} else {
opr->execution_policy() = m_policy;
}
......
......@@ -629,11 +629,10 @@ Checker<Convolution> checker(handle);
out_type = inp_type;
}
checker
.set_dtype(0, inp_type)
.set_dtype(1, inp_type)
.set_dtype(2, out_type)
.set_param(param);
checker.set_dtype(0, inp_type)
.set_dtype(1, inp_type)
.set_dtype(2, out_type)
.set_param(param);
auto opr = checker.opr();
opr->param() = param;
std::string param_str;
......@@ -642,7 +641,8 @@ Checker<Convolution> checker(handle);
oly.dtype = out_type;
opr->deduce_layout(ily, fly, oly);
int channel_start = 1;
if (format) channel_start = 3;
if (format)
channel_start = 3;
float scale = 1.0f / sqrt(fshp[channel_start] * FH * FW);
UniformFloatRNG rng(scale, 2 * scale);
checker.set_rng(0, &rng).set_rng(1, &rng);
......@@ -653,11 +653,11 @@ Checker<Convolution> checker(handle);
construct_sub_execution_policy_heuristic<ConvolutionForward>(
opr->execution_policy(), {ily, fly, oly}, param_str,
opr->handle());
checker
.set_epsilon(eps_getter(dtype == 1, 0, algo.name.c_str()))
.execs({ishp, fshp, {}});
checker.set_epsilon(
eps_getter(dtype == 1, 0, algo.desc.name.c_str()))
.execs({ishp, fshp, {}});
opr->execution_policy() = {};
ASSERT_TRUE(checker.prev_succ()) << errmsg(algo.name.c_str());
ASSERT_TRUE(checker.prev_succ()) << errmsg(algo.desc.name.c_str());
}
if (test_backward) {
......@@ -671,7 +671,7 @@ Checker<Convolution> checker(handle);
opr->param() = param;
std::string param_str;
Algorithm::serialize_write_pod(opr->param(), param_str);
for (auto algo: opr->get_all_algorithms_info(fly, oly, ily)) {
for (auto algo : opr->get_all_algorithms_info(fly, oly, ily)) {
used_algos_bwd_data.insert(algo.desc);
opr->execution_policy().algo = algo.desc;
construct_sub_execution_policy_heuristic<
......@@ -679,26 +679,26 @@ Checker<Convolution> checker(handle);
{fly, oly, ily}, param_str,
opr->handle());
checker_bwd_data
.set_epsilon(eps_getter(dtype == 1, 1, algo.name.c_str()))
.execl({fly, oly, ily});
.set_epsilon(eps_getter(dtype == 1, 1,
algo.desc.name.c_str()))
.execl({fly, oly, ily});
opr->execution_policy() = {};
ASSERT_TRUE(checker_bwd_data.prev_succ()) <<
errmsg(algo.name.c_str());
ASSERT_TRUE(checker_bwd_data.prev_succ())
<< errmsg(algo.desc.name.c_str());
}
}
if (test_backward) {
// backward filter
checker_bwd_filter
.set_dtype(0, inp_type)
.set_dtype(1, out_type)
.set_dtype(2, inp_type)
.set_param(param);
checker_bwd_filter.set_dtype(0, inp_type)
.set_dtype(1, out_type)
.set_dtype(2, inp_type)
.set_param(param);
auto opr = checker_bwd_filter.opr();
opr->param() = param;
std::string param_str;
Algorithm::serialize_write_pod(opr->param(), param_str);
for (auto algo: opr->get_all_algorithms_info(ily, oly, fly)) {
for (auto algo : opr->get_all_algorithms_info(ily, oly, fly)) {
used_algos_bwd_flt.insert(algo.desc);
opr->execution_policy().algo = algo.desc;
construct_sub_execution_policy_heuristic<
......@@ -706,11 +706,12 @@ Checker<Convolution> checker(handle);
{ily, oly, fly}, param_str,
opr->handle());
checker_bwd_filter
.set_epsilon(eps_getter(dtype == 1, 2, algo.name.c_str()))
.execl({ily, oly, fly});
.set_epsilon(eps_getter(dtype == 1, 2,
algo.desc.name.c_str()))
.execl({ily, oly, fly});
opr->execution_policy() = {};
ASSERT_TRUE(checker_bwd_filter.prev_succ()) <<
errmsg(algo.name.c_str());
ASSERT_TRUE(checker_bwd_filter.prev_succ())
<< errmsg(algo.desc.name.c_str());
}
}
}
......
......@@ -400,7 +400,7 @@ struct OprProxyProfilingBase
megcoreSynchronize(opr->handle()->megcore_computing_handle());
timer.stop();
megdnn_log("%.3fms %s", timer.get_time_in_us() / 1e3,
algo.name.c_str());
algo.desc.name.c_str());
if (min_time > timer.get_time_in_us()) {
min_time = timer.get_time_in_us();
best_algo = algo.desc;
......@@ -522,7 +522,7 @@ struct OprWeightPreprocessProxyImpl : public OprProxyProfilingBase<Opr> {
megcoreSynchronize(opr->handle()->megcore_computing_handle());
timer.stop();
printf("%.3fms %s\n", timer.get_time_in_us() / 1e3,
algo.name.c_str());
algo.desc.name.c_str());
if (min_time > timer.get_time_in_us()) {
min_time = timer.get_time_in_us();
Base::target_execution_policy.algo = algo.desc;
......
......@@ -88,7 +88,7 @@ void test_multibatchsize(
A_tensor.layout(), B_tensor.layout(),
C_tensor.layout())) {
if (std::regex_match(
i.name.c_str(),
i.desc.name.c_str(),
std::regex("(" + std::string(algo) + ")(.*)"))) {
opr_reference->execution_policy().algo = i.desc;
break;
......@@ -117,7 +117,7 @@ void test_multibatchsize(
A_tensor_prime.layout(), B_tensor.layout(),
C_tensor_batch.layout())) {
if (std::regex_match(
i.name.c_str(),
i.desc.name.c_str(),
std::regex("(" + std::string(algo) + ")(.*)"))) {
opr_reference->execution_policy().algo = i.desc;
break;
......
......@@ -318,7 +318,7 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx,
Maybe<AlgoChooserProfileCache::ResultEntry> cur_rst;
std::string msg = ssprintf("profiling %s algorithm %s %s",
ctx.mgb_opr()->dyn_typeinfo()->name,
algo.name.c_str(), layouts_str.c_str());
algo.desc.name.c_str(), layouts_str.c_str());
ImplExecutionPolicy policy;
policy.algo = algo.desc;
ctx.construct_execution_policy(selected_strategy, policy);
......@@ -327,12 +327,12 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx,
}
auto palgo = ctx.megdnn_opr()->get_algorithm_from_desc(policy.algo);
if (!(palgo->contain_attribute_all(target_attr.first) &&
!palgo->contain_attribute_any(target_attr.second))) {
!palgo->contain_attribute_any(target_attr.second))) {
mgb_log_debug(
"skip algo %s with attribute(%s), which is not match the "
"profile strategy required contain attribute(%s) and not "
"contain attribute(%s).",
algo.name.c_str(),
algo.desc.name.c_str(),
Algorithm::attribute_str(palgo->attribute()).c_str(),
Algorithm::attribute_str(target_attr.first).c_str(),
Algorithm::attribute_str(target_attr.second).c_str());
......@@ -552,8 +552,8 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache(
auto&& prof = rst.val();
std::unordered_map<std::string, ImplAlgo> algo_map;
for (auto i : get_all_candidates()) {
auto ins = algo_map.emplace(i.name.c_str(), i);
mgb_assert(ins.second, "duplicated algo name: %s", i.name.c_str());
auto ins = algo_map.emplace(i.desc.name.c_str(), i);
mgb_assert(ins.second, "duplicated algo name: %s", i.desc.name.c_str());
}
if (prof.empty())
......
......@@ -41,8 +41,11 @@ std::string serialize_policy(const megdnn::ExecutionPolicy& policy) {
megdnn::Algorithm::serialize_write_pod(policy.algo.handle_type, ret);
megdnn::Algorithm::serialize_write_pod(policy.algo.type, ret);
uint32_t param_size = policy.algo.param.size();
uint32_t name_size = policy.algo.name.size();
megdnn::Algorithm::serialize_write_pod<uint32_t>(param_size, ret);
megdnn::Algorithm::serialize_write_pod<uint32_t>(name_size, ret);
ret += policy.algo.param;
ret += policy.algo.name;
//! serialize sub_policy
uint32_t size = policy.sub_policy.size();
......@@ -64,11 +67,17 @@ megdnn::ExecutionPolicy deserialize_policy(const char* buf, uint32_t size,
cb(ret.algo.type, uint32_t);
uint32_t param_size = 0;
uint32_t name_size = 0;
cb(param_size, uint32_t);
cb(name_size, uint32_t);
if (param_size > 0) {
ret.algo.param = std::string(buf + offset, param_size);
offset += param_size;
}
if (name_size > 0) {
ret.algo.name = std::string(buf + offset, name_size);
offset += name_size;
}
uint32_t nr_policy = 0;
cb(nr_policy, uint32_t);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册