From f089d02aaa327c0a6f191f9bc32e8c40ea8140f7 Mon Sep 17 00:00:00 2001 From: Will-Nie <61083608+Will-Nie@users.noreply.github.com> Date: Tue, 14 Dec 2021 17:41:33 +0800 Subject: [PATCH] polish(nyp): fix unittest for trex training and collecting (#144) * add trex algorithm for pong * sort style * add atari, ll,cp; fix device, collision; add_ppo * add accuracy evaluation * correct style * add seed to make sure results are replicable * remove useless part in cum return of model part * add mujoco onppo training pipeline; ppo config * improve style * add sac training config for mujoco * add log, add save data; polish config * logger; hyperparameter;walker * correct style * modify else condition * change rnd to trex * revise according to comments, add eposode collect * new collect mode for trex, fix all bugs, commnets * final change * polish after the final comment * add readme/test * add test for serial entry of trex/gcl * sort style * change mujoco to cartpole for test for trex_onppo * remove files generated by testing * revise tests for entry * sort style * revise tests * modify pytest * fix(nyz): speed up ppg/ppo and marl algo unittest * polish(nyz): speed up trex unittest and fix trex entry default config bug * fix(nyz): fix same name bug * fix(nyz): fix remove conflict bug(ci skip) Co-authored-by: niuyazhe --- Makefile | 1 + ding/entry/__init__.py | 2 +- ding/entry/application_entry.py | 2 +- ding/entry/serial_entry_trex.py | 1 + ding/entry/serial_entry_trex_onpolicy.py | 1 + ding/entry/tests/test_application_entry.py | 17 +++-- ...est_application_entry_trex_collect_data.py | 10 ++- ding/entry/tests/test_serial_entry.py | 12 ++++ ding/entry/tests/test_serial_entry_trex.py | 9 ++- .../tests/test_serial_entry_trex_onpolicy.py | 22 +++--- ding/reward_model/trex_reward_model.py | 35 ++++----- .../cartpole/config/__init__.py | 1 + .../cartpole/config/cartpole_ppg_config.py | 1 + .../cartpole/config/cartpole_ppo_config.py | 3 + .../config/cartpole_ppo_offpolicy_config.py | 4 ++ .../config/cartpole_trex_dqn_config.py | 2 +- .../config/cartpole_trex_offppo_config.py | 2 +- .../config/cartpole_trex_onppo_config.py | 72 +++++++++++++++++++ .../cartpole/entry/cartpole_ppg_main.py | 2 +- .../cooperative_navigation_atoc_config.py | 4 +- .../cooperative_navigation_collaq_config.py | 4 +- 21 files changed, 157 insertions(+), 50 deletions(-) create mode 100644 dizoo/classic_control/cartpole/config/cartpole_trex_onppo_config.py diff --git a/Makefile b/Makefile index 37b644e..3bbbda2 100644 --- a/Makefile +++ b/Makefile @@ -20,6 +20,7 @@ unittest: --cov-report=xml \ --cov-report term-missing \ --cov=${COV_DIR} \ + ${DURATIONS_COMMAND} \ ${WORKERS_COMMAND} \ -sv -m unittest \ diff --git a/ding/entry/__init__.py b/ding/entry/__init__.py index d7ec000..f0c1930 100644 --- a/ding/entry/__init__.py +++ b/ding/entry/__init__.py @@ -14,6 +14,6 @@ from .serial_entry_trex import serial_pipeline_reward_model_trex from .serial_entry_trex_onpolicy import serial_pipeline_reward_model_trex_onpolicy from .parallel_entry import parallel_pipeline from .application_entry import eval, collect_demo_data, collect_episodic_demo_data, \ - epsiode_to_transitions + episode_to_transitions from .application_entry_trex_collect_data import trex_collecting_data, collect_episodic_demo_data_for_trex from .serial_entry_guided_cost import serial_pipeline_guided_cost diff --git a/ding/entry/application_entry.py b/ding/entry/application_entry.py index 34cb3e5..0d77c39 100644 --- a/ding/entry/application_entry.py +++ b/ding/entry/application_entry.py @@ -231,7 +231,7 @@ def collect_episodic_demo_data( print('Collect episodic demo data successfully') -def epsiode_to_transitions(data_path: str, expert_data_path: str, nstep: int) -> None: +def episode_to_transitions(data_path: str, expert_data_path: str, nstep: int) -> None: r""" Overview: Transfer episoded data into nstep transitions diff --git a/ding/entry/serial_entry_trex.py b/ding/entry/serial_entry_trex.py index c559da0..02b71f8 100644 --- a/ding/entry/serial_entry_trex.py +++ b/ding/entry/serial_entry_trex.py @@ -44,6 +44,7 @@ def serial_pipeline_reward_model_trex( else: cfg, create_cfg = input_cfg create_cfg.policy.type = create_cfg.policy.type + '_command' + create_cfg.reward_model = dict(type='trex') env_fn = None if env_setting is None else env_setting[0] cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True) # Create main components: env, policy diff --git a/ding/entry/serial_entry_trex_onpolicy.py b/ding/entry/serial_entry_trex_onpolicy.py index da69808..6bf6f47 100644 --- a/ding/entry/serial_entry_trex_onpolicy.py +++ b/ding/entry/serial_entry_trex_onpolicy.py @@ -43,6 +43,7 @@ def serial_pipeline_reward_model_trex_onpolicy( else: cfg, create_cfg = input_cfg create_cfg.policy.type = create_cfg.policy.type + '_command' + create_cfg.reward_model = dict(type='trex') env_fn = None if env_setting is None else env_setting[0] cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True) # Create main components: env, policy diff --git a/ding/entry/tests/test_application_entry.py b/ding/entry/tests/test_application_entry.py index 3329afe..747e815 100644 --- a/ding/entry/tests/test_application_entry.py +++ b/ding/entry/tests/test_application_entry.py @@ -10,7 +10,7 @@ from dizoo.classic_control.cartpole.config.cartpole_trex_offppo_config import ca from dizoo.classic_control.cartpole.envs import CartPoleEnv from ding.entry import serial_pipeline, eval, collect_demo_data from ding.config import compile_config -from ding.entry.application_entry import collect_episodic_demo_data, epsiode_to_transitions +from ding.entry.application_entry import collect_episodic_demo_data, episode_to_transitions @pytest.fixture(scope='module') @@ -65,8 +65,11 @@ class TestApplication: def test_collect_episodic_demo_data(self, setup_state_dict): config = deepcopy(cartpole_trex_ppo_offpolicy_config), deepcopy(cartpole_trex_ppo_offpolicy_create_config) + config[0].exp_name = 'cartpole_trex_offppo_episodic' collect_count = 16 - expert_data_path = './expert.data' + if not os.path.exists('./test_episode'): + os.mkdir('./test_episode') + expert_data_path = './test_episode/expert.data' collect_episodic_demo_data( config, seed=0, @@ -79,11 +82,13 @@ class TestApplication: assert isinstance(exp_data, list) assert isinstance(exp_data[0][0], dict) - def test_epsiode_to_transitions(self): - expert_data_path = './expert.data' - epsiode_to_transitions(data_path=expert_data_path, expert_data_path=expert_data_path, nstep=3) + def test_episode_to_transitions(self, setup_state_dict): + self.test_collect_episodic_demo_data(setup_state_dict) + expert_data_path = './test_episode/expert.data' + episode_to_transitions(data_path=expert_data_path, expert_data_path=expert_data_path, nstep=3) with open(expert_data_path, 'rb') as f: exp_data = pickle.load(f) assert isinstance(exp_data, list) assert isinstance(exp_data[0], dict) - os.popen('rm -rf ./expert.data ckpt* log') + os.popen('rm -rf ./test_episode/expert.data ckpt* log') + os.popen('rm -rf ./test_episode') diff --git a/ding/entry/tests/test_application_entry_trex_collect_data.py b/ding/entry/tests/test_application_entry_trex_collect_data.py index dfb98e6..c983050 100644 --- a/ding/entry/tests/test_application_entry_trex_collect_data.py +++ b/ding/entry/tests/test_application_entry_trex_collect_data.py @@ -17,6 +17,7 @@ from ding.entry import serial_pipeline @pytest.mark.unittest def test_collect_episodic_demo_data_for_trex(): expert_policy_state_dict_path = './expert_policy.pth' + expert_policy_state_dict_path = os.path.abspath('./expert_policy.pth') config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)] expert_policy = serial_pipeline(config, seed=0) torch.save(expert_policy.collect_mode.state_dict(), expert_policy_state_dict_path) @@ -39,11 +40,12 @@ def test_collect_episodic_demo_data_for_trex(): os.popen('rm -rf {}'.format(expert_policy_state_dict_path)) -# @pytest.mark.unittest +@pytest.mark.unittest def test_trex_collecting_data(): expert_policy_state_dict_path = './cartpole_ppo_offpolicy' expert_policy_state_dict_path = os.path.abspath(expert_policy_state_dict_path) config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)] + config[0].policy.learn.learner.hook.save_ckpt_after_iter = 100 expert_policy = serial_pipeline(config, seed=0) args = EasyDict( @@ -54,10 +56,14 @@ def test_trex_collecting_data(): 'device': 'cpu' } ) - args.cfg[0].reward_model.offline_data_path = 'cartpole_trex_offppo_offline_data' + args.cfg[0].reward_model.offline_data_path = './cartpole_trex_offppo' args.cfg[0].reward_model.offline_data_path = os.path.abspath(args.cfg[0].reward_model.offline_data_path) args.cfg[0].reward_model.reward_model_path = args.cfg[0].reward_model.offline_data_path + '/cartpole.params' + args.cfg[0].reward_model.expert_model_path = './cartpole_ppo_offpolicy' args.cfg[0].reward_model.expert_model_path = os.path.abspath(args.cfg[0].reward_model.expert_model_path) + args.cfg[0].reward_model.checkpoint_max = 100 + args.cfg[0].reward_model.checkpoint_step = 100 + args.cfg[0].reward_model.num_snippets = 100 trex_collecting_data(args=args) os.popen('rm -rf {}'.format(expert_policy_state_dict_path)) os.popen('rm -rf {}'.format(args.cfg[0].reward_model.offline_data_path)) diff --git a/ding/entry/tests/test_serial_entry.py b/ding/entry/tests/test_serial_entry.py index 877809c..c566f8c 100644 --- a/ding/entry/tests/test_serial_entry.py +++ b/ding/entry/tests/test_serial_entry.py @@ -252,6 +252,8 @@ def test_collaq(): config = [deepcopy(cooperative_navigation_collaq_config), deepcopy(cooperative_navigation_collaq_create_config)] config[0].policy.cuda = False config[0].policy.learn.update_per_collect = 1 + config[0].env.n_evaluator_episode = 2 + config[0].policy.collect.n_sample = 100 try: serial_pipeline(config, seed=0, max_iterations=1) except Exception: @@ -265,6 +267,8 @@ def test_coma(): config = [deepcopy(cooperative_navigation_coma_config), deepcopy(cooperative_navigation_coma_create_config)] config[0].policy.cuda = False config[0].policy.learn.update_per_collect = 1 + config[0].env.n_evaluator_episode = 2 + config[0].policy.collect.n_sample = 100 try: serial_pipeline(config, seed=0, max_iterations=1) except Exception: @@ -278,6 +282,8 @@ def test_qmix(): config = [deepcopy(cooperative_navigation_qmix_config), deepcopy(cooperative_navigation_qmix_create_config)] config[0].policy.cuda = False config[0].policy.learn.update_per_collect = 1 + config[0].env.n_evaluator_episode = 2 + config[0].policy.collect.n_sample = 100 try: serial_pipeline(config, seed=0, max_iterations=1) except Exception: @@ -291,6 +297,8 @@ def test_wqmix(): config = [deepcopy(cooperative_navigation_wqmix_config), deepcopy(cooperative_navigation_wqmix_create_config)] config[0].policy.cuda = False config[0].policy.learn.update_per_collect = 1 + config[0].env.n_evaluator_episode = 2 + config[0].policy.collect.n_sample = 100 try: serial_pipeline(config, seed=0, max_iterations=1) except Exception: @@ -304,6 +312,8 @@ def test_qtran(): config = [deepcopy(cooperative_navigation_qtran_config), deepcopy(cooperative_navigation_qtran_create_config)] config[0].policy.cuda = False config[0].policy.learn.update_per_collect = 1 + config[0].env.n_evaluator_episode = 2 + config[0].policy.collect.n_sample = 100 try: serial_pipeline(config, seed=0, max_iterations=1) except Exception: @@ -316,6 +326,8 @@ def test_qtran(): def test_atoc(): config = [deepcopy(cooperative_navigation_atoc_config), deepcopy(cooperative_navigation_atoc_create_config)] config[0].policy.cuda = False + config[0].env.n_evaluator_episode = 2 + config[0].policy.collect.n_sample = 100 try: serial_pipeline(config, seed=0, max_iterations=1) except Exception: diff --git a/ding/entry/tests/test_serial_entry_trex.py b/ding/entry/tests/test_serial_entry_trex.py index 0fc1ce9..e0ad664 100644 --- a/ding/entry/tests/test_serial_entry_trex.py +++ b/ding/entry/tests/test_serial_entry_trex.py @@ -14,20 +14,25 @@ from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import from ding.entry.application_entry_trex_collect_data import trex_collecting_data -# @pytest.mark.unittest +@pytest.mark.unittest def test_serial_pipeline_reward_model_trex(): config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)] + config[0].policy.learn.learner.hook.save_ckpt_after_iter = 100 expert_policy = serial_pipeline(config, seed=0) config = [deepcopy(cartpole_trex_ppo_offpolicy_config), deepcopy(cartpole_trex_ppo_offpolicy_create_config)] - config[0].reward_model.offline_data_path = 'dizoo/classic_control/cartpole/config/cartpole_trex_offppo' + config[0].reward_model.offline_data_path = './cartpole_trex_offppo' config[0].reward_model.offline_data_path = os.path.abspath(config[0].reward_model.offline_data_path) config[0].reward_model.reward_model_path = config[0].reward_model.offline_data_path + '/cartpole.params' config[0].reward_model.expert_model_path = './cartpole_ppo_offpolicy' config[0].reward_model.expert_model_path = os.path.abspath(config[0].reward_model.expert_model_path) + config[0].reward_model.checkpoint_max = 100 + config[0].reward_model.checkpoint_step = 100 + config[0].reward_model.num_snippets = 100 args = EasyDict({'cfg': deepcopy(config), 'seed': 0, 'device': 'cpu'}) trex_collecting_data(args=args) try: serial_pipeline_reward_model_trex(config, seed=0, max_iterations=1) + os.popen('rm -rf {}'.format(config[0].reward_model.offline_data_path)) except Exception: assert False, "pipeline fail" diff --git a/ding/entry/tests/test_serial_entry_trex_onpolicy.py b/ding/entry/tests/test_serial_entry_trex_onpolicy.py index 123d6be..0a915ad 100644 --- a/ding/entry/tests/test_serial_entry_trex_onpolicy.py +++ b/ding/entry/tests/test_serial_entry_trex_onpolicy.py @@ -7,25 +7,31 @@ import torch from ding.entry import serial_pipeline_onpolicy from ding.entry.serial_entry_trex_onpolicy import serial_pipeline_reward_model_trex_onpolicy -from dizoo.mujoco.config import hopper_ppo_default_config, hopper_ppo_create_default_config -from dizoo.mujoco.config import hopper_trex_ppo_default_config, hopper_trex_ppo_create_default_config +from dizoo.classic_control.cartpole.config import cartpole_ppo_config, cartpole_ppo_create_config +from dizoo.classic_control.cartpole.config import cartpole_trex_ppo_onpolicy_config, \ + cartpole_trex_ppo_onpolicy_create_config from ding.entry.application_entry_trex_collect_data import trex_collecting_data -# @pytest.mark.unittest +@pytest.mark.unittest def test_serial_pipeline_reward_model_trex(): - config = [deepcopy(hopper_ppo_default_config), deepcopy(hopper_ppo_create_default_config)] + config = [deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)] expert_policy = serial_pipeline_onpolicy(config, seed=0, max_iterations=90) - config = [deepcopy(hopper_trex_ppo_default_config), deepcopy(hopper_trex_ppo_create_default_config)] - config[0].reward_model.offline_data_path = 'dizoo/mujoco/config/hopper_trex_onppo' + config = [deepcopy(cartpole_trex_ppo_onpolicy_config), deepcopy(cartpole_trex_ppo_onpolicy_create_config)] + config[0].reward_model.offline_data_path = './cartpole_trex_onppo' config[0].reward_model.offline_data_path = os.path.abspath(config[0].reward_model.offline_data_path) - config[0].reward_model.reward_model_path = config[0].reward_model.offline_data_path + '/hopper.params' - config[0].reward_model.expert_model_path = './hopper_onppo' + config[0].reward_model.reward_model_path = config[0].reward_model.offline_data_path + '/cartpole.params' + config[0].reward_model.expert_model_path = './cartpole_ppo' config[0].reward_model.expert_model_path = os.path.abspath(config[0].reward_model.expert_model_path) + config[0].reward_model.checkpoint_max = 100 + config[0].reward_model.checkpoint_step = 100 + config[0].reward_model.num_snippets = 100 args = EasyDict({'cfg': deepcopy(config), 'seed': 0, 'device': 'cpu'}) trex_collecting_data(args=args) try: serial_pipeline_reward_model_trex_onpolicy(config, seed=0, max_iterations=1) + os.popen('rm -rf {}'.format(config[0].reward_model.offline_data_path)) + os.popen('rm -rf {}'.format(config[0].reward_model.expert_model_path)) except Exception: assert False, "pipeline fail" diff --git a/ding/reward_model/trex_reward_model.py b/ding/reward_model/trex_reward_model.py index 16b654f..f0c20f4 100644 --- a/ding/reward_model/trex_reward_model.py +++ b/ding/reward_model/trex_reward_model.py @@ -125,11 +125,8 @@ class TrexModel(nn.Module): if mode == 'sum': sum_rewards = torch.sum(r) sum_abs_rewards = torch.sum(torch.abs(r)) - # print(sum_rewards) - # print(r) return sum_rewards, sum_abs_rewards elif mode == 'batch': - # print(r) return r, torch.abs(r) else: raise KeyError("not support mode: {}, please choose mode=sum or mode=batch".format(mode)) @@ -157,6 +154,8 @@ class TrexRewardModel(BaseRewardModel): batch_size=64, target_new_data_count=64, hidden_size=128, + num_trajs=0, # number of downsampled full trajectories + num_snippets=6000, # number of short subtrajectories to sample ) def __init__(self, config: EasyDict, device: str, tb_logger: 'SummaryWriter') -> None: # noqa @@ -173,7 +172,7 @@ class TrexRewardModel(BaseRewardModel): assert device in ["cpu", "cuda"] or "cuda" in device self.device = device self.tb_logger = tb_logger - self.reward_model = TrexModel(self.cfg.policy.model.get('obs_shape')) + self.reward_model = TrexModel(self.cfg.policy.model.obs_shape) self.reward_model.to(self.device) self.pre_expert_data = [] self.train_data = [] @@ -184,8 +183,8 @@ class TrexRewardModel(BaseRewardModel): self.learning_rewards = [] self.training_obs = [] self.training_labels = [] - self.num_trajs = 0 # number of downsampled full trajectories - self.num_snippets = 6000 # number of short subtrajectories to sample + self.num_trajs = self.cfg.reward_model.num_trajs + self.num_snippets = self.cfg.reward_model.num_snippets # minimum number of short subtrajectories to sample self.min_snippet_length = config.reward_model.min_snippet_length # maximum number of short subtrajectories to sample @@ -240,19 +239,15 @@ class TrexRewardModel(BaseRewardModel): #collect training data max_traj_length = 0 num_demos = len(demonstrations) + assert num_demos >= 2 #add full trajs (for use on Enduro) si = np.random.randint(6, size=num_trajs) sj = np.random.randint(6, size=num_trajs) step = np.random.randint(3, 7, size=num_trajs) for n in range(num_trajs): - ti = 0 - tj = 0 - #only add trajectories that are different returns - while (ti == tj): - #pick two random demonstrations - ti = np.random.randint(num_demos) - tj = np.random.randint(num_demos) + #pick two random demonstrations + ti, tj = np.random.choice(num_demos, size=(2, ), replace=False) #create random partial trajs by finding random start frame and random skip frame traj_i = demonstrations[ti][si[n]::step[n]] # slice(start,stop,step) traj_j = demonstrations[tj][sj[n]::step[n]] @@ -266,13 +261,8 @@ class TrexRewardModel(BaseRewardModel): #fixed size snippets with progress prior rand_length = np.random.randint(min_snippet_length, max_snippet_length, size=num_snippets) for n in range(num_snippets): - ti = 0 - tj = 0 - #only add trajectories that are different returns - while (ti == tj): - #pick two random demonstrations - ti = np.random.randint(num_demos) - tj = np.random.randint(num_demos) + #pick two random demonstrations + ti, tj = np.random.choice(num_demos, size=(2, ), replace=False) #create random snippets #find min length of both demos to ensure we can pick a demo no earlier #than that chosen in worse preferred demo @@ -285,8 +275,8 @@ class TrexRewardModel(BaseRewardModel): tj_start = np.random.randint(min_length - rand_length[n] + 1) # print(tj_start, len(demonstrations[ti])) ti_start = np.random.randint(tj_start, len(demonstrations[ti]) - rand_length[n] + 1) - traj_i = demonstrations[ti][ti_start:ti_start + rand_length[n]:2 - ] # skip everyother framestack to reduce size + # skip everyother framestack to reduce size + traj_i = demonstrations[ti][ti_start:ti_start + rand_length[n]:2] traj_j = demonstrations[tj][tj_start:tj_start + rand_length[n]:2] max_traj_length = max(max_traj_length, len(traj_i), len(traj_j)) @@ -334,7 +324,6 @@ class TrexRewardModel(BaseRewardModel): item_loss = loss.item() cum_loss += item_loss if i % 100 == 99: - # print(i) self._logger.info("epoch {}:{} loss {}".format(epoch, i, cum_loss)) self._logger.info("abs_returns: {}".format(abs_rewards)) cum_loss = 0.0 diff --git a/dizoo/classic_control/cartpole/config/__init__.py b/dizoo/classic_control/cartpole/config/__init__.py index 623fa94..4c492a7 100644 --- a/dizoo/classic_control/cartpole/config/__init__.py +++ b/dizoo/classic_control/cartpole/config/__init__.py @@ -18,4 +18,5 @@ from .cartpole_dqn_gail_config import cartpole_dqn_gail_config, cartpole_dqn_gai from .cartpole_gcl_config import cartpole_gcl_ppo_onpolicy_config, cartpole_gcl_ppo_onpolicy_create_config from .cartpole_trex_dqn_config import cartpole_trex_dqn_config, cartpole_trex_dqn_create_config from .cartpole_trex_offppo_config import cartpole_trex_ppo_offpolicy_config, cartpole_trex_ppo_offpolicy_create_config +from .cartpole_trex_onppo_config import cartpole_trex_ppo_onpolicy_config, cartpole_trex_ppo_onpolicy_create_config # from .cartpole_ppo_default_loader import cartpole_ppo_default_loader diff --git a/dizoo/classic_control/cartpole/config/cartpole_ppg_config.py b/dizoo/classic_control/cartpole/config/cartpole_ppg_config.py index 8395ecf..be26893 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_ppg_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_ppg_config.py @@ -31,6 +31,7 @@ cartpole_ppg_config = dict( discount_factor=0.9, gae_lambda=0.95, ), + eval=dict(evaluator=dict(eval_freq=40, )), other=dict( replay_buffer=dict( multi_buffer=True, diff --git a/dizoo/classic_control/cartpole/config/cartpole_ppo_config.py b/dizoo/classic_control/cartpole/config/cartpole_ppo_config.py index 28c3182..f63ba82 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_ppo_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_ppo_config.py @@ -25,6 +25,9 @@ cartpole_ppo_config = dict( value_weight=0.5, entropy_weight=0.01, clip_ratio=0.2, + learner=dict( + hook=dict(save_ckpt_after_iter=100) + ), ), collect=dict( n_sample=256, diff --git a/dizoo/classic_control/cartpole/config/cartpole_ppo_offpolicy_config.py b/dizoo/classic_control/cartpole/config/cartpole_ppo_offpolicy_config.py index cbf903d..59ba682 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_ppo_offpolicy_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_ppo_offpolicy_config.py @@ -24,6 +24,9 @@ cartpole_ppo_offpolicy_config = dict( value_weight=0.5, entropy_weight=0.01, clip_ratio=0.2, + learner=dict( + hook=dict(save_ckpt_after_iter=1000) + ), ), collect=dict( n_sample=128, @@ -31,6 +34,7 @@ cartpole_ppo_offpolicy_config = dict( discount_factor=0.9, gae_lambda=0.95, ), + eval=dict(evaluator=dict(eval_freq=40, )), other=dict(replay_buffer=dict(replay_buffer_size=5000)) ), ) diff --git a/dizoo/classic_control/cartpole/config/cartpole_trex_dqn_config.py b/dizoo/classic_control/cartpole/config/cartpole_trex_dqn_config.py index 268f2d9..3f42447 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_trex_dqn_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_trex_dqn_config.py @@ -60,7 +60,7 @@ cartpole_trex_dqn_create_config = dict( type='cartpole', import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'], ), - env_manager=dict(type='subprocess'), + env_manager=dict(type='base'), policy=dict(type='dqn'), ) cartpole_trex_dqn_create_config = EasyDict(cartpole_trex_dqn_create_config) diff --git a/dizoo/classic_control/cartpole/config/cartpole_trex_offppo_config.py b/dizoo/classic_control/cartpole/config/cartpole_trex_offppo_config.py index e343210..32a65a9 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_trex_offppo_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_trex_offppo_config.py @@ -58,7 +58,7 @@ cartpole_trex_ppo_offpolicy_create_config = dict( type='cartpole', import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'], ), - env_manager=dict(type='subprocess'), + env_manager=dict(type='base'), policy=dict(type='ppo_offpolicy'), ) cartpole_trex_ppo_offpolicy_create_config = EasyDict(cartpole_trex_ppo_offpolicy_create_config) diff --git a/dizoo/classic_control/cartpole/config/cartpole_trex_onppo_config.py b/dizoo/classic_control/cartpole/config/cartpole_trex_onppo_config.py new file mode 100644 index 0000000..d9d159c --- /dev/null +++ b/dizoo/classic_control/cartpole/config/cartpole_trex_onppo_config.py @@ -0,0 +1,72 @@ +from easydict import EasyDict + +cartpole_trex_ppo_onpolicy_config = dict( + exp_name='cartpole_trex_onppo', + env=dict( + manager=dict(shared_memory=True, force_reproducibility=True), + collector_env_num=8, + evaluator_env_num=5, + n_evaluator_episode=5, + stop_value=195, + ), + reward_model=dict( + type='trex', + algo_for_model='ppo', + env_id='CartPole-v0', + min_snippet_length=5, + max_snippet_length=100, + checkpoint_min=0, + checkpoint_max=100, + checkpoint_step=100, + learning_rate=1e-5, + update_per_collect=1, + expert_model_path='abs model path', + reward_model_path='abs data path + ./cartpole.params', + offline_data_path='abs data path', + ), + policy=dict( + cuda=False, + continuous=False, + model=dict( + obs_shape=4, + action_shape=2, + encoder_hidden_size_list=[64, 64, 128], + critic_head_hidden_size=128, + actor_head_hidden_size=128, + ), + learn=dict( + epoch_per_collect=2, + batch_size=64, + learning_rate=0.001, + value_weight=0.5, + entropy_weight=0.01, + clip_ratio=0.2, + learner=dict( + hook=dict(save_ckpt_after_iter=1000) + ), + ), + collect=dict( + n_sample=256, + unroll_len=1, + discount_factor=0.9, + gae_lambda=0.95, + ), + eval=dict( + evaluator=dict( + eval_freq=100, + ), + ), + ), +) +cartpole_trex_ppo_onpolicy_config = EasyDict(cartpole_trex_ppo_onpolicy_config) +main_config = cartpole_trex_ppo_onpolicy_config +cartpole_trex_ppo_onpolicy_create_config = dict( + env=dict( + type='cartpole', + import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'], + ), + env_manager=dict(type='base'), + policy=dict(type='ppo'), +) +cartpole_trex_ppo_onpolicy_create_config = EasyDict(cartpole_trex_ppo_onpolicy_create_config) +create_config = cartpole_trex_ppo_onpolicy_create_config diff --git a/dizoo/classic_control/cartpole/entry/cartpole_ppg_main.py b/dizoo/classic_control/cartpole/entry/cartpole_ppg_main.py index 089a073..401dfc8 100644 --- a/dizoo/classic_control/cartpole/entry/cartpole_ppg_main.py +++ b/dizoo/classic_control/cartpole/entry/cartpole_ppg_main.py @@ -55,7 +55,7 @@ def main(cfg, seed=0, max_iterations=int(1e10)): cfg.policy.other.replay_buffer.value, tb_logger, exp_name=cfg.exp_name, instance_name='value_buffer' ) - while True: + for _ in range(max_iterations): if evaluator.should_eval(learner.train_iter): stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) if stop: diff --git a/dizoo/multiagent_particle/config/cooperative_navigation_atoc_config.py b/dizoo/multiagent_particle/config/cooperative_navigation_atoc_config.py index 7f83975..f28e703 100644 --- a/dizoo/multiagent_particle/config/cooperative_navigation_atoc_config.py +++ b/dizoo/multiagent_particle/config/cooperative_navigation_atoc_config.py @@ -2,7 +2,7 @@ from easydict import EasyDict n_agent = 5 collector_env_num = 4 -evaluator_env_num = 5 +evaluator_env_num = 2 communication = True cooperative_navigation_atoc_config = dict( env=dict( @@ -13,7 +13,7 @@ cooperative_navigation_atoc_config = dict( evaluator_env_num=evaluator_env_num, agent_obs_only=True, discrete_action=False, - n_evaluator_episode=5, + n_evaluator_episode=10, stop_value=0, ), policy=dict( diff --git a/dizoo/multiagent_particle/config/cooperative_navigation_collaq_config.py b/dizoo/multiagent_particle/config/cooperative_navigation_collaq_config.py index fee8aa5..94ec937 100644 --- a/dizoo/multiagent_particle/config/cooperative_navigation_collaq_config.py +++ b/dizoo/multiagent_particle/config/cooperative_navigation_collaq_config.py @@ -11,9 +11,9 @@ cooperative_navigation_collaq_config = dict( max_step=100, collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, - manager=dict(shared_memory=False, ), - n_evaluator_episode=5, + n_evaluator_episode=10, stop_value=0, + manager=dict(shared_memory=False, ), ), policy=dict( cuda=True, -- GitLab