未验证 提交 b50e8aea 编写于 作者: T timothijoe 提交者: GitHub

feature(zt): add curiosity icm algorithm (#41)

* curisity_icm_v1

* modified version1

* modified v2

* one_hot function change

* add paper information

* format minigrid ppo curiosity

* flake8 ding checked

* 6th-Oct-gpu-modified

* reset configs in minigrid files

* minigird-env-doorkey88-100-300

* use modulelist instead of list in icm module

* change icm reward model

* delete origin curiosit_reward model and add icm_reward model

* modified icm reward model

* polish icm model by zt, (1) polish ding/reward_model/icm_reward_model.py and related __init__.py (2) add config files for pong:dizoo/atari/config/serial/pong/pong_ppo_offpolicy_icm.py and minigrid env: dizoo/minigrid/config/doorkey8_icm_config.py,fourroom_icm_config.py,minigrid_icm_config.py  (3) add element icm in README

* remove some useless config files in minigrid

* remove redundant part in ppo.py, add cartpole_ppo_icm_config.py, changed test_icm.py and Readme
上级 5216fb31
......@@ -126,11 +126,12 @@ ding -m serial -e cartpole -p dqn -s 0
| 26 | [GCL](https://arxiv.org/pdf/1603.00448.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [reward_model/guided_cost](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/guided_cost_reward_model.py) | python3 lunarlander_gcl_config.py
| 27 | [HER](https://arxiv.org/pdf/1707.01495.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/her](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/her_reward_model.py) | python3 -u bitflip_her_dqn.py |
| 28 | [RND](https://arxiv.org/abs/1810.12894) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/rnd](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/rnd_reward_model.py) | python3 -u cartpole_ppo_rnd_main.py |
| 29 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py |
| 30 | [TD3BC](https://arxiv.org/pdf/2106.06860.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/td3_bc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3_bc.py) | python3 -u mujoco_td3_bc_main.py |
| 31 | [MBPO](https://arxiv.org/pdf/1906.08253.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [model/template/model_based/mbpo](https://github.com/opendilab/DI-engine/blob/main/ding/model/template/model_based/mbpo.py) | python3 -u sac_halfcheetah_mopo_default_config.py |
| 32 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` |
| 33 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` |
| 29 | [ICM](https://arxiv.org/pdf/1705.05363.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/icm](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/icm_reward_model.py) | python3 -u cartpole_ppo_icm_config.py |
| 30 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py |
| 31 | [TD3BC](https://arxiv.org/pdf/2106.06860.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/td3_bc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3_bc.py) | python3 -u mujoco_td3_bc_main.py |
| 32 | [MBPO](https://arxiv.org/pdf/1906.08253.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [model/template/model_based/mbpo](https://github.com/opendilab/DI-engine/blob/main/ding/model/template/model_based/mbpo.py) | python3 -u sac_halfcheetah_mopo_default_config.py |
| 33 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` |
| 34 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` |
![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space, which is only label in normal DRL algorithms (1-16)
......
import pytest
from easydict import EasyDict
from copy import deepcopy
from ding.entry import serial_pipeline_reward_model
from dizoo.classic_control.cartpole.config.cartpole_ppo_icm_config import cartpole_ppo_icm_config
from dizoo.classic_control.cartpole.config.cartpole_ppo_icm_config import cartpole_ppo_icm_create_config
@pytest.mark.unittest
def test_icm():
config = [deepcopy(cartpole_ppo_icm_config), deepcopy(cartpole_ppo_icm_create_config)]
try:
serial_pipeline_reward_model(config, seed=0, max_iterations=2)
except Exception:
assert False, "pipeline fail"
if __name__ == '__main__':
test_icm()
......@@ -667,25 +667,16 @@ class PPOOffPolicy(Policy):
Returns:
- transition (:obj:`dict`): Dict type transition data.
"""
if not self._nstep_return:
transition = {
'obs': obs,
'logit': model_output['logit'],
'action': model_output['action'],
'value': model_output['value'],
'reward': timestep.reward,
'done': timestep.done,
}
else:
transition = {
'obs': obs,
'next_obs': timestep.obs,
'logit': model_output['logit'],
'action': model_output['action'],
'value': model_output['value'],
'reward': timestep.reward,
'done': timestep.done,
}
transition = {
'obs': obs,
'next_obs': timestep.obs,
'logit': model_output['logit'],
'action': model_output['action'],
'value': model_output['value'],
'reward': timestep.reward,
'done': timestep.done,
}
return transition
def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
......
......@@ -10,3 +10,4 @@ from .her_reward_model import HerRewardModel
from .rnd_reward_model import RndRewardModel
from .guided_cost_reward_model import GuidedCostRewardModel
from .ngu_reward_model import RndNGURewardModel, EpisodicNGURewardModel
from .icm_reward_model import ICMRewardModel
from typing import Union, Tuple
from easydict import EasyDict
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from ding.utils import SequenceType, REWARD_MODEL_REGISTRY
from ding.model import FCEncoder, ConvEncoder
from ding.torch_utils import one_hot
from .base_reward_model import BaseRewardModel
def collect_states(iterator: list) -> Tuple[list, list, list]:
states = []
next_states = []
actions = []
for item in iterator:
state = item['obs']
next_state = item['next_obs']
action = item['action']
states.append(state)
next_states.append(next_state)
actions.append(action)
return states, next_states, actions
class ICMNetwork(nn.Module):
r"""
Intrinsic Curiosity Model (ICM Module)
Implementation of:
[1] Curiosity-driven Exploration by Self-supervised Prediction
Pathak, Agrawal, Efros, and Darrell - UC Berkeley - ICML 2017.
https://arxiv.org/pdf/1705.05363.pdf
[2] Code implementation reference:
https://github.com/pathak22/noreward-rl
https://github.com/jcwleo/curiosity-driven-exploration-pytorch
1) Embedding observations into a latent space
2) Predicting the action logit given two consecutive embedded observations
3) Predicting the next embedded obs, given the embeded former observation and action
"""
def __init__(self, obs_shape: Union[int, SequenceType], hidden_size_list: SequenceType, action_shape: int) -> None:
super(ICMNetwork, self).__init__()
if isinstance(obs_shape, int) or len(obs_shape) == 1:
self.feature = FCEncoder(obs_shape, hidden_size_list)
elif len(obs_shape) == 3:
self.feature = ConvEncoder(obs_shape, hidden_size_list)
else:
raise KeyError(
"not support obs_shape for pre-defined encoder: {}, please customize your own ICM model".
format(obs_shape)
)
self.action_shape = action_shape
feature_output = hidden_size_list[-1]
self.inverse_net = nn.Sequential(nn.Linear(feature_output * 2, 512), nn.ReLU(), nn.Linear(512, action_shape))
self.residual = nn.ModuleList(
[
nn.Sequential(
nn.Linear(action_shape + 512, 512),
nn.LeakyReLU(),
nn.Linear(512, 512),
) for _ in range(8)
]
)
self.forward_net_1 = nn.Sequential(nn.Linear(action_shape + feature_output, 512), nn.LeakyReLU())
self.forward_net_2 = nn.Linear(action_shape + 512, feature_output)
def forward(self, state: torch.Tensor, next_state: torch.Tensor,
action_long: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
r"""
Overview:
Use observation, next_observation and action to genearte ICM module
Parameter updates with ICMNetwork forward setup.
Arguments:
- state (:obj:`torch.Tensor`):
The current state batch
- next_state (:obj:`torch.Tensor`):
The next state batch
- action_long (:obj:`torch.Tensor`):
The action batch
Returns:
- real_next_state_feature (:obj:`torch.Tensor`):
Run with the encoder. Return the real next_state's embedded feature.
- pred_next_state_feature (:obj:`torch.Tensor`):
Run with the encoder and residual network. Return the predicted next_state's embedded feature.
- pred_action_logit (:obj:`torch.Tensor`):
Run with the encoder. Return the predicted action logit.
Shapes:
- state (:obj:`torch.Tensor`): :math:`(B, N)`, where B is the batch size and N is ''obs_shape''
- next_state (:obj:`torch.Tensor`): :math:`(B, N)`, where B is the batch size and N is ''obs_shape''
- action_long (:obj:`torch.Tensor`): :math:`(B)`, where B is the batch size''
- real_next_state_feature (:obj:`torch.Tensor`): :math:`(B, M)`, where B is the batch size
and M is embedded feature size
- pred_next_state_feature (:obj:`torch.Tensor`): :math:`(B, M)`, where B is the batch size
and M is embedded feature size
- pred_action_logit (:obj:`torch.Tensor`): :math:`(B, A)`, where B is the batch size
and A is the ''action_shape''
"""
action = one_hot(action_long, num=self.action_shape)
encode_state = self.feature(state)
encode_next_state = self.feature(next_state)
# get pred action logit
concat_state = torch.cat((encode_state, encode_next_state), 1)
pred_action_logit = self.inverse_net(concat_state)
# ---------------------
# get pred next state
pred_next_state_feature_orig = torch.cat((encode_state, action), 1)
pred_next_state_feature_orig = self.forward_net_1(pred_next_state_feature_orig)
# residual
for i in range(4):
pred_next_state_feature = self.residual[i * 2](torch.cat((pred_next_state_feature_orig, action), 1))
pred_next_state_feature_orig = self.residual[i * 2 + 1](
torch.cat((pred_next_state_feature, action), 1)
) + pred_next_state_feature_orig
pred_next_state_feature = self.forward_net_2(torch.cat((pred_next_state_feature_orig, action), 1))
real_next_state_feature = encode_next_state
return real_next_state_feature, pred_next_state_feature, pred_action_logit
@REWARD_MODEL_REGISTRY.register('icm')
class ICMRewardModel(BaseRewardModel):
"""
Overview:
The ICM reward model class (https://arxiv.org/pdf/1705.05363.pdf)
Interface:
``estimate``, ``train``, ``collect_data``, ``clear_data``, \
``__init__``, ``_train``,
"""
config = dict(
# (str) the type of the exploration method
type='icm',
# (str) the intrinsic reward type, including add, new, or assign
intrinsic_reward_type='add',
# (float) learning rate of the optimizer
learning_rate=1e-3,
# (Tuple[int, list]), the observation shape,
obs_shape=6,
# (int) the action shape, support discrete action only in this version
action_shape=7,
# (float) batch size
batch_size=64,
# (list) the MLP layer shape
hidden_size_list=[64, 64, 128],
# (int) update how many times after each collect
update_per_collect=100,
# (float) the importance weight of the forward and reverse loss
reverse_scale=1,
)
def __init__(self, config: EasyDict, device: str, tb_logger: 'SummaryWriter') -> None: # noqa
super(ICMRewardModel, self).__init__()
self.cfg = config
assert device == "cpu" or device.startswith("cuda")
self.device = device
self.tb_logger = tb_logger
self.reward_model = ICMNetwork(config.obs_shape, config.hidden_size_list, config.action_shape)
self.reward_model.to(self.device)
self.intrinsic_reward_type = config.intrinsic_reward_type
assert self.intrinsic_reward_type in ['add', 'new', 'assign']
self.train_data = []
self.train_states = []
self.train_next_states = []
self.train_actions = []
self.opt = optim.Adam(self.reward_model.parameters(), config.learning_rate)
self.ce = nn.CrossEntropyLoss(reduction="mean")
self.forward_mse = nn.MSELoss(reduction='none')
self.reverse_scale = config.reverse_scale
def _train(self) -> None:
train_data_list = [i for i in range(0, len(self.train_states))]
train_data_index = random.sample(train_data_list, self.cfg.batch_size)
data_states: list = [self.train_states[i] for i in train_data_index]
data_states: torch.Tensor = torch.stack(data_states).to(self.device)
data_next_states: list = [self.train_next_states[i] for i in train_data_index]
data_next_states: torch.Tensor = torch.stack(data_next_states).to(self.device)
data_actions: list = [self.train_actions[i] for i in train_data_index]
data_actions: torch.Tensor = torch.cat(data_actions).to(self.device)
real_next_state_feature, pred_next_state_feature, pred_action_logit = self.reward_model(
data_states, data_next_states, data_actions
)
inverse_loss = self.ce(pred_action_logit, data_actions.long())
forward_loss = self.forward_mse(pred_next_state_feature, real_next_state_feature.detach()).mean()
loss = self.reverse_scale * inverse_loss + forward_loss
self.opt.zero_grad()
loss.backward()
self.opt.step()
def train(self) -> None:
for _ in range(self.cfg.update_per_collect):
self._train()
self.clear_data()
def estimate(self, data: list) -> None:
states, next_states, actions = collect_states(data)
states = torch.stack(states).to(self.device)
next_states = torch.stack(next_states).to(self.device)
actions = torch.cat(actions).to(self.device)
with torch.no_grad():
real_next_state_feature, pred_next_state_feature, _ = self.reward_model(states, next_states, actions)
reward = self.forward_mse(real_next_state_feature, pred_next_state_feature).mean(dim=1)
reward = (reward - reward.min()) / (reward.max() - reward.min() + 1e-8)
reward = reward.to(data[0]['reward'].device)
reward = torch.chunk(reward, reward.shape[0], dim=0)
for item, rew in zip(data, reward):
if self.intrinsic_reward_type == 'add':
item['reward'] += rew
elif self.intrinsic_reward_type == 'new':
item['intrinsic_reward'] = rew
elif self.intrinsic_reward_type == 'assign':
item['reward'] = rew
def collect_data(self, data: list) -> None:
self.train_data.extend(collect_states(data))
states, next_states, actions = collect_states(data)
self.train_states.extend(states)
self.train_next_states.extend(next_states)
self.train_actions.extend(actions)
def clear_data(self) -> None:
self.train_data.clear()
self.train_states.clear()
self.train_next_states.clear()
self.train_actions.clear()
from copy import deepcopy
from ding.entry import serial_pipeline_reward_model
from easydict import EasyDict
pong_ppo_icm_config = dict(
exp_name='pong_ppo_icm',
env=dict(
collector_env_num=16,
evaluator_env_num=4,
n_evaluator_episode=8,
stop_value=20,
env_id='PongNoFrameskip-v4',
frame_stack=4,
),
reward_model=dict(
intrinsic_reward_type='add',
learning_rate=0.001,
obs_shape=[4, 84, 84],
action_shape=6,
batch_size=32,
update_per_collect=10,
),
policy=dict(
cuda=True,
on_policy=False,
# (bool) whether use on-policy training pipeline(behaviour policy and training policy are the same)
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[64, 64, 128],
actor_head_hidden_size=128,
critic_head_hidden_size=128,
),
learn=dict(
update_per_collect=24,
batch_size=128,
# (bool) Whether to normalize advantage. Default to False.
adv_norm=False,
learning_rate=0.0002,
# (float) loss weight of the value network, the weight of policy network is set to 1
value_weight=0.5,
# (float) loss weight of the entropy regularization, the weight of policy network is set to 1
entropy_weight=0.015,
clip_ratio=0.1,
),
collect=dict(
# (int) collect n_sample data, train model n_iteration times
n_sample=1024,
# (float) the trade-off factor lambda to balance 1step td and mc
gae_lambda=0.95,
discount_factor=0.99,
),
eval=dict(evaluator=dict(eval_freq=1000, )),
other=dict(replay_buffer=dict(
replay_buffer_size=100000,
max_use=3,
), ),
),
)
main_config = EasyDict(pong_ppo_icm_config)
pong_ppo_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='base'),
policy=dict(type='ppo_offpolicy'),
reward_model=dict(type='icm'),
)
create_config = EasyDict(pong_ppo_create_config)
if __name__ == '__main__':
serial_pipeline_reward_model([main_config, create_config], seed=0)
from easydict import EasyDict
from ding.entry import serial_pipeline_reward_model
cartpole_ppo_icm_config = dict(
exp_name='cartpole_ppo',
env=dict(
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=5,
stop_value=195,
),
reward_model=dict(
intrinsic_reward_type='add',
learning_rate=0.001,
obs_shape=4,
action_shape=2,
batch_size=32,
update_per_collect=10,
),
policy=dict(
cuda=False,
continuous=False,
model=dict(
obs_shape=4,
action_shape=2,
encoder_hidden_size_list=[64, 64, 128],
critic_head_hidden_size=128,
actor_head_hidden_size=128,
),
learn=dict(
epoch_per_collect=2,
batch_size=64,
learning_rate=0.001,
value_weight=0.5,
entropy_weight=0.01,
clip_ratio=0.2,
),
collect=dict(
n_sample=256,
unroll_len=1,
discount_factor=0.9,
gae_lambda=0.95,
),
eval=dict(evaluator=dict(eval_freq=100, ), ),
),
)
cartpole_ppo_icm_config = EasyDict(cartpole_ppo_icm_config)
main_config = cartpole_ppo_icm_config
cartpole_ppo_icm_create_config = dict(
env=dict(
type='cartpole',
import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
),
env_manager=dict(type='base'),
policy=dict(type='ppo_offpolicy'),
reward_model=dict(type='icm'),
)
cartpole_ppo_icm_create_config = EasyDict(cartpole_ppo_icm_create_config)
create_config = cartpole_ppo_icm_create_config
if __name__ == '__main__':
serial_pipeline_reward_model([main_config, create_config], seed=0)
\ No newline at end of file
from easydict import EasyDict
from ding.entry import serial_pipeline_reward_model
minigrid_ppo_icm_config = dict(
exp_name='doorkey8_ppo_icm',
env=dict(
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=5,
env_id='MiniGrid-DoorKey-8x8-v0',
stop_value=0.96,
),
reward_model=dict(
intrinsic_reward_type='add',
learning_rate=0.001,
obs_shape=2739,
batch_size=320,
update_per_collect=10,
),
policy=dict(
cuda=True,
model=dict(
obs_shape=2739,
action_shape=7,
encoder_hidden_size_list=[256, 128, 64, 64],
),
learn=dict(
update_per_collect=10,
batch_size=320,
learning_rate=0.0003,
value_weight=0.5,
entropy_weight=0.001,
clip_ratio=0.2,
adv_norm=False,
),
collect=dict(
n_sample=3200,
unroll_len=1,
discount_factor=0.99,
gae_lambda=0.95,
),
),
)
minigrid_ppo_icm_config = EasyDict(minigrid_ppo_icm_config)
main_config = minigrid_ppo_icm_config
minigrid_ppo_icm_create_config = dict(
env=dict(
type='minigrid',
import_names=['dizoo.minigrid.envs.minigrid_env'],
),
env_manager=dict(type='base'),
policy=dict(type='ppo_offpolicy'),
reward_model=dict(type='icm'),
)
minigrid_ppo_icm_create_config = EasyDict(minigrid_ppo_icm_create_config)
create_config = minigrid_ppo_icm_create_config
if __name__ == "__main__":
serial_pipeline_reward_model([main_config, create_config], seed=10)
\ No newline at end of file
from easydict import EasyDict
from ding.entry import serial_pipeline_reward_model
minigrid_ppo_icm_config = dict(
exp_name='fourroom_ppo_icm',
env=dict(
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=5,
env_id='MiniGrid-FourRooms-v0',
stop_value=0.96,
),
reward_model=dict(
intrinsic_reward_type='add',
learning_rate=0.001,
obs_shape=2739,
action_shape=7,
batch_size=32,
update_per_collect=10,
),
policy=dict(
cuda=True,
model=dict(
obs_shape=2739,
action_shape=7,
encoder_hidden_size_list=[256, 128, 64, 64],
),
learn=dict(
update_per_collect=4,
batch_size=64,
learning_rate=0.0003,
value_weight=0.5,
entropy_weight=0.001,
clip_ratio=0.2,
adv_norm=False,
),
collect=dict(
n_sample=128,
unroll_len=1,
discount_factor=0.99,
gae_lambda=0.95,
),
),
)
minigrid_ppo_icm_config = EasyDict(minigrid_ppo_icm_config)
main_config = minigrid_ppo_icm_config
minigrid_ppo_icm_create_config = dict(
env=dict(
type='minigrid',
import_names=['dizoo.minigrid.envs.minigrid_env'],
),
env_manager=dict(type='base'),
policy=dict(type='ppo_offpolicy'),
reward_model=dict(type='icm'),
)
minigrid_ppo_icm_create_config = EasyDict(minigrid_ppo_icm_create_config)
create_config = minigrid_ppo_icm_create_config
if __name__ == "__main__":
serial_pipeline_reward_model([main_config, create_config], seed=0)
\ No newline at end of file
from easydict import EasyDict
from ding.entry import serial_pipeline_reward_model
minigrid_ppo_icm_config = dict(
exp_name='minigrid_empty8_ppo_icm',
env=dict(
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=5,
env_id='MiniGrid-Empty-8x8-v0',
stop_value=0.96,
),
reward_model=dict(
intrinsic_reward_type='add',
learning_rate=0.001,
obs_shape=2739,
batch_size=32,
update_per_collect=10,
),
policy=dict(
cuda=True,
model=dict(
obs_shape=2739,
action_shape=7,
encoder_hidden_size_list=[256, 128, 64, 64],
),
learn=dict(
update_per_collect=4,
batch_size=64,
learning_rate=0.0003,
value_weight=0.5,
entropy_weight=0.001,
clip_ratio=0.2,
adv_norm=False,
),
collect=dict(
n_sample=128,
unroll_len=1,
discount_factor=0.99,
gae_lambda=0.95,
),
),
)
minigrid_ppo_icm_config = EasyDict(minigrid_ppo_icm_config)
main_config = minigrid_ppo_icm_config
minigrid_ppo_icm_create_config = dict(
env=dict(
type='minigrid',
import_names=['dizoo.minigrid.envs.minigrid_env'],
),
env_manager=dict(type='base'),
policy=dict(type='ppo_offpolicy'),
reward_model=dict(type='icm'),
)
minigrid_ppo_icm_create_config = EasyDict(minigrid_ppo_icm_create_config)
create_config = minigrid_ppo_icm_create_config
if __name__ == "__main__":
serial_pipeline_reward_model([main_config, create_config], seed=0)
......@@ -43,6 +43,29 @@ MINIGRID_INFO_DICT = {
max_step=100,
use_wrappers=None,
),
'MiniGrid-DoorKey-8x8-v0': MiniGridEnvInfo(
agent_num=1,
obs_space=EnvElementInfo(shape=(2739, ), value={
'min': 0,
'max': 8,
'dtype': np.float32
}),
act_space=EnvElementInfo(
shape=(1, ),
value={
'min': 0,
'max': 7, # [0, 7)
'dtype': np.int64,
}
),
rew_space=EnvElementInfo(shape=(1, ), value={
'min': 0,
'max': 1,
'dtype': np.float32
}),
max_step=300,
use_wrappers=None,
),
'MiniGrid-FourRooms-v0': MiniGridEnvInfo(
agent_num=1,
obs_space=EnvElementInfo(shape=(2739, ), value={
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册