提交 118cc673 编写于 作者: N niuyazhe

polish(nyz): move actor_head_type to action_space field in qac and update readme new repo link

上级 a0435286
......@@ -54,11 +54,13 @@ Updated on 2021.12.03 DI-engine-v0.2.2 (beta)
- [DI-star](https://github.com/opendilab/DI-star): Decision AI in StarCraftII
- [DI-drive](https://github.com/opendilab/DI-drive): Auto-driving platform
- [GoBigger](https://github.com/opendilab/GoBigger): Multi-Agent Decision Intelligence Environment
- [DI-smartcross](https://github.com/opendilab/DI-smartcross): Decision AI in Traffic Light Control
- General nested data lib
- [treevalue](https://github.com/opendilab/treevalue): Tree-nested data structure
- [DI-treetensor](https://github.com/opendilab/DI-treetensor): Tree-nested PyTorch tensor Lib
- Docs and Tutorials
- [DI-engine-docs](https://github.com/opendilab/DI-engine-docs)
- [awesome-model-based-RL](https://github.com/opendilab/awesome-model-based-RL): A curated list of awesome Model-Based RL resources
**DI-engine** also has some **system optimization and design** for efficient and robust large-scale RL training:
......
......@@ -24,7 +24,6 @@ class MAQAC(nn.Module):
agent_obs_shape: Union[int, SequenceType],
global_obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType],
# actor_head_type: str,
twin_critic: bool = False,
actor_head_hidden_size: int = 64,
actor_head_layer_num: int = 1,
......@@ -39,7 +38,6 @@ class MAQAC(nn.Module):
Arguments:
- obs_shape (:obj:`Union[int, SequenceType]`): Observation's space.
- action_shape (:obj:`Union[int, SequenceType]`): Action's space.
- actor_head_type (:obj:`str`): Whether choose ``regression`` or ``reparameterization``.
- twin_critic (:obj:`bool`): Whether include twin critic.
- actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``.
- actor_head_layer_num (:obj:`int`):
......@@ -179,11 +177,6 @@ class MAQAC(nn.Module):
- obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape``
- action (:obj:`torch.Tensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape``
- q_value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size.
Examples:
>>> inputs = {'obs': torch.randn(4, N), 'action': torch.randn(4, 1)}
>>> model = QAC(obs_shape=(N, ),action_shape=1,actor_head_type='regression')
>>> model(inputs, mode='compute_critic')['q_value'] # q value
tensor([0.0773, 0.1639, 0.0917, 0.0370], grad_fn=<SqueezeBackward1>)
"""
if self.twin_critic:
......@@ -208,7 +201,7 @@ class ContinuousMAQAC(nn.Module):
agent_obs_shape: Union[int, SequenceType],
global_obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType, EasyDict],
actor_head_type: str,
action_space: str,
twin_critic: bool = False,
actor_head_hidden_size: int = 64,
actor_head_layer_num: int = 1,
......@@ -222,9 +215,8 @@ class ContinuousMAQAC(nn.Module):
Init the QAC Model according to arguments.
Arguments:
- obs_shape (:obj:`Union[int, SequenceType]`): Observation's space.
- action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's space, such as 4, (3, ),
EasyDict({'action_type_shape': 3, 'action_args_shape': 4}).
- actor_head_type (:obj:`str`): Whether choose ``regression`` or ``reparameterization`` or ``hybrid`` .
- action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's space, such as 4, (3, )
- action_space (:obj:`str`): Whether choose ``regression`` or ``reparameterization``.
- twin_critic (:obj:`bool`): Whether include twin critic.
- actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``.
- actor_head_layer_num (:obj:`int`):
......@@ -243,9 +235,9 @@ class ContinuousMAQAC(nn.Module):
global_obs_shape: int = squeeze(global_obs_shape)
action_shape = squeeze(action_shape)
self.action_shape = action_shape
self.actor_head_type = actor_head_type
assert self.actor_head_type in ['regression', 'reparameterization']
if self.actor_head_type == 'regression': # DDPG, TD3
self.action_space = action_space
assert self.action_space in ['regression', 'reparameterization']
if self.action_space == 'regression': # DDPG, TD3
self.actor = nn.Sequential(
nn.Linear(obs_shape, actor_head_hidden_size), activation,
RegressionHead(
......@@ -350,12 +342,6 @@ class ContinuousMAQAC(nn.Module):
>>> actor_outputs['logit'][1].shape # sigma
>>> torch.Size([4, 64])
Critic Examples:
>>> inputs = {'obs': torch.randn(4,N), 'action': torch.randn(4,1)}
>>> model = QAC(obs_shape=(N, ),action_shape=1,actor_head_type='regression')
>>> model(inputs, mode='compute_critic')['q_value'] # q value
tensor([0.0773, 0.1639, 0.0917, 0.0370], grad_fn=<SqueezeBackward1>)
"""
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
return getattr(self, mode)(inputs)
......@@ -404,7 +390,7 @@ class ContinuousMAQAC(nn.Module):
>>> torch.Size([4, 64])
"""
inputs = inputs['agent_state']
if self.actor_head_type == 'regression':
if self.action_space == 'regression':
x = self.actor(inputs)
return {'action': x['pred']}
else:
......@@ -434,12 +420,6 @@ class ContinuousMAQAC(nn.Module):
- obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape``
- action (:obj:`torch.Tensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape``
- q_value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size.
Examples:
>>> inputs = {'obs': torch.randn(4, N), 'action': torch.randn(4, 1)}
>>> model = QAC(obs_shape=(N, ),action_shape=1,actor_head_type='regression')
>>> model(inputs, mode='compute_critic')['q_value'] # q value
>>> tensor([0.0773, 0.1639, 0.0917, 0.0370], grad_fn=<SqueezeBackward1>)
"""
obs, action = inputs['obs']['global_state'], inputs['action']
......
......@@ -325,7 +325,6 @@ class DiscreteQAC(nn.Module):
global_obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType],
encoder_hidden_size_list: SequenceType = [64],
#actor_head_type: str,
twin_critic: bool = False,
actor_head_hidden_size: int = 64,
actor_head_layer_num: int = 1,
......@@ -340,7 +339,6 @@ class DiscreteQAC(nn.Module):
Arguments:
- obs_shape (:obj:`Union[int, SequenceType]`): Observation's space.
- action_shape (:obj:`Union[int, SequenceType]`): Action's space.
- actor_head_type (:obj:`str`): Whether choose ``regression`` or ``reparameterization``.
- twin_critic (:obj:`bool`): Whether include twin critic.
- actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``.
- actor_head_layer_num (:obj:`int`):
......@@ -468,7 +466,7 @@ class DiscreteQAC(nn.Module):
Critic Examples:
>>> inputs = {'obs': torch.randn(4,N), 'action': torch.randn(4,1)}
>>> model = QAC(obs_shape=(N, ),action_shape=1,actor_head_type='regression')
>>> model = QAC(obs_shape=(N, ), action_shape=1, action_space='regression')
>>> model(inputs, mode='compute_critic')['q_value'] # q value
tensor([0.0773, 0.1639, 0.0917, 0.0370], grad_fn=<SqueezeBackward1>)
......@@ -537,7 +535,7 @@ class DiscreteQAC(nn.Module):
Examples:
>>> inputs = {'obs': torch.randn(4, N), 'action': torch.randn(4, 1)}
>>> model = QAC(obs_shape=(N, ),action_shape=1,actor_head_type='regression')
>>> model = QAC(obs_shape=(N, ),action_shape=1, action_space='regression')
>>> model(inputs, mode='compute_critic')['q_value'] # q value
tensor([0.0773, 0.1639, 0.0917, 0.0370], grad_fn=<SqueezeBackward1>)
......
......@@ -21,7 +21,7 @@ bipedalwalker_sac_config = dict(
obs_shape=24,
action_shape=4,
twin_critic=True,
actor_head_type='reparameterization',
action_space='reparameterization',
actor_head_hidden_size=128,
critic_head_hidden_size=128,
),
......
......@@ -20,7 +20,7 @@ bipedalwalker_td3_config = dict(
twin_critic=True,
actor_head_hidden_size=400,
critic_head_hidden_size=400,
actor_head_type='regression',
action_space='regression',
),
learn=dict(
update_per_collect=4,
......
......@@ -17,7 +17,7 @@ hopper_cql_default_config = dict(
obs_shape=11,
action_shape=3,
twin_critic=True,
actor_head_type='reparameterization',
action_space='reparameterization',
actor_head_hidden_size=256,
critic_head_hidden_size=256,
),
......
......@@ -17,7 +17,7 @@ hopper_expert_cql_default_config = dict(
obs_shape=11,
action_shape=3,
twin_critic=True,
actor_head_type='reparameterization',
action_space='reparameterization',
actor_head_hidden_size=256,
critic_head_hidden_size=256,
),
......
......@@ -17,7 +17,7 @@ hopper_medium_cql_default_config = dict(
obs_shape=11,
action_shape=3,
twin_critic=True,
actor_head_type='reparameterization',
action_space='reparameterization',
actor_head_hidden_size=256,
critic_head_hidden_size=256,
),
......
......@@ -20,7 +20,7 @@ ant_ddpg_default_config = dict(
twin_critic=False,
actor_head_hidden_size=256,
critic_head_hidden_size=256,
actor_head_type='regression',
action_space='regression',
),
learn=dict(
update_per_collect=1,
......
......@@ -19,7 +19,7 @@ ant_sac_default_config = dict(
obs_shape=111,
action_shape=8,
twin_critic=True,
actor_head_type='reparameterization',
action_space='reparameterization',
actor_head_hidden_size=256,
critic_head_hidden_size=256,
),
......
......@@ -20,7 +20,7 @@ ant_td3_default_config = dict(
twin_critic=True,
actor_head_hidden_size=256,
critic_head_hidden_size=256,
actor_head_type='regression',
action_space='regression',
),
learn=dict(
update_per_collect=1,
......
......@@ -36,7 +36,7 @@ ant_trex_sac_default_config = dict(
obs_shape=111,
action_shape=8,
twin_critic=True,
actor_head_type='reparameterization',
action_space='reparameterization',
actor_head_hidden_size=256,
critic_head_hidden_size=256,
),
......
......@@ -20,7 +20,7 @@ halfcheetah_ddpg_default_config = dict(
twin_critic=False,
actor_head_hidden_size=256,
critic_head_hidden_size=256,
actor_head_type='regression',
action_space='regression',
),
learn=dict(
update_per_collect=1,
......
......@@ -28,7 +28,7 @@ halfcheetah_gcl_default_config = dict(
obs_shape=17,
action_shape=6,
twin_critic=True,
actor_head_type='reparameterization',
action_space='reparameterization',
actor_head_hidden_size=256,
critic_head_hidden_size=256,
),
......
......@@ -19,7 +19,7 @@ halfcheetah_sac_default_config = dict(
obs_shape=17,
action_shape=6,
twin_critic=True,
actor_head_type='reparameterization',
action_space='reparameterization',
actor_head_hidden_size=256,
critic_head_hidden_size=256,
),
......
......@@ -20,7 +20,7 @@ halfcheetah_td3_default_config = dict(
twin_critic=True,
actor_head_hidden_size=256,
critic_head_hidden_size=256,
actor_head_type='regression',
action_space='regression',
),
learn=dict(
update_per_collect=1,
......
......@@ -36,7 +36,7 @@ halfcheetah_trex_sac_default_config = dict(
obs_shape=17,
action_shape=6,
twin_critic=True,
actor_head_type='reparameterization',
action_space='reparameterization',
actor_head_hidden_size=256,
critic_head_hidden_size=256,
),
......
......@@ -17,7 +17,7 @@ hopper_cql_default_config = dict(
obs_shape=11,
action_shape=3,
twin_critic=True,
actor_head_type='reparameterization',
action_space='reparameterization',
actor_head_hidden_size=256,
critic_head_hidden_size=256,
),
......
......@@ -21,7 +21,7 @@ hopper_d4pg_default_config = dict(
action_shape=3,
actor_head_hidden_size=256,
critic_head_hidden_size=256,
actor_head_type='regression',
action_space='regression',
critic_head_type='categorical',
v_min=-100,
v_max=100,
......
......@@ -20,7 +20,7 @@ hopper_ddpg_default_config = dict(
twin_critic=False,
actor_head_hidden_size=256,
critic_head_hidden_size=256,
actor_head_type='regression',
action_space='regression',
),
learn=dict(
update_per_collect=1,
......
......@@ -18,7 +18,7 @@ hopper_sac_data_genearation_default_config = dict(
obs_shape=11,
action_shape=3,
twin_critic=True,
actor_head_type='reparameterization',
action_space='reparameterization',
actor_head_hidden_size=256,
critic_head_hidden_size=256,
),
......
......@@ -19,7 +19,7 @@ hopper_sac_default_config = dict(
obs_shape=11,
action_shape=3,
twin_critic=True,
actor_head_type='reparameterization',
action_space='reparameterization',
actor_head_hidden_size=256,
critic_head_hidden_size=256,
),
......
......@@ -19,7 +19,7 @@ hopper_td3_bc_default_config = dict(
twin_critic=True,
actor_head_hidden_size=256,
critic_head_hidden_size=256,
actor_head_type='regression',
action_space='regression',
),
learn=dict(
normalize_states=True,
......
......@@ -20,7 +20,7 @@ halfcheetah_td3_default_config = dict(
twin_critic=True,
actor_head_hidden_size=256,
critic_head_hidden_size=256,
actor_head_type='regression',
action_space='regression',
),
learn=dict(
update_per_collect=1,
......
......@@ -20,7 +20,7 @@ hopper_td3_default_config = dict(
twin_critic=True,
actor_head_hidden_size=256,
critic_head_hidden_size=256,
actor_head_type='regression',
action_space='regression',
),
learn=dict(
update_per_collect=1,
......
......@@ -36,7 +36,7 @@ hopper_trex_sac_default_config = dict(
obs_shape=11,
action_shape=3,
twin_critic=True,
actor_head_type='reparameterization',
action_space='reparameterization',
actor_head_hidden_size=256,
critic_head_hidden_size=256,
),
......
......@@ -44,7 +44,7 @@ main_config = dict(
obs_shape=obs_shape,
action_shape=action_shape,
twin_critic=True,
actor_head_type='reparameterization',
action_space='reparameterization',
actor_head_hidden_size=256,
critic_head_hidden_size=256,
),
......
......@@ -44,7 +44,7 @@ main_config = dict(
obs_shape=obs_shape,
action_shape=action_shape,
twin_critic=True,
actor_head_type='reparameterization',
action_space='reparameterization',
actor_head_hidden_size=256,
critic_head_hidden_size=256,
),
......
......@@ -21,7 +21,7 @@ walker2d_ddpg_default_config = dict(
twin_critic=False,
actor_head_hidden_size=256,
critic_head_hidden_size=256,
actor_head_type='regression',
action_space='regression',
),
learn=dict(
update_per_collect=1,
......
......@@ -34,7 +34,7 @@ walker2d_ddpg_gail_default_config = dict(
twin_critic=False,
actor_head_hidden_size=256,
critic_head_hidden_size=256,
actor_head_type='regression',
action_space='regression',
),
learn=dict(
update_per_collect=1,
......
......@@ -18,7 +18,7 @@ walker2d_sac_default_config = dict(
obs_shape=17,
action_shape=6,
twin_critic=True,
actor_head_type='reparameterization',
action_space='reparameterization',
actor_head_hidden_size=256,
critic_head_hidden_size=256,
),
......
......@@ -20,7 +20,7 @@ walker2d_td3_default_config = dict(
twin_critic=True,
actor_head_hidden_size=256,
critic_head_hidden_size=256,
actor_head_type='regression',
action_space='regression',
),
learn=dict(
update_per_collect=1,
......
......@@ -36,7 +36,7 @@ walker2d_trex_sac_default_config = dict(
obs_shape=17,
action_shape=6,
twin_critic=True,
actor_head_type='reparameterization',
action_space='reparameterization',
actor_head_hidden_size=256,
critic_head_hidden_size=256,
),
......
......@@ -22,7 +22,7 @@ ant_sac_default_config = dict(
global_obs_shape=111,
action_shape=4,
twin_critic=True,
actor_head_type='reparameterization',
action_space='reparameterization',
actor_head_hidden_size=256,
critic_head_hidden_size=256,
),
......
......@@ -20,7 +20,7 @@ ant_ddpg_default_config = dict(
twin_critic=False,
actor_head_hidden_size=256,
critic_head_hidden_size=256,
actor_head_type='regression',
action_space='regression',
),
learn=dict(
update_per_collect=1,
......
......@@ -18,7 +18,7 @@ ant_sac_default_config = dict(
obs_shape=111,
action_shape=8,
twin_critic=True,
actor_head_type='reparameterization',
action_space='reparameterization',
actor_head_hidden_size=256,
critic_head_hidden_size=256,
),
......
......@@ -20,7 +20,7 @@ ant_td3_default_config = dict(
twin_critic=True,
actor_head_hidden_size=256,
critic_head_hidden_size=256,
actor_head_type='regression',
action_space='regression',
),
learn=dict(
update_per_collect=1,
......
......@@ -20,7 +20,7 @@ halfcheetah_ddpg_default_config = dict(
twin_critic=False,
actor_head_hidden_size=256,
critic_head_hidden_size=256,
actor_head_type='regression',
action_space='regression',
),
learn=dict(
update_per_collect=1,
......
......@@ -18,7 +18,7 @@ halfcheetah_sac_default_config = dict(
obs_shape=17,
action_shape=6,
twin_critic=True,
actor_head_type='reparameterization',
action_space='reparameterization',
actor_head_hidden_size=256,
critic_head_hidden_size=256,
),
......
......@@ -20,7 +20,7 @@ halfcheetah_td3_default_config = dict(
twin_critic=True,
actor_head_hidden_size=256,
critic_head_hidden_size=256,
actor_head_type='regression',
action_space='regression',
),
learn=dict(
update_per_collect=1,
......
......@@ -20,7 +20,7 @@ hopper_ddpg_default_config = dict(
twin_critic=False,
actor_head_hidden_size=256,
critic_head_hidden_size=256,
actor_head_type='regression',
action_space='regression',
),
learn=dict(
update_per_collect=1,
......
......@@ -18,7 +18,7 @@ hopper_sac_default_config = dict(
obs_shape=11,
action_shape=3,
twin_critic=True,
actor_head_type='reparameterization',
action_space='reparameterization',
actor_head_hidden_size=256,
critic_head_hidden_size=256,
),
......
......@@ -20,7 +20,7 @@ hopper_td3_default_config = dict(
twin_critic=True,
actor_head_hidden_size=256,
critic_head_hidden_size=256,
actor_head_type='regression',
action_space='regression',
),
learn=dict(
update_per_collect=1,
......
......@@ -20,7 +20,7 @@ walker2d_ddpg_default_config = dict(
twin_critic=False,
actor_head_hidden_size=256,
critic_head_hidden_size=256,
actor_head_type='regression',
action_space='regression',
),
learn=dict(
update_per_collect=1,
......
......@@ -18,7 +18,7 @@ walker2d_sac_default_config = dict(
obs_shape=17,
action_shape=6,
twin_critic=True,
actor_head_type='reparameterization',
action_space='reparameterization',
actor_head_hidden_size=256,
critic_head_hidden_size=256,
),
......
......@@ -20,7 +20,7 @@ walker2d_td3_default_config = dict(
twin_critic=True,
actor_head_hidden_size=256,
critic_head_hidden_size=256,
actor_head_type='regression',
action_space='regression',
),
learn=dict(
update_per_collect=1,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册