未验证 提交 fa420300 编写于 作者: B Bo Zhou 提交者: GitHub

update comments for ES (#211)

* update comments for ES

* check dependence on paddle or torch

* update readme

* update readme#2

* users can still use parl.remote when no DL framework was found

* yapf
上级 7f2abd56
......@@ -72,6 +72,7 @@ pip install parl
# 算法示例
- [QuickStart](examples/QuickStart/)
- [DQN](examples/DQN/)
- [ES(深度进化算法)](examples/ES/)
- [DDPG](examples/DDPG/)
- [PPO](examples/PPO/)
- [IMPALA](examples/IMPALA/)
......
......@@ -75,6 +75,7 @@ pip install parl
# Examples
- [QuickStart](examples/QuickStart/)
- [DQN](examples/DQN/)
- [ES](examples/ES/)
- [DDPG](examples/DDPG/)
- [PPO](examples/PPO/)
- [IMPALA](examples/IMPALA/)
......
......@@ -55,7 +55,7 @@ class MujocoAgent(parl.Agent):
noises(np.float32): [batch_size, weights_total_size]
"""
g, count = utils.batched_weighted_sum(
g = utils.batched_weighted_sum(
# mirrored sampling: evaluate pairs of perturbations \epsilon, −\epsilon
noisy_rewards[:, 0] - noisy_rewards[:, 1],
noises,
......
......@@ -19,6 +19,10 @@ def compute_ranks(x):
def compute_centered_ranks(x):
"""Return ranks that is normliazed to [-0.5, 0.5] with the rewards as input.
Args:
x(np.array): an array of rewards.
"""
y = compute_ranks(x.ravel()).reshape(x.shape).astype(np.float32)
y /= (x.size - 1)
y -= 0.5
......@@ -26,6 +30,7 @@ def compute_centered_ranks(x):
def itergroups(items, group_size):
"""An iterator that iterates a list with batch data."""
assert group_size >= 1
group = []
for x in items:
......@@ -38,16 +43,22 @@ def itergroups(items, group_size):
def batched_weighted_sum(weights, vecs, batch_size):
"""Compute the gradients for updating the parameters.
Args:
weights(np.array): the nomalized rewards computed by the function `compute_centered_ranks`.
vecs(np.array): the noise added to the parameters.
batch_size(int): the batch_size for speeding up the computation.
Return:
total(np.array): aggregated gradient.
"""
total = 0
num_items_summed = 0
for batch_weights, batch_vecs in zip(
itergroups(weights, batch_size), itergroups(vecs, batch_size)):
assert len(batch_weights) == len(batch_vecs) <= batch_size
total += np.dot(
np.asarray(batch_weights, dtype=np.float32),
np.asarray(batch_vecs, dtype=np.float32))
num_items_summed += len(batch_weights)
return total, num_items_summed
return total
def unflatten(flat_array, array_shapes):
......
......@@ -13,8 +13,13 @@
# limitations under the License.
from parl.utils.utils import _HAS_FLUID, _HAS_TORCH
from parl.utils import logger
if _HAS_FLUID:
from parl.algorithms.fluid import *
elif _HAS_TORCH:
from parl.algorithms.torch import *
else:
logger.warning(
"No deep learning framework was found, but it's ok for parallel computation."
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册