未验证 提交 a5241e1d 编写于 作者: Y Yan Chunwei 提交者: GitHub

refine program (#17726)

上级 01e1cdac
......@@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_desc.h"
#include <glog/logging.h>
#include <algorithm>
#include <functional>
#include <mutex> // NOLINT
#include <string>
#include <unordered_map>
#include "glog/logging.h"
#include <utility>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
......
......@@ -64,7 +64,7 @@ class LightPredictor {
private:
void BuildRuntimeProgram(const framework::proto::ProgramDesc& prog) {
std::vector<Instruct> insts;
std::vector<Instruction> insts;
// 1. Create op first
Program program(prog, scope_, {});
......@@ -72,7 +72,7 @@ class LightPredictor {
// Create the kernels of the target places, and filter out the specific
// kernel with the target alias.
for (auto& op : program.ops) {
for (auto& op : program.ops_) {
lite::pb::OpDesc desc(op->op_info()->desc());
auto kernel_type = desc.GetAttr(kKernelTypeAttr).get<std::string>();
std::string op_type, alias;
......@@ -89,8 +89,8 @@ class LightPredictor {
insts.emplace_back(op, std::move(*it));
}
program_.reset(new RuntimeProgram(std::move(insts)));
CHECK(program.exec_scope);
program_->set_exec_scope(program.exec_scope);
CHECK(program.exec_scope_);
program_->set_exec_scope(program.exec_scope_);
}
private:
......
......@@ -30,7 +30,7 @@ cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite target_wrapp
cc_library(types_lite SRCS types.cc)
cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite)
cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite)
cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite compatible_pb_lite)
cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite)
add_subdirectory(mir)
......
......@@ -41,7 +41,7 @@ class GenerateProgramPass : public ProgramPass {
}
private:
std::vector<Instruct> insts_;
std::vector<Instruction> insts_;
};
} // namespace mir
......
......@@ -94,7 +94,7 @@ std::vector<mir::Node *> SSAGraph::StmtTopologicalOrder() {
}
void SSAGraph::GraphCreateTmpVarNodes(const Program &program) {
for (const auto &name : program.tmp_vars) {
for (const auto &name : program.tmp_vars()) {
CHECK(!arguments_.count(name)) << "duplicate creating temp variable: "
<< name;
VLOG(5) << "create arg node " << name;
......@@ -107,7 +107,7 @@ void SSAGraph::GraphCreateTmpVarNodes(const Program &program) {
void SSAGraph::GraphCreateWeightVarNodes(const Program &program) {
// create weight nodes.
for (const auto &name : program.weights) {
for (const auto &name : program.weights()) {
CHECK(!arguments_.count(name)) << "duplicate creating weight variable: "
<< name;
VLOG(5) << "create arg node " << name;
......@@ -140,7 +140,7 @@ void SSAGraph::Build(const Program &program,
GraphCreateWeightVarNodes(program);
CHECK(CheckNodesRoleSet());
for (auto &op : program.ops) {
for (auto &op : program.ops()) {
auto *op_node = GraphCreateInstructNode(program, op, valid_places);
for (const std::string &name : op->op_info()->input_names()) {
auto *arg = Argument(name);
......
......@@ -77,7 +77,7 @@ class SSAGraph : GraphBase {
bool CheckLinksRoleSet();
void MarkArgumentWeights(const Program &program) {
for (const auto &name : program.weights) {
for (const auto &name : program.weights()) {
arguments_[name]->AsArg().is_weight = true;
}
}
......
......@@ -147,7 +147,7 @@ class OpLite : public Registry {
class OpInfo : public cpp::OpDesc {
public:
OpInfo(const OpInfo &) = default;
OpInfo(const cpp::OpDesc &other) : cpp::OpDesc(other) {}
explicit OpInfo(const cpp::OpDesc &other) : cpp::OpDesc(other) {}
// Collect all the input variable's name.
std::vector<std::string> input_names() const {
......
......@@ -64,7 +64,7 @@ class Optimizer {
RunPasses(passes);
}
#endif
exec_scope_ = program.exec_scope;
exec_scope_ = program.exec_scope();
}
void KernelPickPreferPlace(const Place& place) {
......
......@@ -62,5 +62,45 @@ void RuntimeProgram::SaveParams(const std::string &dir,
}
}
void Program::Build(const framework::proto::ProgramDesc &program) {
CHECK(ops_.empty()) << "Executor duplicate Build found";
// Create operators.
for (const auto &proto_op_desc : program.blocks(0).ops()) {
lite::OpDesc op_desc_dummy(proto_op_desc);
cpp::OpDesc op_desc;
TransformOpDescPbToCpp(op_desc_dummy, &op_desc);
auto op_type = op_desc.Type();
// if (op_type == "feed" || op_type == "fetch") continue;
VLOG(4) << "create Op [" << op_type << "]";
LOG(INFO) << "create Op [" << op_type << "]";
auto op = LiteOpRegistry::Global().Create(op_type);
CHECK(op) << "no Op found for " << op_type;
ops_.emplace_back(std::move(op));
ops_.back()->Attach(op_desc, exec_scope_);
}
}
void Program::PrepareWorkspace(const framework::proto::ProgramDesc &program) {
CHECK(!exec_scope_) << "Duplicate PrepareWorkspace found";
exec_scope_ = &scope_->NewScope();
// Create Feed and Fetch var.
scope_->Var("feed")->GetMutable<std::vector<lite::Tensor>>();
scope_->Var("fetch")->GetMutable<std::vector<lite::Tensor>>();
tmp_vars_.push_back("feed");
tmp_vars_.push_back("fetch");
CHECK(!program.blocks().empty());
for (auto proto_var_desc : program.blocks(0).vars()) {
lite::VarDesc var_desc(proto_var_desc);
if (!var_desc.Persistable()) {
tmp_vars_.push_back(var_desc.Name());
exec_scope_->Var(var_desc.Name());
} else {
if (var_desc.Name() == "feed" || var_desc.Name() == "fetch") continue;
weights_.push_back(var_desc.Name());
}
}
}
} // namespace lite
} // namespace paddle
......@@ -37,79 +37,48 @@ static const char kKernelTypeAttr[] = "__@kernel_type_attr@__";
// - main block, which is a list of OpLite
// - scope: which contains all the weights
struct Program {
std::list<std::string> tmp_vars;
std::list<std::string> weights;
std::list<std::shared_ptr<OpLite>> ops;
// the scope to run the kernels, NOTE this is the execution scope.
std::shared_ptr<lite::Scope> scope;
std::vector<Place> valid_places;
// Runtime scope.
lite::Scope* exec_scope{};
const framework::proto::ProgramDesc desc;
explicit Program(const std::shared_ptr<Scope>& root) { scope = root; }
public:
explicit Program(const std::shared_ptr<Scope>& root) { scope_ = root; }
Program(const framework::proto::ProgramDesc& desc,
const std::shared_ptr<Scope>& root,
const std::vector<Place>& valid_places)
: scope(root), valid_places(valid_places), desc(desc) {
CHECK(scope) << "scope should be init first";
: scope_(root), valid_places_(valid_places), desc_(desc) {
CHECK(scope_) << "scope should be init first";
PrepareWorkspace(desc);
Build(desc);
}
std::unique_ptr<Program> Clone() const {
std::unique_ptr<Program> res(new Program(desc, scope, valid_places));
std::unique_ptr<Program> res(new Program(desc_, scope_, valid_places_));
return res;
}
const std::list<std::string>& weights() const { return weights_; }
const std::list<std::string>& tmp_vars() const { return tmp_vars_; }
const std::list<std::shared_ptr<OpLite>>& ops() const { return ops_; }
lite::Scope* exec_scope() { return exec_scope_; }
private:
// Build from a program and scope.
void Build(const framework::proto::ProgramDesc& program) {
CHECK(ops.empty()) << "Executor duplicate Build found";
// Create operators.
for (const auto& proto_op_desc : program.blocks(0).ops()) {
pb::OpDesc op_desc(proto_op_desc);
auto op_type = op_desc.Type();
// if (op_type == "feed" || op_type == "fetch") continue;
VLOG(4) << "create Op [" << op_type << "]";
LOG(INFO) << "create Op [" << op_type << "]";
auto op = LiteOpRegistry::Global().Create(op_type);
CHECK(op) << "no Op found for " << op_type;
ops.emplace_back(std::move(op));
cpp::OpDesc cpp_op_desc;
TransformOpDescPbToCpp(op_desc, &cpp_op_desc);
ops.back()->Attach(cpp_op_desc, exec_scope);
}
}
void Build(const framework::proto::ProgramDesc& program);
// Create temporary variables.
void PrepareWorkspace(const framework::proto::ProgramDesc& program) {
CHECK(!exec_scope) << "Duplicate PrepareWorkspace found";
exec_scope = &scope->NewScope();
// Create Feed and Fetch var.
scope->Var("feed")->GetMutable<std::vector<lite::Tensor>>();
scope->Var("fetch")->GetMutable<std::vector<lite::Tensor>>();
tmp_vars.push_back("feed");
tmp_vars.push_back("fetch");
CHECK(!program.blocks().empty());
for (auto proto_var_desc : program.blocks(0).vars()) {
lite::VarDesc var_desc(proto_var_desc);
if (!var_desc.Persistable()) {
tmp_vars.push_back(var_desc.Name());
exec_scope->Var(var_desc.Name());
} else {
if (var_desc.Name() == "feed" || var_desc.Name() == "fetch") continue;
weights.push_back(var_desc.Name());
}
}
}
void PrepareWorkspace(const framework::proto::ProgramDesc& program);
private:
std::list<std::string> tmp_vars_;
std::list<std::string> weights_;
std::list<std::shared_ptr<OpLite>> ops_;
// the scope to run the kernels, NOTE this is the execution scope.
std::shared_ptr<lite::Scope> scope_;
std::vector<Place> valid_places_;
// Runtime scope.
lite::Scope* exec_scope_{};
const framework::proto::ProgramDesc desc_;
};
struct Instruct {
Instruct(const std::shared_ptr<OpLite>& op,
std::unique_ptr<KernelBase>&& kernel)
struct Instruction {
Instruction(const std::shared_ptr<OpLite>& op,
std::unique_ptr<KernelBase>&& kernel)
: op_(op), kernel_(std::move(kernel)) {
#ifdef LITE_WITH_PROFILE
profile_id_ = profile::BasicProfiler<profile::BasicTimer>::Global()
......@@ -132,7 +101,7 @@ struct Instruct {
kernel_->Launch();
}
friend std::ostream& operator<<(std::ostream& os, const Instruct& other) {
friend std::ostream& operator<<(std::ostream& os, const Instruction& other) {
os << other.kernel_->summary() << "\t(" << other.kernel_->doc() << ")";
return os;
}
......@@ -156,7 +125,7 @@ struct Instruct {
*/
class RuntimeProgram {
public:
explicit RuntimeProgram(std::vector<Instruct>&& insts)
explicit RuntimeProgram(std::vector<Instruction>&& insts)
: instructions_(std::move(insts)) {
if (instructions_.empty()) {
LOG(FATAL) << "no instructions";
......@@ -186,7 +155,7 @@ class RuntimeProgram {
private:
RuntimeProgram(const RuntimeProgram&) = delete;
std::vector<Instruct> instructions_;
std::vector<Instruction> instructions_;
lite::Scope* exec_scope_{};
};
......
......@@ -33,9 +33,9 @@ Program FakeProgram() {
std::string w1 = "w" + std::to_string(id);
std::string b1 = "b" + std::to_string(id);
std::string out1 = "out" + std::to_string(id);
auto w1v = program.scope->Var(w1)->GetMutable<lite::Tensor>();
auto b1v = program.scope->Var(b1)->GetMutable<lite::Tensor>();
auto out1v = program.scope->Var(out1)->GetMutable<lite::Tensor>();
auto w1v = program.scope_->Var(w1)->GetMutable<lite::Tensor>();
auto b1v = program.scope_->Var(b1)->GetMutable<lite::Tensor>();
auto out1v = program.scope_->Var(out1)->GetMutable<lite::Tensor>();
lite::OpDesc desc;
desc.SetInput("Input", {x});
......@@ -46,12 +46,12 @@ Program FakeProgram() {
desc.SetAttr("in_num_col_dims", 1);
// add to input
program.tmp_vars.push_back(w1);
program.tmp_vars.push_back(b1);
program.tmp_vars_.push_back(w1);
program.tmp_vars_.push_back(b1);
auto fc_op = LiteOpRegistry::Global().Create("fc");
fc_op->Attach(desc, program.scope.get());
program.ops.emplace_back(std::move(fc_op));
fc_op->Attach(desc, program.scope_.get());
program.ops_.emplace_back(std::move(fc_op));
w1v->Resize(DDimHvy(std::vector<int64_t>({100, 100})));
b1v->Resize(DDimHvy(std::vector<int64_t>({100, 1})));
......@@ -64,8 +64,8 @@ Program FakeProgram() {
// out1, w2, b2 -fc-> out2
std::string x = "x";
program.tmp_vars.push_back(x);
auto* xv = program.scope->Var(x)->GetMutable<lite::Tensor>();
program.tmp_vars_.push_back(x);
auto* xv = program.scope_->Var(x)->GetMutable<lite::Tensor>();
xv->Resize(DDimHvy(std::vector<int64_t>({100, 100})));
for (int i = 0; i < 3; i++) {
......
......@@ -13,7 +13,8 @@
// limitations under the License.
#include "paddle/fluid/lite/model_parser/compatible_pb.h"
#include "compatible_pb.h"
#include <string>
#include <vector>
namespace paddle {
namespace lite {
......
......@@ -61,6 +61,7 @@ static std::string Join(const std::vector<std::string>& vec,
if (!vec.empty()) {
ss << vec.back();
}
return ss.str();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册