提交 bc95e873 编写于 作者: M Megvii Engine Team

fix(jit): fix jit grad

a) fix shape mismatch when take grad of JITExecutor including Dimshuffle
b) avoid redundant computation in the grad of JITExecutor
c) not pass unused vars as inputs to the grad of JITExecutor to save device memory
d) traverse internal graph only once in JITExecutor ctor instead of traverse
   whole graph in each call of setup_args()
e) expand the gradient graph into the origin graph if all inputs are const

GitOrigin-RevId: ba6a2b29e975c7f63a21785efad87dbda76143d4
上级 fc1ce273
......@@ -88,15 +88,34 @@ JITExecutor::JITExecutor(const InternalGraphPtr& internal_graph,
cg::add_workspace_output(this);
}
// check if output of internal_graph is depend on all placeholders
size_t nr_placeholders = internal_graph_ptr()->placeholders().size();
std::vector<bool> used(nr_placeholders, false);
// check if there is reduce or dimshuffle opr
cg::DepOprIter{[this](cg::OperatorNodeBase* opr) {
cg::DepOprIter{[this, nr_placeholders, &used](cg::OperatorNodeBase* opr) {
if (opr->same_type<opr::Reduce>()) {
m_feature_bits |= JITFeatureBits::REDUCE;
}
if (opr->same_type<opr::Dimshuffle>()) {
m_feature_bits |= JITFeatureBits::DIMSHUFFLE;
}
if (auto ph = opr->try_cast_final<JITPlaceholder>()) {
mgb_assert(ph->input_id() < nr_placeholders,
"bad placeholders %s in JITExecutor %s",
ph->cname(), cname());
used[ph->input_id()] = true;
}
}}.add(internal_graph->output());
for (size_t i = 0; i < nr_placeholders; ++ i) {
mgb_assert(used[i],
"placeholder %s is not depended on the output of %s",
internal_graph_ptr()->placeholders()[i]->cname(), cname());
}
if (has_dimshuffle()) {
prepare_dimshuffle();
}
}
void JITExecutor::add_input_layout_constraint() {
......@@ -151,14 +170,14 @@ void JITExecutor::scn_do_execute() {
//! can be ignored
void JITExecutor::do_dimshuffle() {
auto get_dimshuffled_layout = [](const TensorLayout& ily, int32_t* pattern,
size_t pattern_len) {
static auto get_dimshuffled_layout = [](const TensorLayout& ily,
std::vector<int> pattern) {
TensorLayout oly{ily.dtype};
oly.ndim = pattern_len;
oly.ndim = pattern.size();
bool input_used[TensorLayout::MAX_NDIM] = {0};
for (uint32_t idx = 0; idx < pattern_len; ++idx) {
for (uint32_t idx = 0; idx < pattern.size(); ++idx) {
auto i = pattern[idx];
if (i < 0) {
oly.shape[idx] = 1;
......@@ -179,53 +198,20 @@ void JITExecutor::do_dimshuffle() {
return oly;
};
// DFS to make sure traverse the dimshuffles in one branch
std::unordered_set<VarNode*> visited;
std::vector<OperatorNodeBase*> stack(0);
std::vector<uint8_t> idx(0); // input index
stack.push_back(m_internal_graph->output()->owner_opr());
idx.push_back(0);
while (!stack.empty()) {
if (idx.back() < stack.back()->input().size() &&
!visited.count(stack.back()->input(idx.back()))) {
visited.insert(stack.back()->input(idx.back()));
stack.push_back(stack.back()->input(idx.back())->owner_opr());
if (stack.back()->same_type<jit::JITPlaceholder>()) {
auto jitph = gopt::try_cast_as_op<JITPlaceholder>(stack.back());
size_t input_id = jitph->input_id();
auto&& input = m_args.inputs[input_id];
for (int i = stack.size() - 1; i >= 0; --i) {
if (stack[i]->same_type<opr::Dimshuffle>()) {
auto param =
stack[i]->cast_final_safe<opr::Dimshuffle>()
.param();
mgb_assert(input.layout.ndim == param.ndim,
"input ndim mismatch for Dimshuffle: "
"expect=%u "
"actual=%zu",
param.ndim, input.layout.ndim);
auto dimshuffled_layout = get_dimshuffled_layout(
input.layout, param.pattern, param.pattern_len);
input.layout = dimshuffled_layout;
}
}
stack.pop_back();
++idx.back();
} else {
idx.push_back(0);
}
} else {
stack.pop_back();
idx.pop_back();
if (!stack.empty())
++idx.back();
}
for (auto&& i : m_internal_graph->placeholders()) {
auto&& input = m_args.inputs[i->input_id()];
auto&& iter = m_jitph2dimshuffle.find(i);
if (iter == m_jitph2dimshuffle.end()) continue;
auto&& param = iter->second;
mgb_assert(input.layout.ndim == param.second,
"input ndim mismatch for Dimshuffle: "
"expect=%u "
"actual=%zu",
param.second, input.layout.ndim);
auto dimshuffled_layout = get_dimshuffled_layout(
input.layout, param.first);
input.layout = dimshuffled_layout;
}
}
void JITExecutor::update_args() {
......@@ -259,7 +245,9 @@ void JITExecutor::update_args() {
}
//! dimshuffle opr need to change the input.
do_dimshuffle();
if (has_dimshuffle()) {
do_dimshuffle();
}
if (m_compiler->property().contain_flag(CPFlag::NEED_INPUT_COLLAPSE)) {
// collective collapse datum layout, try to reduce the output ndim
......@@ -304,6 +292,82 @@ void JITExecutor::update_args() {
m_args.need_update = false;
}
void JITExecutor::prepare_dimshuffle() {
std::unordered_set<OperatorNodeBase*> visited;
std::vector<OperatorNodeBase*> stack(0);
std::vector<uint8_t> idx(0); // input index
using Param = DimshuffleParam;
std::vector<Param> dimshuffle_stack;
auto merge_dimshuffle = [&](const opr::Dimshuffle::Param& p) {
if (dimshuffle_stack.empty()) {
dimshuffle_stack.emplace_back();
auto&& param = dimshuffle_stack.back();
param.first.insert(param.first.end(), p.pattern, p.pattern + p.pattern_len);
param.second = p.ndim;
} else {
// merge(p, src) -> param and it has performing dimshuffle(dimshuffle(x, p), src)
// is equivalent to dimshuffle(x, param)
dimshuffle_stack.emplace_back();
auto&& param = dimshuffle_stack.back();
auto&& src = dimshuffle_stack[dimshuffle_stack.size() - 2];
mgb_assert(p.pattern_len == src.second);
param.first.resize(src.first.size());
for (size_t i = 0; i < src.first.size(); ++ i) {
if (src.first[i] == -1) {
param.first[i] = -1;
} else {
param.first[i] = p.pattern[src.first[i]];
}
}
param.second = p.ndim;
}
};
auto push_back = [&](cg::OperatorNodeBase* op) {
mgb_assert(!op->same_type<jit::JITPlaceholder>());
if (auto o = op->try_cast_final<opr::Dimshuffle>()) {
merge_dimshuffle(o->param());
}
stack.push_back(op);
idx.push_back(0);
};
auto pop_back = [&]() {
auto&& op = stack.back();
if (op->same_type<opr::Dimshuffle>()) {
dimshuffle_stack.pop_back();
}
stack.pop_back();
idx.pop_back();
};
push_back(m_internal_graph->output()->owner_opr());
while (!stack.empty()) {
if (idx.back() < stack.back()->input().size()) {
auto cur_opr = stack.back()->input(idx.back())->owner_opr();
if (visited.insert(cur_opr).second) {
if (auto jitph = cur_opr->try_cast_final<jit::JITPlaceholder>()) {
if (!dimshuffle_stack.empty()) {
mgb_assert(
m_jitph2dimshuffle.emplace(jitph, dimshuffle_stack.back()).second,
"already visited JITPlaceholder %s",
jitph->cname());
}
++ idx.back();
} else {
push_back(cur_opr);
}
} else {
++ idx.back();
}
} else {
pop_back();
if (!stack.empty())
++ idx.back();
}
}
}
const JITExecutor::Args& JITExecutor::args() const {
if (m_args.need_update) {
const_cast<JITExecutor*>(this)->update_args();
......@@ -383,6 +447,56 @@ megdnn::TensorShape JITExecutor::broadcasted_input_shape() const {
#if MGB_ENABLE_GRAD
namespace {
class InternalGraphRewriter {
ThinHashMap<VarNode*, VarNode*> m_var_map;
VarNode* m_dest_var;
VarNodeArray m_new_inp;
VarNode* get_var(VarNode* var) {
auto&& iter = m_var_map.find(var);
if (iter != m_var_map.end()) {
return iter->second;
}
return var;
}
public:
InternalGraphRewriter(VarNode* dest_var)
:m_dest_var{dest_var}{}
void iter(thin_function<void(cg::OperatorNodeBase*)>&& cb) {
m_var_map.clear();
cg::DepOprIter{std::move(cb)}.add(m_dest_var->owner_opr());
m_dest_var = get_var(m_dest_var);
}
VarNode* dest_var() {
return m_dest_var;
}
void replace_var(VarNode* src, VarNode* dst) {
// Note: do not perform var replacing recursively
// when we extract used placeholders from internal graph, we don't
// consider placeholder replacement pair (a to b), (b to c) as a
// var replacing chain (a to b to c) but as a injective function
// from (a, b) to (b, c)
// in other cases, each var node would be passed as \p src or
// \p dst at most once
m_var_map[src] = dst;
}
void auto_replace_outputs(cg::OperatorNodeBase* opr) {
// in JIT internal graph, output size of opr is always 1
mgb_assert(opr->usable_output().size() == 1);
m_new_inp.clear();
bool need_replace = false;
for (auto&& i : opr->input()) {
auto inp = get_var(i);
m_new_inp.push_back(inp);
need_replace |= (inp != i);
}
if (need_replace) {
auto new_op = serialization::copy_opr_shallow(*opr, m_new_inp);
replace_var(opr->output(0), new_op->output(0));
}
}
};
} // anonymous namespace
MGB_IMPL_OPR_GRAD(JITExecutor) {
VarNodeArray grad_inputs;
for (auto input : opr.input())
......@@ -404,49 +518,120 @@ MGB_IMPL_OPR_GRAD(JITExecutor) {
if (gx.node()->owner_opr()->same_type<opr::InvalidGrad>()) {
return opr::InvalidGrad::make(opr, wrt_idx);
}
// early return if grad expression is single node
for (size_t i = 0; i < fwd_igraph_ptr->placeholders().size(); ++i) {
if (gx.node() == fwd_igraph_ptr->placeholders()[i]->output(0)) {
return grad_inputs[i];
}
}
if (gx.node() == og_ph.node()) {
return out_grad[0];
}
if (gx.node() == fwd_igraph_ptr->output()) {
return opr.output(0);
}
if (auto imm = gopt::try_cast_as_op<opr::ImmutableTensor>(gx.node()->owner_opr())) {
HostTensorND hval{grad_inputs[0]->comp_node()};
hval.copy_from(imm->value()).sync();
return opr::ImmutableTensor::make(*imm->owner_graph(), hval).node();
}
// replace output var in internal graph with output placeholder, so
// we could forward opr.output(computeed by forward JITExecutor) into
// placeholder to avoid redundant computation
InternalGraphRewriter rewriter{gx.node()};
rewriter.iter([&rewriter, &fwd_igraph_ptr,
&output_ph](cg::OperatorNodeBase* opr) {
if (opr == fwd_igraph_ptr->output()->owner_opr()) {
rewriter.replace_var(opr->output(0), output_ph.node());
return;
}
rewriter.auto_replace_outputs(opr);
});
static auto expand_into_origin_graph = [](cg::OperatorNodeBase* opr,
InternalGraphRewriter& rewriter, const VarNodeArray& grad_inputs) {
if (auto ph = gopt::try_cast_as_op<JITPlaceholder>(opr)) {
rewriter.replace_var(
opr->output(0), grad_inputs.at(ph->input_id()));
return;
}
if (auto imm = gopt::try_cast_as_op<opr::ImmutableTensor>(opr)) {
HostTensorND hval{grad_inputs[0]->comp_node()};
hval.copy_from(imm->value()).sync();
rewriter.replace_var(opr->output(0),
opr::ImmutableTensor::make(*opr->owner_graph(), hval).node());
return;
}
rewriter.auto_replace_outputs(opr);
};
if (opr.compiler()->property().feature_bits & JITFeatureBits::REDUCE) {
// expand the gradient graph into the original graph to handle bcast
// oprs
ThinHashMap<VarNode*, VarNode*> old2new;
VarNodeArray new_inp;
auto on_opr = [&old2new, &grad_inputs,
&new_inp](cg::OperatorNodeBase* opr) {
using namespace std::placeholders;
rewriter.iter(std::bind(expand_into_origin_graph, _1,
std::ref(rewriter), std::cref(grad_inputs)));
return rewriter.dest_var();
} else {
VarNodeArray new_grad_inputs;
PlaceholderArray placeholders;
bool all_inp_const = true;
// gx was not depend on all JITPlaceholders so we need to extract used
// placeholders and build a new internal graph
rewriter.iter([&rewriter, &grad_inputs, &new_grad_inputs,
&placeholders, &all_inp_const](cg::OperatorNodeBase* opr) {
if (auto ph = gopt::try_cast_as_op<JITPlaceholder>(opr)) {
old2new[opr->output(0)] = grad_inputs.at(ph->input_id());
return;
}
if (auto imm = gopt::try_cast_as_op<opr::ImmutableTensor>(opr)) {
HostTensorND hval{grad_inputs[0]->comp_node()};
hval.copy_from(imm->value()).sync();
old2new[opr->output(0)] =
opr::ImmutableTensor::make(*opr->owner_graph(), hval)
.node();
new_grad_inputs.push_back(grad_inputs[ph->input_id()]);
auto new_ph = JITPlaceholder::make(
new_grad_inputs.back(), placeholders.size())
.node()->owner_opr();
placeholders.push_back(new_ph->try_cast_final<JITPlaceholder>());
mgb_assert(placeholders.back());
rewriter.replace_var(opr->output(0), new_ph->output(0));
if (!cg::is_const_var_value(new_grad_inputs.back())) {
all_inp_const = false;
}
return;
}
new_inp.clear();
for (auto inp : opr->input()) {
new_inp.push_back(old2new.at(inp));
}
auto new_opr = serialization::copy_opr_shallow(*opr, new_inp);
old2new[opr->output(0)] = new_opr->output(0);
};
cg::DepOprIter{on_opr}.add(gx.node());
return old2new.at(gx.node());
} else {
PlaceholderArray placeholders = fwd_igraph_ptr->placeholders();
for (SymbolVar i : {output_ph, og_ph}) {
placeholders.push_back(
&i.node()->owner_opr()->cast_final_safe<JITPlaceholder>());
rewriter.auto_replace_outputs(opr);
});
if (all_inp_const) {
// if all_inp_const, expand grad graph into origin graph by replace
// placeholders with const inputs, so it could benefit from static
// infer and const folding mechanism
using namespace std::placeholders;
rewriter.iter(std::bind(expand_into_origin_graph, _1,
std::ref(rewriter), std::cref(new_grad_inputs)));
return rewriter.dest_var();
}
for (size_t i = 0; i < placeholders.size(); ++i) {
if (gx.node() == placeholders[i]->output(0)) {
return grad_inputs[i];
gx = rewriter.dest_var();
auto shape_infer = fwd_igraph_ptr->shape_infer();
if (opr.has_dimshuffle()) {
auto&& iter = opr.dimshuffle_params().find(
fwd_igraph_ptr->placeholders()[wrt_idx]);
if (iter != opr.dimshuffle_params().end()) {
auto&& pattern = iter->second.first;
auto&& ndim = iter->second.second;
std::vector<int> back(ndim, -1);
for (size_t i = 0; i < pattern.size(); i ++) {
// outdim[i] is indim[j]
auto j = pattern[i];
if (j >= 0) {
mgb_assert(back[j] == -1,
"taking grad for Dimshuffle with duplicated "
"input axis unsupported");
back[j] = i;
}
}
shape_infer = opr::Dimshuffle::make(shape_infer, back, pattern.size()).node();
}
}
auto grad_ig = std::make_shared<InternalGraph>(
gx.node(), fwd_igraph_ptr->shape_infer(), nullptr,
gx.node(), shape_infer, nullptr,
std::move(placeholders));
auto grad_jit = JITExecutor::make(grad_ig, grad_inputs);
auto grad_jit = JITExecutor::make(grad_ig, new_grad_inputs);
if (opr.input_broadcastable()[wrt_idx]) {
grad_jit = opr::reduce_sum(
......
......@@ -26,7 +26,6 @@ JITPlaceholder::JITPlaceholder(VarNode* src_var, size_t id, InpType inp_type)
{}),
m_inp_type{inp_type},
m_id{id} {
add_equivalence_component<ScalarHash<size_t>>(m_id);
mgb_assert(src_var->dtype().category() == DTypeCategory::FLOAT ||
src_var->dtype().category() == DTypeCategory::INT,
"JIT can only be applied to float/int operators, got %s",
......
......@@ -35,6 +35,7 @@ MGB_DEFINE_OPR_CLASS(JITExecutor, cg::SingleCNOperatorNodeBase) // {
using ModeTrait = megdnn::Elemwise::ModeTrait;
InternalGraphPtr m_internal_graph;
using DimshuffleParam = std::pair<std::vector<int>, uint32_t>;
public:
using Mode = opr::Elemwise::Mode;
......@@ -112,6 +113,11 @@ public:
return static_cast<bool>(m_feature_bits & JITFeatureBits::DIMSHUFFLE);
}
const ThinHashMap<jit::JITPlaceholder*, DimshuffleParam>&
dimshuffle_params() const {
return m_jitph2dimshuffle;
}
//! get broadcasted shape of inputs
megdnn::TensorShape broadcasted_input_shape() const;
......@@ -124,8 +130,14 @@ private:
Compiler* const m_compiler = nullptr;
Executable* m_executable = nullptr;
std::vector<bool> m_input_broadcastable;
// JITPlaceHolder -> pair of (dimshuffle pattern, ndim)
// do DFS on internal graph only once in prepare_dimshuffle(), so we can
// easily get the dimshuffle param which should be applied on given
// JITPlaceholder
ThinHashMap<jit::JITPlaceholder*, DimshuffleParam> m_jitph2dimshuffle;
void update_args();
void do_dimshuffle();
void prepare_dimshuffle();
NodeProp* do_make_node_prop() const override;
};
......
......@@ -61,8 +61,6 @@ public:
const PlaceholderArray& placeholders() const { return m_placeholders; }
static InternalGraphPtr expand_excutor_op(const InternalGraphPtr&);
private:
// For compilation cache, if the output_for_cache is same means the
// expression tree is same.
......
......@@ -1435,6 +1435,16 @@ TEST(TestJITNvrtc, DimshuffleGrad) {
funcs.second->execute();
MGB_ASSERT_TENSOR_NEAR(host_y1, host_y2, 1e-3);
}
{
FusionChecker checker{2,
[](const SymbolVarArray& inp) -> SymbolVar {
auto var = opr::Dimshuffle::make(inp[0], {1, 2, 3, 0});
return inp[1] * var;
},
CompNode::load("gpu0")};
checker.set_jit_level(1)
.run({TensorShape{1, 2, 3, 4}, {2, 3, 4, 1}});
}
}
#endif // MGB_JIT
......
......@@ -98,7 +98,7 @@ void FusionChecker::ensure_init_graph() {
} else {
ComputingGraph::Options opt;
opt.graph_opt_level = 3;
opt.graph_opt.jit = 2;
opt.graph_opt.jit = m_jit_level;
unpack_vector(gopt::GraphOptimizer{}
.add_preset_passes(true, nullptr, &opt)
.apply({{m_truth_y}})
......
......@@ -65,6 +65,13 @@ public:
return *this;
}
//! set jit level, default is 2, see graph_opt.jit in graph options
//! for more details
FusionChecker& set_jit_level(uint8_t jit_level) {
m_jit_level = jit_level;
return *this;
}
/*!
* \brief run and check correctness
*
......@@ -76,6 +83,7 @@ private:
bool m_check_opr_type = true;
bool m_direct_build = false;
const size_t m_nr_input;
uint8_t m_jit_level = 2;
const CompNode m_comp_node;
HostTensorGenerator<> m_input_gen;
SmallVector<std::shared_ptr<HostTensorND>> m_inputs_val;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册