提交 e1c83d8d 编写于 作者: M Megvii Engine Team

fix(mgb/core): add warning information about const_var_shape when record mode

GitOrigin-RevId: a99f9c4e5ddf92c62aaa17fb7b1baf10b68a5411
上级 fc8b501b
...@@ -243,7 +243,7 @@ public: ...@@ -243,7 +243,7 @@ public:
}; };
using CompNodeBaseImpl = CpuCompNode::CompNodeBaseImpl; using CompNodeBaseImpl = CpuCompNode::CompNodeBaseImpl;
using CompNodeNoRecorderImpl = CpuCompNode::CompNodeNoRecorderImpl; using CompNodeDefaultImpl = CpuCompNode::CompNodeDefaultImpl;
using CompNodeRecorderImpl = CpuCompNode::CompNodeRecorderImpl; using CompNodeRecorderImpl = CpuCompNode::CompNodeRecorderImpl;
//! ==================== CompNodeBaseImpl ====================== //! ==================== CompNodeBaseImpl ======================
...@@ -466,29 +466,29 @@ public: ...@@ -466,29 +466,29 @@ public:
} }
}; };
//! ==================== CompNodeNoRecorderImpl ====================== //! ==================== CompNodeDefaultImpl ======================
/** /**
* \note: CompNodeNoRecorderImpl will use most implements in base including: * \note: CompNodeDefaultImpl will use most implements in base including:
* alloc_device, alloc_host, copy_to_host, copy_to_device, peer_copy_to, * alloc_device, alloc_host, copy_to_host, copy_to_device, peer_copy_to,
* add_callback ... * add_callback ...
*/ */
class CpuCompNode::CompNodeNoRecorderImpl final : public CompNodeBaseImpl { class CpuCompNode::CompNodeDefaultImpl final : public CompNodeBaseImpl {
MGB_DYN_TYPE_OBJ_FINAL_DECL; MGB_DYN_TYPE_OBJ_FINAL_DECL;
public: public:
//! ptr to default cpu, only used by check_global_finalized //! ptr to default cpu, only used by check_global_finalized
static CompNodeNoRecorderImpl* sm_default_cpu_comp_node_ptr; static CompNodeDefaultImpl* sm_default_cpu_comp_node_ptr;
static void static_free_device(ImplBase* self, void* ptr) { static void static_free_device(ImplBase* self, void* ptr) {
static_cast<CompNodeNoRecorderImpl*>(self)->free_device(ptr); static_cast<CompNodeDefaultImpl*>(self)->free_device(ptr);
} }
static void static_free_host(ImplBase* self, void* ptr) { static void static_free_host(ImplBase* self, void* ptr) {
static_cast<CompNodeNoRecorderImpl*>(self)->free_host(ptr); static_cast<CompNodeDefaultImpl*>(self)->free_host(ptr);
} }
using CpuEventImpl = CpuDispatchableBase::EventImpl; using CpuEventImpl = CpuDispatchableBase::EventImpl;
CompNodeNoRecorderImpl(const Locator& locator, CompNodeDefaultImpl(const Locator& locator,
const Locator& locator_logical) const Locator& locator_logical)
: CompNodeBaseImpl(locator, locator_logical, static_free_device, : CompNodeBaseImpl(locator, locator_logical, static_free_device,
static_free_host) { static_free_host) {
...@@ -501,7 +501,7 @@ public: ...@@ -501,7 +501,7 @@ public:
sm_default_cpu_comp_node_ptr = this; sm_default_cpu_comp_node_ptr = this;
} }
~CompNodeNoRecorderImpl() { ~CompNodeDefaultImpl() {
m_env.fini(); m_env.fini();
sm_default_cpu_comp_node_ptr = nullptr; sm_default_cpu_comp_node_ptr = nullptr;
} }
...@@ -551,8 +551,8 @@ public: ...@@ -551,8 +551,8 @@ public:
SeqRecorderImpl* cur_recorder() const override { return nullptr; } SeqRecorderImpl* cur_recorder() const override { return nullptr; }
}; };
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CompNodeNoRecorderImpl); MGB_DYN_TYPE_OBJ_FINAL_IMPL(CompNodeDefaultImpl);
CompNodeNoRecorderImpl* CompNodeNoRecorderImpl::sm_default_cpu_comp_node_ptr = CompNodeDefaultImpl* CompNodeDefaultImpl::sm_default_cpu_comp_node_ptr =
nullptr; nullptr;
//! ==================== CompNodeRecorderImpl ====================== //! ==================== CompNodeRecorderImpl ======================
...@@ -746,7 +746,7 @@ public: ...@@ -746,7 +746,7 @@ public:
void peer_copy_to(Impl* dest_impl, void* dest, const void* src, void peer_copy_to(Impl* dest_impl, void* dest, const void* src,
size_t size) override { size_t size) override {
//! copy to default_cpu //! copy to default_cpu
if (dest_impl->same_type<CpuCompNode::CompNodeNoRecorderImpl>()) { if (dest_impl->same_type<CpuCompNode::CompNodeDefaultImpl>()) {
CompNodeBaseImpl::peer_copy_to(dest_impl, dest, src, size); CompNodeBaseImpl::peer_copy_to(dest_impl, dest, src, size);
return; return;
} }
...@@ -986,7 +986,7 @@ void CpuCompNode::sync_all() { ...@@ -986,7 +986,7 @@ void CpuCompNode::sync_all() {
// CpuCompNode::Pool // CpuCompNode::Pool
CompNode CompNode::default_cpu() { CompNode CompNode::default_cpu() {
static Locator locator{DeviceType::CPU, Locator::DEVICE_CPU_DEFAULT, {-1}}; static Locator locator{DeviceType::CPU, Locator::DEVICE_CPU_DEFAULT, {-1}};
static CompNodeNoRecorderImpl impl{locator, locator}; static CompNodeDefaultImpl impl{locator, locator};
return &impl; return &impl;
} }
......
...@@ -55,7 +55,7 @@ namespace mgb { ...@@ -55,7 +55,7 @@ namespace mgb {
}; };
class CompNodeBaseImpl; class CompNodeBaseImpl;
class CompNodeNoRecorderImpl; class CompNodeDefaultImpl;
class CompNodeRecorderImpl; class CompNodeRecorderImpl;
static void foreach(thin_function<void(CompNode)> callback); static void foreach(thin_function<void(CompNode)> callback);
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "./cg_impl_seq.h" #include "./cg_impl_seq.h"
#include "megbrain/graph/exc_extra_info.h" #include "megbrain/graph/exc_extra_info.h"
#include "megbrain/opr/tensor_manip.h"
using namespace mgb; using namespace mgb;
using namespace cg; using namespace cg;
...@@ -255,6 +256,22 @@ ComputingGraphImpl::ComputingSequence::check_enable_comp_node_seq_recorder() { ...@@ -255,6 +256,22 @@ ComputingGraphImpl::ComputingSequence::check_enable_comp_node_seq_recorder() {
} }
} }
} }
auto check_const_shape = [&]() {
for (auto i : *m_opr_seq) {
for (auto j : i->output()) {
if (j->shape().ndim && !is_const_var_shape(j)) {
mgb_log_warn(
"Non-const var shape detected. Make sure all "
"shapes are constant. Check whether "
"'const_var_shape' is set "
"in GraphLoadConfig under record mode");
return;
}
}
}
};
check_const_shape();
auto cn = *m_used_comp_node.begin(); auto cn = *m_used_comp_node.begin();
auto rec = cn.create_seq_recorder(m_owner_graph); auto rec = cn.create_seq_recorder(m_owner_graph);
if (!rec) { if (!rec) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册