未验证 提交 aa612443 编写于 作者: K Ke Li 提交者: GitHub

fix(lk): fix port conflict in gym_soccer (#139)

* feature(lk): fix port conflict

* polish(lk): polish code style and format

* fix(lk): change to subprocess
上级 ff31a86b
from easydict import EasyDict
from ding.entry import serial_pipeline
gym_soccer_pdqn_config = dict(
exp_name='gym_soccer_pdqn_seed1',
env=dict(
collector_env_num=8,
evaluator_env_num=3,
# (bool) Scale output action into legal range [-1, 1].
act_scale=True,
env_id='Soccer-v0', # ['Soccer-v0', 'SoccerEmptyGoal-v0', 'SoccerAgainstKeeper-v0']
n_evaluator_episode=5,
stop_value=0.99, # 1
),
policy=dict(
cuda=True,
priority=False,
# (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
priority_IS_weight=False,
discount_factor=0.99,
nstep=1,
model=dict(
obs_shape=10,
action_shape=dict(
action_type_shape=3,
action_args_shape=5,
),
# multi_pass=True,
# action_mask=[],
),
learn=dict(
# (bool) Whether to use multi gpu
multi_gpu=False,
# How many updates(iterations) to train after collector's one collection.
# Bigger "update_per_collect" means bigger off-policy.
# collect data -> update policy-> collect data -> ...
update_per_collect=500, # 100, 10,
batch_size=320, # 32,
learning_rate_dis=3e-4, # 1e-5, 3e-4, alpha
learning_rate_cont=3e-4, # beta
target_theta=0.001, # 0.005,
# cont_update_freq=10,
# disc_update_freq=10,
update_circle=10,
),
# collect_mode config
collect=dict(
# (int) Only one of [n_sample, n_episode] shoule be set
n_sample=3200, # 128,
# (int) Cut trajectories into pieces with length "unroll_len".
unroll_len=1,
noise_sigma=0.1, # 0.05,
collector=dict(collect_print_freq=1000, ),
),
eval=dict(evaluator=dict(eval_freq=1000, ), ),
# other config
other=dict(
# Epsilon greedy with decay.
eps=dict(
# (str) Decay type. Support ['exp', 'linear'].
type='exp',
start=1, # 0.95,
end=0.1, # 0.05,
# (int) Decay length(env step)
decay=int(1e5),
),
replay_buffer=dict(replay_buffer_size=int(1e6), ),
),
)
)
gym_soccer_pdqn_config = EasyDict(gym_soccer_pdqn_config)
main_config = gym_soccer_pdqn_config
gym_soccer_pdqn_create_config = dict(
env=dict(
type='gym_soccer',
import_names=['dizoo.gym_soccer.envs.gym_soccer_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='pdqn'),
)
gym_soccer_pdqn_create_config = EasyDict(gym_soccer_pdqn_create_config)
create_config = gym_soccer_pdqn_create_config
if __name__ == "__main__":
serial_pipeline([main_config, create_config], seed=1)
......@@ -10,6 +10,7 @@ from ding.envs.common.env_element import EnvElementInfo
from ding.torch_utils import to_list, to_ndarray, to_tensor
from ding.utils import ENV_REGISTRY
from gym.utils import seeding
import copy
@ENV_REGISTRY.register('gym_soccer')
......@@ -22,11 +23,11 @@ class GymSoccerEnv(BaseEnv):
self._env_id = cfg.env_id
assert self._env_id in self.default_env_id
self._init_flag = False
self._replay_path = None
self._replay_path = './game_log'
def reset(self) -> np.array:
if not self._init_flag:
self._env = gym.make(self._env_id, replay_path=self._replay_path)
self._env = gym.make(self._env_id, replay_path=self._replay_path, port=self._cfg.port) # TODO
self._init_flag = True
self._final_eval_reward = 0
obs = self._env.reset()
......@@ -124,3 +125,41 @@ class GymSoccerEnv(BaseEnv):
if replay_path is None:
replay_path = './game_log'
self._replay_path = replay_path
def create_collector_env_cfg(cfg: dict) -> List[dict]:
"""
Overview:
Return a list of all of the environment from input config.
Arguments:
- cfg (:obj:`Dict`) Env config, same config where ``self.__init__()`` takes arguments from
Returns:
- List of ``cfg`` including all of the collector env's config
"""
cfg_list = []
collector_env_num = cfg.pop('collector_env_num')
port_pool = list(range(6000, 9999))
port_candidates = np.random.choice(port_pool, size=collector_env_num, replace=False)
for i in range(collector_env_num):
cfg_copy = copy.deepcopy(cfg)
cfg_copy.port = port_candidates[i]
cfg_list.append(cfg_copy)
return cfg_list
def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
"""
Overview:
Return a list of all of the environment from input config.
Arguments:
- cfg (:obj:`Dict`) Env config, same config where ``self.__init__()`` takes arguments from
Returns:
- List of ``cfg`` including all of the evaluator env's config
"""
cfg_list = []
evaluator_env_num = cfg.pop('evaluator_env_num')
port_pool = list(range(6000, 9999))
port_candidates = np.random.choice(port_pool, size=evaluator_env_num, replace=False)
for i in range(evaluator_env_num):
cfg_copy = copy.deepcopy(cfg)
cfg_copy.port = port_candidates[i]
cfg_list.append(cfg_copy)
return cfg_list
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册