提交 b6f71b8a 编写于 作者: Z Zhen Wang 提交者: Yan Chunwei

Add graph fusion(pattern matcher) in the lite framework. (#17839)

* add the pattern matcher logic. test=develop

* update op_info usage. test=develop
上级 ae60589f
......@@ -48,4 +48,7 @@ if (LITE_WITH_CUDA)
endif()
cc_test(test_variable_place_infrence_pass SRCS variable_place_inference_pass_test.cc DEPS
${test_variable_place_infrence_pass_DEPS})
cc_library(pattern_matcher_lite SRCS pattern_matcher.cc DEPS mir_node mir_ssa_graph op_lite)
cc_test(test_pattern_matcher_lite SRCS pattern_matcher_tester.cc DEPS pattern_matcher_lite)
......@@ -93,6 +93,16 @@ class Node {
return x;
}
Stmt* stmt() const {
CHECK(IsStmt());
return stmt_.get();
}
Arg* arg() const {
CHECK(IsArg());
return arg_.get();
}
// Set roles.
Arg& AsArg() {
if (role_ != Role::kUnk) {
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <array>
#include <string>
#include <vector>
#include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/lite/core/mir/pattern_matcher.h"
#include "paddle/fluid/lite/core/op_lite.h"
namespace paddle {
namespace lite {
namespace mir {
size_t PMPattern::id_ = 0UL;
PMNode *PMPattern::NewNode(const std::string &name) {
if (!name.empty()) {
CHECK_EQ(node_map_.count(name), 0UL)
<< "PMNode's name should be unique, get duplicate " << name;
}
nodes_.emplace_back(new PMNode(this, name));
auto *cur = nodes_.back().get();
node_map_[name] = cur;
return cur;
}
PMNode *PMPattern::NewNode(PMNode::teller_t &&teller, const std::string &name) {
if (!name.empty()) {
CHECK_EQ(node_map_.count(name), 0UL)
<< "PMNode's name should be unique, get duplicate " << name;
}
nodes_.emplace_back(new PMNode(std::move(teller), this, name));
auto *cur = nodes_.back().get();
node_map_[name] = cur;
return cur;
}
PMNode *PMPattern::RetrieveNode(const std::string &id) const {
auto it = node_map_.find(id);
if (it == node_map_.end()) {
return nullptr;
}
return it->second;
}
void PMPattern::AddEdge(PMNode *a, PMNode *b) {
CHECK(a);
CHECK(b);
CHECK_NE(a, b) << "Can't connect to the same nodes.";
edges_.emplace_back(a, b);
}
void PatternMatcher::operator()(SSAGraph *graph,
PatternMatcher::handle_t handler) {
if (!MarkPMNodesInGraph(graph)) {
return;
}
auto subgraphs = DetectPatterns();
UniquePatterns(&subgraphs);
RemoveOverlappedMatch(&subgraphs);
ValidateByNodeRole(&subgraphs);
if (subgraphs.empty()) return;
LOG(INFO) << "--- detected " << subgraphs.size() << " subgraphs.";
int id = 0;
for (auto &g : subgraphs) {
VLOG(3) << "optimizing #" << id++ << " subgraph";
handler(g, graph);
}
}
bool PatternMatcher::MarkPMNodesInGraph(SSAGraph *graph) {
VLOG(3) << "mark pmnodes in graph";
if (graph->nodes().empty()) return false;
for (auto &node : graph->mutable_nodes()) {
for (const auto &pmnode : pattern_.nodes()) {
if (pmnode->Tell(&node)) {
pmnodes2nodes_[pmnode.get()].insert(&node);
}
}
}
// Check to early stop if some PMNode can't find matched Node.
for (auto &pmnode : pattern_.nodes()) {
if (!pmnodes2nodes_.count(pmnode.get())) {
VLOG(4) << pmnode->name() << " can't find matched Node, early stop";
// return false;
}
}
VLOG(3) << pmnodes2nodes_.size() << " nodes marked";
return !pmnodes2nodes_.empty();
}
// The intermediate Nodes can only link to the nodes inside the pattern, or this
// subgraph will be droped.
void PatternMatcher::ValidateByNodeRole(
std::vector<PatternMatcher::subgraph_t> *subgraphs) {
std::vector<PatternMatcher::subgraph_t> result;
subgraphs->erase(
std::remove_if(subgraphs->begin(), subgraphs->end(),
[](const PatternMatcher::subgraph_t &subgraph) -> bool {
// Collect the inlinks and outlinks.
std::unordered_set<Node *> ios;
for (auto &item : subgraph) {
if (!item.first->IsIntermediate()) {
ios.insert(item.second);
}
}
for (auto &item : subgraph) {
if (item.first->IsIntermediate()) {
for (auto *x : item.second->inlinks) {
if (!ios.count(x)) {
return true;
}
}
for (auto *x : item.second->outlinks) {
if (!ios.count(x)) {
return true;
}
}
}
}
return false;
}),
subgraphs->end());
}
struct HitGroup {
std::unordered_map<PMNode *, Node *> roles;
bool Match(Node *node, PMNode *pat) {
if (nodes_.count(node)) {
if (roles.count(pat) && roles[pat] == node) return true;
return false;
} else {
if (roles.count(pat) && roles[pat] != node) return false;
return true;
}
}
void Register(Node *node, PMNode *pat) {
roles[pat] = node;
nodes_.insert(node);
}
private:
std::unordered_set<Node *> nodes_;
};
// Tell whether Node a links to b.
bool IsNodesLink(Node *a, Node *b) {
for (auto *node : a->outlinks) {
if (b == node) {
return true;
}
}
return false;
}
std::vector<PatternMatcher::subgraph_t> PatternMatcher::DetectPatterns() {
// Init empty subgraphs.
std::vector<PatternMatcher::subgraph_t> result;
std::vector<HitGroup> init_groups;
std::array<std::vector<HitGroup>, 2> bi_records;
auto *first_pnode = pattern_.edges().empty() ? pattern().nodes().front().get()
: pattern_.edges().front().first;
if (!pmnodes2nodes_.count(first_pnode)) return result;
for (auto *node : pmnodes2nodes_[first_pnode]) {
HitGroup group;
group.roles[first_pnode] = node;
init_groups.emplace_back(group);
}
int step = 0;
bi_records[0] = std::move(init_groups);
// Extend a PMNode to subgraphs by deducing the connection relations defined
// in edges of PMNodes.
for (const auto &edge : pattern_.edges()) {
VLOG(4) << "check " << edge.first->name() << " -> " << edge.second->name();
// TODO(Superjomn) Fix bug here, the groups might be duplicate here.
// Each role has two PMNodes, which indicates two roles.
// Detect two Nodes that can match these two roles and they are connected.
auto &pre_groups = bi_records[step % 2];
auto &cur_groups = bi_records[1 - (step++ % 2)];
cur_groups.clear();
if (pre_groups.empty()) break;
// source -> target
for (Node *source : pmnodes2nodes_[edge.first]) {
for (Node *target : pmnodes2nodes_[edge.second]) {
// TODO(Superjomn) add some prune strategies.
for (const auto &group : pre_groups) {
if (IsNodesLink(source, target)) {
HitGroup new_group = group;
bool flag = new_group.Match(source, edge.first) &&
new_group.Match(target, edge.second);
if (flag) {
new_group.Register(source, edge.first);
new_group.Register(target, edge.second);
cur_groups.push_back(new_group);
// TODO(Superjomn) need to unique
}
}
}
}
}
VLOG(3) << "step " << step << " get records: " << cur_groups.size();
}
for (auto &group : bi_records[step % 2]) {
PatternMatcher::subgraph_t subgraph;
for (auto &role : group.roles) {
subgraph.emplace(role.first, role.second);
}
result.emplace_back(subgraph);
}
return result;
}
struct GraphItemLessThan {
bool operator()(const std::pair<PMNode *, Node *> &a,
const std::pair<PMNode *, Node *> &b) {
if (a.first != b.first) {
return a.first < b.first;
} else {
return a.second < b.second;
}
}
};
// TODO(Superjomn) enhance the function as it marks unique unique as duplicates
// see https://github.com/PaddlePaddle/Paddle/issues/13550
void PatternMatcher::UniquePatterns(
std::vector<PatternMatcher::subgraph_t> *subgraphs) {
if (subgraphs->empty()) return;
std::vector<PatternMatcher::subgraph_t> result;
std::unordered_set<size_t> set;
std::hash<std::string> hasher;
for (auto &g : *subgraphs) {
// Sort the items in the sub-graph, and transform to a string key.
std::vector<std::pair<PMNode *, Node *>> sorted_keys(g.begin(), g.end());
std::sort(sorted_keys.begin(), sorted_keys.end(), GraphItemLessThan());
std::stringstream ss;
for (auto &item : sorted_keys) {
ss << item.first << ":" << item.second;
}
auto key = hasher(ss.str());
if (!set.count(key)) {
result.emplace_back(g);
set.insert(key);
}
}
*subgraphs = result;
}
void PatternMatcher::RemoveOverlappedMatch(std::vector<subgraph_t> *subgraphs) {
std::vector<subgraph_t> result;
std::unordered_set<Node *> node_set;
for (const auto &subgraph : *subgraphs) {
bool valid = true;
for (auto &item : subgraph) {
if (item.first->IsIntermediate() && node_set.count(item.second)) {
valid = false;
break;
}
}
if (valid) {
for (auto &item : subgraph) {
node_set.insert(item.second);
}
result.push_back(subgraph);
}
}
*subgraphs = result;
}
std::string PMPattern::DotString() const {
using inference::analysis::Dot;
Dot dot;
int id = 0;
// Create Nodes
std::unordered_map<PMNode *, std::string> node2dot;
for (const auto &node : nodes()) {
std::string node_id = "Node" + std::to_string(id++);
dot.AddNode(node_id, {}, node->name());
node2dot[node.get()] = node_id;
}
// Create Edges
for (const auto &edge : edges()) {
if (!node2dot.count(edge.first) || !node2dot.count(edge.second)) {
LOG(ERROR) << "no node " << edge.first << " " << edge.second;
continue;
}
auto &src = node2dot.at(edge.first);
auto &trg = node2dot.at(edge.second);
dot.AddEdge(src, trg, {});
}
return dot.Build();
}
PMNode &PMNode::LinksTo(const std::vector<PMNode *> &others) {
// extend outlinks.
for (PMNode *x : others) {
pattern_->AddEdge(this, x);
}
return *this;
}
PMNode &PMNode::LinksFrom(const std::vector<PMNode *> &others) {
// extend outlinks.
for (PMNode *x : others) {
pattern_->AddEdge(x, this);
}
return *this;
}
PMNode *PMNode::assert_is_op() {
asserts_.emplace_back([](const Node *x) { return x && x->IsStmt(); });
return this;
}
PMNode *PMNode::assert_is_op(const std::string &op_type) {
asserts_.emplace_back([op_type](const Node *x) {
if (x && x->IsStmt()) {
auto *op_info = x->stmt()->op_info();
return op_info->Type() == op_type;
} else {
return false;
}
});
return this;
}
PMNode *PMNode::assert_is_var() {
asserts_.emplace_back([](const Node *x) { return x && x->IsArg(); });
return this;
}
PMNode *PMNode::assert_var_not_persistable() {
assert_is_var();
asserts_.emplace_back([](const Node *x) { return !x->arg()->is_weight; });
return this;
}
PMNode *PMNode::assert_is_persistable_var() {
assert_is_var();
asserts_.emplace_back([=](const Node *x) { return x->arg()->is_weight; });
return this;
}
PMNode *PMNode::assert_is_op_output(const std::string &op_type) {
assert_is_var();
asserts_.emplace_back([=](const Node *x) {
for (auto *op : x->inlinks) {
if (op && op->IsStmt()) {
auto *op_info = x->stmt()->op_info();
if (op_info->Type() == op_type) return true;
}
}
return false;
});
return this;
}
PMNode *PMNode::assert_is_op_input(const std::string &op_type) {
assert_is_var();
asserts_.emplace_back([=](const Node *x) {
for (auto *op : x->outlinks) {
if (op && op->IsStmt()) {
auto *op_info = op->stmt()->op_info();
if (op_info->Type() == op_type) {
return true;
}
}
}
return false;
});
return this;
}
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_TESTING
#include <gtest/gtest_prod.h>
#endif
#include <glog/logging.h>
#include <memory>
#include <numeric>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
#include "paddle/fluid/lite/model_parser/pb/op_desc.h"
namespace paddle {
namespace lite {
namespace mir {
class PMPattern;
// Some basic terminologies:
// - PMPattern: a pattern defined as a data flow graph.
// - PMNode: the node in the pattern, each PMNode represents an `mir::Node`
// that meets some conditions defined in `PMNode.teller`.
// - A pattern is defined with PMNodes with edges.
// Pattern matcher node. This node helps to build a pattern.
struct PMNode {
// tell whether an mir::Node* is a candidation for a PMNode.
using teller_t = std::function<bool(const Node*)>;
enum class Type { kOp, kVar };
enum class Role {
kUnknown, // No role,
kInput, // an input and will be retained,
kOutput, // an output and will be retained,
kIntermediate // will be removed after handler.
};
// this link to others
PMNode& LinksTo(const std::vector<PMNode*>& others);
PMNode& LinksFrom(const std::vector<PMNode*>& others);
bool Tell(const Node* node) const {
if (teller_) return teller_(node);
for (auto& asrt : asserts_) {
if (!asrt(node)) return false;
}
return true;
}
bool IsOp() const { return type_ == Type::kOp; }
bool IsVar() const { return type_ == Type::kVar; }
const std::string& name() const { return name_; }
PMNode& operator=(const PMNode&) = delete;
PMNode(const PMNode&) = delete;
// Mark this node is an Input of a subgraph and will be retained.
PMNode* AsInput() {
role_ = Role::kInput;
return this;
}
// Mark this node is an Output of a subgraph and will be retained.
PMNode* AsOutput() {
role_ = Role::kOutput;
return this;
}
// Mark this node will be removed, so all the links should be inside a matched
// sub-graph.
PMNode* AsIntermediate() {
role_ = Role::kIntermediate;
return this;
}
bool IsIntermediate() const { return role_ == Role::kIntermediate; }
bool IsInput() const { return role_ == Role::kInput; }
bool IsOutput() const { return role_ == Role::kOutput; }
// Assertions, helper functions to simplify the pattern definition.
PMNode* assert_is_op();
PMNode* assert_is_op(const std::string& op_type);
PMNode* assert_is_var();
PMNode* assert_var_not_persistable();
PMNode* assert_is_persistable_var();
PMNode* assert_is_op_output(const std::string& op_type);
PMNode* assert_is_op_input(const std::string& op_type);
template <typename T>
PMNode* assert_op_attr(const std::string& attr_name, const T& attr) {
asserts_.emplace_back([=](Node* x) {
if (x && x->IsStmt()) {
auto* op_info = x->stmt()->op_info();
return op_info->HasAttr(attr_name) &&
op_info->GetAttr<T>(attr_name) == attr;
} else {
return false;
}
});
return this;
}
private:
PMNode(PMPattern* pattern, const std::string& name = "",
Type type = Type::kVar)
: pattern_(pattern), name_(name), type_(type) {}
PMNode(teller_t&& teller, PMPattern* pattern, const std::string& name = "",
Type type = Type::kVar)
: teller_(std::move(teller)),
pattern_(pattern),
name_(name),
type_(type) {
CHECK(teller_ != nullptr) << "invalid teller functer is set.";
}
PMNode(PMNode&& other) = default;
friend class PMPattern;
// Will removed latter.
teller_t teller_;
std::vector<teller_t> asserts_;
PMPattern* pattern_;
std::string name_;
Type type_;
Role role_{Role::kUnknown};
};
/*
* A pattern in a graph, which defined with PMNode and edges. Most graph
* patterns can be divided into PMNodes and link relations between them.
*
* For example, the FC fusion need to filter the MUL and ELEMENTWISE_ADD
* operators from the computation graph, the MUL's output should have only one
* consumer which is the ELEMENTWISE_ADD.
* This pattern can be defined as with the following pseudo codes
*
* // Create two operator PMNodes.
* MUL = PMPattern.NewNode().assert_is_op("mul");
* ELE = PMPattern.NewNode().assert_is_op("elementwise_add");
* // Create the variable PMNodes.
* MUL_out = PMPattern.NewNode().assert_is_op_output("mul") \
* .assert_is_op_input("elementwise_add") \
* .AsIntermediate();
* // Add relations.
* MUL->LinksTo({MUL_out});
* MUL_out->LinksTo({ELE});
*
* One can add more specific asserts for PMNodes or edges, both the Operator
* and Variable Nodes can be ruled in PMNode.assert_more(...).
*
* PMPattern can record the general patterns, such as the pattern represents
* - Op in CPU -> Op in GPU -> Op in CPU, to findout the IO abnormal place.
* - Ops whose inputs and outputs share the same variables
*/
class PMPattern {
public:
using edge_t = std::pair<PMNode*, PMNode*>;
void AddEdge(PMNode* a, PMNode* b);
PMNode* NewNode(PMNode::teller_t&& teller, const std::string& name = NewID());
PMNode* NewNode(const std::string& name = NewID());
PMNode* NewNode(const std::string& prefix, const std::string& name) {
return NewNode(prefix + "/" + name);
}
PMNode* RetrieveNode(const std::string& id) const;
const std::vector<std::unique_ptr<PMNode>>& nodes() const { return nodes_; }
const std::vector<edge_t>& edges() const { return edges_; }
std::string DotString() const;
private:
#ifdef PADDLE_WITH_TESTING
FRIEND_TEST(PMPattern, AddEdge);
FRIEND_TEST(PMPattern, NewNode);
#endif
static std::string NewID() { return "pmnode-" + std::to_string(id_++); }
std::vector<std::unique_ptr<PMNode>> nodes_;
std::vector<edge_t> edges_;
std::unordered_map<std::string, PMNode*> node_map_;
static size_t id_;
};
/*
* PatternMatcher helps to detect the specific patterns in the graph.
* Input a pattern, output a list of the matched subgraphs/nodes.
* This helper can be used to support fuse(conv+batchnorm => batchnorm e.g.).
*
* The algorithm has three phases:
* 1. Mark the nodes that match the defined PMNodes in a PMPattern,
* 2. Extend a PMNode to subgraphs by deducing the connection relation defined
* in PAPattern(the edges),
* 3. Get the filtered subgraphs and treat them with a pre-defined handler.
*
* Usage:
* // Create a matcher
* PatternMatcher matcher;
* // Define the matcher's pattern, by adding PMNode and define the edges.
* auto* node0 = matcher.mutable_pattern().AddNode(...)
* auto* node1 = matcher.mutable_pattern().AddNode(...)
* node0->teller = some lambda.
* node1->teller = some lambda.
* matcher.mutable_pattern().AddEdge(node0, node1);
* // Create an handler, to define the behavior of treating the filtered
* // subgraphs that comply with the patterns.
* PatternMatcher::handle_t handler = some labmda
* // Execute the matcher.
* matcher(&graph, handler);
*/
class PatternMatcher {
public:
using subgraph_t = std::unordered_map<PMNode*, Node*>;
// Operate on the detected pattern.
using handle_t =
std::function<void(const subgraph_t& /*hitted pattern*/, SSAGraph*)>;
void operator()(SSAGraph* graph, handle_t handler);
const PMPattern& pattern() const { return pattern_; }
PMPattern* mutable_pattern() { return &pattern_; }
private:
// Mark the nodes that fits the pattern.
bool MarkPMNodesInGraph(SSAGraph* graph);
// Detect all the pattern and output the hit records.
std::vector<subgraph_t> DetectPatterns();
// Remove duplicate patterns.
void UniquePatterns(std::vector<subgraph_t>* subgraphs);
// Remove overlapped match subgraphs, when overlapped, keep the previous one.
// The intermediate PMNodes will be removed, so can't shared by multiple
// patterns.
void RemoveOverlappedMatch(std::vector<subgraph_t>* subgraphs);
// Validate whether the intermediate nodes are linked by external nodes.
void ValidateByNodeRole(std::vector<subgraph_t>* subgraphs);
#ifdef PADDLE_WITH_TESTING
FRIEND_TEST(PatternMatcher, MarkPMNodesInGraph);
FRIEND_TEST(PatternMatcher, DetectPatterns);
#endif
private:
using hit_rcd_t =
std::pair<Node* /*node in graph*/, PMNode* /*node in pattern*/>;
PMPattern pattern_;
std::unordered_map<const PMNode*, std::unordered_set<Node*>> pmnodes2nodes_;
};
// Some pre-defined patterns those can be reused in multiple passes.
// The related Fluid Layer or Op should be one pattern here for better re-usage
// across different fusion.
namespace patterns {
struct KeyCounter {
static KeyCounter& Instance() {
static KeyCounter x;
return x;
}
int IncCounter(const std::string& key) { return dic_[key]++; }
private:
std::unordered_map<std::string, size_t> dic_;
};
// Generate a unique PMNode's name with name_scope and id.
// The format is {name_scope}/{repr}/{id}/{name}
static std::string PMNodeName(const std::string& name_scope,
const std::string& repr, size_t id,
const std::string& name) {
std::stringstream ss;
ss << name_scope << "/" << repr << "/" << id << "/" << name;
return ss.str();
}
// Generate a unique PMNode's name.
// The format is {name_scope}/{repr}/{id}
static std::string PMNodeName(const std::string& name_scope,
const std::string& repr) {
std::stringstream ss;
ss << name_scope << "/" << repr << "/"
<< KeyCounter::Instance().IncCounter(repr);
return ss.str();
}
// Generate a unique key. It can be used for a universally unique temporary
// name.
// The format is {repr}/{id}
static std::string UniqueKey(const std::string& repr) {
std::stringstream ss;
ss << repr << "/" << KeyCounter::Instance().IncCounter(repr);
return ss.str();
}
// Declare a PMNode in a pattern, will create two methods:
// std::string xxx_repr(); return this PMNode's string id.
// PMNode* xxx_n(); return the corresponding PMNode.
#define PATTERN_DECL_NODE(name__) \
std::string name__##_repr() const { \
return PMNodeName(name_scope_, repr_, id_, #name__); \
} \
PMNode* name__##_n() const { return pattern->RetrieveNode(name__##_repr()); }
// Get an mir::Node* from the matched subgraph.
// var: variable.
// arg: the argument declared by PATTERN_DECL_NODE in a pattern definition.
// pat: the pattern object.
#define GET_IR_NODE_FROM_SUBGRAPH(var, arg, pat) \
CHECK(subgraph.count(pat.arg##_n())) \
<< "Node not found for PMNode " pat.arg##_repr(); \
Node* var = subgraph.at(pat.arg##_n()); \
CHECK(var) << "node " << #arg << "not exists in the sub-graph"
// The base class of all the patterns.
struct PatternBase {
PatternBase(PMPattern* pattern, const std::string& name_scope,
const std::string& repr)
: pattern(pattern),
name_scope_(name_scope),
repr_(repr),
id_(KeyCounter::Instance().IncCounter(repr)) {}
PMPattern* pattern;
protected:
std::string name_scope_;
std::string repr_;
size_t id_;
};
} // namespace patterns
// Link two mir::Nodes from each other.
#define IR_NODE_LINK_TO(a, b) \
a->outlinks.push_back(b); \
b->inlinks.push_back(a);
// Set the out_var as the output of the op
#define IR_OP_VAR_LINK(op, out_var) \
op->outlinks.push_back(out_var); \
out_var->inlinks.clear(); \
out_var->inlinks.push_back(op);
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pattern_matcher.h"
#include <gtest/gtest.h>
namespace paddle {
namespace lite {
namespace mir {
void BuildGraph(SSAGraph* g) {
g->mutable_nodes().emplace_back();
Node& o1 = g->mutable_nodes().back();
o1.AsStmt().op_type = "op1";
g->mutable_nodes().emplace_back();
Node& o2 = g->mutable_nodes().back();
o2.AsStmt().op_type = "op2";
g->mutable_nodes().emplace_back();
Node& o3 = g->mutable_nodes().back();
o3.AsStmt().op_type = "op3";
g->mutable_nodes().emplace_back();
Node& o4 = g->mutable_nodes().back();
o4.AsStmt().op_type = "op4";
g->mutable_nodes().emplace_back();
Node& o5 = g->mutable_nodes().back();
o5.AsStmt().op_type = "op5";
g->mutable_nodes().emplace_back();
Node& v1 = g->mutable_nodes().back();
v1.AsArg("var1");
g->mutable_nodes().emplace_back();
Node& v2 = g->mutable_nodes().back();
v2.AsArg("var2");
g->mutable_nodes().emplace_back();
Node& v3 = g->mutable_nodes().back();
v3.AsArg("var3");
g->mutable_nodes().emplace_back();
Node& v4 = g->mutable_nodes().back();
v4.AsArg("var4");
// o1->v1->o2
o1.outlinks.push_back(&v1);
o2.inlinks.push_back(&v1);
v1.inlinks.push_back(&o1);
v1.outlinks.push_back(&o2);
// o2->v2->o3
// o2->v2->o4
o2.outlinks.push_back(&v2);
o3.inlinks.push_back(&v2);
o4.inlinks.push_back(&v2);
v2.inlinks.push_back(&o2);
v2.outlinks.push_back(&o3);
v2.outlinks.push_back(&o4);
// o2->v3->o5
o2.outlinks.push_back(&v3);
o5.inlinks.push_back(&v3);
v3.inlinks.push_back(&o2);
v3.outlinks.push_back(&o5);
// o3-v4->o5
o3.outlinks.push_back(&v4);
o5.inlinks.push_back(&v4);
v4.inlinks.push_back(&o3);
v4.outlinks.push_back(&o5);
}
TEST(PMPattern, NewNode) {
PMPattern x;
auto* n = x.NewNode([](const Node* x) { return true; });
ASSERT_TRUE(n);
ASSERT_EQ(x.nodes_.size(), 1UL);
}
TEST(PMPattern, AddEdge) {
PMPattern x;
auto* a = x.NewNode([](const Node* x) { return true; });
auto* b = x.NewNode([](const Node* x) { return true; });
ASSERT_TRUE(a);
ASSERT_TRUE(b);
x.AddEdge(a, b);
ASSERT_EQ(x.nodes_.size(), 2UL);
ASSERT_EQ(x.edges_.size(), 1UL);
ASSERT_EQ(x.edges_.front().first, a);
ASSERT_EQ(x.edges_.front().second, b);
ASSERT_EQ(x.nodes().size(), 2UL);
ASSERT_EQ(x.edges().size(), 1UL);
ASSERT_EQ(x.edges().front().first, a);
ASSERT_EQ(x.edges().front().second, b);
}
TEST(PatternMatcher, MarkPMNodesInGraph) {
PatternMatcher x;
// mark o2, o3, v2
// The pattern is a graph:
// o2(a node named o2) -> v2(a node named v2)
// v2 -> o3(a node named o3)
auto* o2 = x.pattern_.NewNode([](const Node* node) {
// The teller can be any condition, such as op type, or variable's shape.
return node && node->IsStmt() && node->stmt()->op_type == "op2";
});
auto* o3 = x.pattern_.NewNode([](const Node* node) {
// The teller can be any condition, such as op type, or variable's shape.
return node && node->IsStmt() && node->stmt()->op_type == "op3";
});
auto* v2 = x.pattern_.NewNode([](const Node* node) {
// The teller can be any condition, such as op type, or variable's shape.
return node && node->IsArg() && node->arg()->name == "var2";
});
ASSERT_FALSE(o2->Tell(nullptr));
ASSERT_FALSE(o3->Tell(nullptr));
ASSERT_FALSE(v2->Tell(nullptr));
x.pattern_.AddEdge(o2, v2);
x.pattern_.AddEdge(v2, o3);
ASSERT_EQ(x.pattern_.edges().size(), 2UL);
ASSERT_EQ(x.pattern_.edges()[0].first, o2);
ASSERT_EQ(x.pattern_.edges()[0].second, v2);
ASSERT_EQ(x.pattern_.edges()[1].first, v2);
ASSERT_EQ(x.pattern_.edges()[1].second, o3);
SSAGraph graph;
BuildGraph(&graph);
x.MarkPMNodesInGraph(&graph);
ASSERT_EQ(x.pmnodes2nodes_.size(), 3UL);
auto subgraphs = x.DetectPatterns();
ASSERT_EQ(subgraphs.size(), 1UL);
}
TEST(PatternMatcher, MultiSubgraph) {
SSAGraph graph;
BuildGraph(&graph);
PatternMatcher x;
// The pattern is a graph:
// op -> var
auto* any_op = x.mutable_pattern()->NewNode(
[](const Node* node) {
return node->IsStmt() && (node->stmt()->op_type == "op2" ||
node->stmt()->op_type == "op3");
},
"OP0");
auto* any_var =
x.mutable_pattern()
->NewNode([](const Node* node) { return node->IsArg(); }, "VAR")
->AsIntermediate();
auto* any_op1 = x.mutable_pattern()->NewNode(
[](const Node* node) { return node->IsStmt(); }, "OP1");
x.mutable_pattern()->AddEdge(any_op, any_var);
x.mutable_pattern()->AddEdge(any_var, any_op1);
int count = 0;
PatternMatcher::handle_t handle = [&](const PatternMatcher::subgraph_t& s,
SSAGraph* g) {
LOG(INFO) << "Detect " << s.at(any_op)->stmt()->op_type << " -> "
<< s.at(any_var)->arg()->name << " -> "
<< s.at(any_op1)->stmt()->op_type;
count++;
};
x(&graph, handle);
// 1. Detect op3 -> var4 -> op5
// 2. Detect op2 -> var2 -> op3
// 3. Detect op2 -> var2 -> op4
// 4. Detect op2 -> var3 -> op5
// But 2 and 3 and 4 overlapped, so keep 2, so the final choices are 1 and 2
ASSERT_GE(count, 1);
ASSERT_LE(count, 2);
}
TEST(PatternMatcher, IntermediateCheck) {
SSAGraph graph;
BuildGraph(&graph);
// o2->v2->o3
// o2->v2->o4
// check o2+o3 fuse, should fail because v2 also link to o4.
PatternMatcher matcher;
auto* op2 = matcher.mutable_pattern()->NewNode(
[](const Node* x) {
return x && x->IsStmt() && x->stmt()->op_type == "op2";
},
"op2");
auto* op3 = matcher.mutable_pattern()->NewNode(
[](const Node* x) {
return x && x->IsStmt() && x->stmt()->op_type == "op3";
},
"op3");
auto* v2 = matcher.mutable_pattern()
->NewNode(
[](const Node* x) {
return x && x->IsArg() && x->arg()->name == "var2";
},
"var2")
->AsIntermediate();
v2->LinksFrom({op2}).LinksTo({op3});
int count = 0;
matcher(&graph, [&](const PatternMatcher::subgraph_t& g, SSAGraph* graph) {
++count;
});
EXPECT_EQ(count, 0);
count = 0;
v2->AsInput();
matcher(&graph, [&](const PatternMatcher::subgraph_t& g, SSAGraph* graph) {
++count;
});
ASSERT_EQ(count, 1);
}
} // namespace mir
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册