diff --git a/.gitignore b/.gitignore index 562404d1cbdfdffd2a33a785e21b24be48f913f3..413a0b66f3d51687deec36d8aa395f351fa457ad 100644 --- a/.gitignore +++ b/.gitignore @@ -1419,3 +1419,4 @@ formatted_* eval_config.py collect_demo_data_config.py default* +!ding/**/*.py diff --git a/ding/framework/context.py b/ding/framework/context.py index 3b06ce252f29e9aae0aba353ba0bf2e01dbe4a6e..2683a6af4260452c6639b4cb13641018ce7755cb 100644 --- a/ding/framework/context.py +++ b/ding/framework/context.py @@ -22,7 +22,8 @@ class Context(dict): """ ctx = Context() for key in self._kept_keys: - ctx[key] = self[key] + if key in self: + ctx[key] = self[key] return ctx def keep(self, *keys: str) -> None: diff --git a/ding/framework/parallel.py b/ding/framework/parallel.py index 66243ff452f079dd69f6b44377669f2be5929fc7..0bd190cfc1b254f0bed9f9ede15dd066f4bd3b02 100644 --- a/ding/framework/parallel.py +++ b/ding/framework/parallel.py @@ -1,7 +1,6 @@ import atexit import os import random -import threading import time from mpire.pool import WorkerPool import pynng @@ -27,10 +26,10 @@ class Parallel(metaclass=SingletonMetaclass): self._sock: Socket = None self._rpc = {} self._bind_addr = None - self._lock = threading.Lock() self.is_active = False self.attach_to = None self.finished = False + self.node_id = None def run(self, node_id: int, listen_to: str, attach_to: List[str] = None) -> None: self.node_id = node_id @@ -187,6 +186,10 @@ now there are {} ports and {} workers".format(len(ports), n_workers) def register_rpc(self, fn_name: str, fn: Callable) -> None: self._rpc[fn_name] = fn + def unregister_rpc(self, fn_name: str) -> None: + if fn_name in self._rpc: + del self._rpc[fn_name] + def send_rpc(self, func_name: str, *args, **kwargs) -> None: if self.is_active: payload = {"f": func_name, "a": args, "k": kwargs} @@ -198,7 +201,8 @@ now there are {} ports and {} workers".format(len(ports), n_workers) except Exception as e: logging.warning("Error when unpacking message on node {}, msg: {}".format(self._bind_addr, e)) if payload["f"] in self._rpc: - self._rpc[payload["f"]](*payload["a"], **payload["k"]) + fn = self._rpc[payload["f"]] + fn(*payload["a"], **payload["k"]) else: logging.warning("There was no function named {} in rpc table".format(payload["f"])) @@ -224,6 +228,7 @@ now there are {} ports and {} workers".format(len(ports), n_workers) def stop(self): logging.info("Stopping parallel worker on address: {}".format(self._bind_addr)) self.finished = True + self._rpc.clear() time.sleep(0.03) if self._sock: self._sock.close() diff --git a/ding/framework/task.py b/ding/framework/task.py index 6362f01cfcdf7c55440c495d5914833e8d382920..92bef9ad10974a7afd2adb775b03a60aaf44037c 100644 --- a/ding/framework/task.py +++ b/ding/framework/task.py @@ -225,8 +225,14 @@ class Task: self.stop() def stop(self) -> None: + self.emit("exit") if self._thread_pool: self._thread_pool.shutdown() + # The middleware and listeners may contain some methods that reference to task, + # If we do not clear them after the task exits, we may find that gc will not clean up the task object. + self.middleware.clear() + self.event_listeners.clear() + self.once_listeners.clear() def sync(self) -> 'Task': if self._loop: diff --git a/ding/framework/tests/test_parallel.py b/ding/framework/tests/test_parallel.py index 2ccc7c84bda10f7a8a68f871b4631cf5e5b71b95..a928424e3f4dd40b2c1f45b18f8c1e6cd4c8cebc 100644 --- a/ding/framework/tests/test_parallel.py +++ b/ding/framework/tests/test_parallel.py @@ -16,14 +16,14 @@ def parallel_main(): router.register_rpc("test_callback", test_callback) # Wait for nodes to bind time.sleep(0.7) - - router.send_rpc("test_callback", "ping") - for _ in range(30): + router.send_rpc("test_callback", "ping") if msg["ping"]: break time.sleep(0.03) assert msg["ping"] + # Avoid can not receiving messages from each other after exit parallel + time.sleep(0.7) @pytest.mark.unittest diff --git a/ding/framework/tests/test_task.py b/ding/framework/tests/test_task.py index 7f9c774b4d9ef2b0a615e182cae0ceb36f51e68d..bb44d9a9359c11b599f2912ce4e02fea495365f6 100644 --- a/ding/framework/tests/test_task.py +++ b/ding/framework/tests/test_task.py @@ -23,21 +23,21 @@ def test_serial_pipeline(): ctx.pipeline.append(1) # Execute step1, step2 twice - task = Task() - for _ in range(2): + with Task() as task: + for _ in range(2): + task.forward(step0) + task.forward(step1) + assert task.ctx.pipeline == [0, 1, 0, 1] + + # Renew and execute step1, step2 + task.renew() + assert task.ctx.total_step == 1 task.forward(step0) task.forward(step1) - assert task.ctx.pipeline == [0, 1, 0, 1] - - # Renew and execute step1, step2 - task.renew() - assert task.ctx.total_step == 1 - task.forward(step0) - task.forward(step1) - assert task.ctx.pipeline == [0, 1] + assert task.ctx.pipeline == [0, 1] - # Test context inheritance - task.renew() + # Test context inheritance + task.renew() @pytest.mark.unittest @@ -52,12 +52,12 @@ def test_serial_yield_pipeline(): def step1(ctx): ctx.pipeline.append(1) - task = Task() - task.forward(step0) - task.forward(step1) - task.backward() - assert task.ctx.pipeline == [0, 1, 0] - assert len(task._backward_stack) == 0 + with Task() as task: + task.forward(step0) + task.forward(step1) + task.backward() + assert task.ctx.pipeline == [0, 1, 0] + assert len(task._backward_stack) == 0 @pytest.mark.unittest @@ -71,16 +71,16 @@ def test_async_pipeline(): ctx.pipeline.append(1) # Execute step1, step2 twice - task = Task(async_mode=True) - for _ in range(2): - task.forward(step0) - time.sleep(0.1) - task.forward(step1) - time.sleep(0.1) - task.backward() - assert task.ctx.pipeline == [0, 1, 0, 1] - task.renew() - assert task.ctx.total_step == 1 + with Task(async_mode=True) as task: + for _ in range(2): + task.forward(step0) + time.sleep(0.1) + task.forward(step1) + time.sleep(0.1) + task.backward() + assert task.ctx.pipeline == [0, 1, 0, 1] + task.renew() + assert task.ctx.total_step == 1 @pytest.mark.unittest @@ -97,17 +97,16 @@ def test_async_yield_pipeline(): time.sleep(0.2) ctx.pipeline.append(1) - task = Task(async_mode=True) - task.forward(step0) - task.forward(step1) - time.sleep(0.3) - task.backward().sync() - assert task.ctx.pipeline == [0, 1, 0] - assert len(task._backward_stack) == 0 + with Task(async_mode=True) as task: + task.forward(step0) + task.forward(step1) + time.sleep(0.3) + task.backward().sync() + assert task.ctx.pipeline == [0, 1, 0] + assert len(task._backward_stack) == 0 def parallel_main(): - task = Task() sync_count = 0 def on_sync_parallel_ctx(ctx): @@ -115,14 +114,14 @@ def parallel_main(): assert isinstance(ctx, Context) sync_count += 1 - task.on("sync_parallel_ctx", on_sync_parallel_ctx) - task.use(lambda _: time.sleep(0.2 + random.random() / 10)) - task.run(max_step=10) - assert sync_count > 0 + with Task() as task: + task.on("sync_parallel_ctx", on_sync_parallel_ctx) + task.use(lambda _: time.sleep(0.2 + random.random() / 10)) + task.run(max_step=10) + assert sync_count > 0 def parallel_main_eager(): - task = Task() sync_count = 0 def on_sync_parallel_ctx(ctx): @@ -130,11 +129,12 @@ def parallel_main_eager(): assert isinstance(ctx, Context) sync_count += 1 - task.on("sync_parallel_ctx", on_sync_parallel_ctx) - for _ in range(10): - task.forward(lambda _: time.sleep(0.2 + random.random() / 10)) - task.renew() - assert sync_count > 0 + with Task() as task: + task.on("sync_parallel_ctx", on_sync_parallel_ctx) + for _ in range(10): + task.forward(lambda _: time.sleep(0.2 + random.random() / 10)) + task.renew() + assert sync_count > 0 @pytest.mark.unittest @@ -143,14 +143,6 @@ def test_parallel_pipeline(): Parallel.runner(n_parallel_workers=2)(parallel_main) -@pytest.mark.unittest -def test_copy_task(): - t1 = Task(async_mode=True, n_async_workers=1) - t2 = copy.copy(t1) - assert t2.async_mode - assert t1 is not t2 - - def attach_mode_main_task(): with Task() as task: task.use(lambda _: time.sleep(0.1)) @@ -199,14 +191,14 @@ def test_attach_mode(): @pytest.mark.unittest def test_label(): - task = Task() - result = {} - task.use(lambda _: result.setdefault("not_me", True), filter_labels=["async"]) - task.use(lambda _: result.setdefault("has_me", True)) - task.run(max_step=1) - - assert "not_me" not in result - assert "has_me" in result + with Task() as task: + result = {} + task.use(lambda _: result.setdefault("not_me", True), filter_labels=["async"]) + task.use(lambda _: result.setdefault("has_me", True)) + task.run(max_step=1) + + assert "not_me" not in result + assert "has_me" in result def sync_parallel_ctx_main(): diff --git a/ding/framework/tests/test_wrapper.py b/ding/framework/tests/test_wrapper.py index 34fc84fba3699aa24d50a4aad4cf52458141517f..ad26a18321a384023f9b2782d0d95fe0f663bca7 100644 --- a/ding/framework/tests/test_wrapper.py +++ b/ding/framework/tests/test_wrapper.py @@ -24,12 +24,12 @@ def test_step_timer(): # Lazy mode (with use statment) step_timer = StepTimer() - task = Task() - task.use_step_wrapper(step_timer) - task.use(step1) - task.use(step2) - task.use(task.sequence(step3, step4)) - task.run(3) + with Task() as task: + task.use_step_wrapper(step_timer) + task.use(step1) + task.use(step2) + task.use(task.sequence(step3, step4)) + task.run(3) assert len(step_timer.records) == 5 for records in step_timer.records.values(): @@ -37,12 +37,12 @@ def test_step_timer(): # Eager mode (with forward statment) step_timer = StepTimer() - task = Task() - task.use_step_wrapper(step_timer) - for _ in range(3): - task.forward(step1) # Step 1 - task.forward(step2) # Step 2 - task.renew() + with Task() as task: + task.use_step_wrapper(step_timer) + for _ in range(3): + task.forward(step1) # Step 1 + task.forward(step2) # Step 2 + task.renew() assert len(step_timer.records) == 2 for records in step_timer.records.values(): @@ -51,15 +51,19 @@ def test_step_timer(): # Wrapper in wrapper step_timer1 = StepTimer() step_timer2 = StepTimer() - task = Task() - task.use_step_wrapper(step_timer1) - task.use_step_wrapper(step_timer2) - task.use(step1) - task.use(step2) - task.run(3) + with Task() as task: + task.use_step_wrapper(step_timer1) + task.use_step_wrapper(step_timer2) + task.use(step1) + task.use(step2) + task.run(3) - assert len(step_timer1.records) == 2 - assert len(step_timer2.records) == 2 + try: + assert len(step_timer1.records) == 2 + assert len(step_timer2.records) == 2 + except: + print("ExceptionStepTimer", step_timer2.records) + raise Exception("StepTimer error") for records in step_timer1.records.values(): assert len(records) == 3 for records in step_timer2.records.values(): diff --git a/ding/utils/__init__.py b/ding/utils/__init__.py index 5ff3b04ae5f39b9cd0c182df8cbf3963c8fe5548..1b5b0240834dd726e9869f5f0d1ae3b64872a54e 100644 --- a/ding/utils/__init__.py +++ b/ding/utils/__init__.py @@ -25,6 +25,7 @@ from .time_helper import build_time_helper, EasyTimer, WatchDog from .type_helper import SequenceType from .scheduler_helper import Scheduler from .profiler_helper import Profiler, register_profiler +from .log_writer_helper import DistributedWriter if ding.enable_linklink: from .linklink_dist_helper import get_rank, get_world_size, dist_mode, dist_init, dist_finalize, \ diff --git a/ding/utils/log_helper.py b/ding/utils/log_helper.py index 41e76e4d5d356c3ba2e29cb53efc7b5c09d4a477..11ae1ce91e0a4a08802acf75a37b6a59e5c9dcd4 100644 --- a/ding/utils/log_helper.py +++ b/ding/utils/log_helper.py @@ -4,7 +4,7 @@ import os import numpy as np import yaml from tabulate import tabulate -from tensorboardX import SummaryWriter +from .log_writer_helper import DistributedWriter from typing import Optional, Tuple, Union, Dict, Any @@ -32,7 +32,7 @@ def build_logger( name = 'default' logger = LoggerFactory.create_logger(path, name=name) if need_text else None tb_name = name + '_tb_logger' - tb_logger = SummaryWriter(os.path.join(path, tb_name)) if need_tb else None + tb_logger = DistributedWriter(os.path.join(path, tb_name)) if need_tb else None return logger, tb_logger diff --git a/ding/utils/log_writer_helper.py b/ding/utils/log_writer_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..5af239f1e72c61ad71d55244fde66c40602f21cf --- /dev/null +++ b/ding/utils/log_writer_helper.py @@ -0,0 +1,104 @@ +from tensorboardX import SummaryWriter +from typing import TYPE_CHECKING +if TYPE_CHECKING: + # TYPE_CHECKING is always False at runtime, but mypy will evaluate the contents of this block. + # So if you import this module within TYPE_CHECKING, you will get code hints and other benefits. + # Here is a good answer on stackoverflow: + # https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports + from ding.framework import Task + + +class DistributedWriter(SummaryWriter): + """ + Overview: + A simple subclass of SummaryWriter that supports writing to one process in multi-process mode. + The best way is to use it in conjunction with the ``task`` to take advantage of the message \ + and event components of the task (see ``writer.plugin``). + """ + + def __init__(self, *args, **kwargs): + self._default_writer_to_disk = kwargs.get("write_to_disk") if "write_to_disk" in kwargs else True + # We need to write data to files lazily, so we should not use file writer in __init__, + # On the contrary, we will initialize the file writer when the user calls the + # add_* function for the first time + kwargs["write_to_disk"] = False + super().__init__(*args, **kwargs) + self._in_parallel = False + self._task = None + self._is_writer = False + self._lazy_initialized = False + + def plugin(self, task: "Task", is_writer: bool = False) -> "DistributedWriter": + """ + Overview: + Plugin ``task``, so when using this writer in the task pipeline, it will automatically send requests\ + to the main writer instead of writing it to the disk. So we can collect data from multiple processes\ + and write them into one file. + Examples: + >>> DistributedWriter().plugin(task, is_writer=("node.0" in task.labels)) + """ + if task.router.is_active: + self._in_parallel = True + self._task = task + self._is_writer = is_writer + if is_writer: + self.initialize() + self._lazy_initialized = True + task.router.register_rpc("distributed_writer", self._on_distributed_writer) + task.once("exit", lambda: self.close()) + return self + + def _on_distributed_writer(self, fn_name: str, *args, **kwargs): + if self._is_writer: + getattr(self, fn_name)(*args, **kwargs) + + def initialize(self): + self.close() + self._write_to_disk = self._default_writer_to_disk + self._get_file_writer() + self._lazy_initialized = True + + def __del__(self): + self.close() + + +def enable_parallel(fn_name, fn): + + def _parallel_fn(self: DistributedWriter, *args, **kwargs): + if not self._lazy_initialized: + self.initialize() + if self._in_parallel and not self._is_writer: + self._task.router.send_rpc("distributed_writer", fn_name, *args, **kwargs) + else: + fn(self, *args, **kwargs) + + return _parallel_fn + + +ready_to_parallel_fns = [ + 'add_audio', + 'add_custom_scalars', + 'add_custom_scalars_marginchart', + 'add_custom_scalars_multilinechart', + 'add_embedding', + 'add_figure', + 'add_graph', + 'add_graph_deprecated', + 'add_histogram', + 'add_histogram_raw', + 'add_hparams', + 'add_image', + 'add_image_with_boxes', + 'add_images', + 'add_mesh', + 'add_onnx_graph', + 'add_openvino_graph', + 'add_pr_curve', + 'add_pr_curve_raw', + 'add_scalar', + 'add_scalars', + 'add_text', + 'add_video', +] +for fn_name in ready_to_parallel_fns: + setattr(DistributedWriter, fn_name, enable_parallel(fn_name, getattr(DistributedWriter, fn_name))) diff --git a/ding/utils/tests/test_log_helper.py b/ding/utils/tests/test_log_helper.py index 116890d7af25bb6c5138b345a2eeecd9f47380f3..439015ac0c610b69e33bc57522d54828dffe8b7c 100644 --- a/ding/utils/tests/test_log_helper.py +++ b/ding/utils/tests/test_log_helper.py @@ -1,9 +1,6 @@ import random -from collections import deque -import numpy as np import pytest from easydict import EasyDict -import logging from ding.utils.log_helper import build_logger, pretty_print from ding.utils.file_helper import remove_file diff --git a/ding/utils/tests/test_log_writer_helper.py b/ding/utils/tests/test_log_writer_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..dfc5c8a448b306d051d1988df75316b4dda21069 --- /dev/null +++ b/ding/utils/tests/test_log_writer_helper.py @@ -0,0 +1,39 @@ +import pytest +import time +import tempfile +import shutil +import os +from os import path +from ding.framework import Parallel +from ding.framework.task import Task +from ding.utils import DistributedWriter + + +def main_distributed_writer(tempdir): + with Task() as task: + time.sleep(task.router.node_id * 1) # Sleep 0 and 1, write to different files + + tblogger = DistributedWriter(tempdir).plugin(task, is_writer=("node.0" in task.labels)) + + def _add_scalar(ctx): + n = 10 + for i in range(n): + tblogger.add_scalar(str(task.router.node_id), task.router.node_id, ctx.total_step * n + i) + + task.use(_add_scalar) + task.use(lambda _: time.sleep(0.2)) + task.run(max_step=10) + + time.sleep(0.3 + (1 - task.router.node_id) * 2) + + +@pytest.mark.unittest +def test_distributed_writer(): + tempdir = path.join(tempfile.gettempdir(), "tblogger") + try: + Parallel.runner(n_parallel_workers=2)(main_distributed_writer, tempdir) + assert path.exists(tempdir) + assert len(os.listdir(tempdir)) == 1 + finally: + if path.exists(tempdir): + shutil.rmtree(tempdir)