未验证 提交 63105fef 编写于 作者: W Will-Nie 提交者: GitHub

feature(nyp): add Trex algorithm (#119)

* add trex algorithm for pong

* sort style

* add atari, ll,cp; fix device, collision; add_ppo

* add accuracy evaluation

* correct style

* add seed to make sure results are replicable

* remove useless part in cum return  of model part

* add mujoco onppo training pipeline; ppo config

* improve style

* add sac training config for mujoco

* add log, add save data; polish config

* logger; hyperparameter;walker

* correct style

* modify else condition

* change rnd to trex

* revise according to comments, add eposode collect

* new collect mode for trex, fix all bugs, commnets

* final change

* polish after the final comment

* add readme/test

* add test for serial entry of trex/gcl

* sort style
上级 18b3720a
......@@ -136,14 +136,15 @@ ding -m serial -e cartpole -p dqn -s 0
| 26 | [DQFD](https://arxiv.org/pdf/1704.03732.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [policy/dqfd](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqfd.py) | ding -m serial_dqfd -c cartpole_dqfd_config.py -s 0 |
| 27 | [R2D3](https://arxiv.org/pdf/1909.01387.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [policy/r2d3](https://github.com/opendilab/DI-engine/blob/main/ding/policy/r2d3.py) | python3 -u pong_r2d3_r2d2expert_config.py |
| 28 | [Guided Cost Learning](https://arxiv.org/pdf/1603.00448.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [reward_model/guided_cost](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/guided_cost_reward_model.py) | python3 lunarlander_gcl_config.py |
| 29 | [HER](https://arxiv.org/pdf/1707.01495.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/her](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/her_reward_model.py) | python3 -u bitflip_her_dqn.py |
| 30 | [RND](https://arxiv.org/abs/1810.12894) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/rnd](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/rnd_reward_model.py) | python3 -u cartpole_ppo_rnd_main.py |
| 31 | [ICM](https://arxiv.org/pdf/1705.05363.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/icm](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/icm_reward_model.py) | python3 -u cartpole_ppo_icm_config.py |
| 32 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py |
| 33 | [TD3BC](https://arxiv.org/pdf/2106.06860.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/td3_bc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3_bc.py) | python3 -u mujoco_td3_bc_main.py |
| 34 | [MBPO](https://arxiv.org/pdf/1906.08253.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [model/template/model_based/mbpo](https://github.com/opendilab/DI-engine/blob/main/ding/model/template/model_based/mbpo.py) | python3 -u sac_halfcheetah_mopo_default_config.py |
| 35 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` |
| 36 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` |
| 29 | [TREX](https://arxiv.org/abs/1904.06387) | ![IL](https://img.shields.io/badge/-IL-purple) | [reward_model/trex](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/trex_reward_model.py) | python3 mujoco_trex_main.py
| 30 | [HER](https://arxiv.org/pdf/1707.01495.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/her](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/her_reward_model.py) | python3 -u bitflip_her_dqn.py |
| 31 | [RND](https://arxiv.org/abs/1810.12894) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/rnd](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/rnd_reward_model.py) | python3 -u cartpole_ppo_rnd_main.py |
| 32 | [ICM](https://arxiv.org/pdf/1705.05363.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/icm](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/icm_reward_model.py) | python3 -u cartpole_ppo_icm_config.py |
| 33 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py |
| 34 | [TD3BC](https://arxiv.org/pdf/2106.06860.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/td3_bc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3_bc.py) | python3 -u mujoco_td3_bc_main.py |
| 35 | [MBPO](https://arxiv.org/pdf/1906.08253.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [model/template/model_based/mbpo](https://github.com/opendilab/DI-engine/blob/main/ding/model/template/model_based/mbpo.py) | python3 -u sac_halfcheetah_mopo_default_config.py |
| 36 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` |
| 37 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` |
![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space, which is only label in normal DRL algorithms (1-18)
......
......@@ -10,6 +10,10 @@ from .serial_entry_mbrl import serial_pipeline_mbrl
from .serial_entry_dqfd import serial_pipeline_dqfd
from .serial_entry_r2d3 import serial_pipeline_r2d3
from .serial_entry_sqil import serial_pipeline_sqil
from .serial_entry_trex import serial_pipeline_reward_model_trex
from .serial_entry_trex_onpolicy import serial_pipeline_reward_model_trex_onpolicy
from .parallel_entry import parallel_pipeline
from .application_entry import eval, collect_demo_data
from .application_entry import eval, collect_demo_data, collect_episodic_demo_data, \
epsiode_to_transitions
from .application_entry_trex_collect_data import trex_collecting_data, collect_episodic_demo_data_for_trex
from .serial_entry_guided_cost import serial_pipeline_guided_cost
......@@ -2,14 +2,17 @@ from typing import Union, Optional, List, Any, Tuple
import pickle
import torch
from functools import partial
import os
from ding.config import compile_config, read_config
from ding.worker import SampleSerialCollector, InteractionSerialEvaluator
from ding.worker import SampleSerialCollector, InteractionSerialEvaluator, EpisodeSerialCollector
from ding.envs import create_env_manager, get_vec_env_setting
from ding.policy import create_policy
from ding.torch_utils import to_device
from ding.utils import set_pkg_seed
from ding.utils.data import offline_data_save_type
from ding.rl_utils import get_nstep_return_data
from ding.utils.data import default_collate
def eval(
......@@ -141,7 +144,7 @@ def collect_demo_data(
policy.collect_mode.load_state_dict(state_dict)
collector = SampleSerialCollector(cfg.policy.collect.collector, collector_env, collect_demo_policy)
policy_kwargs = None if not hasattr(cfg.policy.other.get('eps', None), 'collect') \
policy_kwargs = None if not hasattr(cfg.policy.other, 'eps') \
else {'eps': cfg.policy.other.eps.get('collect', 0.2)}
# Let's collect some expert demonstrations
......@@ -151,3 +154,96 @@ def collect_demo_data(
# Save data transitions.
offline_data_save_type(exp_data, expert_data_path, data_type=cfg.policy.collect.get('data_type', 'naive'))
print('Collect demo data successfully')
def collect_episodic_demo_data(
input_cfg: Union[str, dict],
seed: int,
collect_count: int,
expert_data_path: str,
env_setting: Optional[List[Any]] = None,
model: Optional[torch.nn.Module] = None,
state_dict: Optional[dict] = None,
state_dict_path: Optional[str] = None,
) -> None:
r"""
Overview:
Collect episodic demonstration data by the trained policy.
Arguments:
- input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
``str`` type means config file path. \
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- collect_count (:obj:`int`): The count of collected data.
- expert_data_path (:obj:`str`): File path of the expert demo data will be written to.
- env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
``BaseEnv`` subclass, collector env config, and evaluator env config.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- state_dict (:obj:`Optional[dict]`): The state_dict of policy or model.
- state_dict_path (:obj:'str') the abs path of the state dict
"""
if isinstance(input_cfg, str):
cfg, create_cfg = read_config(input_cfg)
else:
cfg, create_cfg = input_cfg
create_cfg.policy.type += '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(
cfg,
collector=EpisodeSerialCollector,
seed=seed,
env=env_fn,
auto=True,
create_cfg=create_cfg,
save_cfg=True,
save_path='collect_demo_data_config.py'
)
# Create components: env, policy, collector
if env_setting is None:
env_fn, collector_env_cfg, _ = get_vec_env_setting(cfg.env)
else:
env_fn, collector_env_cfg, _ = env_setting
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
collector_env.seed(seed)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
policy = create_policy(cfg.policy, model=model, enable_field=['collect', 'eval'])
collect_demo_policy = policy.collect_mode
if state_dict is None:
assert state_dict_path is not None
state_dict = torch.load(state_dict_path, map_location='cpu')
policy.collect_mode.load_state_dict(state_dict)
collector = EpisodeSerialCollector(cfg.policy.collect.collector, collector_env, collect_demo_policy)
policy_kwargs = None if not hasattr(cfg.policy.other, 'eps') \
else {'eps': cfg.policy.other.eps.get('collect', 0.2)}
# Let's collect some expert demostrations
exp_data = collector.collect(n_episode=collect_count, policy_kwargs=policy_kwargs)
if cfg.policy.cuda:
exp_data = to_device(exp_data, 'cpu')
# Save data transitions.
offline_data_save_type(exp_data, expert_data_path, data_type=cfg.policy.collect.get('data_type', 'naive'))
print('Collect episodic demo data successfully')
def epsiode_to_transitions(data_path: str, expert_data_path: str, nstep: int) -> None:
r"""
Overview:
Transfer episoded data into nstep transitions
Arguments:
- data_path (:obj:str): data path that stores the pkl file
- expert_data_path (:obj:`str`): File path of the expert demo data will be written to.
- nstep (:obj:`int`): {s_{t}, a_{t}, s_{t+n}}.
"""
with open(data_path, 'rb') as f:
_dict = pickle.load(f) # class is list; length is cfg.reward_model.collect_count
post_process_data = []
for i in range(len(_dict)):
data = get_nstep_return_data(_dict[i], nstep)
post_process_data.extend(data)
offline_data_save_type(
post_process_data,
expert_data_path,
)
import argparse
import torch
from typing import Union, Optional, List, Any
from functools import partial
import os
from copy import deepcopy
from ding.config import compile_config, read_config
from ding.worker import EpisodeSerialCollector
from ding.envs import create_env_manager, get_vec_env_setting
from ding.policy import create_policy
from ding.torch_utils import to_device
from ding.utils import set_pkg_seed
from ding.utils.data import offline_data_save_type
from ding.utils.data import default_collate
def collect_episodic_demo_data_for_trex(
input_cfg: Union[str, dict],
seed: int,
collect_count: int,
rank: int,
save_cfg_path: str,
env_setting: Optional[List[Any]] = None,
model: Optional[torch.nn.Module] = None,
state_dict: Optional[dict] = None,
state_dict_path: Optional[str] = None,
) -> None:
r"""
Overview:
Collect episodic demonstration data by the trained policy for trex specifically.
Arguments:
- input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
``str`` type means config file path. \
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- collect_count (:obj:`int`): The count of collected data.
- rank (:obj:`int`) the episode ranking.
- save_cfg_path(:obj:'str') where to save the collector config
- env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
``BaseEnv`` subclass, collector env config, and evaluator env config.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- state_dict (:obj:`Optional[dict]`): The state_dict of policy or model.
- state_dict_path (:obj:'str') the abs path of the state dict
"""
if isinstance(input_cfg, str):
cfg, create_cfg = read_config(input_cfg)
else:
cfg, create_cfg = input_cfg
create_cfg.policy.type += '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg.env.collector_env_num = 1
if not os.path.exists(save_cfg_path):
os.mkdir(save_cfg_path)
cfg = compile_config(
cfg,
collector=EpisodeSerialCollector,
seed=seed,
env=env_fn,
auto=True,
create_cfg=create_cfg,
save_cfg=True,
save_path=save_cfg_path + '/collect_demo_data_config.py'
)
# Create components: env, policy, collector
if env_setting is None:
env_fn, collector_env_cfg, _ = get_vec_env_setting(cfg.env)
else:
env_fn, collector_env_cfg, _ = env_setting
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
collector_env.seed(seed)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
policy = create_policy(cfg.policy, model=model, enable_field=['collect', 'eval'])
collect_demo_policy = policy.collect_mode
if state_dict is None:
assert state_dict_path is not None
state_dict = torch.load(state_dict_path, map_location='cpu')
policy.collect_mode.load_state_dict(state_dict)
collector = EpisodeSerialCollector(cfg.policy.collect.collector, collector_env, collect_demo_policy)
policy_kwargs = None if not hasattr(cfg.policy.other, 'eps') \
else {'eps': cfg.policy.other.eps.get('collect', 0.2)}
# Let's collect some sub-optimal demostrations
exp_data = collector.collect(n_episode=collect_count, policy_kwargs=policy_kwargs)
if cfg.policy.cuda:
exp_data = to_device(exp_data, 'cpu')
# Save data transitions.
print('Collect {}th episodic demo data successfully'.format(rank))
return exp_data
def trex_get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--cfg', type=str, default='abs path for a config')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
args = parser.parse_known_args()[0]
return args
def trex_collecting_data(args=trex_get_args()):
if isinstance(args.cfg, str):
cfg, create_cfg = read_config(args.cfg)
else:
cfg, create_cfg = deepcopy(args.cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
compiled_cfg = compile_config(cfg, seed=args.seed, auto=True, create_cfg=create_cfg, save_cfg=False)
offline_data_path = compiled_cfg.reward_model.offline_data_path
expert_model_path = compiled_cfg.reward_model.expert_model_path
checkpoint_min = compiled_cfg.reward_model.checkpoint_min
checkpoint_max = compiled_cfg.reward_model.checkpoint_max
checkpoint_step = compiled_cfg.reward_model.checkpoint_step
checkpoints = []
for i in range(checkpoint_min, checkpoint_max + checkpoint_step, checkpoint_step):
checkpoints.append(str(i))
data_for_save = {}
learning_returns = []
learning_rewards = []
episodes_data = []
for checkpoint in checkpoints:
model_path = expert_model_path + \
'/ckpt/iteration_' + checkpoint + '.pth.tar'
seed = args.seed + (int(checkpoint) - int(checkpoint_min)) // int(checkpoint_step)
exp_data = collect_episodic_demo_data_for_trex(
deepcopy(args.cfg),
seed,
state_dict_path=model_path,
save_cfg_path=offline_data_path,
collect_count=1,
rank=(int(checkpoint) - int(checkpoint_min)) // int(checkpoint_step) + 1
)
data_for_save[(int(checkpoint) - int(checkpoint_min)) // int(checkpoint_step)] = exp_data[0]
obs = list(default_collate(exp_data[0])['obs'].numpy())
learning_rewards.append(default_collate(exp_data[0])['reward'].tolist())
sum_reward = torch.sum(default_collate(exp_data[0])['reward']).item()
learning_returns.append(sum_reward)
episodes_data.append(obs)
offline_data_save_type(
data_for_save,
offline_data_path + '/suboptimal_data.pkl',
data_type=cfg.policy.collect.get('data_type', 'naive')
)
# if not compiled_cfg.reward_model.auto: more feature
offline_data_save_type(
episodes_data, offline_data_path + '/episodes_data.pkl', data_type=cfg.policy.collect.get('data_type', 'naive')
)
offline_data_save_type(
learning_returns,
offline_data_path + '/learning_returns.pkl',
data_type=cfg.policy.collect.get('data_type', 'naive')
)
offline_data_save_type(
learning_rewards,
offline_data_path + '/learning_rewards.pkl',
data_type=cfg.policy.collect.get('data_type', 'naive')
)
offline_data_save_type(
checkpoints, offline_data_path + '/checkpoints.pkl', data_type=cfg.policy.collect.get('data_type', 'naive')
)
return checkpoints, episodes_data, learning_returns, learning_rewards
if __name__ == '__main__':
trex_collecting_data()
......@@ -54,8 +54,8 @@ CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
'--mode',
type=click.Choice(
[
'serial', 'serial_onpolicy', 'serial_sqil', 'serial_dqfd', 'parallel', 'dist', 'eval',
'serial_reward_model', 'serial_gail'
'serial', 'serial_onpolicy', 'serial_sqil', 'serial_dqfd', 'serial_trex', 'serial_trex_onpolicy',
'parallel', 'dist', 'eval', 'serial_reward_model', 'serial_gail'
]
),
help='serial-train or parallel-train or dist-train or eval'
......@@ -182,6 +182,16 @@ def cli(
+ "the models used in q learning now; However, one should still type the DQFD config in this "\
+ "place, i.e., {}{}".format(config[:config.find('_dqfd')], '_dqfd_config.py')
serial_pipeline_dqfd(config, expert_config, seed, max_iterations=train_iter)
elif mode == 'serial_trex':
from .serial_entry_trex import serial_pipeline_reward_model_trex
if config is None:
config = get_predefined_config(env, policy)
serial_pipeline_reward_model_trex(config, seed, max_iterations=train_iter)
elif mode == 'serial_trex_onpolicy':
from .serial_entry_trex_onpolicy import serial_pipeline_reward_model_trex_onpolicy
if config is None:
config = get_predefined_config(env, policy)
serial_pipeline_reward_model_trex_onpolicy(config, seed, max_iterations=train_iter)
elif mode == 'parallel':
from .parallel_entry import parallel_pipeline
parallel_pipeline(config, seed, enable_total_log, disable_flask_log)
......
from typing import Union, Optional, List, Any, Tuple
import os
import torch
import logging
from functools import partial
from tensorboardX import SummaryWriter
from ding.envs import get_vec_env_setting, create_env_manager
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
create_serial_collector
from ding.config import read_config, compile_config
from ding.policy import create_policy, PolicyFactory
from ding.reward_model import create_reward_model
from ding.utils import set_pkg_seed
# from dizoo.atari.config.serial.pong.pong_trex_sql_config import main_config, create_config
from dizoo.box2d.lunarlander.config.lunarlander_trex_offppo_config import main_config, create_config
def serial_pipeline_reward_model_trex(
input_cfg: Union[str, Tuple[dict, dict]],
seed: int = 0,
env_setting: Optional[List[Any]] = None,
model: Optional[torch.nn.Module] = None,
max_iterations: Optional[int] = int(1e10),
) -> 'Policy': # noqa
"""
Overview:
serial_pipeline_reward_model_trex.
Arguments:
- input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
``str`` type means config file path. \
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
``BaseEnv`` subclass, collector env config, and evaluator env config.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- max_iterations (:obj:`Optional[torch.nn.Module]`): Learner's max iteration. Pipeline will stop \
when reaching this iteration.
Returns:
- policy (:obj:`Policy`): Converged policy.
"""
if isinstance(input_cfg, str):
cfg, create_cfg = read_config(input_cfg)
else:
cfg, create_cfg = input_cfg
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
else:
env_fn, collector_env_cfg, evaluator_env_cfg = env_setting
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
collector_env.seed(cfg.seed)
evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
# Create worker components: learner, collector, evaluator, replay buffer, commander.
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = create_serial_collector(
cfg.policy.collect.collector,
env=collector_env,
policy=policy.collect_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name
)
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
commander = BaseSerialCommander(
cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode
)
reward_model = create_reward_model(cfg, policy.collect_mode.get_attribute('device'), tb_logger)
reward_model.train()
# ==========
# Main loop
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')
# Accumulate plenty of data at the beginning of training.
if cfg.policy.get('random_collect_size', 0) > 0:
if cfg.policy.get('transition_with_policy_data', False):
collector.reset_policy(policy.collect_mode)
else:
action_space = collector_env.env_info().act_space
random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
collector.reset_policy(random_policy)
collect_kwargs = commander.step()
new_data = collector.collect(n_sample=cfg.policy.random_collect_size, policy_kwargs=collect_kwargs)
replay_buffer.push(new_data, cur_collector_envstep=0)
collector.reset_policy(policy.collect_mode)
for _ in range(max_iterations):
collect_kwargs = commander.step()
# Evaluate policy performance
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
# Collect data by default config n_sample/n_episode
if hasattr(cfg.policy.collect, "each_iter_n_sample"): # TODO(pu)
new_data = collector.collect(
n_sample=cfg.policy.collect.each_iter_n_sample,
train_iter=learner.train_iter,
policy_kwargs=collect_kwargs
)
else:
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
# Learn policy from collected data
for i in range(cfg.policy.learn.update_per_collect):
# Learner will train ``update_per_collect`` times in one iteration.
train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
if train_data is None:
# It is possible that replay buffer's data count is too few to train ``update_per_collect`` times
logging.warning(
"Replay buffer's data can only train for {} steps. ".format(i) +
"You can modify data collect config, e.g. increasing n_sample, n_episode."
)
break
# update train_data reward
reward_model.estimate(train_data)
learner.train(train_data, collector.envstep)
if learner.policy.get_attribute('priority'):
replay_buffer.update(learner.priority_info)
# Learner's after_run hook.
learner.call_hook('after_run')
return policy
if __name__ == '__main__':
serial_pipeline_reward_model_trex([main_config, create_config])
from typing import Union, Optional, List, Any, Tuple
import os
import torch
import logging
from functools import partial
from tensorboardX import SummaryWriter
from ding.envs import get_vec_env_setting, create_env_manager
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
create_serial_collector
from ding.config import read_config, compile_config
from ding.policy import create_policy, PolicyFactory
from ding.reward_model import create_reward_model
from ding.utils import set_pkg_seed
from dizoo.mujoco.config.hopper_trex_onppo_default_config import main_config, create_config
def serial_pipeline_reward_model_trex_onpolicy(
input_cfg: Union[str, Tuple[dict, dict]],
seed: int = 0,
env_setting: Optional[List[Any]] = None,
model: Optional[torch.nn.Module] = None,
max_iterations: Optional[int] = int(1e10),
) -> 'Policy': # noqa
"""
Overview:
Serial pipeline entry for onpolicy algorithm(such as PPO).
Arguments:
- input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
``str`` type means config file path. \
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
``BaseEnv`` subclass, collector env config, and evaluator env config.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- max_iterations (:obj:`Optional[torch.nn.Module]`): Learner's max iteration. Pipeline will stop \
when reaching this iteration.
Returns:
- policy (:obj:`Policy`): Converged policy.
"""
if isinstance(input_cfg, str):
cfg, create_cfg = read_config(input_cfg)
else:
cfg, create_cfg = input_cfg
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
else:
env_fn, collector_env_cfg, evaluator_env_cfg = env_setting
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
collector_env.seed(cfg.seed)
evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
# Create worker components: learner, collector, evaluator, replay buffer, commander.
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = create_serial_collector(
cfg.policy.collect.collector,
env=collector_env,
policy=policy.collect_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name
)
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
reward_model = create_reward_model(cfg, policy.collect_mode.get_attribute('device'), tb_logger)
reward_model.train()
# ==========
# Main loop
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')
# Accumulate plenty of data at the beginning of training.
for _ in range(max_iterations):
# Evaluate policy performance
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
# Collect data by default config n_sample/n_episode
new_data = collector.collect(train_iter=learner.train_iter)
# Learn policy from collected data with modified rewards
reward_model.estimate(new_data)
learner.train(new_data, collector.envstep)
# Learner's after_run hook.
learner.call_hook('after_run')
return policy
if __name__ == '__main__':
serial_pipeline_reward_model_trex_onpolicy([main_config, create_config])
......@@ -3,10 +3,14 @@ import pytest
import os
import pickle
from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, cartpole_ppo_offpolicy_create_config # noqa
from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, \
cartpole_ppo_offpolicy_create_config # noqa
from dizoo.classic_control.cartpole.config.cartpole_trex_offppo_config import cartpole_trex_ppo_offpolicy_config,\
cartpole_trex_ppo_offpolicy_create_config
from dizoo.classic_control.cartpole.envs import CartPoleEnv
from ding.entry import serial_pipeline, eval, collect_demo_data
from ding.config import compile_config
from ding.entry.application_entry import collect_episodic_demo_data, epsiode_to_transitions
@pytest.fixture(scope='module')
......@@ -58,4 +62,28 @@ class TestApplication:
exp_data = pickle.load(f)
assert isinstance(exp_data, list)
assert isinstance(exp_data[0], dict)
def test_collect_episodic_demo_data(self, setup_state_dict):
config = deepcopy(cartpole_trex_ppo_offpolicy_config), deepcopy(cartpole_trex_ppo_offpolicy_create_config)
collect_count = 16
expert_data_path = './expert.data'
collect_episodic_demo_data(
config,
seed=0,
state_dict=setup_state_dict['collect'],
expert_data_path=expert_data_path,
collect_count=collect_count,
)
with open(expert_data_path, 'rb') as f:
exp_data = pickle.load(f)
assert isinstance(exp_data, list)
assert isinstance(exp_data[0][0], dict)
def test_epsiode_to_transitions(self):
expert_data_path = './expert.data'
epsiode_to_transitions(data_path=expert_data_path, expert_data_path=expert_data_path, nstep=3)
with open(expert_data_path, 'rb') as f:
exp_data = pickle.load(f)
assert isinstance(exp_data, list)
assert isinstance(exp_data[0], dict)
os.popen('rm -rf ./expert.data ckpt* log')
from easydict import EasyDict
import pytest
from copy import deepcopy
import os
from itertools import product
import torch
from dizoo.classic_control.cartpole.config.cartpole_trex_offppo_config import cartpole_trex_ppo_offpolicy_config,\
cartpole_trex_ppo_offpolicy_create_config
from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config,\
cartpole_ppo_offpolicy_create_config
from ding.entry.application_entry_trex_collect_data import collect_episodic_demo_data_for_trex, trex_collecting_data
from ding.entry import serial_pipeline
@pytest.mark.unittest
def test_collect_episodic_demo_data_for_trex():
expert_policy_state_dict_path = './expert_policy.pth'
expert_policy_state_dict_path = os.path.abspath('ding/entry/expert_policy.pth')
config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
expert_policy = serial_pipeline(config, seed=0)
torch.save(expert_policy.collect_mode.state_dict(), expert_policy_state_dict_path)
config = deepcopy(cartpole_trex_ppo_offpolicy_config), deepcopy(cartpole_trex_ppo_offpolicy_create_config)
collect_count = 1
save_cfg_path = './cartpole_trex_offppo'
save_cfg_path = os.path.abspath(save_cfg_path)
exp_data = collect_episodic_demo_data_for_trex(
config,
seed=0,
state_dict_path=expert_policy_state_dict_path,
save_cfg_path=save_cfg_path,
collect_count=collect_count,
rank=1,
)
assert isinstance(exp_data, list)
assert isinstance(exp_data[0][0], dict)
os.popen('rm -rf {}'.format(save_cfg_path))
os.popen('rm -rf {}'.format(expert_policy_state_dict_path))
@pytest.mark.unittest
def test_trex_collecting_data():
expert_policy_state_dict_path = './cartpole_ppo_offpolicy'
expert_policy_state_dict_path = os.path.abspath(expert_policy_state_dict_path)
config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
expert_policy = serial_pipeline(config, seed=0)
args = EasyDict(
{
'cfg': [deepcopy(cartpole_trex_ppo_offpolicy_config),
deepcopy(cartpole_trex_ppo_offpolicy_create_config)],
'seed': 0,
'device': 'cpu'
}
)
args.cfg[0].reward_model.offline_data_path = 'dizoo/classic_control/cartpole/config/cartpole_trex_offppo'
args.cfg[0].reward_model.offline_data_path = os.path.abspath(args.cfg[0].reward_model.offline_data_path)
args.cfg[0].reward_model.reward_model_path = args.cfg[0].reward_model.offline_data_path + '/cartpole.params'
args.cfg[0].reward_model.expert_model_path = './cartpole_ppo_offpolicy'
args.cfg[0].reward_model.expert_model_path = os.path.abspath(args.cfg[0].reward_model.expert_model_path)
trex_collecting_data(args=args)
os.popen('rm -rf {}'.format(expert_policy_state_dict_path))
os.popen('rm -rf {}'.format(args.cfg[0].reward_model.offline_data_path))
import pytest
import torch
from copy import deepcopy
from ding.entry import serial_pipeline_onpolicy, serial_pipeline_guided_cost
from dizoo.classic_control.cartpole.config import cartpole_ppo_config, cartpole_ppo_create_config
from dizoo.classic_control.cartpole.config import cartpole_gcl_ppo_onpolicy_config, \
cartpole_gcl_ppo_onpolicy_create_config
@pytest.mark.unittest
def test_guided_cost():
expert_policy_state_dict_path = './expert_policy.pth'
config = [deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)]
expert_policy = serial_pipeline_onpolicy(config, seed=0)
torch.save(expert_policy.collect_mode.state_dict(), expert_policy_state_dict_path)
config = [deepcopy(cartpole_gcl_ppo_onpolicy_config), deepcopy(cartpole_gcl_ppo_onpolicy_create_config)]
config[0].policy.collect.demonstration_info_path = expert_policy_state_dict_path
config[0].policy.learn.update_per_collect = 1
try:
serial_pipeline_guided_cost(config, seed=0, max_iterations=1)
except Exception:
assert False, "pipeline fail"
import pytest
from copy import deepcopy
import os
from easydict import EasyDict
import torch
from ding.entry import serial_pipeline
from ding.entry.serial_entry_trex import serial_pipeline_reward_model_trex
from dizoo.classic_control.cartpole.config.cartpole_trex_offppo_config import cartpole_trex_ppo_offpolicy_config,\
cartpole_trex_ppo_offpolicy_create_config
from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config,\
cartpole_ppo_offpolicy_create_config
from ding.entry.application_entry_trex_collect_data import trex_collecting_data
@pytest.mark.unittest
def test_serial_pipeline_reward_model_trex():
config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
expert_policy = serial_pipeline(config, seed=0)
config = [deepcopy(cartpole_trex_ppo_offpolicy_config), deepcopy(cartpole_trex_ppo_offpolicy_create_config)]
config[0].reward_model.offline_data_path = 'dizoo/classic_control/cartpole/config/cartpole_trex_offppo'
config[0].reward_model.offline_data_path = os.path.abspath(config[0].reward_model.offline_data_path)
config[0].reward_model.reward_model_path = config[0].reward_model.offline_data_path + '/cartpole.params'
config[0].reward_model.expert_model_path = './cartpole_ppo_offpolicy'
config[0].reward_model.expert_model_path = os.path.abspath(config[0].reward_model.expert_model_path)
args = EasyDict({'cfg': deepcopy(config), 'seed': 0, 'device': 'cpu'})
trex_collecting_data(args=args)
try:
serial_pipeline_reward_model_trex(config, seed=0, max_iterations=1)
except Exception:
assert False, "pipeline fail"
import pytest
from copy import deepcopy
import os
from easydict import EasyDict
import torch
from ding.entry import serial_pipeline_onpolicy
from ding.entry.serial_entry_trex_onpolicy import serial_pipeline_reward_model_trex_onpolicy
from dizoo.mujoco.config import hopper_ppo_default_config, hopper_ppo_create_default_config
from dizoo.mujoco.config import hopper_trex_ppo_default_config, hopper_trex_ppo_create_default_config
from ding.entry.application_entry_trex_collect_data import trex_collecting_data
@pytest.mark.unittest
def test_serial_pipeline_reward_model_trex():
config = [deepcopy(hopper_ppo_default_config), deepcopy(hopper_ppo_create_default_config)]
expert_policy = serial_pipeline_onpolicy(config, seed=0, max_iterations=90)
config = [deepcopy(hopper_trex_ppo_default_config), deepcopy(hopper_trex_ppo_create_default_config)]
config[0].reward_model.offline_data_path = 'dizoo/mujoco/config/hopper_trex_onppo'
config[0].reward_model.offline_data_path = os.path.abspath(config[0].reward_model.offline_data_path)
config[0].reward_model.reward_model_path = config[0].reward_model.offline_data_path + '/hopper.params'
config[0].reward_model.expert_model_path = './hopper_onppo'
config[0].reward_model.expert_model_path = os.path.abspath(config[0].reward_model.expert_model_path)
args = EasyDict({'cfg': deepcopy(config), 'seed': 0, 'device': 'cpu'})
trex_collecting_data(args=args)
try:
serial_pipeline_reward_model_trex_onpolicy(config, seed=0, max_iterations=1)
except Exception:
assert False, "pipeline fail"
......@@ -4,6 +4,7 @@ from .pdeil_irl_model import PdeilRewardModel
from .gail_irl_model import GailRewardModel
from .pwil_irl_model import PwilRewardModel
from .red_irl_model import RedRewardModel
from .trex_reward_model import TrexRewardModel
# sparse reward
from .her_reward_model import HerRewardModel
# exploration
......
......@@ -90,7 +90,10 @@ def create_reward_model(cfg: dict, device: str, tb_logger: 'SummaryWriter') -> B
cfg = copy.deepcopy(cfg)
if 'import_names' in cfg:
import_module(cfg.pop('import_names'))
reward_model_type = cfg.pop('type')
if hasattr(cfg, 'reward_model'):
reward_model_type = cfg.reward_model.pop('type')
else:
reward_model_type = cfg.pop('type')
return REWARD_MODEL_REGISTRY.build(reward_model_type, cfg, device=device, tb_logger=tb_logger)
......
from collections.abc import Iterable
from easydict import EasyDict
import numpy as np
import pickle
from copy import deepcopy
from typing import Tuple, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal, Independent
from torch.distributions.categorical import Categorical
from ding.utils import REWARD_MODEL_REGISTRY
from ding.model.template.q_learning import DQN
from ding.model.template.vac import VAC
from ding.model.template.qac import QAC
from ding.utils import SequenceType
from ding.model.common import FCEncoder
from ding.utils.data import offline_data_save_type
from ding.utils import build_logger
from dizoo.atari.envs.atari_wrappers import wrap_deepmind
from dizoo.mujoco.envs.mujoco_wrappers import wrap_mujoco
from .base_reward_model import BaseRewardModel
from .rnd_reward_model import collect_states
class ConvEncoder(nn.Module):
r"""
Overview:
The ``Convolution Encoder`` used in models. Used to encoder raw 2-dim observation.
Interfaces:
``__init__``, ``forward``
"""
def __init__(
self,
obs_shape: SequenceType,
hidden_size_list: SequenceType = [16, 16, 16, 16, 64, 1],
activation: Optional[nn.Module] = nn.LeakyReLU(),
norm_type: Optional[str] = None
) -> None:
r"""
Overview:
Init the Convolution Encoder according to arguments.
Arguments:
- obs_shape (:obj:`SequenceType`): Sequence of ``in_channel``, some ``output size``
- hidden_size_list (:obj:`SequenceType`): The collection of ``hidden_size``
- activation (:obj:`nn.Module`):
The type of activation to use in the conv ``layers``,
if ``None`` then default set to ``nn.ReLU()``
- norm_type (:obj:`str`):
The type of normalization to use, see ``ding.torch_utils.ResBlock`` for more details
"""
super(ConvEncoder, self).__init__()
self.obs_shape = obs_shape
self.act = activation
self.hidden_size_list = hidden_size_list
layers = []
kernel_size = [7, 5, 3, 3]
stride = [3, 2, 1, 1]
input_size = obs_shape[0] # in_channel
for i in range(len(kernel_size)):
layers.append(nn.Conv2d(input_size, hidden_size_list[i], kernel_size[i], stride[i]))
layers.append(self.act)
input_size = hidden_size_list[i]
layers.append(nn.Flatten())
self.main = nn.Sequential(*layers)
flatten_size = self._get_flatten_size()
self.mid = nn.Sequential(
nn.Linear(flatten_size, hidden_size_list[-2]), self.act,
nn.Linear(hidden_size_list[-2], hidden_size_list[-1])
)
def _get_flatten_size(self) -> int:
r"""
Overview:
Get the encoding size after ``self.main`` to get the number of ``in-features`` to feed to ``nn.Linear``.
Arguments:
- x (:obj:`torch.Tensor`): Encoded Tensor after ``self.main``
Returns:
- outputs (:obj:`torch.Tensor`): Size int, also number of in-feature
"""
test_data = torch.randn(1, *self.obs_shape)
with torch.no_grad():
output = self.main(test_data)
return output.shape[1]
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""
Overview:
Return embedding tensor of the env observation
Arguments:
- x (:obj:`torch.Tensor`): Env raw observation
Returns:
- outputs (:obj:`torch.Tensor`): Embedding tensor
"""
x = self.main(x)
x = self.mid(x)
return x
class TrexModel(nn.Module):
def __init__(self, obs_shape):
super(TrexModel, self).__init__()
if isinstance(obs_shape, int) or len(obs_shape) == 1:
self.encoder = nn.Sequential(FCEncoder(obs_shape, [512, 64]), nn.Linear(64, 1))
# Conv Encoder
elif len(obs_shape) == 3:
self.encoder = ConvEncoder(obs_shape)
else:
raise KeyError(
"not support obs_shape for pre-defined encoder: {}, please customize your own Trex model".
format(obs_shape)
)
def cum_return(self, traj: torch.Tensor, mode: str = 'sum') -> Tuple[torch.Tensor, torch.Tensor]:
'''calculate cumulative return of trajectory'''
r = self.encoder(traj)
if mode == 'sum':
sum_rewards = torch.sum(r)
sum_abs_rewards = torch.sum(torch.abs(r))
# print(sum_rewards)
# print(r)
return sum_rewards, sum_abs_rewards
elif mode == 'batch':
# print(r)
return r, torch.abs(r)
else:
raise KeyError("not support mode: {}, please choose mode=sum or mode=batch".format(mode))
def forward(self, traj_i: torch.Tensor, traj_j: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
'''compute cumulative return for each trajectory and return logits'''
cum_r_i, abs_r_i = self.cum_return(traj_i)
cum_r_j, abs_r_j = self.cum_return(traj_j)
return torch.cat((cum_r_i.unsqueeze(0), cum_r_j.unsqueeze(0)), 0), abs_r_i + abs_r_j
@REWARD_MODEL_REGISTRY.register('trex')
class TrexRewardModel(BaseRewardModel):
"""
Overview:
The Trex reward model class (https://arxiv.org/pdf/1904.06387.pdf)
Interface:
``estimate``, ``train``, ``load_expert_data``, ``collect_data``, ``clear_date``, \
``__init__``, ``_train``,
"""
config = dict(
type='trex',
learning_rate=1e-5,
update_per_collect=100,
batch_size=64,
target_new_data_count=64,
hidden_size=128,
)
def __init__(self, config: EasyDict, device: str, tb_logger: 'SummaryWriter') -> None: # noqa
"""
Overview:
Initialize ``self.`` See ``help(type(self))`` for accurate signature.
Arguments:
- cfg (:obj:`EasyDict`): Training config
- device (:obj:`str`): Device usage, i.e. "cpu" or "cuda"
- tb_logger (:obj:`SummaryWriter`): Logger, defaultly set as 'SummaryWriter' for model summary
"""
super(TrexRewardModel, self).__init__()
self.cfg = config
assert device in ["cpu", "cuda"] or "cuda" in device
self.device = device
self.tb_logger = tb_logger
self.reward_model = TrexModel(self.cfg.policy.model.get('obs_shape'))
self.reward_model.to(self.device)
self.pre_expert_data = []
self.train_data = []
self.expert_data_loader = None
self.opt = optim.Adam(self.reward_model.parameters(), config.reward_model.learning_rate)
self.train_iter = 0
self.learning_returns = []
self.learning_rewards = []
self.training_obs = []
self.training_labels = []
self.num_trajs = 0 # number of downsampled full trajectories
self.num_snippets = 6000 # number of short subtrajectories to sample
# minimum number of short subtrajectories to sample
self.min_snippet_length = config.reward_model.min_snippet_length
# maximum number of short subtrajectories to sample
self.max_snippet_length = config.reward_model.max_snippet_length
self.l1_reg = 0
self.data_for_save = {}
self._logger, self._tb_logger = build_logger(
path='./{}/log/{}'.format(self.cfg.exp_name, 'trex_reward_model'), name='trex_reward_model'
)
self.load_expert_data()
def load_expert_data(self) -> None:
"""
Overview:
Getting the expert data from ``config.expert_data_path`` attribute in self
Effects:
This is a side effect function which updates the expert data attribute \
(i.e. ``self.expert_data``) with ``fn:concat_state_action_pairs``
"""
with open(self.cfg.reward_model.offline_data_path + '/episodes_data.pkl', 'rb') as f:
self.pre_expert_data = pickle.load(f)
with open(self.cfg.reward_model.offline_data_path + '/learning_returns.pkl', 'rb') as f:
self.learning_returns = pickle.load(f)
with open(self.cfg.reward_model.offline_data_path + '/learning_rewards.pkl', 'rb') as f:
self.learning_reward = pickle.load(f)
with open(self.cfg.reward_model.offline_data_path + '/checkpoints.pkl', 'rb') as f:
self.checkpoints = pickle.load(f)
self.training_obs, self.training_labels = self.create_training_data()
self._logger.info("num_training_obs: {}".format(len(self.training_obs)))
self._logger.info("num_labels: {}".format(len(self.training_labels)))
def create_training_data(self):
demonstrations = self.pre_expert_data
num_trajs = self.num_trajs
num_snippets = self.num_snippets
min_snippet_length = self.min_snippet_length
max_snippet_length = self.max_snippet_length
demo_lengths = [len(d) for d in demonstrations]
self._logger.info("demo_lengths: {}".format(demo_lengths))
max_snippet_length = min(np.min(demo_lengths), max_snippet_length)
self._logger.info("min snippet length: {}".format(min_snippet_length))
self._logger.info("max snippet length: {}".format(max_snippet_length))
self._logger.info(len(self.learning_returns))
self._logger.info(len(demonstrations))
self._logger.info("learning returns: {}".format([a[0] for a in zip(self.learning_returns, demonstrations)]))
demonstrations = [x for _, x in sorted(zip(self.learning_returns, demonstrations), key=lambda pair: pair[0])]
sorted_returns = sorted(self.learning_returns)
self._logger.info("sorted learning returns: {}".format(sorted_returns))
#collect training data
max_traj_length = 0
num_demos = len(demonstrations)
#add full trajs (for use on Enduro)
si = np.random.randint(6, size=num_trajs)
sj = np.random.randint(6, size=num_trajs)
step = np.random.randint(3, 7, size=num_trajs)
for n in range(num_trajs):
ti = 0
tj = 0
#only add trajectories that are different returns
while (ti == tj):
#pick two random demonstrations
ti = np.random.randint(num_demos)
tj = np.random.randint(num_demos)
#create random partial trajs by finding random start frame and random skip frame
traj_i = demonstrations[ti][si[n]::step[n]] # slice(start,stop,step)
traj_j = demonstrations[tj][sj[n]::step[n]]
label = int(ti <= tj)
self.training_obs.append((traj_i, traj_j))
self.training_labels.append(label)
max_traj_length = max(max_traj_length, len(traj_i), len(traj_j))
#fixed size snippets with progress prior
rand_length = np.random.randint(min_snippet_length, max_snippet_length, size=num_snippets)
for n in range(num_snippets):
ti = 0
tj = 0
#only add trajectories that are different returns
while (ti == tj):
#pick two random demonstrations
ti = np.random.randint(num_demos)
tj = np.random.randint(num_demos)
#create random snippets
#find min length of both demos to ensure we can pick a demo no earlier
#than that chosen in worse preferred demo
min_length = min(len(demonstrations[ti]), len(demonstrations[tj]))
if ti < tj: # pick tj snippet to be later than ti
ti_start = np.random.randint(min_length - rand_length[n] + 1)
# print(ti_start, len(demonstrations[tj]))
tj_start = np.random.randint(ti_start, len(demonstrations[tj]) - rand_length[n] + 1)
else: # ti is better so pick later snippet in ti
tj_start = np.random.randint(min_length - rand_length[n] + 1)
# print(tj_start, len(demonstrations[ti]))
ti_start = np.random.randint(tj_start, len(demonstrations[ti]) - rand_length[n] + 1)
traj_i = demonstrations[ti][ti_start:ti_start + rand_length[n]:2
] # skip everyother framestack to reduce size
traj_j = demonstrations[tj][tj_start:tj_start + rand_length[n]:2]
max_traj_length = max(max_traj_length, len(traj_i), len(traj_j))
label = int(ti <= tj)
self.training_obs.append((traj_i, traj_j))
self.training_labels.append(label)
self._logger.info(("maximum traj length: {}".format(max_traj_length)))
return self.training_obs, self.training_labels
def train(self):
# check if gpu available
device = self.device # torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Assume that we are on a CUDA machine, then this should print a CUDA device:
self._logger.info("device: {}".format(device))
training_inputs, training_outputs = self.training_obs, self.training_labels
loss_criterion = nn.CrossEntropyLoss()
cum_loss = 0.0
training_data = list(zip(training_inputs, training_outputs))
for epoch in range(self.cfg.reward_model.update_per_collect): # todo
np.random.shuffle(training_data)
training_obs, training_labels = zip(*training_data)
for i in range(len(training_labels)):
# traj_i, traj_j has the same length, however, they change as i increases
traj_i, traj_j = training_obs[i] # traj_i is a list of array generated by env.step
traj_i = np.array(traj_i)
traj_j = np.array(traj_j)
traj_i = torch.from_numpy(traj_i).float().to(device)
traj_j = torch.from_numpy(traj_j).float().to(device)
# training_labels[i] is a boolean integer: 0 or 1
labels = torch.tensor([training_labels[i]]).to(device)
# forward + backward + zero out gradient + optimize
outputs, abs_rewards = self.reward_model.forward(traj_i, traj_j)
outputs = outputs.unsqueeze(0)
loss = loss_criterion(outputs, labels) + self.l1_reg * abs_rewards
self.opt.zero_grad()
loss.backward()
self.opt.step()
# print stats to see if learning
item_loss = loss.item()
cum_loss += item_loss
if i % 100 == 99:
# print(i)
self._logger.info("epoch {}:{} loss {}".format(epoch, i, cum_loss))
self._logger.info("abs_returns: {}".format(abs_rewards))
cum_loss = 0.0
self._logger.info("check pointing")
torch.save(self.reward_model.state_dict(), self.cfg.reward_model.reward_model_path)
torch.save(self.reward_model.state_dict(), self.cfg.reward_model.reward_model_path)
self._logger.info("finished training")
# print out predicted cumulative returns and actual returns
sorted_returns = sorted(self.learning_returns)
with torch.no_grad():
pred_returns = [self.predict_traj_return(self.reward_model, traj) for traj in self.pre_expert_data]
for i, p in enumerate(pred_returns):
self._logger.info("{} {} {}".format(i, p, sorted_returns[i]))
info = {
#"demo_length": [len(d) for d in self.pre_expert_data],
#"min_snippet_length": self.min_snippet_length,
#"max_snippet_length": min(np.min([len(d) for d in self.pre_expert_data]), self.max_snippet_length),
#"len_num_training_obs": len(self.training_obs),
#"lem_num_labels": len(self.training_labels),
"accuracy": self.calc_accuracy(self.reward_model, self.training_obs, self.training_labels),
}
self._logger.info(
"accuracy and comparison:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))
)
def predict_traj_return(self, net, traj):
device = self.device
# torch.set_printoptions(precision=20)
# torch.use_deterministic_algorithms(True)
with torch.no_grad():
rewards_from_obs = net.cum_return(
torch.from_numpy(np.array(traj)).float().to(device), mode='batch'
)[0].squeeze().tolist()
# rewards_from_obs1 = net.cum_return(torch.from_numpy(np.array([traj[0]])).float().to(device))[0].item()
# different precision
return sum(rewards_from_obs) # rewards_from_obs is a list of floats
def calc_accuracy(self, reward_network, training_inputs, training_outputs):
device = self.device
loss_criterion = nn.CrossEntropyLoss()
num_correct = 0.
with torch.no_grad():
for i in range(len(training_inputs)):
label = training_outputs[i]
traj_i, traj_j = training_inputs[i]
traj_i = np.array(traj_i)
traj_j = np.array(traj_j)
traj_i = torch.from_numpy(traj_i).float().to(device)
traj_j = torch.from_numpy(traj_j).float().to(device)
#forward to get logits
outputs, abs_return = reward_network.forward(traj_i, traj_j)
_, pred_label = torch.max(outputs, 0)
if pred_label.item() == label:
num_correct += 1.
return num_correct / len(training_inputs)
def estimate(self, data: list) -> None:
"""
Overview:
Estimate reward by rewriting the reward key in each row of the data.
Arguments:
- data (:obj:`list`): the list of data used for estimation, with at least \
``obs`` and ``action`` keys.
Effects:
- This is a side effect function which updates the reward values in place.
"""
res = collect_states(data)
res = torch.stack(res).to(self.device)
with torch.no_grad():
sum_rewards, sum_abs_rewards = self.reward_model.cum_return(res, mode='batch')
for item, rew in zip(data, sum_rewards): # TODO optimise this loop as well ?
item['reward'] = rew
def collect_data(self, data: list) -> None:
"""
Overview:
Collecting training data formatted by ``fn:concat_state_action_pairs``.
Arguments:
- data (:obj:`Any`): Raw training data (e.g. some form of states, actions, obs, etc)
Effects:
- This is a side effect function which updates the data attribute in ``self``
"""
pass
def clear_data(self) -> None:
"""
Overview:
Clearing training data. \
This is a side effect function which clears the data attribute in ``self``
"""
self.training_obs.clear()
self.training_labels.clear()
from copy import deepcopy
from easydict import EasyDict
pong_trex_ppo_config = dict(
exp_name='pong_trex_offppo',
env=dict(
collector_env_num=16,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=20,
env_id='PongNoFrameskip-v4',
frame_stack=4,
manager=dict(shared_memory=False, )
),
reward_model=dict(
type='trex',
algo_for_model='ppo',
env_id='PongNoFrameskip-v4',
min_snippet_length=50,
max_snippet_length=100,
checkpoint_min=0,
checkpoint_max=100,
checkpoint_step=100,
learning_rate=1e-5,
update_per_collect=1,
expert_model_path='abs model path',
reward_model_path='abs data path + ./pong.params',
offline_data_path='abs data path',
),
policy=dict(
cuda=True,
random_collect_size=2048,
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[64, 64, 128],
actor_head_hidden_size=128,
critic_head_hidden_size=128,
critic_head_layer_num=1, # Todo, to solve generality problem
),
learn=dict(
update_per_collect=24,
batch_size=128,
# (bool) Whether to normalize advantage. Default to False.
adv_norm=False,
learning_rate=0.0002,
# (float) loss weight of the value network, the weight of policy network is set to 1
value_weight=0.5,
# (float) loss weight of the entropy regularization, the weight of policy network is set to 1
entropy_weight=0.01,
clip_ratio=0.1,
),
collect=dict(
# (int) collect n_sample data, train model n_iteration times
n_sample=1024,
# (float) the trade-off factor lambda to balance 1step td and mc
gae_lambda=0.95,
discount_factor=0.99,
),
eval=dict(evaluator=dict(eval_freq=1000, )),
other=dict(replay_buffer=dict(
replay_buffer_size=100000,
max_use=5,
), ),
),
)
main_config = EasyDict(pong_trex_ppo_config)
pong_trex_ppo_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
# env_manager=dict(type='subprocess'),
env_manager=dict(type='base'),
policy=dict(type='ppo_offpolicy'),
)
create_config = EasyDict(pong_trex_ppo_create_config)
from copy import deepcopy
from easydict import EasyDict
pong_trex_sql_config = dict(
exp_name='pong_trex_sql',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=20,
env_id='PongNoFrameskip-v4',
frame_stack=4,
manager=dict(shared_memory=False, )
),
reward_model=dict(
type='trex',
algo_for_model='sql',
env_id='PongNoFrameskip-v4',
min_snippet_length=50,
max_snippet_length=100,
checkpoint_min=10000,
checkpoint_max=50000,
checkpoint_step=10000,
learning_rate=1e-5,
update_per_collect=1,
expert_model_path='abs model path',
reward_model_path='abs data path + ./pong.params',
offline_data_path='abs data path',
),
policy=dict(
cuda=False,
priority=False,
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[128, 128, 512],
),
nstep=1,
discount_factor=0.99,
learn=dict(update_per_collect=10, batch_size=32, learning_rate=0.0001, target_update_freq=500, alpha=0.12),
collect=dict(n_sample=96, ),
other=dict(
eps=dict(
type='exp',
start=1.,
end=0.05,
decay=250000,
),
replay_buffer=dict(replay_buffer_size=100000, ),
),
),
)
pong_trex_sql_config = EasyDict(pong_trex_sql_config)
main_config = pong_trex_sql_config
pong_trex_sql_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='base', force_reproducibility=True),
policy=dict(type='sql'),
)
pong_trex_sql_create_config = EasyDict(pong_trex_sql_create_config)
create_config = pong_trex_sql_create_config
from copy import deepcopy
from easydict import EasyDict
qbert_trex_dqn_config = dict(
exp_name='qbert_trex_dqn',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=30000,
env_id='QbertNoFrameskip-v4',
frame_stack=4,
manager=dict(shared_memory=False, )
),
reward_model=dict(
type='trex',
algo_for_model='dqn',
env_id='QbertNoFrameskip-v4',
min_snippet_length=30,
max_snippet_length=100,
checkpoint_min=0,
checkpoint_max=100,
checkpoint_step=100,
learning_rate=1e-5,
update_per_collect=1,
expert_model_path='abs model path',
reward_model_path='abs data path + ./qbert.params',
offline_data_path='abs data path',
),
policy=dict(
cuda=True,
priority=False,
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[128, 128, 512],
),
nstep=1,
discount_factor=0.99,
learn=dict(
update_per_collect=10,
batch_size=32,
learning_rate=0.0001,
target_update_freq=500,
),
collect=dict(n_sample=100, ),
eval=dict(evaluator=dict(eval_freq=4000, )),
other=dict(
eps=dict(
type='exp',
start=1.,
end=0.05,
decay=1000000,
),
replay_buffer=dict(replay_buffer_size=400000, ),
),
),
)
qbert_trex_dqn_config = EasyDict(qbert_trex_dqn_config)
main_config = qbert_trex_dqn_config
qbert_trex_dqn_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='base'),
policy=dict(type='dqn'),
)
qbert_trex_dqn_create_config = EasyDict(qbert_trex_dqn_create_config)
create_config = qbert_trex_dqn_create_config
from copy import deepcopy
from easydict import EasyDict
qbert_trex_ppo_config = dict(
exp_name='qbert_trex_offppo',
env=dict(
collector_env_num=16,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=10000000000,
env_id='QbertNoFrameskip-v4',
frame_stack=4,
manager=dict(shared_memory=False, )
),
reward_model=dict(
type='trex',
algo_for_model='ppo',
env_id='QbertNoFrameskip-v4',
min_snippet_length=30,
max_snippet_length=100,
checkpoint_min=0,
checkpoint_max=100,
checkpoint_step=100,
learning_rate=1e-5,
update_per_collect=1,
expert_model_path='abs model path',
reward_model_path='abs data path + ./qbert.params',
offline_data_path='abs data path',
),
policy=dict(
cuda=True,
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[32, 64, 64, 128],
actor_head_hidden_size=128,
critic_head_hidden_size=128,
critic_head_layer_num=2,
),
learn=dict(
update_per_collect=24,
batch_size=128,
# (bool) Whether to normalize advantage. Default to False.
adv_norm=False,
learning_rate=0.0001,
# (float) loss weight of the value network, the weight of policy network is set to 1
value_weight=1.0,
# (float) loss weight of the entropy regularization, the weight of policy network is set to 1
entropy_weight=0.03,
clip_ratio=0.1,
),
collect=dict(
# (int) collect n_sample data, train model n_iteration times
n_sample=1024,
# (float) the trade-off factor lambda to balance 1step td and mc
gae_lambda=0.95,
discount_factor=0.99,
),
eval=dict(evaluator=dict(eval_freq=1000, )),
other=dict(replay_buffer=dict(
replay_buffer_size=100000,
max_use=3,
), ),
),
)
main_config = EasyDict(qbert_trex_ppo_config)
qbert_trex_ppo_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='base'),
policy=dict(type='ppo_offpolicy'),
)
create_config = EasyDict(qbert_trex_ppo_create_config)
from copy import deepcopy
from easydict import EasyDict
space_invaders_trex_dqn_config = dict(
exp_name='space_invaders_trex_dqn',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=10000000000,
env_id='SpaceInvadersNoFrameskip-v4',
frame_stack=4,
manager=dict(shared_memory=False, )
),
reward_model=dict(
type='trex',
algo_for_model='dqn',
env_id='SpaceInvadersNoFrameskip-v4',
min_snippet_length=30,
max_snippet_length=100,
checkpoint_min=0,
checkpoint_max=100,
checkpoint_step=100,
learning_rate=1e-5,
update_per_collect=1,
expert_model_path='abs model path',
reward_model_path='abs data path + ./spaceinvaders.params',
offline_data_path='abs data path',
),
policy=dict(
cuda=True,
priority=False,
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[128, 128, 512],
),
nstep=1,
discount_factor=0.99,
learn=dict(
update_per_collect=10,
batch_size=32,
learning_rate=0.0001,
target_update_freq=500,
),
collect=dict(n_sample=100, ),
eval=dict(evaluator=dict(eval_freq=100, )),
other=dict(
eps=dict(
type='exp',
start=1.,
end=0.05,
decay=1000000,
),
replay_buffer=dict(replay_buffer_size=400000, ),
),
),
)
space_invaders_trex_dqn_config = EasyDict(space_invaders_trex_dqn_config)
main_config = space_invaders_trex_dqn_config
space_invaders_trex_dqn_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='base'),
policy=dict(type='dqn'),
)
space_invaders_trex_dqn_create_config = EasyDict(space_invaders_trex_dqn_create_config)
create_config = space_invaders_trex_dqn_create_config
from copy import deepcopy
from easydict import EasyDict
space_invaders_trex_ppo_config = dict(
exp_name='space_invaders_trex_offppo',
env=dict(
collector_env_num=16,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=10000000000,
env_id='SpaceInvadersNoFrameskip-v4',
frame_stack=4,
manager=dict(shared_memory=False, )
),
reward_model=dict(
type='trex',
algo_for_model='ppo',
env_id='SpaceInvadersNoFrameskip-v4',
min_snippet_length=30,
max_snippet_length=100,
checkpoint_min=0,
checkpoint_max=100,
checkpoint_step=100,
learning_rate=1e-5,
update_per_collect=1,
expert_model_path='abs model path',
reward_model_path='abs data path + ./spaceinvaders.params',
offline_data_path='abs data path',
),
policy=dict(
cuda=True,
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
encoder_hidden_size_list=[32, 64, 64, 128],
actor_head_hidden_size=128,
critic_head_hidden_size=128,
critic_head_layer_num=2,
),
learn=dict(
update_per_collect=24,
batch_size=128,
# (bool) Whether to normalize advantage. Default to False.
adv_norm=False,
learning_rate=0.0001,
# (float) loss weight of the value network, the weight of policy network is set to 1
value_weight=1.0,
# (float) loss weight of the entropy regularization, the weight of policy network is set to 1
entropy_weight=0.03,
clip_ratio=0.1,
),
collect=dict(
# (int) collect n_sample data, train model n_iteration times
n_sample=1024,
# (float) the trade-off factor lambda to balance 1step td and mc
gae_lambda=0.95,
discount_factor=0.99,
),
eval=dict(evaluator=dict(eval_freq=1000, )),
other=dict(replay_buffer=dict(
replay_buffer_size=100000,
max_use=5,
), ),
),
)
main_config = EasyDict(space_invaders_trex_ppo_config)
space_invaders_trex_ppo_create_config = dict(
env=dict(
type='atari',
import_names=['dizoo.atari.envs.atari_env'],
),
env_manager=dict(type='base'),
policy=dict(type='ppo_offpolicy'),
)
create_config = EasyDict(space_invaders_trex_ppo_create_config)
......@@ -2,3 +2,5 @@ from .lunarlander_dqn_config import lunarlander_dqn_default_config, lunarlander_
from .lunarlander_dqn_gail_config import lunarlander_dqn_gail_create_config, lunarlander_dqn_gail_default_config
from .lunarlander_ppo_config import lunarlander_ppo_config
from .lunarlander_qrdqn_config import lunarlander_qrdqn_config, lunarlander_qrdqn_create_config
from .lunarlander_trex_dqn_config import lunarlander_trex_dqn_default_config, lunarlander_trex_dqn_create_config
from .lunarlander_trex_offppo_config import lunarlander_trex_ppo_config, lunarlander_trex_ppo_create_config
......@@ -2,7 +2,9 @@ from easydict import EasyDict
from ding.entry import serial_pipeline
lunarlander_ppo_config = dict(
exp_name='lunarlander_ppo',
env=dict(
manager=dict(shared_memory=True, force_reproducibility=True),
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=5,
......
from easydict import EasyDict
nstep = 1
lunarlander_trex_dqn_default_config = dict(
exp_name='lunarlander_trex_dqn',
env=dict(
# Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
manager=dict(shared_memory=True, force_reproducibility=True),
# Env number respectively for collector and evaluator.
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=5,
stop_value=200,
),
reward_model=dict(
type='trex',
algo_for_model='dqn',
env_id='LunarLander-v2',
min_snippet_length=30,
max_snippet_length=100,
checkpoint_min=100,
checkpoint_max=900,
checkpoint_step=100,
learning_rate=1e-5,
update_per_collect=1,
expert_model_path='abs model path',
reward_model_path='abs data path + ./lunarlander.params',
offline_data_path='abs data path',
),
policy=dict(
# Whether to use cuda for network.
cuda=False,
model=dict(
obs_shape=8,
action_shape=4,
encoder_hidden_size_list=[512, 64],
# Whether to use dueling head.
dueling=True,
),
# Reward's future discount factor, aka. gamma.
discount_factor=0.99,
# How many steps in td error.
nstep=nstep,
# learn_mode config
learn=dict(
update_per_collect=10,
batch_size=64,
learning_rate=0.001,
# Frequency of target network update.
target_update_freq=100,
),
# collect_mode config
collect=dict(
# You can use either "n_sample" or "n_episode" in collector.collect.
# Get "n_sample" samples per collect.
n_sample=64,
# Cut trajectories into pieces with length "unroll_len".
unroll_len=1,
),
# command_mode config
other=dict(
# Epsilon greedy with decay.
eps=dict(
# Decay type. Support ['exp', 'linear'].
type='exp',
start=0.95,
end=0.1,
decay=50000,
),
replay_buffer=dict(replay_buffer_size=100000, )
),
),
)
lunarlander_trex_dqn_default_config = EasyDict(lunarlander_trex_dqn_default_config)
main_config = lunarlander_trex_dqn_default_config
lunarlander_trex_dqn_create_config = dict(
env=dict(
type='lunarlander',
import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='dqn'),
)
lunarlander_trex_dqn_create_config = EasyDict(lunarlander_trex_dqn_create_config)
create_config = lunarlander_trex_dqn_create_config
from easydict import EasyDict
lunarlander_trex_ppo_config = dict(
exp_name='lunarlander_trex_offppo',
env=dict(
manager=dict(shared_memory=True, force_reproducibility=True),
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=5,
stop_value=200,
),
reward_model=dict(
type='trex',
algo_for_model='ppo',
env_id='LunarLander-v2',
min_snippet_length=30,
max_snippet_length=100,
checkpoint_min=1000,
checkpoint_max=9000,
checkpoint_step=1000,
learning_rate=1e-5,
update_per_collect=1,
expert_model_path='abs model path',
reward_model_path='abs data path + ./lunarlander.params',
offline_data_path='abs data path',
),
policy=dict(
cuda=True,
model=dict(
obs_shape=8,
action_shape=4,
),
learn=dict(
update_per_collect=4,
batch_size=64,
learning_rate=0.001,
value_weight=0.5,
entropy_weight=0.01,
clip_ratio=0.2,
nstep=1,
nstep_return=False,
adv_norm=True,
),
collect=dict(
n_sample=128,
unroll_len=1,
discount_factor=0.99,
gae_lambda=0.95,
),
),
)
lunarlander_trex_ppo_config = EasyDict(lunarlander_trex_ppo_config)
main_config = lunarlander_trex_ppo_config
lunarlander_trex_ppo_create_config = dict(
env=dict(
type='lunarlander',
import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='ppo_offpolicy'),
)
lunarlander_trex_ppo_create_config = EasyDict(lunarlander_trex_ppo_create_config)
create_config = lunarlander_trex_ppo_create_config
......@@ -15,4 +15,7 @@ from .cartpole_dqfd_config import cartpole_dqfd_config, cartpole_dqfd_create_con
from .cartpole_sqil_config import cartpole_sqil_config, cartpole_sqil_create_config
from .cartpole_sql_config import cartpole_sql_config, cartpole_sql_create_config
from .cartpole_dqn_gail_config import cartpole_dqn_gail_config, cartpole_dqn_gail_create_config
from .cartpole_gcl_config import cartpole_gcl_ppo_onpolicy_config, cartpole_gcl_ppo_onpolicy_create_config
from .cartpole_trex_dqn_config import cartpole_trex_dqn_config, cartpole_trex_dqn_create_config
from .cartpole_trex_offppo_config import cartpole_trex_ppo_offpolicy_config, cartpole_trex_ppo_offpolicy_create_config
# from .cartpole_ppo_default_loader import cartpole_ppo_default_loader
from easydict import EasyDict
from ding.entry import serial_pipeline_guided_cost
cartpole_ppo_offpolicy_config = dict(
cartpole_gcl_ppo_onpolicy_config = dict(
exp_name='cartpole_guided_cost',
env=dict(
collector_env_num=8,
......@@ -52,9 +51,9 @@ cartpole_ppo_offpolicy_config = dict(
),
),
)
cartpole_ppo_offpolicy_config = EasyDict(cartpole_ppo_offpolicy_config)
main_config = cartpole_ppo_offpolicy_config
cartpole_ppo_offpolicy_create_config = dict(
cartpole_gcl_ppo_onpolicy_config = EasyDict(cartpole_gcl_ppo_onpolicy_config)
main_config = cartpole_gcl_ppo_onpolicy_config
cartpole_gcl_ppo_onpolicy_create_config = dict(
env=dict(
type='cartpole',
import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
......@@ -63,8 +62,5 @@ cartpole_ppo_offpolicy_create_config = dict(
policy=dict(type='ppo'),
reward_model=dict(type='guided_cost'),
)
cartpole_ppo_offpolicy_create_config = EasyDict(cartpole_ppo_offpolicy_create_config)
create_config = cartpole_ppo_offpolicy_create_config
if __name__ == "__main__":
serial_pipeline_guided_cost([main_config, create_config], seed=0)
cartpole_gcl_ppo_onpolicy_create_config = EasyDict(cartpole_gcl_ppo_onpolicy_create_config)
create_config = cartpole_gcl_ppo_onpolicy_create_config
from easydict import EasyDict
cartpole_trex_dqn_config = dict(
exp_name='cartpole_trex_dqn',
env=dict(
manager=dict(shared_memory=True, force_reproducibility=True),
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=5,
stop_value=195,
replay_path='cartpole_dqn/video',
),
reward_model=dict(
type='trex',
algo_for_model='dqn',
env_id='CartPole-v0',
min_snippet_length=5,
max_snippet_length=100,
checkpoint_min=0,
checkpoint_max=500,
checkpoint_step=100,
learning_rate=1e-5,
update_per_collect=1,
expert_model_path='abs model path',
reward_model_path='abs data path + ./cartpole.params',
offline_data_path='abs data path',
),
policy=dict(
load_path='',
cuda=False,
model=dict(
obs_shape=4,
action_shape=2,
encoder_hidden_size_list=[128, 128, 64],
dueling=True,
),
nstep=1,
discount_factor=0.97,
learn=dict(
batch_size=64,
learning_rate=0.001,
),
collect=dict(n_sample=8),
eval=dict(evaluator=dict(eval_freq=40, )),
other=dict(
eps=dict(
type='exp',
start=0.95,
end=0.1,
decay=10000,
),
replay_buffer=dict(replay_buffer_size=20000, ),
),
),
)
cartpole_trex_dqn_config = EasyDict(cartpole_trex_dqn_config)
main_config = cartpole_trex_dqn_config
cartpole_trex_dqn_create_config = dict(
env=dict(
type='cartpole',
import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='dqn'),
)
cartpole_trex_dqn_create_config = EasyDict(cartpole_trex_dqn_create_config)
create_config = cartpole_trex_dqn_create_config
from easydict import EasyDict
cartpole_trex_ppo_offpolicy_config = dict(
exp_name='cartpole_trex_offppo',
env=dict(
manager=dict(shared_memory=True, force_reproducibility=True),
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=5,
stop_value=195,
),
reward_model=dict(
type='trex',
algo_for_model='ppo',
env_id='CartPole-v0',
min_snippet_length=5,
max_snippet_length=100,
checkpoint_min=0,
checkpoint_max=1000,
checkpoint_step=1000,
learning_rate=1e-5,
update_per_collect=1,
expert_model_path='abs model path',
reward_model_path='abs data path + ./cartpole.params',
offline_data_path='abs data path',
),
policy=dict(
cuda=False,
model=dict(
obs_shape=4,
action_shape=2,
encoder_hidden_size_list=[64, 64, 128],
critic_head_hidden_size=128,
actor_head_hidden_size=128,
critic_head_layer_num=1,
),
learn=dict(
update_per_collect=6,
batch_size=64,
learning_rate=0.001,
value_weight=0.5,
entropy_weight=0.01,
clip_ratio=0.2,
),
collect=dict(
n_sample=128,
unroll_len=1,
discount_factor=0.9,
gae_lambda=0.95,
),
other=dict(replay_buffer=dict(replay_buffer_size=5000))
),
)
cartpole_trex_ppo_offpolicy_config = EasyDict(cartpole_trex_ppo_offpolicy_config)
main_config = cartpole_trex_ppo_offpolicy_config
cartpole_trex_ppo_offpolicy_create_config = dict(
env=dict(
type='cartpole',
import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='ppo_offpolicy'),
)
cartpole_trex_ppo_offpolicy_create_config = EasyDict(cartpole_trex_ppo_offpolicy_create_config)
create_config = cartpole_trex_ppo_offpolicy_create_config
from .ant_ddpg_default_config import ant_ddpg_default_config
from .ant_sac_default_config import ant_sac_default_config
from .ant_td3_default_config import ant_td3_default_config
from .ant_trex_onppo_default_config import ant_trex_ppo_default_config, ant_trex_ppo_create_default_config
from .ant_trex_sac_default_config import ant_trex_sac_default_config, ant_trex_sac_default_create_config
from .halfcheetah_ddpg_default_config import halfcheetah_ddpg_default_config
from .halfcheetah_sac_default_config import halfcheetah_sac_default_config
from .halfcheetah_td3_default_config import halfcheetah_td3_default_config
from .halfcheetah_trex_onppo_default_config import halfCheetah_trex_ppo_default_config, halfCheetah_trex_ppo_create_default_config
from .halfcheetah_trex_sac_default_config import halfcheetah_trex_sac_default_config, halfcheetah_trex_sac_default_create_config
from .halfcheetah_onppo_default_config import halfcheetah_ppo_default_config, halfcheetah_ppo_create_default_config
from .hopper_onppo_default_config import hopper_ppo_default_config, hopper_ppo_create_default_config
from .hopper_ddpg_default_config import hopper_ddpg_default_config
from .hopper_sac_default_config import hopper_sac_default_config
from .hopper_td3_default_config import hopper_td3_default_config
from .hopper_trex_onppo_default_config import hopper_trex_ppo_default_config, hopper_trex_ppo_create_default_config
from .hopper_trex_sac_default_config import hopper_trex_sac_default_config, hopper_trex_sac_default_create_config
from .walker2d_ddpg_default_config import walker2d_ddpg_default_config, walker2d_ddpg_default_create_config
from .walker2d_sac_default_config import walker2d_sac_default_config
from .walker2d_td3_default_config import walker2d_td3_default_config
from .walker2d_ddpg_gail_config import walker2d_ddpg_gail_default_config, walker2d_ddpg_gail_default_create_config
from .walker2d_trex_onppo_default_config import walker_trex_ppo_default_config, walker_trex_ppo_create_default_config
from .walker2d_trex_sac_default_config import walker2d_trex_sac_default_config, walker2d_trex_sac_default_create_config
from easydict import EasyDict
ant_ppo_default_config = dict(
exp_name = 'ant_onppo',
env=dict(
manager=dict(shared_memory=True, force_reproducibility=True),
env_id='Ant-v3',
norm_obs=dict(use_norm=False, ),
norm_reward=dict(use_norm=False, ),
collector_env_num=8,
evaluator_env_num=10,
use_act_scale=True,
n_evaluator_episode=10,
stop_value=6000,
),
policy=dict(
cuda=True,
recompute_adv=True,
model=dict(
obs_shape=111,
action_shape=8,
continuous=True,
),
continuous=True,
learn=dict(
epoch_per_collect=10,
batch_size=64,
learning_rate=3e-4,
value_weight=0.5,
entropy_weight=0.0,
clip_ratio=0.2,
adv_norm=True,
value_norm=True,
),
collect=dict(
n_sample=2048,
unroll_len=1,
discount_factor=0.99,
gae_lambda=0.97,
),
eval=dict(evaluator=dict(eval_freq=5000, )),
),
)
ant_ppo_default_config = EasyDict(ant_ppo_default_config)
main_config = ant_ppo_default_config
ant_ppo_create_default_config = dict(
env=dict(
type='mujoco',
import_names=['dizoo.mujoco.envs.mujoco_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='ppo', ),
)
ant_ppo_create_default_config = EasyDict(ant_ppo_create_default_config)
create_config = ant_ppo_create_default_config
from easydict import EasyDict
ant_sac_default_config = dict(
exp_name='ant_sac',
env=dict(
env_id='Ant-v3',
norm_obs=dict(use_norm=False, ),
......
from easydict import EasyDict
ant_trex_ppo_default_config = dict(
exp_name='ant_trex_onppo',
env=dict(
manager=dict(shared_memory=True, force_reproducibility=True),
env_id='Ant-v3',
norm_obs=dict(use_norm=False, ),
norm_reward=dict(use_norm=False, ),
collector_env_num=8,
evaluator_env_num=10,
use_act_scale=True,
n_evaluator_episode=10,
stop_value=6000,
),
reward_model=dict(
type='trex',
algo_for_model='ppo',
env_id='Ant-v3',
min_snippet_length=10,
max_snippet_length=100,
checkpoint_min=100,
checkpoint_max=900,
checkpoint_step=100,
learning_rate=1e-5,
update_per_collect=1,
expert_model_path='abs model path',
reward_model_path='abs data path + ./ant.params',
continuous=True,
offline_data_path='asb data path',
),
policy=dict(
cuda=True,
recompute_adv=True,
model=dict(
obs_shape=111,
action_shape=8,
continuous=True,
),
continuous=True,
learn=dict(
epoch_per_collect=10,
batch_size=64,
learning_rate=3e-4,
value_weight=0.5,
entropy_weight=0.0,
clip_ratio=0.2,
adv_norm=True,
value_norm=True,
),
collect=dict(
n_sample=2048,
unroll_len=1,
discount_factor=0.99,
gae_lambda=0.97,
),
eval=dict(evaluator=dict(eval_freq=5000, )),
),
)
ant_trex_ppo_default_config = EasyDict(ant_trex_ppo_default_config)
main_config = ant_trex_ppo_default_config
ant_trex_ppo_create_default_config = dict(
env=dict(
type='mujoco',
import_names=['dizoo.mujoco.envs.mujoco_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='ppo', ),
)
ant_trex_ppo_create_default_config = EasyDict(ant_trex_ppo_create_default_config)
create_config = ant_trex_ppo_create_default_config
from easydict import EasyDict
ant_trex_sac_default_config = dict(
exp_name='ant_trex_sac',
env=dict(
manager=dict(shared_memory=True, force_reproducibility=True),
env_id='Ant-v3',
norm_obs=dict(use_norm=False, ),
norm_reward=dict(use_norm=False, ),
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=6000,
),
reward_model=dict(
type='trex',
algo_for_model='sac',
env_id='Ant-v3',
min_snippet_length=30,
max_snippet_length=100,
checkpoint_min=1000,
checkpoint_max=9000,
checkpoint_step=1000,
learning_rate=1e-5,
update_per_collect=1,
expert_model_path='abs model path',
reward_model_path='abs data path + ./ant.params',
continuous=True,
offline_data_path='asb data path',
),
policy=dict(
cuda=True,
random_collect_size=10000,
model=dict(
obs_shape=111,
action_shape=8,
twin_critic=True,
actor_head_type='reparameterization',
actor_head_hidden_size=256,
critic_head_hidden_size=256,
),
learn=dict(
update_per_collect=1,
batch_size=256,
learning_rate_q=1e-3,
learning_rate_policy=1e-3,
learning_rate_alpha=3e-4,
ignore_done=False,
target_theta=0.005,
discount_factor=0.99,
alpha=0.2,
reparameterization=True,
auto_alpha=False,
),
collect=dict(
n_sample=1,
unroll_len=1,
),
command=dict(),
eval=dict(),
other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
),
)
ant_trex_sac_default_config = EasyDict(ant_trex_sac_default_config)
main_config = ant_trex_sac_default_config
ant_trex_sac_default_create_config = dict(
env=dict(
type='mujoco',
import_names=['dizoo.mujoco.envs.mujoco_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(
type='sac',
import_names=['ding.policy.sac'],
),
replay_buffer=dict(type='naive', ),
)
ant_trex_sac_default_create_config = EasyDict(ant_trex_sac_default_create_config)
create_config = ant_trex_sac_default_create_config
......@@ -4,7 +4,7 @@ from ding.entry import serial_pipeline_onpolicy
collector_env_num = 1
evaluator_env_num = 1
halfcheetah_ppo_default_config = dict(
exp_name="result_mujoco/halfcheetah_onppo_noig",
exp_name="Halfcheetah_onppo",
# exp_name="debug/debug_halfcheetah_onppo_ig",
env=dict(
......
from easydict import EasyDict
HalfCheetah_ppo_default_config = dict(
exp_name='HalfCheetah_onppo',
env=dict(
manager=dict(shared_memory=True, force_reproducibility=True),
env_id='HalfCheetah-v3',
norm_obs=dict(use_norm=False, ),
norm_reward=dict(use_norm=False, ),
collector_env_num=8,
evaluator_env_num=10,
use_act_scale=True,
n_evaluator_episode=10,
stop_value=3000,
),
policy=dict(
cuda=True,
recompute_adv=True,
model=dict(
obs_shape=17,
action_shape=6,
continuous=True,
),
continuous=True,
learn=dict(
epoch_per_collect=10,
batch_size=64,
learning_rate=3e-4,
value_weight=0.5,
entropy_weight=0.0,
clip_ratio=0.2,
adv_norm=True,
value_norm=True,
),
collect=dict(
n_sample=2048,
unroll_len=1,
discount_factor=0.99,
gae_lambda=0.97,
),
eval=dict(evaluator=dict(eval_freq=5000, )),
),
)
HalfCheetah_ppo_default_config = EasyDict(HalfCheetah_ppo_default_config)
main_config = HalfCheetah_ppo_default_config
HalfCheetah_ppo_create_default_config = dict(
env=dict(
type='mujoco',
import_names=['dizoo.mujoco.envs.mujoco_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='ppo', ),
)
HalfCheetah_ppo_create_default_config = EasyDict(HalfCheetah_ppo_create_default_config)
create_config = HalfCheetah_ppo_create_default_config
from easydict import EasyDict
halfcheetah_sac_default_config = dict(
exp_name = 'halfcheetah_sac',
env=dict(
env_id='HalfCheetah-v3',
norm_obs=dict(use_norm=False, ),
......
from easydict import EasyDict
halfCheetah_trex_ppo_default_config = dict(
exp_name='HalfCheetah_trex_onppo',
env=dict(
manager=dict(shared_memory=True, force_reproducibility=True),
env_id='HalfCheetah-v3',
norm_obs=dict(use_norm=False, ),
norm_reward=dict(use_norm=False, ),
collector_env_num=8,
evaluator_env_num=10,
use_act_scale=True,
n_evaluator_episode=10,
stop_value=3000,
),
reward_model=dict(
type='trex',
algo_for_model='ppo',
env_id='HalfCheetah-v3',
min_snippet_length=30,
max_snippet_length=100,
checkpoint_min=100,
checkpoint_max=900,
checkpoint_step=100,
learning_rate=1e-5,
update_per_collect=1,
expert_model_path='abs model path',
reward_model_path='abs data path + /HalfCheetah.params',
continuous=True,
offline_data_path='asb data path',
),
policy=dict(
cuda=True,
recompute_adv=True,
model=dict(
obs_shape=17,
action_shape=6,
continuous=True,
),
continuous=True,
learn=dict(
epoch_per_collect=10,
batch_size=64,
learning_rate=3e-4,
value_weight=0.5,
entropy_weight=0.0,
clip_ratio=0.2,
adv_norm=True,
value_norm=True,
),
collect=dict(
n_sample=2048,
unroll_len=1,
discount_factor=0.99,
gae_lambda=0.97,
),
eval=dict(evaluator=dict(eval_freq=5000, )),
),
)
halfCheetah_trex_ppo_default_config = EasyDict(halfCheetah_trex_ppo_default_config)
main_config = halfCheetah_trex_ppo_default_config
halfCheetah_trex_ppo_create_default_config = dict(
env=dict(
type='mujoco',
import_names=['dizoo.mujoco.envs.mujoco_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='ppo', ),
)
halfCheetah_trex_ppo_create_default_config = EasyDict(halfCheetah_trex_ppo_create_default_config)
create_config = halfCheetah_trex_ppo_create_default_config
from easydict import EasyDict
halfcheetah_trex_sac_default_config = dict(
exp_name='halfcheetah_trex_sac',
env=dict(
manager=dict(shared_memory=True, force_reproducibility=True),
env_id='HalfCheetah-v3',
norm_obs=dict(use_norm=False, ),
norm_reward=dict(use_norm=False, ),
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=12000,
),
reward_model=dict(
type='trex',
algo_for_model='sac',
env_id='HalfCheetah-v3',
learning_rate=1e-5,
min_snippet_length=30,
max_snippet_length=100,
checkpoint_min=1000,
checkpoint_max=9000,
checkpoint_step=1000,
update_per_collect=1,
expert_model_path='abs model path',
reward_model_path='abs data path + /HalfCheetah.params',
continuous=True,
offline_data_path='asb data path',
),
policy=dict(
cuda=True,
random_collect_size=10000,
model=dict(
obs_shape=17,
action_shape=6,
twin_critic=True,
actor_head_type='reparameterization',
actor_head_hidden_size=256,
critic_head_hidden_size=256,
),
learn=dict(
update_per_collect=1,
batch_size=256,
learning_rate_q=1e-3,
learning_rate_policy=1e-3,
learning_rate_alpha=3e-4,
ignore_done=True,
target_theta=0.005,
discount_factor=0.99,
alpha=0.2,
reparameterization=True,
auto_alpha=False,
),
collect=dict(
n_sample=1,
unroll_len=1,
),
command=dict(),
eval=dict(),
other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
),
)
halfcheetah_trex_sac_default_config = EasyDict(halfcheetah_trex_sac_default_config)
main_config = halfcheetah_trex_sac_default_config
halfcheetah_trex_sac_default_create_config = dict(
env=dict(
type='mujoco',
import_names=['dizoo.mujoco.envs.mujoco_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(
type='sac',
import_names=['ding.policy.sac'],
),
replay_buffer=dict(type='naive', ),
)
halfcheetah_trex_sac_default_create_config = EasyDict(halfcheetah_trex_sac_default_create_config)
create_config = halfcheetah_trex_sac_default_create_config
from easydict import EasyDict
hopper_sac_default_config = dict(
exp_name = 'hopper_sac',
env=dict(
env_id='Hopper-v3',
norm_obs=dict(use_norm=False, ),
......
from easydict import EasyDict
hopper_trex_ppo_default_config = dict(
exp_name='hopper_trex_onppo',
env=dict(
manager=dict(shared_memory=True, force_reproducibility=True),
env_id='Hopper-v3',
norm_obs=dict(use_norm=False, ),
norm_reward=dict(use_norm=False, ),
collector_env_num=8,
evaluator_env_num=10,
use_act_scale=True,
n_evaluator_episode=10,
stop_value=3000,
),
reward_model=dict(
type='trex',
algo_for_model='ppo',
env_id='Hopper-v3',
min_snippet_length=30,
max_snippet_length=100,
checkpoint_min=1000,
checkpoint_max=9000,
checkpoint_step=1000,
learning_rate=1e-5,
update_per_collect=1,
expert_model_path='abs model path',
reward_model_path='abs data path + /hopper.params',
continuous=True,
offline_data_path='asb data path',
),
policy=dict(
cuda=True,
recompute_adv=True,
model=dict(
obs_shape=11,
action_shape=3,
continuous=True,
),
continuous=True,
learn=dict(
epoch_per_collect=10,
batch_size=64,
learning_rate=3e-4,
value_weight=0.5,
entropy_weight=0.0,
clip_ratio=0.2,
adv_norm=True,
value_norm=True,
),
collect=dict(
n_sample=2048,
unroll_len=1,
discount_factor=0.99,
gae_lambda=0.97,
),
eval=dict(evaluator=dict(eval_freq=5000, )),
),
)
hopper_trex_ppo_default_config = EasyDict(hopper_trex_ppo_default_config)
main_config = hopper_trex_ppo_default_config
hopper_trex_ppo_create_default_config = dict(
env=dict(
type='mujoco',
import_names=['dizoo.mujoco.envs.mujoco_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='ppo', ),
)
hopper_trex_ppo_create_default_config = EasyDict(hopper_trex_ppo_create_default_config)
create_config = hopper_trex_ppo_create_default_config
from easydict import EasyDict
hopper_trex_sac_default_config = dict(
exp_name='hopper_trex_sac',
env=dict(
manager=dict(shared_memory=True, force_reproducibility=True),
env_id='Hopper-v3',
norm_obs=dict(use_norm=False, ),
norm_reward=dict(use_norm=False, ),
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=6000,
),
reward_model=dict(
type='trex',
algo_for_model='sac',
env_id='Hopper-v3',
min_snippet_length=30,
max_snippet_length=100,
checkpoint_min=1000,
checkpoint_max=9000,
checkpoint_step=1000,
learning_rate=1e-5,
update_per_collect=1,
expert_model_path='abs model path',
reward_model_path='abs data path + /hopper.params',
continuous=True,
offline_data_path='asb data path',
),
policy=dict(
cuda=True,
random_collect_size=10000,
model=dict(
obs_shape=11,
action_shape=3,
twin_critic=True,
actor_head_type='reparameterization',
actor_head_hidden_size=256,
critic_head_hidden_size=256,
),
learn=dict(
update_per_collect=1,
batch_size=256,
learning_rate_q=1e-3,
learning_rate_policy=1e-3,
learning_rate_alpha=3e-4,
ignore_done=False,
target_theta=0.005,
discount_factor=0.99,
alpha=0.2,
reparameterization=True,
auto_alpha=False,
),
collect=dict(
n_sample=1,
unroll_len=1,
),
command=dict(),
eval=dict(),
other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
),
)
hopper_trex_sac_default_config = EasyDict(hopper_trex_sac_default_config)
main_config = hopper_trex_sac_default_config
hopper_trex_sac_default_create_config = dict(
env=dict(
type='mujoco',
import_names=['dizoo.mujoco.envs.mujoco_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(
type='sac',
import_names=['ding.policy.sac'],
),
replay_buffer=dict(type='naive', ),
)
hopper_trex_sac_default_create_config = EasyDict(hopper_trex_sac_default_create_config)
create_config = hopper_trex_sac_default_create_config
from easydict import EasyDict
walker_ppo_default_config = dict(
exp_name='walker2d_onppo',
env=dict(
manager=dict(shared_memory=True, force_reproducibility=True),
env_id='Walker2d-v3',
norm_obs=dict(use_norm=False, ),
norm_reward=dict(use_norm=False, ),
collector_env_num=8,
evaluator_env_num=10,
use_act_scale=True,
n_evaluator_episode=10,
stop_value=3000,
),
policy=dict(
cuda=True,
recompute_adv=True,
model=dict(
obs_shape=17,
action_shape=6,
continuous=True,
),
continuous=True,
learn=dict(
epoch_per_collect=10,
batch_size=64,
learning_rate=3e-4,
value_weight=0.5,
entropy_weight=0.0,
clip_ratio=0.2,
adv_norm=True,
value_norm=True,
),
collect=dict(
n_sample=2048,
unroll_len=1,
discount_factor=0.99,
gae_lambda=0.97,
),
eval=dict(evaluator=dict(eval_freq=5000, )),
),
)
walker_ppo_default_config = EasyDict(walker_ppo_default_config)
main_config = walker_ppo_default_config
walker_ppo_create_default_config = dict(
env=dict(
type='mujoco',
import_names=['dizoo.mujoco.envs.mujoco_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(
type='ppo',
),
)
walker_ppo_create_default_config = EasyDict(walker_ppo_create_default_config)
create_config = walker_ppo_create_default_config
from easydict import EasyDict
walker_trex_ppo_default_config = dict(
exp_name = 'walker2d_trex_onppo',
env=dict(
manager=dict(shared_memory=True, force_reproducibility=True),
env_id='Walker2d-v3',
norm_obs=dict(use_norm=False, ),
norm_reward=dict(use_norm=False, ),
collector_env_num=8,
evaluator_env_num=10,
use_act_scale=True,
n_evaluator_episode=10,
stop_value=3000,
),
reward_model=dict(
type='trex',
algo_for_model='ppo',
env_id='Walker2d-v3',
min_snippet_length=30,
max_snippet_length=100,
checkpoint_min=0,
checkpoint_max=1000,
checkpoint_step=1000,
learning_rate=1e-5,
update_per_collect=1,
expert_model_path='abs model path',
reward_model_path='abs data path+ ./walker2d.params',
continuous=True,
offline_data_path='asb data path',
),
policy=dict(
cuda=True,
recompute_adv=True,
model=dict(
obs_shape=17,
action_shape=6,
continuous=True,
),
continuous=True,
learn=dict(
epoch_per_collect=10,
batch_size=64,
learning_rate=3e-4,
value_weight=0.5,
entropy_weight=0.0,
clip_ratio=0.2,
adv_norm=True,
value_norm=True,
),
collect=dict(
n_sample=2048,
unroll_len=1,
discount_factor=0.99,
gae_lambda=0.97,
),
eval=dict(evaluator=dict(eval_freq=5000, )),
),
)
walker_trex_ppo_default_config = EasyDict(walker_trex_ppo_default_config)
main_config = walker_trex_ppo_default_config
walker_trex_ppo_create_default_config = dict(
env=dict(
type='mujoco',
import_names=['dizoo.mujoco.envs.mujoco_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(
type='ppo',
),
)
walker_trex_ppo_create_default_config = EasyDict(walker_trex_ppo_create_default_config)
create_config = walker_trex_ppo_create_default_config
from easydict import EasyDict
walker2d_trex_sac_default_config = dict(
exp_name = 'walker2d_trex_sac',
env=dict(
manager=dict(shared_memory=True, force_reproducibility=True),
env_id='Walker2d-v3',
norm_obs=dict(use_norm=False, ),
norm_reward=dict(use_norm=False, ),
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=6000,
),
reward_model=dict(
type='trex',
algo_for_model='sac',
env_id='Walker2d-v3',
min_snippet_length=30,
max_snippet_length=100,
checkpoint_min=100,
checkpoint_max=900,
checkpoint_step=100,
learning_rate=1e-5,
update_per_collect=1,
expert_model_path='abs model path',
reward_model_path='abs data path + ./walker2d.params',
continuous=True,
offline_data_path='asb data path',
),
policy=dict(
cuda=True,
random_collect_size=10000,
model=dict(
obs_shape=17,
action_shape=6,
twin_critic=True,
actor_head_type='reparameterization',
actor_head_hidden_size=256,
critic_head_hidden_size=256,
),
learn=dict(
update_per_collect=1,
batch_size=256,
learning_rate_q=1e-3,
learning_rate_policy=1e-3,
learning_rate_alpha=3e-4,
ignore_done=False,
target_theta=0.005,
discount_factor=0.99,
alpha=0.2,
reparameterization=True,
auto_alpha=False,
),
collect=dict(
n_sample=1,
unroll_len=1,
),
command=dict(),
eval=dict(),
other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
),
)
walker2d_trex_sac_default_config = EasyDict(walker2d_trex_sac_default_config)
main_config = walker2d_trex_sac_default_config
walker2d_trex_sac_default_create_config = dict(
env=dict(
type='mujoco',
import_names=['dizoo.mujoco.envs.mujoco_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(
type='sac',
import_names=['ding.policy.sac'],
),
replay_buffer=dict(type='naive', ),
)
walker2d_trex_sac_default_create_config = EasyDict(walker2d_trex_sac_default_create_config)
create_config = walker2d_trex_sac_default_create_config
import argparse
import torch
from ding.entry import trex_collecting_data
from dizoo.mujoco.config.halfcheetah_trex_onppo_default_config import main_config, create_config
from ding.entry import serial_pipeline_reward_model_trex_onpolicy, serial_pipeline_reward_model_trex
# Note serial_pipeline_reward_model_trex_onpolicy is for on policy ppo whereas serial_pipeline_reward_model_trex is for sac
# Note before run this file, please add the correpsonding path in the config, all path expect exp_name should be abs path
parser = argparse.ArgumentParser()
parser.add_argument('--cfg', type=str, default='please enter abs path for halfcheetah_trex_onppo_default_config.py or halfcheetah_trex_sac_default_config.py')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
args = parser.parse_args()
trex_collecting_data(args)
# if run sac, please import the relevant config and use serial_pipeline_reward_model_trex
serial_pipeline_reward_model_trex_onpolicy([main_config, create_config])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册