提交 2f6d1e10 编写于 作者: Z zhoubo01

fix comments

上级 752974cb
......@@ -11,5 +11,5 @@ optimizer {
epsilon: 1e-08
}
async_es {
model_iter_id: 0
model_iter_id: 99
}
......@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef _ASYNC_ES_AGENT_H
#define _ASYNC_ES_AGENT_H
#ifndef ASYNC_ES_AGENT_H
#define ASYNC_ES_AGENT_H
#include "es_agent.h"
#include <map>
......@@ -49,7 +49,10 @@ class AsyncESAgent: public ESAgent {
std::shared_ptr<AsyncESAgent> clone();
/**
* @brief: Clone an agent for sampling.
* @brief: update parameters given data collected during evaluation.
* @args:
* noisy_info: sampling information returned by add_noise function.
* noisy_reward: evaluation rewards.
*/
bool update(
std::vector<SamplingInfo>& noisy_info,
......
......@@ -21,15 +21,13 @@
#include "gaussian_sampling.h"
#include "deepes.pb.h"
#include <vector>
using namespace paddle::lite_api;
using namespace paddle::lite_api;
namespace DeepES {
int64_t ShapeProduction(const shape_t& shape);
typedef paddle::lite_api::PaddlePredictor PaddlePredictor;
/**
* @brief DeepES agent with PaddleLite as backend.
* Users mainly focus on the following functions:
......
......@@ -40,7 +40,7 @@ bool load_proto_conf(const std::string& config_file, T& proto_config) {
bool success = true;
std::ifstream fin(config_file);
if (!fin || fin.fail()) {
LOG(FATAL) << "open prototxt config failed: " << config_file;
LOG(ERROR) << "open prototxt config failed: " << config_file;
success = false;
} else {
fin.seekg(0, std::ios::end);
......@@ -52,7 +52,7 @@ bool load_proto_conf(const std::string& config_file, T& proto_config) {
std::string proto_str(file_content_buffer, file_size);
if (!google::protobuf::TextFormat::ParseFromString(proto_str, &proto_config)) {
LOG(FATAL) << "Failed to load config: " << config_file;
LOG(ERROR) << "Failed to load config: " << config_file;
success = false;
}
delete[] file_content_buffer;
......@@ -66,7 +66,7 @@ bool save_proto_conf(const std::string& config_file, T&proto_config) {
bool success = true;
std::ofstream ofs(config_file, std::ofstream::out);
if (!ofs || ofs.fail()) {
LOG(FATAL) << "open prototxt config failed: " << config_file;
LOG(ERROR) << "open prototxt config failed: " << config_file;
success = false;
} else {
std::string config_str;
......@@ -76,6 +76,7 @@ bool save_proto_conf(const std::string& config_file, T&proto_config) {
}
ofs << config_str;
}
return success;
}
std::vector<std::string> list_all_model_dirs(std::string path);
......
......@@ -32,8 +32,6 @@ else
exit 0
fi
#export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
#----------------protobuf-------------#
cp ./src/proto/deepes.proto ./
protoc deepes.proto --cpp_out ./
......
......@@ -30,7 +30,7 @@ AsyncESAgent::~AsyncESAgent() {
bool AsyncESAgent::_save() {
bool success = true;
if (_is_sampling_agent) {
LOG(ERROR) << "[DeepES] Original AsyncESAgent cannot call add_noise function, please use cloned AsyncESAgent.";
LOG(ERROR) << "[DeepES] Original AsyncESAgent cannot call `save`.Please use cloned AsyncESAgent.";
success = false;
return success;
}
......@@ -49,7 +49,7 @@ bool AsyncESAgent::_save() {
model_name = "model_iter_id-"+ std::to_string(model_iter_id);
std::string model_path = _config->async_es().model_warehouse() + "/" + model_name;
LOG(INFO) << "[save]model_path: " << model_path;
_predictor->SaveOptimizedModel(model_path, LiteModelType::kProtobuf);
_predictor->SaveOptimizedModel(model_path, paddle::lite_api::LiteModelType::kProtobuf);
// save config
auto async_es = _config->mutable_async_es();
async_es->set_model_iter_id(model_iter_id);
......@@ -93,15 +93,17 @@ bool AsyncESAgent::_compute_model_diff() {
std::shared_ptr<PaddlePredictor> old_predictor = kv.second;
float* diff = new float[_param_size];
memset(diff, 0, _param_size * sizeof(float));
for (std::string param_name: _param_names) {
int offset = 0;
for (const std::string& param_name: _param_names) {
auto des_tensor = old_predictor->GetTensor(param_name);
auto src_tensor = _predictor->GetTensor(param_name);
const float* des_data = des_tensor->data<float>();
const float* src_data = src_tensor->data<float>();
int64_t tensor_size = ShapeProduction(src_tensor->shape());
for (int i = 0; i < tensor_size; ++i) {
diff[i] = des_data[i] - src_data[i];
diff[i + offset] = des_data[i] - src_data[i];
}
offset += tensor_size;
}
_param_delta[model_iter_id] = diff;
}
......@@ -206,6 +208,7 @@ bool AsyncESAgent::update(
float reward = noisy_rewards[i];
int model_iter_id = noisy_info[i].model_iter_id();
bool success = _sampling_method->resampling(key, _noise, _param_size);
CHECK(success) << "[DeepES] resampling error occurs at sample: " << i;
float* delta = _param_delta[model_iter_id];
// compute neg_gradients
if (model_iter_id == current_model_iter_id) {
......
......@@ -17,9 +17,6 @@
namespace DeepES {
typedef paddle::lite_api::Tensor Tensor;
typedef paddle::lite_api::shape_t shape_t;
int64_t ShapeProduction(const shape_t& shape) {
int64_t res = 1;
for (auto i : shape) res *= i;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册