提交 8f7f52ae 编写于 作者: M Megvii Engine Team

feat(jit): add memfwd in jit executor opr

GitOrigin-RevId: b58860bbe87582023d96fdfe9d2cb5c6c93b8731
上级 dfb2b2ce
......@@ -92,7 +92,7 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src,
size_t sw = this->param().stride_w;
size_t ph = this->param().pad_h;
size_t pw = this->param().pad_w;
if (ph < fh && pw < fw) {
if (ph >= fh || pw >= fw) {
megdnn_log_error(
"pooling padding size (%zu %zu) should not be bigger than "
"window size (%zu %zu), it only can be used in CaffePooling",
......
......@@ -135,6 +135,13 @@ void JITExecutor::init_output_mem_plan(bool dynamic) {
m_args.need_update = true;
}
void JITExecutor::mem_plan_fwd_in2out_writable() {
//! currently mem fwd only support elemwise fusion
if (m_feature_bits != JITFeatureBits::NONE) return;
mixin_mem_plan_fwd_in2out_writable(*this);
}
SymbolVar JITExecutor::make(const InternalGraphPtr& internal_graph,
const VarNodeArray& inputs,
const OperatorNodeConfig& config) {
......
......@@ -13,6 +13,7 @@
#include "megbrain/graph/operator_node.h"
#include "megbrain/jit/internal_graph.h"
#include "megbrain/opr/internal/identical_fwd.h"
#if MGB_JIT
......@@ -31,7 +32,8 @@ class Compiler;
* JITExecutor generates runtime Args for this specific inputs, and calls
* methods in Compiler to get the Executable object for actual computing.
*/
MGB_DEFINE_OPR_CLASS(JITExecutor, cg::SingleCNOperatorNodeBase) // {
MGB_DEFINE_OPR_CLASS(JITExecutor, cg::SingleCNOperatorNodeBase,
opr::mixin::FwdIn2OutWritableHelper) // {
using ModeTrait = megdnn::Elemwise::ModeTrait;
InternalGraphPtr m_internal_graph;
......@@ -57,6 +59,8 @@ public:
void init_output_mem_plan(bool dynamic) override;
void mem_plan_fwd_in2out_writable() override;
const InternalGraph& internal_graph() const { return *m_internal_graph; }
const InternalGraphPtr internal_graph_ptr() const {
......
......@@ -137,6 +137,12 @@ void run<basic>(Backend backend, CompNode cn) {
// only one broadcast is allowed in JIT fusion
ASSERT_EQ(1u, jits[0]->input().size());
ASSERT_EQ(4u, jits[1]->input().size());
//! check memfwd
ASSERT_EQ(prev_dev_ptr(jits[0]->input(0)),
prev_dev_ptr(jits[0]->output(0)));
ASSERT_EQ(prev_dev_ptr(jits[1]->input(0)),
prev_dev_ptr(jits[1]->output(0)));
}
template <>
......
......@@ -338,32 +338,7 @@ void Elemwise::broadcast_collective_collapse(
}
void Elemwise::mem_plan_fwd_in2out_writable() {
auto &&inp = input();
auto isize = inp.size();
mgb_assert(isize <= 6);
bool have_conflict[6] = {false};
for (size_t i = 0; i < isize; ++i) {
for (size_t j = i + 1; j < isize; ++j) {
auto type = cg::get_mem_plan_intersection_type(inp[i], inp[j]);
using Type = cg::MemPlanIntersectionType;
bool overlap = type == Type::OVERLAP;
bool self_fwd = type == Type::IDENTICAL &&
(!inp[i]->layout().is_contiguous() ||
!inp[j]->layout().is_contiguous());
if (overlap || self_fwd) {
have_conflict[i] = true;
have_conflict[j] = true;
}
}
}
auto o = output(0);
for (size_t idx = 0; idx < isize; ++ idx) {
auto i = inp[idx];
// equal shape means no broadcast
if (!have_conflict[idx] &&
o->shape().eq_shape(i->shape()) && i->layout().is_contiguous())
o->set_fwd_in2out_writable(i);
}
mixin_mem_plan_fwd_in2out_writable(*this);
}
void Elemwise::scn_do_execute() {
......
......@@ -33,6 +33,37 @@ void mixin::init_rt_force_dynamic_mem_alloc_imply_chain_for_dyn_pass_i2o(
valid_out->add_rt_force_dynamic_mem_alloc_imply_chain(opr.input(0));
}
/* ===================== FwdIn2OutWritableHelper ===================== */
void FwdIn2OutWritableHelper::mixin_mem_plan_fwd_in2out_writable(
OperatorNodeBase& opr) {
auto&& inp = opr.input();
auto isize = inp.size();
std::vector<bool> have_conflict(isize, false);
for (size_t i = 0; i < isize; ++i) {
for (size_t j = i + 1; j < isize; ++j) {
auto type = cg::get_mem_plan_intersection_type(inp[i], inp[j]);
using Type = cg::MemPlanIntersectionType;
bool overlap = type == Type::OVERLAP;
bool self_fwd = type == Type::IDENTICAL &&
(!inp[i]->layout().is_contiguous() ||
!inp[j]->layout().is_contiguous());
if (overlap || self_fwd) {
have_conflict[i] = true;
have_conflict[j] = true;
}
}
}
auto o = opr.output(0);
for (size_t idx = 0; idx < isize; ++ idx) {
auto i = inp[idx];
// equal shape means no broadcast
if (!have_conflict[idx] && o->shape().eq_shape(i->shape()) &&
o->dtype().enumv() == i->dtype().enumv() &&
i->layout().is_contiguous())
o->set_fwd_in2out_writable(i);
}
}
/* ===================== ReadonlyFwdHelper ===================== */
void ReadonlyFwdHelper::mixin_rofwd_init_mem_plan(OperatorNodeBase &opr) {
......
......@@ -58,7 +58,8 @@ namespace intl {
* The operands are broadcasted automatically on dimensions of shape one to
* match shapes of each other; it works like broadcasting in numpy.
*/
MGB_DEFINE_OPR_CLASS(Elemwise, intl::ElemwiseBase) // {
MGB_DEFINE_OPR_CLASS(Elemwise, intl::ElemwiseBase,
mixin::FwdIn2OutWritableHelper) // {
using ModeTrait = megdnn::Elemwise::ModeTrait;
public:
......
......@@ -19,6 +19,19 @@ namespace opr {
namespace mixin {
/*!
* \brief mixin for operators which essentially works by forward input to output
*/
class FwdIn2OutWritableHelper : public cg::OperatorNodeMixinBase {
protected:
/*!
* \brief call this function in mem_plan_fwd_in2out_writable(),
* this function will check if the input have conflict to find if the
* output can be forward.
*/
void mixin_mem_plan_fwd_in2out_writable(OperatorNodeBase &opr);
};
//! for internal use by DynamicOutputIfInputDynamic
void init_rt_force_dynamic_mem_alloc_imply_chain_for_dyn_pass_i2o(
OperatorNodeBase &opr);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册