提交 3f22fd3e 编写于 作者: Z zhoubo01

rename SamplingKey to SamplingInfo

上级 f46ad361
......@@ -11,14 +11,14 @@ auto agent = ESAgent(config);
for (int i = 0; i < 10; ++i) {
auto sampling_agnet = agent->clone(); // clone出一个sampling agent
SamplingKey key;
agent->add_noise(key); // 参数扰动,同时保存随机种子到key
SamplingInfo info;
agent->add_noise(info); // 参数扰动,同时保存随机种子到info
int reward = evaluate(env, sampling_agent); //评估参数
noisy_keys.push_back(key); // 记录随机噪声对应种子
noisy_info.push_back(info); // 记录随机噪声对应种子
noisy_rewards.push_back(reward); // 记录评估结果
}
//根据评估结果、随机种子更新参数,然后重复以上过程,直到收敛。
agent->update(noisy_keys, noisy_rewards);
agent->update(noisy_info, noisy_rewards);
```
## 一键运行demo列表
......
......@@ -95,25 +95,25 @@ int main(int argc, char* argv[]) {
sampling_agents.push_back(agent->clone());
}
std::vector<SamplingKey> noisy_keys;
std::vector<SamplingInfo> noisy_info;
std::vector<float> noisy_rewards(ITER, 0.0f);
noisy_keys.resize(ITER);
noisy_info.resize(ITER);
omp_set_num_threads(10);
for (int epoch = 0; epoch < 1000; ++epoch) {
for (int epoch = 0; epoch < 300; ++epoch) {
#pragma omp parallel for schedule(dynamic, 1)
for (int i = 0; i < ITER; ++i) {
std::shared_ptr<ESAgent> sampling_agent = sampling_agents[i];
SamplingKey key;
bool success = sampling_agent->add_noise(key);
SamplingInfo info;
bool success = sampling_agent->add_noise(info);
float reward = evaluate(envs[i], sampling_agent);
noisy_keys[i] = key;
noisy_info[i] = info;
noisy_rewards[i] = reward;
}
// NOTE: all parameters of sampling_agents will be updated
bool success = agent->update(noisy_keys, noisy_rewards);
bool success = agent->update(noisy_info, noisy_rewards);
int reward = evaluate(envs[0], agent);
LOG(INFO) << "Epoch:" << epoch << " Reward: " << reward;
......
......@@ -59,23 +59,23 @@ int main(int argc, char* argv[]) {
sampling_agents.push_back(agent->clone());
}
std::vector<SamplingKey> noisy_keys;
std::vector<SamplingInfo> noisy_info;
std::vector<float> noisy_rewards(ITER, 0.0f);
noisy_keys.resize(ITER);
noisy_info.resize(ITER);
for (int epoch = 0; epoch < 1000; ++epoch) {
#pragma omp parallel for schedule(dynamic, 1)
for (int i = 0; i < ITER; ++i) {
auto sampling_agent = sampling_agents[i];
SamplingKey key;
bool success = sampling_agent->add_noise(key);
SamplingInfo info;
bool success = sampling_agent->add_noise(info);
float reward = evaluate(envs[i], sampling_agent);
noisy_keys[i] = key;
noisy_info[i] = info;
noisy_rewards[i] = reward;
}
// Will also update parameters of sampling_agents
bool success = agent->update(noisy_keys, noisy_rewards);
bool success = agent->update(noisy_info, noisy_rewards);
// Use original agent to evalute (without noise).
int reward = evaluate(envs[0], agent);
......
......@@ -63,11 +63,11 @@ class ESAgent {
* Parameters of cloned agents will also be updated.
*/
bool update(
std::vector<SamplingKey>& noisy_keys,
std::vector<SamplingInfo>& noisy_info,
std::vector<float>& noisy_rewards);
// copied parameters = original parameters + noise
bool add_noise(SamplingKey& sampling_key);
bool add_noise(SamplingInfo& sampling_info);
/**
* @brief Get paddle predict
......
......@@ -98,7 +98,7 @@ public:
* Only not cloned ESAgent can call `update` function.
* Parameters of cloned agents will also be updated.
*/
bool update(std::vector<SamplingKey>& noisy_keys, std::vector<float>& noisy_rewards) {
bool update(std::vector<SamplingInfo>& noisy_info, std::vector<float>& noisy_rewards) {
if (_is_sampling_agent) {
LOG(ERROR) << "[DeepES] Cloned ESAgent cannot call update function, please use original ESAgent.";
return false;
......@@ -107,8 +107,8 @@ public:
compute_centered_ranks(noisy_rewards);
memset(_neg_gradients, 0, _param_size * sizeof(float));
for (int i = 0; i < noisy_keys.size(); ++i) {
int key = noisy_keys[i].key(0);
for (int i = 0; i < noisy_info.size(); ++i) {
int key = noisy_info[i].key(0);
float reward = noisy_rewards[i];
bool success = _sampling_method->resampling(key, _noise, _param_size);
for (int64_t j = 0; j < _param_size; ++j) {
......@@ -116,7 +116,7 @@ public:
}
}
for (int64_t j = 0; j < _param_size; ++j) {
_neg_gradients[j] /= -1.0 * noisy_keys.size();
_neg_gradients[j] /= -1.0 * noisy_info.size();
}
//update
......@@ -125,7 +125,7 @@ public:
for (auto& param: params) {
torch::Tensor tensor = param.value().view({-1});
auto tensor_a = tensor.accessor<float,1>();
_optimizer->update(tensor_a, _neg_gradients+counter, tensor.size(0), param.key());
_optimizer->update(tensor_a, _neg_gradients+counter, tensor.size(0), param.info());
counter += tensor.size(0);
}
......@@ -133,7 +133,7 @@ public:
}
// copied parameters = original parameters + noise
bool add_noise(SamplingKey& sampling_key) {
bool add_noise(SamplingInfo& sampling_info) {
if (!_is_sampling_agent) {
LOG(ERROR) << "[DeepES] Original ESAgent cannot call add_noise function, please use cloned ESAgent.";
return false;
......@@ -142,11 +142,11 @@ public:
auto sampling_params = _sampling_model->named_parameters();
auto params = _model->named_parameters();
int key = _sampling_method->sampling(_noise, _param_size);
sampling_key.add_key(key);
sampling_info.add_key(key);
int64_t counter = 0;
for (auto& param: sampling_params) {
torch::Tensor sampling_tensor = param.value().view({-1});
std::string param_name = param.key();
std::string param_name = param.info();
torch::Tensor tensor = params.find(param_name)->view({-1});
auto sampling_tensor_a = sampling_tensor.accessor<float,1>();
auto tensor_a = tensor.accessor<float,1>();
......
......@@ -78,7 +78,7 @@ std::shared_ptr<ESAgent> ESAgent::clone() {
}
bool ESAgent::update(
std::vector<SamplingKey>& noisy_keys,
std::vector<SamplingInfo>& noisy_info,
std::vector<float>& noisy_rewards) {
if (_is_sampling_agent) {
LOG(ERROR) << "[DeepES] Cloned ESAgent cannot call update function, please use original ESAgent.";
......@@ -88,8 +88,8 @@ bool ESAgent::update(
compute_centered_ranks(noisy_rewards);
memset(_neg_gradients, 0, _param_size * sizeof(float));
for (int i = 0; i < noisy_keys.size(); ++i) {
int key = noisy_keys[i].key(0);
for (int i = 0; i < noisy_info.size(); ++i) {
int key = noisy_info[i].key(0);
float reward = noisy_rewards[i];
bool success = _sampling_method->resampling(key, _noise, _param_size);
for (int64_t j = 0; j < _param_size; ++j) {
......@@ -97,7 +97,7 @@ bool ESAgent::update(
}
}
for (int64_t j = 0; j < _param_size; ++j) {
_neg_gradients[j] /= -1.0 * noisy_keys.size();
_neg_gradients[j] /= -1.0 * noisy_info.size();
}
//update
......@@ -114,14 +114,14 @@ bool ESAgent::update(
}
bool ESAgent::add_noise(SamplingKey& sampling_key) {
bool ESAgent::add_noise(SamplingInfo& sampling_info) {
if (!_is_sampling_agent) {
LOG(ERROR) << "[DeepES] Original ESAgent cannot call add_noise function, please use cloned ESAgent.";
return false;
}
int key = _sampling_method->sampling(_noise, _param_size);
sampling_key.add_key(key);
sampling_info.add_key(key);
int64_t counter = 0;
for (std::string param_name: _param_names) {
......
......@@ -23,6 +23,8 @@ message DeepESConfig {
optional GaussianSamplingConfig gaussian_sampling = 3;
// Optimizer Configuration
optional OptimizerConfig optimizer = 4;
// AsyncESAgent Configuration
optional AsyncESConfig async_es = 5;
}
message GaussianSamplingConfig {
......@@ -40,6 +42,13 @@ message OptimizerConfig{
optional float epsilon = 6 [default = 1e-8];
}
message SamplingKey{
message SamplingInfo{
repeated int32 key = 1;
optional int32 model_iter_id = 2;
}
message AsyncESConfig{
optional string model_warehouse = 1 [default = "./model_warehouse"];
repeated string model_md5 = 2;
optional int32 max_to_keep = 3 [default = 5];
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册