diff --git a/README.md b/README.md
index 70badb986c48f3204592761a67e7e1d74302ff20..dcb402a1b4cc349f52de380665839a02308c1ce7 100644
--- a/README.md
+++ b/README.md
@@ -208,6 +208,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 21 | [gym_hybrid](https://github.com/thomashirtz/gym-hybrid) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | ![ori](dizoo/gym_hybrid/moving_v0.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/gym_hybrid)
[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/env_tutorial/gym_hybrid_zh.html) |
| 22 | [GoBigger](https://github.com/opendilab/GoBigger) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen)![marl](https://img.shields.io/badge/-MARL-yellow)![selfplay](https://img.shields.io/badge/-selfplay-blue) | ![ori](./dizoo/gobigger_overview.gif) | [opendilab link](https://github.com/opendilab/GoBigger-Challenge-2021/tree/main/di_baseline)
[env tutorial](https://gobigger.readthedocs.io/en/latest/index.html)
[环境指南](https://gobigger.readthedocs.io/zh_CN/latest/) |
| 23 | [gym_soccer](https://github.com/openai/gym-soccer) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | ![ori](dizoo/gym_soccer/half_offensive.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/gym_soccer)
[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/env_tutorial/gym_soccer_zh.html) |
+| 24 |[multiagent_mujoco](https://github.com/schroederdewitt/multiagent_mujoco) | ![continuous](https://img.shields.io/badge/-continous-green) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/mujoco/mujoco.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/multiagent_mujoco/envs)
[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/env_tutorial/mujoco_zh.html) |
![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space
diff --git a/ding/model/template/__init__.py b/ding/model/template/__init__.py
index 19497793fe86b0f87234b2ce69831bc9fba1a960..0721c2739f52d61e5530754ae49127d82c01d71e 100644
--- a/ding/model/template/__init__.py
+++ b/ding/model/template/__init__.py
@@ -15,5 +15,5 @@ from .qtran import QTran
from .mavac import MAVAC
from .ngu import NGU
from .qac_dist import QACDIST
-from .maqac import MAQAC
+from .maqac import MAQAC, ContinuousMAQAC
from .model_based import EnsembleDynamicsModel
diff --git a/ding/model/template/maqac.py b/ding/model/template/maqac.py
index 1159cf310e69962731a57d5676fba4e80dd9703c..ca4273d826bc0b173d957ae32cebca61af1739ae 100644
--- a/ding/model/template/maqac.py
+++ b/ding/model/template/maqac.py
@@ -1,4 +1,6 @@
from typing import Union, Dict, Optional
+from easydict import EasyDict
+import numpy as np
import torch
import torch.nn as nn
@@ -94,6 +96,215 @@ class MAQAC(nn.Module):
Overview:
Use bbservation and action tensor to predict output.
Parameter updates with QAC's MLPs forward setup.
+ Arguments:
+ Forward with ``'compute_actor'``:
+ - inputs (:obj:`torch.Tensor`):
+ The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``.
+ Whether ``actor_head_hidden_size`` or ``critic_head_hidden_size`` depend on ``mode``.
+ Forward with ``'compute_critic'``, inputs (`Dict`) Necessary Keys:
+ - ``obs``, ``action`` encoded tensors.
+ - mode (:obj:`str`): Name of the forward mode.
+ Returns:
+ - outputs (:obj:`Dict`): Outputs of network forward.
+ Forward with ``'compute_actor'``, Necessary Keys (either):
+ - action (:obj:`torch.Tensor`): Action tensor with same size as input ``x``.
+ - logit (:obj:`torch.Tensor`): Action's probabilities.
+ Forward with ``'compute_critic'``, Necessary Keys:
+ - q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
+ Actor Shapes:
+ - inputs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``hidden_size``
+ - action (:obj:`torch.Tensor`): :math:`(B, N0)`
+ - q_value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size.
+ Critic Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``global_obs_shape``
+ - logit (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape``
+ """
+ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
+ return getattr(self, mode)(inputs)
+
+ def compute_actor(self, inputs: torch.Tensor) -> Dict:
+ r"""
+ Overview:
+ Use encoded embedding tensor to predict output.
+ Execute parameter updates with ``'compute_actor'`` mode
+ Use encoded embedding tensor to predict output.
+ Arguments:
+ - inputs (:obj:`torch.Tensor`):
+ The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``.
+ ``hidden_size = actor_head_hidden_size``
+ - mode (:obj:`str`): Name of the forward mode.
+ Returns:
+ - outputs (:obj:`Dict`): Outputs of forward pass encoder and head.
+ ReturnsKeys (either):
+ - action (:obj:`torch.Tensor`): Continuous action tensor with same size as ``action_shape``.
+ - logit (:obj:`torch.Tensor`):
+ Logit tensor encoding ``mu`` and ``sigma``, both with same size as input ``x``.
+ Shapes:
+ - inputs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``hidden_size``
+ - action (:obj:`torch.Tensor`): :math:`(B, N0)`
+ - logit (:obj:`list`): 2 elements, mu and sigma, each is the shape of :math:`(B, N0)`.
+ - q_value (:obj:`torch.FloatTensor`): :math:`(B, )`, B is batch size.
+ Examples:
+ >>> # Regression mode
+ >>> model = QAC(64, 64, 'regression')
+ >>> inputs = torch.randn(4, 64)
+ >>> actor_outputs = model(inputs,'compute_actor')
+ >>> assert actor_outputs['action'].shape == torch.Size([4, 64])
+ >>> # Reparameterization Mode
+ >>> model = QAC(64, 64, 'reparameterization')
+ >>> inputs = torch.randn(4, 64)
+ >>> actor_outputs = model(inputs,'compute_actor')
+ >>> actor_outputs['logit'][0].shape # mu
+ >>> torch.Size([4, 64])
+ >>> actor_outputs['logit'][1].shape # sigma
+ >>> torch.Size([4, 64])
+ """
+ action_mask = inputs['obs']['action_mask']
+ x = self.actor(inputs['obs']['agent_state'])
+ return {'logit': x['logit'], 'action_mask': action_mask}
+
+ def compute_critic(self, inputs: Dict) -> Dict:
+ r"""
+ Overview:
+ Execute parameter updates with ``'compute_critic'`` mode
+ Use encoded embedding tensor to predict output.
+ Arguments:
+ - ``obs``, ``action`` encoded tensors.
+ - mode (:obj:`str`): Name of the forward mode.
+ Returns:
+ - outputs (:obj:`Dict`): Q-value output.
+ ReturnKeys:
+ - q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape``
+ - action (:obj:`torch.Tensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape``
+ - q_value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size.
+ Examples:
+ >>> inputs = {'obs': torch.randn(4, N), 'action': torch.randn(4, 1)}
+ >>> model = QAC(obs_shape=(N, ),action_shape=1,actor_head_type='regression')
+ >>> model(inputs, mode='compute_critic')['q_value'] # q value
+ tensor([0.0773, 0.1639, 0.0917, 0.0370], grad_fn=)
+ """
+
+ if self.twin_critic:
+ x = [m(inputs['obs']['global_state'])['logit'] for m in self.critic]
+ else:
+ x = self.critic(inputs['obs']['global_state'])['logit']
+ return {'q_value': x}
+
+
+@MODEL_REGISTRY.register('maqac_continuous')
+class ContinuousMAQAC(nn.Module):
+ r"""
+ Overview:
+ The Continuous MAQAC model.
+ Interfaces:
+ ``__init__``, ``forward``, ``compute_actor``, ``compute_critic``
+ """
+ mode = ['compute_actor', 'compute_critic']
+
+ def __init__(
+ self,
+ agent_obs_shape: Union[int, SequenceType],
+ global_obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType, EasyDict],
+ actor_head_type: str,
+ twin_critic: bool = False,
+ actor_head_hidden_size: int = 64,
+ actor_head_layer_num: int = 1,
+ critic_head_hidden_size: int = 64,
+ critic_head_layer_num: int = 1,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ ) -> None:
+ r"""
+ Overview:
+ Init the QAC Model according to arguments.
+ Arguments:
+ - obs_shape (:obj:`Union[int, SequenceType]`): Observation's space.
+ - action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's space, such as 4, (3, ),
+ EasyDict({'action_type_shape': 3, 'action_args_shape': 4}).
+ - actor_head_type (:obj:`str`): Whether choose ``regression`` or ``reparameterization`` or ``hybrid`` .
+ - twin_critic (:obj:`bool`): Whether include twin critic.
+ - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``.
+ - actor_head_layer_num (:obj:`int`):
+ The num of layers used in the network to compute Q value output for actor's nn.
+ - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``.
+ - critic_head_layer_num (:obj:`int`):
+ The num of layers used in the network to compute Q value output for critic's nn.
+ - activation (:obj:`Optional[nn.Module]`):
+ The type of activation function to use in ``MLP`` the after ``layer_fn``,
+ if ``None`` then default set to ``nn.ReLU()``
+ - norm_type (:obj:`Optional[str]`):
+ The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details.
+ """
+ super(ContinuousMAQAC, self).__init__()
+ obs_shape: int = squeeze(agent_obs_shape)
+ global_obs_shape: int = squeeze(global_obs_shape)
+ action_shape = squeeze(action_shape)
+ self.action_shape = action_shape
+ self.actor_head_type = actor_head_type
+ assert self.actor_head_type in ['regression', 'reparameterization']
+ if self.actor_head_type == 'regression': # DDPG, TD3
+ self.actor = nn.Sequential(
+ nn.Linear(obs_shape, actor_head_hidden_size), activation,
+ RegressionHead(
+ actor_head_hidden_size,
+ action_shape,
+ actor_head_layer_num,
+ final_tanh=True,
+ activation=activation,
+ norm_type=norm_type
+ )
+ )
+ else: # SAC
+ self.actor = nn.Sequential(
+ nn.Linear(obs_shape, actor_head_hidden_size), activation,
+ ReparameterizationHead(
+ actor_head_hidden_size,
+ action_shape,
+ actor_head_layer_num,
+ sigma_type='conditioned',
+ activation=activation,
+ norm_type=norm_type
+ )
+ )
+ self.twin_critic = twin_critic
+ critic_input_size = global_obs_shape + action_shape
+ if self.twin_critic:
+ self.critic = nn.ModuleList()
+ for _ in range(2):
+ self.critic.append(
+ nn.Sequential(
+ nn.Linear(critic_input_size, critic_head_hidden_size), activation,
+ RegressionHead(
+ critic_head_hidden_size,
+ 1,
+ critic_head_layer_num,
+ final_tanh=False,
+ activation=activation,
+ norm_type=norm_type
+ )
+ )
+ )
+ else:
+ self.critic = nn.Sequential(
+ nn.Linear(critic_input_size, critic_head_hidden_size), activation,
+ RegressionHead(
+ critic_head_hidden_size,
+ 1,
+ critic_head_layer_num,
+ final_tanh=False,
+ activation=activation,
+ norm_type=norm_type
+ )
+ )
+
+ def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict:
+ r"""
+ Overview:
+ Use observation and action tensor to predict output.
+ Parameter updates with QAC's MLPs forward setup.
Arguments:
Forward with ``'compute_actor'``:
- inputs (:obj:`torch.Tensor`):
@@ -167,11 +378,16 @@ class MAQAC(nn.Module):
- action (:obj:`torch.Tensor`): Continuous action tensor with same size as ``action_shape``.
- logit (:obj:`torch.Tensor`):
Logit tensor encoding ``mu`` and ``sigma``, both with same size as input ``x``.
+ - logit + action_args
Shapes:
- inputs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``hidden_size``
- action (:obj:`torch.Tensor`): :math:`(B, N0)`
- - logit (:obj:`list`): 2 elements, mu and sigma, each is the shape of :math:`(B, N0)`.
+ - logit (:obj:`Union[list, torch.Tensor]`):
+ - case1(continuous space, list): 2 elements, mu and sigma, each is the shape of :math:`(B, N0)`.
+ - case2(hybrid space, torch.Tensor): :math:`(B, N1)`, where N1 is action_type_shape
- q_value (:obj:`torch.FloatTensor`): :math:`(B, )`, B is batch size.
+ - action_args (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where N2 is action_args_shape
+ (action_args are continuous real value)
Examples:
>>> # Regression mode
>>> model = QAC(64, 64, 'regression')
@@ -187,9 +403,13 @@ class MAQAC(nn.Module):
>>> actor_outputs['logit'][1].shape # sigma
>>> torch.Size([4, 64])
"""
- action_mask = inputs['obs']['action_mask']
- x = self.actor(inputs['obs']['agent_state'])
- return {'logit': x['logit'], 'action_mask': action_mask}
+ inputs = inputs['agent_state']
+ if self.actor_head_type == 'regression':
+ x = self.actor(inputs)
+ return {'action': x['pred']}
+ else:
+ x = self.actor(inputs)
+ return {'logit': [x['mu'], x['sigma']]}
def compute_critic(self, inputs: Dict) -> Dict:
r"""
@@ -197,11 +417,17 @@ class MAQAC(nn.Module):
Execute parameter updates with ``'compute_critic'`` mode
Use encoded embedding tensor to predict output.
Arguments:
- - ``obs``, ``action`` encoded tensors.
+ - inputs (:obj: `Dict`): ``obs``, ``action`` and ``logit` tensors.
- mode (:obj:`str`): Name of the forward mode.
Returns:
- outputs (:obj:`Dict`): Q-value output.
+ ArgumentsKeys:
+ - necessary:
+ - obs: (:obj:`torch.Tensor`): 2-dim vector observation
+ - action (:obj:`Union[torch.Tensor, Dict]`): action from actor
+ - optional:
+ - logit (:obj:`torch.Tensor`): discrete action logit
ReturnKeys:
- q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
Shapes:
@@ -212,13 +438,16 @@ class MAQAC(nn.Module):
Examples:
>>> inputs = {'obs': torch.randn(4, N), 'action': torch.randn(4, 1)}
>>> model = QAC(obs_shape=(N, ),action_shape=1,actor_head_type='regression')
- >>> model(inputs, mode='compute_critic')['q_value'] # q value
- tensor([0.0773, 0.1639, 0.0917, 0.0370], grad_fn=)
-
+ >>> model(inputs, mode='compute_critic')['q_value'] # q value
+ >>> tensor([0.0773, 0.1639, 0.0917, 0.0370], grad_fn=)
"""
+ obs, action = inputs['obs']['global_state'], inputs['action']
+ if len(action.shape) == 1: # (B, ) -> (B, 1)
+ action = action.unsqueeze(1)
+ x = torch.cat([obs, action], dim=-1)
if self.twin_critic:
- x = [m(inputs['obs']['global_state'])['logit'] for m in self.critic]
+ x = [m(x)['pred'] for m in self.critic]
else:
- x = self.critic(inputs['obs']['global_state'])['logit']
+ x = self.critic(x)['pred']
return {'q_value': x}
diff --git a/ding/policy/cql.py b/ding/policy/cql.py
index 1c8bdb1da47e3a7069b02323ee3bfec3d52668f8..a0a448b36473e3b59b2cf8fa39bbd06ce1f78351 100644
--- a/ding/policy/cql.py
+++ b/ding/policy/cql.py
@@ -78,6 +78,7 @@ class CQLPolicy(SACPolicy):
# on-policy setting influences the behaviour of buffer.
# Default False in SAC.
on_policy=False,
+ multi_agent=False,
# (bool type) priority: Determine whether to use priority in buffer sample.
# Default False in SAC.
priority=False,
diff --git a/ding/policy/sac.py b/ding/policy/sac.py
index 06fddabe21005180de4a01bcccde2db04ea4b730..ae371dda48c9da6b5ad6f7159787556234f53ba2 100644
--- a/ding/policy/sac.py
+++ b/ding/policy/sac.py
@@ -583,6 +583,7 @@ class SACPolicy(Policy):
# (int) Number of training samples(randomly collected) in replay buffer when training starts.
# Default 10000 in SAC.
random_collect_size=10000,
+ multi_agent=False,
model=dict(
# (bool type) twin_critic: Determine whether to use double-soft-q-net for target q computation.
# Please refer to TD3 about Clipped Double-Q Learning trick, which learns two Q-functions instead of one .
@@ -1042,7 +1043,10 @@ class SACPolicy(Policy):
return {i: d for i, d in zip(data_id, output)}
def default_model(self) -> Tuple[str, List[str]]:
- return 'qac', ['ding.model.template.qac']
+ if self._cfg.multi_agent:
+ return 'maqac_continuous', ['ding.model.template.maqac']
+ else:
+ return 'qac', ['ding.model.template.qac']
def _monitor_vars_learn(self) -> List[str]:
r"""
diff --git a/ding/rl_utils/td.py b/ding/rl_utils/td.py
index de20833616b906e0b7b64bddeceb1ccddc9b58e5..5a48dd4fa336758aaa4af0244f7c6632ea34e1bf 100644
--- a/ding/rl_utils/td.py
+++ b/ding/rl_utils/td.py
@@ -267,11 +267,17 @@ def v_1step_td_error(
) -> torch.Tensor:
v, next_v, reward, done, weight = data
if weight is None:
- weight = torch.ones_like(reward)
- if done is not None:
- target_v = gamma * (1 - done) * next_v + reward
+ weight = torch.ones_like(v)
+ if len(v.shape) == len(reward.shape):
+ if done is not None:
+ target_v = gamma * (1 - done) * next_v + reward
+ else:
+ target_v = gamma * next_v + reward
else:
- target_v = gamma * next_v + reward
+ if done is not None:
+ target_v = gamma * (1 - done).unsqueeze(1) * next_v + reward.unsqueeze(1)
+ else:
+ target_v = gamma * next_v + reward.unsqueeze(1)
td_error_per_sample = criterion(v, target_v.detach())
return (td_error_per_sample * weight).mean(), td_error_per_sample
diff --git a/ding/rl_utils/tests/test_td.py b/ding/rl_utils/tests/test_td.py
index b03d40597b3590bbc0464984a9b11afefe32407c..06fbe60d2d8e974dc6a7dc4cb6e8a6f2ce8f6392 100644
--- a/ding/rl_utils/tests/test_td.py
+++ b/ding/rl_utils/tests/test_td.py
@@ -196,6 +196,26 @@ def test_v_1step_td():
assert isinstance(v.grad, torch.Tensor)
+@pytest.mark.unittest
+def test_v_1step_multi_agent_td():
+ batch_size = 5
+ agent_num = 2
+ v = torch.randn(batch_size, agent_num).requires_grad_(True)
+ next_v = torch.randn(batch_size, agent_num)
+ reward = torch.rand(batch_size)
+ done = torch.zeros(batch_size)
+ data = v_1step_td_data(v, next_v, reward, done, None)
+ loss, td_error_per_sample = v_1step_td_error(data, 0.99)
+ assert loss.shape == ()
+ assert v.grad is None
+ loss.backward()
+ assert isinstance(v.grad, torch.Tensor)
+ data = v_1step_td_data(v, next_v, reward, None, None)
+ loss, td_error_per_sample = v_1step_td_error(data, 0.99)
+ loss.backward()
+ assert isinstance(v.grad, torch.Tensor)
+
+
@pytest.mark.unittest
def test_v_nstep_td():
batch_size = 5
diff --git a/dizoo/multiagent_mujoco/README.md b/dizoo/multiagent_mujoco/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..6e82c6ecb1f79706c785a8270a304448df8ab5c2
--- /dev/null
+++ b/dizoo/multiagent_mujoco/README.md
@@ -0,0 +1,7 @@
+## Multi Agent Mujoco Env
+
+Multi Agent Mujoco is an environment for Continuous Multi-Agent Robotic Control, based on OpenAI's Mujoco Gym environments.
+
+The environment is described in the paper [Deep Multi-Agent Reinforcement Learning for Decentralized Continuous Cooperative Control](https://arxiv.org/abs/2003.06709) by Christian Schroeder de Witt, Bei Peng, Pierre-Alexandre Kamienny, Philip Torr, Wendelin Böhmer and Shimon Whiteson, Torr Vision Group and Whiteson Research Lab, University of Oxford, 2020
+
+You can find more details in [Multi-Agent Mujoco Environment](https://github.com/schroederdewitt/multiagent_mujoco)
diff --git a/dizoo/multiagent_mujoco/__init__.py b/dizoo/multiagent_mujoco/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/dizoo/multiagent_mujoco/config/ant_masac_default_config.py b/dizoo/multiagent_mujoco/config/ant_masac_default_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..885c7d870a48e9e9a7fd3c5af0933f337731e78a
--- /dev/null
+++ b/dizoo/multiagent_mujoco/config/ant_masac_default_config.py
@@ -0,0 +1,72 @@
+from easydict import EasyDict
+from ding.entry.serial_entry import serial_pipeline
+ant_sac_default_config = dict(
+ exp_name='multi_mujoco_ant_2x4',
+ env=dict(
+ scenario='Ant-v2',
+ agent_conf="2x4d",
+ agent_obsk=2,
+ add_agent_id=False,
+ episode_limit=1000,
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=0,
+ multi_agent=True,
+ model=dict(
+ agent_obs_shape=54,
+ global_obs_shape=111,
+ action_shape=4,
+ twin_critic=True,
+ actor_head_type='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=10,
+ batch_size=256,
+ learning_rate_q=1e-3,
+ learning_rate_policy=1e-3,
+ learning_rate_alpha=3e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=True,
+ log_space=True,
+ ),
+ collect=dict(
+ n_sample=400,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(evaluator=dict(eval_freq=100, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=100000, ), ),
+ ),
+)
+
+ant_sac_default_config = EasyDict(ant_sac_default_config)
+main_config = ant_sac_default_config
+
+ant_sac_default_create_config = dict(
+ env=dict(
+ type='mujoco_multi',
+ import_names=['dizoo.multiagent_mujoco.envs.multi_mujoco_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='sac',
+ import_names=['ding.policy.sac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+ant_sac_default_create_config = EasyDict(ant_sac_default_create_config)
+create_config = ant_sac_default_create_config
+
+if __name__ == '__main__':
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/dizoo/multiagent_mujoco/envs/__init__.py b/dizoo/multiagent_mujoco/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a828ba4e982da51f5fc218a79e597e131a57567
--- /dev/null
+++ b/dizoo/multiagent_mujoco/envs/__init__.py
@@ -0,0 +1,4 @@
+from .mujoco_multi import MujocoMulti
+from .coupled_half_cheetah import CoupledHalfCheetah
+from .manyagent_swimmer import ManyAgentSwimmerEnv
+from .manyagent_ant import ManyAgentAntEnv
diff --git a/dizoo/multiagent_mujoco/envs/assets/.gitignore b/dizoo/multiagent_mujoco/envs/assets/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..eb0d0a0f1a89ef2ca8e1433ffbe77cb361e0cf11
--- /dev/null
+++ b/dizoo/multiagent_mujoco/envs/assets/.gitignore
@@ -0,0 +1 @@
+*.auto.xml
diff --git a/dizoo/multiagent_mujoco/envs/assets/__init__.py b/dizoo/multiagent_mujoco/envs/assets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/dizoo/multiagent_mujoco/envs/assets/coupled_half_cheetah.xml b/dizoo/multiagent_mujoco/envs/assets/coupled_half_cheetah.xml
new file mode 100644
index 0000000000000000000000000000000000000000..b8c2f9f626b5969edc98f5984e13ca5a3bab36f7
--- /dev/null
+++ b/dizoo/multiagent_mujoco/envs/assets/coupled_half_cheetah.xml
@@ -0,0 +1,140 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ -
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/dizoo/multiagent_mujoco/envs/assets/manyagent_ant.xml b/dizoo/multiagent_mujoco/envs/assets/manyagent_ant.xml
new file mode 100644
index 0000000000000000000000000000000000000000..103c74452687b247a06e7c5bd43d7d0582dc23d3
--- /dev/null
+++ b/dizoo/multiagent_mujoco/envs/assets/manyagent_ant.xml
@@ -0,0 +1,134 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/dizoo/multiagent_mujoco/envs/assets/manyagent_ant.xml.template b/dizoo/multiagent_mujoco/envs/assets/manyagent_ant.xml.template
new file mode 100644
index 0000000000000000000000000000000000000000..3b6b4eb85a14d9416c398a01fd4ab4bc6d397575
--- /dev/null
+++ b/dizoo/multiagent_mujoco/envs/assets/manyagent_ant.xml.template
@@ -0,0 +1,54 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ body }}
+
+
+
+ {{ actuators }}
+
+
\ No newline at end of file
diff --git a/dizoo/multiagent_mujoco/envs/assets/manyagent_ant__stage1.xml b/dizoo/multiagent_mujoco/envs/assets/manyagent_ant__stage1.xml
new file mode 100644
index 0000000000000000000000000000000000000000..c6ef416f3c33575eb088742242d339613a651e23
--- /dev/null
+++ b/dizoo/multiagent_mujoco/envs/assets/manyagent_ant__stage1.xml
@@ -0,0 +1,85 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/dizoo/multiagent_mujoco/envs/assets/manyagent_swimmer.xml.template b/dizoo/multiagent_mujoco/envs/assets/manyagent_swimmer.xml.template
new file mode 100644
index 0000000000000000000000000000000000000000..9fb49a95230e5dc8983ef5c81788a5463ef9d99e
--- /dev/null
+++ b/dizoo/multiagent_mujoco/envs/assets/manyagent_swimmer.xml.template
@@ -0,0 +1,34 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ body }}
+
+
+
+
+{{ actuators }}
+
+
\ No newline at end of file
diff --git a/dizoo/multiagent_mujoco/envs/assets/manyagent_swimmer__bckp2.xml b/dizoo/multiagent_mujoco/envs/assets/manyagent_swimmer__bckp2.xml
new file mode 100644
index 0000000000000000000000000000000000000000..bce5149599c5eec4cae496030c0523a58ba33b53
--- /dev/null
+++ b/dizoo/multiagent_mujoco/envs/assets/manyagent_swimmer__bckp2.xml
@@ -0,0 +1,48 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/dizoo/multiagent_mujoco/envs/assets/manyagent_swimmer_bckp.xml b/dizoo/multiagent_mujoco/envs/assets/manyagent_swimmer_bckp.xml
new file mode 100644
index 0000000000000000000000000000000000000000..3477813790a32e81d4db1bc7b9a997d90f70c58b
--- /dev/null
+++ b/dizoo/multiagent_mujoco/envs/assets/manyagent_swimmer_bckp.xml
@@ -0,0 +1,43 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/dizoo/multiagent_mujoco/envs/coupled_half_cheetah.py b/dizoo/multiagent_mujoco/envs/coupled_half_cheetah.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fe0a68507fd5272ff1c3d6bc7ea827e9fbac7eb
--- /dev/null
+++ b/dizoo/multiagent_mujoco/envs/coupled_half_cheetah.py
@@ -0,0 +1,48 @@
+import numpy as np
+from gym import utils
+from gym.envs.mujoco import mujoco_env
+import os
+
+
+class CoupledHalfCheetah(mujoco_env.MujocoEnv, utils.EzPickle):
+
+ def __init__(self, **kwargs):
+ mujoco_env.MujocoEnv.__init__(
+ self, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'coupled_half_cheetah.xml'), 5
+ )
+ utils.EzPickle.__init__(self)
+
+ def step(self, action):
+ xposbefore1 = self.sim.data.qpos[0]
+ xposbefore2 = self.sim.data.qpos[len(self.sim.data.qpos) // 2]
+ self.do_simulation(action, self.frame_skip)
+ xposafter1 = self.sim.data.qpos[0]
+ xposafter2 = self.sim.data.qpos[len(self.sim.data.qpos) // 2]
+ ob = self._get_obs()
+ reward_ctrl1 = -0.1 * np.square(action[0:len(action) // 2]).sum()
+ reward_ctrl2 = -0.1 * np.square(action[len(action) // 2:]).sum()
+ reward_run1 = (xposafter1 - xposbefore1) / self.dt
+ reward_run2 = (xposafter2 - xposbefore2) / self.dt
+ reward = (reward_ctrl1 + reward_ctrl2) / 2.0 + (reward_run1 + reward_run2) / 2.0
+ done = False
+ return ob, reward, done, dict(
+ reward_run1=reward_run1, reward_ctrl1=reward_ctrl1, reward_run2=reward_run2, reward_ctrl2=reward_ctrl2
+ )
+
+ def _get_obs(self):
+ return np.concatenate([
+ self.sim.data.qpos.flat[1:],
+ self.sim.data.qvel.flat,
+ ])
+
+ def reset_model(self):
+ qpos = self.init_qpos + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq)
+ qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1
+ self.set_state(qpos, qvel)
+ return self._get_obs()
+
+ def viewer_setup(self):
+ self.viewer.cam.distance = self.model.stat.extent * 0.5
+
+ def get_env_info(self):
+ return {"episode_limit": self.episode_limit}
diff --git a/dizoo/multiagent_mujoco/envs/manyagent_ant.py b/dizoo/multiagent_mujoco/envs/manyagent_ant.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bfb793780fa3c8ab53131038f594dfee730aab5
--- /dev/null
+++ b/dizoo/multiagent_mujoco/envs/manyagent_ant.py
@@ -0,0 +1,120 @@
+import numpy as np
+from gym import utils
+from gym.envs.mujoco import mujoco_env
+from jinja2 import Template
+import os
+
+
+class ManyAgentAntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
+
+ def __init__(self, **kwargs):
+ agent_conf = kwargs.get("agent_conf")
+ n_agents = int(agent_conf.split("x")[0])
+ n_segs_per_agents = int(agent_conf.split("x")[1])
+ n_segs = n_agents * n_segs_per_agents
+
+ # Check whether asset file exists already, otherwise create it
+ asset_path = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), 'assets',
+ 'manyagent_ant_{}_agents_each_{}_segments.auto.xml'.format(n_agents, n_segs_per_agents)
+ )
+ #if not os.path.exists(asset_path):
+ print("Auto-Generating Manyagent Ant asset with {} segments at {}.".format(n_segs, asset_path))
+ self._generate_asset(n_segs=n_segs, asset_path=asset_path)
+
+ #asset_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets',git p
+ # 'manyagent_swimmer.xml')
+
+ mujoco_env.MujocoEnv.__init__(self, asset_path, 4)
+ utils.EzPickle.__init__(self)
+
+ def _generate_asset(self, n_segs, asset_path):
+ template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'manyagent_ant.xml.template')
+ with open(template_path, "r") as f:
+ t = Template(f.read())
+ body_str_template = """
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ """
+
+ body_close_str_template = "\n"
+ actuator_str_template = """\t
+
+
+ \n"""
+
+ body_str = ""
+ for i in range(1, n_segs):
+ body_str += body_str_template.format(*([i] * 16))
+ body_str += body_close_str_template * (n_segs - 1)
+
+ actuator_str = ""
+ for i in range(n_segs):
+ actuator_str += actuator_str_template.format(*([i] * 8))
+
+ rt = t.render(body=body_str, actuators=actuator_str)
+ with open(asset_path, "w") as f:
+ f.write(rt)
+ pass
+
+ def step(self, a):
+ xposbefore = self.get_body_com("torso_0")[0]
+ self.do_simulation(a, self.frame_skip)
+ xposafter = self.get_body_com("torso_0")[0]
+ forward_reward = (xposafter - xposbefore) / self.dt
+ ctrl_cost = .5 * np.square(a).sum()
+ contact_cost = 0.5 * 1e-3 * np.sum(np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
+ survive_reward = 1.0
+ reward = forward_reward - ctrl_cost - contact_cost + survive_reward
+ state = self.state_vector()
+ notdone = np.isfinite(state).all() \
+ and state[2] >= 0.2 and state[2] <= 1.0
+ done = not notdone
+ ob = self._get_obs()
+ return ob, reward, done, dict(
+ reward_forward=forward_reward,
+ reward_ctrl=-ctrl_cost,
+ reward_contact=-contact_cost,
+ reward_survive=survive_reward
+ )
+
+ def _get_obs(self):
+ return np.concatenate(
+ [
+ self.sim.data.qpos.flat[2:],
+ self.sim.data.qvel.flat,
+ np.clip(self.sim.data.cfrc_ext, -1, 1).flat,
+ ]
+ )
+
+ def reset_model(self):
+ qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-.1, high=.1)
+ qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1
+ self.set_state(qpos, qvel)
+ return self._get_obs()
+
+ def viewer_setup(self):
+ self.viewer.cam.distance = self.model.stat.extent * 0.5
diff --git a/dizoo/multiagent_mujoco/envs/manyagent_swimmer.py b/dizoo/multiagent_mujoco/envs/manyagent_swimmer.py
new file mode 100644
index 0000000000000000000000000000000000000000..70e8677a01347b0522f5644de50be0e2ca071757
--- /dev/null
+++ b/dizoo/multiagent_mujoco/envs/manyagent_swimmer.py
@@ -0,0 +1,89 @@
+import numpy as np
+from gym import utils
+from gym.envs.mujoco import mujoco_env
+import os
+from jinja2 import Template
+
+
+class ManyAgentSwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
+
+ def __init__(self, **kwargs):
+ agent_conf = kwargs.get("agent_conf")
+ n_agents = int(agent_conf.split("x")[0])
+ n_segs_per_agents = int(agent_conf.split("x")[1])
+ n_segs = n_agents * n_segs_per_agents
+
+ # Check whether asset file exists already, otherwise create it
+ asset_path = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), 'assets',
+ 'manyagent_swimmer_{}_agents_each_{}_segments.auto.xml'.format(n_agents, n_segs_per_agents)
+ )
+ # if not os.path.exists(asset_path):
+ print("Auto-Generating Manyagent Swimmer asset with {} segments at {}.".format(n_segs, asset_path))
+ self._generate_asset(n_segs=n_segs, asset_path=asset_path)
+
+ #asset_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets',git p
+ # 'manyagent_swimmer.xml')
+
+ mujoco_env.MujocoEnv.__init__(self, asset_path, 4)
+ utils.EzPickle.__init__(self)
+
+ def _generate_asset(self, n_segs, asset_path):
+ template_path = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), 'assets', 'manyagent_swimmer.xml.template'
+ )
+ with open(template_path, "r") as f:
+ t = Template(f.read())
+ body_str_template = """
+
+
+
+ """
+
+ body_end_str_template = """
+
+
+
+
+ """
+
+ body_close_str_template = "\n"
+ actuator_str_template = """\t \n"""
+
+ body_str = ""
+ for i in range(1, n_segs - 1):
+ body_str += body_str_template.format(i, (-1) ** (i + 1), i)
+ body_str += body_end_str_template.format(n_segs - 1)
+ body_str += body_close_str_template * (n_segs - 2)
+
+ actuator_str = ""
+ for i in range(n_segs):
+ actuator_str += actuator_str_template.format(i)
+
+ rt = t.render(body=body_str, actuators=actuator_str)
+ with open(asset_path, "w") as f:
+ f.write(rt)
+ pass
+
+ def step(self, a):
+ ctrl_cost_coeff = 0.0001
+ xposbefore = self.sim.data.qpos[0]
+ self.do_simulation(a, self.frame_skip)
+ xposafter = self.sim.data.qpos[0]
+ reward_fwd = (xposafter - xposbefore) / self.dt
+ reward_ctrl = -ctrl_cost_coeff * np.square(a).sum()
+ reward = reward_fwd + reward_ctrl
+ ob = self._get_obs()
+ return ob, reward, False, dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl)
+
+ def _get_obs(self):
+ qpos = self.sim.data.qpos
+ qvel = self.sim.data.qvel
+ return np.concatenate([qpos.flat[2:], qvel.flat])
+
+ def reset_model(self):
+ self.set_state(
+ self.init_qpos + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq),
+ self.init_qvel + self.np_random.uniform(low=-.1, high=.1, size=self.model.nv)
+ )
+ return self._get_obs()
diff --git a/dizoo/multiagent_mujoco/envs/mujoco_multi.py b/dizoo/multiagent_mujoco/envs/mujoco_multi.py
new file mode 100755
index 0000000000000000000000000000000000000000..8b5b6838cb7071775babcfc16403a60d31db35d5
--- /dev/null
+++ b/dizoo/multiagent_mujoco/envs/mujoco_multi.py
@@ -0,0 +1,247 @@
+from functools import partial
+import gym
+from gym.spaces import Box
+from gym.wrappers import TimeLimit
+import numpy as np
+
+from .multiagentenv import MultiAgentEnv
+from .obsk import get_joints_at_kdist, get_parts_and_edges, build_obs
+
+
+# using code from https://github.com/ikostrikov/pytorch-ddpg-naf
+class NormalizedActions(gym.ActionWrapper):
+
+ def _action(self, action):
+ action = (action + 1) / 2
+ action *= (self.action_space.high - self.action_space.low)
+ action += self.action_space.low
+ return action
+
+ def action(self, action_):
+ return self._action(action_)
+
+ def _reverse_action(self, action):
+ action -= self.action_space.low
+ action /= (self.action_space.high - self.action_space.low)
+ action = action * 2 - 1
+ return action
+
+
+class MujocoMulti(MultiAgentEnv):
+
+ def __init__(self, batch_size=None, **kwargs):
+ super().__init__(batch_size, **kwargs)
+ self.add_agent_id = kwargs["env_args"]["add_agent_id"]
+ self.scenario = kwargs["env_args"]["scenario"] # e.g. Ant-v2
+ self.agent_conf = kwargs["env_args"]["agent_conf"] # e.g. '2x3'
+
+ self.agent_partitions, self.mujoco_edges, self.mujoco_globals = get_parts_and_edges(
+ self.scenario, self.agent_conf
+ )
+
+ self.n_agents = len(self.agent_partitions)
+ self.n_actions = max([len(l) for l in self.agent_partitions])
+ self.obs_add_global_pos = kwargs["env_args"].get("obs_add_global_pos", False)
+
+ self.agent_obsk = kwargs["env_args"].get(
+ "agent_obsk", None
+ ) # if None, fully observable else k>=0 implies observe nearest k agents or joints
+ self.agent_obsk_agents = kwargs["env_args"].get(
+ "agent_obsk_agents", False
+ ) # observe full k nearest agents (True) or just single joints (False)
+
+ if self.agent_obsk is not None:
+ self.k_categories_label = kwargs["env_args"].get("k_categories")
+ if self.k_categories_label is None:
+ if self.scenario in ["Ant-v2", "manyagent_ant"]:
+ self.k_categories_label = "qpos,qvel,cfrc_ext|qpos"
+ elif self.scenario in ["Humanoid-v2", "HumanoidStandup-v2"]:
+ self.k_categories_label = "qpos,qvel,cfrc_ext,cvel,cinert,qfrc_actuator|qpos"
+ elif self.scenario in ["Reacher-v2"]:
+ self.k_categories_label = "qpos,qvel,fingertip_dist|qpos"
+ elif self.scenario in ["coupled_half_cheetah"]:
+ self.k_categories_label = "qpos,qvel,ten_J,ten_length,ten_velocity|"
+ else:
+ self.k_categories_label = "qpos,qvel|qpos"
+
+ k_split = self.k_categories_label.split("|")
+ self.k_categories = [k_split[k if k < len(k_split) else -1].split(",") for k in range(self.agent_obsk + 1)]
+
+ self.global_categories_label = kwargs["env_args"].get("global_categories")
+ self.global_categories = self.global_categories_label.split(
+ ","
+ ) if self.global_categories_label is not None else []
+
+ if self.agent_obsk is not None:
+ self.k_dicts = [
+ get_joints_at_kdist(
+ agent_id,
+ self.agent_partitions,
+ self.mujoco_edges,
+ k=self.agent_obsk,
+ kagents=False,
+ ) for agent_id in range(self.n_agents)
+ ]
+
+ # load scenario from script
+ self.episode_limit = self.args.episode_limit
+
+ self.env_version = kwargs["env_args"].get("env_version", 2)
+ if self.env_version == 2:
+ try:
+ self.wrapped_env = NormalizedActions(gym.make(self.scenario))
+ except gym.error.Error: # env not in gym
+ if self.scenario in ["manyagent_ant"]:
+ from .manyagent_ant import ManyAgentAntEnv as this_env
+ elif self.scenario in ["manyagent_swimmer"]:
+ from .manyagent_swimmer import ManyAgentSwimmerEnv as this_env
+ elif self.scenario in ["coupled_half_cheetah"]:
+ from .coupled_half_cheetah import CoupledHalfCheetah as this_env
+ else:
+ raise NotImplementedError('Custom env not implemented!')
+ self.wrapped_env = NormalizedActions(
+ TimeLimit(this_env(**kwargs["env_args"]), max_episode_steps=self.episode_limit)
+ )
+ else:
+ assert False, "not implemented!"
+ self.timelimit_env = self.wrapped_env.env
+ self.timelimit_env._max_episode_steps = self.episode_limit
+ self.env = self.timelimit_env.env
+ self.timelimit_env.reset()
+ self.obs_size = self.get_obs_size()
+
+ # COMPATIBILITY
+ self.n = self.n_agents
+ self.observation_space = [
+ Box(low=np.array([-10] * self.n_agents), high=np.array([10] * self.n_agents)) for _ in range(self.n_agents)
+ ]
+
+ acdims = [len(ap) for ap in self.agent_partitions]
+ self.action_space = tuple(
+ [
+ Box(
+ self.env.action_space.low[sum(acdims[:a]):sum(acdims[:a + 1])],
+ self.env.action_space.high[sum(acdims[:a]):sum(acdims[:a + 1])]
+ ) for a in range(self.n_agents)
+ ]
+ )
+ pass
+
+ def step(self, actions):
+
+ # need to remove dummy actions that arise due to unequal action vector sizes across agents
+ flat_actions = np.concatenate([actions[i][:self.action_space[i].low.shape[0]] for i in range(self.n_agents)])
+ obs_n, reward_n, done_n, info_n = self.wrapped_env.step(flat_actions)
+ self.steps += 1
+
+ info = {}
+ info.update(info_n)
+
+ if done_n:
+ if self.steps < self.episode_limit:
+ info["episode_limit"] = False # the next state will be masked out
+ else:
+ info["episode_limit"] = True # the next state will not be masked out
+
+ obs = {'agent_state': self.get_obs(), 'global_state': self.get_state()}
+
+ return obs, reward_n, done_n, info
+
+ def get_obs(self):
+ """ Returns all agent observat3ions in a list """
+ obs_n = []
+ for a in range(self.n_agents):
+ obs_n.append(self.get_obs_agent(a))
+ return np.array(obs_n).astype(np.float32)
+
+ def get_obs_agent(self, agent_id):
+ if self.agent_obsk is None:
+ return self.env._get_obs()
+ else:
+ return build_obs(
+ self.env,
+ self.k_dicts[agent_id],
+ self.k_categories,
+ self.mujoco_globals,
+ self.global_categories,
+ vec_len=getattr(self, "obs_size", None)
+ )
+
+ def get_obs_size(self):
+ """ Returns the shape of the observation """
+ if self.agent_obsk is None:
+ return self.get_obs_agent(0).size
+ else:
+ return max([len(self.get_obs_agent(agent_id)) for agent_id in range(self.n_agents)])
+
+ def get_state(self, team=None):
+ # TODO: May want global states for different teams (so cannot see what the other team is communicating e.g.)
+ state_n = []
+ if self.add_agent_id:
+ state = self.env._get_obs()
+ for a in range(self.n_agents):
+ agent_id_feats = np.zeros(self.n_agents, dtype=np.float32)
+ agent_id_feats[a] = 1.0
+ state_i = np.concatenate([state, agent_id_feats])
+ state_n.append(state_i)
+ else:
+ for a in range(self.n_agents):
+ state_n.append(self.env._get_obs())
+ return np.array(state_n).astype(np.float32)
+
+ def get_state_size(self):
+ """ Returns the shape of the state"""
+ return len(self.get_state())
+
+ def get_avail_actions(self): # all actions are always available
+ return np.ones(shape=(
+ self.n_agents,
+ self.n_actions,
+ ))
+
+ def get_avail_agent_actions(self, agent_id):
+ """ Returns the available actions for agent_id """
+ return np.ones(shape=(self.n_actions, ))
+
+ def get_total_actions(self):
+ """ Returns the total number of actions an agent could ever take """
+ return self.n_actions # CAREFUL! - for continuous dims, this is action space dim rather
+ # return self.env.action_space.shape[0]
+
+ def get_stats(self):
+ return {}
+
+ # TODO: Temp hack
+ def get_agg_stats(self, stats):
+ return {}
+
+ def reset(self, **kwargs):
+ """ Returns initial observations and states"""
+ self.steps = 0
+ self.timelimit_env.reset()
+ obs = {'agent_state': self.get_obs(), 'global_state': self.get_state()}
+ return obs
+
+ def render(self, **kwargs):
+ self.env.render(**kwargs)
+
+ def close(self):
+ pass
+ #raise NotImplementedError
+
+ def seed(self, args):
+ pass
+
+ def get_env_info(self):
+
+ env_info = {
+ "state_shape": self.get_state_size(),
+ "obs_shape": self.get_obs_size(),
+ "n_actions": self.get_total_actions(),
+ "n_agents": self.n_agents,
+ "episode_limit": self.episode_limit,
+ "action_spaces": self.action_space,
+ "actions_dtype": np.float32,
+ "normalise_actions": False
+ }
+ return env_info
diff --git a/dizoo/multiagent_mujoco/envs/multi_mujoco_env.py b/dizoo/multiagent_mujoco/envs/multi_mujoco_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a2edb22e618aa214f002eb0f3c32e03ec846946
--- /dev/null
+++ b/dizoo/multiagent_mujoco/envs/multi_mujoco_env.py
@@ -0,0 +1,91 @@
+from typing import Any, Union, List
+import copy
+import numpy as np
+
+from ding.envs import BaseEnv, BaseEnvTimestep, BaseEnvInfo, update_shape
+from ding.envs.common.env_element import EnvElement, EnvElementInfo
+from ding.envs.common.common_function import affine_transform
+from ding.torch_utils import to_ndarray, to_list
+from .mujoco_multi import MujocoMulti
+from ding.utils import ENV_REGISTRY
+
+
+@ENV_REGISTRY.register('mujoco_multi')
+class MujocoEnv(BaseEnv):
+
+ def __init__(self, cfg: dict) -> None:
+ self._cfg = cfg
+ self._init_flag = False
+
+ def reset(self) -> np.ndarray:
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._cfg.seed = self._seed + np_seed
+ elif hasattr(self, '_seed'):
+ self._cfg.seed = self._seed
+ if not self._init_flag:
+ self._env = MujocoMulti(env_args=self._cfg)
+ self._init_flag = True
+ obs = self._env.reset()
+ #print(obs)
+ #obs['agent_state'] = to_ndarray(obs['agent_state']).astype('float32')
+ #obs['global_state'] = to_ndarray(obs['global_state']).astype('float32')
+ self._final_eval_reward = 0.
+ return obs
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def step(self, action: Union[np.ndarray, list]) -> BaseEnvTimestep:
+ action = to_ndarray(action)
+ obs, rew, done, info = self._env.step(action)
+ self._final_eval_reward += rew
+ #obs = to_ndarray(obs).astype('float32')
+ rew = to_ndarray([rew]) # wrapped to be transfered to a array with shape (1,)
+ if done:
+ info['final_eval_reward'] = self._final_eval_reward
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def info(self) -> BaseEnvInfo:
+ env_info = self._env.get_env_info()
+ info = BaseEnvInfo(
+ agent_num=env_info['n_agents'],
+ obs_space=EnvElementInfo(
+ shape={
+ 'agent_state': env_info['obs_shape'],
+ 'global_state': env_info['state_shape'],
+ },
+ value={
+ 'min': np.float32("-inf"),
+ 'max': np.float32("inf"),
+ 'dtype': np.float32
+ },
+ ),
+ act_space=EnvElementInfo(
+ shape=env_info['action_spaces'],
+ value={
+ 'min': np.float32("-inf"),
+ 'max': np.float32("inf"),
+ 'dtype': np.float32
+ },
+ ),
+ rew_space=EnvElementInfo(
+ shape=1,
+ value={
+ 'min': np.float64("-inf"),
+ 'max': np.float64("inf")
+ },
+ ),
+ use_wrappers=None,
+ ),
+ return info
+
+ def __repr__(self) -> str:
+ return "DI-engine Multi-agent Mujoco Env({})".format(self._cfg.env_id)
diff --git a/dizoo/multiagent_mujoco/envs/multiagentenv.py b/dizoo/multiagent_mujoco/envs/multiagentenv.py
new file mode 100755
index 0000000000000000000000000000000000000000..07e65fc549a98d6a85d49d3ab77d7614ed9e7fca
--- /dev/null
+++ b/dizoo/multiagent_mujoco/envs/multiagentenv.py
@@ -0,0 +1,85 @@
+from collections import namedtuple
+import numpy as np
+
+
+def convert(dictionary):
+ return namedtuple('GenericDict', dictionary.keys())(**dictionary)
+
+
+class MultiAgentEnv(object):
+
+ def __init__(self, batch_size=None, **kwargs):
+ # Unpack arguments from sacred
+ args = kwargs["env_args"]
+ if isinstance(args, dict):
+ args = convert(args)
+ self.args = args
+
+ if getattr(args, "seed", None) is not None:
+ self.seed = args.seed
+ self.rs = np.random.RandomState(self.seed) # initialise numpy random state
+
+ def step(self, actions):
+ """ Returns reward, terminated, info """
+ raise NotImplementedError
+
+ def get_obs(self):
+ """ Returns all agent observations in a list """
+ raise NotImplementedError
+
+ def get_obs_agent(self, agent_id):
+ """ Returns observation for agent_id """
+ raise NotImplementedError
+
+ def get_obs_size(self):
+ """ Returns the shape of the observation """
+ raise NotImplementedError
+
+ def get_state(self):
+ raise NotImplementedError
+
+ def get_state_size(self):
+ """ Returns the shape of the state"""
+ raise NotImplementedError
+
+ def get_avail_actions(self):
+ raise NotImplementedError
+
+ def get_avail_agent_actions(self, agent_id):
+ """ Returns the available actions for agent_id """
+ raise NotImplementedError
+
+ def get_total_actions(self):
+ """ Returns the total number of actions an agent could ever take """
+ # TODO: This is only suitable for a discrete 1 dimensional action space for each agent
+ raise NotImplementedError
+
+ def get_stats(self):
+ raise NotImplementedError
+
+ # TODO: Temp hack
+ def get_agg_stats(self, stats):
+ return {}
+
+ def reset(self):
+ """ Returns initial observations and states"""
+ raise NotImplementedError
+
+ def render(self):
+ raise NotImplementedError
+
+ def close(self):
+ raise NotImplementedError
+
+ def seed(self, seed):
+ raise NotImplementedError
+
+ def get_env_info(self):
+ env_info = {
+ "state_shape": self.get_state_size(),
+ "obs_shape": self.get_obs_size(),
+ "n_actions": self.get_total_actions(),
+ "n_agents": self.n_agents,
+ "episode_limit": self.episode_limit
+ }
+ return env_info
diff --git a/dizoo/multiagent_mujoco/envs/obsk.py b/dizoo/multiagent_mujoco/envs/obsk.py
new file mode 100644
index 0000000000000000000000000000000000000000..404f455abe0711a53febe8025c71f46584e5b70f
--- /dev/null
+++ b/dizoo/multiagent_mujoco/envs/obsk.py
@@ -0,0 +1,662 @@
+import itertools
+import numpy as np
+from copy import deepcopy
+
+
+class Node():
+
+ def __init__(self, label, qpos_ids, qvel_ids, act_ids, body_fn=None, bodies=None, extra_obs=None, tendons=None):
+ self.label = label
+ self.qpos_ids = qpos_ids
+ self.qvel_ids = qvel_ids
+ self.act_ids = act_ids
+ self.bodies = bodies
+ self.extra_obs = {} if extra_obs is None else extra_obs
+ self.body_fn = body_fn
+ self.tendons = tendons
+ pass
+
+ def __str__(self):
+ return self.label
+
+ def __repr__(self):
+ return self.label
+
+
+class HyperEdge():
+
+ def __init__(self, *edges):
+ self.edges = set(edges)
+
+ def __contains__(self, item):
+ return item in self.edges
+
+ def __str__(self):
+ return "HyperEdge({})".format(self.edges)
+
+ def __repr__(self):
+ return "HyperEdge({})".format(self.edges)
+
+
+def get_joints_at_kdist(
+ agent_id,
+ agent_partitions,
+ hyperedges,
+ k=0,
+ kagents=False,
+):
+ """ Identify all joints at distance <= k from agent agent_id
+
+ :param agent_id: id of agent to be considered
+ :param agent_partitions: list of joint tuples in order of agentids
+ :param edges: list of tuples (joint1, joint2)
+ :param k: kth degree
+ :param kagents: True (observe all joints of an agent if a single one is) or False (individual joint granularity)
+ :return:
+ dict with k as key, and list of joints at that distance
+ """
+ assert not kagents, "kagents not implemented!"
+
+ agent_joints = agent_partitions[agent_id]
+
+ def _adjacent(lst, kagents=False):
+ # return all sets adjacent to any element in lst
+ ret = set([])
+ for l in lst:
+ ret = ret.union(set(itertools.chain(*[e.edges.difference({l}) for e in hyperedges if l in e])))
+ return ret
+
+ seen = set([])
+ new = set([])
+ k_dict = {}
+ for _k in range(k + 1):
+ if not _k:
+ new = set(agent_joints)
+ else:
+ print(hyperedges)
+ new = _adjacent(new) - seen
+ seen = seen.union(new)
+ k_dict[_k] = sorted(list(new), key=lambda x: x.label)
+ return k_dict
+
+
+def build_obs(env, k_dict, k_categories, global_dict, global_categories, vec_len=None):
+ """Given a k_dict from get_joints_at_kdist, extract observation vector.
+
+ :param k_dict: k_dict
+ :param qpos: qpos numpy array
+ :param qvel: qvel numpy array
+ :param vec_len: if None no padding, else zero-pad to vec_len
+ :return:
+ observation vector
+ """
+
+ # TODO: This needs to be fixed, it was designed for half-cheetah only!
+ #if add_global_pos:
+ # obs_qpos_lst.append(global_qpos)
+ # obs_qvel_lst.append(global_qvel)
+
+ body_set_dict = {}
+ obs_lst = []
+ # Add parts attributes
+ for k in sorted(list(k_dict.keys())):
+ cats = k_categories[k]
+ for _t in k_dict[k]:
+ for c in cats:
+ if c in _t.extra_obs:
+ items = _t.extra_obs[c](env).tolist()
+ obs_lst.extend(items if isinstance(items, list) else [items])
+ else:
+ if c in ["qvel", "qpos"]: # this is a "joint position/velocity" item
+ items = getattr(env.sim.data, c)[getattr(_t, "{}_ids".format(c))]
+ obs_lst.extend(items if isinstance(items, list) else [items])
+ elif c in ["qfrc_actuator"]: # this is a "vel position" item
+ items = getattr(env.sim.data, c)[getattr(_t, "{}_ids".format("qvel"))]
+ obs_lst.extend(items if isinstance(items, list) else [items])
+ elif c in ["cvel", "cinert", "cfrc_ext"]: # this is a "body position" item
+ if _t.bodies is not None:
+ for b in _t.bodies:
+ if c not in body_set_dict:
+ body_set_dict[c] = set()
+ if b not in body_set_dict[c]:
+ items = getattr(env.sim.data, c)[b].tolist()
+ items = getattr(_t, "body_fn", lambda _id, x: x)(b, items)
+ obs_lst.extend(items if isinstance(items, list) else [items])
+ body_set_dict[c].add(b)
+
+ # Add global attributes
+ body_set_dict = {}
+ for c in global_categories:
+ if c in ["qvel", "qpos"]: # this is a "joint position" item
+ for j in global_dict.get("joints", []):
+ items = getattr(env.sim.data, c)[getattr(j, "{}_ids".format(c))]
+ obs_lst.extend(items if isinstance(items, list) else [items])
+ else:
+ for b in global_dict.get("bodies", []):
+ if c not in body_set_dict:
+ body_set_dict[c] = set()
+ if b not in body_set_dict[c]:
+ obs_lst.extend(getattr(env.sim.data, c)[b].tolist())
+ body_set_dict[c].add(b)
+
+ if vec_len is not None:
+ pad = np.array((vec_len - len(obs_lst)) * [0])
+ if len(pad):
+ return np.concatenate([np.array(obs_lst), pad])
+ return np.array(obs_lst)
+
+
+def build_actions(agent_partitions, k_dict):
+ # Composes agent actions output from networks
+ # into coherent joint action vector to be sent to the env.
+ pass
+
+
+def get_parts_and_edges(label, partitioning):
+ if label in ["half_cheetah", "HalfCheetah-v2"]:
+
+ # define Mujoco graph
+ bthigh = Node("bthigh", -6, -6, 0)
+ bshin = Node("bshin", -5, -5, 1)
+ bfoot = Node("bfoot", -4, -4, 2)
+ fthigh = Node("fthigh", -3, -3, 3)
+ fshin = Node("fshin", -2, -2, 4)
+ ffoot = Node("ffoot", -1, -1, 5)
+
+ edges = [
+ HyperEdge(bfoot, bshin),
+ HyperEdge(bshin, bthigh),
+ HyperEdge(bthigh, fthigh),
+ HyperEdge(fthigh, fshin),
+ HyperEdge(fshin, ffoot)
+ ]
+
+ root_x = Node("root_x", 0, 0, -1, extra_obs={"qpos": lambda env: np.array([])})
+ root_z = Node("root_z", 1, 1, -1)
+ root_y = Node("root_y", 2, 2, -1)
+ globals = {"joints": [root_x, root_y, root_z]}
+
+ if partitioning == "2x3":
+ parts = [(bfoot, bshin, bthigh), (ffoot, fshin, fthigh)]
+ elif partitioning == "6x1":
+ parts = [(bfoot, ), (bshin, ), (bthigh, ), (ffoot, ), (fshin, ), (fthigh, )]
+ else:
+ raise Exception("UNKNOWN partitioning config: {}".format(partitioning))
+
+ return parts, edges, globals
+
+ elif label in ["Ant-v2"]:
+
+ # define Mujoco graph
+ torso = 1
+ front_left_leg = 2
+ aux_1 = 3
+ ankle_1 = 4
+ front_right_leg = 5
+ aux_2 = 6
+ ankle_2 = 7
+ back_leg = 8
+ aux_3 = 9
+ ankle_3 = 10
+ right_back_leg = 11
+ aux_4 = 12
+ ankle_4 = 13
+
+ hip1 = Node(
+ "hip1", -8, -8, 2, bodies=[torso, front_left_leg], body_fn=lambda _id, x: np.clip(x, -1, 1).tolist()
+ ) #
+ ankle1 = Node(
+ "ankle1",
+ -7,
+ -7,
+ 3,
+ bodies=[front_left_leg, aux_1, ankle_1],
+ body_fn=lambda _id, x: np.clip(x, -1, 1).tolist()
+ ) #,
+ hip2 = Node(
+ "hip2", -6, -6, 4, bodies=[torso, front_right_leg], body_fn=lambda _id, x: np.clip(x, -1, 1).tolist()
+ ) #,
+ ankle2 = Node(
+ "ankle2",
+ -5,
+ -5,
+ 5,
+ bodies=[front_right_leg, aux_2, ankle_2],
+ body_fn=lambda _id, x: np.clip(x, -1, 1).tolist()
+ ) #,
+ hip3 = Node("hip3", -4, -4, 6, bodies=[torso, back_leg], body_fn=lambda _id, x: np.clip(x, -1, 1).tolist()) #,
+ ankle3 = Node(
+ "ankle3", -3, -3, 7, bodies=[back_leg, aux_3, ankle_3], body_fn=lambda _id, x: np.clip(x, -1, 1).tolist()
+ ) #,
+ hip4 = Node(
+ "hip4", -2, -2, 0, bodies=[torso, right_back_leg], body_fn=lambda _id, x: np.clip(x, -1, 1).tolist()
+ ) #,
+ ankle4 = Node(
+ "ankle4",
+ -1,
+ -1,
+ 1,
+ bodies=[right_back_leg, aux_4, ankle_4],
+ body_fn=lambda _id, x: np.clip(x, -1, 1).tolist()
+ ) #,
+
+ edges = [
+ HyperEdge(ankle4, hip4),
+ HyperEdge(ankle1, hip1),
+ HyperEdge(ankle2, hip2),
+ HyperEdge(ankle3, hip3),
+ HyperEdge(hip4, hip1, hip2, hip3),
+ ]
+
+ free_joint = Node(
+ "free",
+ 0,
+ 0,
+ -1,
+ extra_obs={
+ "qpos": lambda env: env.sim.data.qpos[:7],
+ "qvel": lambda env: env.sim.data.qvel[:6],
+ "cfrc_ext": lambda env: np.clip(env.sim.data.cfrc_ext[0:1], -1, 1)
+ }
+ )
+ globals = {"joints": [free_joint]}
+
+ if partitioning == "2x4": # neighbouring legs together
+ parts = [(hip1, ankle1, hip2, ankle2), (hip3, ankle3, hip4, ankle4)]
+ elif partitioning == "2x4d": # diagonal legs together
+ parts = [(hip1, ankle1, hip3, ankle3), (hip2, ankle2, hip4, ankle4)]
+ elif partitioning == "4x2":
+ parts = [(hip1, ankle1), (hip2, ankle2), (hip3, ankle3), (hip4, ankle4)]
+ else:
+ raise Exception("UNKNOWN partitioning config: {}".format(partitioning))
+
+ return parts, edges, globals
+
+ elif label in ["Hopper-v2"]:
+
+ # define Mujoco-Graph
+ thigh_joint = Node(
+ "thigh_joint",
+ -3,
+ -3,
+ 0,
+ extra_obs={"qvel": lambda env: np.clip(np.array([env.sim.data.qvel[-3]]), -10, 10)}
+ )
+ leg_joint = Node(
+ "leg_joint", -2, -2, 1, extra_obs={"qvel": lambda env: np.clip(np.array([env.sim.data.qvel[-2]]), -10, 10)}
+ )
+ foot_joint = Node(
+ "foot_joint",
+ -1,
+ -1,
+ 2,
+ extra_obs={"qvel": lambda env: np.clip(np.array([env.sim.data.qvel[-1]]), -10, 10)}
+ )
+
+ edges = [HyperEdge(foot_joint, leg_joint), HyperEdge(leg_joint, thigh_joint)]
+
+ root_x = Node(
+ "root_x",
+ 0,
+ 0,
+ -1,
+ extra_obs={
+ "qpos": lambda env: np.array([]),
+ "qvel": lambda env: np.clip(np.array([env.sim.data.qvel[1]]), -10, 10)
+ }
+ )
+ root_z = Node(
+ "root_z", 1, 1, -1, extra_obs={"qvel": lambda env: np.clip(np.array([env.sim.data.qvel[1]]), -10, 10)}
+ )
+ root_y = Node(
+ "root_y", 2, 2, -1, extra_obs={"qvel": lambda env: np.clip(np.array([env.sim.data.qvel[2]]), -10, 10)}
+ )
+ globals = {"joints": [root_x, root_y, root_z]}
+
+ if partitioning == "3x1":
+ parts = [(thigh_joint, ), (leg_joint, ), (foot_joint, )]
+
+ else:
+ raise Exception("UNKNOWN partitioning config: {}".format(partitioning))
+
+ return parts, edges, globals
+
+ elif label in ["Humanoid-v2", "HumanoidStandup-v2"]:
+
+ # define Mujoco-Graph
+ abdomen_y = Node("abdomen_y", -16, -16, 0) # act ordering bug in env -- double check!
+ abdomen_z = Node("abdomen_z", -17, -17, 1)
+ abdomen_x = Node("abdomen_x", -15, -15, 2)
+ right_hip_x = Node("right_hip_x", -14, -14, 3)
+ right_hip_z = Node("right_hip_z", -13, -13, 4)
+ right_hip_y = Node("right_hip_y", -12, -12, 5)
+ right_knee = Node("right_knee", -11, -11, 6)
+ left_hip_x = Node("left_hip_x", -10, -10, 7)
+ left_hip_z = Node("left_hip_z", -9, -9, 8)
+ left_hip_y = Node("left_hip_y", -8, -8, 9)
+ left_knee = Node("left_knee", -7, -7, 10)
+ right_shoulder1 = Node("right_shoulder1", -6, -6, 11)
+ right_shoulder2 = Node("right_shoulder2", -5, -5, 12)
+ right_elbow = Node("right_elbow", -4, -4, 13)
+ left_shoulder1 = Node("left_shoulder1", -3, -3, 14)
+ left_shoulder2 = Node("left_shoulder2", -2, -2, 15)
+ left_elbow = Node("left_elbow", -1, -1, 16)
+
+ edges = [
+ HyperEdge(abdomen_x, abdomen_y, abdomen_z),
+ HyperEdge(right_hip_x, right_hip_y, right_hip_z),
+ HyperEdge(left_hip_x, left_hip_y, left_hip_z),
+ HyperEdge(left_elbow, left_shoulder1, left_shoulder2),
+ HyperEdge(right_elbow, right_shoulder1, right_shoulder2),
+ HyperEdge(left_knee, left_hip_x, left_hip_y, left_hip_z),
+ HyperEdge(right_knee, right_hip_x, right_hip_y, right_hip_z),
+ HyperEdge(left_shoulder1, left_shoulder2, abdomen_x, abdomen_y, abdomen_z),
+ HyperEdge(right_shoulder1, right_shoulder2, abdomen_x, abdomen_y, abdomen_z),
+ HyperEdge(abdomen_x, abdomen_y, abdomen_z, left_hip_x, left_hip_y, left_hip_z),
+ HyperEdge(abdomen_x, abdomen_y, abdomen_z, right_hip_x, right_hip_y, right_hip_z),
+ ]
+
+ globals = {}
+
+ if partitioning == "9|8": # 17 in total, so one action is a dummy (to be handled by pymarl)
+ # isolate upper and lower body
+ parts = [
+ (
+ left_shoulder1, left_shoulder2, abdomen_x, abdomen_y, abdomen_z, right_shoulder1, right_shoulder2,
+ right_elbow, left_elbow
+ ), (left_hip_x, left_hip_y, left_hip_z, right_hip_x, right_hip_y, right_hip_z, right_knee, left_knee)
+ ]
+ # TODO: There could be tons of decompositions here
+
+ else:
+ raise Exception("UNKNOWN partitioning config: {}".format(partitioning))
+
+ return parts, edges, globals
+
+ elif label in ["Reacher-v2"]:
+
+ # define Mujoco-Graph
+ body0 = 1
+ body1 = 2
+ fingertip = 3
+ joint0 = Node(
+ "joint0",
+ -4,
+ -4,
+ 0,
+ bodies=[body0, body1],
+ extra_obs={"qpos": (lambda env: np.array([np.sin(env.sim.data.qpos[-4]),
+ np.cos(env.sim.data.qpos[-4])]))}
+ )
+ joint1 = Node(
+ "joint1",
+ -3,
+ -3,
+ 1,
+ bodies=[body1, fingertip],
+ extra_obs={
+ "fingertip_dist": (lambda env: env.get_body_com("fingertip") - env.get_body_com("target")),
+ "qpos": (lambda env: np.array([np.sin(env.sim.data.qpos[-3]),
+ np.cos(env.sim.data.qpos[-3])]))
+ }
+ )
+ edges = [HyperEdge(joint0, joint1)]
+
+ worldbody = 0
+ target = 4
+ target_x = Node("target_x", -2, -2, -1, extra_obs={"qvel": (lambda env: np.array([]))})
+ target_y = Node("target_y", -1, -1, -1, extra_obs={"qvel": (lambda env: np.array([]))})
+ globals = {"bodies": [worldbody, target], "joints": [target_x, target_y]}
+
+ if partitioning == "2x1":
+ # isolate upper and lower arms
+ parts = [(joint0, ), (joint1, )]
+ # TODO: There could be tons of decompositions here
+
+ else:
+ raise Exception("UNKNOWN partitioning config: {}".format(partitioning))
+
+ return parts, edges, globals
+
+ elif label in ["Swimmer-v2"]:
+
+ # define Mujoco-Graph
+ joint0 = Node("rot2", -2, -2, 0) # TODO: double-check ids
+ joint1 = Node("rot3", -1, -1, 1)
+
+ edges = [HyperEdge(joint0, joint1)]
+ globals = {}
+
+ if partitioning == "2x1":
+ # isolate upper and lower body
+ parts = [(joint0, ), (joint1, )]
+ # TODO: There could be tons of decompositions here
+
+ else:
+ raise Exception("UNKNOWN partitioning config: {}".format(partitioning))
+
+ return parts, edges, globals
+
+ elif label in ["Walker2d-v2"]:
+
+ # define Mujoco-Graph
+ thigh_joint = Node("thigh_joint", -6, -6, 0)
+ leg_joint = Node("leg_joint", -5, -5, 1)
+ foot_joint = Node("foot_joint", -4, -4, 2)
+ thigh_left_joint = Node("thigh_left_joint", -3, -3, 3)
+ leg_left_joint = Node("leg_left_joint", -2, -2, 4)
+ foot_left_joint = Node("foot_left_joint", -1, -1, 5)
+
+ edges = [
+ HyperEdge(foot_joint, leg_joint),
+ HyperEdge(leg_joint, thigh_joint),
+ HyperEdge(foot_left_joint, leg_left_joint),
+ HyperEdge(leg_left_joint, thigh_left_joint),
+ HyperEdge(thigh_joint, thigh_left_joint)
+ ]
+ globals = {}
+
+ if partitioning == "2x3":
+ # isolate upper and lower body
+ parts = [(foot_joint, leg_joint, thigh_joint), (
+ foot_left_joint,
+ leg_left_joint,
+ thigh_left_joint,
+ )]
+ # TODO: There could be tons of decompositions here
+
+ else:
+ raise Exception("UNKNOWN partitioning config: {}".format(partitioning))
+
+ return parts, edges, globals
+
+ elif label in ["coupled_half_cheetah"]:
+
+ # define Mujoco graph
+ tendon = 0
+
+ bthigh = Node(
+ "bthigh",
+ -6,
+ -6,
+ 0,
+ tendons=[tendon],
+ extra_obs={
+ "ten_J": lambda env: env.sim.data.ten_J[tendon],
+ "ten_length": lambda env: env.sim.data.ten_length,
+ "ten_velocity": lambda env: env.sim.data.ten_velocity
+ }
+ )
+ bshin = Node("bshin", -5, -5, 1)
+ bfoot = Node("bfoot", -4, -4, 2)
+ fthigh = Node("fthigh", -3, -3, 3)
+ fshin = Node("fshin", -2, -2, 4)
+ ffoot = Node("ffoot", -1, -1, 5)
+
+ bthigh2 = Node(
+ "bthigh2",
+ -6,
+ -6,
+ 0,
+ tendons=[tendon],
+ extra_obs={
+ "ten_J": lambda env: env.sim.data.ten_J[tendon],
+ "ten_length": lambda env: env.sim.data.ten_length,
+ "ten_velocity": lambda env: env.sim.data.ten_velocity
+ }
+ )
+ bshin2 = Node("bshin2", -5, -5, 1)
+ bfoot2 = Node("bfoot2", -4, -4, 2)
+ fthigh2 = Node("fthigh2", -3, -3, 3)
+ fshin2 = Node("fshin2", -2, -2, 4)
+ ffoot2 = Node("ffoot2", -1, -1, 5)
+
+ edges = [
+ HyperEdge(bfoot, bshin),
+ HyperEdge(bshin, bthigh),
+ HyperEdge(bthigh, fthigh),
+ HyperEdge(fthigh, fshin),
+ HyperEdge(fshin, ffoot),
+ HyperEdge(bfoot2, bshin2),
+ HyperEdge(bshin2, bthigh2),
+ HyperEdge(bthigh2, fthigh2),
+ HyperEdge(fthigh2, fshin2),
+ HyperEdge(fshin2, ffoot2)
+ ]
+ globals = {}
+
+ root_x = Node("root_x", 0, 0, -1, extra_obs={"qpos": lambda env: np.array([])})
+ root_z = Node("root_z", 1, 1, -1)
+ root_y = Node("root_y", 2, 2, -1)
+ globals = {"joints": [root_x, root_y, root_z]}
+
+ if partitioning == "1p1":
+ parts = [(bfoot, bshin, bthigh, ffoot, fshin, fthigh), (bfoot2, bshin2, bthigh2, ffoot2, fshin2, fthigh2)]
+ else:
+ raise Exception("UNKNOWN partitioning config: {}".format(partitioning))
+
+ return parts, edges, globals
+
+ elif label in ["manyagent_swimmer"]:
+
+ # Generate asset file
+ try:
+ n_agents = int(partitioning.split("x")[0])
+ n_segs_per_agents = int(partitioning.split("x")[1])
+ n_segs = n_agents * n_segs_per_agents
+ except Exception as e:
+ raise Exception("UNKNOWN partitioning config: {}".format(partitioning))
+
+ # Note: Default Swimmer corresponds to n_segs = 3
+
+ # define Mujoco-Graph
+ joints = [Node("rot{:d}".format(i), -n_segs + i, -n_segs + i, i) for i in range(0, n_segs)]
+ edges = [HyperEdge(joints[i], joints[i + 1]) for i in range(n_segs - 1)]
+ globals = {}
+
+ parts = [tuple(joints[i * n_segs_per_agents:(i + 1) * n_segs_per_agents]) for i in range(n_agents)]
+ return parts, edges, globals
+
+ elif label in ["manyagent_ant"]: # TODO: FIX!
+
+ # Generate asset file
+ try:
+ n_agents = int(partitioning.split("x")[0])
+ n_segs_per_agents = int(partitioning.split("x")[1])
+ n_segs = n_agents * n_segs_per_agents
+ except Exception as e:
+ raise Exception("UNKNOWN partitioning config: {}".format(partitioning))
+
+ # # define Mujoco graph
+ # torso = 1
+ # front_left_leg = 2
+ # aux_1 = 3
+ # ankle_1 = 4
+ # right_back_leg = 11
+ # aux_4 = 12
+ # ankle_4 = 13
+ #
+ # off = -4*(n_segs-1)
+ # hip1 = Node("hip1", -4-off, -4-off, 2, bodies=[torso, front_left_leg], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist()) #
+ # ankle1 = Node("ankle1", -3-off, -3-off, 3, bodies=[front_left_leg, aux_1, ankle_1], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist())#,
+ # hip4 = Node("hip4", -2-off, -2-off, 0, bodies=[torso, right_back_leg], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist())#,
+ # ankle4 = Node("ankle4", -1-off, -1-off, 1, bodies=[right_back_leg, aux_4, ankle_4], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist())#,
+ #
+ # edges = [HyperEdge(ankle4, hip4),
+ # HyperEdge(ankle1, hip1),
+ # HyperEdge(hip4, hip1),
+ # ]
+
+ edges = []
+ joints = []
+ for si in range(n_segs):
+
+ torso = 1 + si * 7
+ front_right_leg = 2 + si * 7
+ aux1 = 3 + si * 7
+ ankle1 = 4 + si * 7
+ back_leg = 5 + si * 7
+ aux2 = 6 + si * 7
+ ankle2 = 7 + si * 7
+
+ off = -4 * (n_segs - 1 - si)
+ hip1n = Node(
+ "hip1_{:d}".format(si),
+ -4 - off,
+ -4 - off,
+ 2 + 4 * si,
+ bodies=[torso, front_right_leg],
+ body_fn=lambda _id, x: np.clip(x, -1, 1).tolist()
+ )
+ ankle1n = Node(
+ "ankle1_{:d}".format(si),
+ -3 - off,
+ -3 - off,
+ 3 + 4 * si,
+ bodies=[front_right_leg, aux1, ankle1],
+ body_fn=lambda _id, x: np.clip(x, -1, 1).tolist()
+ )
+ hip2n = Node(
+ "hip2_{:d}".format(si),
+ -2 - off,
+ -2 - off,
+ 0 + 4 * si,
+ bodies=[torso, back_leg],
+ body_fn=lambda _id, x: np.clip(x, -1, 1).tolist()
+ )
+ ankle2n = Node(
+ "ankle2_{:d}".format(si),
+ -1 - off,
+ -1 - off,
+ 1 + 4 * si,
+ bodies=[back_leg, aux2, ankle2],
+ body_fn=lambda _id, x: np.clip(x, -1, 1).tolist()
+ )
+
+ edges += [HyperEdge(ankle1n, hip1n), HyperEdge(ankle2n, hip2n), HyperEdge(hip1n, hip2n)]
+ if si:
+ edges += [HyperEdge(hip1m, hip2m, hip1n, hip2n)]
+
+ hip1m = deepcopy(hip1n)
+ hip2m = deepcopy(hip2n)
+ joints.append([hip1n, ankle1n, hip2n, ankle2n])
+
+ free_joint = Node(
+ "free",
+ 0,
+ 0,
+ -1,
+ extra_obs={
+ "qpos": lambda env: env.sim.data.qpos[:7],
+ "qvel": lambda env: env.sim.data.qvel[:6],
+ "cfrc_ext": lambda env: np.clip(env.sim.data.cfrc_ext[0:1], -1, 1)
+ }
+ )
+ globals = {"joints": [free_joint]}
+
+ parts = [
+ [x for sublist in joints[i * n_segs_per_agents:(i + 1) * n_segs_per_agents] for x in sublist]
+ for i in range(n_agents)
+ ]
+
+ return parts, edges, globals