From 16833c62b43a223f79d6dac16f362de7b02057c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <48008469+puyuan1996@users.noreply.github.com> Date: Fri, 17 Dec 2021 15:07:25 +0800 Subject: [PATCH] polish(pu): polish eps_greedy_multinomial_sample in model_wrapper (#154) * polish(pu):polish eps_greedy_multinomial_sample in model_wrappers * polish(pu): delete masac wrapper * polish(pu): delete sql wrapper --- ding/model/wrapper/model_wrappers.py | 61 +++++++++++++++------------- ding/policy/sac.py | 9 ++-- 2 files changed, 39 insertions(+), 31 deletions(-) diff --git a/ding/model/wrapper/model_wrappers.py b/ding/model/wrapper/model_wrappers.py index 57f6071..8c9c04b 100644 --- a/ding/model/wrapper/model_wrappers.py +++ b/ding/model/wrapper/model_wrappers.py @@ -5,7 +5,6 @@ import numpy as np import torch from ding.torch_utils import get_tensor_data from ding.rl_utils import create_noise_generator -from torch.distributions import Categorical class IModelWrapper(ABC): @@ -210,12 +209,16 @@ class HybridArgmaxSampleWrapper(IModelWrapper): class MultinomialSampleWrapper(IModelWrapper): r""" Overview: - Used to helper the model get the corresponding action from the output['logits'] + Used to help the model get the corresponding action from the output['logits'] Interfaces: register """ def forward(self, *args, **kwargs): + if 'alpha' in kwargs.keys(): + alpha = kwargs.pop('alpha') + else: + alpha = None output = self._model.forward(*args, **kwargs) assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) logit = output['logit'] @@ -227,7 +230,11 @@ class MultinomialSampleWrapper(IModelWrapper): if isinstance(mask, torch.Tensor): mask = [mask] logit = [l.sub_(1e8 * (1 - m)) for l, m in zip(logit, mask)] - action = [sample_action(logit=l) for l in logit] + if alpha is None: + action = [sample_action(logit=l) for l in logit] + else: + # Note that if alpha is passed in here, we will divide logit by alpha. + action = [sample_action(logit=l / alpha) for l in logit] if len(action) == 1: action, logit = action[0], logit[0] output['action'] = action @@ -272,17 +279,21 @@ class EpsGreedySampleWrapper(IModelWrapper): return output -class HybridEpsGreedySampleWrapper(IModelWrapper): +class EpsGreedyMultinomialSampleWrapper(IModelWrapper): r""" Overview: - Epsilon greedy sampler used in collector_model to help balance exploration and exploitation. - In hybrid action space, i.e.{'action_type': discrete, 'action_args', continuous} + Epsilon greedy sampler coupled with multinomial sample used in collector_model + to help balance exploration and exploitation. Interfaces: - register, forward + register """ def forward(self, *args, **kwargs): eps = kwargs.pop('eps') + if 'alpha' in kwargs.keys(): + alpha = kwargs.pop('alpha') + else: + alpha = None output = self._model.forward(*args, **kwargs) assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) logit = output['logit'] @@ -299,7 +310,11 @@ class HybridEpsGreedySampleWrapper(IModelWrapper): action = [] for i, l in enumerate(logit): if np.random.random() > eps: - action.append(l.argmax(dim=-1)) + if alpha is None: + action = [sample_action(logit=l) for l in logit] + else: + # Note that if alpha is passed in here, we will divide logit by alpha. + action = [sample_action(logit=l / alpha) for l in logit] else: if mask: action.append(sample_action(prob=mask[i].float())) @@ -307,22 +322,21 @@ class HybridEpsGreedySampleWrapper(IModelWrapper): action.append(torch.randint(0, l.shape[-1], size=l.shape[:-1])) if len(action) == 1: action, logit = action[0], logit[0] - output = {'action': {'action_type': action, 'action_args': output['action_args']}, 'logit': logit} + output['action'] = action return output -class EpsGreedyMultinomialSampleWrapper(IModelWrapper): +class HybridEpsGreedySampleWrapper(IModelWrapper): r""" Overview: - Epsilon greedy sampler coupled with multinomial sample used in collector_model - to help balance exploration and exploitation. + Epsilon greedy sampler used in collector_model to help balance exploration and exploitation. + In hybrid action space, i.e.{'action_type': discrete, 'action_args', continuous} Interfaces: - register + register, forward """ def forward(self, *args, **kwargs): eps = kwargs.pop('eps') - alpha = kwargs.pop('alpha') output = self._model.forward(*args, **kwargs) assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) logit = output['logit'] @@ -339,12 +353,7 @@ class EpsGreedyMultinomialSampleWrapper(IModelWrapper): action = [] for i, l in enumerate(logit): if np.random.random() > eps: - prob = torch.softmax(output['logit'] / alpha, dim=-1) - prob = prob / torch.sum(prob, 1, keepdim=True) - pi_action = torch.zeros(prob.shape) - pi_action = Categorical(prob) - pi_action = pi_action.sample() - action.append(pi_action) + action.append(l.argmax(dim=-1)) else: if mask: action.append(sample_action(prob=mask[i].float())) @@ -352,8 +361,8 @@ class EpsGreedyMultinomialSampleWrapper(IModelWrapper): action.append(torch.randint(0, l.shape[-1], size=l.shape[:-1])) if len(action) == 1: action, logit = action[0], logit[0] - output['action'] = action - return output + output = {'action': {'action_type': action, 'action_args': output['action_args']}, 'logit': logit} + return class HybridEpsGreedyMultinomialSampleWrapper(IModelWrapper): @@ -387,11 +396,7 @@ class HybridEpsGreedyMultinomialSampleWrapper(IModelWrapper): action = [] for i, l in enumerate(logit): if np.random.random() > eps: - prob = torch.softmax(l, dim=-1) - prob = prob / torch.sum(prob, 1, keepdim=True) - pi_action = Categorical(prob) - pi_action = pi_action.sample() - action.append(pi_action) + action = [sample_action(logit=l) for l in logit] else: if mask: action.append(sample_action(prob=mask[i].float())) @@ -414,7 +419,7 @@ class EpsGreedySampleNGUWrapper(IModelWrapper): def forward(self, *args, **kwargs): kwargs.pop('eps') - eps = {i: 0.4 ** (1 + 8 * i / (args[0]['obs'].shape[0] - 1)) for i in range(args[0]['obs'].shape[0])} # TODO + eps = {i: 0.4 ** (1 + 8 * i / (args[0]['obs'].shape[0] - 1)) for i in range(args[0]['obs'].shape[0])} output = self._model.forward(*args, **kwargs) assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) logit = output['logit'] diff --git a/ding/policy/sac.py b/ding/policy/sac.py index ae371dd..78f154d 100644 --- a/ding/policy/sac.py +++ b/ding/policy/sac.py @@ -19,7 +19,7 @@ from .common_utils import default_preprocess_learn class SACDiscretePolicy(Policy): r""" Overview: - Policy class of Discrete SAC algorithm. + Policy class of discrete SAC algorithm. Config: == ==================== ======== ============= ================================= ======================= @@ -407,7 +407,10 @@ class SACDiscretePolicy(Policy): """ self._unroll_len = self._cfg.collect.unroll_len self._multi_agent = self._cfg.multi_agent - self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_sample') + # Empirically, we found that eps_greedy_multinomial_sample works better than multinomial_sample + # and eps_greedy_sample, and we don't divide logit by alpha, + # for the details please refer to ding/model/wrapper/model_wrappers + self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_multinomial_sample') self._collect_model.reset() def _forward_collect(self, data: dict, eps: float) -> dict: @@ -516,7 +519,7 @@ class SACDiscretePolicy(Policy): class SACPolicy(Policy): r""" Overview: - Policy class of SAC algorithm. + Policy class of continuous SAC algorithm. https://arxiv.org/pdf/1801.01290.pdf -- GitLab