未验证 提交 954d3100 编写于 作者: X Xu Jingxin 提交者: GitHub

feature(xjx): cli in new pipeline (#160)

* Cli ditask

* Import ditask in init

* Add current path as default package path

* Fix style

* Add topology on ditask
上级 92d973c1
from .cli import cli from .cli import cli
from .cli_ditask import cli_ditask
from .serial_entry import serial_pipeline from .serial_entry import serial_pipeline
from .serial_entry_onpolicy import serial_pipeline_onpolicy from .serial_entry_onpolicy import serial_pipeline_onpolicy
from .serial_entry_offline import serial_pipeline_offline from .serial_entry_offline import serial_pipeline_offline
......
import click
import os
import sys
import importlib
import importlib.util
from click.core import Context, Option
from ding import __TITLE__, __VERSION__, __AUTHOR__, __AUTHOR_EMAIL__
from ding.framework import Parallel
def print_version(ctx: Context, param: Option, value: bool) -> None:
if not value or ctx.resilient_parsing:
return
click.echo('{title}, version {version}.'.format(title=__TITLE__, version=__VERSION__))
click.echo('Developed by {author}, {email}.'.format(author=__AUTHOR__, email=__AUTHOR_EMAIL__))
ctx.exit()
CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
@click.command(context_settings=CONTEXT_SETTINGS)
@click.option(
'-v',
'--version',
is_flag=True,
callback=print_version,
expose_value=False,
is_eager=True,
help="Show package's version information."
)
@click.option('-p', '--package', type=str, help="Your code package path, could be a directory or a zip file.")
@click.option('--parallel-workers', type=int, default=1, help="Parallel worker number, default: 1")
@click.option(
'--protocol',
type=click.Choice(["tcp", "ipc"]),
default="tcp",
help="Network protocol in parallel mode, default: tcp"
)
@click.option(
"--ports",
type=str,
default="50515",
help="The port addresses that the tasks listen to, e.g. 50515,50516, default: 50515"
)
@click.option("--attach-to", type=str, help="The addresses to connect to.")
@click.option("--address", type=str, help="The address to listen to (without port).")
@click.option("--labels", type=str, help="Labels.")
@click.option("--node-ids", type=str, help="Candidate node ids.")
@click.option(
"--topology",
type=click.Choice(["alone", "mesh", "star"]),
default="alone",
help="Network topology, default: alone."
)
@click.option("-m", "--main", type=str, help="Main function of entry module.")
def cli_ditask(
package: str, main: str, parallel_workers: int, protocol: str, ports: str, attach_to: str, address: str,
labels: str, node_ids: str, topology: str
):
# Parse entry point
if not package:
package = os.getcwd()
sys.path.append(package)
if main is None:
mod_name = os.path.basename(package)
mod_name, _ = os.path.splitext(mod_name)
func_name = "main"
else:
mod_name, func_name = main.rsplit(".", 1)
root_mod_name = mod_name.split(".", 1)[0]
sys.path.append(os.path.join(package, root_mod_name))
mod = importlib.import_module(mod_name)
main_func = getattr(mod, func_name)
# Parse arguments
ports = ports.split(",")
ports = list(map(lambda i: int(i), ports))
ports = ports[0] if len(ports) == 1 else ports
if attach_to:
attach_to = attach_to.split(",")
attach_to = list(map(lambda s: s.strip(), attach_to))
if labels:
labels = labels.split(",")
labels = set(map(lambda s: s.strip(), labels))
if node_ids:
node_ids = node_ids.split(",")
node_ids = list(map(lambda i: int(i), node_ids))
Parallel.runner(
n_parallel_workers=parallel_workers,
ports=ports,
protocol=protocol,
topology=topology,
attach_to=attach_to,
address=address,
labels=labels,
node_ids=node_ids
)(main_func)
...@@ -9,7 +9,7 @@ import logging ...@@ -9,7 +9,7 @@ import logging
import tempfile import tempfile
import socket import socket
from os import path from os import path
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union, Set
from threading import Thread from threading import Thread
from pynng.nng import Bus0, Socket from pynng.nng import Bus0, Socket
from ding.utils.design_helper import SingletonMetaclass from ding.utils.design_helper import SingletonMetaclass
...@@ -30,10 +30,18 @@ class Parallel(metaclass=SingletonMetaclass): ...@@ -30,10 +30,18 @@ class Parallel(metaclass=SingletonMetaclass):
self.attach_to = None self.attach_to = None
self.finished = False self.finished = False
self.node_id = None self.node_id = None
self.labels = set()
def run(self, node_id: int, listen_to: str, attach_to: List[str] = None) -> None: def run(
self,
node_id: int,
listen_to: str,
attach_to: Optional[List[str]] = None,
labels: Optional[Set[str]] = None
) -> None:
self.node_id = node_id self.node_id = node_id
self.attach_to = attach_to = attach_to or [] self.attach_to = attach_to = attach_to or []
self.labels = labels or set()
self._listener = Thread( self._listener = Thread(
target=self.listen, target=self.listen,
kwargs={ kwargs={
...@@ -52,7 +60,9 @@ class Parallel(metaclass=SingletonMetaclass): ...@@ -52,7 +60,9 @@ class Parallel(metaclass=SingletonMetaclass):
protocol: str = "ipc", protocol: str = "ipc",
address: Optional[str] = None, address: Optional[str] = None,
ports: Optional[List[int]] = None, ports: Optional[List[int]] = None,
topology: str = "mesh" topology: str = "mesh",
labels: Optional[Set[str]] = None,
node_ids: Optional[List[int]] = None
) -> Callable: ) -> Callable:
""" """
Overview: Overview:
...@@ -66,6 +76,9 @@ class Parallel(metaclass=SingletonMetaclass): ...@@ -66,6 +76,9 @@ class Parallel(metaclass=SingletonMetaclass):
- topology (:obj:`str`): Network topology, includes: - topology (:obj:`str`): Network topology, includes:
`mesh` (default): fully connected between each other; `mesh` (default): fully connected between each other;
`star`: only connect to the first node; `star`: only connect to the first node;
`alone`: do not connect to any node, except the node attached to;
- labels (:obj:`Optional[Set[str]]`): Labels.
- node_ids (:obj:`Optional[List[int]]`): Candidate node ids.
Returns: Returns:
- _runner (:obj:`Callable`): The wrapper function for main. - _runner (:obj:`Callable`): The wrapper function for main.
""" """
...@@ -91,21 +104,29 @@ class Parallel(metaclass=SingletonMetaclass): ...@@ -91,21 +104,29 @@ class Parallel(metaclass=SingletonMetaclass):
atexit.register(cleanup_nodes) atexit.register(cleanup_nodes)
def topology_network(node_id: int) -> List[str]: def topology_network(i: int) -> List[str]:
if topology == "mesh": if topology == "mesh":
return nodes[:node_id] + attach_to return nodes[:i] + attach_to
elif topology == "star": elif topology == "star":
return nodes[:min(1, node_id)] return nodes[:min(1, i)] + attach_to
elif topology == "alone":
return attach_to
else: else:
raise ValueError("Unknown topology: {}".format(topology)) raise ValueError("Unknown topology: {}".format(topology))
params_group = [] params_group = []
for node_id in range(n_parallel_workers): candidate_node_ids = node_ids or range(n_parallel_workers)
assert len(candidate_node_ids) == n_parallel_workers, \
"The number of workers must be the same as the number of node_ids, \
now there are {} workers and {} nodes"\
.format(n_parallel_workers, len(candidate_node_ids))
for i in range(n_parallel_workers):
runner_args = [] runner_args = []
runner_kwargs = { runner_kwargs = {
"node_id": node_id, "node_id": candidate_node_ids[i],
"listen_to": nodes[node_id], "listen_to": nodes[i],
"attach_to": topology_network(node_id) + attach_to "attach_to": topology_network(i) + attach_to,
"labels": labels
} }
params = [(runner_args, runner_kwargs), (main_process, args, kwargs)] params = [(runner_args, runner_kwargs), (main_process, args, kwargs)]
params_group.append(params) params_group.append(params)
...@@ -151,6 +172,8 @@ class Parallel(metaclass=SingletonMetaclass): ...@@ -151,6 +172,8 @@ class Parallel(metaclass=SingletonMetaclass):
elif protocol == "tcp": elif protocol == "tcp":
address = address or Parallel.get_ip() address = address or Parallel.get_ip()
ports = ports or range(50515, 50515 + n_workers) ports = ports or range(50515, 50515 + n_workers)
if isinstance(ports, int):
ports = range(ports, ports + n_workers)
assert len(ports) == n_workers, "The number of ports must be the same as the number of workers, \ assert len(ports) == n_workers, "The number of ports must be the same as the number of workers, \
now there are {} ports and {} workers".format(len(ports), n_workers) now there are {} ports and {} workers".format(len(ports), n_workers)
nodes = ["tcp://{}:{}".format(address, port) for port in ports] nodes = ["tcp://{}:{}".format(address, port) for port in ports]
......
...@@ -98,6 +98,8 @@ class Task: ...@@ -98,6 +98,8 @@ class Task:
if self.router.is_active: if self.router.is_active:
self.labels.add("distributed") self.labels.add("distributed")
self.labels.add("node.{}".format(self.router.node_id)) self.labels.add("node.{}".format(self.router.node_id))
for label in self.router.labels:
self.labels.add(label)
else: else:
self.labels.add("standalone") self.labels.add("standalone")
......
...@@ -158,7 +158,7 @@ setup( ...@@ -158,7 +158,7 @@ setup(
'kubernetes', 'kubernetes',
] ]
}, },
entry_points={'console_scripts': ['ding=ding.entry.cli:cli']}, entry_points={'console_scripts': ['ding=ding.entry.cli:cli', 'ditask=ding.entry.cli_ditask:cli_ditask']},
classifiers=[ classifiers=[
'Development Status :: 5 - Production/Stable', 'Development Status :: 5 - Production/Stable',
"Intended Audience :: Science/Research", "Intended Audience :: Science/Research",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册