diff --git a/ding/model/wrapper/model_wrappers.py b/ding/model/wrapper/model_wrappers.py index 31718f07d765b74ee1114627e1ed895d66f80278..11ecb8b97a2206c88eee5d05ceaca0d5bcd2db34 100644 --- a/ding/model/wrapper/model_wrappers.py +++ b/ding/model/wrapper/model_wrappers.py @@ -338,7 +338,7 @@ class EpsGreedyMultinomialSampleWrapper(IModelWrapper): for i, l in enumerate(logit): if np.random.random() > eps: prob = torch.softmax(output['logit'] / alpha, dim=-1) - prob = prob / torch.sum(prob, 1, keepdims=True) + prob = prob / torch.sum(prob, 1, keepdim=True) pi_action = torch.zeros(prob.shape) pi_action = Categorical(prob) pi_action = pi_action.sample() @@ -386,7 +386,7 @@ class HybridEpsGreedyMultinomialSampleWrapper(IModelWrapper): for i, l in enumerate(logit): if np.random.random() > eps: prob = torch.softmax(l, dim=-1) - prob = prob / torch.sum(prob, 1, keepdims=True) + prob = prob / torch.sum(prob, 1, keepdim=True) pi_action = Categorical(prob) pi_action = pi_action.sample() action.append(pi_action) @@ -441,51 +441,6 @@ class EpsGreedySampleNGUWrapper(IModelWrapper): return output -class EpsGreedySampleWrapperSql(IModelWrapper): - r""" - Overview: - Epsilon greedy sampler coupled with multinomial sample used in collector_model - to help balance exploration and exploitation. - Interfaces: - register - """ - - def forward(self, *args, **kwargs): - eps = kwargs.pop('eps') - alpha = kwargs.pop('alpha') - output = self._model.forward(*args, **kwargs) - assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output)) - logit = output['logit'] - assert isinstance(logit, torch.Tensor) or isinstance(logit, list) - if isinstance(logit, torch.Tensor): - logit = [logit] - if 'action_mask' in output: - mask = output['action_mask'] - if isinstance(mask, torch.Tensor): - mask = [mask] - logit = [l.sub_(1e8 * (1 - m)) for l, m in zip(logit, mask)] - else: - mask = None - action = [] - for i, l in enumerate(logit): - if np.random.random() > eps: - prob = torch.softmax(output['logit'] / alpha, dim=-1) - prob = prob / torch.sum(prob, 1, keepdims=True) - pi_action = torch.zeros(prob.shape) - pi_action = Categorical(prob) - pi_action = pi_action.sample() - action.append(pi_action) - else: - if mask: - action.append(sample_action(prob=mask[i].float())) - else: - action.append(torch.randint(0, l.shape[-1], size=l.shape[:-1])) - if len(action) == 1: - action, logit = action[0], logit[0] - output['action'] = action - return output - - class ActionNoiseWrapper(IModelWrapper): r""" Overview: @@ -629,7 +584,6 @@ wrapper_name_map = { 'hybrid_argmax_sample': HybridArgmaxSampleWrapper, 'eps_greedy_sample': EpsGreedySampleWrapper, 'eps_greedy_sample_ngu': EpsGreedySampleNGUWrapper, - 'eps_greedy_sample_sql': EpsGreedySampleWrapperSql, 'eps_greedy_multinomial_sample': EpsGreedyMultinomialSampleWrapper, 'hybrid_eps_greedy_sample': HybridEpsGreedySampleWrapper, 'hybrid_eps_greedy_multinomial_sample': HybridEpsGreedyMultinomialSampleWrapper,