From 9da26407ea6af313fe60b2756f6813ef5bbe8e9c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 11 Mar 2021 17:00:05 +0800 Subject: [PATCH] feat(ci): and model to test algo policy compatible GitOrigin-RevId: e58caf08c8488cf2416b24c9deb5a7a50dab89bf --- .gitattributes | 1 + imperative/python/src/graph_rt.cpp | 14 +++++++------- src/gopt/impl/inference.cpp | 2 +- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/.gitattributes b/.gitattributes index e189efca..be56319f 100644 --- a/.gitattributes +++ b/.gitattributes @@ -12,3 +12,4 @@ ci/resource/models/float/mobilenet_v2.pkl filter=lfs diff=lfs merge=lfs -text ci/resource/models/float/shufflenet_v2.pkl filter=lfs diff=lfs merge=lfs -text ci/resource/dump/roi_align_backward_8.8.0.mdl filter=lfs diff=lfs merge=lfs -text ci/resource/dump/relayout_format_8.10.0.mdl filter=lfs diff=lfs merge=lfs -text +ci/resource/dump/batch_conv_bias_with_policy_8.8.0.mdl filter=lfs diff=lfs merge=lfs -text diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index 151d8a8f..240bbdbf 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -259,14 +259,14 @@ void init_graph_rt(py::module m) { return vars; }); - m.def("modify_opr_algo_strategy_inplace", [](const VarNodeArray& dest_vars, const std::string& strategy) { + m.def("modify_opr_algo_strategy_inplace", [](const VarNodeArray& dest_vars, + const std::string& strategy) { _AlgoStrategy stg; - const std::unordered_map> m{ - {"HEURISTIC", [&](){ stg = _AlgoStrategy::HEURISTIC; }}, - {"HEURISTIC_REPRODUCIBLE", [&](){ stg = _AlgoStrategy::HEURISTIC_REPRODUCIBLE; }}, - {"PROFILE", [&](){ stg = _AlgoStrategy::PROFILE; }}, - {"PROFILE_REPRODUCIBLE", [&](){ stg = _AlgoStrategy::PROFILE_REPRODUCIBLE; }}, - {"PROFILE_HEURISTIC", [&](){ stg = _AlgoStrategy::PROFILE_HEURISTIC; }}, + const std::unordered_map> m{ + {"HEURISTIC", [&]() { stg = _AlgoStrategy::HEURISTIC; }}, + {"PROFILE", [&]() { stg = _AlgoStrategy::PROFILE; }}, + {"REPRODUCIBLE", [&]() { stg = _AlgoStrategy::REPRODUCIBLE; }}, + {"OPTMIZED", [&]() { stg = _AlgoStrategy::OPTMIZED; }}, }; auto it = m.find(strategy); mgb_assert(it != m.end(), "Invalid strategy string!"); diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index 3b4a8608..58acd3d4 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -154,7 +154,7 @@ void gopt::modify_opr_algo_strategy_inplace( opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) { #if !MGB_ENABLE_FASTRUN using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; - if (strategy == S::PROFILE || strategy == S::PROFILE_REPRODUCIBLE) { + if ((strategy & S::PROFILE) && !(strategy & S::HEURISTIC)) { mgb_throw(MegBrainError, "fastrun is disabled at compile time"); } #endif -- GitLab