diff --git a/.gitattributes b/.gitattributes index e189efca867699e831a8eabe2e2d800420f3558d..be56319f3240269f821d0e40b49d3ab0dc021a99 100644 --- a/.gitattributes +++ b/.gitattributes @@ -12,3 +12,4 @@ ci/resource/models/float/mobilenet_v2.pkl filter=lfs diff=lfs merge=lfs -text ci/resource/models/float/shufflenet_v2.pkl filter=lfs diff=lfs merge=lfs -text ci/resource/dump/roi_align_backward_8.8.0.mdl filter=lfs diff=lfs merge=lfs -text ci/resource/dump/relayout_format_8.10.0.mdl filter=lfs diff=lfs merge=lfs -text +ci/resource/dump/batch_conv_bias_with_policy_8.8.0.mdl filter=lfs diff=lfs merge=lfs -text diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index 151d8a8f94bb6139ed613316882d7d8d1091fc00..240bbdbf6c3e9f20ef1d130bf3df34087ae73939 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -259,14 +259,14 @@ void init_graph_rt(py::module m) { return vars; }); - m.def("modify_opr_algo_strategy_inplace", [](const VarNodeArray& dest_vars, const std::string& strategy) { + m.def("modify_opr_algo_strategy_inplace", [](const VarNodeArray& dest_vars, + const std::string& strategy) { _AlgoStrategy stg; - const std::unordered_map> m{ - {"HEURISTIC", [&](){ stg = _AlgoStrategy::HEURISTIC; }}, - {"HEURISTIC_REPRODUCIBLE", [&](){ stg = _AlgoStrategy::HEURISTIC_REPRODUCIBLE; }}, - {"PROFILE", [&](){ stg = _AlgoStrategy::PROFILE; }}, - {"PROFILE_REPRODUCIBLE", [&](){ stg = _AlgoStrategy::PROFILE_REPRODUCIBLE; }}, - {"PROFILE_HEURISTIC", [&](){ stg = _AlgoStrategy::PROFILE_HEURISTIC; }}, + const std::unordered_map> m{ + {"HEURISTIC", [&]() { stg = _AlgoStrategy::HEURISTIC; }}, + {"PROFILE", [&]() { stg = _AlgoStrategy::PROFILE; }}, + {"REPRODUCIBLE", [&]() { stg = _AlgoStrategy::REPRODUCIBLE; }}, + {"OPTMIZED", [&]() { stg = _AlgoStrategy::OPTMIZED; }}, }; auto it = m.find(strategy); mgb_assert(it != m.end(), "Invalid strategy string!"); diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index 3b4a86087f5b525a74afbb28330f231fb97784b6..58acd3d448a540adb700424249012ffa3122c7f6 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -154,7 +154,7 @@ void gopt::modify_opr_algo_strategy_inplace( opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) { #if !MGB_ENABLE_FASTRUN using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; - if (strategy == S::PROFILE || strategy == S::PROFILE_REPRODUCIBLE) { + if ((strategy & S::PROFILE) && !(strategy & S::HEURISTIC)) { mgb_throw(MegBrainError, "fastrun is disabled at compile time"); } #endif