未验证 提交 954d3100 编写于 作者: X Xu Jingxin 提交者: GitHub

feature(xjx): cli in new pipeline (#160)

* Cli ditask

* Import ditask in init

* Add current path as default package path

* Fix style

* Add topology on ditask
上级 92d973c1
from .cli import cli
from .cli_ditask import cli_ditask
from .serial_entry import serial_pipeline
from .serial_entry_onpolicy import serial_pipeline_onpolicy
from .serial_entry_offline import serial_pipeline_offline
......
import click
import os
import sys
import importlib
import importlib.util
from click.core import Context, Option
from ding import __TITLE__, __VERSION__, __AUTHOR__, __AUTHOR_EMAIL__
from ding.framework import Parallel
def print_version(ctx: Context, param: Option, value: bool) -> None:
if not value or ctx.resilient_parsing:
return
click.echo('{title}, version {version}.'.format(title=__TITLE__, version=__VERSION__))
click.echo('Developed by {author}, {email}.'.format(author=__AUTHOR__, email=__AUTHOR_EMAIL__))
ctx.exit()
CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
@click.command(context_settings=CONTEXT_SETTINGS)
@click.option(
'-v',
'--version',
is_flag=True,
callback=print_version,
expose_value=False,
is_eager=True,
help="Show package's version information."
)
@click.option('-p', '--package', type=str, help="Your code package path, could be a directory or a zip file.")
@click.option('--parallel-workers', type=int, default=1, help="Parallel worker number, default: 1")
@click.option(
'--protocol',
type=click.Choice(["tcp", "ipc"]),
default="tcp",
help="Network protocol in parallel mode, default: tcp"
)
@click.option(
"--ports",
type=str,
default="50515",
help="The port addresses that the tasks listen to, e.g. 50515,50516, default: 50515"
)
@click.option("--attach-to", type=str, help="The addresses to connect to.")
@click.option("--address", type=str, help="The address to listen to (without port).")
@click.option("--labels", type=str, help="Labels.")
@click.option("--node-ids", type=str, help="Candidate node ids.")
@click.option(
"--topology",
type=click.Choice(["alone", "mesh", "star"]),
default="alone",
help="Network topology, default: alone."
)
@click.option("-m", "--main", type=str, help="Main function of entry module.")
def cli_ditask(
package: str, main: str, parallel_workers: int, protocol: str, ports: str, attach_to: str, address: str,
labels: str, node_ids: str, topology: str
):
# Parse entry point
if not package:
package = os.getcwd()
sys.path.append(package)
if main is None:
mod_name = os.path.basename(package)
mod_name, _ = os.path.splitext(mod_name)
func_name = "main"
else:
mod_name, func_name = main.rsplit(".", 1)
root_mod_name = mod_name.split(".", 1)[0]
sys.path.append(os.path.join(package, root_mod_name))
mod = importlib.import_module(mod_name)
main_func = getattr(mod, func_name)
# Parse arguments
ports = ports.split(",")
ports = list(map(lambda i: int(i), ports))
ports = ports[0] if len(ports) == 1 else ports
if attach_to:
attach_to = attach_to.split(",")
attach_to = list(map(lambda s: s.strip(), attach_to))
if labels:
labels = labels.split(",")
labels = set(map(lambda s: s.strip(), labels))
if node_ids:
node_ids = node_ids.split(",")
node_ids = list(map(lambda i: int(i), node_ids))
Parallel.runner(
n_parallel_workers=parallel_workers,
ports=ports,
protocol=protocol,
topology=topology,
attach_to=attach_to,
address=address,
labels=labels,
node_ids=node_ids
)(main_func)
......@@ -9,7 +9,7 @@ import logging
import tempfile
import socket
from os import path
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union, Set
from threading import Thread
from pynng.nng import Bus0, Socket
from ding.utils.design_helper import SingletonMetaclass
......@@ -30,10 +30,18 @@ class Parallel(metaclass=SingletonMetaclass):
self.attach_to = None
self.finished = False
self.node_id = None
self.labels = set()
def run(self, node_id: int, listen_to: str, attach_to: List[str] = None) -> None:
def run(
self,
node_id: int,
listen_to: str,
attach_to: Optional[List[str]] = None,
labels: Optional[Set[str]] = None
) -> None:
self.node_id = node_id
self.attach_to = attach_to = attach_to or []
self.labels = labels or set()
self._listener = Thread(
target=self.listen,
kwargs={
......@@ -52,7 +60,9 @@ class Parallel(metaclass=SingletonMetaclass):
protocol: str = "ipc",
address: Optional[str] = None,
ports: Optional[List[int]] = None,
topology: str = "mesh"
topology: str = "mesh",
labels: Optional[Set[str]] = None,
node_ids: Optional[List[int]] = None
) -> Callable:
"""
Overview:
......@@ -66,6 +76,9 @@ class Parallel(metaclass=SingletonMetaclass):
- topology (:obj:`str`): Network topology, includes:
`mesh` (default): fully connected between each other;
`star`: only connect to the first node;
`alone`: do not connect to any node, except the node attached to;
- labels (:obj:`Optional[Set[str]]`): Labels.
- node_ids (:obj:`Optional[List[int]]`): Candidate node ids.
Returns:
- _runner (:obj:`Callable`): The wrapper function for main.
"""
......@@ -91,21 +104,29 @@ class Parallel(metaclass=SingletonMetaclass):
atexit.register(cleanup_nodes)
def topology_network(node_id: int) -> List[str]:
def topology_network(i: int) -> List[str]:
if topology == "mesh":
return nodes[:node_id] + attach_to
return nodes[:i] + attach_to
elif topology == "star":
return nodes[:min(1, node_id)]
return nodes[:min(1, i)] + attach_to
elif topology == "alone":
return attach_to
else:
raise ValueError("Unknown topology: {}".format(topology))
params_group = []
for node_id in range(n_parallel_workers):
candidate_node_ids = node_ids or range(n_parallel_workers)
assert len(candidate_node_ids) == n_parallel_workers, \
"The number of workers must be the same as the number of node_ids, \
now there are {} workers and {} nodes"\
.format(n_parallel_workers, len(candidate_node_ids))
for i in range(n_parallel_workers):
runner_args = []
runner_kwargs = {
"node_id": node_id,
"listen_to": nodes[node_id],
"attach_to": topology_network(node_id) + attach_to
"node_id": candidate_node_ids[i],
"listen_to": nodes[i],
"attach_to": topology_network(i) + attach_to,
"labels": labels
}
params = [(runner_args, runner_kwargs), (main_process, args, kwargs)]
params_group.append(params)
......@@ -151,6 +172,8 @@ class Parallel(metaclass=SingletonMetaclass):
elif protocol == "tcp":
address = address or Parallel.get_ip()
ports = ports or range(50515, 50515 + n_workers)
if isinstance(ports, int):
ports = range(ports, ports + n_workers)
assert len(ports) == n_workers, "The number of ports must be the same as the number of workers, \
now there are {} ports and {} workers".format(len(ports), n_workers)
nodes = ["tcp://{}:{}".format(address, port) for port in ports]
......
......@@ -98,6 +98,8 @@ class Task:
if self.router.is_active:
self.labels.add("distributed")
self.labels.add("node.{}".format(self.router.node_id))
for label in self.router.labels:
self.labels.add(label)
else:
self.labels.add("standalone")
......
......@@ -158,7 +158,7 @@ setup(
'kubernetes',
]
},
entry_points={'console_scripts': ['ding=ding.entry.cli:cli']},
entry_points={'console_scripts': ['ding=ding.entry.cli:cli', 'ditask=ding.entry.cli_ditask:cli_ditask']},
classifiers=[
'Development Status :: 5 - Production/Stable',
"Intended Audience :: Science/Research",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册