未验证 提交 92d973c1 编写于 作者: X Xu Jingxin 提交者: GitHub

feature(xjx): multiprocess tblogger, fix circular reference problem (#156)

* Fix recur reference in task and parallel, add distributed logger

* Update logger

* Clear ref list when exit task/parallel

* Put task in with statment

* Fix test

* FFix test

* Test is hard

* More comments
上级 1040c9fc
...@@ -1419,3 +1419,4 @@ formatted_* ...@@ -1419,3 +1419,4 @@ formatted_*
eval_config.py eval_config.py
collect_demo_data_config.py collect_demo_data_config.py
default* default*
!ding/**/*.py
...@@ -22,7 +22,8 @@ class Context(dict): ...@@ -22,7 +22,8 @@ class Context(dict):
""" """
ctx = Context() ctx = Context()
for key in self._kept_keys: for key in self._kept_keys:
ctx[key] = self[key] if key in self:
ctx[key] = self[key]
return ctx return ctx
def keep(self, *keys: str) -> None: def keep(self, *keys: str) -> None:
......
import atexit import atexit
import os import os
import random import random
import threading
import time import time
from mpire.pool import WorkerPool from mpire.pool import WorkerPool
import pynng import pynng
...@@ -27,10 +26,10 @@ class Parallel(metaclass=SingletonMetaclass): ...@@ -27,10 +26,10 @@ class Parallel(metaclass=SingletonMetaclass):
self._sock: Socket = None self._sock: Socket = None
self._rpc = {} self._rpc = {}
self._bind_addr = None self._bind_addr = None
self._lock = threading.Lock()
self.is_active = False self.is_active = False
self.attach_to = None self.attach_to = None
self.finished = False self.finished = False
self.node_id = None
def run(self, node_id: int, listen_to: str, attach_to: List[str] = None) -> None: def run(self, node_id: int, listen_to: str, attach_to: List[str] = None) -> None:
self.node_id = node_id self.node_id = node_id
...@@ -187,6 +186,10 @@ now there are {} ports and {} workers".format(len(ports), n_workers) ...@@ -187,6 +186,10 @@ now there are {} ports and {} workers".format(len(ports), n_workers)
def register_rpc(self, fn_name: str, fn: Callable) -> None: def register_rpc(self, fn_name: str, fn: Callable) -> None:
self._rpc[fn_name] = fn self._rpc[fn_name] = fn
def unregister_rpc(self, fn_name: str) -> None:
if fn_name in self._rpc:
del self._rpc[fn_name]
def send_rpc(self, func_name: str, *args, **kwargs) -> None: def send_rpc(self, func_name: str, *args, **kwargs) -> None:
if self.is_active: if self.is_active:
payload = {"f": func_name, "a": args, "k": kwargs} payload = {"f": func_name, "a": args, "k": kwargs}
...@@ -198,7 +201,8 @@ now there are {} ports and {} workers".format(len(ports), n_workers) ...@@ -198,7 +201,8 @@ now there are {} ports and {} workers".format(len(ports), n_workers)
except Exception as e: except Exception as e:
logging.warning("Error when unpacking message on node {}, msg: {}".format(self._bind_addr, e)) logging.warning("Error when unpacking message on node {}, msg: {}".format(self._bind_addr, e))
if payload["f"] in self._rpc: if payload["f"] in self._rpc:
self._rpc[payload["f"]](*payload["a"], **payload["k"]) fn = self._rpc[payload["f"]]
fn(*payload["a"], **payload["k"])
else: else:
logging.warning("There was no function named {} in rpc table".format(payload["f"])) logging.warning("There was no function named {} in rpc table".format(payload["f"]))
...@@ -224,6 +228,7 @@ now there are {} ports and {} workers".format(len(ports), n_workers) ...@@ -224,6 +228,7 @@ now there are {} ports and {} workers".format(len(ports), n_workers)
def stop(self): def stop(self):
logging.info("Stopping parallel worker on address: {}".format(self._bind_addr)) logging.info("Stopping parallel worker on address: {}".format(self._bind_addr))
self.finished = True self.finished = True
self._rpc.clear()
time.sleep(0.03) time.sleep(0.03)
if self._sock: if self._sock:
self._sock.close() self._sock.close()
......
...@@ -225,8 +225,14 @@ class Task: ...@@ -225,8 +225,14 @@ class Task:
self.stop() self.stop()
def stop(self) -> None: def stop(self) -> None:
self.emit("exit")
if self._thread_pool: if self._thread_pool:
self._thread_pool.shutdown() self._thread_pool.shutdown()
# The middleware and listeners may contain some methods that reference to task,
# If we do not clear them after the task exits, we may find that gc will not clean up the task object.
self.middleware.clear()
self.event_listeners.clear()
self.once_listeners.clear()
def sync(self) -> 'Task': def sync(self) -> 'Task':
if self._loop: if self._loop:
......
...@@ -16,14 +16,14 @@ def parallel_main(): ...@@ -16,14 +16,14 @@ def parallel_main():
router.register_rpc("test_callback", test_callback) router.register_rpc("test_callback", test_callback)
# Wait for nodes to bind # Wait for nodes to bind
time.sleep(0.7) time.sleep(0.7)
router.send_rpc("test_callback", "ping")
for _ in range(30): for _ in range(30):
router.send_rpc("test_callback", "ping")
if msg["ping"]: if msg["ping"]:
break break
time.sleep(0.03) time.sleep(0.03)
assert msg["ping"] assert msg["ping"]
# Avoid can not receiving messages from each other after exit parallel
time.sleep(0.7)
@pytest.mark.unittest @pytest.mark.unittest
......
...@@ -23,21 +23,21 @@ def test_serial_pipeline(): ...@@ -23,21 +23,21 @@ def test_serial_pipeline():
ctx.pipeline.append(1) ctx.pipeline.append(1)
# Execute step1, step2 twice # Execute step1, step2 twice
task = Task() with Task() as task:
for _ in range(2): for _ in range(2):
task.forward(step0)
task.forward(step1)
assert task.ctx.pipeline == [0, 1, 0, 1]
# Renew and execute step1, step2
task.renew()
assert task.ctx.total_step == 1
task.forward(step0) task.forward(step0)
task.forward(step1) task.forward(step1)
assert task.ctx.pipeline == [0, 1, 0, 1] assert task.ctx.pipeline == [0, 1]
# Renew and execute step1, step2
task.renew()
assert task.ctx.total_step == 1
task.forward(step0)
task.forward(step1)
assert task.ctx.pipeline == [0, 1]
# Test context inheritance # Test context inheritance
task.renew() task.renew()
@pytest.mark.unittest @pytest.mark.unittest
...@@ -52,12 +52,12 @@ def test_serial_yield_pipeline(): ...@@ -52,12 +52,12 @@ def test_serial_yield_pipeline():
def step1(ctx): def step1(ctx):
ctx.pipeline.append(1) ctx.pipeline.append(1)
task = Task() with Task() as task:
task.forward(step0) task.forward(step0)
task.forward(step1) task.forward(step1)
task.backward() task.backward()
assert task.ctx.pipeline == [0, 1, 0] assert task.ctx.pipeline == [0, 1, 0]
assert len(task._backward_stack) == 0 assert len(task._backward_stack) == 0
@pytest.mark.unittest @pytest.mark.unittest
...@@ -71,16 +71,16 @@ def test_async_pipeline(): ...@@ -71,16 +71,16 @@ def test_async_pipeline():
ctx.pipeline.append(1) ctx.pipeline.append(1)
# Execute step1, step2 twice # Execute step1, step2 twice
task = Task(async_mode=True) with Task(async_mode=True) as task:
for _ in range(2): for _ in range(2):
task.forward(step0) task.forward(step0)
time.sleep(0.1) time.sleep(0.1)
task.forward(step1) task.forward(step1)
time.sleep(0.1) time.sleep(0.1)
task.backward() task.backward()
assert task.ctx.pipeline == [0, 1, 0, 1] assert task.ctx.pipeline == [0, 1, 0, 1]
task.renew() task.renew()
assert task.ctx.total_step == 1 assert task.ctx.total_step == 1
@pytest.mark.unittest @pytest.mark.unittest
...@@ -97,17 +97,16 @@ def test_async_yield_pipeline(): ...@@ -97,17 +97,16 @@ def test_async_yield_pipeline():
time.sleep(0.2) time.sleep(0.2)
ctx.pipeline.append(1) ctx.pipeline.append(1)
task = Task(async_mode=True) with Task(async_mode=True) as task:
task.forward(step0) task.forward(step0)
task.forward(step1) task.forward(step1)
time.sleep(0.3) time.sleep(0.3)
task.backward().sync() task.backward().sync()
assert task.ctx.pipeline == [0, 1, 0] assert task.ctx.pipeline == [0, 1, 0]
assert len(task._backward_stack) == 0 assert len(task._backward_stack) == 0
def parallel_main(): def parallel_main():
task = Task()
sync_count = 0 sync_count = 0
def on_sync_parallel_ctx(ctx): def on_sync_parallel_ctx(ctx):
...@@ -115,14 +114,14 @@ def parallel_main(): ...@@ -115,14 +114,14 @@ def parallel_main():
assert isinstance(ctx, Context) assert isinstance(ctx, Context)
sync_count += 1 sync_count += 1
task.on("sync_parallel_ctx", on_sync_parallel_ctx) with Task() as task:
task.use(lambda _: time.sleep(0.2 + random.random() / 10)) task.on("sync_parallel_ctx", on_sync_parallel_ctx)
task.run(max_step=10) task.use(lambda _: time.sleep(0.2 + random.random() / 10))
assert sync_count > 0 task.run(max_step=10)
assert sync_count > 0
def parallel_main_eager(): def parallel_main_eager():
task = Task()
sync_count = 0 sync_count = 0
def on_sync_parallel_ctx(ctx): def on_sync_parallel_ctx(ctx):
...@@ -130,11 +129,12 @@ def parallel_main_eager(): ...@@ -130,11 +129,12 @@ def parallel_main_eager():
assert isinstance(ctx, Context) assert isinstance(ctx, Context)
sync_count += 1 sync_count += 1
task.on("sync_parallel_ctx", on_sync_parallel_ctx) with Task() as task:
for _ in range(10): task.on("sync_parallel_ctx", on_sync_parallel_ctx)
task.forward(lambda _: time.sleep(0.2 + random.random() / 10)) for _ in range(10):
task.renew() task.forward(lambda _: time.sleep(0.2 + random.random() / 10))
assert sync_count > 0 task.renew()
assert sync_count > 0
@pytest.mark.unittest @pytest.mark.unittest
...@@ -143,14 +143,6 @@ def test_parallel_pipeline(): ...@@ -143,14 +143,6 @@ def test_parallel_pipeline():
Parallel.runner(n_parallel_workers=2)(parallel_main) Parallel.runner(n_parallel_workers=2)(parallel_main)
@pytest.mark.unittest
def test_copy_task():
t1 = Task(async_mode=True, n_async_workers=1)
t2 = copy.copy(t1)
assert t2.async_mode
assert t1 is not t2
def attach_mode_main_task(): def attach_mode_main_task():
with Task() as task: with Task() as task:
task.use(lambda _: time.sleep(0.1)) task.use(lambda _: time.sleep(0.1))
...@@ -199,14 +191,14 @@ def test_attach_mode(): ...@@ -199,14 +191,14 @@ def test_attach_mode():
@pytest.mark.unittest @pytest.mark.unittest
def test_label(): def test_label():
task = Task() with Task() as task:
result = {} result = {}
task.use(lambda _: result.setdefault("not_me", True), filter_labels=["async"]) task.use(lambda _: result.setdefault("not_me", True), filter_labels=["async"])
task.use(lambda _: result.setdefault("has_me", True)) task.use(lambda _: result.setdefault("has_me", True))
task.run(max_step=1) task.run(max_step=1)
assert "not_me" not in result assert "not_me" not in result
assert "has_me" in result assert "has_me" in result
def sync_parallel_ctx_main(): def sync_parallel_ctx_main():
......
...@@ -24,12 +24,12 @@ def test_step_timer(): ...@@ -24,12 +24,12 @@ def test_step_timer():
# Lazy mode (with use statment) # Lazy mode (with use statment)
step_timer = StepTimer() step_timer = StepTimer()
task = Task() with Task() as task:
task.use_step_wrapper(step_timer) task.use_step_wrapper(step_timer)
task.use(step1) task.use(step1)
task.use(step2) task.use(step2)
task.use(task.sequence(step3, step4)) task.use(task.sequence(step3, step4))
task.run(3) task.run(3)
assert len(step_timer.records) == 5 assert len(step_timer.records) == 5
for records in step_timer.records.values(): for records in step_timer.records.values():
...@@ -37,12 +37,12 @@ def test_step_timer(): ...@@ -37,12 +37,12 @@ def test_step_timer():
# Eager mode (with forward statment) # Eager mode (with forward statment)
step_timer = StepTimer() step_timer = StepTimer()
task = Task() with Task() as task:
task.use_step_wrapper(step_timer) task.use_step_wrapper(step_timer)
for _ in range(3): for _ in range(3):
task.forward(step1) # Step 1 task.forward(step1) # Step 1
task.forward(step2) # Step 2 task.forward(step2) # Step 2
task.renew() task.renew()
assert len(step_timer.records) == 2 assert len(step_timer.records) == 2
for records in step_timer.records.values(): for records in step_timer.records.values():
...@@ -51,15 +51,19 @@ def test_step_timer(): ...@@ -51,15 +51,19 @@ def test_step_timer():
# Wrapper in wrapper # Wrapper in wrapper
step_timer1 = StepTimer() step_timer1 = StepTimer()
step_timer2 = StepTimer() step_timer2 = StepTimer()
task = Task() with Task() as task:
task.use_step_wrapper(step_timer1) task.use_step_wrapper(step_timer1)
task.use_step_wrapper(step_timer2) task.use_step_wrapper(step_timer2)
task.use(step1) task.use(step1)
task.use(step2) task.use(step2)
task.run(3) task.run(3)
assert len(step_timer1.records) == 2 try:
assert len(step_timer2.records) == 2 assert len(step_timer1.records) == 2
assert len(step_timer2.records) == 2
except:
print("ExceptionStepTimer", step_timer2.records)
raise Exception("StepTimer error")
for records in step_timer1.records.values(): for records in step_timer1.records.values():
assert len(records) == 3 assert len(records) == 3
for records in step_timer2.records.values(): for records in step_timer2.records.values():
......
...@@ -25,6 +25,7 @@ from .time_helper import build_time_helper, EasyTimer, WatchDog ...@@ -25,6 +25,7 @@ from .time_helper import build_time_helper, EasyTimer, WatchDog
from .type_helper import SequenceType from .type_helper import SequenceType
from .scheduler_helper import Scheduler from .scheduler_helper import Scheduler
from .profiler_helper import Profiler, register_profiler from .profiler_helper import Profiler, register_profiler
from .log_writer_helper import DistributedWriter
if ding.enable_linklink: if ding.enable_linklink:
from .linklink_dist_helper import get_rank, get_world_size, dist_mode, dist_init, dist_finalize, \ from .linklink_dist_helper import get_rank, get_world_size, dist_mode, dist_init, dist_finalize, \
......
...@@ -4,7 +4,7 @@ import os ...@@ -4,7 +4,7 @@ import os
import numpy as np import numpy as np
import yaml import yaml
from tabulate import tabulate from tabulate import tabulate
from tensorboardX import SummaryWriter from .log_writer_helper import DistributedWriter
from typing import Optional, Tuple, Union, Dict, Any from typing import Optional, Tuple, Union, Dict, Any
...@@ -32,7 +32,7 @@ def build_logger( ...@@ -32,7 +32,7 @@ def build_logger(
name = 'default' name = 'default'
logger = LoggerFactory.create_logger(path, name=name) if need_text else None logger = LoggerFactory.create_logger(path, name=name) if need_text else None
tb_name = name + '_tb_logger' tb_name = name + '_tb_logger'
tb_logger = SummaryWriter(os.path.join(path, tb_name)) if need_tb else None tb_logger = DistributedWriter(os.path.join(path, tb_name)) if need_tb else None
return logger, tb_logger return logger, tb_logger
......
from tensorboardX import SummaryWriter
from typing import TYPE_CHECKING
if TYPE_CHECKING:
# TYPE_CHECKING is always False at runtime, but mypy will evaluate the contents of this block.
# So if you import this module within TYPE_CHECKING, you will get code hints and other benefits.
# Here is a good answer on stackoverflow:
# https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports
from ding.framework import Task
class DistributedWriter(SummaryWriter):
"""
Overview:
A simple subclass of SummaryWriter that supports writing to one process in multi-process mode.
The best way is to use it in conjunction with the ``task`` to take advantage of the message \
and event components of the task (see ``writer.plugin``).
"""
def __init__(self, *args, **kwargs):
self._default_writer_to_disk = kwargs.get("write_to_disk") if "write_to_disk" in kwargs else True
# We need to write data to files lazily, so we should not use file writer in __init__,
# On the contrary, we will initialize the file writer when the user calls the
# add_* function for the first time
kwargs["write_to_disk"] = False
super().__init__(*args, **kwargs)
self._in_parallel = False
self._task = None
self._is_writer = False
self._lazy_initialized = False
def plugin(self, task: "Task", is_writer: bool = False) -> "DistributedWriter":
"""
Overview:
Plugin ``task``, so when using this writer in the task pipeline, it will automatically send requests\
to the main writer instead of writing it to the disk. So we can collect data from multiple processes\
and write them into one file.
Examples:
>>> DistributedWriter().plugin(task, is_writer=("node.0" in task.labels))
"""
if task.router.is_active:
self._in_parallel = True
self._task = task
self._is_writer = is_writer
if is_writer:
self.initialize()
self._lazy_initialized = True
task.router.register_rpc("distributed_writer", self._on_distributed_writer)
task.once("exit", lambda: self.close())
return self
def _on_distributed_writer(self, fn_name: str, *args, **kwargs):
if self._is_writer:
getattr(self, fn_name)(*args, **kwargs)
def initialize(self):
self.close()
self._write_to_disk = self._default_writer_to_disk
self._get_file_writer()
self._lazy_initialized = True
def __del__(self):
self.close()
def enable_parallel(fn_name, fn):
def _parallel_fn(self: DistributedWriter, *args, **kwargs):
if not self._lazy_initialized:
self.initialize()
if self._in_parallel and not self._is_writer:
self._task.router.send_rpc("distributed_writer", fn_name, *args, **kwargs)
else:
fn(self, *args, **kwargs)
return _parallel_fn
ready_to_parallel_fns = [
'add_audio',
'add_custom_scalars',
'add_custom_scalars_marginchart',
'add_custom_scalars_multilinechart',
'add_embedding',
'add_figure',
'add_graph',
'add_graph_deprecated',
'add_histogram',
'add_histogram_raw',
'add_hparams',
'add_image',
'add_image_with_boxes',
'add_images',
'add_mesh',
'add_onnx_graph',
'add_openvino_graph',
'add_pr_curve',
'add_pr_curve_raw',
'add_scalar',
'add_scalars',
'add_text',
'add_video',
]
for fn_name in ready_to_parallel_fns:
setattr(DistributedWriter, fn_name, enable_parallel(fn_name, getattr(DistributedWriter, fn_name)))
import random import random
from collections import deque
import numpy as np
import pytest import pytest
from easydict import EasyDict from easydict import EasyDict
import logging
from ding.utils.log_helper import build_logger, pretty_print from ding.utils.log_helper import build_logger, pretty_print
from ding.utils.file_helper import remove_file from ding.utils.file_helper import remove_file
......
import pytest
import time
import tempfile
import shutil
import os
from os import path
from ding.framework import Parallel
from ding.framework.task import Task
from ding.utils import DistributedWriter
def main_distributed_writer(tempdir):
with Task() as task:
time.sleep(task.router.node_id * 1) # Sleep 0 and 1, write to different files
tblogger = DistributedWriter(tempdir).plugin(task, is_writer=("node.0" in task.labels))
def _add_scalar(ctx):
n = 10
for i in range(n):
tblogger.add_scalar(str(task.router.node_id), task.router.node_id, ctx.total_step * n + i)
task.use(_add_scalar)
task.use(lambda _: time.sleep(0.2))
task.run(max_step=10)
time.sleep(0.3 + (1 - task.router.node_id) * 2)
@pytest.mark.unittest
def test_distributed_writer():
tempdir = path.join(tempfile.gettempdir(), "tblogger")
try:
Parallel.runner(n_parallel_workers=2)(main_distributed_writer, tempdir)
assert path.exists(tempdir)
assert len(os.listdir(tempdir)) == 1
finally:
if path.exists(tempdir):
shutil.rmtree(tempdir)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册