From e19b9af19bf4b256de4e877754f64d9208afe152 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 9 Mar 2021 17:53:06 +0800 Subject: [PATCH] feat(imperative): add bit combined enum to python C extension GitOrigin-RevId: 92307dd2ca077ea5606657f7cb7b321fd0dc8129 --- imperative/python/src/ops.cpp | 123 +++++++++- imperative/tablegen/autogen.cpp | 227 +++++++++++++----- sdk/load-and-run/src/mgblar.cpp | 37 ++- src/core/include/megbrain/common.h | 7 - .../include/megbrain/graph/operator_node.h | 7 + src/opr/impl/search_policy/algo_chooser.cpp | 60 +++-- .../megbrain/opr/search_policy/algo_chooser.h | 12 +- 7 files changed, 359 insertions(+), 114 deletions(-) diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index 06100741e..8eb2c5a9b 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -73,7 +73,7 @@ PyTypeObject PyOpType(name); } \ } while (0) -template +template struct pyobj_convert_generic { static T from(PyObject* obj) { // TODO: remove this guard which is used for pybind11 implicit conversion @@ -87,7 +87,12 @@ struct pyobj_convert_generic { } }; -template +template +struct EnumTrait { + static constexpr bool is_bit_combined = false; +}; + +template PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) { PyObject* obj = type->tp_alloc(type, 0); T* self = reinterpret_cast(obj); @@ -203,9 +208,10 @@ struct EnumWrapper { } }; -template +template struct pyobj_convert_generic>>> { + std::enable_if_t> && + !EnumTrait::is_bit_combined>> { using Wrapper = EnumWrapper; static T from(PyObject* obj) { if (PyObject_TypeCheck(obj, &Wrapper::type)) { @@ -223,6 +229,115 @@ struct pyobj_convert_generic +struct BitCombinedEnumWrapper { + static_assert(std::is_enum_v); + PyObject_HEAD + T value; + static const char* name; + static PyTypeObject type; + static std::unordered_map type2str; + static std::unordered_map str2type; + static PyNumberMethods number_methods; + BitCombinedEnumWrapper() = default; + BitCombinedEnumWrapper(T v): value(v) {} + BitCombinedEnumWrapper(std::string&& str) + : BitCombinedEnumWrapper(str2type.at(normalize_enum(str))) {} + std::string to_string() const { + if (static_cast(value) == 0) { + return "None"; + } else { + auto ret = std::string(); + bool first = true; + for (uint32_t i = 0; i < 32; i++) { + uint32_t value_int = static_cast(value); + auto it = type2str.find(static_cast((1 << i) & value_int)); + if (it != type2str.end()) { + if (!first) { + ret += " + "; + } else { + first = false; + } + ret += (std::string(name) + "." + it->second); + } + } + return ret; + } + } + static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject*, PyObject*) { + PyObject* obj = type->tp_alloc(type, 0); + reinterpret_cast(obj)->value = static_cast(1); + return obj; + } + static int py_init(PyObject* self, PyObject* args, PyObject*) { + int input = 1; + if (PyArg_ParseTuple(args, "|i", &input)){ + reinterpret_cast(self)->value = + static_cast(input); + } + return 0; + } + static PyObject* py_repr(PyObject* self) { + return pyobj_convert_generic::to( + reinterpret_cast(self)->to_string()); + } + static PyObject* py_or(PyObject* self, PyObject* other) { + if(!(self->ob_type == other->ob_type)){ + return PyErr_Format( + PyExc_RuntimeError, + "Operand in or operator must be the same type."); + } + PyObject* obj = type.tp_alloc(&type, 0); + T lhs = reinterpret_cast(self)->value, + rhs = reinterpret_cast(other)->value; + reinterpret_cast(obj)->value = static_cast( + static_cast(lhs) | static_cast(rhs)); + return obj; + } + static PyObject* py_and(PyObject* self, PyObject* other) { + if (!(self->ob_type == other->ob_type)) { + return PyErr_Format( + PyExc_RuntimeError, + "Operand in and operator must be the same type."); + } + PyObject* obj = type.tp_alloc(&type, 0); + T lhs = reinterpret_cast(self)->value, + rhs = reinterpret_cast(other)->value; + reinterpret_cast(obj)->value = static_cast( + static_cast(lhs) & static_cast(rhs)); + return obj; + } + static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) { + T lhs = reinterpret_cast(self)->value, + rhs = reinterpret_cast(other)->value; + if (op == Py_EQ || op == Py_NE) { + RETURN_RICHCOMPARE(lhs, rhs, op); + } + Py_RETURN_NOTIMPLEMENTED; + } +}; + +template +struct pyobj_convert_generic> && + EnumTrait::is_bit_combined>> { + using Wrapper = BitCombinedEnumWrapper; + static T from(PyObject* obj) { + if (PyObject_TypeCheck(obj, &Wrapper::type)) { + return reinterpret_cast(obj)->value; + } + // try as string + // TODO: type checkcd + return Wrapper(pyobj_convert_generic::from(obj)).value; + } + static PyObject* to(T t) { + PyTypeObject* pytype = &Wrapper::type; + PyObject* obj = pytype->tp_alloc(pytype, 0); + reinterpret_cast(obj)->value = t; + return obj; + } +}; + void _init_py_op_def(py::module m) { using py_op = PyOp(OpDef); auto& py_type = PyOpType(OpDef); diff --git a/imperative/tablegen/autogen.cpp b/imperative/tablegen/autogen.cpp index 1e00f8f3d..44b3dabfb 100644 --- a/imperative/tablegen/autogen.cpp +++ b/imperative/tablegen/autogen.cpp @@ -408,61 +408,58 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& os << ";\n\n"; } -static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { - auto className = op.getCppClassName(); +static std::string gen_op_def_python_c_extension_enum( + raw_ostream& os, EnumContext& ctx, MgbEnumAttr* attr, + llvm::StringRef className) { std::string body; - - // generate PyType for enum class member - for (auto&& i : op.getMgbAttributes()) { - if (auto attr = llvm::dyn_cast(&i.attr)) { - unsigned int enumID; - if (auto alias = llvm::dyn_cast(attr)) { - auto&& aliasBase = alias->getAliasBase(); - enumID = - llvm::cast(aliasBase) - .getBaseRecord()->getID(); - } else { - enumID = attr->getBaseRecord()->getID(); - } - auto&& enumAlias = ctx.enumAlias; - auto&& iter = enumAlias.find(enumID); - auto enumName = attr->getEnumName(); - body += "{\n"; - body += formatv( - "auto& e_type = EnumWrapper<{0}::{1}>::type;", className, enumName - ); - if (iter == enumAlias.end()) { - os << formatv( - "template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n", - className, enumName); - os << formatv( - "template<> const char* EnumWrapper<{0}::{1}>::name = \"{0}.{1}\";\n", - className, enumName); - std::vector pairStr; - for (auto&& i: attr->getEnumMembers()) { - pairStr.push_back(formatv( - "{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", - className, enumName, i)); - } - os << formatv(R"( + unsigned int enumID; + if (auto alias = llvm::dyn_cast(attr)) { + auto&& aliasBase = alias->getAliasBase(); + enumID = llvm::cast(aliasBase).getBaseRecord()->getID(); + } else { + enumID = attr->getBaseRecord()->getID(); + } + auto&& enumAlias = ctx.enumAlias; + auto&& iter = enumAlias.find(enumID); + auto enumName = attr->getEnumName(); + body += "{\n"; + body += formatv("auto& e_type = EnumWrapper<{0}::{1}>::type;", className, + enumName); + if (iter == enumAlias.end()) { + os << formatv( + "template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n", + className, enumName); + os << formatv( + "template<> const char* EnumWrapper<{0}::{1}>::name = " + "\"{0}.{1}\";\n", + className, enumName); + std::vector pairStr; + for (auto&& i : attr->getEnumMembers()) { + pairStr.push_back( + formatv("{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", + className, enumName, i)); + } + os << formatv(R"( template<> std::unordered_map EnumWrapper<{0}::{1}>::str2type = {{ {2} }; -)", className, enumName, llvm::join(pairStr, ", ")); - pairStr.clear(); - for (auto&& i: attr->getEnumMembers()) { - pairStr.push_back(formatv( - "{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", - className, enumName, i)); - } - os << formatv(R"( +)", + className, enumName, llvm::join(pairStr, ", ")); + pairStr.clear(); + for (auto&& i : attr->getEnumMembers()) { + pairStr.push_back( + formatv("{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", + className, enumName, i)); + } + os << formatv(R"( template<> std::unordered_map<{0}::{1}, std::string> EnumWrapper<{0}::{1}>::type2str = {{ {2} }; -)", className, enumName, llvm::join(pairStr, ", ")); - body += formatv(R"( +)", + className, enumName, llvm::join(pairStr, ", ")); + body += formatv(R"( e_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}"; e_type.tp_basicsize = sizeof(EnumWrapper<{0}::{1}>); @@ -472,22 +469,140 @@ EnumWrapper<{0}::{1}>::type2str = {{ e_type.tp_repr = EnumWrapper<{0}::{1}>::py_repr; e_type.tp_richcompare = EnumWrapper<{0}::{1}>::tp_richcompare; mgb_assert(PyType_Ready(&e_type) >= 0); -)", className, enumName); - for (auto&& i: attr->getEnumMembers()) { - body += formatv(R"({{ +)", + className, enumName); + for (auto&& i : attr->getEnumMembers()) { + body += formatv(R"({{ PyObject* inst = e_type.tp_alloc(&e_type, 0); reinterpret_cast*>(inst)->value = {0}::{1}::{2}; mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0); -})", className, enumName, i); - } - enumAlias.emplace(enumID, std::make_pair(className, enumName)); - } - body += formatv(R"( +})", + className, enumName, i); + } + enumAlias.emplace(enumID, std::make_pair(className, enumName)); + } + body += formatv(R"( + PyType_Modified(&e_type); + mgb_assert(PyDict_SetItemString( + py_type.tp_dict, "{0}", reinterpret_cast(&e_type)) >= 0); +)", + enumName); + body += "}\n"; + return body; +} + +static std::string gen_op_def_python_c_extension_bit_combined_enum( + raw_ostream& os, EnumContext& ctx, MgbEnumAttr* attr, + llvm::StringRef className) { + std::string body; + unsigned int enumID; + if (auto alias = llvm::dyn_cast(attr)) { + auto&& aliasBase = alias->getAliasBase(); + enumID = llvm::cast(aliasBase).getBaseRecord()->getID(); + } else { + enumID = attr->getBaseRecord()->getID(); + } + auto&& enumAlias = ctx.enumAlias; + auto&& iter = enumAlias.find(enumID); + auto enumName = attr->getEnumName(); + body += "{\n"; + body += formatv("auto& e_type = BitCombinedEnumWrapper<{0}::{1}>::type;", + className, enumName); + if (iter == enumAlias.end()) { + os << formatv( + "template<> PyTypeObject " + "BitCombinedEnumWrapper<{0}::{1}>::type={{};\n", + className, enumName); + os << formatv( + "template<> PyNumberMethods " + "BitCombinedEnumWrapper<{0}::{1}>::number_methods={{};\n", + className, enumName); + os << formatv( + "template<> const char* BitCombinedEnumWrapper<{0}::{1}>::name " + "= \"{0}.{1}\";\n", + className, enumName); + os << formatv( + "template<> struct EnumTrait<{0}::{1}> {{ static constexpr " + "bool is_bit_combined = true;};\n", + className, enumName); + std::vector pairStr; + for (auto&& i : attr->getEnumMembers()) { + pairStr.push_back( + formatv("{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", + className, enumName, i)); + } + os << formatv(R"( +template<> std::unordered_map +BitCombinedEnumWrapper<{0}::{1}>::str2type = {{ + {2} +}; +)", + className, enumName, llvm::join(pairStr, ", ")); + pairStr.clear(); + for (auto&& i : attr->getEnumMembers()) { + pairStr.push_back( + formatv("{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", + className, enumName, i)); + } + os << formatv(R"( +template<> std::unordered_map<{0}::{1}, std::string> +BitCombinedEnumWrapper<{0}::{1}>::type2str = {{ + {2} +}; +)", + className, enumName, llvm::join(pairStr, ", ")); + body += formatv(R"( + e_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; + e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}"; + e_type.tp_basicsize = sizeof(BitCombinedEnumWrapper<{0}::{1}>); + e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; + e_type.tp_doc = "{0}.{1}"; + e_type.tp_base = &PyBaseObject_Type; + e_type.tp_new = BitCombinedEnumWrapper<{0}::{1}>::py_new_combined_enum; + e_type.tp_init = BitCombinedEnumWrapper<{0}::{1}>::py_init; + e_type.tp_repr = BitCombinedEnumWrapper<{0}::{1}>::py_repr; + e_type.tp_richcompare = BitCombinedEnumWrapper<{0}::{1}>::tp_richcompare; + auto& number_method = BitCombinedEnumWrapper<{0}::{1}>::number_methods; + number_method.nb_or = BitCombinedEnumWrapper<{0}::{1}>::py_or; + number_method.nb_and = BitCombinedEnumWrapper<{0}::{1}>::py_and; + e_type.tp_as_number = &number_method; + mgb_assert(PyType_Ready(&e_type) >= 0); +)", + className, enumName); + for (auto&& i : attr->getEnumMembers()) { + body += formatv(R"({{ + PyObject* inst = e_type.tp_alloc(&e_type, 0); + reinterpret_cast*>(inst)->value = {0}::{1}::{2}; + mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0); +})", + className, enumName, i); + } + enumAlias.emplace(enumID, std::make_pair(className, enumName)); + } + body += formatv(R"( PyType_Modified(&e_type); mgb_assert(PyDict_SetItemString( py_type.tp_dict, "{0}", reinterpret_cast(&e_type)) >= 0); -)", enumName); - body += "}\n"; +)", + enumName); + body += "}\n"; + return body; +} + +static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { + auto className = op.getCppClassName(); + std::string body; + + // generate PyType for enum class member + for (auto&& i : op.getMgbAttributes()) { + if (auto attr = llvm::dyn_cast(&i.attr)) { + if (attr->getEnumCombinedFlag()) { + body += gen_op_def_python_c_extension_bit_combined_enum( + os, ctx, attr, className); + } else { + body += gen_op_def_python_c_extension_enum(os, ctx, attr, + className); + } } } diff --git a/sdk/load-and-run/src/mgblar.cpp b/sdk/load-and-run/src/mgblar.cpp index f25a0e374..e51ffa6a2 100644 --- a/sdk/load-and-run/src/mgblar.cpp +++ b/sdk/load-and-run/src/mgblar.cpp @@ -141,15 +141,13 @@ R"__usage__( )__usage__" #if MGB_ENABLE_FASTRUN R"__usage__( - --fast-run - This param will be deperated later, please replace with param --full-profile. - --full-profile - Enable full-profile mode. Operators with multiple algorithms would be profiled + --full-run + Enable full-run mode. Operators with multiple algorithms would be profiled on the real device with actual input shapes, all algorithms will be profiled include naive algorithms. See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details. - --fast-profile - Enable fast-profile mode. Operators with multiple algorithms would be profiled + --fast-run + Enable fast-run mode. Operators with multiple algorithms would be profiled on the real device with actual input shapes, this mode will only profile the well optimized algorithms to get the profile result fast. See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details. @@ -519,8 +517,8 @@ struct Args { bool disable_assert_throw = false; bool share_param_mem = false; #if MGB_ENABLE_FASTRUN - bool use_full_profile = false; - bool use_fast_profile = false; + bool use_full_run = false; + bool use_fast_run = false; #endif bool reproducible = false; std::string fast_run_cache_path; @@ -704,13 +702,13 @@ void run_test_st(Args &env) { using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; S strategy = S::HEURISTIC; #if MGB_ENABLE_FASTRUN - if (env.use_full_profile) { + if (env.use_full_run) { if (env.reproducible) { strategy = S::PROFILE | S::REPRODUCIBLE; } else { strategy = S::PROFILE; } - } else if (env.use_fast_profile) { + } else if (env.use_fast_run) { strategy = S::PROFILE | S::OPTMIZED; } else if (env.reproducible) { strategy = S::HEURISTIC | S::REPRODUCIBLE; @@ -740,12 +738,12 @@ void run_test_st(Args &env) { std::make_shared(buf.get(), flen)); #if MGB_ENABLE_FASTRUN } else { - mgb_assert(env.use_full_profile || env.use_fast_profile, - "fast-run or fast-profile should be enabled"); + mgb_assert(env.use_full_run || env.use_fast_run, + "fast-run or fast-run should be enabled"); PersistentCache::set_impl( std::make_shared()); } - if (!env.use_full_profile && !env.use_fast_profile) + if (!env.use_full_run && !env.use_fast_run) #endif mgb::gopt::enable_opr_use_profiling_cache_inplace(vars); } @@ -1326,18 +1324,11 @@ Args Args::from_argv(int argc, char **argv) { } #if MGB_ENABLE_FASTRUN if (!strcmp(argv[i], "--fast-run")) { - mgb_log_warn( - "--fast-run param will be deperated later, please replace " - "with --full-profile or --fast-profile."); - ret.use_full_profile = true; + ret.use_fast_run = true; continue; } - if (!strcmp(argv[i], "--full-profile")) { - ret.use_full_profile = true; - continue; - } - if (!strcmp(argv[i], "--fast-profile")) { - ret.use_fast_profile = true; + if (!strcmp(argv[i], "--full-run")) { + ret.use_full_run = true; continue; } #endif diff --git a/src/core/include/megbrain/common.h b/src/core/include/megbrain/common.h index 085ff4144..cb2781e30 100644 --- a/src/core/include/megbrain/common.h +++ b/src/core/include/megbrain/common.h @@ -12,7 +12,6 @@ #pragma once #include "megbrain_build_config.h" -#include "megbrain/opr/param_defs.h" #include "megdnn/basic_types.h" #include @@ -250,10 +249,4 @@ inline constexpr std::size_t operator"" _z(unsigned long long n) { } // namespace mgb -namespace megdnn { -namespace param { -MGB_DEF_ENUM_CLASS_BIT_OPR(ExecutionPolicy::Strategy) -} -} // namespace megdnn - // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/core/include/megbrain/graph/operator_node.h b/src/core/include/megbrain/graph/operator_node.h index 27c597416..021e255f9 100644 --- a/src/core/include/megbrain/graph/operator_node.h +++ b/src/core/include/megbrain/graph/operator_node.h @@ -18,6 +18,7 @@ #include "megbrain/utils/hashable.h" #include "megbrain/utils/thin/hash_table.h" #include "megbrain/utils/small_vector.h" +#include "megbrain/opr/param_defs.h" #include @@ -1043,4 +1044,10 @@ MGB_DEFINE_CLS_WITH_SUPER(_name final, _base ,##__VA_ARGS__) \ } // namespace cg } // namespace mgb +namespace megdnn { +namespace param { +MGB_DEF_ENUM_CLASS_BIT_OPR(ExecutionPolicy::Strategy) +} +} // namespace megdnn + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/search_policy/algo_chooser.cpp b/src/opr/impl/search_policy/algo_chooser.cpp index 22c491557..9f712c404 100644 --- a/src/opr/impl/search_policy/algo_chooser.cpp +++ b/src/opr/impl/search_policy/algo_chooser.cpp @@ -278,6 +278,19 @@ std::vector flatten_search_space( return ret; } +//! Test whether the algo attribute of a algo match the require +//! algo_strategy +static bool algo_attribute_match_strategy(AlgoAttribute attribute, + ExecutionStrategy selected_strategy) { + bool ret = true; + if (selected_strategy & ExecutionStrategy::OPTMIZED) { + ret &= (!static_cast(AlgoAttribute::NAIVE & attribute)); + } else if (selected_strategy & ExecutionStrategy::REPRODUCIBLE) { + ret &= static_cast(AlgoAttribute::REPRODUCIBLE & attribute); + } + return ret; +} + } // namespace namespace mgb { @@ -285,8 +298,8 @@ namespace opr { template void AlgoChooser::profile(ExeContext& ctx, - ExecutionStrategy select_strategy) { - if (ctx.get_profile_result_from_cache(select_strategy).valid()) + ExecutionStrategy selected_strategy) { + if (ctx.get_profile_result_from_cache(selected_strategy).valid()) return; AlgoChooserProfileCache::Result prof_rst; @@ -306,9 +319,19 @@ void AlgoChooser::profile(ExeContext& ctx, algo.name.c_str(), str_on_inp_shape.c_str()); ImplExecutionPolicy policy; policy.algo = algo.desc; - ctx.construct_execution_policy(select_strategy, policy); - if (ctx.get_workspace_size_bytes(policy) >= workspace_limit) + ctx.construct_execution_policy(selected_strategy, policy); + if (ctx.get_workspace_size_bytes(policy) >= workspace_limit) { continue; + } + auto algo_attribute = ctx.megdnn_opr() + ->get_algorithm_from_desc(policy.algo) + ->attribute(); + if (!algo_attribute_match_strategy(algo_attribute, selected_strategy)) { + mgb_log_debug( + "skip algo %s, which is not match the profile strategy.", + algo.name.c_str()); + continue; + } timer.reset(); MGB_TRY { cur_rst = ctx.profile_single_algo(policy, cur_timeout); } @@ -356,7 +379,7 @@ void AlgoChooser::profile(ExeContext& ctx, template typename AlgoChooser::ImplExecutionPolicy AlgoChooser::choose_by_profile(ExeContext& ctx, - ExecutionStrategy select_strategy, + ExecutionStrategy selected_strategy, bool enable_update) { MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("AlgoChooser::choose_by_profile"))) if (ctx.owner_graph()->options().no_profiling_on_shape_change) { @@ -378,11 +401,11 @@ AlgoChooser::choose_by_profile(ExeContext& ctx, to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(), _item.param, ctx.mgb_opr(), ctx.comp_node(), ctx.execution_policy(), ctx.allow_weight_preprocess()); - AlgoChooser<_Opr>::profile(sub_ctx, select_strategy); + AlgoChooser<_Opr>::profile(sub_ctx, selected_strategy); }); } typename AlgoChooser::ImplExecutionPolicy policy; - ctx.construct_execution_policy(select_strategy, policy); + ctx.construct_execution_policy(selected_strategy, policy); return policy; MIDOUT_E } @@ -440,7 +463,8 @@ typename AlgoChooser::ImplExecutionPolicy AlgoChooser::get_policy( if (!policy.algo.valid()) policy = ctx.choose_by_heuristic(opr_strategy); return policy; - } else if ((opr_strategy & ExecutionStrategy::HEURISTIC)) { + } else if (!static_cast(opr_strategy) || + (opr_strategy & ExecutionStrategy::HEURISTIC)) { return ctx.choose_by_heuristic(opr_strategy); } #if MGB_ENABLE_FASTRUN @@ -449,7 +473,7 @@ typename AlgoChooser::ImplExecutionPolicy AlgoChooser::get_policy( } #endif else { - mgb_throw(GraphError, "bad convolution ExecutionPolicy strategy"); + mgb_throw(GraphError, "bad ExecutionPolicy strategy"); } } @@ -495,7 +519,7 @@ AlgoChooser::ExeContext::ExeContext( template typename AlgoChooser::ImplAlgo AlgoChooser::ExeContext::get_profile_result_from_cache( - ExecutionStrategy select_strategy) const { + ExecutionStrategy selected_strategy) const { MIDOUT_B(Opr, midout_iv(MGB_HASH_STR( "AlgoChooser::ExeContext::get_profile_result_from_cache"))) @@ -519,7 +543,7 @@ AlgoChooser::ExeContext::get_profile_result_from_cache( if (prof.empty()) return {}; for (auto&& i : prof) { - if (!(select_strategy & ExecutionStrategy::REPRODUCIBLE) || + if (!(selected_strategy & ExecutionStrategy::REPRODUCIBLE) || static_cast(i.attribute) & AlgoAttribute::REPRODUCIBLE) { auto iter = algo_map.find(i.algo); @@ -550,7 +574,7 @@ AlgoChooser::ExeContext::get_profile_result_from_cache( template typename AlgoChooser::ImplExecutionPolicy AlgoChooser::ExeContext::choose_by_heuristic( - ExecutionStrategy select_strategy) const { + ExecutionStrategy selected_strategy) const { if (m_execution_policy.workspace_limit != std::numeric_limits::max()) { @@ -558,7 +582,7 @@ AlgoChooser::ExeContext::choose_by_heuristic( "workspace_limit should not be setted if choose algo by " "heuristic"); } - bool reproducible = static_cast(select_strategy & + bool reproducible = static_cast(selected_strategy & ExecutionStrategy::REPRODUCIBLE); auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( owner_graph(), m_cn, m_execution_policy.workspace_limit); @@ -582,7 +606,7 @@ AlgoChooser::ExeContext::choose_by_heuristic( _item.param, m_base_mgb_opr, m_cn, m_execution_policy, m_allow_weight_preprocess); policy.sub_policy.push_back( - sub_ctx.choose_by_heuristic(select_strategy)); + sub_ctx.choose_by_heuristic(selected_strategy)); }); return policy; @@ -613,15 +637,15 @@ AlgoChooser::ExeContext::get_all_candidates() const { template void AlgoChooser::ExeContext::construct_execution_policy( - ExecutionStrategy select_strategy, + ExecutionStrategy selected_strategy, typename AlgoChooser::ImplExecutionPolicy& policy, bool retrive_from_cache) const { - bool reproducible = static_cast(select_strategy & + bool reproducible = static_cast(selected_strategy & ExecutionStrategy::REPRODUCIBLE); if (!policy.algo.valid()) { if (retrive_from_cache) { policy.algo = - get_profile_result_from_cache(select_strategy).desc; + get_profile_result_from_cache(selected_strategy).desc; } else { auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( owner_graph(), m_cn, m_execution_policy.workspace_limit); @@ -651,7 +675,7 @@ void AlgoChooser::ExeContext::construct_execution_policy( _item.param, m_base_mgb_opr, m_cn, m_execution_policy, m_allow_weight_preprocess); policy.sub_policy.push_back({}); - sub_ctx.construct_execution_policy(select_strategy, + sub_ctx.construct_execution_policy(selected_strategy, policy.sub_policy.back(), retrive_from_cache); }); diff --git a/src/opr/include/megbrain/opr/search_policy/algo_chooser.h b/src/opr/include/megbrain/opr/search_policy/algo_chooser.h index a9af20813..bb193e18a 100644 --- a/src/opr/include/megbrain/opr/search_policy/algo_chooser.h +++ b/src/opr/include/megbrain/opr/search_policy/algo_chooser.h @@ -110,7 +110,7 @@ public: const FixedTensorLayouts& layouts() const { return m_layouts; } ImplExecutionPolicy choose_by_heuristic( - ExecutionStrategy select_strategy) const; + ExecutionStrategy selected_strategy) const; //! get all candidate algos, and the one choose_by_heuristic() is //! put first @@ -134,17 +134,17 @@ public: //! get all profile algorithm from cache, return invalid if not exists ImplAlgo get_profile_result_from_cache( - ExecutionStrategy select_strategy) const; + ExecutionStrategy selected_strategy) const; /** * \brief construct execution policy from cache or heuristic. * - * \param select_strategy select algo which matched this strategy + * \param selected_strategy select algo which matched this strategy * \param policy execution policy * \param retrive_from_cache retrive algo from cache if set True, get * from heuristic otherwise. */ - void construct_execution_policy(ExecutionStrategy select_strategy, + void construct_execution_policy(ExecutionStrategy selected_strategy, ImplExecutionPolicy& policy, bool retrive_from_cache = true) const; @@ -161,10 +161,10 @@ private: //! profile and save to cache - static void profile(ExeContext& ctx, ExecutionStrategy select_strategy); + static void profile(ExeContext& ctx, ExecutionStrategy selected_strategy); static ImplExecutionPolicy choose_by_profile( - ExeContext& ctx, ExecutionStrategy select_strategy, + ExeContext& ctx, ExecutionStrategy selected_strategy, bool enable_update = true); public: -- GitLab