未验证 提交 f089d02a 编写于 作者: W Will-Nie 提交者: GitHub

polish(nyp): fix unittest for trex training and collecting (#144)

* add trex algorithm for pong

* sort style

* add atari, ll,cp; fix device, collision; add_ppo

* add accuracy evaluation

* correct style

* add seed to make sure results are replicable

* remove useless part in cum return  of model part

* add mujoco onppo training pipeline; ppo config

* improve style

* add sac training config for mujoco

* add log, add save data; polish config

* logger; hyperparameter;walker

* correct style

* modify else condition

* change rnd to trex

* revise according to comments, add eposode collect

* new collect mode for trex, fix all bugs, commnets

* final change

* polish after the final comment

* add readme/test

* add test for serial entry of trex/gcl

* sort style

* change mujoco to cartpole for test for trex_onppo

* remove files generated by testing

* revise tests for entry

* sort style

* revise tests

* modify pytest

* fix(nyz): speed up ppg/ppo and marl algo unittest

* polish(nyz): speed up trex unittest and fix trex entry default config bug

* fix(nyz): fix same name bug

* fix(nyz): fix remove conflict bug(ci skip)
Co-authored-by: Nniuyazhe <niuyazhe@sensetime.com>
上级 973e33e2
......@@ -20,6 +20,7 @@ unittest:
--cov-report=xml \
--cov-report term-missing \
--cov=${COV_DIR} \
${DURATIONS_COMMAND} \
${WORKERS_COMMAND} \
-sv -m unittest \
......
......@@ -14,6 +14,6 @@ from .serial_entry_trex import serial_pipeline_reward_model_trex
from .serial_entry_trex_onpolicy import serial_pipeline_reward_model_trex_onpolicy
from .parallel_entry import parallel_pipeline
from .application_entry import eval, collect_demo_data, collect_episodic_demo_data, \
epsiode_to_transitions
episode_to_transitions
from .application_entry_trex_collect_data import trex_collecting_data, collect_episodic_demo_data_for_trex
from .serial_entry_guided_cost import serial_pipeline_guided_cost
......@@ -231,7 +231,7 @@ def collect_episodic_demo_data(
print('Collect episodic demo data successfully')
def epsiode_to_transitions(data_path: str, expert_data_path: str, nstep: int) -> None:
def episode_to_transitions(data_path: str, expert_data_path: str, nstep: int) -> None:
r"""
Overview:
Transfer episoded data into nstep transitions
......
......@@ -44,6 +44,7 @@ def serial_pipeline_reward_model_trex(
else:
cfg, create_cfg = input_cfg
create_cfg.policy.type = create_cfg.policy.type + '_command'
create_cfg.reward_model = dict(type='trex')
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
# Create main components: env, policy
......
......@@ -43,6 +43,7 @@ def serial_pipeline_reward_model_trex_onpolicy(
else:
cfg, create_cfg = input_cfg
create_cfg.policy.type = create_cfg.policy.type + '_command'
create_cfg.reward_model = dict(type='trex')
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
# Create main components: env, policy
......
......@@ -10,7 +10,7 @@ from dizoo.classic_control.cartpole.config.cartpole_trex_offppo_config import ca
from dizoo.classic_control.cartpole.envs import CartPoleEnv
from ding.entry import serial_pipeline, eval, collect_demo_data
from ding.config import compile_config
from ding.entry.application_entry import collect_episodic_demo_data, epsiode_to_transitions
from ding.entry.application_entry import collect_episodic_demo_data, episode_to_transitions
@pytest.fixture(scope='module')
......@@ -65,8 +65,11 @@ class TestApplication:
def test_collect_episodic_demo_data(self, setup_state_dict):
config = deepcopy(cartpole_trex_ppo_offpolicy_config), deepcopy(cartpole_trex_ppo_offpolicy_create_config)
config[0].exp_name = 'cartpole_trex_offppo_episodic'
collect_count = 16
expert_data_path = './expert.data'
if not os.path.exists('./test_episode'):
os.mkdir('./test_episode')
expert_data_path = './test_episode/expert.data'
collect_episodic_demo_data(
config,
seed=0,
......@@ -79,11 +82,13 @@ class TestApplication:
assert isinstance(exp_data, list)
assert isinstance(exp_data[0][0], dict)
def test_epsiode_to_transitions(self):
expert_data_path = './expert.data'
epsiode_to_transitions(data_path=expert_data_path, expert_data_path=expert_data_path, nstep=3)
def test_episode_to_transitions(self, setup_state_dict):
self.test_collect_episodic_demo_data(setup_state_dict)
expert_data_path = './test_episode/expert.data'
episode_to_transitions(data_path=expert_data_path, expert_data_path=expert_data_path, nstep=3)
with open(expert_data_path, 'rb') as f:
exp_data = pickle.load(f)
assert isinstance(exp_data, list)
assert isinstance(exp_data[0], dict)
os.popen('rm -rf ./expert.data ckpt* log')
os.popen('rm -rf ./test_episode/expert.data ckpt* log')
os.popen('rm -rf ./test_episode')
......@@ -17,6 +17,7 @@ from ding.entry import serial_pipeline
@pytest.mark.unittest
def test_collect_episodic_demo_data_for_trex():
expert_policy_state_dict_path = './expert_policy.pth'
expert_policy_state_dict_path = os.path.abspath('./expert_policy.pth')
config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
expert_policy = serial_pipeline(config, seed=0)
torch.save(expert_policy.collect_mode.state_dict(), expert_policy_state_dict_path)
......@@ -39,11 +40,12 @@ def test_collect_episodic_demo_data_for_trex():
os.popen('rm -rf {}'.format(expert_policy_state_dict_path))
# @pytest.mark.unittest
@pytest.mark.unittest
def test_trex_collecting_data():
expert_policy_state_dict_path = './cartpole_ppo_offpolicy'
expert_policy_state_dict_path = os.path.abspath(expert_policy_state_dict_path)
config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
config[0].policy.learn.learner.hook.save_ckpt_after_iter = 100
expert_policy = serial_pipeline(config, seed=0)
args = EasyDict(
......@@ -54,10 +56,14 @@ def test_trex_collecting_data():
'device': 'cpu'
}
)
args.cfg[0].reward_model.offline_data_path = 'cartpole_trex_offppo_offline_data'
args.cfg[0].reward_model.offline_data_path = './cartpole_trex_offppo'
args.cfg[0].reward_model.offline_data_path = os.path.abspath(args.cfg[0].reward_model.offline_data_path)
args.cfg[0].reward_model.reward_model_path = args.cfg[0].reward_model.offline_data_path + '/cartpole.params'
args.cfg[0].reward_model.expert_model_path = './cartpole_ppo_offpolicy'
args.cfg[0].reward_model.expert_model_path = os.path.abspath(args.cfg[0].reward_model.expert_model_path)
args.cfg[0].reward_model.checkpoint_max = 100
args.cfg[0].reward_model.checkpoint_step = 100
args.cfg[0].reward_model.num_snippets = 100
trex_collecting_data(args=args)
os.popen('rm -rf {}'.format(expert_policy_state_dict_path))
os.popen('rm -rf {}'.format(args.cfg[0].reward_model.offline_data_path))
......@@ -252,6 +252,8 @@ def test_collaq():
config = [deepcopy(cooperative_navigation_collaq_config), deepcopy(cooperative_navigation_collaq_create_config)]
config[0].policy.cuda = False
config[0].policy.learn.update_per_collect = 1
config[0].env.n_evaluator_episode = 2
config[0].policy.collect.n_sample = 100
try:
serial_pipeline(config, seed=0, max_iterations=1)
except Exception:
......@@ -265,6 +267,8 @@ def test_coma():
config = [deepcopy(cooperative_navigation_coma_config), deepcopy(cooperative_navigation_coma_create_config)]
config[0].policy.cuda = False
config[0].policy.learn.update_per_collect = 1
config[0].env.n_evaluator_episode = 2
config[0].policy.collect.n_sample = 100
try:
serial_pipeline(config, seed=0, max_iterations=1)
except Exception:
......@@ -278,6 +282,8 @@ def test_qmix():
config = [deepcopy(cooperative_navigation_qmix_config), deepcopy(cooperative_navigation_qmix_create_config)]
config[0].policy.cuda = False
config[0].policy.learn.update_per_collect = 1
config[0].env.n_evaluator_episode = 2
config[0].policy.collect.n_sample = 100
try:
serial_pipeline(config, seed=0, max_iterations=1)
except Exception:
......@@ -291,6 +297,8 @@ def test_wqmix():
config = [deepcopy(cooperative_navigation_wqmix_config), deepcopy(cooperative_navigation_wqmix_create_config)]
config[0].policy.cuda = False
config[0].policy.learn.update_per_collect = 1
config[0].env.n_evaluator_episode = 2
config[0].policy.collect.n_sample = 100
try:
serial_pipeline(config, seed=0, max_iterations=1)
except Exception:
......@@ -304,6 +312,8 @@ def test_qtran():
config = [deepcopy(cooperative_navigation_qtran_config), deepcopy(cooperative_navigation_qtran_create_config)]
config[0].policy.cuda = False
config[0].policy.learn.update_per_collect = 1
config[0].env.n_evaluator_episode = 2
config[0].policy.collect.n_sample = 100
try:
serial_pipeline(config, seed=0, max_iterations=1)
except Exception:
......@@ -316,6 +326,8 @@ def test_qtran():
def test_atoc():
config = [deepcopy(cooperative_navigation_atoc_config), deepcopy(cooperative_navigation_atoc_create_config)]
config[0].policy.cuda = False
config[0].env.n_evaluator_episode = 2
config[0].policy.collect.n_sample = 100
try:
serial_pipeline(config, seed=0, max_iterations=1)
except Exception:
......
......@@ -14,20 +14,25 @@ from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import
from ding.entry.application_entry_trex_collect_data import trex_collecting_data
# @pytest.mark.unittest
@pytest.mark.unittest
def test_serial_pipeline_reward_model_trex():
config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
config[0].policy.learn.learner.hook.save_ckpt_after_iter = 100
expert_policy = serial_pipeline(config, seed=0)
config = [deepcopy(cartpole_trex_ppo_offpolicy_config), deepcopy(cartpole_trex_ppo_offpolicy_create_config)]
config[0].reward_model.offline_data_path = 'dizoo/classic_control/cartpole/config/cartpole_trex_offppo'
config[0].reward_model.offline_data_path = './cartpole_trex_offppo'
config[0].reward_model.offline_data_path = os.path.abspath(config[0].reward_model.offline_data_path)
config[0].reward_model.reward_model_path = config[0].reward_model.offline_data_path + '/cartpole.params'
config[0].reward_model.expert_model_path = './cartpole_ppo_offpolicy'
config[0].reward_model.expert_model_path = os.path.abspath(config[0].reward_model.expert_model_path)
config[0].reward_model.checkpoint_max = 100
config[0].reward_model.checkpoint_step = 100
config[0].reward_model.num_snippets = 100
args = EasyDict({'cfg': deepcopy(config), 'seed': 0, 'device': 'cpu'})
trex_collecting_data(args=args)
try:
serial_pipeline_reward_model_trex(config, seed=0, max_iterations=1)
os.popen('rm -rf {}'.format(config[0].reward_model.offline_data_path))
except Exception:
assert False, "pipeline fail"
......@@ -7,25 +7,31 @@ import torch
from ding.entry import serial_pipeline_onpolicy
from ding.entry.serial_entry_trex_onpolicy import serial_pipeline_reward_model_trex_onpolicy
from dizoo.mujoco.config import hopper_ppo_default_config, hopper_ppo_create_default_config
from dizoo.mujoco.config import hopper_trex_ppo_default_config, hopper_trex_ppo_create_default_config
from dizoo.classic_control.cartpole.config import cartpole_ppo_config, cartpole_ppo_create_config
from dizoo.classic_control.cartpole.config import cartpole_trex_ppo_onpolicy_config, \
cartpole_trex_ppo_onpolicy_create_config
from ding.entry.application_entry_trex_collect_data import trex_collecting_data
# @pytest.mark.unittest
@pytest.mark.unittest
def test_serial_pipeline_reward_model_trex():
config = [deepcopy(hopper_ppo_default_config), deepcopy(hopper_ppo_create_default_config)]
config = [deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)]
expert_policy = serial_pipeline_onpolicy(config, seed=0, max_iterations=90)
config = [deepcopy(hopper_trex_ppo_default_config), deepcopy(hopper_trex_ppo_create_default_config)]
config[0].reward_model.offline_data_path = 'dizoo/mujoco/config/hopper_trex_onppo'
config = [deepcopy(cartpole_trex_ppo_onpolicy_config), deepcopy(cartpole_trex_ppo_onpolicy_create_config)]
config[0].reward_model.offline_data_path = './cartpole_trex_onppo'
config[0].reward_model.offline_data_path = os.path.abspath(config[0].reward_model.offline_data_path)
config[0].reward_model.reward_model_path = config[0].reward_model.offline_data_path + '/hopper.params'
config[0].reward_model.expert_model_path = './hopper_onppo'
config[0].reward_model.reward_model_path = config[0].reward_model.offline_data_path + '/cartpole.params'
config[0].reward_model.expert_model_path = './cartpole_ppo'
config[0].reward_model.expert_model_path = os.path.abspath(config[0].reward_model.expert_model_path)
config[0].reward_model.checkpoint_max = 100
config[0].reward_model.checkpoint_step = 100
config[0].reward_model.num_snippets = 100
args = EasyDict({'cfg': deepcopy(config), 'seed': 0, 'device': 'cpu'})
trex_collecting_data(args=args)
try:
serial_pipeline_reward_model_trex_onpolicy(config, seed=0, max_iterations=1)
os.popen('rm -rf {}'.format(config[0].reward_model.offline_data_path))
os.popen('rm -rf {}'.format(config[0].reward_model.expert_model_path))
except Exception:
assert False, "pipeline fail"
......@@ -125,11 +125,8 @@ class TrexModel(nn.Module):
if mode == 'sum':
sum_rewards = torch.sum(r)
sum_abs_rewards = torch.sum(torch.abs(r))
# print(sum_rewards)
# print(r)
return sum_rewards, sum_abs_rewards
elif mode == 'batch':
# print(r)
return r, torch.abs(r)
else:
raise KeyError("not support mode: {}, please choose mode=sum or mode=batch".format(mode))
......@@ -157,6 +154,8 @@ class TrexRewardModel(BaseRewardModel):
batch_size=64,
target_new_data_count=64,
hidden_size=128,
num_trajs=0, # number of downsampled full trajectories
num_snippets=6000, # number of short subtrajectories to sample
)
def __init__(self, config: EasyDict, device: str, tb_logger: 'SummaryWriter') -> None: # noqa
......@@ -173,7 +172,7 @@ class TrexRewardModel(BaseRewardModel):
assert device in ["cpu", "cuda"] or "cuda" in device
self.device = device
self.tb_logger = tb_logger
self.reward_model = TrexModel(self.cfg.policy.model.get('obs_shape'))
self.reward_model = TrexModel(self.cfg.policy.model.obs_shape)
self.reward_model.to(self.device)
self.pre_expert_data = []
self.train_data = []
......@@ -184,8 +183,8 @@ class TrexRewardModel(BaseRewardModel):
self.learning_rewards = []
self.training_obs = []
self.training_labels = []
self.num_trajs = 0 # number of downsampled full trajectories
self.num_snippets = 6000 # number of short subtrajectories to sample
self.num_trajs = self.cfg.reward_model.num_trajs
self.num_snippets = self.cfg.reward_model.num_snippets
# minimum number of short subtrajectories to sample
self.min_snippet_length = config.reward_model.min_snippet_length
# maximum number of short subtrajectories to sample
......@@ -240,19 +239,15 @@ class TrexRewardModel(BaseRewardModel):
#collect training data
max_traj_length = 0
num_demos = len(demonstrations)
assert num_demos >= 2
#add full trajs (for use on Enduro)
si = np.random.randint(6, size=num_trajs)
sj = np.random.randint(6, size=num_trajs)
step = np.random.randint(3, 7, size=num_trajs)
for n in range(num_trajs):
ti = 0
tj = 0
#only add trajectories that are different returns
while (ti == tj):
#pick two random demonstrations
ti = np.random.randint(num_demos)
tj = np.random.randint(num_demos)
#pick two random demonstrations
ti, tj = np.random.choice(num_demos, size=(2, ), replace=False)
#create random partial trajs by finding random start frame and random skip frame
traj_i = demonstrations[ti][si[n]::step[n]] # slice(start,stop,step)
traj_j = demonstrations[tj][sj[n]::step[n]]
......@@ -266,13 +261,8 @@ class TrexRewardModel(BaseRewardModel):
#fixed size snippets with progress prior
rand_length = np.random.randint(min_snippet_length, max_snippet_length, size=num_snippets)
for n in range(num_snippets):
ti = 0
tj = 0
#only add trajectories that are different returns
while (ti == tj):
#pick two random demonstrations
ti = np.random.randint(num_demos)
tj = np.random.randint(num_demos)
#pick two random demonstrations
ti, tj = np.random.choice(num_demos, size=(2, ), replace=False)
#create random snippets
#find min length of both demos to ensure we can pick a demo no earlier
#than that chosen in worse preferred demo
......@@ -285,8 +275,8 @@ class TrexRewardModel(BaseRewardModel):
tj_start = np.random.randint(min_length - rand_length[n] + 1)
# print(tj_start, len(demonstrations[ti]))
ti_start = np.random.randint(tj_start, len(demonstrations[ti]) - rand_length[n] + 1)
traj_i = demonstrations[ti][ti_start:ti_start + rand_length[n]:2
] # skip everyother framestack to reduce size
# skip everyother framestack to reduce size
traj_i = demonstrations[ti][ti_start:ti_start + rand_length[n]:2]
traj_j = demonstrations[tj][tj_start:tj_start + rand_length[n]:2]
max_traj_length = max(max_traj_length, len(traj_i), len(traj_j))
......@@ -334,7 +324,6 @@ class TrexRewardModel(BaseRewardModel):
item_loss = loss.item()
cum_loss += item_loss
if i % 100 == 99:
# print(i)
self._logger.info("epoch {}:{} loss {}".format(epoch, i, cum_loss))
self._logger.info("abs_returns: {}".format(abs_rewards))
cum_loss = 0.0
......
......@@ -18,4 +18,5 @@ from .cartpole_dqn_gail_config import cartpole_dqn_gail_config, cartpole_dqn_gai
from .cartpole_gcl_config import cartpole_gcl_ppo_onpolicy_config, cartpole_gcl_ppo_onpolicy_create_config
from .cartpole_trex_dqn_config import cartpole_trex_dqn_config, cartpole_trex_dqn_create_config
from .cartpole_trex_offppo_config import cartpole_trex_ppo_offpolicy_config, cartpole_trex_ppo_offpolicy_create_config
from .cartpole_trex_onppo_config import cartpole_trex_ppo_onpolicy_config, cartpole_trex_ppo_onpolicy_create_config
# from .cartpole_ppo_default_loader import cartpole_ppo_default_loader
......@@ -31,6 +31,7 @@ cartpole_ppg_config = dict(
discount_factor=0.9,
gae_lambda=0.95,
),
eval=dict(evaluator=dict(eval_freq=40, )),
other=dict(
replay_buffer=dict(
multi_buffer=True,
......
......@@ -25,6 +25,9 @@ cartpole_ppo_config = dict(
value_weight=0.5,
entropy_weight=0.01,
clip_ratio=0.2,
learner=dict(
hook=dict(save_ckpt_after_iter=100)
),
),
collect=dict(
n_sample=256,
......
......@@ -24,6 +24,9 @@ cartpole_ppo_offpolicy_config = dict(
value_weight=0.5,
entropy_weight=0.01,
clip_ratio=0.2,
learner=dict(
hook=dict(save_ckpt_after_iter=1000)
),
),
collect=dict(
n_sample=128,
......@@ -31,6 +34,7 @@ cartpole_ppo_offpolicy_config = dict(
discount_factor=0.9,
gae_lambda=0.95,
),
eval=dict(evaluator=dict(eval_freq=40, )),
other=dict(replay_buffer=dict(replay_buffer_size=5000))
),
)
......
......@@ -60,7 +60,7 @@ cartpole_trex_dqn_create_config = dict(
type='cartpole',
import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
),
env_manager=dict(type='subprocess'),
env_manager=dict(type='base'),
policy=dict(type='dqn'),
)
cartpole_trex_dqn_create_config = EasyDict(cartpole_trex_dqn_create_config)
......
......@@ -58,7 +58,7 @@ cartpole_trex_ppo_offpolicy_create_config = dict(
type='cartpole',
import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
),
env_manager=dict(type='subprocess'),
env_manager=dict(type='base'),
policy=dict(type='ppo_offpolicy'),
)
cartpole_trex_ppo_offpolicy_create_config = EasyDict(cartpole_trex_ppo_offpolicy_create_config)
......
from easydict import EasyDict
cartpole_trex_ppo_onpolicy_config = dict(
exp_name='cartpole_trex_onppo',
env=dict(
manager=dict(shared_memory=True, force_reproducibility=True),
collector_env_num=8,
evaluator_env_num=5,
n_evaluator_episode=5,
stop_value=195,
),
reward_model=dict(
type='trex',
algo_for_model='ppo',
env_id='CartPole-v0',
min_snippet_length=5,
max_snippet_length=100,
checkpoint_min=0,
checkpoint_max=100,
checkpoint_step=100,
learning_rate=1e-5,
update_per_collect=1,
expert_model_path='abs model path',
reward_model_path='abs data path + ./cartpole.params',
offline_data_path='abs data path',
),
policy=dict(
cuda=False,
continuous=False,
model=dict(
obs_shape=4,
action_shape=2,
encoder_hidden_size_list=[64, 64, 128],
critic_head_hidden_size=128,
actor_head_hidden_size=128,
),
learn=dict(
epoch_per_collect=2,
batch_size=64,
learning_rate=0.001,
value_weight=0.5,
entropy_weight=0.01,
clip_ratio=0.2,
learner=dict(
hook=dict(save_ckpt_after_iter=1000)
),
),
collect=dict(
n_sample=256,
unroll_len=1,
discount_factor=0.9,
gae_lambda=0.95,
),
eval=dict(
evaluator=dict(
eval_freq=100,
),
),
),
)
cartpole_trex_ppo_onpolicy_config = EasyDict(cartpole_trex_ppo_onpolicy_config)
main_config = cartpole_trex_ppo_onpolicy_config
cartpole_trex_ppo_onpolicy_create_config = dict(
env=dict(
type='cartpole',
import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
),
env_manager=dict(type='base'),
policy=dict(type='ppo'),
)
cartpole_trex_ppo_onpolicy_create_config = EasyDict(cartpole_trex_ppo_onpolicy_create_config)
create_config = cartpole_trex_ppo_onpolicy_create_config
......@@ -55,7 +55,7 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
cfg.policy.other.replay_buffer.value, tb_logger, exp_name=cfg.exp_name, instance_name='value_buffer'
)
while True:
for _ in range(max_iterations):
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
......
......@@ -2,7 +2,7 @@ from easydict import EasyDict
n_agent = 5
collector_env_num = 4
evaluator_env_num = 5
evaluator_env_num = 2
communication = True
cooperative_navigation_atoc_config = dict(
env=dict(
......@@ -13,7 +13,7 @@ cooperative_navigation_atoc_config = dict(
evaluator_env_num=evaluator_env_num,
agent_obs_only=True,
discrete_action=False,
n_evaluator_episode=5,
n_evaluator_episode=10,
stop_value=0,
),
policy=dict(
......
......@@ -11,9 +11,9 @@ cooperative_navigation_collaq_config = dict(
max_step=100,
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
manager=dict(shared_memory=False, ),
n_evaluator_episode=5,
n_evaluator_episode=10,
stop_value=0,
manager=dict(shared_memory=False, ),
),
policy=dict(
cuda=True,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册