提交 40bab1ed 编写于 作者: M Megvii Engine Team

feat(log): opt log, enable mgb sdk log at opt build

more info: 16cd674c56

* change MGE_OVERRIDE_LOG_LEVEL to RUNTIME_OVERRIDE_LOG_LEVEL
* use ::std::getenv not MGB_GETENV for special ENV

GitOrigin-RevId: ee0f9c0f72e627c331c00100f6a21adc927081df
上级 c3a1ac3d
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#ifdef __ANDROID__ #ifdef __ANDROID__
#include <android/log.h> #include <android/log.h>
#include <sys/system_properties.h>
#endif #endif
using namespace mgb; using namespace mgb;
...@@ -32,11 +33,19 @@ LogLevel config_default_log_level() { ...@@ -32,11 +33,19 @@ LogLevel config_default_log_level() {
auto default_level = LogLevel::ERROR; auto default_level = LogLevel::ERROR;
//! env to config LogLevel //! env to config LogLevel
//! DEBUG = 0, INFO = 1, WARN = 2, ERROR = 3, NO_LOG = 4 //! DEBUG = 0, INFO = 1, WARN = 2, ERROR = 3, NO_LOG = 4
//! for example , export MGE_OVERRIDE_LOG_LEVEL=0, means set LogLevel to //! for example , export RUNTIME_OVERRIDE_LOG_LEVEL=0, means set LogLevel to
//! DEBUG //! DEBUG
if (auto env = MGB_GETENV("MGE_OVERRIDE_LOG_LEVEL")) if (auto env = ::std::getenv("RUNTIME_OVERRIDE_LOG_LEVEL"))
default_level = static_cast<LogLevel>(std::stoi(env)); default_level = static_cast<LogLevel>(std::stoi(env));
#ifdef __ANDROID__
//! special for Android prop, attention: getprop may need permission
char buf[PROP_VALUE_MAX];
if (__system_property_get("RUNTIME_OVERRIDE_LOG_LEVEL", buf) > 0) {
default_level = static_cast<LogLevel>(atoi(buf));
}
#endif
return default_level; return default_level;
} }
...@@ -155,7 +164,7 @@ void default_log_handler(LogLevel level, ...@@ -155,7 +164,7 @@ void default_log_handler(LogLevel level,
default: default:
android_level = ANDROID_LOG_ERROR; android_level = ANDROID_LOG_ERROR;
} }
__android_log_vprint(android_level, "megbrain", fmt, ap); __android_log_vprint(android_level, "runtime", fmt, ap);
#endif #endif
#undef HDR_FMT #undef HDR_FMT
...@@ -185,7 +194,7 @@ class MegDNNLogHandler { ...@@ -185,7 +194,7 @@ class MegDNNLogHandler {
return; return;
} }
std::string new_fmt{"[megdnn] "}; std::string new_fmt{"[dnn] "};
new_fmt.append(fmt); new_fmt.append(fmt);
log_handler(mgb_level, file, func, line, new_fmt.c_str(), ap); log_handler(mgb_level, file, func, line, new_fmt.c_str(), ap);
} }
...@@ -238,9 +247,17 @@ namespace { ...@@ -238,9 +247,17 @@ namespace {
#endif // MGB_ENABLE_LOGGING #endif // MGB_ENABLE_LOGGING
LogLevel mgb::set_log_level(LogLevel level) { LogLevel mgb::set_log_level(LogLevel level) {
if (auto env = MGB_GETENV("MGE_OVERRIDE_LOG_LEVEL")) if (auto env = ::std::getenv("RUNTIME_OVERRIDE_LOG_LEVEL"))
level = static_cast<LogLevel>(std::stoi(env)); level = static_cast<LogLevel>(std::stoi(env));
#ifdef __ANDROID__
//! special for Android prop, attention: getprop may need permission
char buf[PROP_VALUE_MAX];
if (__system_property_get("RUNTIME_OVERRIDE_LOG_LEVEL", buf) > 0) {
level = static_cast<LogLevel>(atoi(buf));
}
#endif
auto ret = min_log_level; auto ret = min_log_level;
min_log_level = level; min_log_level = level;
return ret; return ret;
...@@ -256,7 +273,6 @@ LogHandler mgb::set_log_handler(LogHandler handler) { ...@@ -256,7 +273,6 @@ LogHandler mgb::set_log_handler(LogHandler handler) {
return ret; return ret;
} }
#if MGB_ASSERT_LOC
void mgb::__assert_fail__( void mgb::__assert_fail__(
const char *file, int line, const char *func, const char *file, int line, const char *func,
const char *expr, const char *msg_fmt, ...) { const char *expr, const char *msg_fmt, ...) {
...@@ -273,11 +289,6 @@ void mgb::__assert_fail__( ...@@ -273,11 +289,6 @@ void mgb::__assert_fail__(
} }
mgb_throw_raw(AssertionError{msg}); mgb_throw_raw(AssertionError{msg});
} }
#else
void mgb::__assert_fail__() {
mgb_throw(AssertionError, "assertion failed");
}
#endif
#if MGB_ENABLE_LOGGING && !MGB_ENABLE_EXCEPTION #if MGB_ENABLE_LOGGING && !MGB_ENABLE_EXCEPTION
void mgb::__on_exception_throw__(const std::exception &exc) { void mgb::__on_exception_throw__(const std::exception &exc) {
......
...@@ -759,7 +759,7 @@ public: ...@@ -759,7 +759,7 @@ public:
#else #else
mgb_throw(MegBrainError, mgb_throw(MegBrainError,
"Atlas comp_node used but " "Atlas comp_node used but "
"MGB_ATLAS not enabled"); "ATLAS BUILD not enabled");
#endif #endif
} else if (dest_impl->env().property().type == } else if (dest_impl->env().property().type ==
DeviceType::CAMBRICON) { DeviceType::CAMBRICON) {
...@@ -769,7 +769,7 @@ public: ...@@ -769,7 +769,7 @@ public:
#else #else
mgb_throw(MegBrainError, mgb_throw(MegBrainError,
"Cambricon comp_node used but " "Cambricon comp_node used but "
"MGB_CAMBRICON not enabled"); "CAMBRICON BUILD not enabled");
#endif #endif
} }
else { else {
...@@ -1035,7 +1035,7 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by( ...@@ -1035,7 +1035,7 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by(
return m_comp_node_impl->sync(); return m_comp_node_impl->sync();
#else #else
mgb_throw(MegBrainError, mgb_throw(MegBrainError,
"Atlas comp_node used but MGB_ATLAS not enabled"); "Atlas comp_node used but ATLAS BUILD not enabled");
#endif #endif
} else if (cn_impl->env().property().type == } else if (cn_impl->env().property().type ==
CompNode::DeviceType::CAMBRICON) { CompNode::DeviceType::CAMBRICON) {
......
...@@ -51,14 +51,13 @@ MegDNNHandle& MegDNNHandle::get(const CompNodeEnv& env) { ...@@ -51,14 +51,13 @@ MegDNNHandle& MegDNNHandle::get(const CompNodeEnv& env) {
MegDNNHandle::MegDNNHandle(const CompNodeEnv& env) { MegDNNHandle::MegDNNHandle(const CompNodeEnv& env) {
auto megdnn_version = megdnn::get_version(); auto megdnn_version = megdnn::get_version();
mgb_throw_if( mgb_throw_if(megdnn_version.major != MEGDNN_MAJOR ||
megdnn_version.major != MEGDNN_MAJOR || megdnn_version.minor < MEGDNN_MINOR,
megdnn_version.minor < MEGDNN_MINOR, SystemError,
SystemError, "incompatible dnn version: compiled with %d.%d, get %d.%d.%d "
"incompatible megdnn version: compiled with %d.%d, get %d.%d.%d " "at runtime",
"at runtime", MEGDNN_MAJOR, MEGDNN_MINOR, megdnn_version.major,
MEGDNN_MAJOR, MEGDNN_MINOR, megdnn_version.major, megdnn_version.minor, megdnn_version.patch);
megdnn_version.minor, megdnn_version.patch);
bool init = false; bool init = false;
#if MGB_CUDA #if MGB_CUDA
if (env.property().type == CompNode::DeviceType::CUDA) { if (env.property().type == CompNode::DeviceType::CUDA) {
......
...@@ -880,7 +880,7 @@ std::string ComputingGraphImpl::get_mem_allocation_info() const { ...@@ -880,7 +880,7 @@ std::string ComputingGraphImpl::get_mem_allocation_info() const {
return objlist->to_string(); return objlist->to_string();
#endif // MGB_ENABLE_JSON #endif // MGB_ENABLE_JSON
mgb_log_warn("mgb is not configured with MGB_ENABLE_JSON on," mgb_log_warn("target is not configured with JSON BUILD on,"
"get_mem_allocation_info returns null string"); "get_mem_allocation_info returns null string");
return std::string(); return std::string();
} }
......
...@@ -619,7 +619,7 @@ void ComputingGraphImpl::MegDNNDtorCheck::enable() { ...@@ -619,7 +619,7 @@ void ComputingGraphImpl::MegDNNDtorCheck::enable() {
mgb_assert(!m_enabled); mgb_assert(!m_enabled);
m_enabled = true; m_enabled = true;
auto cb_dnn = [](megdnn::OperatorBase* opr) { auto cb_dnn = [](megdnn::OperatorBase* opr) {
mgb_log_error("unexpected destruction of megdnn opr %p", opr); mgb_log_error("unexpected destruction of dnn opr %p", opr);
mgb_trap(); mgb_trap();
}; };
auto cb_mem = [](size_t alloc_size, bool, void* ptr) { auto cb_mem = [](size_t alloc_size, bool, void* ptr) {
......
...@@ -108,34 +108,33 @@ void __on_exception_throw__(const std::exception &exc) ...@@ -108,34 +108,33 @@ void __on_exception_throw__(const std::exception &exc)
} while(0) } while(0)
// assert // assert
void __assert_fail__(const char* file, int line, const char* func,
const char* expr, const char* msg_fmt = 0, ...)
__attribute__((format(printf, 5, 6), noreturn));
#if MGB_ASSERT_LOC #if MGB_ASSERT_LOC
/*! /*!
* \brief extended assert * \brief extended assert
* extra diagnostics message (in printf format) could be printed when assertion * extra diagnostics message (in printf format) could be printed when assertion
* fails; the asserted expression is guaranteed to be evaluated * fails; the asserted expression is guaranteed to be evaluated
*/ */
#define mgb_assert(expr, msg...) \ #define mgb_assert(expr, msg...) \
do { \ do { \
if (mgb_unlikely(!(expr))) \ if (mgb_unlikely(!(expr))) \
::mgb::__assert_fail__(__FILE__, __LINE__, \ ::mgb::__assert_fail__(__FILE__, __LINE__, __PRETTY_FUNCTION__, \
__PRETTY_FUNCTION__, # expr, ##msg); \ #expr, ##msg); \
} while(0) } while (0)
void __assert_fail__(
const char *file, int line, const char *func,
const char *expr, const char *msg_fmt = 0, ...)
__attribute__((format(printf, 5, 6), noreturn));
#else #else
#define mgb_assert(expr, msg...) \ #define mgb_assert(expr, msg...) \
do { \ do { \
if (mgb_unlikely(!(expr))) \ if (mgb_unlikely(!(expr))) \
::mgb::__assert_fail__(); \ ::mgb::__assert_fail__( \
} while(0) "about location info, please build with debug", __LINE__, \
void __assert_fail__() __attribute__((noreturn)); NULL, #expr, ##msg); \
#endif // MGB_ASSERT_LOC } while (0)
#endif // MGB_ASSERT_LOC
/* ================ logging ================ */ /* ================ logging ================ */
//! caused by need remve some words at opt release #if MGB_ASSERT_LOC
#if MGB_ENABLE_LOGGING
#define mgb_log_debug(fmt...) \ #define mgb_log_debug(fmt...) \
_mgb_do_log(::mgb::LogLevel::DEBUG, __FILE__, __func__, __LINE__, fmt) _mgb_do_log(::mgb::LogLevel::DEBUG, __FILE__, __func__, __LINE__, fmt)
#define mgb_log(fmt...) \ #define mgb_log(fmt...) \
...@@ -154,7 +153,6 @@ void __assert_fail__() __attribute__((noreturn)); ...@@ -154,7 +153,6 @@ void __assert_fail__() __attribute__((noreturn));
_mgb_do_log(::mgb::LogLevel::WARN, "", "", __LINE__, fmt) _mgb_do_log(::mgb::LogLevel::WARN, "", "", __LINE__, fmt)
#define mgb_log_error(fmt...) \ #define mgb_log_error(fmt...) \
_mgb_do_log(::mgb::LogLevel::ERROR, LOC, "", __LINE__, fmt) _mgb_do_log(::mgb::LogLevel::ERROR, LOC, "", __LINE__, fmt)
#undef LOC
#endif #endif
enum class LogLevel { DEBUG, INFO, WARN, ERROR, NO_LOG }; enum class LogLevel { DEBUG, INFO, WARN, ERROR, NO_LOG };
......
...@@ -1045,7 +1045,8 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { ...@@ -1045,7 +1045,8 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
INTER_WEIGHT_DENSEI_DOT; INTER_WEIGHT_DENSEI_DOT;
return megdnn::param::RelayoutFormat::Mode::INTER_WEIGHT_DENSEI; return megdnn::param::RelayoutFormat::Mode::INTER_WEIGHT_DENSEI;
} else { } else {
mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP); mgb_throw_if(conv_mode != megdnn::param::Convolution::Sparse::GROUP,
MegBrainError, "mode error");
if (filter->shape()[1] == 1 && filter->shape()[2] == 1) { if (filter->shape()[1] == 1 && filter->shape()[2] == 1) {
return megdnn::param::RelayoutFormat::Mode::INTER_WEIGHT_CHANI; return megdnn::param::RelayoutFormat::Mode::INTER_WEIGHT_CHANI;
} else { } else {
...@@ -1081,9 +1082,11 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { ...@@ -1081,9 +1082,11 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>(); auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
mgb_assert(conv_opr.param().format == mgb_throw_if(
megdnn::param::Convolution::Format::NCHW, conv_opr.param().format !=
"ConvertFormat Pass only support converting NCHW to NHWCD4"); megdnn::param::Convolution::Format::NCHW,
MegBrainError,
"ConvertFormat Pass only support converting NCHW to NHWCD4");
VarNode *conv_src = nullptr, *conv_weights = nullptr; VarNode *conv_src = nullptr, *conv_weights = nullptr;
if (new_inp[0]->shape().ndim == 4) { if (new_inp[0]->shape().ndim == 4) {
// new input src is NCHW // new input src is NCHW
...@@ -1094,8 +1097,9 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { ...@@ -1094,8 +1097,9 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
icpg = new_inp[1]->shape()[1]; icpg = new_inp[1]->shape()[1];
ocpg = new_inp[1]->shape()[0]; ocpg = new_inp[1]->shape()[0];
} else { } else {
mgb_assert(conv_opr.param().sparse == mgb_throw_if(conv_opr.param().sparse !=
megdnn::param::Convolution::Sparse::GROUP); megdnn::param::Convolution::Sparse::GROUP,
MegBrainError, "ERROR mode");
group = new_inp[1]->shape()[0]; group = new_inp[1]->shape()[0];
icpg = new_inp[1]->shape()[2]; icpg = new_inp[1]->shape()[2];
ocpg = new_inp[1]->shape()[1]; ocpg = new_inp[1]->shape()[1];
...@@ -1117,8 +1121,9 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { ...@@ -1117,8 +1121,9 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
megdnn::param::Convolution::Sparse::DENSE) { megdnn::param::Convolution::Sparse::DENSE) {
ocpg = new_inp[1]->shape()[0]; ocpg = new_inp[1]->shape()[0];
} else { } else {
mgb_assert(conv_opr.param().sparse == mgb_throw_if(conv_opr.param().sparse !=
megdnn::param::Convolution::Sparse::GROUP); megdnn::param::Convolution::Sparse::GROUP,
MegBrainError, "ERROR mode");
size_t icpg = new_inp[1]->shape()[2]; size_t icpg = new_inp[1]->shape()[2];
ocpg = new_inp[1]->shape()[1]; ocpg = new_inp[1]->shape()[1];
if (icpg == 1 && ocpg == 1) { if (icpg == 1 && ocpg == 1) {
...@@ -1176,9 +1181,11 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { ...@@ -1176,9 +1181,11 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
mgb_assert(conv_bias_opr.param().format == mgb_throw_if(
megdnn::param::ConvBias::Format::NCHW, conv_bias_opr.param().format !=
"ConvertFormat Pass only support converting NCHW to NHWCD4"); megdnn::param::ConvBias::Format::NCHW,
MegBrainError,
"ConvertFormat Pass only support converting NCHW to NHWCD4");
VarNode *conv_bias_src = nullptr, *conv_bias_weights = nullptr, VarNode *conv_bias_src = nullptr, *conv_bias_weights = nullptr,
*conv_bias_bias = nullptr; *conv_bias_bias = nullptr;
if (new_inp[0]->shape().ndim == 4) { if (new_inp[0]->shape().ndim == 4) {
...@@ -1190,8 +1197,9 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { ...@@ -1190,8 +1197,9 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
icpg = new_inp[1]->shape()[1]; icpg = new_inp[1]->shape()[1];
ocpg = new_inp[1]->shape()[0]; ocpg = new_inp[1]->shape()[0];
} else { } else {
mgb_assert(conv_bias_opr.param().sparse == mgb_throw_if(conv_bias_opr.param().sparse !=
megdnn::param::ConvBias::Sparse::GROUP); megdnn::param::ConvBias::Sparse::GROUP,
MegBrainError, "mode error");
group = new_inp[1]->shape()[0]; group = new_inp[1]->shape()[0];
icpg = new_inp[1]->shape()[2]; icpg = new_inp[1]->shape()[2];
ocpg = new_inp[1]->shape()[1]; ocpg = new_inp[1]->shape()[1];
...@@ -1213,8 +1221,9 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { ...@@ -1213,8 +1221,9 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
megdnn::param::ConvBias::Sparse::DENSE) { megdnn::param::ConvBias::Sparse::DENSE) {
ocpg = new_inp[1]->shape()[0]; ocpg = new_inp[1]->shape()[0];
} else { } else {
mgb_assert(conv_bias_opr.param().sparse == mgb_throw_if(conv_bias_opr.param().sparse !=
megdnn::param::ConvBias::Sparse::GROUP); megdnn::param::ConvBias::Sparse::GROUP,
MegBrainError, "ERROR mode");
size_t icpg = new_inp[1]->shape()[2]; size_t icpg = new_inp[1]->shape()[2];
ocpg = new_inp[1]->shape()[1]; ocpg = new_inp[1]->shape()[1];
if (icpg == 1 && ocpg == 1) { if (icpg == 1 && ocpg == 1) {
...@@ -1293,9 +1302,11 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { ...@@ -1293,9 +1302,11 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
auto& deconv_opr = opr->cast_final_safe<opr::ConvolutionBackwardData>(); auto& deconv_opr = opr->cast_final_safe<opr::ConvolutionBackwardData>();
mgb_assert(deconv_opr.param().format == mgb_throw_if(
megdnn::param::Convolution::Format::NCHW, deconv_opr.param().format !=
"ConvertFormat Pass only support converting NCHW to NHWCD4"); megdnn::param::Convolution::Format::NCHW,
MegBrainError,
"ConvertFormat Pass only support converting NCHW to NHWCD4");
VarNode *deconv_src = nullptr, *deconv_weights = nullptr; VarNode *deconv_src = nullptr, *deconv_weights = nullptr;
if (new_inp[1]->shape().ndim == 4) { if (new_inp[1]->shape().ndim == 4) {
// new input src is NCHW // new input src is NCHW
...@@ -1306,8 +1317,9 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { ...@@ -1306,8 +1317,9 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
icpg = new_inp[0]->shape()[0]; icpg = new_inp[0]->shape()[0];
ocpg = new_inp[0]->shape()[1]; ocpg = new_inp[0]->shape()[1];
} else { } else {
mgb_assert(deconv_opr.param().sparse == mgb_throw_if(deconv_opr.param().sparse !=
megdnn::param::Convolution::Sparse::GROUP); megdnn::param::Convolution::Sparse::GROUP,
MegBrainError, "mode error");
group = new_inp[0]->shape()[0]; group = new_inp[0]->shape()[0];
icpg = new_inp[0]->shape()[1]; icpg = new_inp[0]->shape()[1];
ocpg = new_inp[0]->shape()[2]; ocpg = new_inp[0]->shape()[2];
...@@ -1329,8 +1341,9 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { ...@@ -1329,8 +1341,9 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
megdnn::param::Convolution::Sparse::DENSE) { megdnn::param::Convolution::Sparse::DENSE) {
ocpg = new_inp[0]->shape()[1]; ocpg = new_inp[0]->shape()[1];
} else { } else {
mgb_assert(deconv_opr.param().sparse == mgb_throw_if(deconv_opr.param().sparse !=
megdnn::param::Convolution::Sparse::GROUP); megdnn::param::Convolution::Sparse::GROUP,
MegBrainError, "mode error");
ocpg = new_inp[0]->shape()[2]; ocpg = new_inp[0]->shape()[2];
} }
...@@ -1393,9 +1406,11 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { ...@@ -1393,9 +1406,11 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
return opr_shallow_copy; return opr_shallow_copy;
} }
auto& resize_opr = opr->cast_final_safe<opr::ResizeForward>(); auto& resize_opr = opr->cast_final_safe<opr::ResizeForward>();
mgb_assert(resize_opr.param().format == mgb_throw_if(
megdnn::param::Resize::Format::NCHW, resize_opr.param().format !=
"ConvertFormat Pass only support converting NCHW to NHWCD4"); megdnn::param::Resize::Format::NCHW,
MegBrainError,
"ConvertFormat Pass only support converting NCHW to NHWCD4");
VarNode* inp = nullptr; VarNode* inp = nullptr;
if (new_inp[0]->shape().ndim == 4) { if (new_inp[0]->shape().ndim == 4) {
auto param = megdnn::param::RelayoutFormat(); auto param = megdnn::param::RelayoutFormat();
...@@ -1425,9 +1440,11 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { ...@@ -1425,9 +1440,11 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
return opr_shallow_copy; return opr_shallow_copy;
} }
auto& warp_opr = opr->cast_final_safe<opr::WarpPerspectiveForward>(); auto& warp_opr = opr->cast_final_safe<opr::WarpPerspectiveForward>();
mgb_assert(warp_opr.param().format == mgb_throw_if(
megdnn::param::WarpPerspective::Format::NCHW, warp_opr.param().format !=
"ConvertFormat Pass only support converting NCHW to NHWCD4"); megdnn::param::WarpPerspective::Format::NCHW,
MegBrainError,
"ConvertFormat Pass only support converting NCHW to NHWCD4");
VarNode* inp = nullptr; VarNode* inp = nullptr;
if (new_inp[0]->shape().ndim == 4) { if (new_inp[0]->shape().ndim == 4) {
// new input src is NCHW // new input src is NCHW
...@@ -1466,9 +1483,11 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { ...@@ -1466,9 +1483,11 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
return opr_shallow_copy; return opr_shallow_copy;
} }
auto& warp_opr = opr->cast_final_safe<opr::WarpAffineForward>(); auto& warp_opr = opr->cast_final_safe<opr::WarpAffineForward>();
mgb_assert(warp_opr.param().format == mgb_throw_if(
megdnn::param::WarpAffine::Format::NCHW, warp_opr.param().format !=
"ConvertFormat Pass only support converting NCHW to NHWCD4"); megdnn::param::WarpAffine::Format::NCHW,
MegBrainError,
"ConvertFormat Pass only support converting NCHW to NHWCD4");
VarNode* inp = nullptr; VarNode* inp = nullptr;
if (new_inp[0]->shape().ndim == 4) { if (new_inp[0]->shape().ndim == 4) {
// new input src is NCHW // new input src is NCHW
...@@ -1499,9 +1518,11 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { ...@@ -1499,9 +1518,11 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
return opr_shallow_copy; return opr_shallow_copy;
} }
auto& pooling_opr = opr->cast_final_safe<opr::PoolingForward>(); auto& pooling_opr = opr->cast_final_safe<opr::PoolingForward>();
mgb_assert(pooling_opr.param().format == mgb_throw_if(
megdnn::param::Pooling::Format::NCHW, pooling_opr.param().format !=
"ConvertFormat Pass only support converting NCHW to NHWCD4"); megdnn::param::Pooling::Format::NCHW,
MegBrainError,
"ConvertFormat Pass only support converting NCHW to NHWCD4");
VarNode* inp = nullptr; VarNode* inp = nullptr;
if (new_inp[0]->shape().ndim == 4) { if (new_inp[0]->shape().ndim == 4) {
// new input src is NCHW // new input src is NCHW
......
...@@ -1465,7 +1465,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { ...@@ -1465,7 +1465,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
return {weight_to_nchw4_mode_dense, src_to_nchw4_mode}; return {weight_to_nchw4_mode_dense, src_to_nchw4_mode};
} }
} else { } else {
mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP); mgb_throw_if(conv_mode != megdnn::param::Convolution::Sparse::GROUP,
MegBrainError, "mode error");
mgb_assert(filter->shape().ndim == 5, mgb_assert(filter->shape().ndim == 5,
"The origin filter if not NCHW mode"); "The origin filter if not NCHW mode");
size_t IC = filter->shape()[2]; size_t IC = filter->shape()[2];
...@@ -2018,7 +2019,8 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { ...@@ -2018,7 +2019,8 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
ret.second = hybrid_nchw_nchwxx; ret.second = hybrid_nchw_nchwxx;
} }
} else { } else {
mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP); mgb_throw_if(conv_mode != megdnn::param::Convolution::Sparse::GROUP,
MegBrainError, "mode error");
size_t group = filter->shape()[0]; size_t group = filter->shape()[0];
size_t ocpg = filter->shape()[1]; size_t ocpg = filter->shape()[1];
size_t icpg = filter->shape()[2]; size_t icpg = filter->shape()[2];
...@@ -2038,9 +2040,11 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { ...@@ -2038,9 +2040,11 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>(); auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
mgb_assert(conv_opr.param().format == mgb_throw_if(
megdnn::param::Convolution::Format::NCHW, conv_opr.param().format !=
"ConvertFormat Pass only support converting NCHW to NCHWXX"); megdnn::param::Convolution::Format::NCHW,
MegBrainError,
"ConvertFormat Pass only support converting NCHW to NCHWXX");
bool valid_nchw_nchw44 = bool valid_nchw_nchw44 =
nchw_nchwxx_valid(conv_opr, new_inp, pack_c_size); nchw_nchwxx_valid(conv_opr, new_inp, pack_c_size);
auto is_trans = test_trans_nchwxx( auto is_trans = test_trans_nchwxx(
...@@ -2118,9 +2122,11 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { ...@@ -2118,9 +2122,11 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
mgb_assert(opr->input().size() <= 3, mgb_assert(opr->input().size() <= 3,
"nchwxx does not support conv_bias fuse Z right now"); "nchwxx does not support conv_bias fuse Z right now");
auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
mgb_assert(conv_bias_opr.param().format == mgb_throw_if(
megdnn::param::ConvBias::Format::NCHW, conv_bias_opr.param().format !=
"ConvertFormat Pass only support converting NCHW to NCHWXX"); megdnn::param::ConvBias::Format::NCHW,
MegBrainError,
"ConvertFormat Pass only support converting NCHW to NCHWXX");
bool valid_nchw_nchw44 = bool valid_nchw_nchw44 =
nchw_nchwxx_valid(conv_bias_opr, new_inp, pack_c_size, nchw_nchwxx_valid(conv_bias_opr, new_inp, pack_c_size,
conv_bias_opr.param().nonlineMode); conv_bias_opr.param().nonlineMode);
...@@ -2244,9 +2250,11 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { ...@@ -2244,9 +2250,11 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
auto& pooling_opr = opr->cast_final_safe<opr::PoolingForward>(); auto& pooling_opr = opr->cast_final_safe<opr::PoolingForward>();
mgb_assert(pooling_opr.param().format == mgb_throw_if(
megdnn::param::Pooling::Format::NCHW, pooling_opr.param().format !=
"ConvertFormat Pass only support converting NCHW to NCHWxx"); megdnn::param::Pooling::Format::NCHW,
MegBrainError,
"ConvertFormat Pass only support converting NCHW to NCHWxx");
VarNode* inp = new_inp[0]; VarNode* inp = new_inp[0];
//! if input is nchwxx //! if input is nchwxx
if (inp->shape().ndim == 5) { if (inp->shape().ndim == 5) {
...@@ -2433,7 +2441,8 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2433,7 +2441,8 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
} }
} }
} else { } else {
mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP); mgb_throw_if(conv_mode != megdnn::param::Convolution::Sparse::GROUP,
MegBrainError, "mode error");
size_t group = filter->shape()[0]; size_t group = filter->shape()[0];
size_t ocpg = filter->shape()[1]; size_t ocpg = filter->shape()[1];
size_t icpg = filter->shape()[2]; size_t icpg = filter->shape()[2];
...@@ -2462,10 +2471,11 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2462,10 +2471,11 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>(); auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
mgb_assert(conv_opr.param().format == mgb_throw_if(conv_opr.param().format !=
megdnn::param::Convolution::Format::NCHW, megdnn::param::Convolution::Format::NCHW,
"ConvertFormat Pass only support converting NCHW to " MegBrainError,
"NCHW44_DOT"); "ConvertFormat Pass only support converting NCHW to "
"NCHW44_DOT");
bool valid_nchw_nchw44 = nchw_nchwxx_valid( bool valid_nchw_nchw44 = nchw_nchwxx_valid(
conv_opr, new_inp, pack_c_size, conv_opr, new_inp, pack_c_size,
megdnn::param::ConvBias::NonlineMode::IDENTITY, true); megdnn::param::ConvBias::NonlineMode::IDENTITY, true);
...@@ -2543,9 +2553,11 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2543,9 +2553,11 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
mgb_assert(opr->input().size() <= 3, mgb_assert(opr->input().size() <= 3,
"nchwxx-dot does not support conv_bias fuse Z right now"); "nchwxx-dot does not support conv_bias fuse Z right now");
auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
mgb_assert(conv_bias_opr.param().format == mgb_throw_if(
megdnn::param::ConvBias::Format::NCHW, conv_bias_opr.param().format !=
"ConvertFormat Pass only support converting NCHW to NCHWXX"); megdnn::param::ConvBias::Format::NCHW,
MegBrainError,
"ConvertFormat Pass only support converting NCHW to NCHWXX");
bool valid_nchw_nchw44 = bool valid_nchw_nchw44 =
nchw_nchwxx_valid(conv_bias_opr, new_inp, pack_c_size, nchw_nchwxx_valid(conv_bias_opr, new_inp, pack_c_size,
conv_bias_opr.param().nonlineMode, true); conv_bias_opr.param().nonlineMode, true);
......
...@@ -127,7 +127,7 @@ ...@@ -127,7 +127,7 @@
// whether to enbale configuing megbrain internals through env vars // whether to enbale configuing megbrain internals through env vars
#ifndef MGB_ENABLE_GETENV #ifndef MGB_ENABLE_GETENV
#define MGB_ENABLE_GETENV 1 #define MGB_ENABLE_GETENV MGB_ASSERT_LOC
#endif #endif
// whether to remove unnecessary features when used for serving // whether to remove unnecessary features when used for serving
......
...@@ -343,24 +343,24 @@ void Elemwise::mem_plan_fwd_in2out_writable() { ...@@ -343,24 +343,24 @@ void Elemwise::mem_plan_fwd_in2out_writable() {
} }
void Elemwise::scn_do_execute() { void Elemwise::scn_do_execute() {
auto &&inp = input(); auto&& inp = input();
megdnn::TensorNDArray megdnn_inp; megdnn::TensorNDArray dnn_inp;
mgb_assert(megdnn_inp.capacity() >= inp.size(), mgb_assert(dnn_inp.capacity() >= inp.size(),
"heap allocation in elemwise exec"); "heap allocation in elemwise exec");
megdnn_inp.resize(inp.size()); dnn_inp.resize(inp.size());
for (size_t i = 0; i < inp.size(); ++ i) { for (size_t i = 0; i < inp.size(); ++i) {
if (inp[i]->dev_tensor().empty()) { if (inp[i]->dev_tensor().empty()) {
mgb_assert(output(0)->dev_tensor().empty()); mgb_assert(output(0)->dev_tensor().empty());
return; return;
} }
megdnn_inp[i] = (inp[i]->dev_tensor().as_megdnn()); dnn_inp[i] = (inp[i]->dev_tensor().as_megdnn());
} }
mgb_assert(!output(0)->dev_tensor().empty()); mgb_assert(!output(0)->dev_tensor().empty());
megdnn_opr()->param() = param(); megdnn_opr()->param() = param();
call_megdnn_opr_exec( call_megdnn_opr_exec(comp_node(), dnn_inp,
comp_node(), megdnn_inp, output(0)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), megdnn_opr(),
megdnn_opr(), this); this);
} }
void Elemwise::init_output_static_infer_desc() { void Elemwise::init_output_static_infer_desc() {
......
...@@ -126,10 +126,11 @@ namespace serialization { ...@@ -126,10 +126,11 @@ namespace serialization {
MGB_MARK_USED_VAR(graph); MGB_MARK_USED_VAR(graph);
SymbolVar target_shape; SymbolVar target_shape;
if (inputs.size() == 1) { if (inputs.size() == 1) {
mgb_assert(param.axis >= mgb_throw_if(
-megdnn::param::OptionalAxisV1::MAX_NDIM && param.axis < -megdnn::param::OptionalAxisV1::MAX_NDIM ||
param.axis < param.axis >=
megdnn::param::OptionalAxisV1::MAX_NDIM); megdnn::param::OptionalAxisV1::MAX_NDIM,
MegBrainError, "DIM error");
} else { } else {
mgb_assert(inputs.size() == 2); mgb_assert(inputs.size() == 2);
target_shape = inputs[1]; target_shape = inputs[1];
......
...@@ -470,9 +470,9 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(SVD); ...@@ -470,9 +470,9 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(SVD);
SVD::SVD(VarNode* src, const Param& param, const OperatorNodeConfig& config) : SVD::SVD(VarNode* src, const Param& param, const OperatorNodeConfig& config) :
Super(OperatorNodeBaseCtorParam{src->owner_graph(), Super(OperatorNodeBaseCtorParam{src->owner_graph(),
config, "svd", {src}}) { config, "svd", {src}}) {
mgb_assert(src->dtype() == megdnn::dtype::Float32(), mgb_throw_if(src->dtype() != megdnn::dtype::Float32(), MegDNNError,
"Singular Value Decomposition on non-float32 tensors is " "Singular Value Decomposition on non-float32 tensors is not "
"not supoorted."); "supoorted.");
init_megdnn_opr(*this, param); init_megdnn_opr(*this, param);
add_input({src}); add_input({src});
......
...@@ -187,12 +187,12 @@ template<class Opr> ...@@ -187,12 +187,12 @@ template<class Opr>
Opr& mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::megdnn_opr( Opr& mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::megdnn_opr(
cg::SingleCNOperatorNodeBase& self) { cg::SingleCNOperatorNodeBase& self) {
auto comp_node = self.comp_node(); auto comp_node = self.comp_node();
if (!m_megdnn_opr || m_megdnn_opr.comp_node() != comp_node) { if (!m_dnn_opr || m_dnn_opr.comp_node() != comp_node) {
m_megdnn_opr = intl::create_megdnn_opr<Opr>(comp_node); m_dnn_opr = intl::create_megdnn_opr<Opr>(comp_node);
m_megdnn_opr->set_error_tracker( m_dnn_opr->set_error_tracker(
static_cast<cg::OperatorNodeBase*>(&self)); static_cast<cg::OperatorNodeBase*>(&self));
} }
return *m_megdnn_opr; return *m_dnn_opr;
} }
template<class Opr> template<class Opr>
...@@ -228,7 +228,7 @@ template <class Opr> ...@@ -228,7 +228,7 @@ template <class Opr>
void mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::record_megdnn_opr( void mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::record_megdnn_opr(
mgb::cg::GraphExecutable::ExecDependencyArray& deps) { mgb::cg::GraphExecutable::ExecDependencyArray& deps) {
deps.emplace_back( deps.emplace_back(
std::make_unique<intl::MegDNNGraphDep>(std::move(m_megdnn_opr))); std::make_unique<intl::MegDNNGraphDep>(std::move(m_dnn_opr)));
} }
/* ==================== MultiAxisVecFancyIndexingHelper ==================== */ /* ==================== MultiAxisVecFancyIndexingHelper ==================== */
...@@ -258,14 +258,24 @@ intl::MultiAxisVecFancyIndexingHelper::make_megdnn_index_desc( ...@@ -258,14 +258,24 @@ intl::MultiAxisVecFancyIndexingHelper::make_megdnn_index_desc(
} }
} }
if (all_scalar) { if (all_scalar) {
mgb_log_warn("%s{%s}: no vector indexer; consider using Subtensor " #if MGB_ENABLE_GETENV
mgb_log_warn(
"%s{%s}: no vector indexer; consider using Subtensor "
"family for better performance; you can set " "family for better performance; you can set "
"MGB_THROW_ON_SCALAR_IDX to throw an exception to help " "MGB_THROW_ON_SCALAR_IDX to throw an exception to help "
"tracking the related operator", "tracking the related operator",
cname(), dyn_typeinfo()->name); cname(), dyn_typeinfo()->name);
mgb_throw_if(MGB_GETENV("MGB_THROW_ON_SCALAR_IDX"), #else
MegBrainError, "vector-indexing operator used with all " mgb_log_warn(
"scalar indices"); "%s{%s}: no vector indexer; consider using Subtensor "
"family for better performance",
cname(), dyn_typeinfo()->name);
#endif
#if MGB_ENABLE_GETENV
mgb_throw_if(MGB_GETENV("MGB_THROW_ON_SCALAR_IDX"), MegBrainError,
"vector-indexing operator used with all "
"scalar indices");
#endif
} }
// always set m_scalar_idx_warn_printed to be true, so we do not print // always set m_scalar_idx_warn_printed to be true, so we do not print
......
...@@ -377,21 +377,21 @@ MegDNNOprHolder::~MegDNNOprHolder() noexcept = default; ...@@ -377,21 +377,21 @@ MegDNNOprHolder::~MegDNNOprHolder() noexcept = default;
void MegDNNOprHolder::mixin_init_output_comp_node(OperatorNodeBase &self) { void MegDNNOprHolder::mixin_init_output_comp_node(OperatorNodeBase &self) {
SingleCNOperatorNode::mixin_init_output_comp_node(self); SingleCNOperatorNode::mixin_init_output_comp_node(self);
create_megdnn_opr(); create_megdnn_opr();
mgb_assert(m_megdnn_opr); mgb_assert(m_dnn_opr);
m_megdnn_opr->set_error_tracker(&self); m_dnn_opr->set_error_tracker(&self);
} }
void MegDNNOprHolder::mixin_on_output_comp_node_stream_changed( void MegDNNOprHolder::mixin_on_output_comp_node_stream_changed(
OperatorNodeBase &self) { OperatorNodeBase &self) {
SingleCNOperatorNode::mixin_on_output_comp_node_stream_changed(self); SingleCNOperatorNode::mixin_on_output_comp_node_stream_changed(self);
create_megdnn_opr(); create_megdnn_opr();
mgb_assert(m_megdnn_opr); mgb_assert(m_dnn_opr);
m_megdnn_opr->set_error_tracker(&self); m_dnn_opr->set_error_tracker(&self);
} }
void MegDNNOprHolder::set_megdnn_opr( void MegDNNOprHolder::set_megdnn_opr(
std::unique_ptr<megdnn::OperatorBase> self) { std::unique_ptr<megdnn::OperatorBase> self) {
m_megdnn_opr = std::move(self); m_dnn_opr = std::move(self);
} }
void MegDNNOprHolder::record_megdnn_opr( void MegDNNOprHolder::record_megdnn_opr(
...@@ -402,7 +402,7 @@ void MegDNNOprHolder::record_megdnn_opr( ...@@ -402,7 +402,7 @@ void MegDNNOprHolder::record_megdnn_opr(
void MegDNNOprHolder::record_megdnn_opr( void MegDNNOprHolder::record_megdnn_opr(
cg::GraphExecutable::ExecDependencyArray& deps) { cg::GraphExecutable::ExecDependencyArray& deps) {
record_megdnn_opr(std::move(m_megdnn_opr), deps); record_megdnn_opr(std::move(m_dnn_opr), deps);
} }
/* ================== MegDNNOprHolderBwdStaticInfer ================== */ /* ================== MegDNNOprHolderBwdStaticInfer ================== */
......
...@@ -59,10 +59,10 @@ cg::OperatorNodeBase::NodeProp* RNGOprBase::do_make_node_prop() const { ...@@ -59,10 +59,10 @@ cg::OperatorNodeBase::NodeProp* RNGOprBase::do_make_node_prop() const {
} }
void RNGOprBase::ensure_megdnn_opr() { void RNGOprBase::ensure_megdnn_opr() {
if (!m_megdnn_opr || m_megdnn_opr.comp_node() != comp_node()) { if (!m_dnn_opr || m_dnn_opr.comp_node() != comp_node()) {
// activate comp_node for curandCreateGenerator in create_megdnn_opr // activate comp_node for curandCreateGenerator in create_megdnn_opr
comp_node().activate(); comp_node().activate();
m_megdnn_opr = create_megdnn_opr(); m_dnn_opr = create_megdnn_opr();
} }
} }
...@@ -76,7 +76,7 @@ void RNGOprBase::init_output_static_infer_desc() { ...@@ -76,7 +76,7 @@ void RNGOprBase::init_output_static_infer_desc() {
auto infer_wk = [this](TensorShape &dest, const InpVal &inp) { auto infer_wk = [this](TensorShape &dest, const InpVal &inp) {
ensure_megdnn_opr(); ensure_megdnn_opr();
dest.ndim = 1; dest.ndim = 1;
dest.shape[0] = m_megdnn_opr->get_workspace_in_bytes( dest.shape[0] = m_dnn_opr->get_workspace_in_bytes(
{inp.val.at(0).shape(), output(0)->dtype()}); {inp.val.at(0).shape(), output(0)->dtype()});
return true; return true;
}; };
...@@ -87,7 +87,7 @@ void RNGOprBase::init_output_static_infer_desc() { ...@@ -87,7 +87,7 @@ void RNGOprBase::init_output_static_infer_desc() {
} }
void RNGOprBase::scn_do_execute() { void RNGOprBase::scn_do_execute() {
m_megdnn_opr->exec( m_dnn_opr->exec(
output(0)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(),
get_megdnn_workspace_from_var(output(1))); get_megdnn_workspace_from_var(output(1)));
} }
......
...@@ -332,7 +332,7 @@ AlgoChooser<Opr>::AlgoChooserHelper::AlgoChooserHelper( ...@@ -332,7 +332,7 @@ AlgoChooser<Opr>::AlgoChooserHelper::AlgoChooserHelper(
const megdnn::param::ExecutionPolicy& execution_policy, const megdnn::param::ExecutionPolicy& execution_policy,
bool allow_weight_preprocess) bool allow_weight_preprocess)
: m_layouts{layouts}, : m_layouts{layouts},
m_megdnn_opr{megdnn_opr}, m_dnn_opr{megdnn_opr},
m_param{param_str}, m_param{param_str},
m_base_mgb_opr{mgb_opr}, m_base_mgb_opr{mgb_opr},
m_cn{cn}, m_cn{cn},
...@@ -356,15 +356,15 @@ AlgoChooser<Opr>::AlgoChooserHelper::choose_by_heuristic( ...@@ -356,15 +356,15 @@ AlgoChooser<Opr>::AlgoChooserHelper::choose_by_heuristic(
owner_graph(), m_cn, m_execution_policy.workspace_limit); owner_graph(), m_cn, m_execution_policy.workspace_limit);
auto attr = extract_algo_attribute(selected_strategy); auto attr = extract_algo_attribute(selected_strategy);
policy.algo = policy.algo =
APPLY(m_megdnn_opr->get_algorithm_info_heuristic( APPLY(m_dnn_opr->get_algorithm_info_heuristic(
args..., workspace_limit, attr.first, attr.second), args..., workspace_limit, attr.first, attr.second),
m_layouts) m_layouts)
.desc; .desc;
Algorithm* algo = m_megdnn_opr->get_algorithm_from_desc(policy.algo); Algorithm* algo = m_dnn_opr->get_algorithm_from_desc(policy.algo);
mgb_assert(algo, "Unknown algo description"); mgb_assert(algo, "Unknown algo description");
std::vector<Algorithm::SearchItem>&& sub_items = algo->get_subopr_list( std::vector<Algorithm::SearchItem>&& sub_items = algo->get_subopr_list(
to_layout_array<Opr>(m_layouts), m_megdnn_opr); to_layout_array<Opr>(m_layouts), m_dnn_opr);
FOREACH_OPR_TYPE_DISPATCH(sub_items, { FOREACH_OPR_TYPE_DISPATCH(sub_items, {
auto&& megdnn_opr = intl::create_megdnn_opr<_Opr>(m_cn); auto&& megdnn_opr = intl::create_megdnn_opr<_Opr>(m_cn);
...@@ -389,7 +389,7 @@ AlgoChooser<Opr>::AlgoChooserHelper::choose_by_profile( ...@@ -389,7 +389,7 @@ AlgoChooser<Opr>::AlgoChooserHelper::choose_by_profile(
const ExecutionStrategy& selected_strategy, bool enable_update) const { const ExecutionStrategy& selected_strategy, bool enable_update) const {
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("choose_by_profile"))) MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("choose_by_profile")))
if (owner_graph()->options().no_profiling_on_shape_change) { if (owner_graph()->options().no_profiling_on_shape_change) {
auto policy = m_megdnn_opr->execution_policy(); auto policy = m_dnn_opr->execution_policy();
if (policy.algo.valid()) { if (policy.algo.valid()) {
return policy; return policy;
} }
...@@ -439,9 +439,9 @@ typename AlgoChooser<Opr>::ImplAlgoDesc ...@@ -439,9 +439,9 @@ typename AlgoChooser<Opr>::ImplAlgoDesc
AlgoChooser<Opr>::AlgoChooserHelper::get_profile_result_from_cache( AlgoChooser<Opr>::AlgoChooserHelper::get_profile_result_from_cache(
const ExecutionStrategy& selected_strategy) const { const ExecutionStrategy& selected_strategy) const {
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_profile_result_from_cache"))) MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_profile_result_from_cache")))
AlgoChooserProfileCache cache(m_cn, profile_name(m_megdnn_opr).c_str()); AlgoChooserProfileCache cache(m_cn, profile_name(m_dnn_opr).c_str());
typename Opr::Param origin_param = m_megdnn_opr->param(); typename Opr::Param origin_param = m_dnn_opr->param();
AlgoChooserProfileCache::Key cache_key{m_layouts.data(), m_layouts.size(), AlgoChooserProfileCache::Key cache_key{m_layouts.data(), m_layouts.size(),
&origin_param, sizeof(origin_param)}; &origin_param, sizeof(origin_param)};
auto&& rst = cache.get(cache_key); auto&& rst = cache.get(cache_key);
...@@ -504,7 +504,7 @@ void AlgoChooser<Opr>::AlgoChooserHelper::construct_execution_policy( ...@@ -504,7 +504,7 @@ void AlgoChooser<Opr>::AlgoChooserHelper::construct_execution_policy(
std::string layouts_str = format_fixlayouts<Opr>( std::string layouts_str = format_fixlayouts<Opr>(
m_layouts, arity_in, arity_out); m_layouts, arity_in, arity_out);
std::string msg = ssprintf( std::string msg = ssprintf(
"(mbg_opr : %s, layouts %s, with attribute(%s) and " "(opr : %s, layouts %s, with attribute(%s) and "
"without attribute(%s)", "without attribute(%s)",
m_base_mgb_opr->dyn_typeinfo()->name, m_base_mgb_opr->dyn_typeinfo()->name,
layouts_str.c_str(), layouts_str.c_str(),
...@@ -526,7 +526,7 @@ void AlgoChooser<Opr>::AlgoChooserHelper::construct_execution_policy( ...@@ -526,7 +526,7 @@ void AlgoChooser<Opr>::AlgoChooserHelper::construct_execution_policy(
owner_graph(), m_cn, m_execution_policy.workspace_limit); owner_graph(), m_cn, m_execution_policy.workspace_limit);
auto attr = extract_algo_attribute(selected_strategy); auto attr = extract_algo_attribute(selected_strategy);
policy.algo = APPLY(m_megdnn_opr->get_algorithm_info_heuristic( policy.algo = APPLY(m_dnn_opr->get_algorithm_info_heuristic(
args..., workspace_limit, attr.first, args..., workspace_limit, attr.first,
attr.second), attr.second),
m_layouts) m_layouts)
...@@ -539,10 +539,10 @@ void AlgoChooser<Opr>::AlgoChooserHelper::construct_execution_policy( ...@@ -539,10 +539,10 @@ void AlgoChooser<Opr>::AlgoChooserHelper::construct_execution_policy(
} }
} }
Algorithm* algo = m_megdnn_opr->get_algorithm_from_desc(policy.algo); Algorithm* algo = m_dnn_opr->get_algorithm_from_desc(policy.algo);
mgb_assert(algo, "Unknown algo description"); mgb_assert(algo, "Unknown algo description");
std::vector<Algorithm::SearchItem>&& sub_items = algo->get_subopr_list( std::vector<Algorithm::SearchItem>&& sub_items = algo->get_subopr_list(
to_layout_array<Opr>(m_layouts), m_megdnn_opr); to_layout_array<Opr>(m_layouts), m_dnn_opr);
FOREACH_OPR_TYPE_DISPATCH(sub_items, { FOREACH_OPR_TYPE_DISPATCH(sub_items, {
auto&& megdnn_opr = intl::create_megdnn_opr<_Opr>(m_cn); auto&& megdnn_opr = intl::create_megdnn_opr<_Opr>(m_cn);
...@@ -571,11 +571,11 @@ template <typename Opr> ...@@ -571,11 +571,11 @@ template <typename Opr>
size_t AlgoChooser<Opr>::AlgoChooserHelper::get_workspace_size_bytes( size_t AlgoChooser<Opr>::AlgoChooserHelper::get_workspace_size_bytes(
const ImplExecutionPolicy& policy) const { const ImplExecutionPolicy& policy) const {
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_workspace_size_bytes"))) MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_workspace_size_bytes")))
m_megdnn_opr->execution_policy() = policy; m_dnn_opr->execution_policy() = policy;
size_t result; size_t result;
if_constexpr<opr_supports_preprocess<Opr>()>( if_constexpr<opr_supports_preprocess<Opr>()>(
[&](auto _) { [&](auto _) {
auto&& opr = _(m_megdnn_opr); auto&& opr = _(m_dnn_opr);
auto prep = this->construct_fake_preprocess_filter(); auto prep = this->construct_fake_preprocess_filter();
PreprocessFilter<Opr>* prep_ptr = PreprocessFilter<Opr>* prep_ptr =
prep.valid() ? &prep.val() : nullptr; prep.valid() ? &prep.val() : nullptr;
...@@ -587,7 +587,7 @@ size_t AlgoChooser<Opr>::AlgoChooserHelper::get_workspace_size_bytes( ...@@ -587,7 +587,7 @@ size_t AlgoChooser<Opr>::AlgoChooserHelper::get_workspace_size_bytes(
}, },
/* else */ /* else */
[&](auto _) { [&](auto _) {
result = APPLY(_(m_megdnn_opr)->get_workspace_in_bytes(args...), result = APPLY(_(m_dnn_opr)->get_workspace_in_bytes(args...),
m_layouts); m_layouts);
}); });
return result; return result;
...@@ -600,7 +600,7 @@ AlgoChooser<Opr>::AlgoChooserHelper::get_all_candidates() const { ...@@ -600,7 +600,7 @@ AlgoChooser<Opr>::AlgoChooserHelper::get_all_candidates() const {
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_all_candidates"))) MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_all_candidates")))
auto heu = choose_by_heuristic(m_execution_policy.strategy); auto heu = choose_by_heuristic(m_execution_policy.strategy);
auto&& ret = auto&& ret =
APPLY(m_megdnn_opr->get_all_algorithms_info(args...), m_layouts); APPLY(m_dnn_opr->get_all_algorithms_info(args...), m_layouts);
bool found = false; bool found = false;
for (size_t i = 0; i < ret.size(); ++i) { for (size_t i = 0; i < ret.size(); ++i) {
if (ret[i].desc == heu.algo) { if (ret[i].desc == heu.algo) {
...@@ -610,7 +610,7 @@ AlgoChooser<Opr>::AlgoChooserHelper::get_all_candidates() const { ...@@ -610,7 +610,7 @@ AlgoChooser<Opr>::AlgoChooserHelper::get_all_candidates() const {
} }
} }
Algorithm* palgo = m_megdnn_opr->get_algorithm_from_desc(heu.algo); Algorithm* palgo = m_dnn_opr->get_algorithm_from_desc(heu.algo);
mgb_assert(palgo, "Unknown algo description"); mgb_assert(palgo, "Unknown algo description");
mgb_assert(found, mgb_assert(found,
"algo %s got by heuristic not found in " "algo %s got by heuristic not found in "
...@@ -644,10 +644,10 @@ AlgoChooser<Opr>::AlgoChooserHelper::profile_single_algo( ...@@ -644,10 +644,10 @@ AlgoChooser<Opr>::AlgoChooserHelper::profile_single_algo(
mgb_assert(param.shapes.size() == m_layouts.size()); mgb_assert(param.shapes.size() == m_layouts.size());
for (size_t i = 0; i < param.shapes.size(); ++i) for (size_t i = 0; i < param.shapes.size(); ++i)
param.shapes[i] = m_layouts[i]; param.shapes[i] = m_layouts[i];
param.opr_param = m_megdnn_opr->param(); param.opr_param = m_dnn_opr->param();
param.allow_weight_preprocess = m_allow_weight_preprocess; param.allow_weight_preprocess = m_allow_weight_preprocess;
Algorithm* palgo = m_megdnn_opr->get_algorithm_from_desc(policy.algo); Algorithm* palgo = m_dnn_opr->get_algorithm_from_desc(policy.algo);
mgb_assert(palgo, "can not find algo when profile single algo"); mgb_assert(palgo, "can not find algo when profile single algo");
auto rst = TimedProfiler<Opr>::profile(param, timeout); auto rst = TimedProfiler<Opr>::profile(param, timeout);
...@@ -691,7 +691,7 @@ void AlgoChooser<Opr>::AlgoChooserHelper::profile( ...@@ -691,7 +691,7 @@ void AlgoChooser<Opr>::AlgoChooserHelper::profile(
policy.algo = algo.desc; policy.algo = algo.desc;
//! check negative attribute : skip negative attribute //! check negative attribute : skip negative attribute
auto palgo = m_megdnn_opr->get_algorithm_from_desc(policy.algo); auto palgo = m_dnn_opr->get_algorithm_from_desc(policy.algo);
if (palgo->contain_attribute_any(target_attr.second)) { if (palgo->contain_attribute_any(target_attr.second)) {
mgb_log_debug( mgb_log_debug(
"skip algo %s, which matches the profile strategy required " "skip algo %s, which matches the profile strategy required "
...@@ -748,12 +748,12 @@ void AlgoChooser<Opr>::AlgoChooserHelper::profile( ...@@ -748,12 +748,12 @@ void AlgoChooser<Opr>::AlgoChooserHelper::profile(
mgb_assert(!prof_rst.empty(), "%s", msg.c_str()); mgb_assert(!prof_rst.empty(), "%s", msg.c_str());
FixedTensorLayouts origin_layouts = m_layouts; FixedTensorLayouts origin_layouts = m_layouts;
typename Opr::Param origin_param = m_megdnn_opr->param(); typename Opr::Param origin_param = m_dnn_opr->param();
AlgoChooserProfileCache::Key cache_key{origin_layouts.data(), AlgoChooserProfileCache::Key cache_key{origin_layouts.data(),
origin_layouts.size(), &origin_param, origin_layouts.size(), &origin_param,
sizeof(origin_param)}; sizeof(origin_param)};
AlgoChooserProfileCache cache(m_cn, profile_name(m_megdnn_opr).c_str()); AlgoChooserProfileCache cache(m_cn, profile_name(m_dnn_opr).c_str());
cache.put(cache_key, prof_rst); cache.put(cache_key, prof_rst);
MIDOUT_E MIDOUT_E
} }
...@@ -766,7 +766,7 @@ AlgoChooser<Opr>::AlgoChooserHelper::construct_fake_preprocess_filter() const { ...@@ -766,7 +766,7 @@ AlgoChooser<Opr>::AlgoChooserHelper::construct_fake_preprocess_filter() const {
if_constexpr<opr_supports_preprocess<Opr>()>([&](auto _) { if_constexpr<opr_supports_preprocess<Opr>()>([&](auto _) {
if (!m_allow_weight_preprocess) if (!m_allow_weight_preprocess)
return; return;
auto opr = _(m_megdnn_opr); auto opr = _(m_dnn_opr);
auto layouts = APPLY(opr->deduce_preprocessed_filter_layout(args...), auto layouts = APPLY(opr->deduce_preprocessed_filter_layout(args...),
m_layouts); m_layouts);
//! No preprocess layout means no need weight preprocess //! No preprocess layout means no need weight preprocess
......
...@@ -312,10 +312,15 @@ typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl( ...@@ -312,10 +312,15 @@ typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl(
double next_report_time = 0.5; double next_report_time = 0.5;
while (!ev_end->finished()) { while (!ev_end->finished()) {
if (timer.get_secs() >= next_report_time) { if (timer.get_secs() >= next_report_time) {
#if MGB_ENABLE_GETENV
mgb_log_warn( mgb_log_warn(
"profiling conv algo %s already took %.3f/%.3f secs" "profiling conv algo %s already took %.3f/%.3f secs"
" (limit can be set by MGB_CONV_PROFILING_TIMEOUT) ", " (limit can be set by MGB_CONV_PROFILING_TIMEOUT) ",
algo->name(), timer.get_secs(), param.actual_timeout); algo->name(), timer.get_secs(), param.actual_timeout);
#else
mgb_log_warn("profiling conv algo %s already took %.3f/%.3f secs",
algo->name(), timer.get_secs(), param.actual_timeout);
#endif
next_report_time = timer.get_secs() + 1; next_report_time = timer.get_secs() + 1;
} }
using namespace std::literals; using namespace std::literals;
......
...@@ -111,7 +111,7 @@ void Linspace::scn_do_execute() { ...@@ -111,7 +111,7 @@ void Linspace::scn_do_execute() {
stop.dtype(), stop.raw_ptr()).get_cast<double>(); stop.dtype(), stop.raw_ptr()).get_cast<double>();
auto cn = comp_node(); auto cn = comp_node();
auto &&opr = m_megdnn_opr; auto &&opr = m_dnn_opr;
if (!opr || opr.comp_node() != cn) if (!opr || opr.comp_node() != cn)
opr = intl::create_megdnn_opr<megdnn::Linspace>(cn); opr = intl::create_megdnn_opr<megdnn::Linspace>(cn);
opr->param() = {startv, stopv, m_param.endpoint}; opr->param() = {startv, stopv, m_param.endpoint};
...@@ -122,7 +122,7 @@ void Linspace::scn_do_execute() { ...@@ -122,7 +122,7 @@ void Linspace::scn_do_execute() {
void Linspace::record_execute_deps(ExecDependencyArray& deps) { void Linspace::record_execute_deps(ExecDependencyArray& deps) {
deps.emplace_back( deps.emplace_back(
std::make_unique<intl::MegDNNGraphDep>(std::move(m_megdnn_opr))); std::make_unique<intl::MegDNNGraphDep>(std::move(m_dnn_opr)));
} }
#if MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
...@@ -184,7 +184,7 @@ cg::OperatorNodeBase::NodeProp* Eye::do_make_node_prop() const { ...@@ -184,7 +184,7 @@ cg::OperatorNodeBase::NodeProp* Eye::do_make_node_prop() const {
void Eye::scn_do_execute() { void Eye::scn_do_execute() {
auto cn = comp_node(); auto cn = comp_node();
auto &&opr = m_megdnn_opr; auto &&opr = m_dnn_opr;
if (!opr || opr.comp_node() != cn) { if (!opr || opr.comp_node() != cn) {
opr = intl::create_megdnn_opr<megdnn::Eye>(cn); opr = intl::create_megdnn_opr<megdnn::Eye>(cn);
opr->param() = m_param; opr->param() = m_param;
...@@ -196,7 +196,7 @@ void Eye::scn_do_execute() { ...@@ -196,7 +196,7 @@ void Eye::scn_do_execute() {
void Eye::record_execute_deps(ExecDependencyArray& deps) { void Eye::record_execute_deps(ExecDependencyArray& deps) {
deps.emplace_back( deps.emplace_back(
std::make_unique<intl::MegDNNGraphDep>(std::move(m_megdnn_opr))); std::make_unique<intl::MegDNNGraphDep>(std::move(m_dnn_opr)));
} }
#if MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
......
...@@ -88,7 +88,7 @@ namespace mixin { ...@@ -88,7 +88,7 @@ namespace mixin {
template<class Opr> template<class Opr>
class IndexingMultiAxisVecMegDNNOprHolder { class IndexingMultiAxisVecMegDNNOprHolder {
intl::UniqPtrWithCN<Opr> m_megdnn_opr; intl::UniqPtrWithCN<Opr> m_dnn_opr;
protected: protected:
Opr& megdnn_opr(cg::SingleCNOperatorNodeBase& self); Opr& megdnn_opr(cg::SingleCNOperatorNodeBase& self);
......
...@@ -136,7 +136,7 @@ namespace mixin { ...@@ -136,7 +136,7 @@ namespace mixin {
virtual void create_megdnn_opr() = 0; virtual void create_megdnn_opr() = 0;
megdnn::OperatorBase* megdnn_opr() const { megdnn::OperatorBase* megdnn_opr() const {
return m_megdnn_opr.get(); return m_dnn_opr.get();
} }
void set_megdnn_opr(std::unique_ptr<megdnn::OperatorBase> opr); void set_megdnn_opr(std::unique_ptr<megdnn::OperatorBase> opr);
...@@ -146,7 +146,7 @@ namespace mixin { ...@@ -146,7 +146,7 @@ namespace mixin {
cg::GraphExecutable::ExecDependencyArray& deps); cg::GraphExecutable::ExecDependencyArray& deps);
private: private:
std::unique_ptr<megdnn::OperatorBase> m_megdnn_opr; std::unique_ptr<megdnn::OperatorBase> m_dnn_opr;
}; };
class MegDNNOprHolderBwdStaticInfer: public MegDNNOprHolder { class MegDNNOprHolderBwdStaticInfer: public MegDNNOprHolder {
......
...@@ -23,7 +23,7 @@ namespace opr { ...@@ -23,7 +23,7 @@ namespace opr {
namespace intl { namespace intl {
MGB_DEFINE_CLS_WITH_SUPER(RNGOprBase, cg::SingleCNOperatorNodeBase) // { MGB_DEFINE_CLS_WITH_SUPER(RNGOprBase, cg::SingleCNOperatorNodeBase) // {
UniqPtrWithCN<megdnn::RNGBase> m_megdnn_opr; UniqPtrWithCN<megdnn::RNGBase> m_dnn_opr;
void ensure_megdnn_opr(); void ensure_megdnn_opr();
void init_output_static_infer_desc() override; void init_output_static_infer_desc() override;
......
...@@ -69,7 +69,7 @@ public: ...@@ -69,7 +69,7 @@ public:
using FixedTensorLayouts = std::array<TensorLayout, arity>; using FixedTensorLayouts = std::array<TensorLayout, arity>;
class AlgoChooserHelper { class AlgoChooserHelper {
FixedTensorLayouts m_layouts; FixedTensorLayouts m_layouts;
Opr* m_megdnn_opr; Opr* m_dnn_opr;
std::string m_param; std::string m_param;
const cg::OperatorNodeBase* m_base_mgb_opr; const cg::OperatorNodeBase* m_base_mgb_opr;
CompNode m_cn; CompNode m_cn;
...@@ -84,7 +84,7 @@ public: ...@@ -84,7 +84,7 @@ public:
const megdnn::param::ExecutionPolicy& execution_policy, const megdnn::param::ExecutionPolicy& execution_policy,
bool allow_weight_preprocess); bool allow_weight_preprocess);
Opr* megdnn_opr() const { return m_megdnn_opr; } Opr* megdnn_opr() const { return m_dnn_opr; }
const cg::OperatorNodeBase* mgb_opr() const { return m_base_mgb_opr; } const cg::OperatorNodeBase* mgb_opr() const { return m_base_mgb_opr; }
...@@ -106,7 +106,7 @@ public: ...@@ -106,7 +106,7 @@ public:
megdnn::Algorithm* get_algorithm_from_desc( megdnn::Algorithm* get_algorithm_from_desc(
const megdnn::Algorithm::Info::Desc& desc) const { const megdnn::Algorithm::Info::Desc& desc) const {
return m_megdnn_opr->get_algorithm_from_desc(desc); return m_dnn_opr->get_algorithm_from_desc(desc);
} }
const FixedTensorLayouts& layouts() const { return m_layouts; } const FixedTensorLayouts& layouts() const { return m_layouts; }
......
...@@ -72,7 +72,7 @@ MGB_DEFINE_OPR_CLASS(Linspace, cg::SingleCNOperatorNodeBase) // { ...@@ -72,7 +72,7 @@ MGB_DEFINE_OPR_CLASS(Linspace, cg::SingleCNOperatorNodeBase) // {
private: private:
const Param m_param; const Param m_param;
intl::UniqPtrWithCN<megdnn::Linspace> m_megdnn_opr; intl::UniqPtrWithCN<megdnn::Linspace> m_dnn_opr;
void scn_do_execute() override; void scn_do_execute() override;
void init_output_static_infer_desc() override; void init_output_static_infer_desc() override;
...@@ -97,7 +97,7 @@ MGB_DEFINE_OPR_CLASS(Eye, cg::SingleCNOperatorNodeBase) // { ...@@ -97,7 +97,7 @@ MGB_DEFINE_OPR_CLASS(Eye, cg::SingleCNOperatorNodeBase) // {
private: private:
const Param m_param; const Param m_param;
intl::UniqPtrWithCN<megdnn::Eye> m_megdnn_opr; intl::UniqPtrWithCN<megdnn::Eye> m_dnn_opr;
void scn_do_execute() override; void scn_do_execute() override;
void init_output_static_infer_desc() override; void init_output_static_infer_desc() override;
......
...@@ -279,6 +279,7 @@ void VarSanityCheck::check_single_input(bool add_debug_log, ...@@ -279,6 +279,7 @@ void VarSanityCheck::check_single_input(bool add_debug_log,
} }
if (checksum != checksum_expect) { if (checksum != checksum_expect) {
#if MGB_ENABLE_GETENV
mgb_throw(Error, mgb_throw(Error,
"var sanity check failed: var: %s" "var sanity check failed: var: %s"
" (checksum: expect=%s got=%s); receiver: %s{%s}(%zu);" " (checksum: expect=%s got=%s); receiver: %s{%s}(%zu);"
...@@ -288,6 +289,15 @@ void VarSanityCheck::check_single_input(bool add_debug_log, ...@@ -288,6 +289,15 @@ void VarSanityCheck::check_single_input(bool add_debug_log,
str(checksum_expect).c_str(), str(checksum).c_str(), str(checksum_expect).c_str(), str(checksum).c_str(),
recv_opr->cname(), recv_opr->dyn_typeinfo()->name, recv_opr->cname(), recv_opr->dyn_typeinfo()->name,
recv_opr->id(), var->id(), !add_debug_log); recv_opr->id(), var->id(), !add_debug_log);
#else
mgb_throw(Error,
"var sanity check failed: var: %s"
" (checksum: expect=%s got=%s); receiver: %s{%s}(%zu);",
cg::dump_var_info({var}).c_str(),
str(checksum_expect).c_str(), str(checksum).c_str(),
recv_opr->cname(), recv_opr->dyn_typeinfo()->name,
recv_opr->id());
#endif
} }
} }
......
...@@ -292,7 +292,7 @@ ExternCOprRunner::ExternCOprRunner(std::string& name, ...@@ -292,7 +292,7 @@ ExternCOprRunner::ExternCOprRunner(std::string& name,
auto size_diff = sizeof(MGBOprDesc) - m_desc->size; auto size_diff = sizeof(MGBOprDesc) - m_desc->size;
is_loader_support_dynamic_param = (0 == size_diff) ? true : false; is_loader_support_dynamic_param = (0 == size_diff) ? true : false;
mgb_assert(0 == size_diff || sizeof(ExternCOprParam*) == size_diff, mgb_assert(0 == size_diff || sizeof(ExternCOprParam*) == size_diff,
"invalid MGBOprDesc size: expect=%zu got=%u, may caused by " "invalid OprDesc size: expect=%zu got=%u, may caused by "
"extern_c_opr.h mismatch, please confirm that the " "extern_c_opr.h mismatch, please confirm that the "
"extern_c_opr.h used when compiling the loader is consistent " "extern_c_opr.h used when compiling the loader is consistent "
"with the runtime caller build used", "with the runtime caller build used",
...@@ -531,8 +531,8 @@ cg::OperatorNodeBase* ExternCOprRunner::shallow_copy( ...@@ -531,8 +531,8 @@ cg::OperatorNodeBase* ExternCOprRunner::shallow_copy(
} }
MGBTensorShape ExternCOprRunner::tensor_shape_to_c(const TensorShape& shape) { MGBTensorShape ExternCOprRunner::tensor_shape_to_c(const TensorShape& shape) {
mgb_assert(shape.ndim <= MGB_TENSOR_MAX_NDIM, "shape ndim too large: %zu", mgb_throw_if(shape.ndim > MGB_TENSOR_MAX_NDIM, MegBrainError,
shape.ndim); "shape ndim too large: %zu", shape.ndim);
MGBTensorShape ret; MGBTensorShape ret;
ret.ndim = shape.ndim; ret.ndim = shape.ndim;
for (size_t i = 0; i < shape.ndim; ++i) { for (size_t i = 0; i < shape.ndim; ++i) {
......
...@@ -41,7 +41,8 @@ DType OprLoadContextRawPOD::read_param() { ...@@ -41,7 +41,8 @@ DType OprLoadContextRawPOD::read_param() {
if (m_check_param_tag) { if (m_check_param_tag) {
uint32_t tag; uint32_t tag;
read_raw(&tag, sizeof(tag)); read_raw(&tag, sizeof(tag));
mgb_assert(tag == megdnn::param::FakeSerializedDType::TAG); mgb_throw_if(tag != megdnn::param::FakeSerializedDType::TAG,
MegBrainError, "ERROR tag");
} }
return serialization::deserialize_dtype( return serialization::deserialize_dtype(
[this](void* data, size_t len) { read_raw(data, len); }); [this](void* data, size_t len) { read_raw(data, len); });
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册