未验证 提交 d8bde45c 编写于 作者: X Xu Jingxin 提交者: GitHub

feature(xjx): new main framework and profile helper (#142)

* Init base buffer and storage

* Use ratelimit as middleware

* Pass style check

* Keep the return original return value

* Add buffer.view

* Add replace flag on sample, rewrite middleware processing

* Test slicing

* Add buffer copy middleware

* Add update/delete api in buffer, rename middleware

* Implement update and delete api of buffer

* add naive use time count middleware in buffer

* Rename next to chain

* feature(nyz): add staleness check middleware and polish buffer

* feature(nyz): add naive priority experience replay

* Sample by indices

* Combine buffer and storage layers

* Support indices when deleting items from the queue

* Use dataclass to save buffered data, remove return_index and return_meta

* Add ignore_insufficient

* polish(nyz): add return index in push and copy same data in sample

* Drop useless import

* Fix sample with indices, ensure return size is equal to input size or indices size

* Make sure sampled data in buffer is different from each other

* Support sample by grouped meta key

* Support sample by rolling window

* Add import/export data in buffer

* Padding after sampling from buffer

* Polish use_time_check

* Use buffer as dataset

* Set collate_fn in buffer test

* Init framework

* Remove set_default, add keep

* Move backward_stack to task

* Fix total_step

* Pydash pick is too slow

* Add step records

* Add async mode

* Reuse forward and backward functions in sequence

* Fix sample profile

* demo(nyz): add atari pong runnable demo

* Fix forward bug

* Add task test

* Test pong

* feature(nyz): add deque buffer compatibility wrapper and demo

* polish(nyz): polish code style and add pong dqn new deque buffer demo

* Use sync mode

* Config worker number

* Init parallel mode

* Add prev property on context

* Mesh workers

* First version of parallel mode

* Make send rpc async

* Dont pickle prev

* Support tcp

* More cleanup on system exit

* Test parallel and task

* Enable task copy

* Test attach mode

* Add with statment

* Polish code

* Raise exception when timeout in attach mode

* Add event listeners

* feature(nyz): add pendulum sac new pipeline demo

* Fix main

* Add profiler and step profiler

* Rewrite parallel, cleanup res after task finished

* Add comments

* Remove ctx.prev

* Enable standalone parallel mode

* Remove hooks on ctx

* Add max mean

* demo(nyz): add pong dqn new pipeline demo

* Ensure parallel sock closed before program exit

* Fix parallel test

* Fix pong

* feature(zjow): add feature of profile in ding (#135)

* add profiling feature in ding cli.

* fix ding --profile cli.

* reformat files.

* reformat files again.

* reformat files again.

* Remove flameprof

* Change kept_keys to set

* Use finish as a properity

* Use wrapper

* Reformat step timer output

* Test random seed

* Revert learning rate

* Add topology on parallel

* Use labels on task

* Star in parallel mode

* Don't use daemon process

* Auto sync finish state

* Return logvars

* Fix test wrapper

* Fix test profiler helper

* Pass flake_check

* Lazy launch

* Reporter

* Replace main with main_sac

* Fix parallel ctx

* Fix test

* Fix merge issues
Co-authored-by: Nniuyazhe <niuyazhe@sensetime.com>
Co-authored-by: Nzjowowen <93968541+zjowowen@users.noreply.github.com>
上级 aa612443
......@@ -107,6 +107,12 @@ CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
@click.option(
'--memory', type=str, default=None, help='The requested Memory, read the value from DIJob yaml by default'
)
@click.option(
'--profile',
type=str,
default=None,
help='profile Time cost by cProfile, and save the files into the specified folder path'
)
def cli(
# serial/eval
mode: str,
......@@ -143,7 +149,12 @@ def cli(
gpus: int,
memory: str,
restart_pod_name: str,
profile: str,
):
if profile is not None:
from ..utils.profiler_helper import Profiler
profiler = Profiler()
profiler.profile(profile)
if mode == 'serial':
from .serial_entry import serial_pipeline
if config is None:
......
"""
Main entry
"""
from collections import deque
import torch
import numpy as np
import time
from rich import print
from functools import partial
from ding.model import QAC
from ding.utils import set_pkg_seed
from ding.envs import BaseEnvManager, get_vec_env_setting
from ding.config import compile_config
from ding.policy import SACPolicy
from ding.torch_utils import to_ndarray, to_tensor
from ding.rl_utils import get_epsilon_greedy_fn
from ding.worker.collector.base_serial_evaluator import VectorEvalMonitor
from ding.framework import Task
from dizoo.classic_control.pendulum.config.pendulum_sac_config import main_config, create_config
class DequeBuffer:
"""
For demonstration only
"""
def __init__(self, maxlen=20000) -> None:
self.memory = deque(maxlen=maxlen)
self.n_counter = 0
def push(self, data):
self.memory.append(data)
self.n_counter += 1
def sample(self, size):
if size > len(self.memory):
print('[Warning] no enough data: {}/{}'.format(size, len(self.memory)))
return None
indices = list(np.random.choice(a=len(self.memory), size=size, replace=False))
return [self.memory[i] for i in indices]
# return random.sample(self.memory, size)
def count(self):
return len(self.memory)
class Pipeline:
def __init__(self, cfg, model: torch.nn.Module):
self.cfg = cfg
self.model = model
self.policy = SACPolicy(cfg.policy, model=model)
if 'eps' in cfg.policy.other:
eps_cfg = cfg.policy.other.eps
self.epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
def act(self, env):
def _act(ctx):
ctx.setdefault("collect_env_step", 0)
ctx.keep("collect_env_step")
ctx.obs = env.ready_obs
policy_kwargs = {}
if hasattr(self, 'epsilon_greedy'):
policy_kwargs['eps'] = self.epsilon_greedy(ctx.collect_env_step)
policy_output = self.policy.collect_mode.forward(ctx.obs, **policy_kwargs)
ctx.action = to_ndarray({env_id: output['action'] for env_id, output in policy_output.items()})
ctx.policy_output = policy_output
return _act
def collect(self, env, buffer_, task: Task):
def _collect(ctx):
timesteps = env.step(ctx.action)
ctx.collect_env_step += len(timesteps)
timesteps = to_tensor(timesteps, dtype=torch.float32)
ctx.collect_transitions = []
for env_id, timestep in timesteps.items():
transition = self.policy.collect_mode.process_transition(
ctx.obs[env_id], ctx.policy_output[env_id], timestep
)
ctx.collect_transitions.append(transition)
buffer_.push(transition)
return _collect
def learn(self, buffer_: DequeBuffer, task: Task):
def _learn(ctx):
ctx.setdefault("train_iter", 0)
ctx.keep("train_iter")
for i in range(self.cfg.policy.learn.update_per_collect):
data = buffer_.sample(self.policy.learn_mode.get_attribute('batch_size'))
if not data:
break
learn_output = self.policy.learn_mode.forward(data)
if ctx.train_iter % 20 == 0:
print(
'Current Training: Train Iter({})\tLoss({:.3f})'.format(
ctx.train_iter, learn_output['total_loss']
)
)
ctx.train_iter += 1
return _learn
def evaluate(self, env):
def _eval(ctx):
ctx.setdefault("train_iter", 0)
ctx.setdefault("last_eval_iter", -1)
ctx.keep("train_iter", "last_eval_iter")
if ctx.train_iter == ctx.last_eval_iter or (
(ctx.train_iter - ctx.last_eval_iter) < self.cfg.policy.eval.evaluator.eval_freq
and ctx.train_iter != 0):
return
env.reset()
eval_monitor = VectorEvalMonitor(env.env_num, self.cfg.env.n_evaluator_episode)
while not eval_monitor.is_finished():
obs = env.ready_obs
obs = to_tensor(obs, dtype=torch.float32)
policy_output = self.policy.eval_mode.forward(obs)
action = to_ndarray({i: a['action'] for i, a in policy_output.items()})
timesteps = env.step(action)
timesteps = to_tensor(timesteps, dtype=torch.float32)
for env_id, timestep in timesteps.items():
if timestep.done:
self.policy.eval_mode.reset([env_id])
reward = timestep.info['final_eval_reward']
eval_monitor.update_reward(env_id, reward)
episode_reward = eval_monitor.get_episode_reward()
eval_reward = np.mean(episode_reward)
stop_flag = eval_reward >= self.cfg.env.stop_value and ctx.train_iter > 0
print('Current Evaluation: Train Iter({})\tEval Reward({:.3f})'.format(ctx.train_iter, eval_reward))
ctx.last_eval_iter = ctx.train_iter
if stop_flag:
ctx.finish = True
return _eval
def main(cfg, model, seed=0):
with Task(async_mode=False) as task:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
collector_env = BaseEnvManager(env_fn=[partial(env_fn, cfg=c) for c in collector_env_cfg], cfg=cfg.env.manager)
evaluator_env = BaseEnvManager(env_fn=[partial(env_fn, cfg=c) for c in evaluator_env_cfg], cfg=cfg.env.manager)
collector_env.seed(seed)
evaluator_env.seed(seed, dynamic_seed=False)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
collector_env.launch()
evaluator_env.launch()
replay_buffer = DequeBuffer()
sac = Pipeline(cfg, model)
# task.use_step_wrapper(StepTimer(print_per_step=1))
task.use(sac.evaluate(evaluator_env), filter_labels=["standalone", "node.0"])
task.use(
task.sequence(sac.act(collector_env), sac.collect(collector_env, replay_buffer, task=task)),
filter_labels=["standalone", "node.[1-9]*"]
)
task.use(sac.learn(replay_buffer, task=task), filter_labels=["standalone", "node.0"])
task.run(max_step=100000)
if __name__ == "__main__":
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
model = QAC(**cfg.policy.model)
main(cfg, model)
from .context import Context
from .task import Task
from .parallel import Parallel
class Context(dict):
"""
Overview:
Context is an object that pass contextual data between middlewares, whose life cycle
is only one training iteration. It is a dict that reflect itself, so you can set
any properties as you wish.
"""
def __init__(self, total_step: int = 0, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.__dict__ = self
self.total_step = total_step
# Reserved properties
self.finish = False
self._kept_keys = {"finish"}
def renew(self) -> 'Context': # noqa
"""
Overview:
Renew context from self, add total_step and shift kept properties to the new instance.
"""
ctx = Context()
for key in self._kept_keys:
ctx[key] = self[key]
return ctx
def keep(self, *keys: str) -> None:
"""
Overview:
Keep this key/keys until next iteration.
"""
for key in keys:
self._kept_keys.add(key)
import atexit
import os
import random
import threading
import time
from mpire.pool import WorkerPool
import pynng
import pickle
import logging
import tempfile
import socket
from os import path
from typing import Callable, Dict, List, Optional, Tuple, Union
from threading import Thread
from pynng.nng import Bus0, Socket
from ding.utils.design_helper import SingletonMetaclass
from rich import print
# Avoid ipc address conflict, random should always use random seed
random = random.Random()
class Parallel(metaclass=SingletonMetaclass):
def __init__(self) -> None:
self._listener = None
self._sock: Socket = None
self._rpc = {}
self._bind_addr = None
self._lock = threading.Lock()
self.is_active = False
self.attach_to = None
self.finished = False
def run(self, node_id: int, listen_to: str, attach_to: List[str] = None) -> None:
self.node_id = node_id
self.attach_to = attach_to = attach_to or []
self._listener = Thread(
target=self.listen,
kwargs={
"listen_to": listen_to,
"attach_to": attach_to
},
name="paralllel_listener",
daemon=True
)
self._listener.start()
@staticmethod
def runner(
n_parallel_workers: int,
attach_to: Optional[List[str]] = None,
protocol: str = "ipc",
address: Optional[str] = None,
ports: Optional[List[int]] = None,
topology: str = "mesh"
) -> Callable:
"""
Overview:
This method allows you to configure parallel parameters, and now you are still in the parent process.
Arguments:
- n_parallel_workers (:obj:`int`): Workers to spawn.
- attach_to (:obj:`Optional[List[str]]`): The node's addresses you want to attach to.
- protocol (:obj:`str`): Network protocol.
- address (:obj:`Optional[str]`): Bind address, ip or file path.
- ports (:obj:`Optional[List[int]]`): Candidate ports.
- topology (:obj:`str`): Network topology, includes:
`mesh` (default): fully connected between each other;
`star`: only connect to the first node;
Returns:
- _runner (:obj:`Callable`): The wrapper function for main.
"""
attach_to = attach_to or []
assert n_parallel_workers > 0, "Parallel worker number should bigger than 0"
def _runner(main_process: Callable, *args, **kwargs) -> None:
"""
Overview:
Prepare to run in subprocess.
Arguments:
- main_process (:obj:`Callable`): The main function, your program start from here.
"""
nodes = Parallel.get_node_addrs(n_parallel_workers, protocol=protocol, address=address, ports=ports)
logging.info("Bind subprocesses on these addresses: {}".format(nodes))
print("Bind subprocesses on these addresses: {}".format(nodes))
def cleanup_nodes():
for node in nodes:
protocol, file_path = node.split("://")
if protocol == "ipc" and path.exists(file_path):
os.remove(file_path)
atexit.register(cleanup_nodes)
def topology_network(node_id: int) -> List[str]:
if topology == "mesh":
return nodes[:node_id] + attach_to
elif topology == "star":
return nodes[:min(1, node_id)]
else:
raise ValueError("Unknown topology: {}".format(topology))
params_group = []
for node_id in range(n_parallel_workers):
runner_args = []
runner_kwargs = {
"node_id": node_id,
"listen_to": nodes[node_id],
"attach_to": topology_network(node_id) + attach_to
}
params = [(runner_args, runner_kwargs), (main_process, args, kwargs)]
params_group.append(params)
if n_parallel_workers == 1:
Parallel.subprocess_runner(*params_group[0])
else:
with WorkerPool(n_jobs=n_parallel_workers, start_method="spawn", daemon=False) as pool:
# Cleanup the pool just in case the program crashes.
atexit.register(pool.__exit__)
pool.map(Parallel.subprocess_runner, params_group)
return _runner
@staticmethod
def subprocess_runner(runner_params: Tuple[Union[List, Dict]], main_params: Tuple[Union[List, Dict]]) -> None:
"""
Overview:
Really run in subprocess.
Arguments:
- runner_params (:obj:`Tuple[Union[List, Dict]]`): Args and kwargs for runner.
- main_params (:obj:`Tuple[Union[List, Dict]]`): Args and kwargs for main function.
"""
main_process, args, kwargs = main_params
runner_args, runner_kwargs = runner_params
with Parallel() as router:
router.is_active = True
router.run(*runner_args, **runner_kwargs)
main_process(*args, **kwargs)
@staticmethod
def get_node_addrs(
n_workers: int,
protocol: str = "ipc",
address: Optional[str] = None,
ports: Optional[List[int]] = None
) -> None:
if protocol == "ipc":
node_name = "".join(random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=4))
tmp_dir = tempfile.gettempdir()
nodes = ["ipc://{}/ditask_{}_{}.ipc".format(tmp_dir, node_name, i) for i in range(n_workers)]
elif protocol == "tcp":
address = address or Parallel.get_ip()
ports = ports or range(50515, 50515 + n_workers)
assert len(ports) == n_workers, "The number of ports must be the same as the number of workers, \
now there are {} ports and {} workers".format(len(ports), n_workers)
nodes = ["tcp://{}:{}".format(address, port) for port in ports]
else:
raise Exception("Unknown protocol {}".format(protocol))
return nodes
def listen(self, listen_to: str, attach_to: List[str] = None):
attach_to = attach_to or []
self._bind_addr = listen_to
with Bus0() as sock:
self._sock = sock
sock.listen(self._bind_addr)
time.sleep(0.1) # Wait for peers to bind
for contact in attach_to:
sock.dial(contact)
while True:
try:
msg = sock.recv_msg()
self.recv_rpc(msg.bytes)
except pynng.Timeout:
logging.warning("Timeout on node {} when waiting for message from bus".format(self._bind_addr))
except pynng.Closed:
if not self.finished:
logging.error("The socket was not closed under normal circumstances!")
break
except Exception as e:
logging.error("Meet exception when listening for new messages", e)
break
def register_rpc(self, fn_name: str, fn: Callable) -> None:
self._rpc[fn_name] = fn
def send_rpc(self, func_name: str, *args, **kwargs) -> None:
if self.is_active:
payload = {"f": func_name, "a": args, "k": kwargs}
return self._sock and self._sock.send(pickle.dumps(payload, protocol=-1))
def recv_rpc(self, msg: bytes):
try:
payload = pickle.loads(msg)
except Exception as e:
logging.warning("Error when unpacking message on node {}, msg: {}".format(self._bind_addr, e))
if payload["f"] in self._rpc:
self._rpc[payload["f"]](*payload["a"], **payload["k"])
else:
logging.warning("There was no function named {} in rpc table".format(payload["f"]))
@staticmethod
def get_ip():
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
# doesn't even have to be reachable
s.connect(('10.255.255.255', 1))
IP = s.getsockname()[0]
except Exception:
IP = '127.0.0.1'
finally:
s.close()
return IP
def __enter__(self) -> "Parallel":
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
def stop(self):
logging.info("Stopping parallel worker on address: {}".format(self._bind_addr))
self.finished = True
time.sleep(0.03)
if self._sock:
self._sock.close()
if self._listener:
self._listener.join(timeout=1)
from collections import defaultdict
import logging
import time
import asyncio
import concurrent.futures
import fnmatch
from types import GeneratorType
from typing import Awaitable, Callable, Dict, Generator, Iterable, List, Optional, Set
from ding.framework.context import Context
from ding.framework.parallel import Parallel
def enable_async(func: Callable) -> Callable:
"""
Overview:
Empower the function with async ability.
Arguments:
- func (:obj:`Callable`): The original function.
Returns:
- runtime_handler (:obj:`Callable`): The wrap function.
"""
def runtime_handler(task: "Task", *args, **kwargs) -> "Task":
"""
Overview:
If task's async mode is enabled, execute the step in current loop executor asyncly,
or execute the task sync.
Arguments:
- task (:obj:`Task`): The task instance.
Returns:
- result (:obj:`Union[Any, Awaitable]`): The result or future object of middleware.
"""
if "async_mode" in kwargs:
async_mode = kwargs.pop("async_mode")
else:
async_mode = task.async_mode
if async_mode:
t = task._loop.run_in_executor(task._thread_pool, func, task, *args, **kwargs)
task._async_stack.append(t)
return task
else:
return func(task, *args, **kwargs)
return runtime_handler
class Task:
"""
Tash will manage the execution order of the entire pipeline, register new middleware,
and generate new context objects.
"""
def __init__(
self,
async_mode: bool = False,
n_async_workers: int = 3,
middleware: Optional[List[Callable]] = None,
step_wrappers: Optional[List[Callable]] = None,
event_listeners: Optional[Dict[str, List]] = None,
once_listeners: Optional[Dict[str, List]] = None,
attach_callback: Optional[Callable] = None,
labels: Optional[Set[str]] = None,
**_
) -> None:
self.middleware = middleware or []
self.step_wrappers = step_wrappers or []
self.ctx = Context()
self.parallel_ctx = Context()
self._backward_stack = []
# Async segment
self.async_mode = async_mode
self.n_async_workers = n_async_workers
self._async_stack = []
self._loop = None
self._thread_pool = None
self.event_listeners = event_listeners or defaultdict(list)
self.once_listeners = once_listeners or defaultdict(list)
self.labels = labels or set()
# Parallel segment
self.router = Parallel()
if async_mode or self.router.is_active:
self._thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=n_async_workers)
self._loop = asyncio.new_event_loop()
if self.router.is_active:
self.router.register_rpc("task.emit", self.emit)
if attach_callback:
self.wait_for_attach_callback(attach_callback)
self.on("sync_parallel_ctx", self.sync_parallel_ctx)
self.init_labels()
def init_labels(self):
if self.async_mode:
self.labels.add("async")
if self.router.is_active:
self.labels.add("distributed")
self.labels.add("node.{}".format(self.router.node_id))
else:
self.labels.add("standalone")
def use(self, fn: Callable, filter_labels: Optional[Iterable[str]] = None) -> 'Task':
"""
Overview:
Register middleware to task. The middleware will be executed by it's registry order.
Arguments:
- fn (:obj:`Callable`): A middleware is a function with only one argument: ctx.
"""
if not filter_labels or any([fnmatch.filter(self.labels, v) for v in filter_labels]):
self.middleware.append(fn)
return self
def use_step_wrapper(self, fn: Callable) -> 'Task':
self.step_wrappers.append(fn)
return self
def run(self, max_step: int = 1e10) -> None:
"""
Overview:
Execute the iterations, when reach the max_step or task.finish is true,
The loop will be break.
Arguments:
- max_step (:obj:`int`): Max step of iterations.
"""
if len(self.middleware) == 0:
return
for i in range(max_step):
for fn in self.middleware:
self.forward(fn)
self.backward()
if i == max_step - 1:
self.ctx.finish = True
self.renew()
if self.finish:
break
@enable_async
def forward(self, fn: Callable, ctx: Context = None, backward_stack: List[Generator] = None) -> 'Task':
"""
Overview:
This function will execute the middleware until the first yield statment,
or the end of the middleware.
Arguments:
- fn (:obj:`Callable`): Function with contain the ctx argument in middleware.
"""
if not backward_stack:
backward_stack = self._backward_stack
if not ctx:
ctx = self.ctx
for wrapper in self.step_wrappers:
fn = wrapper(fn)
g = fn(ctx)
if isinstance(g, GeneratorType):
try:
next(g)
backward_stack.append(g)
except StopIteration:
pass
return self
@enable_async
def backward(self, backward_stack: List[Generator] = None) -> 'Task':
"""
Overview:
Execute the rest part of middleware, by the reversed order of registry.
"""
if not backward_stack:
backward_stack = self._backward_stack
while backward_stack:
# FILO
g = backward_stack.pop()
try:
next(g)
except StopIteration:
continue
return self
def sequence(self, *fns: List[Callable]) -> Callable:
"""
Overview:
Wrap functions and keep them run in sequence, Usually in order to avoid the confusion
of dependencies in async mode.
Arguments:
- fn (:obj:`Callable`): Chain a sequence of middleware, wrap them into one middleware function.
"""
def _sequence(ctx):
backward_stack = []
for fn in fns:
self.forward(fn, ctx=ctx, backward_stack=backward_stack, async_mode=False)
yield
self.backward(backward_stack=backward_stack, async_mode=False)
name = ",".join([fn.__name__ for fn in fns])
_sequence.__name__ = "sequence<{}>".format(name)
return _sequence
def renew(self) -> 'Task':
"""
Overview:
Renew the context instance, this function should be called after backward in the end of iteration.
"""
# Sync should be called before backward, otherwise it is possible
# that some generators have not been pushed to backward_stack.
self.sync()
self.backward()
self.sync()
# Renew context
old_ctx = self.ctx
if self.router.is_active:
# Send context to other parallel processes
self.async_executor(self.router.send_rpc, "task.emit", "sync_parallel_ctx", old_ctx)
new_ctx = old_ctx.renew()
new_ctx.total_step = old_ctx.total_step + 1
self.ctx = new_ctx
return self
def __enter__(self) -> "Task":
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
def stop(self) -> None:
if self._thread_pool:
self._thread_pool.shutdown()
def sync(self) -> 'Task':
if self._loop:
self._loop.run_until_complete(self.sync_tasks())
return self
async def sync_tasks(self) -> Awaitable[None]:
while self._async_stack:
# FIFO
t = self._async_stack.pop(0)
await t
def wait_for_attach_callback(self, attach_callback: Callable, n_timeout: int = 30):
if len(self.router.attach_to) > 0:
logging.warning(
"The attach mode will wait for the latest context, an exception will \
be thrown after the timeout {}s is reached".format(n_timeout)
)
is_timeout = True
ctx = None
def on_sync_parallel_ctx(new_ctx):
nonlocal ctx
ctx = new_ctx
self.once("sync_parallel_ctx", on_sync_parallel_ctx)
for _ in range(n_timeout * 10):
if ctx:
is_timeout = False
break
time.sleep(0.1)
if is_timeout:
# If attach callback is defined, the attach mode should wait for callback finished,
# otherwise it may overwrite the training results of other processes
raise TimeoutError("Attach timeout, not received the latest context.")
attach_callback(ctx)
def async_executor(self, fn: Callable, *args, **kwargs) -> None:
"""
Overview:
Execute task in background, then apppend the future instance in _async_stack.
Arguments:
- fn (:obj:`Callable`): Synchronization fuction.
"""
if not self._loop:
raise Exception("Event loop was not initialized, please call this function in async or parallel mode")
t = self._loop.run_in_executor(self._thread_pool, fn, *args, **kwargs)
self._async_stack.append(t)
def emit(self, event_name, *args, **kwargs):
if event_name in self.event_listeners:
for fn in self.event_listeners[event_name]:
fn(*args, **kwargs)
if event_name in self.once_listeners:
while self.once_listeners[event_name]:
fn = self.once_listeners[event_name].pop()
fn(*args, **kwargs)
def on(self, event: str, fn: Callable) -> None:
self.event_listeners[event].append(fn)
def once(self, event: str, fn: Callable) -> None:
self.once_listeners[event].append(fn)
@property
def finish(self) -> bool:
"""
Overview:
Link the ctx's finish state, in order to be easily called externally.
"""
return self.ctx.finish
def __copy__(self):
return Task(**self.__dict__)
def sync_parallel_ctx(self, ctx):
"""
Overview:
Sync parallel ctx
"""
self.parallel_ctx = ctx
if self.parallel_ctx.finish:
self.ctx.finish = True
import pytest
from ding.framework import Context
import pickle
@pytest.mark.unittest
def test_pickable():
ctx = Context(hello="world", keep_me=True)
ctx.keep("keep_me")
_ctx = pickle.loads(pickle.dumps(ctx))
assert _ctx.hello == "world"
_ctx = ctx.renew()
assert _ctx.keep_me
from collections import defaultdict
import pytest
import time
import os
from ding.framework import Parallel
from ding.utils.design_helper import SingletonMetaclass
def parallel_main():
msg = defaultdict(bool)
def test_callback(key):
msg[key] = True
with Parallel() as router:
router.register_rpc("test_callback", test_callback)
# Wait for nodes to bind
time.sleep(0.7)
router.send_rpc("test_callback", "ping")
for _ in range(30):
if msg["ping"]:
break
time.sleep(0.03)
assert msg["ping"]
@pytest.mark.unittest
def test_parallel_run():
Parallel.runner(n_parallel_workers=2)(parallel_main)
Parallel.runner(n_parallel_workers=2, protocol="tcp")(parallel_main)
def parallel_main_alone(pid):
assert os.getpid() == pid
router = Parallel()
time.sleep(0.3) # Waiting to bind listening address
assert router._bind_addr
def test_parallel_run_alone():
try:
Parallel.runner(n_parallel_workers=1)(parallel_main_alone, os.getpid())
finally:
del SingletonMetaclass.instances[Parallel]
def star_parallel_main():
with Parallel() as router:
if router.node_id != 0:
assert len(router.attach_to) == 1
# Wait for other nodes
time.sleep(2)
@pytest.mark.unittest
def test_parallel_topology():
Parallel.runner(n_parallel_workers=3, topology="star")(star_parallel_main)
from concurrent.futures import thread
from os import spawnl
from attr.validators import instance_of
import pytest
import time
import copy
import random
from mpire import WorkerPool
from ding.framework import Task
from ding.framework.context import Context
from ding.framework.parallel import Parallel
from ding.utils.design_helper import SingletonMetaclass
@pytest.mark.unittest
def test_serial_pipeline():
def step0(ctx):
ctx.setdefault("pipeline", [])
ctx.pipeline.append(0)
def step1(ctx):
ctx.pipeline.append(1)
# Execute step1, step2 twice
task = Task()
for _ in range(2):
task.forward(step0)
task.forward(step1)
assert task.ctx.pipeline == [0, 1, 0, 1]
# Renew and execute step1, step2
task.renew()
assert task.ctx.total_step == 1
task.forward(step0)
task.forward(step1)
assert task.ctx.pipeline == [0, 1]
# Test context inheritance
task.renew()
@pytest.mark.unittest
def test_serial_yield_pipeline():
def step0(ctx):
ctx.setdefault("pipeline", [])
ctx.pipeline.append(0)
yield
ctx.pipeline.append(0)
def step1(ctx):
ctx.pipeline.append(1)
task = Task()
task.forward(step0)
task.forward(step1)
task.backward()
assert task.ctx.pipeline == [0, 1, 0]
assert len(task._backward_stack) == 0
@pytest.mark.unittest
def test_async_pipeline():
def step0(ctx):
ctx.setdefault("pipeline", [])
ctx.pipeline.append(0)
def step1(ctx):
ctx.pipeline.append(1)
# Execute step1, step2 twice
task = Task(async_mode=True)
for _ in range(2):
task.forward(step0)
time.sleep(0.1)
task.forward(step1)
time.sleep(0.1)
task.backward()
assert task.ctx.pipeline == [0, 1, 0, 1]
task.renew()
assert task.ctx.total_step == 1
@pytest.mark.unittest
def test_async_yield_pipeline():
def step0(ctx):
ctx.setdefault("pipeline", [])
time.sleep(0.1)
ctx.pipeline.append(0)
yield
ctx.pipeline.append(0)
def step1(ctx):
time.sleep(0.2)
ctx.pipeline.append(1)
task = Task(async_mode=True)
task.forward(step0)
task.forward(step1)
time.sleep(0.3)
task.backward().sync()
assert task.ctx.pipeline == [0, 1, 0]
assert len(task._backward_stack) == 0
def parallel_main():
task = Task()
sync_count = 0
def on_sync_parallel_ctx(ctx):
nonlocal sync_count
assert isinstance(ctx, Context)
sync_count += 1
task.on("sync_parallel_ctx", on_sync_parallel_ctx)
task.use(lambda _: time.sleep(0.2 + random.random() / 10))
task.run(max_step=10)
assert sync_count > 0
def parallel_main_eager():
task = Task()
sync_count = 0
def on_sync_parallel_ctx(ctx):
nonlocal sync_count
assert isinstance(ctx, Context)
sync_count += 1
task.on("sync_parallel_ctx", on_sync_parallel_ctx)
for _ in range(10):
task.forward(lambda _: time.sleep(0.2 + random.random() / 10))
task.renew()
assert sync_count > 0
@pytest.mark.unittest
def test_parallel_pipeline():
Parallel.runner(n_parallel_workers=2)(parallel_main_eager)
Parallel.runner(n_parallel_workers=2)(parallel_main)
@pytest.mark.unittest
def test_copy_task():
t1 = Task(async_mode=True, n_async_workers=1)
t2 = copy.copy(t1)
assert t2.async_mode
assert t1 is not t2
def attach_mode_main_task():
with Task() as task:
task.use(lambda _: time.sleep(0.1))
task.run(max_step=10)
def attach_mode_attach_task():
ctx = None
def attach_callback(new_ctx):
nonlocal ctx
ctx = new_ctx
with Task(attach_callback=attach_callback) as task:
task.use(lambda _: time.sleep(0.1))
task.run(max_step=10)
assert ctx is not None
def attach_mode_main(job):
if job == "run_task":
Parallel.runner(
n_parallel_workers=2, protocol="tcp", address="127.0.0.1", ports=[50501, 50502]
)(attach_mode_main_task)
elif "run_attach_task":
time.sleep(0.3)
try:
Parallel.runner(
n_parallel_workers=1,
protocol="tcp",
address="127.0.0.1",
ports=[50503],
attach_to=["tcp://127.0.0.1:50501", "tcp://127.0.0.1:50502"]
)(attach_mode_attach_task)
finally:
del SingletonMetaclass.instances[Parallel]
else:
raise Exception("Unknown task")
@pytest.mark.unittest
def test_attach_mode():
with WorkerPool(n_jobs=2, daemon=False, start_method="spawn") as pool:
pool.map(attach_mode_main, ["run_task", "run_attach_task"])
@pytest.mark.unittest
def test_label():
task = Task()
result = {}
task.use(lambda _: result.setdefault("not_me", True), filter_labels=["async"])
task.use(lambda _: result.setdefault("has_me", True))
task.run(max_step=1)
assert "not_me" not in result
assert "has_me" in result
def sync_parallel_ctx_main():
with Task() as task:
task.use(lambda _: time.sleep(1))
if task.router.node_id == 0: # Fast
task.run(max_step=2)
else: # Slow
task.run(max_step=10)
assert task.parallel_ctx
assert task.ctx.finish
assert task.ctx.total_step < 9
@pytest.mark.unittest
def test_sync_parallel_ctx():
Parallel.runner(n_parallel_workers=2)(sync_parallel_ctx_main)
# In use mode
# In forward mode
# Wrapper in wrapper
import pytest
from ding.framework.task import Task
from ding.framework.wrapper import StepTimer
@pytest.mark.unittest
def test_step_timer():
def step1(_):
1
def step2(_):
2
def step3(_):
3
def step4(_):
4
# Lazy mode (with use statment)
step_timer = StepTimer()
task = Task()
task.use_step_wrapper(step_timer)
task.use(step1)
task.use(step2)
task.use(task.sequence(step3, step4))
task.run(3)
assert len(step_timer.records) == 5
for records in step_timer.records.values():
assert len(records) == 3
# Eager mode (with forward statment)
step_timer = StepTimer()
task = Task()
task.use_step_wrapper(step_timer)
for _ in range(3):
task.forward(step1) # Step 1
task.forward(step2) # Step 2
task.renew()
assert len(step_timer.records) == 2
for records in step_timer.records.values():
assert len(records) == 3
# Wrapper in wrapper
step_timer1 = StepTimer()
step_timer2 = StepTimer()
task = Task()
task.use_step_wrapper(step_timer1)
task.use_step_wrapper(step_timer2)
task.use(step1)
task.use(step2)
task.run(3)
assert len(step_timer1.records) == 2
assert len(step_timer2.records) == 2
for records in step_timer1.records.values():
assert len(records) == 3
for records in step_timer2.records.values():
assert len(records) == 3
from .step_timer import StepTimer
from collections import deque, defaultdict
from types import GeneratorType
from typing import Callable
from rich import print
import numpy as np
import time
class StepTimer:
def __init__(self, print_per_step: int = 1, smooth_window: int = 10) -> None:
self.print_per_step = print_per_step
self.records = defaultdict(lambda: deque(maxlen=print_per_step * smooth_window))
def __call__(self, fn: Callable) -> Callable:
step_name = getattr(fn, "__name__", type(fn).__name__)
step_id = id(fn)
def executor(ctx):
start_time = time.time()
time_cost = 0
g = fn(ctx)
if isinstance(g, GeneratorType):
try:
next(g)
except StopIteration:
pass
time_cost = time.time() - start_time
yield
start_time = time.time()
try:
next(g)
except StopIteration:
pass
time_cost += time.time() - start_time
else:
time_cost = time.time() - start_time
self.records[step_id].append(time_cost * 1000)
if ctx.total_step % self.print_per_step == 0:
print(
"[Step Timer] {}: Cost: {:.2f}ms, Mean: {:.2f}ms".format(
step_name, time_cost * 1000, np.mean(self.records[step_id])
)
)
executor.__name__ = "StepTimer<{}>".format(step_name)
return executor
......@@ -24,6 +24,7 @@ from .system_helper import get_ip, get_pid, get_task_uid, PropagatingThread, fin
from .time_helper import build_time_helper, EasyTimer, WatchDog
from .type_helper import SequenceType
from .scheduler_helper import Scheduler
from .profiler_helper import Profiler, register_profiler
if ding.enable_linklink:
from .linklink_dist_helper import get_rank, get_world_size, dist_mode, dist_init, dist_finalize, \
......
import atexit
import pstats
import io
import cProfile
import os
def register_profiler(write_profile, pr, folder_path):
atexit.register(write_profile, pr, folder_path)
class Profiler:
def __init__(self):
self.pr = cProfile.Profile()
def mkdir(self, directory):
if not os.path.exists(directory):
os.makedirs(directory)
def write_profile(self, pr, folder_path):
pr.disable()
s_tottime = io.StringIO()
s_cumtime = io.StringIO()
ps = pstats.Stats(pr, stream=s_tottime).sort_stats('tottime')
ps.print_stats()
with open(folder_path + "/profile_tottime.txt", 'w+') as f:
f.write(s_tottime.getvalue())
ps = pstats.Stats(pr, stream=s_cumtime).sort_stats('cumtime')
ps.print_stats()
with open(folder_path + "/profile_cumtime.txt", 'w+') as f:
f.write(s_cumtime.getvalue())
pr.dump_stats(folder_path + "/profile.prof")
def profile(self, folder_path="./tmp"):
self.mkdir(folder_path)
self.pr.enable()
register_profiler(self.write_profile, self.pr, folder_path)
from easydict import EasyDict
import pytest
import unittest
from unittest import mock
from unittest.mock import patch
import pathlib as pl
import os
import shutil
from ding.utils.profiler_helper import Profiler, register_profiler
@pytest.mark.unittest
class TestProfilerModule:
def assertIsFile(self, path):
if not pl.Path(path).resolve().is_file():
raise AssertionError("File does not exist: %s" % str(path))
def test(self):
profiler = Profiler()
def register_mock(write_profile, pr, folder_path):
profiler.write_profile(pr, folder_path)
def clean_up(dir):
if os.path.exists(dir):
shutil.rmtree(dir)
dir = "./tmp_test/"
clean_up(dir)
with patch('ding.utils.profiler_helper.register_profiler', register_mock):
profiler.profile(dir)
file_path = os.path.join(dir, "profile_tottime.txt")
self.assertIsFile(file_path)
file_path = os.path.join(dir, "profile_cumtime.txt")
self.assertIsFile(file_path)
file_path = os.path.join(dir, "profile.prof")
self.assertIsFile(file_path)
clean_up(dir)
......@@ -238,6 +238,8 @@ class BaseLearner(object):
self.call_hook('after_iter')
self._last_iter.add(1)
return log_vars
@auto_checkpoint
def start(self) -> None:
"""
......
......@@ -73,6 +73,9 @@ setup(
'scipy',
'trueskill',
'h5py',
'rich',
'mpire',
'pynng'
],
extras_require={
'test': [
......@@ -135,7 +138,6 @@ setup(
# 'gym_soccer_env': [
# 'gym-soccer @ git+https://github.com/LikeJulia/gym-soccer@dev-install-packages#egg=gym-soccer',
# ],
'sc2_env': [
'absl-py>=0.1.0',
'future',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册