提交 14d8b709 编写于 作者: M Megvii Engine Team

perf(mge/imperative): add mini graph to partially replace proxy graph

GitOrigin-RevId: 73e2529ba53ccb6c0607f52aee40e69e2c289343
上级 c294b9d1
......@@ -258,6 +258,9 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice =
MGB_LOCK_GUARD(m_mutex);
dest->value_fetched = ptr->value_fetched();
// update tensor desc for static infer
// if (dest->desc.layout.ndim) {
// mgb_assert(dest->desc.layout.eq_shape(ptr->layout()));
// }
dest->desc.layout = ptr->layout();
dest->desc.comp_node = ptr->comp_node();
dest->ptr = std::move(ptr);
......@@ -363,7 +366,7 @@ void ChannelImpl::regenerate(TensorInfo* info, bool must_drop = false) {
}
inputs.push_back(i->ptr);
}
auto outputs = OpDef::apply_on_physical_tensor(*path.op, inputs);
auto outputs = OpDef::apply_on_physical_tensor(*path.op, inputs);
for (size_t i = 0; i < outputs.size(); i ++) {
auto out_ptr = path.outputs[i].lock();
if (out_ptr) {
......
namespace mgb::imperative::proxy_graph {
// a "namespace" struct to simplify friend declaration,
// e.g. friend class mgb::imperative::proxy_graph::ProxyGraph
struct ProxyGraph {
struct InputPlaceholder;
struct MiniGraph;
};
} // namespace mgb::imperative::proxy_graph
此差异已折叠。
#include "./mini_graph.h"
// #include "../proxy_graph.h"
namespace mgb::imperative::proxy_graph {
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ProxyGraph::InputPlaceholder);
thread_local std::unique_ptr<ProxyGraphTypeI> ProxyGraphTypeI::sm_instance = {};
} // namespace mgb::imperative::proxy_graph
namespace mgb::imperative::proxy_graph_detail {
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) {
auto ret = proxy_graph::ProxyGraphTypeI::inst().infer_output_attrs_fallible(def, inputs);
// auto ref = ProxyGraph::get_default_graph()->infer_output_attrs_fallible(def, inputs);
// auto& [a, _1] = ret;
// auto& [b, _2] = ref;
// if (a.size() != b.size()) mgb_trap();
// for (size_t i = 0; i < a.size(); ++i) {
// if (a[i].layout.dtype != b[i].layout.dtype) mgb_trap();
// if (a[i].comp_node != b[i].comp_node) mgb_trap();
// if (!a[i].layout.eq_shape(b[i].layout)) mgb_trap();
// }
return ret;
}
} // namespace mgb::imperative::proxy_graph_detail
#include "megbrain/graph/cg.h"
namespace mgb::imperative::proxy_graph {
using cg::VarNode;
struct ExecEnvBase : cg::GraphExecutable::ExecEnv {
void dispatch_on_comp_node(CompNode, Task&& task) override {
task();
}
void dispatch_on_comp_node_with_mask(CompNode, Task&&, cg::ExecutionMask*) override {mgb_assert(0);}
void pause_exec() override {mgb_assert(0);}
void resume_exec() override {mgb_assert(0);}
};
struct StaticInferManagerBase : cg::static_infer::StaticInferManager {
protected:
void register_shape_infer(VarNode*, const cg::static_infer::ShapeInferDesc&) override {mgb_assert(0);};
void register_value_infer(VarNode*, const cg::static_infer::ValueInferDesc&) override {mgb_assert(0);};
cg::static_infer::InferType get_infer_type(VarNode*) override {mgb_assert(0);};
const TensorShape& infer_shape(VarNode*) override {mgb_assert(0);}
const TensorShape* infer_shape_fallible(VarNode*) override {mgb_assert(0);}
const DeviceTensorND& infer_value(VarNode*) override {mgb_assert(0);}
const DeviceTensorND* infer_value_fallible(VarNode*) override {mgb_assert(0);}
cg::static_infer::DepVal get_rt_static_source_deps(const cg::static_infer::DepElement&) override {mgb_assert(0);}
};
struct SeqCompNodeOptimizerBase : cg::SeqCompNodeOptimizer {
protected:
void register_stream_var(VarNode*, StreamPropType) override {}
void register_propagate_function(VarNode*, PropFunction) override {}
StreamPropType stream_prop_type(VarNode*) override {mgb_assert(0);}
};
struct ProxyGraphBase : cg::ComputingGraph {
private:
VarReceiverInfo m_var_receiver_info;
SeqCompNodeOptimizerBase m_seq_comp_node_optimizer;
StaticInferManagerBase m_static_infer_manager;
protected:
MemPool<VarNode> m_var_node_pool;
ProxyGraphBase() {
options().imperative_proxy_graph = true;
options().no_force_inplace = true;
options().log_level = 0;
m_var_receiver_info.dev_value = 1;
m_var_receiver_info.allow_empty_value = 1;
}
void* alloc_varnode_storage() override {
return m_var_node_pool.alloc_raw();
}
void free_varnode_storage(void* ptr) override {
m_var_node_pool.free_raw(ptr);
}
const VarReceiverInfo& var_receiver_in_current_comp_seq(const VarNode *var) const override {
return m_var_receiver_info;
}
cg::static_infer::StaticInferManager& static_infer_manager() override {
return m_static_infer_manager;
}
cg::SeqCompNodeOptimizer& seq_comp_node_optimizer() override {
return m_seq_comp_node_optimizer;
}
std::shared_ptr<void> on_comp_node_finalize() override {
return {};
}
std::unique_ptr<cg::AsyncExecutable> compile(const OutputSpec&) override {mgb_assert(0);}
SmallVector<std::unique_ptr<cg::AsyncExecutable>> compile_multi_part(const SmallVector<OutputSpec>&) override {mgb_assert(0);}
cg::AsyncExecutable* current_comp_seq() override {mgb_assert(0);}
std::string get_mem_allocation_info() const override {mgb_assert(0);}
VarNode* find_var_by_id(size_t) const override {mgb_assert(0);}
void share_device_memory_with(ComputingGraph&) override {mgb_assert(0);}
void set_device_memory_allocator(std::shared_ptr<cg::DeviceMemoryAllocator>) override {mgb_assert(0);}
size_t get_device_memory_size(CompNode) override {mgb_assert(0);}
size_t clear_device_memory() override {mgb_assert(0);}
void set_as_subgraph(ComputingGraph&) override {mgb_assert(0);}
void record_async_error(std::unique_ptr<MegBrainError>) override {mgb_assert(0);}
};
MGB_DEFINE_OPR_CLASS(
ProxyGraph::InputPlaceholder,
cg::OperatorNodeBase) // {
void on_output_comp_node_stream_changed() override {mgb_assert(0);}
void init_output_comp_node() override {}
void init_output_format() override {}
void init_output_dtype() override {}
void init_output_static_infer_desc() override {}
void init_output_mem_plan(bool) override {mgb_assert(0);}
void do_execute(ExecEnv&) override {mgb_assert(0);}
public:
InputPlaceholder(cg::ComputingGraph& graph)
: Super(&graph, {}, "placeholder", {}) {
add_output(None)->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC);
// never dedup
add_equivalence_component<ScalarHash<void*>>(this);
}
InputPlaceholder(cg::ComputingGraph& graph, DType dtype, CompNode cn)
: InputPlaceholder(graph) {
output(0)->dtype(dtype).comp_node(cn);
}
};
using InputPlaceholder = ProxyGraph::InputPlaceholder;
} // namespace mgb::imperative::proxy_graph
......@@ -80,11 +80,11 @@ apply_on_physical_tensor(const OpDef& def,
return outputs;
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) {
auto&& graph = ProxyGraph::get_default_graph();
return graph->infer_output_attrs_fallible(def, inputs);
}
// std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def,
// const SmallVector<LogicalTensorDesc>& inputs) {
// auto&& graph = ProxyGraph::get_default_graph();
// return graph->infer_output_attrs_fallible(def, inputs);
// }
namespace {
......
......@@ -89,10 +89,18 @@ public:
return m_blob->comp_node();
}
DType dtype() const {
return m_layout.dtype;
}
TensorLayout layout() const {
return m_layout;
}
const TensorShape& shape() const {
return m_layout;
}
DeviceTensorND dev_tensor();
static TensorPtr make_scalar(DTypeScalar value, CompNode cn);
......
......@@ -16,7 +16,10 @@
namespace mgb {
namespace imperative {
class ProxyGraph;
class ProxyGraph;
namespace proxy_graph {
class ProxyGraph;
} // namespace proxy_graph
} // namespace imperative
namespace cg {
......@@ -56,6 +59,7 @@ namespace static_infer {
friend class StaticInferManagerImpl;
friend class imperative::ProxyGraph;
friend class imperative::proxy_graph::ProxyGraph;
public:
/*!
......@@ -342,4 +346,3 @@ using StaticInferInpVal = static_infer::InpVal;
} // mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -23,7 +23,10 @@
namespace mgb {
namespace imperative {
class ProxyGraph;
class ProxyGraph;
namespace proxy_graph {
class ProxyGraph;
}
} // namespace imperative
namespace cg {
......@@ -587,6 +590,7 @@ class VarNode final: public GraphNodeBase {
friend class EagerEvalManager;
friend class MemAllocPlan;
friend class imperative::ProxyGraph;
friend class imperative::proxy_graph::ProxyGraph;
};
enum class VarNode::Flag : uint32_t {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册