diff --git a/ding/entry/application_entry.py b/ding/entry/application_entry.py index 0d77c39c1ce078ad230c0685e36681febe225ec4..2367b8ad5af0106ee5df04ef5d9f2f5b80e4a36d 100644 --- a/ding/entry/application_entry.py +++ b/ding/entry/application_entry.py @@ -154,6 +154,7 @@ def collect_demo_data( if cfg.policy.cuda: exp_data = to_device(exp_data, 'cpu') # Save data transitions. + expert_data_path = os.path.join(cfg.exp_name, expert_data_path) offline_data_save_type(exp_data, expert_data_path, data_type=cfg.policy.collect.get('data_type', 'naive')) print('Collect demo data successfully') @@ -227,6 +228,7 @@ def collect_episodic_demo_data( if cfg.policy.cuda: exp_data = to_device(exp_data, 'cpu') # Save data transitions. + expert_data_path = os.path.join(cfg.exp_name, expert_data_path) offline_data_save_type(exp_data, expert_data_path, data_type=cfg.policy.collect.get('data_type', 'naive')) print('Collect episodic demo data successfully') diff --git a/ding/entry/tests/test_serial_entry.py b/ding/entry/tests/test_serial_entry.py index d80f4b8dcf1bebaa7c041a6a8d7f6d783280fbf9..b8c0904558f8a40731b0e63989b19e7d647d696d 100644 --- a/ding/entry/tests/test_serial_entry.py +++ b/ding/entry/tests/test_serial_entry.py @@ -2,6 +2,7 @@ import pytest import time import os from copy import deepcopy +import torch from ding.entry import serial_pipeline, collect_demo_data, serial_pipeline_offline from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config @@ -360,7 +361,9 @@ def test_sqn(): @pytest.mark.unittest def test_selfplay(): try: - selfplay_main(deepcopy(league_demo_ppo_config), seed=0, max_iterations=1) + config = deepcopy(league_demo_ppo_config) + config.exp_name = 'test_selfplay' + selfplay_main(config, seed=0, max_iterations=1) except Exception: assert False, "pipeline fail" @@ -368,7 +371,9 @@ def test_selfplay(): @pytest.mark.unittest def test_league(): try: - league_main(deepcopy(league_demo_ppo_config), seed=0, max_iterations=1) + config = deepcopy(league_demo_ppo_config) + config.exp_name = 'test_league' + league_main(config, seed=0, max_iterations=1) except Exception as e: assert False, "pipeline fail" @@ -395,14 +400,13 @@ def test_cql(): assert False, "pipeline fail" # collect expert data - import torch config = [ deepcopy(pendulum_sac_data_genearation_default_config), deepcopy(pendulum_sac_data_genearation_default_create_config) ] collect_count = 1000 expert_data_path = config[0].policy.collect.save_path - state_dict = torch.load('./sac/ckpt/iteration_0.pth.tar', map_location='cpu') + state_dict = torch.load('./sac_seed0/ckpt/iteration_0.pth.tar', map_location='cpu') try: collect_demo_data( config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict @@ -442,11 +446,10 @@ def test_discrete_cql(): except Exception: assert False, "pipeline fail" # collect expert data - import torch config = [deepcopy(cartpole_qrdqn_generation_data_config), deepcopy(cartpole_qrdqn_generation_data_create_config)] collect_count = 1000 expert_data_path = config[0].policy.collect.save_path - state_dict = torch.load('./cql_cartpole/ckpt/iteration_0.pth.tar', map_location='cpu') + state_dict = torch.load('./cql_cartpole_seed0/ckpt/iteration_0.pth.tar', map_location='cpu') try: collect_demo_data( config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict @@ -467,7 +470,7 @@ def test_discrete_cql(): os.popen('rm -rf cartpole cartpole_cql') -@pytest.mark.algotest +@pytest.mark.unittest def test_td3_bc(): # train expert config = [deepcopy(pendulum_td3_config), deepcopy(pendulum_td3_create_config)] @@ -479,11 +482,10 @@ def test_td3_bc(): assert False, "pipeline fail" # collect expert data - import torch config = [deepcopy(pendulum_td3_generation_config), deepcopy(pendulum_td3_generation_create_config)] collect_count = 1000 expert_data_path = config[0].policy.collect.save_path - state_dict = torch.load('./td3/ckpt/iteration_0.pth.tar', map_location='cpu') + state_dict = torch.load('./td3_seed0/ckpt/iteration_0.pth.tar', map_location='cpu') try: collect_demo_data( config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict diff --git a/ding/entry/tests/test_serial_entry_algo.py b/ding/entry/tests/test_serial_entry_algo.py index e8f743b74ccfe769c1778767eaf20b664b20cf6a..82d527bfdc25305c1baaf86b3f040d02cc354709 100644 --- a/ding/entry/tests/test_serial_entry_algo.py +++ b/ding/entry/tests/test_serial_entry_algo.py @@ -281,7 +281,9 @@ def test_acer(): @pytest.mark.algotest def test_selfplay(): try: - selfplay_main(deepcopy(league_demo_ppo_config), seed=0) + config = deepcopy(league_demo_ppo_config) + config.exp_name = 'test_selfplay' + selfplay_main(config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: @@ -291,7 +293,9 @@ def test_selfplay(): @pytest.mark.algotest def test_league(): try: - league_main(deepcopy(league_demo_ppo_config), seed=0) + config = deepcopy(league_demo_ppo_config) + config.exp_name = 'test_league' + league_main(config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: @@ -326,14 +330,13 @@ def test_cql(): assert False, "pipeline fail" # collect expert data - import torch config = [ deepcopy(pendulum_sac_data_genearation_default_config), deepcopy(pendulum_sac_data_genearation_default_create_config) ] collect_count = config[0].policy.other.replay_buffer.replay_buffer_size expert_data_path = config[0].policy.collect.save_path - state_dict = torch.load(config[0].policy.learn.learner.load_path, map_location='cpu') + state_dict = torch.load('./sac_seed0/ckpt/iteration_0.pth.tar', map_location='cpu') try: collect_demo_data( config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict @@ -362,11 +365,10 @@ def test_discrete_cql(): assert False, "pipeline fail" # collect expert data - import torch config = [deepcopy(cartpole_qrdqn_generation_data_config), deepcopy(cartpole_qrdqn_generation_data_create_config)] collect_count = config[0].policy.other.replay_buffer.replay_buffer_size expert_data_path = config[0].policy.collect.save_path - state_dict = torch.load(config[0].policy.learn.learner.load_path, map_location='cpu') + state_dict = torch.load('./cql_cartpole_seed0/ckpt/iteration_0.pth.tar', map_location='cpu') try: collect_demo_data( config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict @@ -406,11 +408,10 @@ def test_td3_bc(): assert False, "pipeline fail" # collect expert data - import torch config = [deepcopy(pendulum_td3_generation_config), deepcopy(pendulum_td3_generation_create_config)] collect_count = config[0].policy.other.replay_buffer.replay_buffer_size expert_data_path = config[0].policy.collect.save_path - state_dict = torch.load(config[0].policy.learn.learner.load_path, map_location='cpu') + state_dict = torch.load('./td3_seed0/ckpt/iteration_0.pth.tar', map_location='cpu') try: collect_demo_data( config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict diff --git a/ding/policy/dqn.py b/ding/policy/dqn.py index fe646a89bb746a66b108dabb5445b6f55732cf52..456114aafe82e85760f433dbaeb731ea7e9e013f 100644 --- a/ding/policy/dqn.py +++ b/ding/policy/dqn.py @@ -71,12 +71,17 @@ class DQNPolicy(Policy): config = dict( type='dqn', + # (bool) Whether use cuda in policy cuda=False, + # (bool) Whether learning policy is the same as collecting data policy(on-policy) on_policy=False, + # (bool) Whether enable priority experience sample priority=False, # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. priority_IS_weight=False, + # (float) Discount factor(gamma) for returns discount_factor=0.97, + # (int) The number of step for calculating target q_value nstep=1, learn=dict( # (bool) Whether to use multi gpu diff --git a/dizoo/atari/config/serial/pong/pong_qrdqn_generation_data_config.py b/dizoo/atari/config/serial/pong/pong_qrdqn_generation_data_config.py index f366c85fdf3d7c8091b97ecee2c38adc6cb80335..6bc8a6cc193c6e37398be59e0ae54421745ea341 100644 --- a/dizoo/atari/config/serial/pong/pong_qrdqn_generation_data_config.py +++ b/dizoo/atari/config/serial/pong/pong_qrdqn_generation_data_config.py @@ -3,6 +3,7 @@ from ding.entry import serial_pipeline from easydict import EasyDict pong_qrdqn_config = dict( + exp_name='pong_qrdqn_generation', env=dict( collector_env_num=8, evaluator_env_num=8, @@ -39,7 +40,7 @@ pong_qrdqn_config = dict( collect=dict( n_sample=100, data_type='hdf5', - save_path='./expert/expert.pkl', + save_path='expert.pkl', ), eval=dict(evaluator=dict(eval_freq=4000, )), other=dict( diff --git a/dizoo/atari/config/serial/qbert/qbert_qrdqn_generation_data_config.py b/dizoo/atari/config/serial/qbert/qbert_qrdqn_generation_data_config.py index e802e33e80f5e3c1c9f2ab3c663b1c2fa3a5a323..487470dd84794a797be7b1d88e0b5cf2447571bf 100644 --- a/dizoo/atari/config/serial/qbert/qbert_qrdqn_generation_data_config.py +++ b/dizoo/atari/config/serial/qbert/qbert_qrdqn_generation_data_config.py @@ -3,6 +3,7 @@ from ding.entry import serial_pipeline from easydict import EasyDict qbert_qrdqn_config = dict( + exp_name='qbert_qrdqn_geneation', env=dict( collector_env_num=8, evaluator_env_num=8, @@ -39,7 +40,7 @@ qbert_qrdqn_config = dict( collect=dict( n_sample=100, data_type='hdf5', - save_path='./expert/expert.pkl', + save_path='expert.pkl', ), eval=dict(evaluator=dict(eval_freq=4000, )), other=dict( diff --git a/dizoo/classic_control/cartpole/config/cartpole_cql_config.py b/dizoo/classic_control/cartpole/config/cartpole_cql_config.py index 8fa011e44e8af25ed71548904296752cce654326..e39405b1e2c0a562d491d202b8a928d563c1acda 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_cql_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_cql_config.py @@ -29,7 +29,7 @@ cartpole_discrete_cql_config = dict( ), collect=dict( data_type='hdf5', - data_path='./cartpole_generation/expert_demos.hdf5', + data_path='./cartpole_generation_seed0/expert_demos.hdf5', # user-specific n_sample=80, unroll_len=1, ), diff --git a/dizoo/classic_control/cartpole/config/cartpole_qrdqn_generation_data_config.py b/dizoo/classic_control/cartpole/config/cartpole_qrdqn_generation_data_config.py index 0d6450b15fe33edd2159768c01ccd8540eb1431c..2924693bb1160e2cce1b7d1c2864835a442eba53 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_qrdqn_generation_data_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_qrdqn_generation_data_config.py @@ -37,7 +37,7 @@ cartpole_qrdqn_generation_data_config = dict( n_sample=80, unroll_len=1, data_type='hdf5', - save_path='./cartpole_generation/expert.pkl', + save_path='expert.pkl', ), other=dict( eps=dict( diff --git a/dizoo/classic_control/pendulum/config/pendulum_cql_config.py b/dizoo/classic_control/pendulum/config/pendulum_cql_config.py index 7e91497b52afba1184b3ba4bcc43771ff2b3f05c..d14378b930aef1ac328b867d6f76e2f379d2c71a 100644 --- a/dizoo/classic_control/pendulum/config/pendulum_cql_config.py +++ b/dizoo/classic_control/pendulum/config/pendulum_cql_config.py @@ -37,7 +37,7 @@ pendulum_cql_default_config = dict( n_sample=1, unroll_len=1, data_type='hdf5', - data_path='./sac/expert_demos.hdf5', + data_path='./peudulum_sac_generation_seed0/expert_demos.hdf5', # user-specific ), command=dict(), eval=dict(evaluator=dict(eval_freq=100, )), diff --git a/dizoo/classic_control/pendulum/config/pendulum_sac_data_generation_default_config.py b/dizoo/classic_control/pendulum/config/pendulum_sac_data_generation_default_config.py index a673e8bde36e335a7a76b845790fc26d77204a0e..76e9e6eadb7cc1fbb50c70ca260e8034e4c79f4d 100644 --- a/dizoo/classic_control/pendulum/config/pendulum_sac_data_generation_default_config.py +++ b/dizoo/classic_control/pendulum/config/pendulum_sac_data_generation_default_config.py @@ -1,6 +1,7 @@ from easydict import EasyDict pendulum_sac_data_genearation_default_config = dict( + exp_name='peudulum_sac_generation', seed=0, env=dict( collector_env_num=10, @@ -43,7 +44,7 @@ pendulum_sac_data_genearation_default_config = dict( collect=dict( n_sample=1, unroll_len=1, - save_path='./sac/expert.pkl', + save_path='expert.pkl', data_type='hdf5', ), command=dict(), diff --git a/dizoo/classic_control/pendulum/config/pendulum_td3_bc_config.py b/dizoo/classic_control/pendulum/config/pendulum_td3_bc_config.py index 200aa632ab994570262b8102e81a77d8b70a9158..9e95815e70d7336966095df6e135216ce1fae3ca 100644 --- a/dizoo/classic_control/pendulum/config/pendulum_td3_bc_config.py +++ b/dizoo/classic_control/pendulum/config/pendulum_td3_bc_config.py @@ -44,7 +44,7 @@ pendulum_td3_bc_config = dict( noise_sigma=0.1, collector=dict(collect_print_freq=1000, ), data_type='hdf5', - data_path='./td3/expert_demos.hdf5', + data_path='./pendulum_td3_generation_seed0/expert_demos.hdf5', # user-specific normalize_states=True, ), eval=dict(evaluator=dict(eval_freq=100, ), ), diff --git a/dizoo/classic_control/pendulum/config/pendulum_td3_data_generation_config.py b/dizoo/classic_control/pendulum/config/pendulum_td3_data_generation_config.py index 357443486de70dd63b4e79a95e5e49395bf5bc57..f3d6694e5ecd86165eb6ec1257cdef428ea92283 100644 --- a/dizoo/classic_control/pendulum/config/pendulum_td3_data_generation_config.py +++ b/dizoo/classic_control/pendulum/config/pendulum_td3_data_generation_config.py @@ -1,7 +1,7 @@ from easydict import EasyDict pendulum_td3_generation_config = dict( - exp_name='td3', + exp_name='pendulum_td3_generation', env=dict( collector_env_num=8, evaluator_env_num=10, @@ -45,7 +45,7 @@ pendulum_td3_generation_config = dict( n_sample=10, noise_sigma=0.1, collector=dict(collect_print_freq=1000, ), - save_path='./td3/expert.pkl', + save_path='expert.pkl', data_type='hdf5', ), eval=dict(evaluator=dict(eval_freq=100, ), ), diff --git a/dizoo/mujoco/config/hopper_sac_data_generation_default_config.py b/dizoo/mujoco/config/hopper_sac_data_generation_default_config.py index 6a126d823108b8a3b4e75807b618546f9dd66858..8aa8647a66ad0a1693eb501dbfb8803fe0ee218c 100644 --- a/dizoo/mujoco/config/hopper_sac_data_generation_default_config.py +++ b/dizoo/mujoco/config/hopper_sac_data_generation_default_config.py @@ -1,6 +1,7 @@ from easydict import EasyDict hopper_sac_data_genearation_default_config = dict( + exp='hopper_sac_generation', env=dict( env_id='Hopper-v3', norm_obs=dict(use_norm=False, ), @@ -45,7 +46,7 @@ hopper_sac_data_genearation_default_config = dict( collect=dict( n_sample=1, unroll_len=1, - save_path='./default_experiment/expert_iteration_200000.pkl', + save_path='expert_iteration_200000.pkl', ), command=dict(), eval=dict(), diff --git a/dizoo/mujoco/config/hopper_td3_data_generation_config.py b/dizoo/mujoco/config/hopper_td3_data_generation_config.py index 9a0f71cf92f0a169a8a61115b898c61ad7ecc011..31bd9e96c5bf39572dcc96c1c3923dd3c435213f 100644 --- a/dizoo/mujoco/config/hopper_td3_data_generation_config.py +++ b/dizoo/mujoco/config/hopper_td3_data_generation_config.py @@ -1,6 +1,7 @@ from easydict import EasyDict halfcheetah_td3_default_config = dict( + exp_name='halfcheetah_td3_generation', env=dict( env_id='Hopper-v3', norm_obs=dict(use_norm=False, ), @@ -49,7 +50,7 @@ halfcheetah_td3_default_config = dict( n_sample=1, unroll_len=1, noise_sigma=0.1, - save_path='./td3/expert.pkl', + save_path='expert.pkl', data_type='hdf5', ), other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),