未验证 提交 16833c62 编写于 作者: 蒲源 提交者: GitHub

polish(pu): polish eps_greedy_multinomial_sample in model_wrapper (#154)

* polish(pu):polish eps_greedy_multinomial_sample in model_wrappers

* polish(pu): delete masac wrapper

* polish(pu): delete sql wrapper
上级 e6604502
...@@ -5,7 +5,6 @@ import numpy as np ...@@ -5,7 +5,6 @@ import numpy as np
import torch import torch
from ding.torch_utils import get_tensor_data from ding.torch_utils import get_tensor_data
from ding.rl_utils import create_noise_generator from ding.rl_utils import create_noise_generator
from torch.distributions import Categorical
class IModelWrapper(ABC): class IModelWrapper(ABC):
...@@ -210,12 +209,16 @@ class HybridArgmaxSampleWrapper(IModelWrapper): ...@@ -210,12 +209,16 @@ class HybridArgmaxSampleWrapper(IModelWrapper):
class MultinomialSampleWrapper(IModelWrapper): class MultinomialSampleWrapper(IModelWrapper):
r""" r"""
Overview: Overview:
Used to helper the model get the corresponding action from the output['logits'] Used to help the model get the corresponding action from the output['logits']
Interfaces: Interfaces:
register register
""" """
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
if 'alpha' in kwargs.keys():
alpha = kwargs.pop('alpha')
else:
alpha = None
output = self._model.forward(*args, **kwargs) output = self._model.forward(*args, **kwargs)
assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output))
logit = output['logit'] logit = output['logit']
...@@ -227,7 +230,11 @@ class MultinomialSampleWrapper(IModelWrapper): ...@@ -227,7 +230,11 @@ class MultinomialSampleWrapper(IModelWrapper):
if isinstance(mask, torch.Tensor): if isinstance(mask, torch.Tensor):
mask = [mask] mask = [mask]
logit = [l.sub_(1e8 * (1 - m)) for l, m in zip(logit, mask)] logit = [l.sub_(1e8 * (1 - m)) for l, m in zip(logit, mask)]
action = [sample_action(logit=l) for l in logit] if alpha is None:
action = [sample_action(logit=l) for l in logit]
else:
# Note that if alpha is passed in here, we will divide logit by alpha.
action = [sample_action(logit=l / alpha) for l in logit]
if len(action) == 1: if len(action) == 1:
action, logit = action[0], logit[0] action, logit = action[0], logit[0]
output['action'] = action output['action'] = action
...@@ -272,17 +279,21 @@ class EpsGreedySampleWrapper(IModelWrapper): ...@@ -272,17 +279,21 @@ class EpsGreedySampleWrapper(IModelWrapper):
return output return output
class HybridEpsGreedySampleWrapper(IModelWrapper): class EpsGreedyMultinomialSampleWrapper(IModelWrapper):
r""" r"""
Overview: Overview:
Epsilon greedy sampler used in collector_model to help balance exploration and exploitation. Epsilon greedy sampler coupled with multinomial sample used in collector_model
In hybrid action space, i.e.{'action_type': discrete, 'action_args', continuous} to help balance exploration and exploitation.
Interfaces: Interfaces:
register, forward register
""" """
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
eps = kwargs.pop('eps') eps = kwargs.pop('eps')
if 'alpha' in kwargs.keys():
alpha = kwargs.pop('alpha')
else:
alpha = None
output = self._model.forward(*args, **kwargs) output = self._model.forward(*args, **kwargs)
assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output))
logit = output['logit'] logit = output['logit']
...@@ -299,7 +310,11 @@ class HybridEpsGreedySampleWrapper(IModelWrapper): ...@@ -299,7 +310,11 @@ class HybridEpsGreedySampleWrapper(IModelWrapper):
action = [] action = []
for i, l in enumerate(logit): for i, l in enumerate(logit):
if np.random.random() > eps: if np.random.random() > eps:
action.append(l.argmax(dim=-1)) if alpha is None:
action = [sample_action(logit=l) for l in logit]
else:
# Note that if alpha is passed in here, we will divide logit by alpha.
action = [sample_action(logit=l / alpha) for l in logit]
else: else:
if mask: if mask:
action.append(sample_action(prob=mask[i].float())) action.append(sample_action(prob=mask[i].float()))
...@@ -307,22 +322,21 @@ class HybridEpsGreedySampleWrapper(IModelWrapper): ...@@ -307,22 +322,21 @@ class HybridEpsGreedySampleWrapper(IModelWrapper):
action.append(torch.randint(0, l.shape[-1], size=l.shape[:-1])) action.append(torch.randint(0, l.shape[-1], size=l.shape[:-1]))
if len(action) == 1: if len(action) == 1:
action, logit = action[0], logit[0] action, logit = action[0], logit[0]
output = {'action': {'action_type': action, 'action_args': output['action_args']}, 'logit': logit} output['action'] = action
return output return output
class EpsGreedyMultinomialSampleWrapper(IModelWrapper): class HybridEpsGreedySampleWrapper(IModelWrapper):
r""" r"""
Overview: Overview:
Epsilon greedy sampler coupled with multinomial sample used in collector_model Epsilon greedy sampler used in collector_model to help balance exploration and exploitation.
to help balance exploration and exploitation. In hybrid action space, i.e.{'action_type': discrete, 'action_args', continuous}
Interfaces: Interfaces:
register register, forward
""" """
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
eps = kwargs.pop('eps') eps = kwargs.pop('eps')
alpha = kwargs.pop('alpha')
output = self._model.forward(*args, **kwargs) output = self._model.forward(*args, **kwargs)
assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output))
logit = output['logit'] logit = output['logit']
...@@ -339,12 +353,7 @@ class EpsGreedyMultinomialSampleWrapper(IModelWrapper): ...@@ -339,12 +353,7 @@ class EpsGreedyMultinomialSampleWrapper(IModelWrapper):
action = [] action = []
for i, l in enumerate(logit): for i, l in enumerate(logit):
if np.random.random() > eps: if np.random.random() > eps:
prob = torch.softmax(output['logit'] / alpha, dim=-1) action.append(l.argmax(dim=-1))
prob = prob / torch.sum(prob, 1, keepdim=True)
pi_action = torch.zeros(prob.shape)
pi_action = Categorical(prob)
pi_action = pi_action.sample()
action.append(pi_action)
else: else:
if mask: if mask:
action.append(sample_action(prob=mask[i].float())) action.append(sample_action(prob=mask[i].float()))
...@@ -352,8 +361,8 @@ class EpsGreedyMultinomialSampleWrapper(IModelWrapper): ...@@ -352,8 +361,8 @@ class EpsGreedyMultinomialSampleWrapper(IModelWrapper):
action.append(torch.randint(0, l.shape[-1], size=l.shape[:-1])) action.append(torch.randint(0, l.shape[-1], size=l.shape[:-1]))
if len(action) == 1: if len(action) == 1:
action, logit = action[0], logit[0] action, logit = action[0], logit[0]
output['action'] = action output = {'action': {'action_type': action, 'action_args': output['action_args']}, 'logit': logit}
return output return
class HybridEpsGreedyMultinomialSampleWrapper(IModelWrapper): class HybridEpsGreedyMultinomialSampleWrapper(IModelWrapper):
...@@ -387,11 +396,7 @@ class HybridEpsGreedyMultinomialSampleWrapper(IModelWrapper): ...@@ -387,11 +396,7 @@ class HybridEpsGreedyMultinomialSampleWrapper(IModelWrapper):
action = [] action = []
for i, l in enumerate(logit): for i, l in enumerate(logit):
if np.random.random() > eps: if np.random.random() > eps:
prob = torch.softmax(l, dim=-1) action = [sample_action(logit=l) for l in logit]
prob = prob / torch.sum(prob, 1, keepdim=True)
pi_action = Categorical(prob)
pi_action = pi_action.sample()
action.append(pi_action)
else: else:
if mask: if mask:
action.append(sample_action(prob=mask[i].float())) action.append(sample_action(prob=mask[i].float()))
...@@ -414,7 +419,7 @@ class EpsGreedySampleNGUWrapper(IModelWrapper): ...@@ -414,7 +419,7 @@ class EpsGreedySampleNGUWrapper(IModelWrapper):
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
kwargs.pop('eps') kwargs.pop('eps')
eps = {i: 0.4 ** (1 + 8 * i / (args[0]['obs'].shape[0] - 1)) for i in range(args[0]['obs'].shape[0])} # TODO eps = {i: 0.4 ** (1 + 8 * i / (args[0]['obs'].shape[0] - 1)) for i in range(args[0]['obs'].shape[0])}
output = self._model.forward(*args, **kwargs) output = self._model.forward(*args, **kwargs)
assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output))
logit = output['logit'] logit = output['logit']
......
...@@ -19,7 +19,7 @@ from .common_utils import default_preprocess_learn ...@@ -19,7 +19,7 @@ from .common_utils import default_preprocess_learn
class SACDiscretePolicy(Policy): class SACDiscretePolicy(Policy):
r""" r"""
Overview: Overview:
Policy class of Discrete SAC algorithm. Policy class of discrete SAC algorithm.
Config: Config:
== ==================== ======== ============= ================================= ======================= == ==================== ======== ============= ================================= =======================
...@@ -407,7 +407,10 @@ class SACDiscretePolicy(Policy): ...@@ -407,7 +407,10 @@ class SACDiscretePolicy(Policy):
""" """
self._unroll_len = self._cfg.collect.unroll_len self._unroll_len = self._cfg.collect.unroll_len
self._multi_agent = self._cfg.multi_agent self._multi_agent = self._cfg.multi_agent
self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_sample') # Empirically, we found that eps_greedy_multinomial_sample works better than multinomial_sample
# and eps_greedy_sample, and we don't divide logit by alpha,
# for the details please refer to ding/model/wrapper/model_wrappers
self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_multinomial_sample')
self._collect_model.reset() self._collect_model.reset()
def _forward_collect(self, data: dict, eps: float) -> dict: def _forward_collect(self, data: dict, eps: float) -> dict:
...@@ -516,7 +519,7 @@ class SACDiscretePolicy(Policy): ...@@ -516,7 +519,7 @@ class SACDiscretePolicy(Policy):
class SACPolicy(Policy): class SACPolicy(Policy):
r""" r"""
Overview: Overview:
Policy class of SAC algorithm. Policy class of continuous SAC algorithm.
https://arxiv.org/pdf/1801.01290.pdf https://arxiv.org/pdf/1801.01290.pdf
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册