提交 2ec7c167 编写于 作者: M Megvii Engine Team

feat(mgb/gopt): profiler support opr filter and var node filter

GitOrigin-RevId: 5f8d86687f6316a80cd89601361b308c13a62f59
上级 50ea5ae8
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "megbrain/opr/imgproc.h" #include "megbrain/opr/imgproc.h"
#include "megbrain/opr/nn_int.h" #include "megbrain/opr/nn_int.h"
#include "megbrain/opr/io.h" #include "megbrain/opr/io.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/plugin/base.h" #include "megbrain/plugin/base.h"
#include "megbrain/serialization/sereg.h" #include "megbrain/serialization/sereg.h"
...@@ -246,6 +247,8 @@ float ProfilerImpl::profile_operator(const OperatorNodeBase* opr, ...@@ -246,6 +247,8 @@ float ProfilerImpl::profile_operator(const OperatorNodeBase* opr,
} }
auto new_opr = serialization::copy_opr_shallow( auto new_opr = serialization::copy_opr_shallow(
*opr, new_inps, opr->config(), {graph.get()}); *opr, new_inps, opr->config(), {graph.get()});
if (!m_opr_filter(opr, new_opr))
return PROFILE_TIME_OUT;
auto y = new_opr->output(0); auto y = new_opr->output(0);
auto mark = MarkInputContiguous::make(SymbolVar(y)); auto mark = MarkInputContiguous::make(SymbolVar(y));
auto func = graph->compile({{mark, {}}}); auto func = graph->compile({{mark, {}}});
...@@ -338,6 +341,8 @@ float ProfilerImpl::profile_operator( ...@@ -338,6 +341,8 @@ float ProfilerImpl::profile_operator(
!mgb::gopt::intl::has_available_algo(new_inps, y->owner_opr())) !mgb::gopt::intl::has_available_algo(new_inps, y->owner_opr()))
return PROFILE_TIME_OUT; return PROFILE_TIME_OUT;
#endif #endif
if (!m_opr_filter(opr, y->owner_opr()))
return PROFILE_TIME_OUT;
auto mark = MarkInputContiguous::make(SymbolVar(y)); auto mark = MarkInputContiguous::make(SymbolVar(y));
auto func = graph->compile({{mark, {}}}); auto func = graph->compile({{mark, {}}});
auto new_opr = y->owner_opr(); auto new_opr = y->owner_opr();
...@@ -384,6 +389,9 @@ float ProfilerImpl::profile_var_node(const VarNode* var, ...@@ -384,6 +389,9 @@ float ProfilerImpl::profile_var_node(const VarNode* var,
auto builder = ReformatManager::instance().auto_aligned_reformat_featrue( auto builder = ReformatManager::instance().auto_aligned_reformat_featrue(
var, base_format, key); var, base_format, key);
auto y = builder({aligned_var.node()}); auto y = builder({aligned_var.node()});
if (!m_var_node_filter(var, aligned_tensor_shape, y->shape(),
TensorFormat{}))
return PROFILE_TIME_OUT;
ThinHashSet<OperatorNodeBase*> set; ThinHashSet<OperatorNodeBase*> set;
DepOprIter iter([&set](OperatorNodeBase* opr) { set.insert(opr); }); DepOprIter iter([&set](OperatorNodeBase* opr) { set.insert(opr); });
iter.add(y->owner_opr()); iter.add(y->owner_opr());
...@@ -503,6 +511,40 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile( ...@@ -503,6 +511,40 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile(
} }
/* ================== ProfilerBase =================*/ /* ================== ProfilerBase =================*/
ProfilerBase::ProfilerBase(float opr_threshold, float var_node_threshold)
: m_opr_threshold{opr_threshold},
m_var_node_threshold{var_node_threshold} {
m_opr_filter = [this](const OperatorNodeBase* opr,
OperatorNodeBase* new_opr) {
float comp1 = m_opr_footprint.get_computation(
const_cast<OperatorNodeBase*>(opr));
float comp2 = m_opr_footprint.get_computation(new_opr);
if (comp2 > m_opr_threshold * comp1)
return false;
return true;
};
m_var_node_filter = [this](const VarNode* var, TensorShape from,
TensorShape to, TensorFormat format) {
TensorFormat default_;
TensorLayout orig_ly, from_ly, to_ly;
if (format == default_) {
orig_ly = {var->shape(), var->dtype()};
from_ly = {from, var->dtype()};
to_ly = {to, var->dtype()};
} else {
orig_ly = {var->shape(), var->dtype(), format};
from_ly = {from, var->dtype(), format};
to_ly = {to, var->dtype(), format};
}
float orig_memory = orig_ly.span().dist_byte() * 2.f;
float reformat_memory =
from_ly.span().dist_byte() + to_ly.span().dist_byte();
if (reformat_memory > orig_memory * m_var_node_threshold)
return false;
return true;
};
}
std::string ProfilerBase::OperatorNodeRecord::to_string() const { std::string ProfilerBase::OperatorNodeRecord::to_string() const {
auto str = ssprintf("\nopr type: %s\nopr name: %s\ninputs:\n", auto str = ssprintf("\nopr type: %s\nopr name: %s\ninputs:\n",
opr->dyn_typeinfo()->name, opr->cname()); opr->dyn_typeinfo()->name, opr->cname());
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "megbrain/gopt/reformat_manager.h" #include "megbrain/gopt/reformat_manager.h"
#include "megbrain/gopt/subgraph_extractor.h" #include "megbrain/gopt/subgraph_extractor.h"
#include "megbrain/opr/dnn/convolution.h" #include "megbrain/opr/dnn/convolution.h"
#include "megbrain/plugin/opr_footprint.h"
namespace mgb { namespace mgb {
namespace gopt { namespace gopt {
...@@ -218,11 +219,27 @@ public: ...@@ -218,11 +219,27 @@ public:
/// A hashmap, that maps the var node to the costs of layout transform /// A hashmap, that maps the var node to the costs of layout transform
ThinHashMap<VarNode*, VarNodeRecord> var_record; ThinHashMap<VarNode*, VarNodeRecord> var_record;
}; };
using OprFilter = thin_function<bool(const cg::OperatorNodeBase*,
cg::OperatorNodeBase*)>;
using VarNodeFilter = thin_function<bool(const VarNode*, TensorShape,
TensorShape, TensorFormat)>;
ProfilerBase() = default; ProfilerBase(float opr_threshold = 2.f, float var_node_threshold = 2.f);
ProfilerBase(OprFilter opr_filter, VarNodeFilter var_node_filter = {})
: m_opr_filter{std::move(opr_filter)},
m_var_node_filter{std::move(var_node_filter)} {}
virtual ~ProfilerBase() = default; virtual ~ProfilerBase() = default;
virtual ProfilingResult profile(const Problem& problem) const = 0; virtual ProfilingResult profile(const Problem& problem) const = 0;
static std::unique_ptr<ProfilerBase> make_profiler(); static std::unique_ptr<ProfilerBase> make_profiler();
protected:
OprFilter m_opr_filter;
VarNodeFilter m_var_node_filter;
float m_opr_threshold;
float m_var_node_threshold;
private:
OprFootprint m_opr_footprint;
}; };
/*! /*!
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册