from typing import Union, Dict, Optional from easydict import EasyDict import torch import torch.nn as nn from ding.utils import SequenceType, squeeze, MODEL_REGISTRY from ..common import ReparameterizationHead, RegressionHead, DiscreteHead, MultiHead, \ FCEncoder, ConvEncoder @MODEL_REGISTRY.register('vac') class VAC(nn.Module): r""" Overview: The VAC model. Interfaces: ``__init__``, ``forward``, ``compute_actor``, ``compute_critic`` """ mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] def __init__( self, obs_shape: Union[int, SequenceType], action_shape: Union[int, SequenceType, EasyDict], action_space: str = 'discrete', share_encoder: bool = True, encoder_hidden_size_list: SequenceType = [128, 128, 64], actor_head_hidden_size: int = 64, actor_head_layer_num: int = 1, critic_head_hidden_size: int = 64, critic_head_layer_num: int = 1, activation: Optional[nn.Module] = nn.ReLU(), norm_type: Optional[str] = None, sigma_type: Optional[str] = 'independent', fixed_sigma_value: Optional[int] = 0.3, bound_type: Optional[str] = None, ) -> None: r""" Overview: Init the VAC Model according to arguments. Arguments: - obs_shape (:obj:`Union[int, SequenceType]`): Observation's space. - action_shape (:obj:`Union[int, SequenceType]`): Action's space. - action_space (:obj:`str`): Choose action head in ['discrete', 'continuous', 'hybrid'] - share_encoder (:obj:`bool`): Whether share encoder. - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder`` - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``. - actor_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output for actor's nn. - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``. - critic_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output for critic's nn. - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` the after ``layer_fn``, if ``None`` then default set to ``nn.ReLU()`` - norm_type (:obj:`Optional[str]`): The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details` """ super(VAC, self).__init__() obs_shape: int = squeeze(obs_shape) action_shape = squeeze(action_shape) self.obs_shape, self.action_shape = obs_shape, action_shape # Encoder Type if isinstance(obs_shape, int) or len(obs_shape) == 1: encoder_cls = FCEncoder elif len(obs_shape) == 3: encoder_cls = ConvEncoder else: raise RuntimeError( "not support obs_shape for pre-defined encoder: {}, please customize your own DQN".format(obs_shape) ) self.share_encoder = share_encoder if self.share_encoder: self.encoder = encoder_cls(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type) else: self.actor_encoder = encoder_cls( obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type ) self.critic_encoder = encoder_cls( obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type ) # Head Type self.critic_head = RegressionHead( critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type ) self.action_space = action_space assert self.action_space in ['discrete', 'continuous', 'hybrid'], self.action_space if self.action_space == 'continuous': self.multi_head = False self.actor_head = ReparameterizationHead( actor_head_hidden_size, action_shape, actor_head_layer_num, sigma_type=sigma_type, activation=activation, norm_type=norm_type, bound_type=bound_type ) elif self.action_space == 'discrete': actor_head_cls = DiscreteHead multi_head = not isinstance(action_shape, int) self.multi_head = multi_head if multi_head: self.actor_head = MultiHead( actor_head_cls, actor_head_hidden_size, action_shape, layer_num=actor_head_layer_num, activation=activation, norm_type=norm_type ) else: self.actor_head = actor_head_cls( actor_head_hidden_size, action_shape, actor_head_layer_num, activation=activation, norm_type=norm_type ) elif self.action_space == 'hybrid': # HPPO # hybrid action space: action_type(discrete) + action_args(continuous), # such as {'action_type_shape': torch.LongTensor([0]), 'action_args_shape': torch.FloatTensor([0.1, -0.27])} action_shape.action_args_shape = squeeze(action_shape.action_args_shape) action_shape.action_type_shape = squeeze(action_shape.action_type_shape) actor_action_args = ReparameterizationHead( actor_head_hidden_size, action_shape.action_args_shape, actor_head_layer_num, sigma_type=sigma_type, fixed_sigma_value=fixed_sigma_value, activation=activation, norm_type=norm_type, bound_type=bound_type, ) actor_action_type = DiscreteHead( actor_head_hidden_size, action_shape.action_type_shape, actor_head_layer_num, activation=activation, norm_type=norm_type, ) self.actor_head = nn.ModuleList([actor_action_type, actor_action_args]) # must use list, not nn.ModuleList if self.share_encoder: self.actor = [self.encoder, self.actor_head] self.critic = [self.encoder, self.critic_head] else: self.actor = [self.actor_encoder, self.actor_head] self.critic = [self.critic_encoder, self.critic_head] # Convenient for calling some apis (e.g. self.critic.parameters()), # but may cause misunderstanding when `print(self)` self.actor = nn.ModuleList(self.actor) self.critic = nn.ModuleList(self.critic) def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict: r""" Overview: Use encoded embedding tensor to predict output. Parameter updates with VAC's MLPs forward setup. Arguments: Forward with ``'compute_actor'`` or ``'compute_critic'``: - inputs (:obj:`torch.Tensor`): The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``. Whether ``actor_head_hidden_size`` or ``critic_head_hidden_size`` depend on ``mode``. Returns: - outputs (:obj:`Dict`): Run with encoder and head. Forward with ``'compute_actor'``, Necessary Keys: - logit (:obj:`torch.Tensor`): Logit encoding tensor, with same size as input ``x``. Forward with ``'compute_critic'``, Necessary Keys: - value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. Shapes: - inputs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N corresponding ``hidden_size`` - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape`` - value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size. Actor Examples: >>> model = VAC(64,128) >>> inputs = torch.randn(4, 64) >>> actor_outputs = model(inputs,'compute_actor') >>> assert actor_outputs['logit'].shape == torch.Size([4, 128]) Critic Examples: >>> model = VAC(64,64) >>> inputs = torch.randn(4, 64) >>> critic_outputs = model(inputs,'compute_critic') >>> critic_outputs['value'] tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=) Actor-Critic Examples: >>> model = VAC(64,64) >>> inputs = torch.randn(4, 64) >>> outputs = model(inputs,'compute_actor_critic') >>> outputs['value'] tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=) >>> assert outputs['logit'].shape == torch.Size([4, 64]) """ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) return getattr(self, mode)(inputs) def compute_actor(self, x: torch.Tensor) -> Dict: r""" Overview: Execute parameter updates with ``'compute_actor'`` mode Use encoded embedding tensor to predict output. Arguments: - inputs (:obj:`torch.Tensor`): The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``. ``hidden_size = actor_head_hidden_size`` Returns: - outputs (:obj:`Dict`): Run with encoder and head. ReturnsKeys: - logit (:obj:`torch.Tensor`): Logit encoding tensor, with same size as input ``x``. Shapes: - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape`` Examples: >>> model = VAC(64,64) >>> inputs = torch.randn(4, 64) >>> actor_outputs = model(inputs,'compute_actor') >>> assert actor_outputs['action'].shape == torch.Size([4, 64]) """ if self.share_encoder: x = self.encoder(x) else: x = self.actor_encoder(x) if self.action_space == 'discrete': return self.actor_head(x) elif self.action_space == 'continuous': x = self.actor_head(x) # mu, sigma return {'logit': x} elif self.action_space == 'hybrid': action_type = self.actor_head[0](x) action_args = self.actor_head[1](x) return {'logit': {'action_type': action_type['logit'], 'action_args': action_args}} def compute_critic(self, x: torch.Tensor) -> Dict: r""" Overview: Execute parameter updates with ``'compute_critic'`` mode Use encoded embedding tensor to predict output. Arguments: - inputs (:obj:`torch.Tensor`): The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``. ``hidden_size = critic_head_hidden_size`` Returns: - outputs (:obj:`Dict`): Run with encoder and head. Necessary Keys: - value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. Shapes: - value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size. Examples: >>> model = VAC(64,64) >>> inputs = torch.randn(4, 64) >>> critic_outputs = model(inputs,'compute_critic') >>> critic_outputs['value'] tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=) """ if self.share_encoder: x = self.encoder(x) else: x = self.critic_encoder(x) x = self.critic_head(x) return {'value': x['pred']} def compute_actor_critic(self, x: torch.Tensor) -> Dict: r""" Overview: Execute parameter updates with ``'compute_actor_critic'`` mode Use encoded embedding tensor to predict output. Arguments: - inputs (:obj:`torch.Tensor`): The encoded embedding tensor. Returns: - outputs (:obj:`Dict`): Run with encoder and head. ReturnsKeys: - logit (:obj:`torch.Tensor`): Logit encoding tensor, with same size as input ``x``. - value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. Shapes: - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape`` - value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size. Examples: >>> model = VAC(64,64) >>> inputs = torch.randn(4, 64) >>> outputs = model(inputs,'compute_actor_critic') >>> outputs['value'] tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=) >>> assert outputs['logit'].shape == torch.Size([4, 64]) .. note:: ``compute_actor_critic`` interface aims to save computation when shares encoder. Returning the combination dictionry. """ if self.share_encoder: actor_embedding = critic_embedding = self.encoder(x) else: actor_embedding = self.actor_encoder(x) critic_embedding = self.critic_encoder(x) value = self.critic_head(critic_embedding)['pred'] if self.action_space == 'discrete': logit = self.actor_head(actor_embedding)['logit'] return {'logit': logit, 'value': value} elif self.action_space == 'continuous': x = self.actor_head(actor_embedding) return {'logit': x, 'value': value} elif self.action_space == 'hybrid': action_type = self.actor_head[0](actor_embedding) action_args = self.actor_head[1](actor_embedding) return {'logit': {'action_type': action_type['logit'], 'action_args': action_args}, 'value': value}