diff --git a/deepes/demo/paddle/cartpole_async_solver.cc b/deepes/demo/paddle/cartpole_async_solver.cc index 5cbe48eb0ddc8e1dc7458371ae8c0367e19e0198..7d244b104375d4e323659fcc13960d044c1a74ed 100644 --- a/deepes/demo/paddle/cartpole_async_solver.cc +++ b/deepes/demo/paddle/cartpole_async_solver.cc @@ -24,20 +24,6 @@ using namespace paddle::lite_api; const int ITER = 10; -std::shared_ptr create_paddle_predictor(const std::string& model_dir) { - // 1. Create CxxConfig - CxxConfig config; - config.set_model_dir(model_dir); - config.set_valid_places({ - Place{TARGET(kX86), PRECISION(kFloat)}, - Place{TARGET(kHost), PRECISION(kFloat)} - }); - - // 2. Create PaddlePredictor by CxxConfig - std::shared_ptr predictor = CreatePaddlePredictor(config); - return predictor; -} - // Use PaddlePredictor of CartPole model to predict the action. std::vector forward(std::shared_ptr predictor, const float* obs) { std::unique_ptr input_tensor(std::move(predictor->GetInput(0))); @@ -86,8 +72,7 @@ int main(int argc, char* argv[]) { envs.push_back(CartPole()); } - std::shared_ptr paddle_predictor = create_paddle_predictor("../demo/paddle/cartpole_init_model"); - std::shared_ptr agent = std::make_shared(paddle_predictor, "../benchmark/cartpole_config.prototxt"); + std::shared_ptr agent = std::make_shared("../demo/paddle/cartpole_init_model", "../benchmark/cartpole_config.prototxt"); // Clone agents to sample (explore). std::vector< std::shared_ptr > sampling_agents; diff --git a/deepes/demo/paddle/cartpole_solver_parallel.cc b/deepes/demo/paddle/cartpole_solver_parallel.cc index 9fccb1a995774a9e98d50dfbf4e42470237c0fed..239fbb052d0785cddbd94d40a1aad118b055d90f 100644 --- a/deepes/demo/paddle/cartpole_solver_parallel.cc +++ b/deepes/demo/paddle/cartpole_solver_parallel.cc @@ -24,20 +24,6 @@ using namespace paddle::lite_api; const int ITER = 10; -std::shared_ptr create_paddle_predictor(const std::string& model_dir) { - // 1. Create CxxConfig - CxxConfig config; - config.set_model_dir(model_dir); - config.set_valid_places({ - Place{TARGET(kX86), PRECISION(kFloat)}, - Place{TARGET(kHost), PRECISION(kFloat)} - }); - - // 2. Create PaddlePredictor by CxxConfig - std::shared_ptr predictor = CreatePaddlePredictor(config); - return predictor; -} - // Use PaddlePredictor of CartPole model to predict the action. std::vector forward(std::shared_ptr predictor, const float* obs) { std::unique_ptr input_tensor(std::move(predictor->GetInput(0))); @@ -86,8 +72,8 @@ int main(int argc, char* argv[]) { envs.push_back(CartPole()); } - std::shared_ptr paddle_predictor = create_paddle_predictor("../demo/paddle/cartpole_init_model"); - std::shared_ptr agent = std::make_shared(paddle_predictor, "../benchmark/cartpole_config.prototxt"); + //std::shared_ptr paddle_predictor = create_paddle_predictor("../demo/paddle/cartpole_init_model"); + std::shared_ptr agent = std::make_shared("../demo/paddle/cartpole_init_model", "../benchmark/cartpole_config.prototxt"); // Clone agents to sample (explore). std::vector< std::shared_ptr > sampling_agents; diff --git a/deepes/demo/paddle/gen_cartpole_init_model.py b/deepes/demo/paddle/gen_cartpole_init_model.py index 66b841aaf4ac428ca2232324a35fa66bd683c572..9295224953e74a9572915d3612bd4634f61de55e 100644 --- a/deepes/demo/paddle/gen_cartpole_init_model.py +++ b/deepes/demo/paddle/gen_cartpole_init_model.py @@ -36,4 +36,6 @@ if __name__ == '__main__': dirname='cartpole_init_model', feeded_var_names=['obs'], target_vars=[prob], + params_filename='param', + model_filename='model', executor=exe) diff --git a/deepes/include/paddle/async_es_agent.h b/deepes/include/paddle/async_es_agent.h index 11b8dff53bdafd23b1ce3524f42b97a688283a53..edc3548bdfb22bce2d010349e45f64505e5c6216 100644 --- a/deepes/include/paddle/async_es_agent.h +++ b/deepes/include/paddle/async_es_agent.h @@ -28,7 +28,9 @@ namespace DeepES{ */ class AsyncESAgent: public ESAgent { public: - AsyncESAgent() {} + AsyncESAgent() = delete; + + AsyncESAgent(const CxxConfig& cxx_config); ~AsyncESAgent(); @@ -40,8 +42,8 @@ class AsyncESAgent: public ESAgent { * Please use the up-to-date configuration. */ AsyncESAgent( - std::shared_ptr predictor, - std::string config_path); + const std::string& model_dir, + const std::string& config_path); /** * @brief: Clone an agent for sampling. diff --git a/deepes/include/paddle/es_agent.h b/deepes/include/paddle/es_agent.h index 25c9d98e9b11776669692233cc1b9061cc8fe1eb..734a998c0987e9274a0623dce91b4f47b3a06f80 100644 --- a/deepes/include/paddle/es_agent.h +++ b/deepes/include/paddle/es_agent.h @@ -38,14 +38,14 @@ int64_t ShapeProduction(const shape_t& shape); */ class ESAgent { public: - ESAgent(); + ESAgent() = delete; ~ESAgent(); - ESAgent( - std::shared_ptr predictor, - std::string config_path); + + ESAgent(const std::string& model_dir, const std::string& config_path); + ESAgent(const CxxConfig& cxx_config); /** * @breif Clone a sampling agent * @@ -83,15 +83,16 @@ class ESAgent { std::shared_ptr _predictor; std::shared_ptr _sampling_predictor; - bool _is_sampling_agent; std::shared_ptr _sampling_method; std::shared_ptr _optimizer; std::shared_ptr _config; - int64_t _param_size; + std::shared_ptr _cxx_config; std::vector _param_names; // malloc memory of noise and neg_gradients in advance. float* _noise; float* _neg_gradients; + int64_t _param_size; + bool _is_sampling_agent; }; } diff --git a/deepes/include/utils.h b/deepes/include/utils.h index 5835a43defd6a4abfeae7a68a5671f3c3239dcfc..76ba45b23b4729170d3bdcb657cecf345fa9107f 100644 --- a/deepes/include/utils.h +++ b/deepes/include/utils.h @@ -20,6 +20,7 @@ #include #include "deepes.pb.h" #include +#include namespace DeepES{ @@ -29,6 +30,8 @@ namespace DeepES{ */ bool compute_centered_ranks(std::vector &reward); +std::string read_file(const std::string& filename); + /* Load a protobuf-based configuration from the file. * Args: * config_file: file path. diff --git a/deepes/src/paddle/async_es_agent.cc b/deepes/src/paddle/async_es_agent.cc index f128ddcde2556009f56f9a9aea829d3ee46ce7b3..35e31b0cc8808751d71206237d249f3d409ac47f 100644 --- a/deepes/src/paddle/async_es_agent.cc +++ b/deepes/src/paddle/async_es_agent.cc @@ -16,8 +16,8 @@ namespace DeepES { AsyncESAgent::AsyncESAgent( - std::shared_ptr predictor, - std::string config_path): ESAgent(predictor, config_path) { + const std::string& model_dir, + const std::string& config_path): ESAgent(model_dir, config_path) { _config_path = config_path; } AsyncESAgent::~AsyncESAgent() { @@ -154,15 +154,16 @@ std::shared_ptr AsyncESAgent::_load_previous_model(std::string return predictor; } +AsyncESAgent::AsyncESAgent(const CxxConfig& cxx_config): ESAgent(cxx_config){ +} + std::shared_ptr AsyncESAgent::clone() { - std::shared_ptr new_sampling_predictor = _predictor->Clone(); - std::shared_ptr new_agent = std::make_shared(); + std::shared_ptr new_agent = std::make_shared(*_cxx_config); float* noise = new float [_param_size]; new_agent->_predictor = _predictor; - new_agent->_sampling_predictor = new_sampling_predictor; new_agent->_is_sampling_agent = true; new_agent->_sampling_method = _sampling_method; diff --git a/deepes/src/paddle/es_agent.cc b/deepes/src/paddle/es_agent.cc index 6cd6f93c896ee23ad15d0b947f823ea52011d806..fba218f20abdd1da91e74281e8b18a4cc040ba4f 100644 --- a/deepes/src/paddle/es_agent.cc +++ b/deepes/src/paddle/es_agent.cc @@ -23,22 +23,31 @@ int64_t ShapeProduction(const shape_t& shape) { return res; } -ESAgent::ESAgent() {} - ESAgent::~ESAgent() { delete[] _noise; if (!_is_sampling_agent) delete[] _neg_gradients; } -ESAgent::ESAgent( - std::shared_ptr predictor, - std::string config_path) { +ESAgent::ESAgent(const std::string& model_dir, const std::string& config_path) { + // 1. Create CxxConfig + _cxx_config = std::make_shared(); + std::string model_path = model_dir + "/model"; + std::string param_path = model_dir + "/param"; + std::string model_buffer = read_file(model_path); + std::string param_buffer = read_file(param_path); + _cxx_config->set_model_buffer(model_buffer.c_str(), model_buffer.size(), + param_buffer.c_str(), param_buffer.size()); + _cxx_config->set_valid_places({ + Place{TARGET(kX86), PRECISION(kFloat)}, + Place{TARGET(kHost), PRECISION(kFloat)} + }); + + _predictor = CreatePaddlePredictor(*_cxx_config); _is_sampling_agent = false; - _predictor = predictor; // Original agent can't be used to sample, so keep it same with _predictor for evaluating. - _sampling_predictor = predictor; + _sampling_predictor = _predictor; _config = std::make_shared(); load_proto_conf(config_path, *_config); @@ -55,16 +64,21 @@ ESAgent::ESAgent( _neg_gradients = new float [_param_size]; } -std::shared_ptr ESAgent::clone() { - std::shared_ptr new_sampling_predictor = _predictor->Clone(); +ESAgent::ESAgent(const CxxConfig& cxx_config) { + _sampling_predictor = CreatePaddlePredictor(cxx_config); +} - std::shared_ptr new_agent = std::make_shared(); +std::shared_ptr ESAgent::clone() { + if (_is_sampling_agent) { + LOG(ERROR) << "[DeepES] only original ESAgent can call `clone` function."; + return nullptr; + } + std::shared_ptr new_agent = std::make_shared(*_cxx_config); float* noise = new float [_param_size]; new_agent->_predictor = _predictor; - new_agent->_sampling_predictor = new_sampling_predictor; - + new_agent->_cxx_config = _cxx_config; new_agent->_is_sampling_agent = true; new_agent->_sampling_method = _sampling_method; new_agent->_param_names = _param_names; diff --git a/deepes/src/utils.cc b/deepes/src/utils.cc index cd5b055405ceefc41d7f8be007b52e9e4ddd7221..f988fe1cb838d14678425839cab35a74e0c6b327 100644 --- a/deepes/src/utils.cc +++ b/deepes/src/utils.cc @@ -52,4 +52,16 @@ std::vector list_all_model_dirs(std::string path) { return model_dirs; } +std::string read_file(const std::string& filename) { + std::ifstream ifile(filename.c_str()); + if (!ifile.is_open()) { + LOG(FATAL) << "Open file: [" << filename << "] failed."; + } + std::ostringstream buf; + char ch; + while (buf && ifile.get(ch)) buf.put(ch); + ifile.close(); + return buf.str(); +} + }//namespace