From 8f7f52ae4d2b5fc09e4bcacf44240007061e0025 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 30 Dec 2020 21:56:26 +0800 Subject: [PATCH] feat(jit): add memfwd in jit executor opr GitOrigin-RevId: b58860bbe87582023d96fdfe9d2cb5c6c93b8731 --- dnn/src/common/pooling.cpp | 2 +- src/jit/impl/executor_opr.cpp | 7 +++++ src/jit/include/megbrain/jit/executor_opr.h | 6 +++- src/jit/test/fusion.cpp | 6 ++++ src/opr/impl/basic_arith.cpp | 27 +--------------- src/opr/impl/internal/identical_fwd.cpp | 31 +++++++++++++++++++ src/opr/include/megbrain/opr/basic_arith.h | 3 +- .../megbrain/opr/internal/identical_fwd.h | 13 ++++++++ 8 files changed, 66 insertions(+), 29 deletions(-) diff --git a/dnn/src/common/pooling.cpp b/dnn/src/common/pooling.cpp index 6eb381fdd..b875746bb 100644 --- a/dnn/src/common/pooling.cpp +++ b/dnn/src/common/pooling.cpp @@ -92,7 +92,7 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, size_t sw = this->param().stride_w; size_t ph = this->param().pad_h; size_t pw = this->param().pad_w; - if (ph < fh && pw < fw) { + if (ph >= fh || pw >= fw) { megdnn_log_error( "pooling padding size (%zu %zu) should not be bigger than " "window size (%zu %zu), it only can be used in CaffePooling", diff --git a/src/jit/impl/executor_opr.cpp b/src/jit/impl/executor_opr.cpp index f76506d1b..b3d975d0d 100644 --- a/src/jit/impl/executor_opr.cpp +++ b/src/jit/impl/executor_opr.cpp @@ -135,6 +135,13 @@ void JITExecutor::init_output_mem_plan(bool dynamic) { m_args.need_update = true; } +void JITExecutor::mem_plan_fwd_in2out_writable() { + //! currently mem fwd only support elemwise fusion + if (m_feature_bits != JITFeatureBits::NONE) return; + mixin_mem_plan_fwd_in2out_writable(*this); +} + + SymbolVar JITExecutor::make(const InternalGraphPtr& internal_graph, const VarNodeArray& inputs, const OperatorNodeConfig& config) { diff --git a/src/jit/include/megbrain/jit/executor_opr.h b/src/jit/include/megbrain/jit/executor_opr.h index dabfede6f..4bc8b5b5f 100644 --- a/src/jit/include/megbrain/jit/executor_opr.h +++ b/src/jit/include/megbrain/jit/executor_opr.h @@ -13,6 +13,7 @@ #include "megbrain/graph/operator_node.h" #include "megbrain/jit/internal_graph.h" +#include "megbrain/opr/internal/identical_fwd.h" #if MGB_JIT @@ -31,7 +32,8 @@ class Compiler; * JITExecutor generates runtime Args for this specific inputs, and calls * methods in Compiler to get the Executable object for actual computing. */ -MGB_DEFINE_OPR_CLASS(JITExecutor, cg::SingleCNOperatorNodeBase) // { +MGB_DEFINE_OPR_CLASS(JITExecutor, cg::SingleCNOperatorNodeBase, + opr::mixin::FwdIn2OutWritableHelper) // { using ModeTrait = megdnn::Elemwise::ModeTrait; InternalGraphPtr m_internal_graph; @@ -57,6 +59,8 @@ public: void init_output_mem_plan(bool dynamic) override; + void mem_plan_fwd_in2out_writable() override; + const InternalGraph& internal_graph() const { return *m_internal_graph; } const InternalGraphPtr internal_graph_ptr() const { diff --git a/src/jit/test/fusion.cpp b/src/jit/test/fusion.cpp index b60a8ef49..0561d72b1 100644 --- a/src/jit/test/fusion.cpp +++ b/src/jit/test/fusion.cpp @@ -137,6 +137,12 @@ void run(Backend backend, CompNode cn) { // only one broadcast is allowed in JIT fusion ASSERT_EQ(1u, jits[0]->input().size()); ASSERT_EQ(4u, jits[1]->input().size()); + + //! check memfwd + ASSERT_EQ(prev_dev_ptr(jits[0]->input(0)), + prev_dev_ptr(jits[0]->output(0))); + ASSERT_EQ(prev_dev_ptr(jits[1]->input(0)), + prev_dev_ptr(jits[1]->output(0))); } template <> diff --git a/src/opr/impl/basic_arith.cpp b/src/opr/impl/basic_arith.cpp index 997564057..1a97ecf09 100644 --- a/src/opr/impl/basic_arith.cpp +++ b/src/opr/impl/basic_arith.cpp @@ -338,32 +338,7 @@ void Elemwise::broadcast_collective_collapse( } void Elemwise::mem_plan_fwd_in2out_writable() { - auto &&inp = input(); - auto isize = inp.size(); - mgb_assert(isize <= 6); - bool have_conflict[6] = {false}; - for (size_t i = 0; i < isize; ++i) { - for (size_t j = i + 1; j < isize; ++j) { - auto type = cg::get_mem_plan_intersection_type(inp[i], inp[j]); - using Type = cg::MemPlanIntersectionType; - bool overlap = type == Type::OVERLAP; - bool self_fwd = type == Type::IDENTICAL && - (!inp[i]->layout().is_contiguous() || - !inp[j]->layout().is_contiguous()); - if (overlap || self_fwd) { - have_conflict[i] = true; - have_conflict[j] = true; - } - } - } - auto o = output(0); - for (size_t idx = 0; idx < isize; ++ idx) { - auto i = inp[idx]; - // equal shape means no broadcast - if (!have_conflict[idx] && - o->shape().eq_shape(i->shape()) && i->layout().is_contiguous()) - o->set_fwd_in2out_writable(i); - } + mixin_mem_plan_fwd_in2out_writable(*this); } void Elemwise::scn_do_execute() { diff --git a/src/opr/impl/internal/identical_fwd.cpp b/src/opr/impl/internal/identical_fwd.cpp index 8bbf576df..0afd21a7d 100644 --- a/src/opr/impl/internal/identical_fwd.cpp +++ b/src/opr/impl/internal/identical_fwd.cpp @@ -33,6 +33,37 @@ void mixin::init_rt_force_dynamic_mem_alloc_imply_chain_for_dyn_pass_i2o( valid_out->add_rt_force_dynamic_mem_alloc_imply_chain(opr.input(0)); } +/* ===================== FwdIn2OutWritableHelper ===================== */ +void FwdIn2OutWritableHelper::mixin_mem_plan_fwd_in2out_writable( + OperatorNodeBase& opr) { + auto&& inp = opr.input(); + auto isize = inp.size(); + std::vector have_conflict(isize, false); + for (size_t i = 0; i < isize; ++i) { + for (size_t j = i + 1; j < isize; ++j) { + auto type = cg::get_mem_plan_intersection_type(inp[i], inp[j]); + using Type = cg::MemPlanIntersectionType; + bool overlap = type == Type::OVERLAP; + bool self_fwd = type == Type::IDENTICAL && + (!inp[i]->layout().is_contiguous() || + !inp[j]->layout().is_contiguous()); + if (overlap || self_fwd) { + have_conflict[i] = true; + have_conflict[j] = true; + } + } + } + auto o = opr.output(0); + for (size_t idx = 0; idx < isize; ++ idx) { + auto i = inp[idx]; + // equal shape means no broadcast + if (!have_conflict[idx] && o->shape().eq_shape(i->shape()) && + o->dtype().enumv() == i->dtype().enumv() && + i->layout().is_contiguous()) + o->set_fwd_in2out_writable(i); + } +} + /* ===================== ReadonlyFwdHelper ===================== */ void ReadonlyFwdHelper::mixin_rofwd_init_mem_plan(OperatorNodeBase &opr) { diff --git a/src/opr/include/megbrain/opr/basic_arith.h b/src/opr/include/megbrain/opr/basic_arith.h index 1dfda12d2..2ef8e8d63 100644 --- a/src/opr/include/megbrain/opr/basic_arith.h +++ b/src/opr/include/megbrain/opr/basic_arith.h @@ -58,7 +58,8 @@ namespace intl { * The operands are broadcasted automatically on dimensions of shape one to * match shapes of each other; it works like broadcasting in numpy. */ -MGB_DEFINE_OPR_CLASS(Elemwise, intl::ElemwiseBase) // { +MGB_DEFINE_OPR_CLASS(Elemwise, intl::ElemwiseBase, + mixin::FwdIn2OutWritableHelper) // { using ModeTrait = megdnn::Elemwise::ModeTrait; public: diff --git a/src/opr/include/megbrain/opr/internal/identical_fwd.h b/src/opr/include/megbrain/opr/internal/identical_fwd.h index 57c900f3d..8e2b7702e 100644 --- a/src/opr/include/megbrain/opr/internal/identical_fwd.h +++ b/src/opr/include/megbrain/opr/internal/identical_fwd.h @@ -19,6 +19,19 @@ namespace opr { namespace mixin { +/*! + * \brief mixin for operators which essentially works by forward input to output + */ +class FwdIn2OutWritableHelper : public cg::OperatorNodeMixinBase { +protected: + /*! + * \brief call this function in mem_plan_fwd_in2out_writable(), + * this function will check if the input have conflict to find if the + * output can be forward. + */ + void mixin_mem_plan_fwd_in2out_writable(OperatorNodeBase &opr); +}; + //! for internal use by DynamicOutputIfInputDynamic void init_rt_force_dynamic_mem_alloc_imply_chain_for_dyn_pass_i2o( OperatorNodeBase &opr); -- GitLab