diff --git a/ding/entry/cli.py b/ding/entry/cli.py index 8eac954e851400b93165204385505866c3c9ecad..e55ce67ffc0fb9ae636f201195b602ad1bc53993 100644 --- a/ding/entry/cli.py +++ b/ding/entry/cli.py @@ -107,6 +107,12 @@ CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) @click.option( '--memory', type=str, default=None, help='The requested Memory, read the value from DIJob yaml by default' ) +@click.option( + '--profile', + type=str, + default=None, + help='profile Time cost by cProfile, and save the files into the specified folder path' +) def cli( # serial/eval mode: str, @@ -143,7 +149,12 @@ def cli( gpus: int, memory: str, restart_pod_name: str, + profile: str, ): + if profile is not None: + from ..utils.profiler_helper import Profiler + profiler = Profiler() + profiler.profile(profile) if mode == 'serial': from .serial_entry import serial_pipeline if config is None: diff --git a/ding/entry/main.py b/ding/entry/main.py new file mode 100644 index 0000000000000000000000000000000000000000..08168ef56f97660641adb233b1d298391852678b --- /dev/null +++ b/ding/entry/main.py @@ -0,0 +1,172 @@ +""" +Main entry +""" +from collections import deque +import torch +import numpy as np +import time +from rich import print +from functools import partial +from ding.model import QAC +from ding.utils import set_pkg_seed +from ding.envs import BaseEnvManager, get_vec_env_setting +from ding.config import compile_config +from ding.policy import SACPolicy +from ding.torch_utils import to_ndarray, to_tensor +from ding.rl_utils import get_epsilon_greedy_fn +from ding.worker.collector.base_serial_evaluator import VectorEvalMonitor +from ding.framework import Task +from dizoo.classic_control.pendulum.config.pendulum_sac_config import main_config, create_config + + +class DequeBuffer: + """ + For demonstration only + """ + + def __init__(self, maxlen=20000) -> None: + self.memory = deque(maxlen=maxlen) + self.n_counter = 0 + + def push(self, data): + self.memory.append(data) + self.n_counter += 1 + + def sample(self, size): + if size > len(self.memory): + print('[Warning] no enough data: {}/{}'.format(size, len(self.memory))) + return None + indices = list(np.random.choice(a=len(self.memory), size=size, replace=False)) + return [self.memory[i] for i in indices] + # return random.sample(self.memory, size) + + def count(self): + return len(self.memory) + + +class Pipeline: + + def __init__(self, cfg, model: torch.nn.Module): + self.cfg = cfg + self.model = model + self.policy = SACPolicy(cfg.policy, model=model) + if 'eps' in cfg.policy.other: + eps_cfg = cfg.policy.other.eps + self.epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type) + + def act(self, env): + + def _act(ctx): + ctx.setdefault("collect_env_step", 0) + ctx.keep("collect_env_step") + ctx.obs = env.ready_obs + policy_kwargs = {} + if hasattr(self, 'epsilon_greedy'): + policy_kwargs['eps'] = self.epsilon_greedy(ctx.collect_env_step) + policy_output = self.policy.collect_mode.forward(ctx.obs, **policy_kwargs) + ctx.action = to_ndarray({env_id: output['action'] for env_id, output in policy_output.items()}) + ctx.policy_output = policy_output + + return _act + + def collect(self, env, buffer_, task: Task): + + def _collect(ctx): + timesteps = env.step(ctx.action) + ctx.collect_env_step += len(timesteps) + timesteps = to_tensor(timesteps, dtype=torch.float32) + ctx.collect_transitions = [] + for env_id, timestep in timesteps.items(): + transition = self.policy.collect_mode.process_transition( + ctx.obs[env_id], ctx.policy_output[env_id], timestep + ) + ctx.collect_transitions.append(transition) + buffer_.push(transition) + + return _collect + + def learn(self, buffer_: DequeBuffer, task: Task): + + def _learn(ctx): + ctx.setdefault("train_iter", 0) + ctx.keep("train_iter") + for i in range(self.cfg.policy.learn.update_per_collect): + data = buffer_.sample(self.policy.learn_mode.get_attribute('batch_size')) + if not data: + break + learn_output = self.policy.learn_mode.forward(data) + if ctx.train_iter % 20 == 0: + print( + 'Current Training: Train Iter({})\tLoss({:.3f})'.format( + ctx.train_iter, learn_output['total_loss'] + ) + ) + ctx.train_iter += 1 + + return _learn + + def evaluate(self, env): + + def _eval(ctx): + ctx.setdefault("train_iter", 0) + ctx.setdefault("last_eval_iter", -1) + ctx.keep("train_iter", "last_eval_iter") + if ctx.train_iter == ctx.last_eval_iter or ( + (ctx.train_iter - ctx.last_eval_iter) < self.cfg.policy.eval.evaluator.eval_freq + and ctx.train_iter != 0): + return + env.reset() + eval_monitor = VectorEvalMonitor(env.env_num, self.cfg.env.n_evaluator_episode) + while not eval_monitor.is_finished(): + obs = env.ready_obs + obs = to_tensor(obs, dtype=torch.float32) + policy_output = self.policy.eval_mode.forward(obs) + action = to_ndarray({i: a['action'] for i, a in policy_output.items()}) + timesteps = env.step(action) + timesteps = to_tensor(timesteps, dtype=torch.float32) + for env_id, timestep in timesteps.items(): + if timestep.done: + self.policy.eval_mode.reset([env_id]) + reward = timestep.info['final_eval_reward'] + eval_monitor.update_reward(env_id, reward) + episode_reward = eval_monitor.get_episode_reward() + eval_reward = np.mean(episode_reward) + stop_flag = eval_reward >= self.cfg.env.stop_value and ctx.train_iter > 0 + print('Current Evaluation: Train Iter({})\tEval Reward({:.3f})'.format(ctx.train_iter, eval_reward)) + ctx.last_eval_iter = ctx.train_iter + if stop_flag: + ctx.finish = True + + return _eval + + +def main(cfg, model, seed=0): + with Task(async_mode=False) as task: + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + + collector_env = BaseEnvManager(env_fn=[partial(env_fn, cfg=c) for c in collector_env_cfg], cfg=cfg.env.manager) + evaluator_env = BaseEnvManager(env_fn=[partial(env_fn, cfg=c) for c in evaluator_env_cfg], cfg=cfg.env.manager) + + collector_env.seed(seed) + evaluator_env.seed(seed, dynamic_seed=False) + set_pkg_seed(seed, use_cuda=cfg.policy.cuda) + collector_env.launch() + evaluator_env.launch() + + replay_buffer = DequeBuffer() + sac = Pipeline(cfg, model) + + # task.use_step_wrapper(StepTimer(print_per_step=1)) + task.use(sac.evaluate(evaluator_env), filter_labels=["standalone", "node.0"]) + task.use( + task.sequence(sac.act(collector_env), sac.collect(collector_env, replay_buffer, task=task)), + filter_labels=["standalone", "node.[1-9]*"] + ) + task.use(sac.learn(replay_buffer, task=task), filter_labels=["standalone", "node.0"]) + task.run(max_step=100000) + + +if __name__ == "__main__": + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + model = QAC(**cfg.policy.model) + main(cfg, model) diff --git a/ding/framework/__init__.py b/ding/framework/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fb712ee8418be2ff17314248c447aa73ac21de16 --- /dev/null +++ b/ding/framework/__init__.py @@ -0,0 +1,3 @@ +from .context import Context +from .task import Task +from .parallel import Parallel diff --git a/ding/framework/context.py b/ding/framework/context.py new file mode 100644 index 0000000000000000000000000000000000000000..3b06ce252f29e9aae0aba353ba0bf2e01dbe4a6e --- /dev/null +++ b/ding/framework/context.py @@ -0,0 +1,34 @@ +class Context(dict): + """ + Overview: + Context is an object that pass contextual data between middlewares, whose life cycle + is only one training iteration. It is a dict that reflect itself, so you can set + any properties as you wish. + """ + + def __init__(self, total_step: int = 0, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.__dict__ = self + self.total_step = total_step + + # Reserved properties + self.finish = False + self._kept_keys = {"finish"} + + def renew(self) -> 'Context': # noqa + """ + Overview: + Renew context from self, add total_step and shift kept properties to the new instance. + """ + ctx = Context() + for key in self._kept_keys: + ctx[key] = self[key] + return ctx + + def keep(self, *keys: str) -> None: + """ + Overview: + Keep this key/keys until next iteration. + """ + for key in keys: + self._kept_keys.add(key) diff --git a/ding/framework/parallel.py b/ding/framework/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..66243ff452f079dd69f6b44377669f2be5929fc7 --- /dev/null +++ b/ding/framework/parallel.py @@ -0,0 +1,231 @@ +import atexit +import os +import random +import threading +import time +from mpire.pool import WorkerPool +import pynng +import pickle +import logging +import tempfile +import socket +from os import path +from typing import Callable, Dict, List, Optional, Tuple, Union +from threading import Thread +from pynng.nng import Bus0, Socket +from ding.utils.design_helper import SingletonMetaclass +from rich import print + +# Avoid ipc address conflict, random should always use random seed +random = random.Random() + + +class Parallel(metaclass=SingletonMetaclass): + + def __init__(self) -> None: + self._listener = None + self._sock: Socket = None + self._rpc = {} + self._bind_addr = None + self._lock = threading.Lock() + self.is_active = False + self.attach_to = None + self.finished = False + + def run(self, node_id: int, listen_to: str, attach_to: List[str] = None) -> None: + self.node_id = node_id + self.attach_to = attach_to = attach_to or [] + self._listener = Thread( + target=self.listen, + kwargs={ + "listen_to": listen_to, + "attach_to": attach_to + }, + name="paralllel_listener", + daemon=True + ) + self._listener.start() + + @staticmethod + def runner( + n_parallel_workers: int, + attach_to: Optional[List[str]] = None, + protocol: str = "ipc", + address: Optional[str] = None, + ports: Optional[List[int]] = None, + topology: str = "mesh" + ) -> Callable: + """ + Overview: + This method allows you to configure parallel parameters, and now you are still in the parent process. + Arguments: + - n_parallel_workers (:obj:`int`): Workers to spawn. + - attach_to (:obj:`Optional[List[str]]`): The node's addresses you want to attach to. + - protocol (:obj:`str`): Network protocol. + - address (:obj:`Optional[str]`): Bind address, ip or file path. + - ports (:obj:`Optional[List[int]]`): Candidate ports. + - topology (:obj:`str`): Network topology, includes: + `mesh` (default): fully connected between each other; + `star`: only connect to the first node; + Returns: + - _runner (:obj:`Callable`): The wrapper function for main. + """ + attach_to = attach_to or [] + assert n_parallel_workers > 0, "Parallel worker number should bigger than 0" + + def _runner(main_process: Callable, *args, **kwargs) -> None: + """ + Overview: + Prepare to run in subprocess. + Arguments: + - main_process (:obj:`Callable`): The main function, your program start from here. + """ + nodes = Parallel.get_node_addrs(n_parallel_workers, protocol=protocol, address=address, ports=ports) + logging.info("Bind subprocesses on these addresses: {}".format(nodes)) + print("Bind subprocesses on these addresses: {}".format(nodes)) + + def cleanup_nodes(): + for node in nodes: + protocol, file_path = node.split("://") + if protocol == "ipc" and path.exists(file_path): + os.remove(file_path) + + atexit.register(cleanup_nodes) + + def topology_network(node_id: int) -> List[str]: + if topology == "mesh": + return nodes[:node_id] + attach_to + elif topology == "star": + return nodes[:min(1, node_id)] + else: + raise ValueError("Unknown topology: {}".format(topology)) + + params_group = [] + for node_id in range(n_parallel_workers): + runner_args = [] + runner_kwargs = { + "node_id": node_id, + "listen_to": nodes[node_id], + "attach_to": topology_network(node_id) + attach_to + } + params = [(runner_args, runner_kwargs), (main_process, args, kwargs)] + params_group.append(params) + + if n_parallel_workers == 1: + Parallel.subprocess_runner(*params_group[0]) + else: + with WorkerPool(n_jobs=n_parallel_workers, start_method="spawn", daemon=False) as pool: + # Cleanup the pool just in case the program crashes. + atexit.register(pool.__exit__) + pool.map(Parallel.subprocess_runner, params_group) + + return _runner + + @staticmethod + def subprocess_runner(runner_params: Tuple[Union[List, Dict]], main_params: Tuple[Union[List, Dict]]) -> None: + """ + Overview: + Really run in subprocess. + Arguments: + - runner_params (:obj:`Tuple[Union[List, Dict]]`): Args and kwargs for runner. + - main_params (:obj:`Tuple[Union[List, Dict]]`): Args and kwargs for main function. + """ + main_process, args, kwargs = main_params + runner_args, runner_kwargs = runner_params + + with Parallel() as router: + router.is_active = True + router.run(*runner_args, **runner_kwargs) + main_process(*args, **kwargs) + + @staticmethod + def get_node_addrs( + n_workers: int, + protocol: str = "ipc", + address: Optional[str] = None, + ports: Optional[List[int]] = None + ) -> None: + if protocol == "ipc": + node_name = "".join(random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=4)) + tmp_dir = tempfile.gettempdir() + nodes = ["ipc://{}/ditask_{}_{}.ipc".format(tmp_dir, node_name, i) for i in range(n_workers)] + elif protocol == "tcp": + address = address or Parallel.get_ip() + ports = ports or range(50515, 50515 + n_workers) + assert len(ports) == n_workers, "The number of ports must be the same as the number of workers, \ +now there are {} ports and {} workers".format(len(ports), n_workers) + nodes = ["tcp://{}:{}".format(address, port) for port in ports] + else: + raise Exception("Unknown protocol {}".format(protocol)) + return nodes + + def listen(self, listen_to: str, attach_to: List[str] = None): + attach_to = attach_to or [] + self._bind_addr = listen_to + + with Bus0() as sock: + self._sock = sock + sock.listen(self._bind_addr) + time.sleep(0.1) # Wait for peers to bind + for contact in attach_to: + sock.dial(contact) + + while True: + try: + msg = sock.recv_msg() + self.recv_rpc(msg.bytes) + except pynng.Timeout: + logging.warning("Timeout on node {} when waiting for message from bus".format(self._bind_addr)) + except pynng.Closed: + if not self.finished: + logging.error("The socket was not closed under normal circumstances!") + break + except Exception as e: + logging.error("Meet exception when listening for new messages", e) + break + + def register_rpc(self, fn_name: str, fn: Callable) -> None: + self._rpc[fn_name] = fn + + def send_rpc(self, func_name: str, *args, **kwargs) -> None: + if self.is_active: + payload = {"f": func_name, "a": args, "k": kwargs} + return self._sock and self._sock.send(pickle.dumps(payload, protocol=-1)) + + def recv_rpc(self, msg: bytes): + try: + payload = pickle.loads(msg) + except Exception as e: + logging.warning("Error when unpacking message on node {}, msg: {}".format(self._bind_addr, e)) + if payload["f"] in self._rpc: + self._rpc[payload["f"]](*payload["a"], **payload["k"]) + else: + logging.warning("There was no function named {} in rpc table".format(payload["f"])) + + @staticmethod + def get_ip(): + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + # doesn't even have to be reachable + s.connect(('10.255.255.255', 1)) + IP = s.getsockname()[0] + except Exception: + IP = '127.0.0.1' + finally: + s.close() + return IP + + def __enter__(self) -> "Parallel": + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() + + def stop(self): + logging.info("Stopping parallel worker on address: {}".format(self._bind_addr)) + self.finished = True + time.sleep(0.03) + if self._sock: + self._sock.close() + if self._listener: + self._listener.join(timeout=1) diff --git a/ding/framework/task.py b/ding/framework/task.py new file mode 100644 index 0000000000000000000000000000000000000000..6362f01cfcdf7c55440c495d5914833e8d382920 --- /dev/null +++ b/ding/framework/task.py @@ -0,0 +1,312 @@ +from collections import defaultdict +import logging +import time +import asyncio +import concurrent.futures +import fnmatch +from types import GeneratorType +from typing import Awaitable, Callable, Dict, Generator, Iterable, List, Optional, Set +from ding.framework.context import Context +from ding.framework.parallel import Parallel + + +def enable_async(func: Callable) -> Callable: + """ + Overview: + Empower the function with async ability. + Arguments: + - func (:obj:`Callable`): The original function. + Returns: + - runtime_handler (:obj:`Callable`): The wrap function. + """ + + def runtime_handler(task: "Task", *args, **kwargs) -> "Task": + """ + Overview: + If task's async mode is enabled, execute the step in current loop executor asyncly, + or execute the task sync. + Arguments: + - task (:obj:`Task`): The task instance. + Returns: + - result (:obj:`Union[Any, Awaitable]`): The result or future object of middleware. + """ + if "async_mode" in kwargs: + async_mode = kwargs.pop("async_mode") + else: + async_mode = task.async_mode + if async_mode: + t = task._loop.run_in_executor(task._thread_pool, func, task, *args, **kwargs) + task._async_stack.append(t) + return task + else: + return func(task, *args, **kwargs) + + return runtime_handler + + +class Task: + """ + Tash will manage the execution order of the entire pipeline, register new middleware, + and generate new context objects. + """ + + def __init__( + self, + async_mode: bool = False, + n_async_workers: int = 3, + middleware: Optional[List[Callable]] = None, + step_wrappers: Optional[List[Callable]] = None, + event_listeners: Optional[Dict[str, List]] = None, + once_listeners: Optional[Dict[str, List]] = None, + attach_callback: Optional[Callable] = None, + labels: Optional[Set[str]] = None, + **_ + ) -> None: + self.middleware = middleware or [] + self.step_wrappers = step_wrappers or [] + self.ctx = Context() + self.parallel_ctx = Context() + self._backward_stack = [] + + # Async segment + self.async_mode = async_mode + self.n_async_workers = n_async_workers + self._async_stack = [] + self._loop = None + self._thread_pool = None + self.event_listeners = event_listeners or defaultdict(list) + self.once_listeners = once_listeners or defaultdict(list) + self.labels = labels or set() + + # Parallel segment + self.router = Parallel() + if async_mode or self.router.is_active: + self._thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=n_async_workers) + self._loop = asyncio.new_event_loop() + + if self.router.is_active: + self.router.register_rpc("task.emit", self.emit) + if attach_callback: + self.wait_for_attach_callback(attach_callback) + self.on("sync_parallel_ctx", self.sync_parallel_ctx) + + self.init_labels() + + def init_labels(self): + if self.async_mode: + self.labels.add("async") + if self.router.is_active: + self.labels.add("distributed") + self.labels.add("node.{}".format(self.router.node_id)) + else: + self.labels.add("standalone") + + def use(self, fn: Callable, filter_labels: Optional[Iterable[str]] = None) -> 'Task': + """ + Overview: + Register middleware to task. The middleware will be executed by it's registry order. + Arguments: + - fn (:obj:`Callable`): A middleware is a function with only one argument: ctx. + """ + if not filter_labels or any([fnmatch.filter(self.labels, v) for v in filter_labels]): + self.middleware.append(fn) + return self + + def use_step_wrapper(self, fn: Callable) -> 'Task': + self.step_wrappers.append(fn) + return self + + def run(self, max_step: int = 1e10) -> None: + """ + Overview: + Execute the iterations, when reach the max_step or task.finish is true, + The loop will be break. + Arguments: + - max_step (:obj:`int`): Max step of iterations. + """ + if len(self.middleware) == 0: + return + for i in range(max_step): + for fn in self.middleware: + self.forward(fn) + self.backward() + if i == max_step - 1: + self.ctx.finish = True + self.renew() + if self.finish: + break + + @enable_async + def forward(self, fn: Callable, ctx: Context = None, backward_stack: List[Generator] = None) -> 'Task': + """ + Overview: + This function will execute the middleware until the first yield statment, + or the end of the middleware. + Arguments: + - fn (:obj:`Callable`): Function with contain the ctx argument in middleware. + """ + if not backward_stack: + backward_stack = self._backward_stack + if not ctx: + ctx = self.ctx + for wrapper in self.step_wrappers: + fn = wrapper(fn) + g = fn(ctx) + if isinstance(g, GeneratorType): + try: + next(g) + backward_stack.append(g) + except StopIteration: + pass + return self + + @enable_async + def backward(self, backward_stack: List[Generator] = None) -> 'Task': + """ + Overview: + Execute the rest part of middleware, by the reversed order of registry. + """ + if not backward_stack: + backward_stack = self._backward_stack + while backward_stack: + # FILO + g = backward_stack.pop() + try: + next(g) + except StopIteration: + continue + return self + + def sequence(self, *fns: List[Callable]) -> Callable: + """ + Overview: + Wrap functions and keep them run in sequence, Usually in order to avoid the confusion + of dependencies in async mode. + Arguments: + - fn (:obj:`Callable`): Chain a sequence of middleware, wrap them into one middleware function. + """ + + def _sequence(ctx): + backward_stack = [] + for fn in fns: + self.forward(fn, ctx=ctx, backward_stack=backward_stack, async_mode=False) + yield + self.backward(backward_stack=backward_stack, async_mode=False) + + name = ",".join([fn.__name__ for fn in fns]) + _sequence.__name__ = "sequence<{}>".format(name) + return _sequence + + def renew(self) -> 'Task': + """ + Overview: + Renew the context instance, this function should be called after backward in the end of iteration. + """ + # Sync should be called before backward, otherwise it is possible + # that some generators have not been pushed to backward_stack. + self.sync() + self.backward() + self.sync() + # Renew context + old_ctx = self.ctx + if self.router.is_active: + # Send context to other parallel processes + self.async_executor(self.router.send_rpc, "task.emit", "sync_parallel_ctx", old_ctx) + + new_ctx = old_ctx.renew() + new_ctx.total_step = old_ctx.total_step + 1 + self.ctx = new_ctx + return self + + def __enter__(self) -> "Task": + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() + + def stop(self) -> None: + if self._thread_pool: + self._thread_pool.shutdown() + + def sync(self) -> 'Task': + if self._loop: + self._loop.run_until_complete(self.sync_tasks()) + return self + + async def sync_tasks(self) -> Awaitable[None]: + while self._async_stack: + # FIFO + t = self._async_stack.pop(0) + await t + + def wait_for_attach_callback(self, attach_callback: Callable, n_timeout: int = 30): + if len(self.router.attach_to) > 0: + logging.warning( + "The attach mode will wait for the latest context, an exception will \ +be thrown after the timeout {}s is reached".format(n_timeout) + ) + is_timeout = True + ctx = None + + def on_sync_parallel_ctx(new_ctx): + nonlocal ctx + ctx = new_ctx + + self.once("sync_parallel_ctx", on_sync_parallel_ctx) + for _ in range(n_timeout * 10): + if ctx: + is_timeout = False + break + time.sleep(0.1) + if is_timeout: + # If attach callback is defined, the attach mode should wait for callback finished, + # otherwise it may overwrite the training results of other processes + raise TimeoutError("Attach timeout, not received the latest context.") + attach_callback(ctx) + + def async_executor(self, fn: Callable, *args, **kwargs) -> None: + """ + Overview: + Execute task in background, then apppend the future instance in _async_stack. + Arguments: + - fn (:obj:`Callable`): Synchronization fuction. + """ + if not self._loop: + raise Exception("Event loop was not initialized, please call this function in async or parallel mode") + t = self._loop.run_in_executor(self._thread_pool, fn, *args, **kwargs) + self._async_stack.append(t) + + def emit(self, event_name, *args, **kwargs): + if event_name in self.event_listeners: + for fn in self.event_listeners[event_name]: + fn(*args, **kwargs) + if event_name in self.once_listeners: + while self.once_listeners[event_name]: + fn = self.once_listeners[event_name].pop() + fn(*args, **kwargs) + + def on(self, event: str, fn: Callable) -> None: + self.event_listeners[event].append(fn) + + def once(self, event: str, fn: Callable) -> None: + self.once_listeners[event].append(fn) + + @property + def finish(self) -> bool: + """ + Overview: + Link the ctx's finish state, in order to be easily called externally. + """ + return self.ctx.finish + + def __copy__(self): + return Task(**self.__dict__) + + def sync_parallel_ctx(self, ctx): + """ + Overview: + Sync parallel ctx + """ + self.parallel_ctx = ctx + if self.parallel_ctx.finish: + self.ctx.finish = True diff --git a/ding/framework/tests/test_context.py b/ding/framework/tests/test_context.py new file mode 100644 index 0000000000000000000000000000000000000000..406680f5600c466afd92d6f00102821ca3061b8d --- /dev/null +++ b/ding/framework/tests/test_context.py @@ -0,0 +1,14 @@ +import pytest +from ding.framework import Context +import pickle + + +@pytest.mark.unittest +def test_pickable(): + ctx = Context(hello="world", keep_me=True) + ctx.keep("keep_me") + _ctx = pickle.loads(pickle.dumps(ctx)) + assert _ctx.hello == "world" + + _ctx = ctx.renew() + assert _ctx.keep_me diff --git a/ding/framework/tests/test_parallel.py b/ding/framework/tests/test_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..2ccc7c84bda10f7a8a68f871b4631cf5e5b71b95 --- /dev/null +++ b/ding/framework/tests/test_parallel.py @@ -0,0 +1,60 @@ +from collections import defaultdict +import pytest +import time +import os +from ding.framework import Parallel +from ding.utils.design_helper import SingletonMetaclass + + +def parallel_main(): + msg = defaultdict(bool) + + def test_callback(key): + msg[key] = True + + with Parallel() as router: + router.register_rpc("test_callback", test_callback) + # Wait for nodes to bind + time.sleep(0.7) + + router.send_rpc("test_callback", "ping") + + for _ in range(30): + if msg["ping"]: + break + time.sleep(0.03) + assert msg["ping"] + + +@pytest.mark.unittest +def test_parallel_run(): + Parallel.runner(n_parallel_workers=2)(parallel_main) + Parallel.runner(n_parallel_workers=2, protocol="tcp")(parallel_main) + + +def parallel_main_alone(pid): + assert os.getpid() == pid + router = Parallel() + time.sleep(0.3) # Waiting to bind listening address + assert router._bind_addr + + +def test_parallel_run_alone(): + try: + Parallel.runner(n_parallel_workers=1)(parallel_main_alone, os.getpid()) + finally: + del SingletonMetaclass.instances[Parallel] + + +def star_parallel_main(): + with Parallel() as router: + if router.node_id != 0: + assert len(router.attach_to) == 1 + + # Wait for other nodes + time.sleep(2) + + +@pytest.mark.unittest +def test_parallel_topology(): + Parallel.runner(n_parallel_workers=3, topology="star")(star_parallel_main) diff --git a/ding/framework/tests/test_task.py b/ding/framework/tests/test_task.py new file mode 100644 index 0000000000000000000000000000000000000000..7f9c774b4d9ef2b0a615e182cae0ceb36f51e68d --- /dev/null +++ b/ding/framework/tests/test_task.py @@ -0,0 +1,226 @@ +from concurrent.futures import thread +from os import spawnl +from attr.validators import instance_of +import pytest +import time +import copy +import random +from mpire import WorkerPool +from ding.framework import Task +from ding.framework.context import Context +from ding.framework.parallel import Parallel +from ding.utils.design_helper import SingletonMetaclass + + +@pytest.mark.unittest +def test_serial_pipeline(): + + def step0(ctx): + ctx.setdefault("pipeline", []) + ctx.pipeline.append(0) + + def step1(ctx): + ctx.pipeline.append(1) + + # Execute step1, step2 twice + task = Task() + for _ in range(2): + task.forward(step0) + task.forward(step1) + assert task.ctx.pipeline == [0, 1, 0, 1] + + # Renew and execute step1, step2 + task.renew() + assert task.ctx.total_step == 1 + task.forward(step0) + task.forward(step1) + assert task.ctx.pipeline == [0, 1] + + # Test context inheritance + task.renew() + + +@pytest.mark.unittest +def test_serial_yield_pipeline(): + + def step0(ctx): + ctx.setdefault("pipeline", []) + ctx.pipeline.append(0) + yield + ctx.pipeline.append(0) + + def step1(ctx): + ctx.pipeline.append(1) + + task = Task() + task.forward(step0) + task.forward(step1) + task.backward() + assert task.ctx.pipeline == [0, 1, 0] + assert len(task._backward_stack) == 0 + + +@pytest.mark.unittest +def test_async_pipeline(): + + def step0(ctx): + ctx.setdefault("pipeline", []) + ctx.pipeline.append(0) + + def step1(ctx): + ctx.pipeline.append(1) + + # Execute step1, step2 twice + task = Task(async_mode=True) + for _ in range(2): + task.forward(step0) + time.sleep(0.1) + task.forward(step1) + time.sleep(0.1) + task.backward() + assert task.ctx.pipeline == [0, 1, 0, 1] + task.renew() + assert task.ctx.total_step == 1 + + +@pytest.mark.unittest +def test_async_yield_pipeline(): + + def step0(ctx): + ctx.setdefault("pipeline", []) + time.sleep(0.1) + ctx.pipeline.append(0) + yield + ctx.pipeline.append(0) + + def step1(ctx): + time.sleep(0.2) + ctx.pipeline.append(1) + + task = Task(async_mode=True) + task.forward(step0) + task.forward(step1) + time.sleep(0.3) + task.backward().sync() + assert task.ctx.pipeline == [0, 1, 0] + assert len(task._backward_stack) == 0 + + +def parallel_main(): + task = Task() + sync_count = 0 + + def on_sync_parallel_ctx(ctx): + nonlocal sync_count + assert isinstance(ctx, Context) + sync_count += 1 + + task.on("sync_parallel_ctx", on_sync_parallel_ctx) + task.use(lambda _: time.sleep(0.2 + random.random() / 10)) + task.run(max_step=10) + assert sync_count > 0 + + +def parallel_main_eager(): + task = Task() + sync_count = 0 + + def on_sync_parallel_ctx(ctx): + nonlocal sync_count + assert isinstance(ctx, Context) + sync_count += 1 + + task.on("sync_parallel_ctx", on_sync_parallel_ctx) + for _ in range(10): + task.forward(lambda _: time.sleep(0.2 + random.random() / 10)) + task.renew() + assert sync_count > 0 + + +@pytest.mark.unittest +def test_parallel_pipeline(): + Parallel.runner(n_parallel_workers=2)(parallel_main_eager) + Parallel.runner(n_parallel_workers=2)(parallel_main) + + +@pytest.mark.unittest +def test_copy_task(): + t1 = Task(async_mode=True, n_async_workers=1) + t2 = copy.copy(t1) + assert t2.async_mode + assert t1 is not t2 + + +def attach_mode_main_task(): + with Task() as task: + task.use(lambda _: time.sleep(0.1)) + task.run(max_step=10) + + +def attach_mode_attach_task(): + ctx = None + + def attach_callback(new_ctx): + nonlocal ctx + ctx = new_ctx + + with Task(attach_callback=attach_callback) as task: + task.use(lambda _: time.sleep(0.1)) + task.run(max_step=10) + assert ctx is not None + + +def attach_mode_main(job): + if job == "run_task": + Parallel.runner( + n_parallel_workers=2, protocol="tcp", address="127.0.0.1", ports=[50501, 50502] + )(attach_mode_main_task) + elif "run_attach_task": + time.sleep(0.3) + try: + Parallel.runner( + n_parallel_workers=1, + protocol="tcp", + address="127.0.0.1", + ports=[50503], + attach_to=["tcp://127.0.0.1:50501", "tcp://127.0.0.1:50502"] + )(attach_mode_attach_task) + finally: + del SingletonMetaclass.instances[Parallel] + else: + raise Exception("Unknown task") + + +@pytest.mark.unittest +def test_attach_mode(): + with WorkerPool(n_jobs=2, daemon=False, start_method="spawn") as pool: + pool.map(attach_mode_main, ["run_task", "run_attach_task"]) + + +@pytest.mark.unittest +def test_label(): + task = Task() + result = {} + task.use(lambda _: result.setdefault("not_me", True), filter_labels=["async"]) + task.use(lambda _: result.setdefault("has_me", True)) + task.run(max_step=1) + + assert "not_me" not in result + assert "has_me" in result + + +def sync_parallel_ctx_main(): + with Task() as task: + task.use(lambda _: time.sleep(1)) + if task.router.node_id == 0: # Fast + task.run(max_step=2) + else: # Slow + task.run(max_step=10) + assert task.parallel_ctx + assert task.ctx.finish + assert task.ctx.total_step < 9 + + +@pytest.mark.unittest +def test_sync_parallel_ctx(): + Parallel.runner(n_parallel_workers=2)(sync_parallel_ctx_main) diff --git a/ding/framework/tests/test_wrapper.py b/ding/framework/tests/test_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..34fc84fba3699aa24d50a4aad4cf52458141517f --- /dev/null +++ b/ding/framework/tests/test_wrapper.py @@ -0,0 +1,66 @@ +# In use mode +# In forward mode +# Wrapper in wrapper + +import pytest +from ding.framework.task import Task +from ding.framework.wrapper import StepTimer + + +@pytest.mark.unittest +def test_step_timer(): + + def step1(_): + 1 + + def step2(_): + 2 + + def step3(_): + 3 + + def step4(_): + 4 + + # Lazy mode (with use statment) + step_timer = StepTimer() + task = Task() + task.use_step_wrapper(step_timer) + task.use(step1) + task.use(step2) + task.use(task.sequence(step3, step4)) + task.run(3) + + assert len(step_timer.records) == 5 + for records in step_timer.records.values(): + assert len(records) == 3 + + # Eager mode (with forward statment) + step_timer = StepTimer() + task = Task() + task.use_step_wrapper(step_timer) + for _ in range(3): + task.forward(step1) # Step 1 + task.forward(step2) # Step 2 + task.renew() + + assert len(step_timer.records) == 2 + for records in step_timer.records.values(): + assert len(records) == 3 + + # Wrapper in wrapper + step_timer1 = StepTimer() + step_timer2 = StepTimer() + task = Task() + task.use_step_wrapper(step_timer1) + task.use_step_wrapper(step_timer2) + task.use(step1) + task.use(step2) + task.run(3) + + assert len(step_timer1.records) == 2 + assert len(step_timer2.records) == 2 + for records in step_timer1.records.values(): + assert len(records) == 3 + for records in step_timer2.records.values(): + assert len(records) == 3 diff --git a/ding/framework/wrapper/__init__.py b/ding/framework/wrapper/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4883440435c131c4e7bf65eab526e1cf36edc49b --- /dev/null +++ b/ding/framework/wrapper/__init__.py @@ -0,0 +1 @@ +from .step_timer import StepTimer diff --git a/ding/framework/wrapper/step_timer.py b/ding/framework/wrapper/step_timer.py new file mode 100644 index 0000000000000000000000000000000000000000..ab504ba3a7c1bf20f32e7dfbad5b3b64c830fa70 --- /dev/null +++ b/ding/framework/wrapper/step_timer.py @@ -0,0 +1,48 @@ +from collections import deque, defaultdict +from types import GeneratorType +from typing import Callable +from rich import print +import numpy as np +import time + + +class StepTimer: + + def __init__(self, print_per_step: int = 1, smooth_window: int = 10) -> None: + self.print_per_step = print_per_step + self.records = defaultdict(lambda: deque(maxlen=print_per_step * smooth_window)) + + def __call__(self, fn: Callable) -> Callable: + step_name = getattr(fn, "__name__", type(fn).__name__) + step_id = id(fn) + + def executor(ctx): + start_time = time.time() + time_cost = 0 + g = fn(ctx) + if isinstance(g, GeneratorType): + try: + next(g) + except StopIteration: + pass + time_cost = time.time() - start_time + yield + start_time = time.time() + try: + next(g) + except StopIteration: + pass + time_cost += time.time() - start_time + else: + time_cost = time.time() - start_time + self.records[step_id].append(time_cost * 1000) + if ctx.total_step % self.print_per_step == 0: + print( + "[Step Timer] {}: Cost: {:.2f}ms, Mean: {:.2f}ms".format( + step_name, time_cost * 1000, np.mean(self.records[step_id]) + ) + ) + + executor.__name__ = "StepTimer<{}>".format(step_name) + + return executor diff --git a/ding/utils/__init__.py b/ding/utils/__init__.py index 286ecd1f910e5735637e2984fe6a755657dcbb70..5ff3b04ae5f39b9cd0c182df8cbf3963c8fe5548 100644 --- a/ding/utils/__init__.py +++ b/ding/utils/__init__.py @@ -24,6 +24,7 @@ from .system_helper import get_ip, get_pid, get_task_uid, PropagatingThread, fin from .time_helper import build_time_helper, EasyTimer, WatchDog from .type_helper import SequenceType from .scheduler_helper import Scheduler +from .profiler_helper import Profiler, register_profiler if ding.enable_linklink: from .linklink_dist_helper import get_rank, get_world_size, dist_mode, dist_init, dist_finalize, \ diff --git a/ding/utils/profiler_helper.py b/ding/utils/profiler_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..20e1024629fa7fdf535c32c6518759ab55f9d3a6 --- /dev/null +++ b/ding/utils/profiler_helper.py @@ -0,0 +1,41 @@ +import atexit +import pstats +import io +import cProfile +import os + + +def register_profiler(write_profile, pr, folder_path): + atexit.register(write_profile, pr, folder_path) + + +class Profiler: + + def __init__(self): + self.pr = cProfile.Profile() + + def mkdir(self, directory): + if not os.path.exists(directory): + os.makedirs(directory) + + def write_profile(self, pr, folder_path): + pr.disable() + s_tottime = io.StringIO() + s_cumtime = io.StringIO() + + ps = pstats.Stats(pr, stream=s_tottime).sort_stats('tottime') + ps.print_stats() + with open(folder_path + "/profile_tottime.txt", 'w+') as f: + f.write(s_tottime.getvalue()) + + ps = pstats.Stats(pr, stream=s_cumtime).sort_stats('cumtime') + ps.print_stats() + with open(folder_path + "/profile_cumtime.txt", 'w+') as f: + f.write(s_cumtime.getvalue()) + + pr.dump_stats(folder_path + "/profile.prof") + + def profile(self, folder_path="./tmp"): + self.mkdir(folder_path) + self.pr.enable() + register_profiler(self.write_profile, self.pr, folder_path) diff --git a/ding/utils/tests/test_profiler_helper.py b/ding/utils/tests/test_profiler_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..fd1b075cb2cd6b70fbbdd1dce013e0fcb6fcdaff --- /dev/null +++ b/ding/utils/tests/test_profiler_helper.py @@ -0,0 +1,42 @@ +from easydict import EasyDict +import pytest +import unittest +from unittest import mock +from unittest.mock import patch +import pathlib as pl +import os +import shutil + +from ding.utils.profiler_helper import Profiler, register_profiler + + +@pytest.mark.unittest +class TestProfilerModule: + + def assertIsFile(self, path): + if not pl.Path(path).resolve().is_file(): + raise AssertionError("File does not exist: %s" % str(path)) + + def test(self): + profiler = Profiler() + + def register_mock(write_profile, pr, folder_path): + profiler.write_profile(pr, folder_path) + + def clean_up(dir): + if os.path.exists(dir): + shutil.rmtree(dir) + + dir = "./tmp_test/" + clean_up(dir) + + with patch('ding.utils.profiler_helper.register_profiler', register_mock): + profiler.profile(dir) + file_path = os.path.join(dir, "profile_tottime.txt") + self.assertIsFile(file_path) + file_path = os.path.join(dir, "profile_cumtime.txt") + self.assertIsFile(file_path) + file_path = os.path.join(dir, "profile.prof") + self.assertIsFile(file_path) + + clean_up(dir) diff --git a/ding/worker/learner/base_learner.py b/ding/worker/learner/base_learner.py index 20b3b7afcf5c78edac9312181452ed6ad63ed89a..13f0fcb9d3c14b3992ba7839fb8e6c553b2b85c1 100644 --- a/ding/worker/learner/base_learner.py +++ b/ding/worker/learner/base_learner.py @@ -238,6 +238,8 @@ class BaseLearner(object): self.call_hook('after_iter') self._last_iter.add(1) + return log_vars + @auto_checkpoint def start(self) -> None: """ diff --git a/setup.py b/setup.py index f5745074bc60fff6b76f5c59e752de722269fd66..ad29a48297087486c3a8d3edadea0e8779822378 100755 --- a/setup.py +++ b/setup.py @@ -73,6 +73,9 @@ setup( 'scipy', 'trueskill', 'h5py', + 'rich', + 'mpire', + 'pynng' ], extras_require={ 'test': [ @@ -135,7 +138,6 @@ setup( # 'gym_soccer_env': [ # 'gym-soccer @ git+https://github.com/LikeJulia/gym-soccer@dev-install-packages#egg=gym-soccer', # ], - 'sc2_env': [ 'absl-py>=0.1.0', 'future',