提交 d6d3e6af 编写于 作者: D dzhwinter

add more skip strategy

上级 2739096e
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/graph_print_pass.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
......@@ -54,6 +55,11 @@ class GraphvizOp : public GraphvizNode {
}
}
template <typename Callback>
void AddCustomEdge(const Callback& cb) {
stream_ << cb() << std::endl;
}
private:
std::ostringstream stream_;
};
......@@ -68,12 +74,47 @@ std::vector<T*> FilterByNodeWrapper(const Container& con) {
return ret;
}
// bool DetectCircleRecursive(const std::map<ir::Node*,
// std::unordered_set<ir::Node*>>, std::unordered_set<ir::Node*>* visited,
// std::unordered_set<ir::Node*> *in_trace, std::vector<std::vector<ir::Node*>>*
// circles) {
// if (visited->find(node) == visited->end()) {
// visited->insert(node);
// in_trace->insert(node);
// for (ir::Node *in : adj_list.at(node)) {
// if (visited->find(in) == visited->end() &&
// HasCircleHelper(in, adj_list, visited, in_trace)) {
// return true;
// } else if (in_trace->find(in) != in_trace->end()) {
// circles->push_back(in_trace);
// return true;
// }
// }
// }
// in_trace->erase(node);
// return false;
// }
// bool DetectCircle(const std::map<ir::Node*, std::unordered_set<ir::Node*>>&
// adj_list, std::vector<std::vector<ir::Node*>>* circles) {
// std::unordered_set<ir::Node *> visited;
// std::unordered_set<ir::Node *> in_trace;
// bool has_circle = false;
// for(auto& adj : adj_list) {
// has_circle &= DetectCircleRecursive(adj, adj_list,&visited, &in_trace,
// circles);
// }
// return has_circle;
// }
std::unordered_map<ir::Node*, int> SSAGraphPrinterImpl::ToGraphvizNode(
const ir::Graph& graph) const {
// Convert to GraphvizNode format
auto& graphviz_nodes = graph.Get<GraphvizNodes>(kGraphviz);
graphviz_nodes.clear();
std::unordered_map<ir::Node*, int> vars;
std::unordered_map<ir::Node*, GraphvizOp*> ops;
int var_id = 0;
int op_id = 0;
for (auto& node : graph.Nodes()) {
......@@ -81,11 +122,33 @@ std::unordered_map<ir::Node*, int> SSAGraphPrinterImpl::ToGraphvizNode(
graphviz_nodes.emplace(new GraphvizVar(node, var_id));
vars.emplace(std::make_pair(node, var_id++));
} else if (node->IsOp()) {
graphviz_nodes.emplace(new GraphvizOp(node, op_id++));
std::unique_ptr<GraphvizOp> op(new GraphvizOp(node, op_id++));
ops[node] = op.get();
graphviz_nodes.emplace(std::move(op));
// graphviz_nodes.emplace(new GraphvizOp(node, op_id++));
// ops.emplace(std::make_pair(node, graphviz_nodes.back().get()));
} else {
PADDLE_THROW("Unknown op type");
}
}
// Detect circle. Draw circle in different lines
std::vector<std::vector<ir::Node*>> circles;
const std::string kCircleEdge = "[color=red,penwidth=3.0]";
if (ir::FindCircleSubGraph(graph, &circles)) {
VLOG(3) << "Graph has circle! circles count : " << circles.size();
for (auto& circle : circles) {
for (size_t i = 0; i < circle.size() - 1; ++i) {
GraphvizOp* prev = ops[circle[i]];
GraphvizOp* next = ops[circle[i + 1]];
std::string prev_op = "op_" + std::to_string(prev->Id());
std::string next_op = "op_" + std::to_string(next->Id());
prev->AddCustomEdge([&]() -> std::string {
return prev_op + "->" + next_op + kCircleEdge;
});
}
}
}
return vars;
}
......
......@@ -31,6 +31,8 @@ class GraphvizNode {
GraphvizNode(ir::Node* n, const int& i) : node_(n), id_(i) {}
virtual ~GraphvizNode() = default;
int Id() const { return id_; }
protected:
ir::Node* node_;
int id_;
......
......@@ -19,6 +19,9 @@ REGISTER_OPERATOR(sum, paddle::framework::DummyOp,
paddle::framework::SumOpMaker);
REGISTER_OPERATOR(split, paddle::framework::DummyOp,
paddle::framework::SplitOpMaker);
REGISTER_OPERATOR(assign, paddle::framework::DummyOp,
paddle::framework::AssignOpMaker,
paddle::framework::DummyVarTypeInference);
/*
a @ b
......@@ -54,6 +57,12 @@ inline static ProgramDesc FillProgramDesc() {
op->SetInput("X", {"d", "e"});
op->SetOutput("Out", {"d"});
}
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("assign");
op->SetInput("X", {"d"});
op->SetOutput("Out", {"d"});
}
return prog;
}
......@@ -74,6 +83,108 @@ TEST(SSAGraphPrinter, Normal) {
printer->Print(*graph, *fout);
}
using ir::Graph;
using ir::Node;
void BuildCircleGraph(Graph* g) {
ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation);
ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable);
o1->outputs.push_back(v1);
o1->inputs.push_back(v1);
v1->inputs.push_back(o1);
v1->outputs.push_back(o1);
}
void BuildCircleGraph2(Graph* g) {
ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation);
ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation);
ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable);
ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable);
o1->outputs.push_back(v1);
o2->inputs.push_back(v1);
v1->inputs.push_back(o1);
v1->outputs.push_back(o2);
o2->outputs.push_back(v2);
o1->inputs.push_back(v2);
v2->inputs.push_back(o2);
v2->outputs.push_back(o1);
}
void BuildNoCircleGraph(Graph* g) {
ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation);
ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation);
ir::Node* o3 = g->CreateEmptyNode("op3", Node::Type::kOperation);
ir::Node* o4 = g->CreateEmptyNode("op4", Node::Type::kOperation);
ir::Node* o5 = g->CreateEmptyNode("op5", Node::Type::kOperation);
ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable);
ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable);
ir::Node* v3 = g->CreateEmptyNode("var3", Node::Type::kVariable);
ir::Node* v4 = g->CreateEmptyNode("var4", Node::Type::kVariable);
// o1->v1->o2
o1->outputs.push_back(v1);
o2->inputs.push_back(v1);
v1->inputs.push_back(o1);
v1->outputs.push_back(o2);
// o2->v2->o3
// o2->v2->o4
o2->outputs.push_back(v2);
o3->inputs.push_back(v2);
o4->inputs.push_back(v2);
v2->inputs.push_back(o2);
v2->outputs.push_back(o3);
v2->outputs.push_back(o4);
// o2->v3->o5
o2->outputs.push_back(v3);
o5->inputs.push_back(v3);
v3->inputs.push_back(o2);
v3->outputs.push_back(o5);
// o3-v4->o5
o3->outputs.push_back(v4);
o5->inputs.push_back(v4);
v4->inputs.push_back(o3);
v4->outputs.push_back(o5);
// o2->v3->o1
v3->outputs.push_back(o1);
o1->inputs.push_back(v3);
}
TEST(SSAGraphPrinter, SimpleCircle) {
ProgramDesc prog;
Graph graph(prog);
BuildCircleGraph(&graph);
ASSERT_TRUE(HasCircle(graph));
graph.Set<GraphvizNodes>(kGraphviz, new GraphvizNodes);
std::unique_ptr<SSAGraphPrinter> printer(new SSAGraphPrinterImpl);
// redirect debug graph to a file.
constexpr char graph_path[] = "graph_print_pass_simple_circle.txt";
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_path));
PADDLE_ENFORCE(fout->good());
printer->Print(graph, *fout);
}
TEST(SSAGraphPrinter, ComplexCircle) {
ProgramDesc prog;
Graph graph(prog);
BuildCircleGraph2(&graph);
ASSERT_TRUE(HasCircle(graph));
graph.Set<GraphvizNodes>(kGraphviz, new GraphvizNodes);
std::unique_ptr<SSAGraphPrinter> printer(new SSAGraphPrinterImpl);
// redirect debug graph to a file.
constexpr char graph_path[] = "graph_print_pass_complex_circle.txt";
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_path));
PADDLE_ENFORCE(fout->good());
printer->Print(graph, *fout);
}
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -23,6 +23,7 @@
#include <vector>
#include "paddle/fluid/framework/details/graph_print_pass.h"
#include "paddle/fluid/framework/details/memory_optimize_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_info.h"
// NOTE(dzhwinter): inplace means one op output variable reuse the input space.
......@@ -39,16 +40,20 @@
// auto* out_ptr = out->mutable_data<T>(ctx.GetPlace());
// out_ptr[0] = 0; // input contect is overwrited.
// For backward compacity. if enable_inplace_whitelist is turn on.
// NOTE(dzhwinter):
// Only for backward compacity and stable. if enable_inplace_whitelist is turn
// on.
// only the ops in whitelist will be use inplace strategy.
// if not, all the op will be inplaced if it registered with InplaceClass
DEFINE_bool(
enable_inplace_whitelist, true,
enable_inplace_whitelist, false,
"If this option turns on, only these op in whitelist can be inplaced."
"If it turns off, all of the running op can be candidate of inplaced op."
"Such as scale, elementwise_add"
"By default, it's turned on");
DECLARE_string(memory_optimize_debug);
// clang-format off
const std::string kInplacedOpWhiteList[] = { // NOLINT
"sigmoid",
......@@ -77,63 +82,6 @@ namespace paddle {
namespace framework {
namespace details {
static inline std::string NodeDebugString(ir::Node* var) {
std::ostringstream os;
if (var->IsCtrlVar()) {
os << "kControlDepVarName"
<< " ";
} else if (var->IsOp()) {
os << "kOperation"
<< " " << var->Name();
PADDLE_ENFORCE(var->Op() != nullptr && var->Op()->Type() == var->Name());
} else if (var->IsVar()) {
os << "kVariable"
<< " " << var->Name();
PADDLE_ENFORCE(var->Var() != nullptr && var->Var()->Name() == var->Name());
} else {
PADDLE_THROW("Unknown node type.");
}
return os.str();
}
static inline std::string OpDebugString(ir::Node* var) {
ir::Node* op = var;
if (var->IsVar()) op = var->inputs.at(0);
std::stringstream os;
os << op->Name() << " : ";
os << "Input ";
VLOG(3) << op->Name();
for (auto* var : op->inputs) {
if (var->IsVar() && !var->IsCtrlVar()) {
PADDLE_ENFORCE(var->Var() != nullptr && var->Var()->Name() == var->Name(),
"unmatched desc and var");
// os << var << ":" << var->Name() << " ";
os << var->Name() << " ";
}
}
os << "Output ";
VLOG(3) << op->Name();
for (auto* var : op->outputs) {
VLOG(3) << var;
VLOG(3) << var->Name();
if (!var->IsVar()) {
VLOG(3) << "error";
}
// VLOG(3) << var->Var()->Name();
if (var->IsVar() && !var->IsCtrlVar()) {
PADDLE_ENFORCE(var->Var() != nullptr && var->Var()->Name() == var->Name(),
"unmatched desc and var");
// os << var << ":" << var->Name() << " ";
os << var->Name() << " ";
}
if (var->Name() == "fc_10.tmp_0") {
VLOG(3) << NodeDebugString(var);
}
}
return os.str();
}
static inline ir::Node* GetNextCascadeInplacedVar(ir::Node* var) {
// if next op is inplaced, then return the output var
// otherwise return nullptr
......@@ -218,6 +166,10 @@ std::unique_ptr<ir::Graph> InplacePass::ApplyImpl(
InitSSAGraphNodes();
std::unique_ptr<SSAGraphPrinter> printer(new SSAGraphPrinterImpl);
constexpr char graph_path1[] = "ir_graph_before_inplaced.txt";
std::unique_ptr<std::ostream> fout1(new std::ofstream(graph_path1));
PADDLE_ENFORCE(fout1->good());
printer->Print(*graph, *fout1);
for (auto* op : view_.AllOps()) {
if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name()))
......@@ -230,9 +182,6 @@ std::unique_ptr<ir::Graph> InplacePass::ApplyImpl(
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_path));
PADDLE_ENFORCE(fout->good());
printer->Print(*graph, *fout);
// for(auto* op : view_.AllOps()) {
// VLOG(3) << OpDebugString(op);
// }
return graph;
}
......@@ -250,6 +199,92 @@ void InplacePass::InplaceModifyDesc(const std::string& var,
}
}
const SSANodeVector InplacePass::TryInplaceModifyVar(
const std::string& var, const std::string& cache_var, const size_t& idx,
ir::Graph* graph) const {
PADDLE_ENFORCE(var_nodes_[var].size() >= 1 &&
var_nodes_[var].at(0)->Var() != nullptr);
std::unique_ptr<VarDesc> var_desc(new VarDesc(*var_nodes_[var].at(0)->Var()));
var_desc->SetName(cache_var);
SSANodeVector swap_nodes;
for (size_t i = idx; i < view_.AllOps().size(); ++i) {
auto* op = view_.AllOps()[i];
// redirect the input to the latest version of cache_var
for (auto* node : op->inputs) {
if (node->Name() == var) {
ir::Node* cache_node = graph->CreateVarNode(var_desc.get());
// swap node to cache_node
cache_node->outputs.insert(cache_node->outputs.end(),
node->outputs.begin(), node->outputs.end());
PADDLE_ENFORCE(node->inputs.size() == 1 && node->inputs[0]->IsOp());
auto* prev_op = node->inputs[0];
std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), node,
cache_node);
cache_node->inputs.emplace_back(prev_op);
for (auto* next_op : node->outputs) {
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
cache_node);
}
swap_nodes[node].emplace_back(cache_node);
}
}
for (auto* node : op->outputs) {
if (node->Name() == var) {
ir::Node* cache_node = graph->CreateVarNode(var_desc.get());
var_nodes_[cache_var].emplace_back(cache_node);
// swap node to cache node
cache_node->outputs.insert(cache_node->outputs.end(),
node->outputs.begin(), node->outputs.end());
cache_node->inputs.emplace_back(op);
std::replace(op->outputs.begin(), op->outputs.end(), node, cache_node);
for (auto* next_op : node->outputs) {
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
cache_node);
}
swap_nodes[node].emplace_back(cache_node);
}
}
}
return swap_nodes;
}
void InplacePass::CommitModify(const SSANodeVector& swap_nodes,
ir::Graph* graph) const {
for (auto& pair : swap_nodes) {
auto* node = pair.first;
const std::string var = node->Name();
for (auto* cache_node : pair.second) {
const std::string cache_var = cache_node->Name();
var_nodes_[cache_var].emplace_back(cache_node);
}
auto& nodes = var_nodes_.at(var);
nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end());
graph->RemoveNode(node);
}
}
void InplacePass::WithDrawModify(const SSANodeVector& nodes,
ir::Graph* graph) const {
for (auto& pair : nodes) {
auto* node = pair.first;
const std::string var = node->Name();
for (auto* cache_node : pair.second) {
const std::string cache_var = cache_node->Name();
auto* prev_op = node->inputs[0];
std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), cache_node,
node);
for (auto* next_op : node->outputs) {
std::replace(next_op->inputs.begin(), next_op->inputs.end(), cache_node,
node);
}
graph->RemoveNode(cache_node);
}
}
}
void InplacePass::InplaceModifyVar(const std::string& var,
const std::string& cache_var,
const size_t& idx, ir::Graph* graph) const {
......@@ -318,7 +353,7 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
ir::Graph* graph) const {
PADDLE_ENFORCE(op->Op() != nullptr && op->Op()->Block() != nullptr,
"op_desc is nullptr");
// 3 pre-requirments need to meet if the op want to inplaced.
// 4 pre-requirments need to meet if the op want to inplaced.
// 1. infer_inplace_ is registered.
auto* op_desc = op->Op();
auto& infer_inplace =
......@@ -333,36 +368,68 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
auto& all_ops = view_.AllOps();
auto cursor = std::find(all_ops.begin(), all_ops.end(), op);
size_t idx = std::distance(all_ops.begin(), cursor);
VLOG(3) << op->Name() << idx;
for (auto& pair : in_to_outs) {
auto& in_var_name = pair.first;
auto& out_var_name = pair.second;
auto* in_node = view_.GetNodeByName(in_var_name, op->inputs);
auto* out_node = view_.GetNodeByName(out_var_name, op->outputs);
// 2. there is no external pending op on the input node
if (view_.PendingOpsOnVar(in_node).size() > 1) {
VLOG(3) << string::Sprintf(
"!!! %s input has external dependency, can not inplaced, %s => %s "
"skiped",
op->Name(), out_var_name, in_var_name);
VLOG(4) << string::Sprintf(
"Skiped pair %s => %s. %s input has external dependency."
"inplace such pair will overwrite the memory.",
out_var_name, in_var_name, op->Name());
continue;
}
// 3. if output reuse input inplaced, the dependency group is not changed.
// For detail, check
// the function description in "OutConnectInputByCtrlVar"
if (view_.OutConnectInputByCtrlVar(in_node, out_node)) {
VLOG(3) << string::Sprintf(
"!!! %s input output connect by ctrl var, cannot inplaced, %s => %s "
"skiped",
op->Name(), out_var_name, in_var_name);
VLOG(4) << string::Sprintf(
"Skiped pair %s => %s. %s input and output connect by ctrl var."
"inplace such pair will generate a circle.",
out_var_name, in_var_name, op->Name());
continue;
}
VLOG(3) << string::Sprintf("!!! %s, %s => %s inplaced", op->Name(),
out_var_name, in_var_name);
// VLOG(3) << "Out " << OpDebugString(op);
InplaceModifyDesc(out_var_name, in_var_name, idx);
InplaceModifyVar(out_var_name, in_var_name, idx, graph);
// 4. if output has been memory optimize by python(fluid.memory_optmize()).
// this candidate can not be inplaced. Will be deprecated in the future.
if (view_.ReusedInPythonMemOpt(out_node->Name())) {
VLOG(4) << string::Sprintf(
"Skiped %s => %s reused previous memory block in python memory "
"optmize,"
"it inplace may generate a circle",
out_var_name, in_var_name, op->Name());
continue;
}
// Debug Interface. Which would be skipped by the pass.
if (out_node->Name() == FLAGS_memory_optimize_debug) {
VLOG(3) << "Skiped var by force. FLAGS_memory_optimize_debug="
<< out_node->Name();
continue;
}
auto swap_nodes =
TryInplaceModifyVar(out_var_name, in_var_name, idx, graph);
// NOTE(dzhwinter):
// two stage commit of inplaced op. If add such node generate a circle,
// then withdraw the changes. Otherwise, safely add the node.
if (!ir::HasCircle(*graph)) {
VLOG(3) << string::Sprintf("!!! %s, %s => %s inplaced", op->Name(),
out_var_name, in_var_name);
CommitModify(swap_nodes, graph);
InplaceModifyDesc(out_var_name, in_var_name, idx);
} else {
VLOG(3) << string::Sprintf(
"Skiped pair %s => %s, inplace will generate a circle. withdraw %s",
out_var_name, in_var_name, op->Name());
WithDrawModify(swap_nodes, graph);
}
}
}
......@@ -406,7 +473,28 @@ std::vector<ir::Node*> GraphView::PendingOpsOnVar(ir::Node* node) {
return pending_ops;
}
void GraphView::Build(ir::Graph* g) { ops_ = SortOpLikeDescOrder(*g); }
void GraphView::Build(ir::Graph* g) {
// track the var nodes in correct order.
// Because we insert some new created node. Which may have data race between
// nodes.
// resolve data harzards depends on the var nodes in right order.
ops_ = SortOpLikeDescOrder(*g);
// track the nodes which reused previous node in Python memory optimize.
// these node can not be inplaced, otherwise may generate a circle in graph.
std::unordered_set<std::string> all_vars;
for (auto& node : g->Nodes()) {
if (node->IsVar()) continue;
for (auto& out : node->outputs) {
if (out->IsCtrlVar() || out->Var() == nullptr) continue;
if (all_vars.count(out->Name())) {
dup_nodes_.emplace(out->Name());
} else {
all_vars.emplace(out->Name());
}
}
}
}
const std::vector<ir::Node*> GraphView::AllOps() { return ops_; }
......@@ -452,6 +540,10 @@ bool GraphView::OutConnectInputByCtrlVar(ir::Node* in_var, ir::Node* out_var) {
return ConnectByCtrlVar(in_var_set, out_var_set);
}
bool GraphView::ReusedInPythonMemOpt(const std::string& var) const {
return dup_nodes_.count(var);
}
} // namespace details
} // namespace framework
} // namespace paddle
......
......@@ -2,7 +2,7 @@
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// You may abtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
......@@ -15,6 +15,7 @@
#pragma once
#include <map>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
......@@ -40,10 +41,20 @@ class GraphView {
bool OutConnectInputByCtrlVar(ir::Node* in_var, ir::Node* out_var);
// Will Deperated in the future.
// NOTE(dzhwinter) : Python memory optimize will reuse
// memory based var name, so different op output may
// have the same variable name. enable inplace on such node
// will generate a circle in ssa graph.
bool ReusedInPythonMemOpt(const std::string& var) const;
private:
std::vector<ir::Node*> ops_;
std::unordered_set<std::string> dup_nodes_; // mem opt affect nodes
std::map<ir::Node*, std::unordered_set<ir::Node*>> adj_list_;
};
typedef std::unordered_map<ir::Node*, std::vector<ir::Node*>> SSANodeVector;
class InplacePass : public ir::Pass {
public:
InplacePass();
......@@ -58,6 +69,15 @@ class InplacePass : public ir::Pass {
void InplaceModifyVar(const std::string& in_var, const std::string& out_var,
const size_t& idx, ir::Graph* graph) const;
const SSANodeVector TryInplaceModifyVar(const std::string& var,
const std::string& cache_var,
const size_t& idx,
ir::Graph* graph) const;
void CommitModify(const SSANodeVector&, ir::Graph* graph) const;
void WithDrawModify(const SSANodeVector& nodes, ir::Graph* graph) const;
void InplaceModifyDesc(const std::string& in_var, const std::string& out_var,
const size_t& idx) const;
......
......@@ -52,16 +52,29 @@ bool HasCircleHelper(
ir::Node *node,
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list,
std::unordered_set<ir::Node *> *visited,
std::unordered_set<ir::Node *> *in_trace) {
std::unordered_set<ir::Node *> *in_trace,
std::vector<std::vector<ir::Node *>> *circles) {
if (visited->find(node) == visited->end()) {
visited->insert(node);
in_trace->insert(node);
for (ir::Node *in : adj_list.at(node)) {
if (visited->find(in) == visited->end() &&
HasCircleHelper(in, adj_list, visited, in_trace)) {
HasCircleHelper(in, adj_list, visited, in_trace, circles)) {
return true;
} else if (in_trace->find(in) != in_trace->end()) {
if (circles != nullptr) {
std::vector<ir::Node *> circle;
circle.emplace_back(in);
ir::Node *p = in;
for (auto &adj : adj_list.at(p)) {
if (in_trace->count(adj)) {
circle.emplace_back(adj);
p = adj;
}
}
circles->emplace_back(circle);
}
return true;
}
}
......@@ -71,11 +84,12 @@ bool HasCircleHelper(
}
bool HasCircleInternal(
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list) {
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list,
std::vector<std::vector<ir::Node *>> *circles) {
std::unordered_set<ir::Node *> visited;
std::unordered_set<ir::Node *> in_trace;
for (auto &adj : adj_list) {
if (HasCircleHelper(adj.first, adj_list, &visited, &in_trace)) {
if (HasCircleHelper(adj.first, adj_list, &visited, &in_trace, circles)) {
return true;
}
}
......@@ -84,13 +98,18 @@ bool HasCircleInternal(
} // namespace
bool HasCircle(const Graph &graph) {
return HasCircleInternal(BuildOperationAdjList(graph));
return HasCircleInternal(BuildOperationAdjList(graph), nullptr);
}
bool FindCircleSubGraph(const Graph &graph,
std::vector<std::vector<ir::Node *>> *circles) {
return HasCircleInternal(BuildOperationAdjList(graph), circles);
}
std::vector<ir::Node *> TopologySortOperations(const Graph &graph) {
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list =
BuildOperationAdjList(graph);
PADDLE_ENFORCE(!HasCircleInternal(adj_list));
PADDLE_ENFORCE(!HasCircleInternal(adj_list, nullptr));
std::unordered_set<ir::Node *> visited;
std::vector<ir::Node *> ret;
for (auto adj : adj_list) {
......
......@@ -28,6 +28,11 @@ namespace ir {
// Test if the graph contains circle.
bool HasCircle(const Graph &graph);
// Find All Circles for debugging,
// store all subgraph in circles.
bool FindCircleSubGraph(const Graph &graph,
std::vector<std::vector<ir::Node *>> *circles);
size_t GraphNum(const Graph &graph);
// Topology Sort the operations in the graph from inputs to outputs.
......
......@@ -195,6 +195,17 @@ void BuildTwoGraphs(Graph* g) {
// v4->outputs.push_back(o5);
}
TEST(GraphHelperTest, Circles) {
ProgramDesc prog;
Graph g(prog);
BuildCircleGraph(&g);
std::vector<std::vector<ir::Node*>> circles;
ASSERT_TRUE(FindCircleSubGraph(g, &circles));
ASSERT_EQ(circles.size() == 1UL);
}
TEST(GraphHelperTest, GraphNum) {
ProgramDesc prog;
......
......@@ -32,7 +32,7 @@ class TestParallelExecutorBase(unittest.TestCase):
def check_network_convergence(self,
method,
use_cuda=True,
memory_opt=True,
memory_opt=False,
iter=50,
batch_size=None,
allow_op_delay=False,
......@@ -67,8 +67,6 @@ class TestParallelExecutorBase(unittest.TestCase):
if memory_opt:
fluid.memory_optimize(main)
with open("program_model.txt", "w") as f:
f.write(str(main))
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup)
......@@ -82,9 +80,10 @@ class TestParallelExecutorBase(unittest.TestCase):
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
build_strategy.fuse_relu_depthwise_conv = fuse_relu_depthwise_conv
build_strategy.memory_optimize = use_ir_memory_optimize
build_strategy.enable_inplace = enable_inplace
# python memory optimization is conflict with inplace pass.
# Use ir graph memory optimization after inplace pass is the correct way.
build_strategy.enable_inplace = False if memory_opt else enable_inplace
build_strategy.enable_sequential_execution = enable_sequential_execution
build_strategy.debug_graphviz_path = "debug_ir_graph_"
if use_cuda and core.is_compiled_with_cuda():
build_strategy.remove_unnecessary_lock = True
......
......@@ -46,7 +46,10 @@ class TestIrInplace(TestParallelExecutorBase):
def setUpClass(cls):
os.environ['CPU_NUM'] = str(4)
def _fc_with_batchnorm(self, ir_memory_optimize, enable_inplace):
def _fc_with_batchnorm(self,
ir_memory_optimize,
enable_inplace,
memory_opt=False):
np.random.seed(5)
img = np.random.random(size=[32, 784]).astype(np.float32)
label = np.ones(shape=[32, 1], dtype='int64')
......@@ -55,7 +58,7 @@ class TestIrInplace(TestParallelExecutorBase):
feed_dict={"image": img,
"label": label},
use_cuda=True,
memory_opt=False, # inplace is conflict with memory opt
memory_opt=memory_opt,
use_ir_memory_optimize=ir_memory_optimize,
enable_inplace=enable_inplace)
......@@ -67,3 +70,10 @@ class TestIrInplace(TestParallelExecutorBase):
self.assertAlmostEqual(loss00, loss10, delta=delta)
self.assertAlmostEqual(loss00, loss01, delta=delta)
self.assertAlmostEqual(loss00, loss11, delta=delta)
def test_fc_with_batchnorm_memory_opt(self, delta=1e-3):
loss00 = self._fc_with_batchnorm(False, True, False)
loss10 = self._fc_with_batchnorm(False, True, True)
loss10 = self._fc_with_batchnorm(True, True, True)
self.assertAlmostEqual(loss00, loss10, delta=delta)
self.assertAlmostEqual(loss00, loss01, delta=delta)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册