未验证 提交 a0435286 编写于 作者: R Robin Chen 提交者: GitHub

polish(nyz): update multi-discrete policies (#167)

上级 2699aa5e
...@@ -31,7 +31,10 @@ class PolicyFactory: ...@@ -31,7 +31,10 @@ class PolicyFactory:
def forward(data: Dict[int, Any], *args, **kwargs) -> Dict[int, Any]: def forward(data: Dict[int, Any], *args, **kwargs) -> Dict[int, Any]:
def discrete_random_action(min_val, max_val, shape): def discrete_random_action(min_val, max_val, shape):
return np.random.randint(min_val, max_val, shape) action = np.random.randint(min_val, max_val, shape)
if len(action) > 1:
action = list(np.expand_dims(action, axis=1))
return action
def continuous_random_action(min_val, max_val, shape): def continuous_random_action(min_val, max_val, shape):
bounded_below = min_val != float("inf") bounded_below = min_val != float("inf")
......
...@@ -56,9 +56,10 @@ class MultiDiscreteDQNPolicy(DQNPolicy): ...@@ -56,9 +56,10 @@ class MultiDiscreteDQNPolicy(DQNPolicy):
value_gamma = data.get('value_gamma') value_gamma = data.get('value_gamma')
if isinstance(q_value, list): if isinstance(q_value, list):
tl_num = len(q_value) act_num = len(q_value)
loss, td_error_per_sample = [], [] loss, td_error_per_sample = [], []
for i in range(tl_num): q_value_list = []
for i in range(act_num):
td_data = q_nstep_td_data( td_data = q_nstep_td_data(
q_value[i], target_q_value[i], data['action'][i], target_q_action[i], data['reward'], data['done'], q_value[i], target_q_value[i], data['action'][i], target_q_action[i], data['reward'], data['done'],
data['weight'] data['weight']
...@@ -68,8 +69,10 @@ class MultiDiscreteDQNPolicy(DQNPolicy): ...@@ -68,8 +69,10 @@ class MultiDiscreteDQNPolicy(DQNPolicy):
) )
loss.append(loss_) loss.append(loss_)
td_error_per_sample.append(td_error_per_sample_.abs()) td_error_per_sample.append(td_error_per_sample_.abs())
q_value_list.append(q_value[i].mean().item())
loss = sum(loss) / (len(loss) + 1e-8) loss = sum(loss) / (len(loss) + 1e-8)
td_error_per_sample = sum(td_error_per_sample) / (len(td_error_per_sample) + 1e-8) td_error_per_sample = sum(td_error_per_sample) / (len(td_error_per_sample) + 1e-8)
q_value_mean = sum(q_value_list) / act_num
else: else:
data_n = q_nstep_td_data( data_n = q_nstep_td_data(
q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], data['weight'] q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], data['weight']
...@@ -77,6 +80,7 @@ class MultiDiscreteDQNPolicy(DQNPolicy): ...@@ -77,6 +80,7 @@ class MultiDiscreteDQNPolicy(DQNPolicy):
loss, td_error_per_sample = q_nstep_td_error( loss, td_error_per_sample = q_nstep_td_error(
data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma
) )
q_value_mean = q_value.mean().item()
# ==================== # ====================
# Q-learning update # Q-learning update
...@@ -94,5 +98,6 @@ class MultiDiscreteDQNPolicy(DQNPolicy): ...@@ -94,5 +98,6 @@ class MultiDiscreteDQNPolicy(DQNPolicy):
return { return {
'cur_lr': self._optimizer.defaults['lr'], 'cur_lr': self._optimizer.defaults['lr'],
'total_loss': loss.item(), 'total_loss': loss.item(),
'q_value_mean': q_value_mean,
'priority': td_error_per_sample.abs().tolist(), 'priority': td_error_per_sample.abs().tolist(),
} }
...@@ -34,26 +34,19 @@ class MultiDiscretePPOPolicy(PPOPolicy): ...@@ -34,26 +34,19 @@ class MultiDiscretePPOPolicy(PPOPolicy):
# ==================== # ====================
return_infos = [] return_infos = []
self._learn_model.train() self._learn_model.train()
if self._value_norm:
unnormalized_return = data['adv'] + data['value'] * self._running_mean_std.std
data['return'] = unnormalized_return / self._running_mean_std.std
self._running_mean_std.update(unnormalized_return.cpu().numpy())
else:
data['return'] = data['adv'] + data['value']
for epoch in range(self._cfg.learn.epoch_per_collect): for epoch in range(self._cfg.learn.epoch_per_collect):
if self._recompute_adv: if self._recompute_adv:
with torch.no_grad(): with torch.no_grad():
# obs = torch.cat([data['obs'], data['next_obs'][-1:]])
value = self._learn_model.forward(data['obs'], mode='compute_critic')['value'] value = self._learn_model.forward(data['obs'], mode='compute_critic')['value']
next_value = self._learn_model.forward(data['next_obs'], mode='compute_critic')['value'] next_value = self._learn_model.forward(data['next_obs'], mode='compute_critic')['value']
if self._value_norm: if self._value_norm:
value *= self._running_mean_std.std value *= self._running_mean_std.std
next_value *= self._running_mean_std.std next_value *= self._running_mean_std.std
gae_data_ = gae_data(value, next_value, data['reward'], data['done']) compute_adv_data = gae_data(value, next_value, data['reward'], data['done'], data['traj_flag'])
# GAE need (T, B) shape input and return (T, B) output # GAE need (T, B) shape input and return (T, B) output
data['adv'] = gae(gae_data_, self._gamma, self._gae_lambda) data['adv'] = gae(compute_adv_data, self._gamma, self._gae_lambda)
# value = value[:-1] # value = value[:-1]
unnormalized_returns = value + data['adv'] unnormalized_returns = value + data['adv']
...@@ -65,6 +58,14 @@ class MultiDiscretePPOPolicy(PPOPolicy): ...@@ -65,6 +58,14 @@ class MultiDiscretePPOPolicy(PPOPolicy):
data['value'] = value data['value'] = value
data['return'] = unnormalized_returns data['return'] = unnormalized_returns
else: # don't recompute adv
if self._value_norm:
unnormalized_return = data['adv'] + data['value'] * self._running_mean_std.std
data['return'] = unnormalized_return / self._running_mean_std.std
self._running_mean_std.update(unnormalized_return.cpu().numpy())
else:
data['return'] = data['adv'] + data['value']
for batch in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True): for batch in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True):
output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic') output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic')
adv = batch['adv'] adv = batch['adv']
......
...@@ -67,10 +67,10 @@ class MultiDiscreteRainbowDQNPolicy(RainbowDQNPolicy): ...@@ -67,10 +67,10 @@ class MultiDiscreteRainbowDQNPolicy(RainbowDQNPolicy):
value_gamma=value_gamma value_gamma=value_gamma
) )
else: else:
tl_num = len(q_dist) act_num = len(q_dist)
losses = [] losses = []
td_error_per_samples = [] td_error_per_samples = []
for i in range(tl_num): for i in range(act_num):
td_data = dist_nstep_td_data( td_data = dist_nstep_td_data(
q_dist[i], target_q_dist[i], data['action'][i], target_q_action[i], data['reward'], data['done'], q_dist[i], target_q_dist[i], data['action'][i], target_q_action[i], data['reward'], data['done'],
data['weight'] data['weight']
...@@ -87,7 +87,7 @@ class MultiDiscreteRainbowDQNPolicy(RainbowDQNPolicy): ...@@ -87,7 +87,7 @@ class MultiDiscreteRainbowDQNPolicy(RainbowDQNPolicy):
losses.append(td_loss) losses.append(td_loss)
td_error_per_samples.append(td_error_per_sample) td_error_per_samples.append(td_error_per_sample)
loss = sum(losses) / (len(losses) + 1e-8) loss = sum(losses) / (len(losses) + 1e-8)
td_error_per_sample_mean = sum(td_error_per_samples) td_error_per_sample_mean = sum(td_error_per_samples) / (len(td_error_per_samples) + 1e-8)
# ==================== # ====================
# Rainbow update # Rainbow update
# ==================== # ====================
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册