提交 a8caedbe 编写于 作者: Z zhoubo01

remove depedence on predictor.clone()

上级 bc4c9c43
...@@ -24,20 +24,6 @@ using namespace paddle::lite_api; ...@@ -24,20 +24,6 @@ using namespace paddle::lite_api;
const int ITER = 10; const int ITER = 10;
std::shared_ptr<PaddlePredictor> create_paddle_predictor(const std::string& model_dir) {
// 1. Create CxxConfig
CxxConfig config;
config.set_model_dir(model_dir);
config.set_valid_places({
Place{TARGET(kX86), PRECISION(kFloat)},
Place{TARGET(kHost), PRECISION(kFloat)}
});
// 2. Create PaddlePredictor by CxxConfig
std::shared_ptr<PaddlePredictor> predictor = CreatePaddlePredictor<CxxConfig>(config);
return predictor;
}
// Use PaddlePredictor of CartPole model to predict the action. // Use PaddlePredictor of CartPole model to predict the action.
std::vector<float> forward(std::shared_ptr<PaddlePredictor> predictor, const float* obs) { std::vector<float> forward(std::shared_ptr<PaddlePredictor> predictor, const float* obs) {
std::unique_ptr<Tensor> input_tensor(std::move(predictor->GetInput(0))); std::unique_ptr<Tensor> input_tensor(std::move(predictor->GetInput(0)));
...@@ -86,8 +72,7 @@ int main(int argc, char* argv[]) { ...@@ -86,8 +72,7 @@ int main(int argc, char* argv[]) {
envs.push_back(CartPole()); envs.push_back(CartPole());
} }
std::shared_ptr<PaddlePredictor> paddle_predictor = create_paddle_predictor("../demo/paddle/cartpole_init_model"); std::shared_ptr<AsyncESAgent> agent = std::make_shared<AsyncESAgent>("../demo/paddle/cartpole_init_model", "../benchmark/cartpole_config.prototxt");
std::shared_ptr<AsyncESAgent> agent = std::make_shared<AsyncESAgent>(paddle_predictor, "../benchmark/cartpole_config.prototxt");
// Clone agents to sample (explore). // Clone agents to sample (explore).
std::vector< std::shared_ptr<AsyncESAgent> > sampling_agents; std::vector< std::shared_ptr<AsyncESAgent> > sampling_agents;
......
...@@ -24,20 +24,6 @@ using namespace paddle::lite_api; ...@@ -24,20 +24,6 @@ using namespace paddle::lite_api;
const int ITER = 10; const int ITER = 10;
std::shared_ptr<PaddlePredictor> create_paddle_predictor(const std::string& model_dir) {
// 1. Create CxxConfig
CxxConfig config;
config.set_model_dir(model_dir);
config.set_valid_places({
Place{TARGET(kX86), PRECISION(kFloat)},
Place{TARGET(kHost), PRECISION(kFloat)}
});
// 2. Create PaddlePredictor by CxxConfig
std::shared_ptr<PaddlePredictor> predictor = CreatePaddlePredictor<CxxConfig>(config);
return predictor;
}
// Use PaddlePredictor of CartPole model to predict the action. // Use PaddlePredictor of CartPole model to predict the action.
std::vector<float> forward(std::shared_ptr<PaddlePredictor> predictor, const float* obs) { std::vector<float> forward(std::shared_ptr<PaddlePredictor> predictor, const float* obs) {
std::unique_ptr<Tensor> input_tensor(std::move(predictor->GetInput(0))); std::unique_ptr<Tensor> input_tensor(std::move(predictor->GetInput(0)));
...@@ -86,8 +72,8 @@ int main(int argc, char* argv[]) { ...@@ -86,8 +72,8 @@ int main(int argc, char* argv[]) {
envs.push_back(CartPole()); envs.push_back(CartPole());
} }
std::shared_ptr<PaddlePredictor> paddle_predictor = create_paddle_predictor("../demo/paddle/cartpole_init_model"); //std::shared_ptr<PaddlePredictor> paddle_predictor = create_paddle_predictor("../demo/paddle/cartpole_init_model");
std::shared_ptr<ESAgent> agent = std::make_shared<ESAgent>(paddle_predictor, "../benchmark/cartpole_config.prototxt"); std::shared_ptr<ESAgent> agent = std::make_shared<ESAgent>("../demo/paddle/cartpole_init_model", "../benchmark/cartpole_config.prototxt");
// Clone agents to sample (explore). // Clone agents to sample (explore).
std::vector< std::shared_ptr<ESAgent> > sampling_agents; std::vector< std::shared_ptr<ESAgent> > sampling_agents;
......
...@@ -36,4 +36,6 @@ if __name__ == '__main__': ...@@ -36,4 +36,6 @@ if __name__ == '__main__':
dirname='cartpole_init_model', dirname='cartpole_init_model',
feeded_var_names=['obs'], feeded_var_names=['obs'],
target_vars=[prob], target_vars=[prob],
params_filename='param',
model_filename='model',
executor=exe) executor=exe)
...@@ -28,7 +28,9 @@ namespace DeepES{ ...@@ -28,7 +28,9 @@ namespace DeepES{
*/ */
class AsyncESAgent: public ESAgent { class AsyncESAgent: public ESAgent {
public: public:
AsyncESAgent() {} AsyncESAgent() = delete;
AsyncESAgent(const CxxConfig& cxx_config);
~AsyncESAgent(); ~AsyncESAgent();
...@@ -40,8 +42,8 @@ class AsyncESAgent: public ESAgent { ...@@ -40,8 +42,8 @@ class AsyncESAgent: public ESAgent {
* Please use the up-to-date configuration. * Please use the up-to-date configuration.
*/ */
AsyncESAgent( AsyncESAgent(
std::shared_ptr<PaddlePredictor> predictor, const std::string& model_dir,
std::string config_path); const std::string& config_path);
/** /**
* @brief: Clone an agent for sampling. * @brief: Clone an agent for sampling.
......
...@@ -38,14 +38,14 @@ int64_t ShapeProduction(const shape_t& shape); ...@@ -38,14 +38,14 @@ int64_t ShapeProduction(const shape_t& shape);
*/ */
class ESAgent { class ESAgent {
public: public:
ESAgent(); ESAgent() = delete;
~ESAgent(); ~ESAgent();
ESAgent(
std::shared_ptr<PaddlePredictor> predictor, ESAgent(const std::string& model_dir, const std::string& config_path);
std::string config_path);
ESAgent(const CxxConfig& cxx_config);
/** /**
* @breif Clone a sampling agent * @breif Clone a sampling agent
* *
...@@ -83,15 +83,16 @@ class ESAgent { ...@@ -83,15 +83,16 @@ class ESAgent {
std::shared_ptr<PaddlePredictor> _predictor; std::shared_ptr<PaddlePredictor> _predictor;
std::shared_ptr<PaddlePredictor> _sampling_predictor; std::shared_ptr<PaddlePredictor> _sampling_predictor;
bool _is_sampling_agent;
std::shared_ptr<SamplingMethod> _sampling_method; std::shared_ptr<SamplingMethod> _sampling_method;
std::shared_ptr<Optimizer> _optimizer; std::shared_ptr<Optimizer> _optimizer;
std::shared_ptr<DeepESConfig> _config; std::shared_ptr<DeepESConfig> _config;
int64_t _param_size; std::shared_ptr<CxxConfig> _cxx_config;
std::vector<std::string> _param_names; std::vector<std::string> _param_names;
// malloc memory of noise and neg_gradients in advance. // malloc memory of noise and neg_gradients in advance.
float* _noise; float* _noise;
float* _neg_gradients; float* _neg_gradients;
int64_t _param_size;
bool _is_sampling_agent;
}; };
} }
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <glog/logging.h> #include <glog/logging.h>
#include "deepes.pb.h" #include "deepes.pb.h"
#include <google/protobuf/text_format.h> #include <google/protobuf/text_format.h>
#include <fstream>
namespace DeepES{ namespace DeepES{
...@@ -29,6 +30,8 @@ namespace DeepES{ ...@@ -29,6 +30,8 @@ namespace DeepES{
*/ */
bool compute_centered_ranks(std::vector<float> &reward); bool compute_centered_ranks(std::vector<float> &reward);
std::string read_file(const std::string& filename);
/* Load a protobuf-based configuration from the file. /* Load a protobuf-based configuration from the file.
* Args: * Args:
* config_file: file path. * config_file: file path.
......
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
namespace DeepES { namespace DeepES {
AsyncESAgent::AsyncESAgent( AsyncESAgent::AsyncESAgent(
std::shared_ptr<PaddlePredictor> predictor, const std::string& model_dir,
std::string config_path): ESAgent(predictor, config_path) { const std::string& config_path): ESAgent(model_dir, config_path) {
_config_path = config_path; _config_path = config_path;
} }
AsyncESAgent::~AsyncESAgent() { AsyncESAgent::~AsyncESAgent() {
...@@ -154,15 +154,16 @@ std::shared_ptr<PaddlePredictor> AsyncESAgent::_load_previous_model(std::string ...@@ -154,15 +154,16 @@ std::shared_ptr<PaddlePredictor> AsyncESAgent::_load_previous_model(std::string
return predictor; return predictor;
} }
AsyncESAgent::AsyncESAgent(const CxxConfig& cxx_config): ESAgent(cxx_config){
}
std::shared_ptr<AsyncESAgent> AsyncESAgent::clone() { std::shared_ptr<AsyncESAgent> AsyncESAgent::clone() {
std::shared_ptr<PaddlePredictor> new_sampling_predictor = _predictor->Clone();
std::shared_ptr<AsyncESAgent> new_agent = std::make_shared<AsyncESAgent>(); std::shared_ptr<AsyncESAgent> new_agent = std::make_shared<AsyncESAgent>(*_cxx_config);
float* noise = new float [_param_size]; float* noise = new float [_param_size];
new_agent->_predictor = _predictor; new_agent->_predictor = _predictor;
new_agent->_sampling_predictor = new_sampling_predictor;
new_agent->_is_sampling_agent = true; new_agent->_is_sampling_agent = true;
new_agent->_sampling_method = _sampling_method; new_agent->_sampling_method = _sampling_method;
......
...@@ -23,22 +23,31 @@ int64_t ShapeProduction(const shape_t& shape) { ...@@ -23,22 +23,31 @@ int64_t ShapeProduction(const shape_t& shape) {
return res; return res;
} }
ESAgent::ESAgent() {}
ESAgent::~ESAgent() { ESAgent::~ESAgent() {
delete[] _noise; delete[] _noise;
if (!_is_sampling_agent) if (!_is_sampling_agent)
delete[] _neg_gradients; delete[] _neg_gradients;
} }
ESAgent::ESAgent( ESAgent::ESAgent(const std::string& model_dir, const std::string& config_path) {
std::shared_ptr<PaddlePredictor> predictor, // 1. Create CxxConfig
std::string config_path) { _cxx_config = std::make_shared<CxxConfig>();
std::string model_path = model_dir + "/model";
std::string param_path = model_dir + "/param";
std::string model_buffer = read_file(model_path);
std::string param_buffer = read_file(param_path);
_cxx_config->set_model_buffer(model_buffer.c_str(), model_buffer.size(),
param_buffer.c_str(), param_buffer.size());
_cxx_config->set_valid_places({
Place{TARGET(kX86), PRECISION(kFloat)},
Place{TARGET(kHost), PRECISION(kFloat)}
});
_predictor = CreatePaddlePredictor<CxxConfig>(*_cxx_config);
_is_sampling_agent = false; _is_sampling_agent = false;
_predictor = predictor;
// Original agent can't be used to sample, so keep it same with _predictor for evaluating. // Original agent can't be used to sample, so keep it same with _predictor for evaluating.
_sampling_predictor = predictor; _sampling_predictor = _predictor;
_config = std::make_shared<DeepESConfig>(); _config = std::make_shared<DeepESConfig>();
load_proto_conf(config_path, *_config); load_proto_conf(config_path, *_config);
...@@ -55,16 +64,21 @@ ESAgent::ESAgent( ...@@ -55,16 +64,21 @@ ESAgent::ESAgent(
_neg_gradients = new float [_param_size]; _neg_gradients = new float [_param_size];
} }
std::shared_ptr<ESAgent> ESAgent::clone() { ESAgent::ESAgent(const CxxConfig& cxx_config) {
std::shared_ptr<PaddlePredictor> new_sampling_predictor = _predictor->Clone(); _sampling_predictor = CreatePaddlePredictor<CxxConfig>(cxx_config);
}
std::shared_ptr<ESAgent> new_agent = std::make_shared<ESAgent>(); std::shared_ptr<ESAgent> ESAgent::clone() {
if (_is_sampling_agent) {
LOG(ERROR) << "[DeepES] only original ESAgent can call `clone` function.";
return nullptr;
}
std::shared_ptr<ESAgent> new_agent = std::make_shared<ESAgent>(*_cxx_config);
float* noise = new float [_param_size]; float* noise = new float [_param_size];
new_agent->_predictor = _predictor; new_agent->_predictor = _predictor;
new_agent->_sampling_predictor = new_sampling_predictor; new_agent->_cxx_config = _cxx_config;
new_agent->_is_sampling_agent = true; new_agent->_is_sampling_agent = true;
new_agent->_sampling_method = _sampling_method; new_agent->_sampling_method = _sampling_method;
new_agent->_param_names = _param_names; new_agent->_param_names = _param_names;
......
...@@ -52,4 +52,16 @@ std::vector<std::string> list_all_model_dirs(std::string path) { ...@@ -52,4 +52,16 @@ std::vector<std::string> list_all_model_dirs(std::string path) {
return model_dirs; return model_dirs;
} }
std::string read_file(const std::string& filename) {
std::ifstream ifile(filename.c_str());
if (!ifile.is_open()) {
LOG(FATAL) << "Open file: [" << filename << "] failed.";
}
std::ostringstream buf;
char ch;
while (buf && ifile.get(ch)) buf.put(ch);
ifile.close();
return buf.str();
}
}//namespace }//namespace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册