未验证 提交 f88bc0e0 编写于 作者: R Robin Chen 提交者: GitHub

fix(crb): add renew for env manager; update retry and timeout logit for...

fix(crb): add renew for env manager; update retry and timeout logit for subprecess env manager (#127)

* update base env manager and test

* add test reset once

* update subprecess env manager and test

* format code

* update picking error
上级 41dce176
......@@ -24,47 +24,15 @@ class EnvState(enum.IntEnum):
ERROR = 5
def retry_wrapper(func: Callable = None, max_retry: int = 10, waiting_time: float = 0.1) -> Callable:
"""
Overview:
Retry the function until exceeding the maximum retry times.
"""
if func is None:
return partial(retry_wrapper, max_retry=max_retry)
if max_retry == 1:
return func
@wraps(func)
def wrapper(*args, **kwargs):
exceptions = []
for _ in range(max_retry):
try:
ret = func(*args, **kwargs)
return ret
except BaseException as e:
exceptions.append(e)
time.sleep(waiting_time)
logging.error("Function {} has exceeded max retries({})".format(func, max_retry))
runtime_error = RuntimeError(
"Function {} has exceeded max retries({}), and the latest exception is: {}".format(
func, max_retry, repr(exceptions[-1])
)
)
runtime_error.__traceback__ = exceptions[-1].__traceback__
raise runtime_error
return wrapper
def timeout_wrapper(func: Callable = None, timeout: int = 10) -> Callable:
def timeout_wrapper(func: Callable = None, timeout: Optional[int] = None) -> Callable:
"""
Overview:
Watch the function that must be finihsed within a period of time. If timeout, raise the captured error.
"""
if func is None:
return partial(timeout_wrapper, timeout=timeout)
if timeout is None:
return func
windows_flag = platform.system().lower() == 'windows'
if windows_flag:
......@@ -109,9 +77,10 @@ class BaseEnvManager(object):
config = dict(
episode_num=float("inf"),
max_retry=1,
step_timeout=60,
retry_type='reset',
auto_reset=True,
reset_timeout=60,
step_timeout=None,
reset_timeout=None,
retry_waiting_time=0.1,
)
......@@ -138,9 +107,11 @@ class BaseEnvManager(object):
self._env_seed = {i: None for i in range(self._env_num)}
self._episode_num = self._cfg.episode_num
self._max_retry = self._cfg.max_retry
self._step_timeout = self._cfg.step_timeout
self._max_retry = max(self._cfg.max_retry, 1)
self._auto_reset = self._cfg.auto_reset
self._retry_type = self._cfg.retry_type
assert self._retry_type in ['reset', 'renew'], self._retry_type
self._step_timeout = self._cfg.step_timeout
self._reset_timeout = self._cfg.reset_timeout
self._retry_waiting_time = self._cfg.retry_waiting_time
......@@ -256,7 +227,6 @@ class BaseEnvManager(object):
def _reset(self, env_id: int) -> None:
@retry_wrapper(max_retry=self._max_retry, waiting_time=self._retry_waiting_time)
@timeout_wrapper(timeout=self._reset_timeout)
def reset_fn():
# if self._reset_param[env_id] is None, just reset specific env, not pass reset param
......@@ -266,14 +236,32 @@ class BaseEnvManager(object):
else:
return self._envs[env_id].reset()
try:
obs = reset_fn()
except Exception as e:
self._env_states[env_id] = EnvState.ERROR
self.close()
raise e
self._ready_obs[env_id] = obs
self._env_states[env_id] = EnvState.RUN
exceptions = []
for _ in range(self._max_retry):
try:
obs = reset_fn()
self._ready_obs[env_id] = obs
self._env_states[env_id] = EnvState.RUN
return
except BaseException as e:
if self._retry_type == 'renew':
err_env = self._envs[env_id]
err_env.close()
self._envs[env_id] = self._env_fn[env_id]()
exceptions.append(e)
time.sleep(self._retry_waiting_time)
continue
self._env_states[env_id] = EnvState.ERROR
self.close()
logging.error("Env {} reset has exceeded max retries({})".format(env_id, self._max_retry))
runtime_error = RuntimeError(
"Env {} reset has exceeded max retries({}), and the latest exception is: {}".format(
env_id, self._max_retry, repr(exceptions[-1])
)
)
runtime_error.__traceback__ = exceptions[-1].__traceback__
raise runtime_error
def step(self, actions: Dict[int, Any]) -> Dict[int, namedtuple]:
"""
......@@ -312,16 +300,25 @@ class BaseEnvManager(object):
def _step(self, env_id: int, act: Any) -> namedtuple:
@retry_wrapper(max_retry=self._max_retry, waiting_time=self._retry_waiting_time)
@timeout_wrapper(timeout=self._step_timeout)
def step_fn():
return self._envs[env_id].step(act)
try:
return step_fn()
except Exception as e:
self._env_states[env_id] = EnvState.ERROR
raise e
exceptions = []
for _ in range(self._max_retry):
try:
return step_fn()
except BaseException as e:
exceptions.append(e)
self._env_states[env_id] = EnvState.ERROR
logging.error("Env {} step has exceeded max retries({})".format(env_id, self._max_retry))
runtime_error = RuntimeError(
"Env {} step has exceeded max retries({}), and the latest exception is: {}".format(
env_id, self._max_retry, repr(exceptions[-1])
)
)
runtime_error.__traceback__ = exceptions[-1].__traceback__
raise runtime_error
def seed(self, seed: Union[Dict[int, int], List[int], int], dynamic_seed: bool = None) -> None:
"""
......
......@@ -15,7 +15,8 @@ from easydict import EasyDict
from types import MethodType
from ding.utils import PropagatingThread, LockContextType, LockContext, ENV_MANAGER_REGISTRY
from .base_env_manager import BaseEnvManager, EnvState, retry_wrapper, timeout_wrapper
from .base_env_manager import BaseEnvManager, EnvState, timeout_wrapper
from ding.envs.env.base_env import BaseEnvTimestep
_NTYPE_TO_CTYPE = {
np.bool_: ctypes.c_bool,
......@@ -163,9 +164,10 @@ class AsyncSubprocessEnvManager(BaseEnvManager):
config = dict(
episode_num=float("inf"),
max_retry=5,
step_timeout=60,
step_timeout=None,
auto_reset=True,
reset_timeout=60,
retry_type='reset',
reset_timeout=None,
retry_waiting_time=0.1,
# subprocess specified args
shared_memory=True,
......@@ -200,7 +202,6 @@ class AsyncSubprocessEnvManager(BaseEnvManager):
self._lock = LockContext(LockContextType.THREAD_LOCK)
self._connect_timeout = self._cfg.connect_timeout
self._connect_timeout = np.max([self._connect_timeout, self._step_timeout + 0.5, self._reset_timeout + 0.5])
self._async_args = {
'step': {
'wait_num': min(self._wait_num, self._env_num),
......@@ -246,7 +247,6 @@ class AsyncSubprocessEnvManager(BaseEnvManager):
self.method_name_list,
self._reset_timeout,
self._step_timeout,
self._max_retry,
),
daemon=True,
name='subprocess_env_manager{}_{}'.format(env_id, time.time())
......@@ -345,7 +345,7 @@ class AsyncSubprocessEnvManager(BaseEnvManager):
ret = self._pipe_parents[env_id].recv()
self._check_data({env_id: ret})
self._env_seed[env_id] = None # seed only use once
except Exception as e:
except BaseException as e:
logging.warning("subprocess reset set seed failed, ignore and continue...")
reset_thread = PropagatingThread(target=self._reset, args=(env_id, ))
reset_thread.daemon = True
......@@ -358,11 +358,10 @@ class AsyncSubprocessEnvManager(BaseEnvManager):
def _reset(self, env_id: int) -> None:
@retry_wrapper(max_retry=self._max_retry, waiting_time=self._retry_waiting_time)
def reset_fn():
if self._pipe_parents[env_id].poll():
recv_data = self._pipe_parents[env_id].recv()
raise Exception("unread data left before sending to the pipe: {}".format(repr(recv_data)))
raise RuntimeError("unread data left before sending to the pipe: {}".format(repr(recv_data)))
# if self._reset_param[env_id] is None, just reset specific env, not pass reset param
if self._reset_param[env_id] is not None:
assert isinstance(self._reset_param[env_id], dict), type(self._reset_param[env_id])
......@@ -371,13 +370,7 @@ class AsyncSubprocessEnvManager(BaseEnvManager):
self._pipe_parents[env_id].send(['reset', [], {}])
if not self._pipe_parents[env_id].poll(self._connect_timeout):
# terminate the old subprocess
self._pipe_parents[env_id].close()
if self._subprocesses[env_id].is_alive():
self._subprocesses[env_id].terminate()
# reset the subprocess
self._create_env_subprocess(env_id)
raise Exception("env reset timeout") # Leave it to retry_wrapper to try again
raise ConnectionError("env reset connection timeout") # Leave it to try again
obs = self._pipe_parents[env_id].recv()
self._check_data({env_id: obs}, close=False)
......@@ -387,16 +380,32 @@ class AsyncSubprocessEnvManager(BaseEnvManager):
self._env_states[env_id] = EnvState.RUN
self._ready_obs[env_id] = obs
try:
reset_fn()
except Exception as e:
logging.error('VEC_ENV_MANAGER: env {} reset error'.format(env_id))
logging.error('\nEnv Process Reset Exception:\n' + ''.join(traceback.format_tb(e.__traceback__)) + repr(e))
if self._closed: # exception cased by main thread closing parent_remote
exceptions = []
for _ in range(self._max_retry):
try:
reset_fn()
return
else:
self.close()
raise e
except BaseException as e:
if self._retry_type == 'renew' or isinstance(e, pickle.UnpicklingError):
self._pipe_parents[env_id].close()
if self._subprocesses[env_id].is_alive():
self._subprocesses[env_id].terminate()
self._create_env_subprocess(env_id)
exceptions.append(e)
time.sleep(self._retry_waiting_time)
logging.error("Env {} reset has exceeded max retries({})".format(env_id, self._max_retry))
runtime_error = RuntimeError(
"Env {} reset has exceeded max retries({}), and the latest exception is: {}".format(
env_id, self._max_retry, repr(exceptions[-1])
)
)
runtime_error.__traceback__ = exceptions[-1].__traceback__
if self._closed: # exception cased by main thread closing parent_remote
return
else:
self.close()
raise runtime_error
def step(self, actions: Dict[int, Any]) -> Dict[int, namedtuple]:
"""
......@@ -441,7 +450,17 @@ class AsyncSubprocessEnvManager(BaseEnvManager):
ready_conn, ready_ids = AsyncSubprocessEnvManager.wait(rest_conn, min(wait_num, len(rest_conn)), timeout)
cur_ready_env_ids = [cur_rest_env_ids[env_id] for env_id in ready_ids]
assert len(cur_ready_env_ids) == len(ready_conn)
timesteps.update({env_id: p.recv() for env_id, p in zip(cur_ready_env_ids, ready_conn)})
# timesteps.update({env_id: p.recv() for env_id, p in zip(cur_ready_env_ids, ready_conn)})
for env_id, p in zip(cur_ready_env_ids, ready_conn):
try:
timesteps.update({env_id: p.recv()})
except pickle.UnpicklingError as e:
timestep = BaseEnvTimestep(None, None, None, {'abnormal': True})
timesteps.update({env_id: timestep})
self._pipe_parents[env_id].close()
if self._subprocesses[env_id].is_alive():
self._subprocesses[env_id].terminate()
self._create_env_subprocess(env_id)
self._check_data(timesteps)
ready_env_ids += cur_ready_env_ids
cur_rest_env_ids = list(set(cur_rest_env_ids).difference(set(cur_ready_env_ids)))
......@@ -546,9 +565,8 @@ class AsyncSubprocessEnvManager(BaseEnvManager):
env_fn_wrapper,
obs_buffer,
method_name_list,
reset_timeout=60,
step_timeout=60,
max_retry=1
reset_timeout=None,
step_timeout=None,
) -> None:
"""
Overview:
......@@ -559,7 +577,6 @@ class AsyncSubprocessEnvManager(BaseEnvManager):
env = env_fn()
parent.close()
@retry_wrapper(max_retry=max_retry)
@timeout_wrapper(timeout=step_timeout)
def step_fn(*args, **kwargs):
timestep = env.step(*args, **kwargs)
......@@ -581,7 +598,7 @@ class AsyncSubprocessEnvManager(BaseEnvManager):
obs_buffer.fill(ret)
ret = None
return ret
except Exception as e:
except BaseException as e:
env.close()
raise e
......@@ -606,8 +623,8 @@ class AsyncSubprocessEnvManager(BaseEnvManager):
else:
raise KeyError("not support env cmd: {}".format(cmd))
child.send(ret)
except Exception as e:
# print("Sub env '{}' error when executing {}".format(str(env), cmd))
except BaseException as e:
logging.debug("Sub env '{}' error when executing {}".format(str(env), cmd))
# when there are some errors in env, worker_fn will send the errors to env manager
# directly send error to another process will lose the stack trace, so we create a new Exception
child.send(
......@@ -620,7 +637,7 @@ class AsyncSubprocessEnvManager(BaseEnvManager):
def _check_data(self, data: Dict, close: bool = True) -> None:
exceptions = []
for i, d in data.items():
if isinstance(d, Exception):
if isinstance(d, BaseException):
self._env_states[i] = EnvState.ERROR
exceptions.append(d)
# when receiving env Exception, env manager will safely close and raise this Exception to caller
......@@ -670,7 +687,9 @@ class AsyncSubprocessEnvManager(BaseEnvManager):
self._env_ref.close()
for _, p in self._pipe_parents.items():
p.send(['close', None, None])
for _, p in self._pipe_parents.items():
for env_id, p in self._pipe_parents.items():
if not p.poll(5):
continue
p.recv()
for i in range(self._env_num):
self._env_states[i] = EnvState.VOID
......@@ -714,9 +733,10 @@ class SyncSubprocessEnvManager(AsyncSubprocessEnvManager):
config = dict(
episode_num=float("inf"),
max_retry=5,
step_timeout=60,
step_timeout=None,
auto_reset=True,
reset_timeout=60,
reset_timeout=None,
retry_type='reset',
retry_waiting_time=0.1,
# subprocess specified args
shared_memory=True,
......
......@@ -12,12 +12,7 @@ from ding.envs.env.base_env import BaseEnvTimestep, BaseEnvInfo
from ding.envs.env_manager.base_env_manager import EnvState
from ding.envs.env_manager import BaseEnvManager, SyncSubprocessEnvManager, AsyncSubprocessEnvManager
from ding.torch_utils import to_tensor, to_ndarray, to_list
from ding.utils import WatchDog, deep_merge_dicts
@pytest.fixture(scope='module')
def setup_watchdog():
return WatchDog
from ding.utils import deep_merge_dicts
class FakeEnv(object):
......@@ -26,17 +21,25 @@ class FakeEnv(object):
self._target_time = random.randint(3, 6)
self._current_time = 0
self._name = cfg['name']
self._id = time.time()
self._stat = None
self._seed = 0
self._data_count = 0
self.timeout_flag = False
self._launched = False
self._state = EnvState.INIT
self._dead_once = False
def reset(self, stat):
if isinstance(stat, str) and stat == 'error':
self.dead()
if isinstance(stat, str) and stat == "timeout":
if isinstance(stat, str) and stat == 'error_once':
if self._dead_once:
self._dead_once = False
self.dead()
else:
self._dead_once = True
if isinstance(stat, str) and stat == "wait":
if self.timeout_flag: # after step(), the reset can hall with status of timeout
time.sleep(5)
if isinstance(stat, str) and stat == "block":
......@@ -55,7 +58,7 @@ class FakeEnv(object):
self.dead()
if isinstance(action, str) and action == 'catched_error':
return BaseEnvTimestep(None, None, True, {'abnormal': True})
if isinstance(action, str) and action == "timeout":
if isinstance(action, str) and action == "wait":
if self.timeout_flag: # after step(), the reset can hall with status of timeout
time.sleep(3)
if isinstance(action, str) and action == 'block':
......@@ -111,6 +114,10 @@ class FakeEnv(object):
def name(self):
return self._name
@property
def time_id(self):
return self._id
def user_defined(self):
pass
......@@ -143,7 +150,7 @@ def setup_model_type():
return FakeModel
def get_manager_cfg(env_num=4):
def get_base_manager_cfg(env_num=4):
manager_cfg = {
'env_cfg': [{
'name': 'name{}'.format(i),
......@@ -156,9 +163,24 @@ def get_manager_cfg(env_num=4):
return EasyDict(manager_cfg)
def get_subprecess_manager_cfg(env_num=4):
manager_cfg = {
'env_cfg': [{
'name': 'name{}'.format(i),
} for i in range(env_num)],
'episode_num': 2,
#'step_timeout': 8,
#'reset_timeout': 10,
'connect_timeout': 8,
'step_timeout': 5,
'max_retry': 2,
}
return EasyDict(manager_cfg)
@pytest.fixture(scope='function')
def setup_base_manager_cfg():
manager_cfg = get_manager_cfg(4)
manager_cfg = get_base_manager_cfg(4)
env_cfg = manager_cfg.pop('env_cfg')
manager_cfg['env_fn'] = [partial(FakeEnv, cfg=c) for c in env_cfg]
return deep_merge_dicts(BaseEnvManager.default_config(), EasyDict(manager_cfg))
......@@ -166,7 +188,7 @@ def setup_base_manager_cfg():
@pytest.fixture(scope='function')
def setup_sync_manager_cfg():
manager_cfg = get_manager_cfg(4)
manager_cfg = get_subprecess_manager_cfg(4)
env_cfg = manager_cfg.pop('env_cfg')
# TODO(nyz) test fail when shared_memory = True
manager_cfg['shared_memory'] = False
......@@ -176,9 +198,8 @@ def setup_sync_manager_cfg():
@pytest.fixture(scope='function')
def setup_async_manager_cfg():
manager_cfg = get_manager_cfg(4)
manager_cfg = get_subprecess_manager_cfg(4)
env_cfg = manager_cfg.pop('env_cfg')
manager_cfg['env_fn'] = [partial(FakeAsyncEnv, cfg=c) for c in env_cfg]
manager_cfg['shared_memory'] = False
manager_cfg['connect_timeout'] = 30
return deep_merge_dicts(AsyncSubprocessEnvManager.default_config(), EasyDict(manager_cfg))
......@@ -21,9 +21,6 @@ class TestBaseEnvManager:
name = env_manager._name
assert len(name) == env_manager.env_num
assert all([isinstance(n, str) for n in name])
name = env_manager.name
assert len(name) == env_manager.env_num
assert all([isinstance(n, str) for n in name])
assert env_manager._max_retry == 5
assert env_manager._reset_timeout == 10
assert all([s == 314 for s in env_manager._seed])
......@@ -73,6 +70,21 @@ class TestBaseEnvManager:
timestep = env_manager.step({i: np.random.randn(4) for i in range(env_manager.env_num)})
assert len(timestep) == env_manager.env_num
# Test reset error once
reset_param = {i: {'stat': 'stat_test'} for i in range(env_manager.env_num)}
assert env_manager._retry_type == 'reset'
env_id_0 = env_manager.time_id[0]
reset_param[0] = {'stat': 'error_once'}
env_manager.reset(reset_param)
env_manager.reset(reset_param)
assert not env_manager._closed
assert env_manager.time_id[0] == env_id_0
env_manager._retry_type = 'renew'
env_id_0 = env_manager.time_id[0]
reset_param[0] = {'stat': 'error_once'}
env_manager.reset(reset_param)
assert not env_manager._closed
assert env_manager.time_id[0] != env_id_0
# Test step catched error
action = {i: np.random.randn(4) for i in range(env_manager.env_num)}
......@@ -96,27 +108,26 @@ class TestBaseEnvManager:
env_manager.close()
def test_block(self, setup_base_manager_cfg, setup_watchdog):
@pytest.mark.timeout(60)
def test_block(self, setup_base_manager_cfg):
env_fn = setup_base_manager_cfg.pop('env_fn')
setup_base_manager_cfg['max_retry'] = 1
env_manager = BaseEnvManager(env_fn, setup_base_manager_cfg)
watchdog = setup_watchdog(30)
assert env_manager._max_retry == 1
# Test reset timeout
watchdog.start()
with pytest.raises(RuntimeError):
reset_param = {i: {'stat': 'block'} for i in range(env_manager.env_num)}
obs = env_manager.launch(reset_param=reset_param)
assert env_manager._closed
reset_param = {i: {'stat': 'stat_test'} for i in range(env_manager.env_num)}
reset_param[0]['stat'] = 'timeout'
reset_param[0]['stat'] = 'wait'
obs = env_manager.launch(reset_param=reset_param)
assert not env_manager._closed
timestep = env_manager.step({i: np.random.randn(4) for i in range(env_manager.env_num)})
assert len(timestep) == env_manager.env_num
watchdog.stop()
# Test step timeout
watchdog.start()
action = {i: np.random.randn(4) for i in range(env_manager.env_num)}
action[0] = 'block'
with pytest.raises(RuntimeError):
......@@ -124,9 +135,8 @@ class TestBaseEnvManager:
assert all([env_manager._env_states[i] == EnvState.RUN for i in range(1, env_manager.env_num)])
obs = env_manager.reset(reset_param)
action[0] = 'timeout'
action[0] = 'wait'
timestep = env_manager.step(action)
assert len(timestep) == env_manager.env_num
watchdog.stop()
env_manager.close()
......@@ -29,8 +29,9 @@ class TestSubprocessEnvManager:
name = env_manager.name
assert len(name) == env_manager.env_num
assert all([isinstance(n, str) for n in name])
assert env_manager._max_retry == 5
assert env_manager._reset_timeout == 10
assert env_manager._max_retry == 2
assert env_manager._connect_timeout == 8
assert env_manager._step_timeout == 5
# Test arribute
with pytest.raises(AttributeError):
data = env_manager.xxx
......@@ -77,15 +78,30 @@ class TestSubprocessEnvManager:
with pytest.raises(AssertionError):
env_manager.reset(reset_param={i: {'stat': 'stat_test'} for i in range(env_manager.env_num)})
with pytest.raises(RuntimeError):
obs = env_manager.launch(reset_param={i: {'stat': 'error'} for i in range(env_manager.env_num)})
env_manager.launch(reset_param={i: {'stat': 'error'} for i in range(env_manager.env_num)})
assert env_manager._closed
time.sleep(0.5) # necessary time interval
obs = env_manager.launch(reset_param={i: {'stat': 'stat_test'} for i in range(env_manager.env_num)})
env_manager.launch(reset_param={i: {'stat': 'stat_test'} for i in range(env_manager.env_num)})
assert not env_manager._closed
timestep = env_manager.step({i: np.random.randn(4) for i in range(env_manager.env_num)})
assert len(timestep) == env_manager.env_num
# Test reset error once
reset_param = {i: {'stat': 'stat_test'} for i in range(env_manager.env_num)}
assert env_manager._retry_type == 'reset'
env_id_0 = env_manager.time_id[0]
reset_param[0] = {'stat': 'error_once'}
env_manager.reset(reset_param)
assert not env_manager._closed
assert env_manager.time_id[0] == env_id_0
env_manager._retry_type = 'renew'
env_id_0 = env_manager.time_id[0]
reset_param[0] = {'stat': 'error_once'}
env_manager.reset(reset_param)
assert not env_manager._closed
assert env_manager.time_id[0] != env_id_0
# Test step catched error
action = {i: np.random.randn(4) for i in range(env_manager.env_num)}
action[0] = 'catched_error'
......@@ -105,9 +121,9 @@ class TestSubprocessEnvManager:
assert len(env_manager.ready_obs) == 4
timestep = env_manager.step({i: np.random.randn(4) for i in range(env_manager.env_num)})
# Test step error
# # Test step error
action[0] = 'error'
with pytest.raises(Exception):
with pytest.raises(RuntimeError):
timestep = env_manager.step(action)
assert env_manager._closed
......@@ -117,21 +133,21 @@ class TestSubprocessEnvManager:
with pytest.raises(AssertionError): # Assert env manager is not closed
env_manager.step([])
@pytest.mark.tmp # gitlab ci and local test pass, github always fail
def test_block(self, setup_async_manager_cfg, setup_watchdog, setup_model_type):
#@pytest.mark.tmp # gitlab ci and local test pass, github always fail
@pytest.mark.unittest
@pytest.mark.timeout(100)
def test_block(self, setup_async_manager_cfg, setup_model_type):
env_fn = setup_async_manager_cfg.pop('env_fn')
env_manager = AsyncSubprocessEnvManager(env_fn, setup_async_manager_cfg)
watchdog = setup_watchdog(60)
model = setup_model_type()
# Test reset timeout
watchdog.start()
# Test connect timeout
with pytest.raises(RuntimeError):
reset_param = {i: {'stat': 'block'} for i in range(env_manager.env_num)}
obs = env_manager.launch(reset_param=reset_param)
assert env_manager._closed
time.sleep(0.5)
reset_param = {i: {'stat': 'stat_test'} for i in range(env_manager.env_num)}
reset_param[0]['stat'] = 'timeout'
reset_param[0]['stat'] = 'wait'
env_manager.launch(reset_param=reset_param)
time.sleep(0.5)
assert not env_manager._closed
......@@ -139,14 +155,27 @@ class TestSubprocessEnvManager:
timestep = env_manager.step({i: np.random.randn(4) for i in range(env_manager.env_num)})
obs = env_manager.ready_obs
assert len(obs) >= 1
watchdog.stop()
# Test reset timeout
env_manager._connect_timeout = 30
env_manager._reset_timeout = 8
with pytest.raises(RuntimeError):
reset_param = {i: {'stat': 'block'} for i in range(env_manager.env_num)}
obs = env_manager.reset(reset_param=reset_param)
assert env_manager._closed
time.sleep(0.5)
reset_param = {i: {'stat': 'stat_test'} for i in range(env_manager.env_num)}
reset_param[0]['stat'] = 'wait'
env_manager.launch(reset_param=reset_param)
time.sleep(0.5)
assert not env_manager._closed
# Test step timeout
watchdog.start()
env_manager._step_timeout = 5
obs = env_manager.reset({i: {'stat': 'stat_test'} for i in range(env_manager.env_num)})
action = {i: np.random.randn(4) for i in range(env_manager.env_num)}
action[0] = 'block'
with pytest.raises(RuntimeError):
with pytest.raises(TimeoutError):
timestep = env_manager.step(action)
obs = env_manager.ready_obs
while 0 not in obs:
......@@ -157,7 +186,7 @@ class TestSubprocessEnvManager:
obs = env_manager.launch(reset_param={i: {'stat': 'stat_test'} for i in range(env_manager.env_num)})
time.sleep(1)
action[0] = 'timeout'
action[0] = 'wait'
timestep = env_manager.step(action)
obs = env_manager.ready_obs
while 0 not in obs:
......@@ -165,7 +194,6 @@ class TestSubprocessEnvManager:
timestep = env_manager.step(action)
obs = env_manager.ready_obs
assert len(obs) >= 1
watchdog.stop()
env_manager.close()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册