提交 e1c83d8d 编写于 作者: M Megvii Engine Team

fix(mgb/core): add warning information about const_var_shape when record mode

GitOrigin-RevId: a99f9c4e5ddf92c62aaa17fb7b1baf10b68a5411
上级 fc8b501b
......@@ -243,7 +243,7 @@ public:
};
using CompNodeBaseImpl = CpuCompNode::CompNodeBaseImpl;
using CompNodeNoRecorderImpl = CpuCompNode::CompNodeNoRecorderImpl;
using CompNodeDefaultImpl = CpuCompNode::CompNodeDefaultImpl;
using CompNodeRecorderImpl = CpuCompNode::CompNodeRecorderImpl;
//! ==================== CompNodeBaseImpl ======================
......@@ -466,29 +466,29 @@ public:
}
};
//! ==================== CompNodeNoRecorderImpl ======================
//! ==================== CompNodeDefaultImpl ======================
/**
* \note: CompNodeNoRecorderImpl will use most implements in base including:
* \note: CompNodeDefaultImpl will use most implements in base including:
* alloc_device, alloc_host, copy_to_host, copy_to_device, peer_copy_to,
* add_callback ...
*/
class CpuCompNode::CompNodeNoRecorderImpl final : public CompNodeBaseImpl {
class CpuCompNode::CompNodeDefaultImpl final : public CompNodeBaseImpl {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
//! ptr to default cpu, only used by check_global_finalized
static CompNodeNoRecorderImpl* sm_default_cpu_comp_node_ptr;
static CompNodeDefaultImpl* sm_default_cpu_comp_node_ptr;
static void static_free_device(ImplBase* self, void* ptr) {
static_cast<CompNodeNoRecorderImpl*>(self)->free_device(ptr);
static_cast<CompNodeDefaultImpl*>(self)->free_device(ptr);
}
static void static_free_host(ImplBase* self, void* ptr) {
static_cast<CompNodeNoRecorderImpl*>(self)->free_host(ptr);
static_cast<CompNodeDefaultImpl*>(self)->free_host(ptr);
}
using CpuEventImpl = CpuDispatchableBase::EventImpl;
CompNodeNoRecorderImpl(const Locator& locator,
CompNodeDefaultImpl(const Locator& locator,
const Locator& locator_logical)
: CompNodeBaseImpl(locator, locator_logical, static_free_device,
static_free_host) {
......@@ -501,7 +501,7 @@ public:
sm_default_cpu_comp_node_ptr = this;
}
~CompNodeNoRecorderImpl() {
~CompNodeDefaultImpl() {
m_env.fini();
sm_default_cpu_comp_node_ptr = nullptr;
}
......@@ -551,8 +551,8 @@ public:
SeqRecorderImpl* cur_recorder() const override { return nullptr; }
};
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CompNodeNoRecorderImpl);
CompNodeNoRecorderImpl* CompNodeNoRecorderImpl::sm_default_cpu_comp_node_ptr =
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CompNodeDefaultImpl);
CompNodeDefaultImpl* CompNodeDefaultImpl::sm_default_cpu_comp_node_ptr =
nullptr;
//! ==================== CompNodeRecorderImpl ======================
......@@ -746,7 +746,7 @@ public:
void peer_copy_to(Impl* dest_impl, void* dest, const void* src,
size_t size) override {
//! copy to default_cpu
if (dest_impl->same_type<CpuCompNode::CompNodeNoRecorderImpl>()) {
if (dest_impl->same_type<CpuCompNode::CompNodeDefaultImpl>()) {
CompNodeBaseImpl::peer_copy_to(dest_impl, dest, src, size);
return;
}
......@@ -986,7 +986,7 @@ void CpuCompNode::sync_all() {
// CpuCompNode::Pool
CompNode CompNode::default_cpu() {
static Locator locator{DeviceType::CPU, Locator::DEVICE_CPU_DEFAULT, {-1}};
static CompNodeNoRecorderImpl impl{locator, locator};
static CompNodeDefaultImpl impl{locator, locator};
return &impl;
}
......
......@@ -55,7 +55,7 @@ namespace mgb {
};
class CompNodeBaseImpl;
class CompNodeNoRecorderImpl;
class CompNodeDefaultImpl;
class CompNodeRecorderImpl;
static void foreach(thin_function<void(CompNode)> callback);
......
......@@ -11,6 +11,7 @@
#include "./cg_impl_seq.h"
#include "megbrain/graph/exc_extra_info.h"
#include "megbrain/opr/tensor_manip.h"
using namespace mgb;
using namespace cg;
......@@ -255,6 +256,22 @@ ComputingGraphImpl::ComputingSequence::check_enable_comp_node_seq_recorder() {
}
}
}
auto check_const_shape = [&]() {
for (auto i : *m_opr_seq) {
for (auto j : i->output()) {
if (j->shape().ndim && !is_const_var_shape(j)) {
mgb_log_warn(
"Non-const var shape detected. Make sure all "
"shapes are constant. Check whether "
"'const_var_shape' is set "
"in GraphLoadConfig under record mode");
return;
}
}
}
};
check_const_shape();
auto cn = *m_used_comp_node.begin();
auto rec = cn.create_seq_recorder(m_owner_graph);
if (!rec) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册