From fa420300e2b10dd5304328c15d8eabd3d5e33454 Mon Sep 17 00:00:00 2001 From: Bo Zhou <2466956298@qq.com> Date: Mon, 16 Mar 2020 14:59:12 +0800 Subject: [PATCH] update comments for ES (#211) * update comments for ES * check dependence on paddle or torch * update readme * update readme#2 * users can still use parl.remote when no DL framework was found * yapf --- README.cn.md | 1 + README.md | 1 + examples/ES/mujoco_agent.py | 2 +- examples/ES/utils.py | 17 ++++++++++++++--- parl/algorithms/__init__.py | 5 +++++ 5 files changed, 22 insertions(+), 4 deletions(-) diff --git a/README.cn.md b/README.cn.md index ff206c8..99cba49 100644 --- a/README.cn.md +++ b/README.cn.md @@ -72,6 +72,7 @@ pip install parl # 算法示例 - [QuickStart](examples/QuickStart/) - [DQN](examples/DQN/) +- [ES(深度进化算法)](examples/ES/) - [DDPG](examples/DDPG/) - [PPO](examples/PPO/) - [IMPALA](examples/IMPALA/) diff --git a/README.md b/README.md index e2327bf..a17d057 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,7 @@ pip install parl # Examples - [QuickStart](examples/QuickStart/) - [DQN](examples/DQN/) +- [ES](examples/ES/) - [DDPG](examples/DDPG/) - [PPO](examples/PPO/) - [IMPALA](examples/IMPALA/) diff --git a/examples/ES/mujoco_agent.py b/examples/ES/mujoco_agent.py index 58260d7..f914e49 100644 --- a/examples/ES/mujoco_agent.py +++ b/examples/ES/mujoco_agent.py @@ -55,7 +55,7 @@ class MujocoAgent(parl.Agent): noises(np.float32): [batch_size, weights_total_size] """ - g, count = utils.batched_weighted_sum( + g = utils.batched_weighted_sum( # mirrored sampling: evaluate pairs of perturbations \epsilon, −\epsilon noisy_rewards[:, 0] - noisy_rewards[:, 1], noises, diff --git a/examples/ES/utils.py b/examples/ES/utils.py index 29d43e0..265da51 100644 --- a/examples/ES/utils.py +++ b/examples/ES/utils.py @@ -19,6 +19,10 @@ def compute_ranks(x): def compute_centered_ranks(x): + """Return ranks that is normliazed to [-0.5, 0.5] with the rewards as input. + Args: + x(np.array): an array of rewards. + """ y = compute_ranks(x.ravel()).reshape(x.shape).astype(np.float32) y /= (x.size - 1) y -= 0.5 @@ -26,6 +30,7 @@ def compute_centered_ranks(x): def itergroups(items, group_size): + """An iterator that iterates a list with batch data.""" assert group_size >= 1 group = [] for x in items: @@ -38,16 +43,22 @@ def itergroups(items, group_size): def batched_weighted_sum(weights, vecs, batch_size): + """Compute the gradients for updating the parameters. + Args: + weights(np.array): the nomalized rewards computed by the function `compute_centered_ranks`. + vecs(np.array): the noise added to the parameters. + batch_size(int): the batch_size for speeding up the computation. + Return: + total(np.array): aggregated gradient. + """ total = 0 - num_items_summed = 0 for batch_weights, batch_vecs in zip( itergroups(weights, batch_size), itergroups(vecs, batch_size)): assert len(batch_weights) == len(batch_vecs) <= batch_size total += np.dot( np.asarray(batch_weights, dtype=np.float32), np.asarray(batch_vecs, dtype=np.float32)) - num_items_summed += len(batch_weights) - return total, num_items_summed + return total def unflatten(flat_array, array_shapes): diff --git a/parl/algorithms/__init__.py b/parl/algorithms/__init__.py index 20c3d3d..8565455 100644 --- a/parl/algorithms/__init__.py +++ b/parl/algorithms/__init__.py @@ -13,8 +13,13 @@ # limitations under the License. from parl.utils.utils import _HAS_FLUID, _HAS_TORCH +from parl.utils import logger if _HAS_FLUID: from parl.algorithms.fluid import * elif _HAS_TORCH: from parl.algorithms.torch import * +else: + logger.warning( + "No deep learning framework was found, but it's ok for parallel computation." + ) -- GitLab