提交 a8caedbe 编写于 作者: Z zhoubo01

remove depedence on predictor.clone()

上级 bc4c9c43
......@@ -24,20 +24,6 @@ using namespace paddle::lite_api;
const int ITER = 10;
std::shared_ptr<PaddlePredictor> create_paddle_predictor(const std::string& model_dir) {
// 1. Create CxxConfig
CxxConfig config;
config.set_model_dir(model_dir);
config.set_valid_places({
Place{TARGET(kX86), PRECISION(kFloat)},
Place{TARGET(kHost), PRECISION(kFloat)}
});
// 2. Create PaddlePredictor by CxxConfig
std::shared_ptr<PaddlePredictor> predictor = CreatePaddlePredictor<CxxConfig>(config);
return predictor;
}
// Use PaddlePredictor of CartPole model to predict the action.
std::vector<float> forward(std::shared_ptr<PaddlePredictor> predictor, const float* obs) {
std::unique_ptr<Tensor> input_tensor(std::move(predictor->GetInput(0)));
......@@ -86,8 +72,7 @@ int main(int argc, char* argv[]) {
envs.push_back(CartPole());
}
std::shared_ptr<PaddlePredictor> paddle_predictor = create_paddle_predictor("../demo/paddle/cartpole_init_model");
std::shared_ptr<AsyncESAgent> agent = std::make_shared<AsyncESAgent>(paddle_predictor, "../benchmark/cartpole_config.prototxt");
std::shared_ptr<AsyncESAgent> agent = std::make_shared<AsyncESAgent>("../demo/paddle/cartpole_init_model", "../benchmark/cartpole_config.prototxt");
// Clone agents to sample (explore).
std::vector< std::shared_ptr<AsyncESAgent> > sampling_agents;
......
......@@ -24,20 +24,6 @@ using namespace paddle::lite_api;
const int ITER = 10;
std::shared_ptr<PaddlePredictor> create_paddle_predictor(const std::string& model_dir) {
// 1. Create CxxConfig
CxxConfig config;
config.set_model_dir(model_dir);
config.set_valid_places({
Place{TARGET(kX86), PRECISION(kFloat)},
Place{TARGET(kHost), PRECISION(kFloat)}
});
// 2. Create PaddlePredictor by CxxConfig
std::shared_ptr<PaddlePredictor> predictor = CreatePaddlePredictor<CxxConfig>(config);
return predictor;
}
// Use PaddlePredictor of CartPole model to predict the action.
std::vector<float> forward(std::shared_ptr<PaddlePredictor> predictor, const float* obs) {
std::unique_ptr<Tensor> input_tensor(std::move(predictor->GetInput(0)));
......@@ -86,8 +72,8 @@ int main(int argc, char* argv[]) {
envs.push_back(CartPole());
}
std::shared_ptr<PaddlePredictor> paddle_predictor = create_paddle_predictor("../demo/paddle/cartpole_init_model");
std::shared_ptr<ESAgent> agent = std::make_shared<ESAgent>(paddle_predictor, "../benchmark/cartpole_config.prototxt");
//std::shared_ptr<PaddlePredictor> paddle_predictor = create_paddle_predictor("../demo/paddle/cartpole_init_model");
std::shared_ptr<ESAgent> agent = std::make_shared<ESAgent>("../demo/paddle/cartpole_init_model", "../benchmark/cartpole_config.prototxt");
// Clone agents to sample (explore).
std::vector< std::shared_ptr<ESAgent> > sampling_agents;
......
......@@ -36,4 +36,6 @@ if __name__ == '__main__':
dirname='cartpole_init_model',
feeded_var_names=['obs'],
target_vars=[prob],
params_filename='param',
model_filename='model',
executor=exe)
......@@ -28,7 +28,9 @@ namespace DeepES{
*/
class AsyncESAgent: public ESAgent {
public:
AsyncESAgent() {}
AsyncESAgent() = delete;
AsyncESAgent(const CxxConfig& cxx_config);
~AsyncESAgent();
......@@ -40,8 +42,8 @@ class AsyncESAgent: public ESAgent {
* Please use the up-to-date configuration.
*/
AsyncESAgent(
std::shared_ptr<PaddlePredictor> predictor,
std::string config_path);
const std::string& model_dir,
const std::string& config_path);
/**
* @brief: Clone an agent for sampling.
......
......@@ -38,14 +38,14 @@ int64_t ShapeProduction(const shape_t& shape);
*/
class ESAgent {
public:
ESAgent();
ESAgent() = delete;
~ESAgent();
ESAgent(
std::shared_ptr<PaddlePredictor> predictor,
std::string config_path);
ESAgent(const std::string& model_dir, const std::string& config_path);
ESAgent(const CxxConfig& cxx_config);
/**
* @breif Clone a sampling agent
*
......@@ -83,15 +83,16 @@ class ESAgent {
std::shared_ptr<PaddlePredictor> _predictor;
std::shared_ptr<PaddlePredictor> _sampling_predictor;
bool _is_sampling_agent;
std::shared_ptr<SamplingMethod> _sampling_method;
std::shared_ptr<Optimizer> _optimizer;
std::shared_ptr<DeepESConfig> _config;
int64_t _param_size;
std::shared_ptr<CxxConfig> _cxx_config;
std::vector<std::string> _param_names;
// malloc memory of noise and neg_gradients in advance.
float* _noise;
float* _neg_gradients;
int64_t _param_size;
bool _is_sampling_agent;
};
}
......
......@@ -20,6 +20,7 @@
#include <glog/logging.h>
#include "deepes.pb.h"
#include <google/protobuf/text_format.h>
#include <fstream>
namespace DeepES{
......@@ -29,6 +30,8 @@ namespace DeepES{
*/
bool compute_centered_ranks(std::vector<float> &reward);
std::string read_file(const std::string& filename);
/* Load a protobuf-based configuration from the file.
* Args:
* config_file: file path.
......
......@@ -16,8 +16,8 @@
namespace DeepES {
AsyncESAgent::AsyncESAgent(
std::shared_ptr<PaddlePredictor> predictor,
std::string config_path): ESAgent(predictor, config_path) {
const std::string& model_dir,
const std::string& config_path): ESAgent(model_dir, config_path) {
_config_path = config_path;
}
AsyncESAgent::~AsyncESAgent() {
......@@ -154,15 +154,16 @@ std::shared_ptr<PaddlePredictor> AsyncESAgent::_load_previous_model(std::string
return predictor;
}
AsyncESAgent::AsyncESAgent(const CxxConfig& cxx_config): ESAgent(cxx_config){
}
std::shared_ptr<AsyncESAgent> AsyncESAgent::clone() {
std::shared_ptr<PaddlePredictor> new_sampling_predictor = _predictor->Clone();
std::shared_ptr<AsyncESAgent> new_agent = std::make_shared<AsyncESAgent>();
std::shared_ptr<AsyncESAgent> new_agent = std::make_shared<AsyncESAgent>(*_cxx_config);
float* noise = new float [_param_size];
new_agent->_predictor = _predictor;
new_agent->_sampling_predictor = new_sampling_predictor;
new_agent->_is_sampling_agent = true;
new_agent->_sampling_method = _sampling_method;
......
......@@ -23,22 +23,31 @@ int64_t ShapeProduction(const shape_t& shape) {
return res;
}
ESAgent::ESAgent() {}
ESAgent::~ESAgent() {
delete[] _noise;
if (!_is_sampling_agent)
delete[] _neg_gradients;
}
ESAgent::ESAgent(
std::shared_ptr<PaddlePredictor> predictor,
std::string config_path) {
ESAgent::ESAgent(const std::string& model_dir, const std::string& config_path) {
// 1. Create CxxConfig
_cxx_config = std::make_shared<CxxConfig>();
std::string model_path = model_dir + "/model";
std::string param_path = model_dir + "/param";
std::string model_buffer = read_file(model_path);
std::string param_buffer = read_file(param_path);
_cxx_config->set_model_buffer(model_buffer.c_str(), model_buffer.size(),
param_buffer.c_str(), param_buffer.size());
_cxx_config->set_valid_places({
Place{TARGET(kX86), PRECISION(kFloat)},
Place{TARGET(kHost), PRECISION(kFloat)}
});
_predictor = CreatePaddlePredictor<CxxConfig>(*_cxx_config);
_is_sampling_agent = false;
_predictor = predictor;
// Original agent can't be used to sample, so keep it same with _predictor for evaluating.
_sampling_predictor = predictor;
_sampling_predictor = _predictor;
_config = std::make_shared<DeepESConfig>();
load_proto_conf(config_path, *_config);
......@@ -55,16 +64,21 @@ ESAgent::ESAgent(
_neg_gradients = new float [_param_size];
}
std::shared_ptr<ESAgent> ESAgent::clone() {
std::shared_ptr<PaddlePredictor> new_sampling_predictor = _predictor->Clone();
ESAgent::ESAgent(const CxxConfig& cxx_config) {
_sampling_predictor = CreatePaddlePredictor<CxxConfig>(cxx_config);
}
std::shared_ptr<ESAgent> new_agent = std::make_shared<ESAgent>();
std::shared_ptr<ESAgent> ESAgent::clone() {
if (_is_sampling_agent) {
LOG(ERROR) << "[DeepES] only original ESAgent can call `clone` function.";
return nullptr;
}
std::shared_ptr<ESAgent> new_agent = std::make_shared<ESAgent>(*_cxx_config);
float* noise = new float [_param_size];
new_agent->_predictor = _predictor;
new_agent->_sampling_predictor = new_sampling_predictor;
new_agent->_cxx_config = _cxx_config;
new_agent->_is_sampling_agent = true;
new_agent->_sampling_method = _sampling_method;
new_agent->_param_names = _param_names;
......
......@@ -52,4 +52,16 @@ std::vector<std::string> list_all_model_dirs(std::string path) {
return model_dirs;
}
std::string read_file(const std::string& filename) {
std::ifstream ifile(filename.c_str());
if (!ifile.is_open()) {
LOG(FATAL) << "Open file: [" << filename << "] failed.";
}
std::ostringstream buf;
char ch;
while (buf && ifile.get(ch)) buf.put(ch);
ifile.close();
return buf.str();
}
}//namespace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册