From 63105fef8526c8a45f2f18886d9334d6a346b51d Mon Sep 17 00:00:00 2001 From: Will-Nie <61083608+Will-Nie@users.noreply.github.com> Date: Wed, 8 Dec 2021 22:39:21 +0800 Subject: [PATCH] feature(nyp): add Trex algorithm (#119) * add trex algorithm for pong * sort style * add atari, ll,cp; fix device, collision; add_ppo * add accuracy evaluation * correct style * add seed to make sure results are replicable * remove useless part in cum return of model part * add mujoco onppo training pipeline; ppo config * improve style * add sac training config for mujoco * add log, add save data; polish config * logger; hyperparameter;walker * correct style * modify else condition * change rnd to trex * revise according to comments, add eposode collect * new collect mode for trex, fix all bugs, commnets * final change * polish after the final comment * add readme/test * add test for serial entry of trex/gcl * sort style --- README.md | 17 +- ding/entry/__init__.py | 6 +- ding/entry/application_entry.py | 100 +++- .../application_entry_trex_collect_data.py | 166 +++++++ ding/entry/cli.py | 14 +- ding/entry/serial_entry_trex.py | 138 ++++++ ding/entry/serial_entry_trex_onpolicy.py | 100 ++++ ding/entry/tests/test_application_entry.py | 30 +- ...est_application_entry_trex_collect_data.py | 65 +++ .../tests/test_serial_entry_guided_cost.py | 23 + ding/entry/tests/test_serial_entry_trex.py | 33 ++ .../tests/test_serial_entry_trex_onpolicy.py | 31 ++ ding/reward_model/__init__.py | 1 + ding/reward_model/base_reward_model.py | 5 +- ding/reward_model/trex_reward_model.py | 431 ++++++++++++++++++ .../serial/pong/pong_trex_offppo_config.py | 78 ++++ .../serial/pong/pong_trex_sql_config.py | 64 +++ .../serial/qbert/qbert_trex_dqn_config.py | 70 +++ .../serial/qbert/qbert_trex_offppo_config.py | 76 +++ .../spaceinvaders_trex_dqn_config.py | 70 +++ .../spaceinvaders_trex_offppo_config.py | 76 +++ dizoo/box2d/lunarlander/config/__init__.py | 2 + .../config/lunarlander_ppo_config.py | 2 + .../config/lunarlander_trex_dqn_config.py | 86 ++++ .../config/lunarlander_trex_offppo_config.py | 63 +++ .../cartpole/config/__init__.py | 3 + .../cartpole/config/cartpole_gcl_config.py | 16 +- .../config/cartpole_trex_dqn_config.py | 67 +++ .../config/cartpole_trex_offppo_config.py | 65 +++ dizoo/mujoco/config/__init__.py | 10 + dizoo/mujoco/config/ant_ppo_default_config.py | 56 +++ dizoo/mujoco/config/ant_sac_default_config.py | 1 + .../config/ant_trex_onppo_default_config.py | 72 +++ .../config/ant_trex_sac_default_config.py | 82 ++++ .../halfcheetah_onppo_default_config.py | 2 +- .../config/halfcheetah_ppo_default_config.py | 56 +++ .../config/halfcheetah_sac_default_config.py | 1 + .../halfcheetah_trex_onppo_default_config.py | 72 +++ .../halfcheetah_trex_sac_default_config.py | 82 ++++ .../config/hopper_sac_default_config.py | 1 + .../hopper_trex_onppo_default_config.py | 72 +++ .../config/hopper_trex_sac_default_config.py | 82 ++++ .../config/walker2d_ppo_default_config.py | 58 +++ .../walker2d_trex_onppo_default_config.py | 74 +++ .../walker2d_trex_sac_default_config.py | 82 ++++ dizoo/mujoco/entry/mujoco_trex_main.py | 22 + 46 files changed, 2697 insertions(+), 26 deletions(-) create mode 100644 ding/entry/application_entry_trex_collect_data.py create mode 100644 ding/entry/serial_entry_trex.py create mode 100644 ding/entry/serial_entry_trex_onpolicy.py create mode 100644 ding/entry/tests/test_application_entry_trex_collect_data.py create mode 100644 ding/entry/tests/test_serial_entry_guided_cost.py create mode 100644 ding/entry/tests/test_serial_entry_trex.py create mode 100644 ding/entry/tests/test_serial_entry_trex_onpolicy.py create mode 100644 ding/reward_model/trex_reward_model.py create mode 100644 dizoo/atari/config/serial/pong/pong_trex_offppo_config.py create mode 100644 dizoo/atari/config/serial/pong/pong_trex_sql_config.py create mode 100644 dizoo/atari/config/serial/qbert/qbert_trex_dqn_config.py create mode 100644 dizoo/atari/config/serial/qbert/qbert_trex_offppo_config.py create mode 100644 dizoo/atari/config/serial/spaceinvaders/spaceinvaders_trex_dqn_config.py create mode 100644 dizoo/atari/config/serial/spaceinvaders/spaceinvaders_trex_offppo_config.py create mode 100644 dizoo/box2d/lunarlander/config/lunarlander_trex_dqn_config.py create mode 100644 dizoo/box2d/lunarlander/config/lunarlander_trex_offppo_config.py create mode 100644 dizoo/classic_control/cartpole/config/cartpole_trex_dqn_config.py create mode 100644 dizoo/classic_control/cartpole/config/cartpole_trex_offppo_config.py create mode 100644 dizoo/mujoco/config/ant_ppo_default_config.py create mode 100644 dizoo/mujoco/config/ant_trex_onppo_default_config.py create mode 100644 dizoo/mujoco/config/ant_trex_sac_default_config.py create mode 100644 dizoo/mujoco/config/halfcheetah_ppo_default_config.py create mode 100644 dizoo/mujoco/config/halfcheetah_trex_onppo_default_config.py create mode 100644 dizoo/mujoco/config/halfcheetah_trex_sac_default_config.py create mode 100644 dizoo/mujoco/config/hopper_trex_onppo_default_config.py create mode 100644 dizoo/mujoco/config/hopper_trex_sac_default_config.py create mode 100644 dizoo/mujoco/config/walker2d_ppo_default_config.py create mode 100644 dizoo/mujoco/config/walker2d_trex_onppo_default_config.py create mode 100644 dizoo/mujoco/config/walker2d_trex_sac_default_config.py create mode 100644 dizoo/mujoco/entry/mujoco_trex_main.py diff --git a/README.md b/README.md index 6b0a12f..5c79574 100644 --- a/README.md +++ b/README.md @@ -136,14 +136,15 @@ ding -m serial -e cartpole -p dqn -s 0 | 26 | [DQFD](https://arxiv.org/pdf/1704.03732.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [policy/dqfd](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqfd.py) | ding -m serial_dqfd -c cartpole_dqfd_config.py -s 0 | | 27 | [R2D3](https://arxiv.org/pdf/1909.01387.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [policy/r2d3](https://github.com/opendilab/DI-engine/blob/main/ding/policy/r2d3.py) | python3 -u pong_r2d3_r2d2expert_config.py | | 28 | [Guided Cost Learning](https://arxiv.org/pdf/1603.00448.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [reward_model/guided_cost](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/guided_cost_reward_model.py) | python3 lunarlander_gcl_config.py | -| 29 | [HER](https://arxiv.org/pdf/1707.01495.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/her](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/her_reward_model.py) | python3 -u bitflip_her_dqn.py | -| 30 | [RND](https://arxiv.org/abs/1810.12894) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/rnd](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/rnd_reward_model.py) | python3 -u cartpole_ppo_rnd_main.py | -| 31 | [ICM](https://arxiv.org/pdf/1705.05363.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/icm](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/icm_reward_model.py) | python3 -u cartpole_ppo_icm_config.py | -| 32 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py | -| 33 | [TD3BC](https://arxiv.org/pdf/2106.06860.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/td3_bc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3_bc.py) | python3 -u mujoco_td3_bc_main.py | -| 34 | [MBPO](https://arxiv.org/pdf/1906.08253.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [model/template/model_based/mbpo](https://github.com/opendilab/DI-engine/blob/main/ding/model/template/model_based/mbpo.py) | python3 -u sac_halfcheetah_mopo_default_config.py | -| 35 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` | -| 36 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` | +| 29 | [TREX](https://arxiv.org/abs/1904.06387) | ![IL](https://img.shields.io/badge/-IL-purple) | [reward_model/trex](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/trex_reward_model.py) | python3 mujoco_trex_main.py +| 30 | [HER](https://arxiv.org/pdf/1707.01495.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/her](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/her_reward_model.py) | python3 -u bitflip_her_dqn.py | +| 31 | [RND](https://arxiv.org/abs/1810.12894) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/rnd](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/rnd_reward_model.py) | python3 -u cartpole_ppo_rnd_main.py | +| 32 | [ICM](https://arxiv.org/pdf/1705.05363.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [reward_model/icm](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/icm_reward_model.py) | python3 -u cartpole_ppo_icm_config.py | +| 33 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py | +| 34 | [TD3BC](https://arxiv.org/pdf/2106.06860.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/td3_bc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3_bc.py) | python3 -u mujoco_td3_bc_main.py | +| 35 | [MBPO](https://arxiv.org/pdf/1906.08253.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [model/template/model_based/mbpo](https://github.com/opendilab/DI-engine/blob/main/ding/model/template/model_based/mbpo.py) | python3 -u sac_halfcheetah_mopo_default_config.py | +| 36 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` | +| 37 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space, which is only label in normal DRL algorithms (1-18) diff --git a/ding/entry/__init__.py b/ding/entry/__init__.py index 2358329..d7ec000 100644 --- a/ding/entry/__init__.py +++ b/ding/entry/__init__.py @@ -10,6 +10,10 @@ from .serial_entry_mbrl import serial_pipeline_mbrl from .serial_entry_dqfd import serial_pipeline_dqfd from .serial_entry_r2d3 import serial_pipeline_r2d3 from .serial_entry_sqil import serial_pipeline_sqil +from .serial_entry_trex import serial_pipeline_reward_model_trex +from .serial_entry_trex_onpolicy import serial_pipeline_reward_model_trex_onpolicy from .parallel_entry import parallel_pipeline -from .application_entry import eval, collect_demo_data +from .application_entry import eval, collect_demo_data, collect_episodic_demo_data, \ + epsiode_to_transitions +from .application_entry_trex_collect_data import trex_collecting_data, collect_episodic_demo_data_for_trex from .serial_entry_guided_cost import serial_pipeline_guided_cost diff --git a/ding/entry/application_entry.py b/ding/entry/application_entry.py index ba886c1..fc7540c 100644 --- a/ding/entry/application_entry.py +++ b/ding/entry/application_entry.py @@ -2,14 +2,17 @@ from typing import Union, Optional, List, Any, Tuple import pickle import torch from functools import partial +import os from ding.config import compile_config, read_config -from ding.worker import SampleSerialCollector, InteractionSerialEvaluator +from ding.worker import SampleSerialCollector, InteractionSerialEvaluator, EpisodeSerialCollector from ding.envs import create_env_manager, get_vec_env_setting from ding.policy import create_policy from ding.torch_utils import to_device from ding.utils import set_pkg_seed from ding.utils.data import offline_data_save_type +from ding.rl_utils import get_nstep_return_data +from ding.utils.data import default_collate def eval( @@ -141,7 +144,7 @@ def collect_demo_data( policy.collect_mode.load_state_dict(state_dict) collector = SampleSerialCollector(cfg.policy.collect.collector, collector_env, collect_demo_policy) - policy_kwargs = None if not hasattr(cfg.policy.other.get('eps', None), 'collect') \ + policy_kwargs = None if not hasattr(cfg.policy.other, 'eps') \ else {'eps': cfg.policy.other.eps.get('collect', 0.2)} # Let's collect some expert demonstrations @@ -151,3 +154,96 @@ def collect_demo_data( # Save data transitions. offline_data_save_type(exp_data, expert_data_path, data_type=cfg.policy.collect.get('data_type', 'naive')) print('Collect demo data successfully') + + +def collect_episodic_demo_data( + input_cfg: Union[str, dict], + seed: int, + collect_count: int, + expert_data_path: str, + env_setting: Optional[List[Any]] = None, + model: Optional[torch.nn.Module] = None, + state_dict: Optional[dict] = None, + state_dict_path: Optional[str] = None, +) -> None: + r""" + Overview: + Collect episodic demonstration data by the trained policy. + Arguments: + - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ + ``str`` type means config file path. \ + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - collect_count (:obj:`int`): The count of collected data. + - expert_data_path (:obj:`str`): File path of the expert demo data will be written to. + - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ + ``BaseEnv`` subclass, collector env config, and evaluator env config. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - state_dict (:obj:`Optional[dict]`): The state_dict of policy or model. + - state_dict_path (:obj:'str') the abs path of the state dict + """ + if isinstance(input_cfg, str): + cfg, create_cfg = read_config(input_cfg) + else: + cfg, create_cfg = input_cfg + create_cfg.policy.type += '_command' + env_fn = None if env_setting is None else env_setting[0] + cfg = compile_config( + cfg, + collector=EpisodeSerialCollector, + seed=seed, + env=env_fn, + auto=True, + create_cfg=create_cfg, + save_cfg=True, + save_path='collect_demo_data_config.py' + ) + + # Create components: env, policy, collector + if env_setting is None: + env_fn, collector_env_cfg, _ = get_vec_env_setting(cfg.env) + else: + env_fn, collector_env_cfg, _ = env_setting + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + collector_env.seed(seed) + set_pkg_seed(seed, use_cuda=cfg.policy.cuda) + policy = create_policy(cfg.policy, model=model, enable_field=['collect', 'eval']) + collect_demo_policy = policy.collect_mode + if state_dict is None: + assert state_dict_path is not None + state_dict = torch.load(state_dict_path, map_location='cpu') + policy.collect_mode.load_state_dict(state_dict) + collector = EpisodeSerialCollector(cfg.policy.collect.collector, collector_env, collect_demo_policy) + + policy_kwargs = None if not hasattr(cfg.policy.other, 'eps') \ + else {'eps': cfg.policy.other.eps.get('collect', 0.2)} + + # Let's collect some expert demostrations + exp_data = collector.collect(n_episode=collect_count, policy_kwargs=policy_kwargs) + if cfg.policy.cuda: + exp_data = to_device(exp_data, 'cpu') + # Save data transitions. + offline_data_save_type(exp_data, expert_data_path, data_type=cfg.policy.collect.get('data_type', 'naive')) + print('Collect episodic demo data successfully') + + +def epsiode_to_transitions(data_path: str, expert_data_path: str, nstep: int) -> None: + r""" + Overview: + Transfer episoded data into nstep transitions + Arguments: + - data_path (:obj:str): data path that stores the pkl file + - expert_data_path (:obj:`str`): File path of the expert demo data will be written to. + - nstep (:obj:`int`): {s_{t}, a_{t}, s_{t+n}}. + + """ + with open(data_path, 'rb') as f: + _dict = pickle.load(f) # class is list; length is cfg.reward_model.collect_count + post_process_data = [] + for i in range(len(_dict)): + data = get_nstep_return_data(_dict[i], nstep) + post_process_data.extend(data) + offline_data_save_type( + post_process_data, + expert_data_path, + ) diff --git a/ding/entry/application_entry_trex_collect_data.py b/ding/entry/application_entry_trex_collect_data.py new file mode 100644 index 0000000..9c49221 --- /dev/null +++ b/ding/entry/application_entry_trex_collect_data.py @@ -0,0 +1,166 @@ +import argparse +import torch +from typing import Union, Optional, List, Any +from functools import partial +import os +from copy import deepcopy + +from ding.config import compile_config, read_config +from ding.worker import EpisodeSerialCollector +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.torch_utils import to_device +from ding.utils import set_pkg_seed +from ding.utils.data import offline_data_save_type +from ding.utils.data import default_collate + + +def collect_episodic_demo_data_for_trex( + input_cfg: Union[str, dict], + seed: int, + collect_count: int, + rank: int, + save_cfg_path: str, + env_setting: Optional[List[Any]] = None, + model: Optional[torch.nn.Module] = None, + state_dict: Optional[dict] = None, + state_dict_path: Optional[str] = None, +) -> None: + r""" + Overview: + Collect episodic demonstration data by the trained policy for trex specifically. + Arguments: + - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ + ``str`` type means config file path. \ + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - collect_count (:obj:`int`): The count of collected data. + - rank (:obj:`int`) the episode ranking. + - save_cfg_path(:obj:'str') where to save the collector config + - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ + ``BaseEnv`` subclass, collector env config, and evaluator env config. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - state_dict (:obj:`Optional[dict]`): The state_dict of policy or model. + - state_dict_path (:obj:'str') the abs path of the state dict + """ + if isinstance(input_cfg, str): + cfg, create_cfg = read_config(input_cfg) + else: + cfg, create_cfg = input_cfg + create_cfg.policy.type += '_command' + env_fn = None if env_setting is None else env_setting[0] + cfg.env.collector_env_num = 1 + if not os.path.exists(save_cfg_path): + os.mkdir(save_cfg_path) + cfg = compile_config( + cfg, + collector=EpisodeSerialCollector, + seed=seed, + env=env_fn, + auto=True, + create_cfg=create_cfg, + save_cfg=True, + save_path=save_cfg_path + '/collect_demo_data_config.py' + ) + + # Create components: env, policy, collector + if env_setting is None: + env_fn, collector_env_cfg, _ = get_vec_env_setting(cfg.env) + else: + env_fn, collector_env_cfg, _ = env_setting + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + collector_env.seed(seed) + set_pkg_seed(seed, use_cuda=cfg.policy.cuda) + policy = create_policy(cfg.policy, model=model, enable_field=['collect', 'eval']) + collect_demo_policy = policy.collect_mode + if state_dict is None: + assert state_dict_path is not None + state_dict = torch.load(state_dict_path, map_location='cpu') + policy.collect_mode.load_state_dict(state_dict) + collector = EpisodeSerialCollector(cfg.policy.collect.collector, collector_env, collect_demo_policy) + + policy_kwargs = None if not hasattr(cfg.policy.other, 'eps') \ + else {'eps': cfg.policy.other.eps.get('collect', 0.2)} + + # Let's collect some sub-optimal demostrations + exp_data = collector.collect(n_episode=collect_count, policy_kwargs=policy_kwargs) + if cfg.policy.cuda: + exp_data = to_device(exp_data, 'cpu') + # Save data transitions. + print('Collect {}th episodic demo data successfully'.format(rank)) + return exp_data + + +def trex_get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--cfg', type=str, default='abs path for a config') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') + args = parser.parse_known_args()[0] + return args + + +def trex_collecting_data(args=trex_get_args()): + if isinstance(args.cfg, str): + cfg, create_cfg = read_config(args.cfg) + else: + cfg, create_cfg = deepcopy(args.cfg) + create_cfg.policy.type = create_cfg.policy.type + '_command' + compiled_cfg = compile_config(cfg, seed=args.seed, auto=True, create_cfg=create_cfg, save_cfg=False) + offline_data_path = compiled_cfg.reward_model.offline_data_path + expert_model_path = compiled_cfg.reward_model.expert_model_path + checkpoint_min = compiled_cfg.reward_model.checkpoint_min + checkpoint_max = compiled_cfg.reward_model.checkpoint_max + checkpoint_step = compiled_cfg.reward_model.checkpoint_step + checkpoints = [] + for i in range(checkpoint_min, checkpoint_max + checkpoint_step, checkpoint_step): + checkpoints.append(str(i)) + data_for_save = {} + learning_returns = [] + learning_rewards = [] + episodes_data = [] + for checkpoint in checkpoints: + model_path = expert_model_path + \ + '/ckpt/iteration_' + checkpoint + '.pth.tar' + seed = args.seed + (int(checkpoint) - int(checkpoint_min)) // int(checkpoint_step) + exp_data = collect_episodic_demo_data_for_trex( + deepcopy(args.cfg), + seed, + state_dict_path=model_path, + save_cfg_path=offline_data_path, + collect_count=1, + rank=(int(checkpoint) - int(checkpoint_min)) // int(checkpoint_step) + 1 + ) + data_for_save[(int(checkpoint) - int(checkpoint_min)) // int(checkpoint_step)] = exp_data[0] + obs = list(default_collate(exp_data[0])['obs'].numpy()) + learning_rewards.append(default_collate(exp_data[0])['reward'].tolist()) + sum_reward = torch.sum(default_collate(exp_data[0])['reward']).item() + learning_returns.append(sum_reward) + episodes_data.append(obs) + offline_data_save_type( + data_for_save, + offline_data_path + '/suboptimal_data.pkl', + data_type=cfg.policy.collect.get('data_type', 'naive') + ) + # if not compiled_cfg.reward_model.auto: more feature + offline_data_save_type( + episodes_data, offline_data_path + '/episodes_data.pkl', data_type=cfg.policy.collect.get('data_type', 'naive') + ) + offline_data_save_type( + learning_returns, + offline_data_path + '/learning_returns.pkl', + data_type=cfg.policy.collect.get('data_type', 'naive') + ) + offline_data_save_type( + learning_rewards, + offline_data_path + '/learning_rewards.pkl', + data_type=cfg.policy.collect.get('data_type', 'naive') + ) + offline_data_save_type( + checkpoints, offline_data_path + '/checkpoints.pkl', data_type=cfg.policy.collect.get('data_type', 'naive') + ) + return checkpoints, episodes_data, learning_returns, learning_rewards + + +if __name__ == '__main__': + trex_collecting_data() diff --git a/ding/entry/cli.py b/ding/entry/cli.py index b0af753..8eac954 100644 --- a/ding/entry/cli.py +++ b/ding/entry/cli.py @@ -54,8 +54,8 @@ CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) '--mode', type=click.Choice( [ - 'serial', 'serial_onpolicy', 'serial_sqil', 'serial_dqfd', 'parallel', 'dist', 'eval', - 'serial_reward_model', 'serial_gail' + 'serial', 'serial_onpolicy', 'serial_sqil', 'serial_dqfd', 'serial_trex', 'serial_trex_onpolicy', + 'parallel', 'dist', 'eval', 'serial_reward_model', 'serial_gail' ] ), help='serial-train or parallel-train or dist-train or eval' @@ -182,6 +182,16 @@ def cli( + "the models used in q learning now; However, one should still type the DQFD config in this "\ + "place, i.e., {}{}".format(config[:config.find('_dqfd')], '_dqfd_config.py') serial_pipeline_dqfd(config, expert_config, seed, max_iterations=train_iter) + elif mode == 'serial_trex': + from .serial_entry_trex import serial_pipeline_reward_model_trex + if config is None: + config = get_predefined_config(env, policy) + serial_pipeline_reward_model_trex(config, seed, max_iterations=train_iter) + elif mode == 'serial_trex_onpolicy': + from .serial_entry_trex_onpolicy import serial_pipeline_reward_model_trex_onpolicy + if config is None: + config = get_predefined_config(env, policy) + serial_pipeline_reward_model_trex_onpolicy(config, seed, max_iterations=train_iter) elif mode == 'parallel': from .parallel_entry import parallel_pipeline parallel_pipeline(config, seed, enable_total_log, disable_flask_log) diff --git a/ding/entry/serial_entry_trex.py b/ding/entry/serial_entry_trex.py new file mode 100644 index 0000000..c559da0 --- /dev/null +++ b/ding/entry/serial_entry_trex.py @@ -0,0 +1,138 @@ +from typing import Union, Optional, List, Any, Tuple +import os +import torch +import logging +from functools import partial +from tensorboardX import SummaryWriter + +from ding.envs import get_vec_env_setting, create_env_manager +from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \ + create_serial_collector +from ding.config import read_config, compile_config +from ding.policy import create_policy, PolicyFactory +from ding.reward_model import create_reward_model +from ding.utils import set_pkg_seed +# from dizoo.atari.config.serial.pong.pong_trex_sql_config import main_config, create_config +from dizoo.box2d.lunarlander.config.lunarlander_trex_offppo_config import main_config, create_config + + +def serial_pipeline_reward_model_trex( + input_cfg: Union[str, Tuple[dict, dict]], + seed: int = 0, + env_setting: Optional[List[Any]] = None, + model: Optional[torch.nn.Module] = None, + max_iterations: Optional[int] = int(1e10), +) -> 'Policy': # noqa + """ + Overview: + serial_pipeline_reward_model_trex. + Arguments: + - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ + ``str`` type means config file path. \ + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ + ``BaseEnv`` subclass, collector env config, and evaluator env config. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - max_iterations (:obj:`Optional[torch.nn.Module]`): Learner's max iteration. Pipeline will stop \ + when reaching this iteration. + Returns: + - policy (:obj:`Policy`): Converged policy. + """ + if isinstance(input_cfg, str): + cfg, create_cfg = read_config(input_cfg) + else: + cfg, create_cfg = input_cfg + create_cfg.policy.type = create_cfg.policy.type + '_command' + env_fn = None if env_setting is None else env_setting[0] + cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True) + # Create main components: env, policy + if env_setting is None: + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + else: + env_fn, collector_env_cfg, evaluator_env_cfg = env_setting + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(cfg.seed) + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command']) + + # Create worker components: learner, collector, evaluator, replay buffer, commander. + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + collector = create_serial_collector( + cfg.policy.collect.collector, + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name + ) + evaluator = InteractionSerialEvaluator( + cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name + ) + replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name) + commander = BaseSerialCommander( + cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode + ) + reward_model = create_reward_model(cfg, policy.collect_mode.get_attribute('device'), tb_logger) + reward_model.train() + # ========== + # Main loop + # ========== + # Learner's before_run hook. + learner.call_hook('before_run') + + # Accumulate plenty of data at the beginning of training. + if cfg.policy.get('random_collect_size', 0) > 0: + if cfg.policy.get('transition_with_policy_data', False): + collector.reset_policy(policy.collect_mode) + else: + action_space = collector_env.env_info().act_space + random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space) + collector.reset_policy(random_policy) + collect_kwargs = commander.step() + new_data = collector.collect(n_sample=cfg.policy.random_collect_size, policy_kwargs=collect_kwargs) + replay_buffer.push(new_data, cur_collector_envstep=0) + collector.reset_policy(policy.collect_mode) + for _ in range(max_iterations): + collect_kwargs = commander.step() + # Evaluate policy performance + if evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + # Collect data by default config n_sample/n_episode + if hasattr(cfg.policy.collect, "each_iter_n_sample"): # TODO(pu) + new_data = collector.collect( + n_sample=cfg.policy.collect.each_iter_n_sample, + train_iter=learner.train_iter, + policy_kwargs=collect_kwargs + ) + else: + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + replay_buffer.push(new_data, cur_collector_envstep=collector.envstep) + # Learn policy from collected data + for i in range(cfg.policy.learn.update_per_collect): + # Learner will train ``update_per_collect`` times in one iteration. + train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) + if train_data is None: + # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times + logging.warning( + "Replay buffer's data can only train for {} steps. ".format(i) + + "You can modify data collect config, e.g. increasing n_sample, n_episode." + ) + break + # update train_data reward + reward_model.estimate(train_data) + learner.train(train_data, collector.envstep) + if learner.policy.get_attribute('priority'): + replay_buffer.update(learner.priority_info) + + # Learner's after_run hook. + learner.call_hook('after_run') + return policy + + +if __name__ == '__main__': + serial_pipeline_reward_model_trex([main_config, create_config]) diff --git a/ding/entry/serial_entry_trex_onpolicy.py b/ding/entry/serial_entry_trex_onpolicy.py new file mode 100644 index 0000000..da69808 --- /dev/null +++ b/ding/entry/serial_entry_trex_onpolicy.py @@ -0,0 +1,100 @@ +from typing import Union, Optional, List, Any, Tuple +import os +import torch +import logging +from functools import partial +from tensorboardX import SummaryWriter + +from ding.envs import get_vec_env_setting, create_env_manager +from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \ + create_serial_collector +from ding.config import read_config, compile_config +from ding.policy import create_policy, PolicyFactory +from ding.reward_model import create_reward_model +from ding.utils import set_pkg_seed +from dizoo.mujoco.config.hopper_trex_onppo_default_config import main_config, create_config + + +def serial_pipeline_reward_model_trex_onpolicy( + input_cfg: Union[str, Tuple[dict, dict]], + seed: int = 0, + env_setting: Optional[List[Any]] = None, + model: Optional[torch.nn.Module] = None, + max_iterations: Optional[int] = int(1e10), +) -> 'Policy': # noqa + """ + Overview: + Serial pipeline entry for onpolicy algorithm(such as PPO). + Arguments: + - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ + ``str`` type means config file path. \ + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ + ``BaseEnv`` subclass, collector env config, and evaluator env config. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - max_iterations (:obj:`Optional[torch.nn.Module]`): Learner's max iteration. Pipeline will stop \ + when reaching this iteration. + Returns: + - policy (:obj:`Policy`): Converged policy. + """ + if isinstance(input_cfg, str): + cfg, create_cfg = read_config(input_cfg) + else: + cfg, create_cfg = input_cfg + create_cfg.policy.type = create_cfg.policy.type + '_command' + env_fn = None if env_setting is None else env_setting[0] + cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True) + # Create main components: env, policy + if env_setting is None: + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + else: + env_fn, collector_env_cfg, evaluator_env_cfg = env_setting + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(cfg.seed) + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command']) + + # Create worker components: learner, collector, evaluator, replay buffer, commander. + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + collector = create_serial_collector( + cfg.policy.collect.collector, + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name + ) + evaluator = InteractionSerialEvaluator( + cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name + ) + reward_model = create_reward_model(cfg, policy.collect_mode.get_attribute('device'), tb_logger) + reward_model.train() + # ========== + # Main loop + # ========== + # Learner's before_run hook. + learner.call_hook('before_run') + + # Accumulate plenty of data at the beginning of training. + for _ in range(max_iterations): + # Evaluate policy performance + if evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + # Collect data by default config n_sample/n_episode + new_data = collector.collect(train_iter=learner.train_iter) + # Learn policy from collected data with modified rewards + reward_model.estimate(new_data) + learner.train(new_data, collector.envstep) + + # Learner's after_run hook. + learner.call_hook('after_run') + return policy + + +if __name__ == '__main__': + serial_pipeline_reward_model_trex_onpolicy([main_config, create_config]) diff --git a/ding/entry/tests/test_application_entry.py b/ding/entry/tests/test_application_entry.py index 8689838..3329afe 100644 --- a/ding/entry/tests/test_application_entry.py +++ b/ding/entry/tests/test_application_entry.py @@ -3,10 +3,14 @@ import pytest import os import pickle -from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, cartpole_ppo_offpolicy_create_config # noqa +from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, \ + cartpole_ppo_offpolicy_create_config # noqa +from dizoo.classic_control.cartpole.config.cartpole_trex_offppo_config import cartpole_trex_ppo_offpolicy_config,\ + cartpole_trex_ppo_offpolicy_create_config from dizoo.classic_control.cartpole.envs import CartPoleEnv from ding.entry import serial_pipeline, eval, collect_demo_data from ding.config import compile_config +from ding.entry.application_entry import collect_episodic_demo_data, epsiode_to_transitions @pytest.fixture(scope='module') @@ -58,4 +62,28 @@ class TestApplication: exp_data = pickle.load(f) assert isinstance(exp_data, list) assert isinstance(exp_data[0], dict) + + def test_collect_episodic_demo_data(self, setup_state_dict): + config = deepcopy(cartpole_trex_ppo_offpolicy_config), deepcopy(cartpole_trex_ppo_offpolicy_create_config) + collect_count = 16 + expert_data_path = './expert.data' + collect_episodic_demo_data( + config, + seed=0, + state_dict=setup_state_dict['collect'], + expert_data_path=expert_data_path, + collect_count=collect_count, + ) + with open(expert_data_path, 'rb') as f: + exp_data = pickle.load(f) + assert isinstance(exp_data, list) + assert isinstance(exp_data[0][0], dict) + + def test_epsiode_to_transitions(self): + expert_data_path = './expert.data' + epsiode_to_transitions(data_path=expert_data_path, expert_data_path=expert_data_path, nstep=3) + with open(expert_data_path, 'rb') as f: + exp_data = pickle.load(f) + assert isinstance(exp_data, list) + assert isinstance(exp_data[0], dict) os.popen('rm -rf ./expert.data ckpt* log') diff --git a/ding/entry/tests/test_application_entry_trex_collect_data.py b/ding/entry/tests/test_application_entry_trex_collect_data.py new file mode 100644 index 0000000..fef055d --- /dev/null +++ b/ding/entry/tests/test_application_entry_trex_collect_data.py @@ -0,0 +1,65 @@ +from easydict import EasyDict +import pytest +from copy import deepcopy +import os +from itertools import product + +import torch + +from dizoo.classic_control.cartpole.config.cartpole_trex_offppo_config import cartpole_trex_ppo_offpolicy_config,\ + cartpole_trex_ppo_offpolicy_create_config +from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config,\ + cartpole_ppo_offpolicy_create_config +from ding.entry.application_entry_trex_collect_data import collect_episodic_demo_data_for_trex, trex_collecting_data +from ding.entry import serial_pipeline + + +@pytest.mark.unittest +def test_collect_episodic_demo_data_for_trex(): + expert_policy_state_dict_path = './expert_policy.pth' + expert_policy_state_dict_path = os.path.abspath('ding/entry/expert_policy.pth') + config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)] + expert_policy = serial_pipeline(config, seed=0) + torch.save(expert_policy.collect_mode.state_dict(), expert_policy_state_dict_path) + + config = deepcopy(cartpole_trex_ppo_offpolicy_config), deepcopy(cartpole_trex_ppo_offpolicy_create_config) + collect_count = 1 + save_cfg_path = './cartpole_trex_offppo' + save_cfg_path = os.path.abspath(save_cfg_path) + exp_data = collect_episodic_demo_data_for_trex( + config, + seed=0, + state_dict_path=expert_policy_state_dict_path, + save_cfg_path=save_cfg_path, + collect_count=collect_count, + rank=1, + ) + assert isinstance(exp_data, list) + assert isinstance(exp_data[0][0], dict) + os.popen('rm -rf {}'.format(save_cfg_path)) + os.popen('rm -rf {}'.format(expert_policy_state_dict_path)) + + +@pytest.mark.unittest +def test_trex_collecting_data(): + expert_policy_state_dict_path = './cartpole_ppo_offpolicy' + expert_policy_state_dict_path = os.path.abspath(expert_policy_state_dict_path) + config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)] + expert_policy = serial_pipeline(config, seed=0) + + args = EasyDict( + { + 'cfg': [deepcopy(cartpole_trex_ppo_offpolicy_config), + deepcopy(cartpole_trex_ppo_offpolicy_create_config)], + 'seed': 0, + 'device': 'cpu' + } + ) + args.cfg[0].reward_model.offline_data_path = 'dizoo/classic_control/cartpole/config/cartpole_trex_offppo' + args.cfg[0].reward_model.offline_data_path = os.path.abspath(args.cfg[0].reward_model.offline_data_path) + args.cfg[0].reward_model.reward_model_path = args.cfg[0].reward_model.offline_data_path + '/cartpole.params' + args.cfg[0].reward_model.expert_model_path = './cartpole_ppo_offpolicy' + args.cfg[0].reward_model.expert_model_path = os.path.abspath(args.cfg[0].reward_model.expert_model_path) + trex_collecting_data(args=args) + os.popen('rm -rf {}'.format(expert_policy_state_dict_path)) + os.popen('rm -rf {}'.format(args.cfg[0].reward_model.offline_data_path)) diff --git a/ding/entry/tests/test_serial_entry_guided_cost.py b/ding/entry/tests/test_serial_entry_guided_cost.py new file mode 100644 index 0000000..ce13dec --- /dev/null +++ b/ding/entry/tests/test_serial_entry_guided_cost.py @@ -0,0 +1,23 @@ +import pytest +import torch +from copy import deepcopy +from ding.entry import serial_pipeline_onpolicy, serial_pipeline_guided_cost +from dizoo.classic_control.cartpole.config import cartpole_ppo_config, cartpole_ppo_create_config +from dizoo.classic_control.cartpole.config import cartpole_gcl_ppo_onpolicy_config, \ + cartpole_gcl_ppo_onpolicy_create_config + + +@pytest.mark.unittest +def test_guided_cost(): + expert_policy_state_dict_path = './expert_policy.pth' + config = [deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)] + expert_policy = serial_pipeline_onpolicy(config, seed=0) + torch.save(expert_policy.collect_mode.state_dict(), expert_policy_state_dict_path) + + config = [deepcopy(cartpole_gcl_ppo_onpolicy_config), deepcopy(cartpole_gcl_ppo_onpolicy_create_config)] + config[0].policy.collect.demonstration_info_path = expert_policy_state_dict_path + config[0].policy.learn.update_per_collect = 1 + try: + serial_pipeline_guided_cost(config, seed=0, max_iterations=1) + except Exception: + assert False, "pipeline fail" diff --git a/ding/entry/tests/test_serial_entry_trex.py b/ding/entry/tests/test_serial_entry_trex.py new file mode 100644 index 0000000..658d44c --- /dev/null +++ b/ding/entry/tests/test_serial_entry_trex.py @@ -0,0 +1,33 @@ +import pytest +from copy import deepcopy +import os +from easydict import EasyDict + +import torch + +from ding.entry import serial_pipeline +from ding.entry.serial_entry_trex import serial_pipeline_reward_model_trex +from dizoo.classic_control.cartpole.config.cartpole_trex_offppo_config import cartpole_trex_ppo_offpolicy_config,\ + cartpole_trex_ppo_offpolicy_create_config +from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config,\ + cartpole_ppo_offpolicy_create_config +from ding.entry.application_entry_trex_collect_data import trex_collecting_data + + +@pytest.mark.unittest +def test_serial_pipeline_reward_model_trex(): + config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)] + expert_policy = serial_pipeline(config, seed=0) + + config = [deepcopy(cartpole_trex_ppo_offpolicy_config), deepcopy(cartpole_trex_ppo_offpolicy_create_config)] + config[0].reward_model.offline_data_path = 'dizoo/classic_control/cartpole/config/cartpole_trex_offppo' + config[0].reward_model.offline_data_path = os.path.abspath(config[0].reward_model.offline_data_path) + config[0].reward_model.reward_model_path = config[0].reward_model.offline_data_path + '/cartpole.params' + config[0].reward_model.expert_model_path = './cartpole_ppo_offpolicy' + config[0].reward_model.expert_model_path = os.path.abspath(config[0].reward_model.expert_model_path) + args = EasyDict({'cfg': deepcopy(config), 'seed': 0, 'device': 'cpu'}) + trex_collecting_data(args=args) + try: + serial_pipeline_reward_model_trex(config, seed=0, max_iterations=1) + except Exception: + assert False, "pipeline fail" diff --git a/ding/entry/tests/test_serial_entry_trex_onpolicy.py b/ding/entry/tests/test_serial_entry_trex_onpolicy.py new file mode 100644 index 0000000..5820a71 --- /dev/null +++ b/ding/entry/tests/test_serial_entry_trex_onpolicy.py @@ -0,0 +1,31 @@ +import pytest +from copy import deepcopy +import os +from easydict import EasyDict + +import torch + +from ding.entry import serial_pipeline_onpolicy +from ding.entry.serial_entry_trex_onpolicy import serial_pipeline_reward_model_trex_onpolicy +from dizoo.mujoco.config import hopper_ppo_default_config, hopper_ppo_create_default_config +from dizoo.mujoco.config import hopper_trex_ppo_default_config, hopper_trex_ppo_create_default_config +from ding.entry.application_entry_trex_collect_data import trex_collecting_data + + +@pytest.mark.unittest +def test_serial_pipeline_reward_model_trex(): + config = [deepcopy(hopper_ppo_default_config), deepcopy(hopper_ppo_create_default_config)] + expert_policy = serial_pipeline_onpolicy(config, seed=0, max_iterations=90) + + config = [deepcopy(hopper_trex_ppo_default_config), deepcopy(hopper_trex_ppo_create_default_config)] + config[0].reward_model.offline_data_path = 'dizoo/mujoco/config/hopper_trex_onppo' + config[0].reward_model.offline_data_path = os.path.abspath(config[0].reward_model.offline_data_path) + config[0].reward_model.reward_model_path = config[0].reward_model.offline_data_path + '/hopper.params' + config[0].reward_model.expert_model_path = './hopper_onppo' + config[0].reward_model.expert_model_path = os.path.abspath(config[0].reward_model.expert_model_path) + args = EasyDict({'cfg': deepcopy(config), 'seed': 0, 'device': 'cpu'}) + trex_collecting_data(args=args) + try: + serial_pipeline_reward_model_trex_onpolicy(config, seed=0, max_iterations=1) + except Exception: + assert False, "pipeline fail" diff --git a/ding/reward_model/__init__.py b/ding/reward_model/__init__.py index faf497b..30ca69a 100644 --- a/ding/reward_model/__init__.py +++ b/ding/reward_model/__init__.py @@ -4,6 +4,7 @@ from .pdeil_irl_model import PdeilRewardModel from .gail_irl_model import GailRewardModel from .pwil_irl_model import PwilRewardModel from .red_irl_model import RedRewardModel +from .trex_reward_model import TrexRewardModel # sparse reward from .her_reward_model import HerRewardModel # exploration diff --git a/ding/reward_model/base_reward_model.py b/ding/reward_model/base_reward_model.py index 1793fdf..08eaec5 100644 --- a/ding/reward_model/base_reward_model.py +++ b/ding/reward_model/base_reward_model.py @@ -90,7 +90,10 @@ def create_reward_model(cfg: dict, device: str, tb_logger: 'SummaryWriter') -> B cfg = copy.deepcopy(cfg) if 'import_names' in cfg: import_module(cfg.pop('import_names')) - reward_model_type = cfg.pop('type') + if hasattr(cfg, 'reward_model'): + reward_model_type = cfg.reward_model.pop('type') + else: + reward_model_type = cfg.pop('type') return REWARD_MODEL_REGISTRY.build(reward_model_type, cfg, device=device, tb_logger=tb_logger) diff --git a/ding/reward_model/trex_reward_model.py b/ding/reward_model/trex_reward_model.py new file mode 100644 index 0000000..16b654f --- /dev/null +++ b/ding/reward_model/trex_reward_model.py @@ -0,0 +1,431 @@ +from collections.abc import Iterable +from easydict import EasyDict +import numpy as np +import pickle +from copy import deepcopy +from typing import Tuple, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.distributions import Normal, Independent +from torch.distributions.categorical import Categorical + +from ding.utils import REWARD_MODEL_REGISTRY +from ding.model.template.q_learning import DQN +from ding.model.template.vac import VAC +from ding.model.template.qac import QAC +from ding.utils import SequenceType +from ding.model.common import FCEncoder +from ding.utils.data import offline_data_save_type +from ding.utils import build_logger +from dizoo.atari.envs.atari_wrappers import wrap_deepmind +from dizoo.mujoco.envs.mujoco_wrappers import wrap_mujoco + +from .base_reward_model import BaseRewardModel +from .rnd_reward_model import collect_states + + +class ConvEncoder(nn.Module): + r""" + Overview: + The ``Convolution Encoder`` used in models. Used to encoder raw 2-dim observation. + Interfaces: + ``__init__``, ``forward`` + """ + + def __init__( + self, + obs_shape: SequenceType, + hidden_size_list: SequenceType = [16, 16, 16, 16, 64, 1], + activation: Optional[nn.Module] = nn.LeakyReLU(), + norm_type: Optional[str] = None + ) -> None: + r""" + Overview: + Init the Convolution Encoder according to arguments. + Arguments: + - obs_shape (:obj:`SequenceType`): Sequence of ``in_channel``, some ``output size`` + - hidden_size_list (:obj:`SequenceType`): The collection of ``hidden_size`` + - activation (:obj:`nn.Module`): + The type of activation to use in the conv ``layers``, + if ``None`` then default set to ``nn.ReLU()`` + - norm_type (:obj:`str`): + The type of normalization to use, see ``ding.torch_utils.ResBlock`` for more details + """ + super(ConvEncoder, self).__init__() + self.obs_shape = obs_shape + self.act = activation + self.hidden_size_list = hidden_size_list + + layers = [] + kernel_size = [7, 5, 3, 3] + stride = [3, 2, 1, 1] + input_size = obs_shape[0] # in_channel + for i in range(len(kernel_size)): + layers.append(nn.Conv2d(input_size, hidden_size_list[i], kernel_size[i], stride[i])) + layers.append(self.act) + input_size = hidden_size_list[i] + layers.append(nn.Flatten()) + self.main = nn.Sequential(*layers) + + flatten_size = self._get_flatten_size() + self.mid = nn.Sequential( + nn.Linear(flatten_size, hidden_size_list[-2]), self.act, + nn.Linear(hidden_size_list[-2], hidden_size_list[-1]) + ) + + def _get_flatten_size(self) -> int: + r""" + Overview: + Get the encoding size after ``self.main`` to get the number of ``in-features`` to feed to ``nn.Linear``. + Arguments: + - x (:obj:`torch.Tensor`): Encoded Tensor after ``self.main`` + Returns: + - outputs (:obj:`torch.Tensor`): Size int, also number of in-feature + """ + test_data = torch.randn(1, *self.obs_shape) + with torch.no_grad(): + output = self.main(test_data) + return output.shape[1] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r""" + Overview: + Return embedding tensor of the env observation + Arguments: + - x (:obj:`torch.Tensor`): Env raw observation + Returns: + - outputs (:obj:`torch.Tensor`): Embedding tensor + """ + x = self.main(x) + x = self.mid(x) + return x + + +class TrexModel(nn.Module): + + def __init__(self, obs_shape): + super(TrexModel, self).__init__() + if isinstance(obs_shape, int) or len(obs_shape) == 1: + self.encoder = nn.Sequential(FCEncoder(obs_shape, [512, 64]), nn.Linear(64, 1)) + # Conv Encoder + elif len(obs_shape) == 3: + self.encoder = ConvEncoder(obs_shape) + else: + raise KeyError( + "not support obs_shape for pre-defined encoder: {}, please customize your own Trex model". + format(obs_shape) + ) + + def cum_return(self, traj: torch.Tensor, mode: str = 'sum') -> Tuple[torch.Tensor, torch.Tensor]: + '''calculate cumulative return of trajectory''' + r = self.encoder(traj) + if mode == 'sum': + sum_rewards = torch.sum(r) + sum_abs_rewards = torch.sum(torch.abs(r)) + # print(sum_rewards) + # print(r) + return sum_rewards, sum_abs_rewards + elif mode == 'batch': + # print(r) + return r, torch.abs(r) + else: + raise KeyError("not support mode: {}, please choose mode=sum or mode=batch".format(mode)) + + def forward(self, traj_i: torch.Tensor, traj_j: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + '''compute cumulative return for each trajectory and return logits''' + cum_r_i, abs_r_i = self.cum_return(traj_i) + cum_r_j, abs_r_j = self.cum_return(traj_j) + return torch.cat((cum_r_i.unsqueeze(0), cum_r_j.unsqueeze(0)), 0), abs_r_i + abs_r_j + + +@REWARD_MODEL_REGISTRY.register('trex') +class TrexRewardModel(BaseRewardModel): + """ + Overview: + The Trex reward model class (https://arxiv.org/pdf/1904.06387.pdf) + Interface: + ``estimate``, ``train``, ``load_expert_data``, ``collect_data``, ``clear_date``, \ + ``__init__``, ``_train``, + """ + config = dict( + type='trex', + learning_rate=1e-5, + update_per_collect=100, + batch_size=64, + target_new_data_count=64, + hidden_size=128, + ) + + def __init__(self, config: EasyDict, device: str, tb_logger: 'SummaryWriter') -> None: # noqa + """ + Overview: + Initialize ``self.`` See ``help(type(self))`` for accurate signature. + Arguments: + - cfg (:obj:`EasyDict`): Training config + - device (:obj:`str`): Device usage, i.e. "cpu" or "cuda" + - tb_logger (:obj:`SummaryWriter`): Logger, defaultly set as 'SummaryWriter' for model summary + """ + super(TrexRewardModel, self).__init__() + self.cfg = config + assert device in ["cpu", "cuda"] or "cuda" in device + self.device = device + self.tb_logger = tb_logger + self.reward_model = TrexModel(self.cfg.policy.model.get('obs_shape')) + self.reward_model.to(self.device) + self.pre_expert_data = [] + self.train_data = [] + self.expert_data_loader = None + self.opt = optim.Adam(self.reward_model.parameters(), config.reward_model.learning_rate) + self.train_iter = 0 + self.learning_returns = [] + self.learning_rewards = [] + self.training_obs = [] + self.training_labels = [] + self.num_trajs = 0 # number of downsampled full trajectories + self.num_snippets = 6000 # number of short subtrajectories to sample + # minimum number of short subtrajectories to sample + self.min_snippet_length = config.reward_model.min_snippet_length + # maximum number of short subtrajectories to sample + self.max_snippet_length = config.reward_model.max_snippet_length + self.l1_reg = 0 + self.data_for_save = {} + self._logger, self._tb_logger = build_logger( + path='./{}/log/{}'.format(self.cfg.exp_name, 'trex_reward_model'), name='trex_reward_model' + ) + self.load_expert_data() + + def load_expert_data(self) -> None: + """ + Overview: + Getting the expert data from ``config.expert_data_path`` attribute in self + Effects: + This is a side effect function which updates the expert data attribute \ + (i.e. ``self.expert_data``) with ``fn:concat_state_action_pairs`` + """ + with open(self.cfg.reward_model.offline_data_path + '/episodes_data.pkl', 'rb') as f: + self.pre_expert_data = pickle.load(f) + with open(self.cfg.reward_model.offline_data_path + '/learning_returns.pkl', 'rb') as f: + self.learning_returns = pickle.load(f) + with open(self.cfg.reward_model.offline_data_path + '/learning_rewards.pkl', 'rb') as f: + self.learning_reward = pickle.load(f) + with open(self.cfg.reward_model.offline_data_path + '/checkpoints.pkl', 'rb') as f: + self.checkpoints = pickle.load(f) + self.training_obs, self.training_labels = self.create_training_data() + self._logger.info("num_training_obs: {}".format(len(self.training_obs))) + self._logger.info("num_labels: {}".format(len(self.training_labels))) + + def create_training_data(self): + demonstrations = self.pre_expert_data + num_trajs = self.num_trajs + num_snippets = self.num_snippets + min_snippet_length = self.min_snippet_length + max_snippet_length = self.max_snippet_length + + demo_lengths = [len(d) for d in demonstrations] + self._logger.info("demo_lengths: {}".format(demo_lengths)) + max_snippet_length = min(np.min(demo_lengths), max_snippet_length) + self._logger.info("min snippet length: {}".format(min_snippet_length)) + self._logger.info("max snippet length: {}".format(max_snippet_length)) + + self._logger.info(len(self.learning_returns)) + self._logger.info(len(demonstrations)) + self._logger.info("learning returns: {}".format([a[0] for a in zip(self.learning_returns, demonstrations)])) + demonstrations = [x for _, x in sorted(zip(self.learning_returns, demonstrations), key=lambda pair: pair[0])] + sorted_returns = sorted(self.learning_returns) + self._logger.info("sorted learning returns: {}".format(sorted_returns)) + + #collect training data + max_traj_length = 0 + num_demos = len(demonstrations) + + #add full trajs (for use on Enduro) + si = np.random.randint(6, size=num_trajs) + sj = np.random.randint(6, size=num_trajs) + step = np.random.randint(3, 7, size=num_trajs) + for n in range(num_trajs): + ti = 0 + tj = 0 + #only add trajectories that are different returns + while (ti == tj): + #pick two random demonstrations + ti = np.random.randint(num_demos) + tj = np.random.randint(num_demos) + #create random partial trajs by finding random start frame and random skip frame + traj_i = demonstrations[ti][si[n]::step[n]] # slice(start,stop,step) + traj_j = demonstrations[tj][sj[n]::step[n]] + + label = int(ti <= tj) + + self.training_obs.append((traj_i, traj_j)) + self.training_labels.append(label) + max_traj_length = max(max_traj_length, len(traj_i), len(traj_j)) + + #fixed size snippets with progress prior + rand_length = np.random.randint(min_snippet_length, max_snippet_length, size=num_snippets) + for n in range(num_snippets): + ti = 0 + tj = 0 + #only add trajectories that are different returns + while (ti == tj): + #pick two random demonstrations + ti = np.random.randint(num_demos) + tj = np.random.randint(num_demos) + #create random snippets + #find min length of both demos to ensure we can pick a demo no earlier + #than that chosen in worse preferred demo + min_length = min(len(demonstrations[ti]), len(demonstrations[tj])) + if ti < tj: # pick tj snippet to be later than ti + ti_start = np.random.randint(min_length - rand_length[n] + 1) + # print(ti_start, len(demonstrations[tj])) + tj_start = np.random.randint(ti_start, len(demonstrations[tj]) - rand_length[n] + 1) + else: # ti is better so pick later snippet in ti + tj_start = np.random.randint(min_length - rand_length[n] + 1) + # print(tj_start, len(demonstrations[ti])) + ti_start = np.random.randint(tj_start, len(demonstrations[ti]) - rand_length[n] + 1) + traj_i = demonstrations[ti][ti_start:ti_start + rand_length[n]:2 + ] # skip everyother framestack to reduce size + traj_j = demonstrations[tj][tj_start:tj_start + rand_length[n]:2] + + max_traj_length = max(max_traj_length, len(traj_i), len(traj_j)) + label = int(ti <= tj) + self.training_obs.append((traj_i, traj_j)) + self.training_labels.append(label) + + self._logger.info(("maximum traj length: {}".format(max_traj_length))) + return self.training_obs, self.training_labels + + def train(self): + # check if gpu available + device = self.device # torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + # Assume that we are on a CUDA machine, then this should print a CUDA device: + self._logger.info("device: {}".format(device)) + training_inputs, training_outputs = self.training_obs, self.training_labels + loss_criterion = nn.CrossEntropyLoss() + + cum_loss = 0.0 + training_data = list(zip(training_inputs, training_outputs)) + for epoch in range(self.cfg.reward_model.update_per_collect): # todo + np.random.shuffle(training_data) + training_obs, training_labels = zip(*training_data) + for i in range(len(training_labels)): + + # traj_i, traj_j has the same length, however, they change as i increases + traj_i, traj_j = training_obs[i] # traj_i is a list of array generated by env.step + traj_i = np.array(traj_i) + traj_j = np.array(traj_j) + traj_i = torch.from_numpy(traj_i).float().to(device) + traj_j = torch.from_numpy(traj_j).float().to(device) + + # training_labels[i] is a boolean integer: 0 or 1 + labels = torch.tensor([training_labels[i]]).to(device) + + # forward + backward + zero out gradient + optimize + outputs, abs_rewards = self.reward_model.forward(traj_i, traj_j) + outputs = outputs.unsqueeze(0) + loss = loss_criterion(outputs, labels) + self.l1_reg * abs_rewards + self.opt.zero_grad() + loss.backward() + self.opt.step() + + # print stats to see if learning + item_loss = loss.item() + cum_loss += item_loss + if i % 100 == 99: + # print(i) + self._logger.info("epoch {}:{} loss {}".format(epoch, i, cum_loss)) + self._logger.info("abs_returns: {}".format(abs_rewards)) + cum_loss = 0.0 + self._logger.info("check pointing") + torch.save(self.reward_model.state_dict(), self.cfg.reward_model.reward_model_path) + torch.save(self.reward_model.state_dict(), self.cfg.reward_model.reward_model_path) + self._logger.info("finished training") + # print out predicted cumulative returns and actual returns + sorted_returns = sorted(self.learning_returns) + with torch.no_grad(): + pred_returns = [self.predict_traj_return(self.reward_model, traj) for traj in self.pre_expert_data] + for i, p in enumerate(pred_returns): + self._logger.info("{} {} {}".format(i, p, sorted_returns[i])) + info = { + #"demo_length": [len(d) for d in self.pre_expert_data], + #"min_snippet_length": self.min_snippet_length, + #"max_snippet_length": min(np.min([len(d) for d in self.pre_expert_data]), self.max_snippet_length), + #"len_num_training_obs": len(self.training_obs), + #"lem_num_labels": len(self.training_labels), + "accuracy": self.calc_accuracy(self.reward_model, self.training_obs, self.training_labels), + } + self._logger.info( + "accuracy and comparison:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()])) + ) + + def predict_traj_return(self, net, traj): + device = self.device + # torch.set_printoptions(precision=20) + # torch.use_deterministic_algorithms(True) + with torch.no_grad(): + rewards_from_obs = net.cum_return( + torch.from_numpy(np.array(traj)).float().to(device), mode='batch' + )[0].squeeze().tolist() + # rewards_from_obs1 = net.cum_return(torch.from_numpy(np.array([traj[0]])).float().to(device))[0].item() + # different precision + return sum(rewards_from_obs) # rewards_from_obs is a list of floats + + def calc_accuracy(self, reward_network, training_inputs, training_outputs): + device = self.device + loss_criterion = nn.CrossEntropyLoss() + num_correct = 0. + with torch.no_grad(): + for i in range(len(training_inputs)): + label = training_outputs[i] + traj_i, traj_j = training_inputs[i] + traj_i = np.array(traj_i) + traj_j = np.array(traj_j) + traj_i = torch.from_numpy(traj_i).float().to(device) + traj_j = torch.from_numpy(traj_j).float().to(device) + + #forward to get logits + outputs, abs_return = reward_network.forward(traj_i, traj_j) + _, pred_label = torch.max(outputs, 0) + if pred_label.item() == label: + num_correct += 1. + return num_correct / len(training_inputs) + + def estimate(self, data: list) -> None: + """ + Overview: + Estimate reward by rewriting the reward key in each row of the data. + Arguments: + - data (:obj:`list`): the list of data used for estimation, with at least \ + ``obs`` and ``action`` keys. + Effects: + - This is a side effect function which updates the reward values in place. + """ + res = collect_states(data) + res = torch.stack(res).to(self.device) + with torch.no_grad(): + sum_rewards, sum_abs_rewards = self.reward_model.cum_return(res, mode='batch') + + for item, rew in zip(data, sum_rewards): # TODO optimise this loop as well ? + item['reward'] = rew + + def collect_data(self, data: list) -> None: + """ + Overview: + Collecting training data formatted by ``fn:concat_state_action_pairs``. + Arguments: + - data (:obj:`Any`): Raw training data (e.g. some form of states, actions, obs, etc) + Effects: + - This is a side effect function which updates the data attribute in ``self`` + """ + pass + + def clear_data(self) -> None: + """ + Overview: + Clearing training data. \ + This is a side effect function which clears the data attribute in ``self`` + """ + self.training_obs.clear() + self.training_labels.clear() diff --git a/dizoo/atari/config/serial/pong/pong_trex_offppo_config.py b/dizoo/atari/config/serial/pong/pong_trex_offppo_config.py new file mode 100644 index 0000000..c28e4f3 --- /dev/null +++ b/dizoo/atari/config/serial/pong/pong_trex_offppo_config.py @@ -0,0 +1,78 @@ +from copy import deepcopy +from easydict import EasyDict + +pong_trex_ppo_config = dict( + exp_name='pong_trex_offppo', + env=dict( + collector_env_num=16, + evaluator_env_num=8, + n_evaluator_episode=8, + stop_value=20, + env_id='PongNoFrameskip-v4', + frame_stack=4, + manager=dict(shared_memory=False, ) + ), + reward_model=dict( + type='trex', + algo_for_model='ppo', + env_id='PongNoFrameskip-v4', + min_snippet_length=50, + max_snippet_length=100, + checkpoint_min=0, + checkpoint_max=100, + checkpoint_step=100, + learning_rate=1e-5, + update_per_collect=1, + expert_model_path='abs model path', + reward_model_path='abs data path + ./pong.params', + offline_data_path='abs data path', + ), + policy=dict( + cuda=True, + random_collect_size=2048, + model=dict( + obs_shape=[4, 84, 84], + action_shape=6, + encoder_hidden_size_list=[64, 64, 128], + actor_head_hidden_size=128, + critic_head_hidden_size=128, + critic_head_layer_num=1, # Todo, to solve generality problem + ), + learn=dict( + update_per_collect=24, + batch_size=128, + # (bool) Whether to normalize advantage. Default to False. + adv_norm=False, + learning_rate=0.0002, + # (float) loss weight of the value network, the weight of policy network is set to 1 + value_weight=0.5, + # (float) loss weight of the entropy regularization, the weight of policy network is set to 1 + entropy_weight=0.01, + clip_ratio=0.1, + ), + collect=dict( + # (int) collect n_sample data, train model n_iteration times + n_sample=1024, + # (float) the trade-off factor lambda to balance 1step td and mc + gae_lambda=0.95, + discount_factor=0.99, + ), + eval=dict(evaluator=dict(eval_freq=1000, )), + other=dict(replay_buffer=dict( + replay_buffer_size=100000, + max_use=5, + ), ), + ), +) +main_config = EasyDict(pong_trex_ppo_config) + +pong_trex_ppo_create_config = dict( + env=dict( + type='atari', + import_names=['dizoo.atari.envs.atari_env'], + ), + # env_manager=dict(type='subprocess'), + env_manager=dict(type='base'), + policy=dict(type='ppo_offpolicy'), +) +create_config = EasyDict(pong_trex_ppo_create_config) diff --git a/dizoo/atari/config/serial/pong/pong_trex_sql_config.py b/dizoo/atari/config/serial/pong/pong_trex_sql_config.py new file mode 100644 index 0000000..eaac65e --- /dev/null +++ b/dizoo/atari/config/serial/pong/pong_trex_sql_config.py @@ -0,0 +1,64 @@ +from copy import deepcopy +from easydict import EasyDict + +pong_trex_sql_config = dict( + exp_name='pong_trex_sql', + env=dict( + collector_env_num=8, + evaluator_env_num=8, + n_evaluator_episode=8, + stop_value=20, + env_id='PongNoFrameskip-v4', + frame_stack=4, + manager=dict(shared_memory=False, ) + ), + reward_model=dict( + type='trex', + algo_for_model='sql', + env_id='PongNoFrameskip-v4', + min_snippet_length=50, + max_snippet_length=100, + checkpoint_min=10000, + checkpoint_max=50000, + checkpoint_step=10000, + learning_rate=1e-5, + update_per_collect=1, + expert_model_path='abs model path', + reward_model_path='abs data path + ./pong.params', + offline_data_path='abs data path', + ), + policy=dict( + cuda=False, + priority=False, + model=dict( + obs_shape=[4, 84, 84], + action_shape=6, + encoder_hidden_size_list=[128, 128, 512], + ), + nstep=1, + discount_factor=0.99, + learn=dict(update_per_collect=10, batch_size=32, learning_rate=0.0001, target_update_freq=500, alpha=0.12), + collect=dict(n_sample=96, ), + other=dict( + eps=dict( + type='exp', + start=1., + end=0.05, + decay=250000, + ), + replay_buffer=dict(replay_buffer_size=100000, ), + ), + ), +) +pong_trex_sql_config = EasyDict(pong_trex_sql_config) +main_config = pong_trex_sql_config +pong_trex_sql_create_config = dict( + env=dict( + type='atari', + import_names=['dizoo.atari.envs.atari_env'], + ), + env_manager=dict(type='base', force_reproducibility=True), + policy=dict(type='sql'), +) +pong_trex_sql_create_config = EasyDict(pong_trex_sql_create_config) +create_config = pong_trex_sql_create_config diff --git a/dizoo/atari/config/serial/qbert/qbert_trex_dqn_config.py b/dizoo/atari/config/serial/qbert/qbert_trex_dqn_config.py new file mode 100644 index 0000000..8675ea4 --- /dev/null +++ b/dizoo/atari/config/serial/qbert/qbert_trex_dqn_config.py @@ -0,0 +1,70 @@ +from copy import deepcopy +from easydict import EasyDict + +qbert_trex_dqn_config = dict( + exp_name='qbert_trex_dqn', + env=dict( + collector_env_num=8, + evaluator_env_num=8, + n_evaluator_episode=8, + stop_value=30000, + env_id='QbertNoFrameskip-v4', + frame_stack=4, + manager=dict(shared_memory=False, ) + ), + reward_model=dict( + type='trex', + algo_for_model='dqn', + env_id='QbertNoFrameskip-v4', + min_snippet_length=30, + max_snippet_length=100, + checkpoint_min=0, + checkpoint_max=100, + checkpoint_step=100, + learning_rate=1e-5, + update_per_collect=1, + expert_model_path='abs model path', + reward_model_path='abs data path + ./qbert.params', + offline_data_path='abs data path', + ), + policy=dict( + cuda=True, + priority=False, + model=dict( + obs_shape=[4, 84, 84], + action_shape=6, + encoder_hidden_size_list=[128, 128, 512], + ), + nstep=1, + discount_factor=0.99, + learn=dict( + update_per_collect=10, + batch_size=32, + learning_rate=0.0001, + target_update_freq=500, + ), + collect=dict(n_sample=100, ), + eval=dict(evaluator=dict(eval_freq=4000, )), + other=dict( + eps=dict( + type='exp', + start=1., + end=0.05, + decay=1000000, + ), + replay_buffer=dict(replay_buffer_size=400000, ), + ), + ), +) +qbert_trex_dqn_config = EasyDict(qbert_trex_dqn_config) +main_config = qbert_trex_dqn_config +qbert_trex_dqn_create_config = dict( + env=dict( + type='atari', + import_names=['dizoo.atari.envs.atari_env'], + ), + env_manager=dict(type='base'), + policy=dict(type='dqn'), +) +qbert_trex_dqn_create_config = EasyDict(qbert_trex_dqn_create_config) +create_config = qbert_trex_dqn_create_config diff --git a/dizoo/atari/config/serial/qbert/qbert_trex_offppo_config.py b/dizoo/atari/config/serial/qbert/qbert_trex_offppo_config.py new file mode 100644 index 0000000..be9cc83 --- /dev/null +++ b/dizoo/atari/config/serial/qbert/qbert_trex_offppo_config.py @@ -0,0 +1,76 @@ +from copy import deepcopy +from easydict import EasyDict + +qbert_trex_ppo_config = dict( + exp_name='qbert_trex_offppo', + env=dict( + collector_env_num=16, + evaluator_env_num=8, + n_evaluator_episode=8, + stop_value=10000000000, + env_id='QbertNoFrameskip-v4', + frame_stack=4, + manager=dict(shared_memory=False, ) + ), + reward_model=dict( + type='trex', + algo_for_model='ppo', + env_id='QbertNoFrameskip-v4', + min_snippet_length=30, + max_snippet_length=100, + checkpoint_min=0, + checkpoint_max=100, + checkpoint_step=100, + learning_rate=1e-5, + update_per_collect=1, + expert_model_path='abs model path', + reward_model_path='abs data path + ./qbert.params', + offline_data_path='abs data path', + ), + policy=dict( + cuda=True, + model=dict( + obs_shape=[4, 84, 84], + action_shape=6, + encoder_hidden_size_list=[32, 64, 64, 128], + actor_head_hidden_size=128, + critic_head_hidden_size=128, + critic_head_layer_num=2, + ), + learn=dict( + update_per_collect=24, + batch_size=128, + # (bool) Whether to normalize advantage. Default to False. + adv_norm=False, + learning_rate=0.0001, + # (float) loss weight of the value network, the weight of policy network is set to 1 + value_weight=1.0, + # (float) loss weight of the entropy regularization, the weight of policy network is set to 1 + entropy_weight=0.03, + clip_ratio=0.1, + ), + collect=dict( + # (int) collect n_sample data, train model n_iteration times + n_sample=1024, + # (float) the trade-off factor lambda to balance 1step td and mc + gae_lambda=0.95, + discount_factor=0.99, + ), + eval=dict(evaluator=dict(eval_freq=1000, )), + other=dict(replay_buffer=dict( + replay_buffer_size=100000, + max_use=3, + ), ), + ), +) +main_config = EasyDict(qbert_trex_ppo_config) + +qbert_trex_ppo_create_config = dict( + env=dict( + type='atari', + import_names=['dizoo.atari.envs.atari_env'], + ), + env_manager=dict(type='base'), + policy=dict(type='ppo_offpolicy'), +) +create_config = EasyDict(qbert_trex_ppo_create_config) diff --git a/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_trex_dqn_config.py b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_trex_dqn_config.py new file mode 100644 index 0000000..86b6498 --- /dev/null +++ b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_trex_dqn_config.py @@ -0,0 +1,70 @@ +from copy import deepcopy +from easydict import EasyDict + +space_invaders_trex_dqn_config = dict( + exp_name='space_invaders_trex_dqn', + env=dict( + collector_env_num=8, + evaluator_env_num=8, + n_evaluator_episode=8, + stop_value=10000000000, + env_id='SpaceInvadersNoFrameskip-v4', + frame_stack=4, + manager=dict(shared_memory=False, ) + ), + reward_model=dict( + type='trex', + algo_for_model='dqn', + env_id='SpaceInvadersNoFrameskip-v4', + min_snippet_length=30, + max_snippet_length=100, + checkpoint_min=0, + checkpoint_max=100, + checkpoint_step=100, + learning_rate=1e-5, + update_per_collect=1, + expert_model_path='abs model path', + reward_model_path='abs data path + ./spaceinvaders.params', + offline_data_path='abs data path', + ), + policy=dict( + cuda=True, + priority=False, + model=dict( + obs_shape=[4, 84, 84], + action_shape=6, + encoder_hidden_size_list=[128, 128, 512], + ), + nstep=1, + discount_factor=0.99, + learn=dict( + update_per_collect=10, + batch_size=32, + learning_rate=0.0001, + target_update_freq=500, + ), + collect=dict(n_sample=100, ), + eval=dict(evaluator=dict(eval_freq=100, )), + other=dict( + eps=dict( + type='exp', + start=1., + end=0.05, + decay=1000000, + ), + replay_buffer=dict(replay_buffer_size=400000, ), + ), + ), +) +space_invaders_trex_dqn_config = EasyDict(space_invaders_trex_dqn_config) +main_config = space_invaders_trex_dqn_config +space_invaders_trex_dqn_create_config = dict( + env=dict( + type='atari', + import_names=['dizoo.atari.envs.atari_env'], + ), + env_manager=dict(type='base'), + policy=dict(type='dqn'), +) +space_invaders_trex_dqn_create_config = EasyDict(space_invaders_trex_dqn_create_config) +create_config = space_invaders_trex_dqn_create_config diff --git a/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_trex_offppo_config.py b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_trex_offppo_config.py new file mode 100644 index 0000000..8f7bcb2 --- /dev/null +++ b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_trex_offppo_config.py @@ -0,0 +1,76 @@ +from copy import deepcopy +from easydict import EasyDict + +space_invaders_trex_ppo_config = dict( + exp_name='space_invaders_trex_offppo', + env=dict( + collector_env_num=16, + evaluator_env_num=8, + n_evaluator_episode=8, + stop_value=10000000000, + env_id='SpaceInvadersNoFrameskip-v4', + frame_stack=4, + manager=dict(shared_memory=False, ) + ), + reward_model=dict( + type='trex', + algo_for_model='ppo', + env_id='SpaceInvadersNoFrameskip-v4', + min_snippet_length=30, + max_snippet_length=100, + checkpoint_min=0, + checkpoint_max=100, + checkpoint_step=100, + learning_rate=1e-5, + update_per_collect=1, + expert_model_path='abs model path', + reward_model_path='abs data path + ./spaceinvaders.params', + offline_data_path='abs data path', + ), + policy=dict( + cuda=True, + model=dict( + obs_shape=[4, 84, 84], + action_shape=6, + encoder_hidden_size_list=[32, 64, 64, 128], + actor_head_hidden_size=128, + critic_head_hidden_size=128, + critic_head_layer_num=2, + ), + learn=dict( + update_per_collect=24, + batch_size=128, + # (bool) Whether to normalize advantage. Default to False. + adv_norm=False, + learning_rate=0.0001, + # (float) loss weight of the value network, the weight of policy network is set to 1 + value_weight=1.0, + # (float) loss weight of the entropy regularization, the weight of policy network is set to 1 + entropy_weight=0.03, + clip_ratio=0.1, + ), + collect=dict( + # (int) collect n_sample data, train model n_iteration times + n_sample=1024, + # (float) the trade-off factor lambda to balance 1step td and mc + gae_lambda=0.95, + discount_factor=0.99, + ), + eval=dict(evaluator=dict(eval_freq=1000, )), + other=dict(replay_buffer=dict( + replay_buffer_size=100000, + max_use=5, + ), ), + ), +) +main_config = EasyDict(space_invaders_trex_ppo_config) + +space_invaders_trex_ppo_create_config = dict( + env=dict( + type='atari', + import_names=['dizoo.atari.envs.atari_env'], + ), + env_manager=dict(type='base'), + policy=dict(type='ppo_offpolicy'), +) +create_config = EasyDict(space_invaders_trex_ppo_create_config) diff --git a/dizoo/box2d/lunarlander/config/__init__.py b/dizoo/box2d/lunarlander/config/__init__.py index b6a6df1..bee16fe 100644 --- a/dizoo/box2d/lunarlander/config/__init__.py +++ b/dizoo/box2d/lunarlander/config/__init__.py @@ -2,3 +2,5 @@ from .lunarlander_dqn_config import lunarlander_dqn_default_config, lunarlander_ from .lunarlander_dqn_gail_config import lunarlander_dqn_gail_create_config, lunarlander_dqn_gail_default_config from .lunarlander_ppo_config import lunarlander_ppo_config from .lunarlander_qrdqn_config import lunarlander_qrdqn_config, lunarlander_qrdqn_create_config +from .lunarlander_trex_dqn_config import lunarlander_trex_dqn_default_config, lunarlander_trex_dqn_create_config +from .lunarlander_trex_offppo_config import lunarlander_trex_ppo_config, lunarlander_trex_ppo_create_config diff --git a/dizoo/box2d/lunarlander/config/lunarlander_ppo_config.py b/dizoo/box2d/lunarlander/config/lunarlander_ppo_config.py index bcdda8b..4d3af31 100644 --- a/dizoo/box2d/lunarlander/config/lunarlander_ppo_config.py +++ b/dizoo/box2d/lunarlander/config/lunarlander_ppo_config.py @@ -2,7 +2,9 @@ from easydict import EasyDict from ding.entry import serial_pipeline lunarlander_ppo_config = dict( + exp_name='lunarlander_ppo', env=dict( + manager=dict(shared_memory=True, force_reproducibility=True), collector_env_num=8, evaluator_env_num=5, n_evaluator_episode=5, diff --git a/dizoo/box2d/lunarlander/config/lunarlander_trex_dqn_config.py b/dizoo/box2d/lunarlander/config/lunarlander_trex_dqn_config.py new file mode 100644 index 0000000..947ecfe --- /dev/null +++ b/dizoo/box2d/lunarlander/config/lunarlander_trex_dqn_config.py @@ -0,0 +1,86 @@ +from easydict import EasyDict + +nstep = 1 +lunarlander_trex_dqn_default_config = dict( + exp_name='lunarlander_trex_dqn', + env=dict( + # Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess' + manager=dict(shared_memory=True, force_reproducibility=True), + # Env number respectively for collector and evaluator. + collector_env_num=8, + evaluator_env_num=5, + n_evaluator_episode=5, + stop_value=200, + ), + reward_model=dict( + type='trex', + algo_for_model='dqn', + env_id='LunarLander-v2', + min_snippet_length=30, + max_snippet_length=100, + checkpoint_min=100, + checkpoint_max=900, + checkpoint_step=100, + learning_rate=1e-5, + update_per_collect=1, + expert_model_path='abs model path', + reward_model_path='abs data path + ./lunarlander.params', + offline_data_path='abs data path', + ), + policy=dict( + # Whether to use cuda for network. + cuda=False, + model=dict( + obs_shape=8, + action_shape=4, + encoder_hidden_size_list=[512, 64], + # Whether to use dueling head. + dueling=True, + ), + # Reward's future discount factor, aka. gamma. + discount_factor=0.99, + # How many steps in td error. + nstep=nstep, + # learn_mode config + learn=dict( + update_per_collect=10, + batch_size=64, + learning_rate=0.001, + # Frequency of target network update. + target_update_freq=100, + ), + # collect_mode config + collect=dict( + # You can use either "n_sample" or "n_episode" in collector.collect. + # Get "n_sample" samples per collect. + n_sample=64, + # Cut trajectories into pieces with length "unroll_len". + unroll_len=1, + ), + # command_mode config + other=dict( + # Epsilon greedy with decay. + eps=dict( + # Decay type. Support ['exp', 'linear']. + type='exp', + start=0.95, + end=0.1, + decay=50000, + ), + replay_buffer=dict(replay_buffer_size=100000, ) + ), + ), +) +lunarlander_trex_dqn_default_config = EasyDict(lunarlander_trex_dqn_default_config) +main_config = lunarlander_trex_dqn_default_config + +lunarlander_trex_dqn_create_config = dict( + env=dict( + type='lunarlander', + import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='dqn'), +) +lunarlander_trex_dqn_create_config = EasyDict(lunarlander_trex_dqn_create_config) +create_config = lunarlander_trex_dqn_create_config diff --git a/dizoo/box2d/lunarlander/config/lunarlander_trex_offppo_config.py b/dizoo/box2d/lunarlander/config/lunarlander_trex_offppo_config.py new file mode 100644 index 0000000..70d90c6 --- /dev/null +++ b/dizoo/box2d/lunarlander/config/lunarlander_trex_offppo_config.py @@ -0,0 +1,63 @@ +from easydict import EasyDict + +lunarlander_trex_ppo_config = dict( + exp_name='lunarlander_trex_offppo', + env=dict( + manager=dict(shared_memory=True, force_reproducibility=True), + collector_env_num=8, + evaluator_env_num=5, + n_evaluator_episode=5, + stop_value=200, + ), + reward_model=dict( + type='trex', + algo_for_model='ppo', + env_id='LunarLander-v2', + min_snippet_length=30, + max_snippet_length=100, + checkpoint_min=1000, + checkpoint_max=9000, + checkpoint_step=1000, + learning_rate=1e-5, + update_per_collect=1, + expert_model_path='abs model path', + reward_model_path='abs data path + ./lunarlander.params', + offline_data_path='abs data path', + ), + policy=dict( + cuda=True, + model=dict( + obs_shape=8, + action_shape=4, + ), + learn=dict( + update_per_collect=4, + batch_size=64, + learning_rate=0.001, + value_weight=0.5, + entropy_weight=0.01, + clip_ratio=0.2, + nstep=1, + nstep_return=False, + adv_norm=True, + ), + collect=dict( + n_sample=128, + unroll_len=1, + discount_factor=0.99, + gae_lambda=0.95, + ), + ), +) +lunarlander_trex_ppo_config = EasyDict(lunarlander_trex_ppo_config) +main_config = lunarlander_trex_ppo_config +lunarlander_trex_ppo_create_config = dict( + env=dict( + type='lunarlander', + import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='ppo_offpolicy'), +) +lunarlander_trex_ppo_create_config = EasyDict(lunarlander_trex_ppo_create_config) +create_config = lunarlander_trex_ppo_create_config diff --git a/dizoo/classic_control/cartpole/config/__init__.py b/dizoo/classic_control/cartpole/config/__init__.py index 16943bf..623fa94 100644 --- a/dizoo/classic_control/cartpole/config/__init__.py +++ b/dizoo/classic_control/cartpole/config/__init__.py @@ -15,4 +15,7 @@ from .cartpole_dqfd_config import cartpole_dqfd_config, cartpole_dqfd_create_con from .cartpole_sqil_config import cartpole_sqil_config, cartpole_sqil_create_config from .cartpole_sql_config import cartpole_sql_config, cartpole_sql_create_config from .cartpole_dqn_gail_config import cartpole_dqn_gail_config, cartpole_dqn_gail_create_config +from .cartpole_gcl_config import cartpole_gcl_ppo_onpolicy_config, cartpole_gcl_ppo_onpolicy_create_config +from .cartpole_trex_dqn_config import cartpole_trex_dqn_config, cartpole_trex_dqn_create_config +from .cartpole_trex_offppo_config import cartpole_trex_ppo_offpolicy_config, cartpole_trex_ppo_offpolicy_create_config # from .cartpole_ppo_default_loader import cartpole_ppo_default_loader diff --git a/dizoo/classic_control/cartpole/config/cartpole_gcl_config.py b/dizoo/classic_control/cartpole/config/cartpole_gcl_config.py index f94604a..71972d9 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_gcl_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_gcl_config.py @@ -1,7 +1,6 @@ from easydict import EasyDict -from ding.entry import serial_pipeline_guided_cost -cartpole_ppo_offpolicy_config = dict( +cartpole_gcl_ppo_onpolicy_config = dict( exp_name='cartpole_guided_cost', env=dict( collector_env_num=8, @@ -52,9 +51,9 @@ cartpole_ppo_offpolicy_config = dict( ), ), ) -cartpole_ppo_offpolicy_config = EasyDict(cartpole_ppo_offpolicy_config) -main_config = cartpole_ppo_offpolicy_config -cartpole_ppo_offpolicy_create_config = dict( +cartpole_gcl_ppo_onpolicy_config = EasyDict(cartpole_gcl_ppo_onpolicy_config) +main_config = cartpole_gcl_ppo_onpolicy_config +cartpole_gcl_ppo_onpolicy_create_config = dict( env=dict( type='cartpole', import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'], @@ -63,8 +62,5 @@ cartpole_ppo_offpolicy_create_config = dict( policy=dict(type='ppo'), reward_model=dict(type='guided_cost'), ) -cartpole_ppo_offpolicy_create_config = EasyDict(cartpole_ppo_offpolicy_create_config) -create_config = cartpole_ppo_offpolicy_create_config - -if __name__ == "__main__": - serial_pipeline_guided_cost([main_config, create_config], seed=0) +cartpole_gcl_ppo_onpolicy_create_config = EasyDict(cartpole_gcl_ppo_onpolicy_create_config) +create_config = cartpole_gcl_ppo_onpolicy_create_config diff --git a/dizoo/classic_control/cartpole/config/cartpole_trex_dqn_config.py b/dizoo/classic_control/cartpole/config/cartpole_trex_dqn_config.py new file mode 100644 index 0000000..268f2d9 --- /dev/null +++ b/dizoo/classic_control/cartpole/config/cartpole_trex_dqn_config.py @@ -0,0 +1,67 @@ +from easydict import EasyDict + +cartpole_trex_dqn_config = dict( + exp_name='cartpole_trex_dqn', + env=dict( + manager=dict(shared_memory=True, force_reproducibility=True), + collector_env_num=8, + evaluator_env_num=5, + n_evaluator_episode=5, + stop_value=195, + replay_path='cartpole_dqn/video', + ), + reward_model=dict( + type='trex', + algo_for_model='dqn', + env_id='CartPole-v0', + min_snippet_length=5, + max_snippet_length=100, + checkpoint_min=0, + checkpoint_max=500, + checkpoint_step=100, + learning_rate=1e-5, + update_per_collect=1, + expert_model_path='abs model path', + reward_model_path='abs data path + ./cartpole.params', + offline_data_path='abs data path', + ), + policy=dict( + load_path='', + cuda=False, + model=dict( + obs_shape=4, + action_shape=2, + encoder_hidden_size_list=[128, 128, 64], + dueling=True, + ), + nstep=1, + discount_factor=0.97, + learn=dict( + batch_size=64, + learning_rate=0.001, + ), + collect=dict(n_sample=8), + eval=dict(evaluator=dict(eval_freq=40, )), + other=dict( + eps=dict( + type='exp', + start=0.95, + end=0.1, + decay=10000, + ), + replay_buffer=dict(replay_buffer_size=20000, ), + ), + ), +) +cartpole_trex_dqn_config = EasyDict(cartpole_trex_dqn_config) +main_config = cartpole_trex_dqn_config +cartpole_trex_dqn_create_config = dict( + env=dict( + type='cartpole', + import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='dqn'), +) +cartpole_trex_dqn_create_config = EasyDict(cartpole_trex_dqn_create_config) +create_config = cartpole_trex_dqn_create_config diff --git a/dizoo/classic_control/cartpole/config/cartpole_trex_offppo_config.py b/dizoo/classic_control/cartpole/config/cartpole_trex_offppo_config.py new file mode 100644 index 0000000..e343210 --- /dev/null +++ b/dizoo/classic_control/cartpole/config/cartpole_trex_offppo_config.py @@ -0,0 +1,65 @@ +from easydict import EasyDict + +cartpole_trex_ppo_offpolicy_config = dict( + exp_name='cartpole_trex_offppo', + env=dict( + manager=dict(shared_memory=True, force_reproducibility=True), + collector_env_num=8, + evaluator_env_num=5, + n_evaluator_episode=5, + stop_value=195, + ), + reward_model=dict( + type='trex', + algo_for_model='ppo', + env_id='CartPole-v0', + min_snippet_length=5, + max_snippet_length=100, + checkpoint_min=0, + checkpoint_max=1000, + checkpoint_step=1000, + learning_rate=1e-5, + update_per_collect=1, + expert_model_path='abs model path', + reward_model_path='abs data path + ./cartpole.params', + offline_data_path='abs data path', + ), + policy=dict( + cuda=False, + model=dict( + obs_shape=4, + action_shape=2, + encoder_hidden_size_list=[64, 64, 128], + critic_head_hidden_size=128, + actor_head_hidden_size=128, + critic_head_layer_num=1, + ), + learn=dict( + update_per_collect=6, + batch_size=64, + learning_rate=0.001, + value_weight=0.5, + entropy_weight=0.01, + clip_ratio=0.2, + ), + collect=dict( + n_sample=128, + unroll_len=1, + discount_factor=0.9, + gae_lambda=0.95, + ), + other=dict(replay_buffer=dict(replay_buffer_size=5000)) + ), +) +cartpole_trex_ppo_offpolicy_config = EasyDict(cartpole_trex_ppo_offpolicy_config) +main_config = cartpole_trex_ppo_offpolicy_config +cartpole_trex_ppo_offpolicy_create_config = dict( + env=dict( + type='cartpole', + import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='ppo_offpolicy'), +) +cartpole_trex_ppo_offpolicy_create_config = EasyDict(cartpole_trex_ppo_offpolicy_create_config) +create_config = cartpole_trex_ppo_offpolicy_create_config diff --git a/dizoo/mujoco/config/__init__.py b/dizoo/mujoco/config/__init__.py index 5ee769b..4550574 100644 --- a/dizoo/mujoco/config/__init__.py +++ b/dizoo/mujoco/config/__init__.py @@ -1,13 +1,23 @@ from .ant_ddpg_default_config import ant_ddpg_default_config from .ant_sac_default_config import ant_sac_default_config from .ant_td3_default_config import ant_td3_default_config +from .ant_trex_onppo_default_config import ant_trex_ppo_default_config, ant_trex_ppo_create_default_config +from .ant_trex_sac_default_config import ant_trex_sac_default_config, ant_trex_sac_default_create_config from .halfcheetah_ddpg_default_config import halfcheetah_ddpg_default_config from .halfcheetah_sac_default_config import halfcheetah_sac_default_config from .halfcheetah_td3_default_config import halfcheetah_td3_default_config +from .halfcheetah_trex_onppo_default_config import halfCheetah_trex_ppo_default_config, halfCheetah_trex_ppo_create_default_config +from .halfcheetah_trex_sac_default_config import halfcheetah_trex_sac_default_config, halfcheetah_trex_sac_default_create_config +from .halfcheetah_onppo_default_config import halfcheetah_ppo_default_config, halfcheetah_ppo_create_default_config +from .hopper_onppo_default_config import hopper_ppo_default_config, hopper_ppo_create_default_config from .hopper_ddpg_default_config import hopper_ddpg_default_config from .hopper_sac_default_config import hopper_sac_default_config from .hopper_td3_default_config import hopper_td3_default_config +from .hopper_trex_onppo_default_config import hopper_trex_ppo_default_config, hopper_trex_ppo_create_default_config +from .hopper_trex_sac_default_config import hopper_trex_sac_default_config, hopper_trex_sac_default_create_config from .walker2d_ddpg_default_config import walker2d_ddpg_default_config, walker2d_ddpg_default_create_config from .walker2d_sac_default_config import walker2d_sac_default_config from .walker2d_td3_default_config import walker2d_td3_default_config from .walker2d_ddpg_gail_config import walker2d_ddpg_gail_default_config, walker2d_ddpg_gail_default_create_config +from .walker2d_trex_onppo_default_config import walker_trex_ppo_default_config, walker_trex_ppo_create_default_config +from .walker2d_trex_sac_default_config import walker2d_trex_sac_default_config, walker2d_trex_sac_default_create_config diff --git a/dizoo/mujoco/config/ant_ppo_default_config.py b/dizoo/mujoco/config/ant_ppo_default_config.py new file mode 100644 index 0000000..0fb7910 --- /dev/null +++ b/dizoo/mujoco/config/ant_ppo_default_config.py @@ -0,0 +1,56 @@ +from easydict import EasyDict + +ant_ppo_default_config = dict( + exp_name = 'ant_onppo', + env=dict( + manager=dict(shared_memory=True, force_reproducibility=True), + env_id='Ant-v3', + norm_obs=dict(use_norm=False, ), + norm_reward=dict(use_norm=False, ), + collector_env_num=8, + evaluator_env_num=10, + use_act_scale=True, + n_evaluator_episode=10, + stop_value=6000, + ), + policy=dict( + cuda=True, + recompute_adv=True, + model=dict( + obs_shape=111, + action_shape=8, + continuous=True, + ), + continuous=True, + learn=dict( + epoch_per_collect=10, + batch_size=64, + learning_rate=3e-4, + value_weight=0.5, + entropy_weight=0.0, + clip_ratio=0.2, + adv_norm=True, + value_norm=True, + ), + collect=dict( + n_sample=2048, + unroll_len=1, + discount_factor=0.99, + gae_lambda=0.97, + ), + eval=dict(evaluator=dict(eval_freq=5000, )), + ), +) +ant_ppo_default_config = EasyDict(ant_ppo_default_config) +main_config = ant_ppo_default_config + +ant_ppo_create_default_config = dict( + env=dict( + type='mujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='ppo', ), +) +ant_ppo_create_default_config = EasyDict(ant_ppo_create_default_config) +create_config = ant_ppo_create_default_config diff --git a/dizoo/mujoco/config/ant_sac_default_config.py b/dizoo/mujoco/config/ant_sac_default_config.py index 930b59f..81bc183 100644 --- a/dizoo/mujoco/config/ant_sac_default_config.py +++ b/dizoo/mujoco/config/ant_sac_default_config.py @@ -1,6 +1,7 @@ from easydict import EasyDict ant_sac_default_config = dict( + exp_name='ant_sac', env=dict( env_id='Ant-v3', norm_obs=dict(use_norm=False, ), diff --git a/dizoo/mujoco/config/ant_trex_onppo_default_config.py b/dizoo/mujoco/config/ant_trex_onppo_default_config.py new file mode 100644 index 0000000..1870715 --- /dev/null +++ b/dizoo/mujoco/config/ant_trex_onppo_default_config.py @@ -0,0 +1,72 @@ +from easydict import EasyDict + +ant_trex_ppo_default_config = dict( + exp_name='ant_trex_onppo', + env=dict( + manager=dict(shared_memory=True, force_reproducibility=True), + env_id='Ant-v3', + norm_obs=dict(use_norm=False, ), + norm_reward=dict(use_norm=False, ), + collector_env_num=8, + evaluator_env_num=10, + use_act_scale=True, + n_evaluator_episode=10, + stop_value=6000, + ), + reward_model=dict( + type='trex', + algo_for_model='ppo', + env_id='Ant-v3', + min_snippet_length=10, + max_snippet_length=100, + checkpoint_min=100, + checkpoint_max=900, + checkpoint_step=100, + learning_rate=1e-5, + update_per_collect=1, + expert_model_path='abs model path', + reward_model_path='abs data path + ./ant.params', + continuous=True, + offline_data_path='asb data path', + ), + policy=dict( + cuda=True, + recompute_adv=True, + model=dict( + obs_shape=111, + action_shape=8, + continuous=True, + ), + continuous=True, + learn=dict( + epoch_per_collect=10, + batch_size=64, + learning_rate=3e-4, + value_weight=0.5, + entropy_weight=0.0, + clip_ratio=0.2, + adv_norm=True, + value_norm=True, + ), + collect=dict( + n_sample=2048, + unroll_len=1, + discount_factor=0.99, + gae_lambda=0.97, + ), + eval=dict(evaluator=dict(eval_freq=5000, )), + ), +) +ant_trex_ppo_default_config = EasyDict(ant_trex_ppo_default_config) +main_config = ant_trex_ppo_default_config + +ant_trex_ppo_create_default_config = dict( + env=dict( + type='mujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='ppo', ), +) +ant_trex_ppo_create_default_config = EasyDict(ant_trex_ppo_create_default_config) +create_config = ant_trex_ppo_create_default_config diff --git a/dizoo/mujoco/config/ant_trex_sac_default_config.py b/dizoo/mujoco/config/ant_trex_sac_default_config.py new file mode 100644 index 0000000..cd85d9a --- /dev/null +++ b/dizoo/mujoco/config/ant_trex_sac_default_config.py @@ -0,0 +1,82 @@ +from easydict import EasyDict + +ant_trex_sac_default_config = dict( + exp_name='ant_trex_sac', + env=dict( + manager=dict(shared_memory=True, force_reproducibility=True), + env_id='Ant-v3', + norm_obs=dict(use_norm=False, ), + norm_reward=dict(use_norm=False, ), + collector_env_num=1, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + stop_value=6000, + ), + reward_model=dict( + type='trex', + algo_for_model='sac', + env_id='Ant-v3', + min_snippet_length=30, + max_snippet_length=100, + checkpoint_min=1000, + checkpoint_max=9000, + checkpoint_step=1000, + learning_rate=1e-5, + update_per_collect=1, + expert_model_path='abs model path', + reward_model_path='abs data path + ./ant.params', + continuous=True, + offline_data_path='asb data path', + ), + policy=dict( + cuda=True, + random_collect_size=10000, + model=dict( + obs_shape=111, + action_shape=8, + twin_critic=True, + actor_head_type='reparameterization', + actor_head_hidden_size=256, + critic_head_hidden_size=256, + ), + learn=dict( + update_per_collect=1, + batch_size=256, + learning_rate_q=1e-3, + learning_rate_policy=1e-3, + learning_rate_alpha=3e-4, + ignore_done=False, + target_theta=0.005, + discount_factor=0.99, + alpha=0.2, + reparameterization=True, + auto_alpha=False, + ), + collect=dict( + n_sample=1, + unroll_len=1, + ), + command=dict(), + eval=dict(), + other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ), + ), +) + +ant_trex_sac_default_config = EasyDict(ant_trex_sac_default_config) +main_config = ant_trex_sac_default_config + +ant_trex_sac_default_create_config = dict( + env=dict( + type='mujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sac', + import_names=['ding.policy.sac'], + ), + replay_buffer=dict(type='naive', ), +) +ant_trex_sac_default_create_config = EasyDict(ant_trex_sac_default_create_config) +create_config = ant_trex_sac_default_create_config diff --git a/dizoo/mujoco/config/halfcheetah_onppo_default_config.py b/dizoo/mujoco/config/halfcheetah_onppo_default_config.py index 20acc02..e95d955 100644 --- a/dizoo/mujoco/config/halfcheetah_onppo_default_config.py +++ b/dizoo/mujoco/config/halfcheetah_onppo_default_config.py @@ -4,7 +4,7 @@ from ding.entry import serial_pipeline_onpolicy collector_env_num = 1 evaluator_env_num = 1 halfcheetah_ppo_default_config = dict( - exp_name="result_mujoco/halfcheetah_onppo_noig", + exp_name="Halfcheetah_onppo", # exp_name="debug/debug_halfcheetah_onppo_ig", env=dict( diff --git a/dizoo/mujoco/config/halfcheetah_ppo_default_config.py b/dizoo/mujoco/config/halfcheetah_ppo_default_config.py new file mode 100644 index 0000000..7c557f2 --- /dev/null +++ b/dizoo/mujoco/config/halfcheetah_ppo_default_config.py @@ -0,0 +1,56 @@ +from easydict import EasyDict + +HalfCheetah_ppo_default_config = dict( + exp_name='HalfCheetah_onppo', + env=dict( + manager=dict(shared_memory=True, force_reproducibility=True), + env_id='HalfCheetah-v3', + norm_obs=dict(use_norm=False, ), + norm_reward=dict(use_norm=False, ), + collector_env_num=8, + evaluator_env_num=10, + use_act_scale=True, + n_evaluator_episode=10, + stop_value=3000, + ), + policy=dict( + cuda=True, + recompute_adv=True, + model=dict( + obs_shape=17, + action_shape=6, + continuous=True, + ), + continuous=True, + learn=dict( + epoch_per_collect=10, + batch_size=64, + learning_rate=3e-4, + value_weight=0.5, + entropy_weight=0.0, + clip_ratio=0.2, + adv_norm=True, + value_norm=True, + ), + collect=dict( + n_sample=2048, + unroll_len=1, + discount_factor=0.99, + gae_lambda=0.97, + ), + eval=dict(evaluator=dict(eval_freq=5000, )), + ), +) +HalfCheetah_ppo_default_config = EasyDict(HalfCheetah_ppo_default_config) +main_config = HalfCheetah_ppo_default_config + +HalfCheetah_ppo_create_default_config = dict( + env=dict( + type='mujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='ppo', ), +) +HalfCheetah_ppo_create_default_config = EasyDict(HalfCheetah_ppo_create_default_config) +create_config = HalfCheetah_ppo_create_default_config diff --git a/dizoo/mujoco/config/halfcheetah_sac_default_config.py b/dizoo/mujoco/config/halfcheetah_sac_default_config.py index 444d70b..19253c4 100644 --- a/dizoo/mujoco/config/halfcheetah_sac_default_config.py +++ b/dizoo/mujoco/config/halfcheetah_sac_default_config.py @@ -1,6 +1,7 @@ from easydict import EasyDict halfcheetah_sac_default_config = dict( + exp_name = 'halfcheetah_sac', env=dict( env_id='HalfCheetah-v3', norm_obs=dict(use_norm=False, ), diff --git a/dizoo/mujoco/config/halfcheetah_trex_onppo_default_config.py b/dizoo/mujoco/config/halfcheetah_trex_onppo_default_config.py new file mode 100644 index 0000000..72ca698 --- /dev/null +++ b/dizoo/mujoco/config/halfcheetah_trex_onppo_default_config.py @@ -0,0 +1,72 @@ +from easydict import EasyDict + +halfCheetah_trex_ppo_default_config = dict( + exp_name='HalfCheetah_trex_onppo', + env=dict( + manager=dict(shared_memory=True, force_reproducibility=True), + env_id='HalfCheetah-v3', + norm_obs=dict(use_norm=False, ), + norm_reward=dict(use_norm=False, ), + collector_env_num=8, + evaluator_env_num=10, + use_act_scale=True, + n_evaluator_episode=10, + stop_value=3000, + ), + reward_model=dict( + type='trex', + algo_for_model='ppo', + env_id='HalfCheetah-v3', + min_snippet_length=30, + max_snippet_length=100, + checkpoint_min=100, + checkpoint_max=900, + checkpoint_step=100, + learning_rate=1e-5, + update_per_collect=1, + expert_model_path='abs model path', + reward_model_path='abs data path + /HalfCheetah.params', + continuous=True, + offline_data_path='asb data path', + ), + policy=dict( + cuda=True, + recompute_adv=True, + model=dict( + obs_shape=17, + action_shape=6, + continuous=True, + ), + continuous=True, + learn=dict( + epoch_per_collect=10, + batch_size=64, + learning_rate=3e-4, + value_weight=0.5, + entropy_weight=0.0, + clip_ratio=0.2, + adv_norm=True, + value_norm=True, + ), + collect=dict( + n_sample=2048, + unroll_len=1, + discount_factor=0.99, + gae_lambda=0.97, + ), + eval=dict(evaluator=dict(eval_freq=5000, )), + ), +) +halfCheetah_trex_ppo_default_config = EasyDict(halfCheetah_trex_ppo_default_config) +main_config = halfCheetah_trex_ppo_default_config + +halfCheetah_trex_ppo_create_default_config = dict( + env=dict( + type='mujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='ppo', ), +) +halfCheetah_trex_ppo_create_default_config = EasyDict(halfCheetah_trex_ppo_create_default_config) +create_config = halfCheetah_trex_ppo_create_default_config diff --git a/dizoo/mujoco/config/halfcheetah_trex_sac_default_config.py b/dizoo/mujoco/config/halfcheetah_trex_sac_default_config.py new file mode 100644 index 0000000..13194f7 --- /dev/null +++ b/dizoo/mujoco/config/halfcheetah_trex_sac_default_config.py @@ -0,0 +1,82 @@ +from easydict import EasyDict + +halfcheetah_trex_sac_default_config = dict( + exp_name='halfcheetah_trex_sac', + env=dict( + manager=dict(shared_memory=True, force_reproducibility=True), + env_id='HalfCheetah-v3', + norm_obs=dict(use_norm=False, ), + norm_reward=dict(use_norm=False, ), + collector_env_num=1, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + stop_value=12000, + ), + reward_model=dict( + type='trex', + algo_for_model='sac', + env_id='HalfCheetah-v3', + learning_rate=1e-5, + min_snippet_length=30, + max_snippet_length=100, + checkpoint_min=1000, + checkpoint_max=9000, + checkpoint_step=1000, + update_per_collect=1, + expert_model_path='abs model path', + reward_model_path='abs data path + /HalfCheetah.params', + continuous=True, + offline_data_path='asb data path', + ), + policy=dict( + cuda=True, + random_collect_size=10000, + model=dict( + obs_shape=17, + action_shape=6, + twin_critic=True, + actor_head_type='reparameterization', + actor_head_hidden_size=256, + critic_head_hidden_size=256, + ), + learn=dict( + update_per_collect=1, + batch_size=256, + learning_rate_q=1e-3, + learning_rate_policy=1e-3, + learning_rate_alpha=3e-4, + ignore_done=True, + target_theta=0.005, + discount_factor=0.99, + alpha=0.2, + reparameterization=True, + auto_alpha=False, + ), + collect=dict( + n_sample=1, + unroll_len=1, + ), + command=dict(), + eval=dict(), + other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ), + ), +) + +halfcheetah_trex_sac_default_config = EasyDict(halfcheetah_trex_sac_default_config) +main_config = halfcheetah_trex_sac_default_config + +halfcheetah_trex_sac_default_create_config = dict( + env=dict( + type='mujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sac', + import_names=['ding.policy.sac'], + ), + replay_buffer=dict(type='naive', ), +) +halfcheetah_trex_sac_default_create_config = EasyDict(halfcheetah_trex_sac_default_create_config) +create_config = halfcheetah_trex_sac_default_create_config diff --git a/dizoo/mujoco/config/hopper_sac_default_config.py b/dizoo/mujoco/config/hopper_sac_default_config.py index bc9b8c6..ae777b2 100644 --- a/dizoo/mujoco/config/hopper_sac_default_config.py +++ b/dizoo/mujoco/config/hopper_sac_default_config.py @@ -1,6 +1,7 @@ from easydict import EasyDict hopper_sac_default_config = dict( + exp_name = 'hopper_sac', env=dict( env_id='Hopper-v3', norm_obs=dict(use_norm=False, ), diff --git a/dizoo/mujoco/config/hopper_trex_onppo_default_config.py b/dizoo/mujoco/config/hopper_trex_onppo_default_config.py new file mode 100644 index 0000000..c1b6272 --- /dev/null +++ b/dizoo/mujoco/config/hopper_trex_onppo_default_config.py @@ -0,0 +1,72 @@ +from easydict import EasyDict + +hopper_trex_ppo_default_config = dict( + exp_name='hopper_trex_onppo', + env=dict( + manager=dict(shared_memory=True, force_reproducibility=True), + env_id='Hopper-v3', + norm_obs=dict(use_norm=False, ), + norm_reward=dict(use_norm=False, ), + collector_env_num=8, + evaluator_env_num=10, + use_act_scale=True, + n_evaluator_episode=10, + stop_value=3000, + ), + reward_model=dict( + type='trex', + algo_for_model='ppo', + env_id='Hopper-v3', + min_snippet_length=30, + max_snippet_length=100, + checkpoint_min=1000, + checkpoint_max=9000, + checkpoint_step=1000, + learning_rate=1e-5, + update_per_collect=1, + expert_model_path='abs model path', + reward_model_path='abs data path + /hopper.params', + continuous=True, + offline_data_path='asb data path', + ), + policy=dict( + cuda=True, + recompute_adv=True, + model=dict( + obs_shape=11, + action_shape=3, + continuous=True, + ), + continuous=True, + learn=dict( + epoch_per_collect=10, + batch_size=64, + learning_rate=3e-4, + value_weight=0.5, + entropy_weight=0.0, + clip_ratio=0.2, + adv_norm=True, + value_norm=True, + ), + collect=dict( + n_sample=2048, + unroll_len=1, + discount_factor=0.99, + gae_lambda=0.97, + ), + eval=dict(evaluator=dict(eval_freq=5000, )), + ), +) +hopper_trex_ppo_default_config = EasyDict(hopper_trex_ppo_default_config) +main_config = hopper_trex_ppo_default_config + +hopper_trex_ppo_create_default_config = dict( + env=dict( + type='mujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='ppo', ), +) +hopper_trex_ppo_create_default_config = EasyDict(hopper_trex_ppo_create_default_config) +create_config = hopper_trex_ppo_create_default_config diff --git a/dizoo/mujoco/config/hopper_trex_sac_default_config.py b/dizoo/mujoco/config/hopper_trex_sac_default_config.py new file mode 100644 index 0000000..c1176b2 --- /dev/null +++ b/dizoo/mujoco/config/hopper_trex_sac_default_config.py @@ -0,0 +1,82 @@ +from easydict import EasyDict + +hopper_trex_sac_default_config = dict( + exp_name='hopper_trex_sac', + env=dict( + manager=dict(shared_memory=True, force_reproducibility=True), + env_id='Hopper-v3', + norm_obs=dict(use_norm=False, ), + norm_reward=dict(use_norm=False, ), + collector_env_num=1, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + stop_value=6000, + ), + reward_model=dict( + type='trex', + algo_for_model='sac', + env_id='Hopper-v3', + min_snippet_length=30, + max_snippet_length=100, + checkpoint_min=1000, + checkpoint_max=9000, + checkpoint_step=1000, + learning_rate=1e-5, + update_per_collect=1, + expert_model_path='abs model path', + reward_model_path='abs data path + /hopper.params', + continuous=True, + offline_data_path='asb data path', + ), + policy=dict( + cuda=True, + random_collect_size=10000, + model=dict( + obs_shape=11, + action_shape=3, + twin_critic=True, + actor_head_type='reparameterization', + actor_head_hidden_size=256, + critic_head_hidden_size=256, + ), + learn=dict( + update_per_collect=1, + batch_size=256, + learning_rate_q=1e-3, + learning_rate_policy=1e-3, + learning_rate_alpha=3e-4, + ignore_done=False, + target_theta=0.005, + discount_factor=0.99, + alpha=0.2, + reparameterization=True, + auto_alpha=False, + ), + collect=dict( + n_sample=1, + unroll_len=1, + ), + command=dict(), + eval=dict(), + other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ), + ), +) + +hopper_trex_sac_default_config = EasyDict(hopper_trex_sac_default_config) +main_config = hopper_trex_sac_default_config + +hopper_trex_sac_default_create_config = dict( + env=dict( + type='mujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sac', + import_names=['ding.policy.sac'], + ), + replay_buffer=dict(type='naive', ), +) +hopper_trex_sac_default_create_config = EasyDict(hopper_trex_sac_default_create_config) +create_config = hopper_trex_sac_default_create_config diff --git a/dizoo/mujoco/config/walker2d_ppo_default_config.py b/dizoo/mujoco/config/walker2d_ppo_default_config.py new file mode 100644 index 0000000..d3a3132 --- /dev/null +++ b/dizoo/mujoco/config/walker2d_ppo_default_config.py @@ -0,0 +1,58 @@ +from easydict import EasyDict + +walker_ppo_default_config = dict( + exp_name='walker2d_onppo', + env=dict( + manager=dict(shared_memory=True, force_reproducibility=True), + env_id='Walker2d-v3', + norm_obs=dict(use_norm=False, ), + norm_reward=dict(use_norm=False, ), + collector_env_num=8, + evaluator_env_num=10, + use_act_scale=True, + n_evaluator_episode=10, + stop_value=3000, + ), + policy=dict( + cuda=True, + recompute_adv=True, + model=dict( + obs_shape=17, + action_shape=6, + continuous=True, + ), + continuous=True, + learn=dict( + epoch_per_collect=10, + batch_size=64, + learning_rate=3e-4, + value_weight=0.5, + entropy_weight=0.0, + clip_ratio=0.2, + adv_norm=True, + value_norm=True, + ), + collect=dict( + n_sample=2048, + unroll_len=1, + discount_factor=0.99, + gae_lambda=0.97, + ), + eval=dict(evaluator=dict(eval_freq=5000, )), + ), +) +walker_ppo_default_config = EasyDict(walker_ppo_default_config) +main_config = walker_ppo_default_config + +walker_ppo_create_default_config = dict( + env=dict( + type='mujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='ppo', + ), +) +walker_ppo_create_default_config = EasyDict(walker_ppo_create_default_config) +create_config = walker_ppo_create_default_config diff --git a/dizoo/mujoco/config/walker2d_trex_onppo_default_config.py b/dizoo/mujoco/config/walker2d_trex_onppo_default_config.py new file mode 100644 index 0000000..18ab44e --- /dev/null +++ b/dizoo/mujoco/config/walker2d_trex_onppo_default_config.py @@ -0,0 +1,74 @@ +from easydict import EasyDict + +walker_trex_ppo_default_config = dict( + exp_name = 'walker2d_trex_onppo', + env=dict( + manager=dict(shared_memory=True, force_reproducibility=True), + env_id='Walker2d-v3', + norm_obs=dict(use_norm=False, ), + norm_reward=dict(use_norm=False, ), + collector_env_num=8, + evaluator_env_num=10, + use_act_scale=True, + n_evaluator_episode=10, + stop_value=3000, + ), + reward_model=dict( + type='trex', + algo_for_model='ppo', + env_id='Walker2d-v3', + min_snippet_length=30, + max_snippet_length=100, + checkpoint_min=0, + checkpoint_max=1000, + checkpoint_step=1000, + learning_rate=1e-5, + update_per_collect=1, + expert_model_path='abs model path', + reward_model_path='abs data path+ ./walker2d.params', + continuous=True, + offline_data_path='asb data path', + ), + policy=dict( + cuda=True, + recompute_adv=True, + model=dict( + obs_shape=17, + action_shape=6, + continuous=True, + ), + continuous=True, + learn=dict( + epoch_per_collect=10, + batch_size=64, + learning_rate=3e-4, + value_weight=0.5, + entropy_weight=0.0, + clip_ratio=0.2, + adv_norm=True, + value_norm=True, + ), + collect=dict( + n_sample=2048, + unroll_len=1, + discount_factor=0.99, + gae_lambda=0.97, + ), + eval=dict(evaluator=dict(eval_freq=5000, )), + ), +) +walker_trex_ppo_default_config = EasyDict(walker_trex_ppo_default_config) +main_config = walker_trex_ppo_default_config + +walker_trex_ppo_create_default_config = dict( + env=dict( + type='mujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='ppo', + ), +) +walker_trex_ppo_create_default_config = EasyDict(walker_trex_ppo_create_default_config) +create_config = walker_trex_ppo_create_default_config diff --git a/dizoo/mujoco/config/walker2d_trex_sac_default_config.py b/dizoo/mujoco/config/walker2d_trex_sac_default_config.py new file mode 100644 index 0000000..e02fe2b --- /dev/null +++ b/dizoo/mujoco/config/walker2d_trex_sac_default_config.py @@ -0,0 +1,82 @@ +from easydict import EasyDict + +walker2d_trex_sac_default_config = dict( + exp_name = 'walker2d_trex_sac', + env=dict( + manager=dict(shared_memory=True, force_reproducibility=True), + env_id='Walker2d-v3', + norm_obs=dict(use_norm=False, ), + norm_reward=dict(use_norm=False, ), + collector_env_num=1, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + stop_value=6000, + ), + reward_model=dict( + type='trex', + algo_for_model='sac', + env_id='Walker2d-v3', + min_snippet_length=30, + max_snippet_length=100, + checkpoint_min=100, + checkpoint_max=900, + checkpoint_step=100, + learning_rate=1e-5, + update_per_collect=1, + expert_model_path='abs model path', + reward_model_path='abs data path + ./walker2d.params', + continuous=True, + offline_data_path='asb data path', + ), + policy=dict( + cuda=True, + random_collect_size=10000, + model=dict( + obs_shape=17, + action_shape=6, + twin_critic=True, + actor_head_type='reparameterization', + actor_head_hidden_size=256, + critic_head_hidden_size=256, + ), + learn=dict( + update_per_collect=1, + batch_size=256, + learning_rate_q=1e-3, + learning_rate_policy=1e-3, + learning_rate_alpha=3e-4, + ignore_done=False, + target_theta=0.005, + discount_factor=0.99, + alpha=0.2, + reparameterization=True, + auto_alpha=False, + ), + collect=dict( + n_sample=1, + unroll_len=1, + ), + command=dict(), + eval=dict(), + other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ), + ), +) + +walker2d_trex_sac_default_config = EasyDict(walker2d_trex_sac_default_config) +main_config = walker2d_trex_sac_default_config + +walker2d_trex_sac_default_create_config = dict( + env=dict( + type='mujoco', + import_names=['dizoo.mujoco.envs.mujoco_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sac', + import_names=['ding.policy.sac'], + ), + replay_buffer=dict(type='naive', ), +) +walker2d_trex_sac_default_create_config = EasyDict(walker2d_trex_sac_default_create_config) +create_config = walker2d_trex_sac_default_create_config diff --git a/dizoo/mujoco/entry/mujoco_trex_main.py b/dizoo/mujoco/entry/mujoco_trex_main.py new file mode 100644 index 0000000..065b24e --- /dev/null +++ b/dizoo/mujoco/entry/mujoco_trex_main.py @@ -0,0 +1,22 @@ +import argparse +import torch +from ding.entry import trex_collecting_data +from dizoo.mujoco.config.halfcheetah_trex_onppo_default_config import main_config, create_config +from ding.entry import serial_pipeline_reward_model_trex_onpolicy, serial_pipeline_reward_model_trex + +# Note serial_pipeline_reward_model_trex_onpolicy is for on policy ppo whereas serial_pipeline_reward_model_trex is for sac +# Note before run this file, please add the correpsonding path in the config, all path expect exp_name should be abs path + + +parser = argparse.ArgumentParser() +parser.add_argument('--cfg', type=str, default='please enter abs path for halfcheetah_trex_onppo_default_config.py or halfcheetah_trex_sac_default_config.py') +parser.add_argument('--seed', type=int, default=0) +parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' +) +args = parser.parse_args() + +trex_collecting_data(args) +# if run sac, please import the relevant config and use serial_pipeline_reward_model_trex +serial_pipeline_reward_model_trex_onpolicy([main_config, create_config]) + -- GitLab