提交 fcbb1cc0 编写于 作者: C chengmo

for merge

上级 cd7cb08a
......@@ -20,13 +20,14 @@ import os
import copy
from fleetrec.core.engine.engine import Engine
from fleetrec.core.utils import envs
class LocalClusterEngine(Engine):
def start_procs(self):
worker_num = self.envs["worker_num"]
server_num = self.envs["server_num"]
start_port = self.envs["start_port"]
ports = [self.envs["start_port"]]
logs_dir = self.envs["log_dir"]
default_env = os.environ.copy()
......@@ -36,10 +37,19 @@ class LocalClusterEngine(Engine):
current_env.pop("https_proxy", None)
procs = []
log_fns = []
ports = range(start_port, start_port + server_num, 1)
for i in range(server_num - 1):
while True:
new_port = envs.find_free_port()
if new_port not in ports:
ports.append(new_port)
break
user_endpoints = ",".join(["127.0.0.1:" + str(x) for x in ports])
user_endpoints_ips = [x.split(":")[0] for x in user_endpoints.split(",")]
user_endpoints_port = [x.split(":")[1] for x in user_endpoints.split(",")]
user_endpoints_ips = [x.split(":")[0]
for x in user_endpoints.split(",")]
user_endpoints_port = [x.split(":")[1]
for x in user_endpoints.split(",")]
factory = "fleetrec.core.factory"
cmd = [sys.executable, "-u", "-m", factory, self.trainer]
......@@ -56,7 +66,8 @@ class LocalClusterEngine(Engine):
os.system("mkdir -p {}".format(logs_dir))
fn = open("%s/server.%d" % (logs_dir, i), "w")
log_fns.append(fn)
proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn, cwd=os.getcwd())
proc = subprocess.Popen(
cmd, env=current_env, stdout=fn, stderr=fn, cwd=os.getcwd())
procs.append(proc)
for i in range(worker_num):
......@@ -70,7 +81,8 @@ class LocalClusterEngine(Engine):
os.system("mkdir -p {}".format(logs_dir))
fn = open("%s/worker.%d" % (logs_dir, i), "w")
log_fns.append(fn)
proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn, cwd=os.getcwd())
proc = subprocess.Popen(
cmd, env=current_env, stdout=fn, stderr=fn, cwd=os.getcwd())
procs.append(proc)
# only wait worker to finish here
......
......@@ -15,7 +15,7 @@
import os
import copy
import sys
import socket
global_envs = {}
......@@ -78,7 +78,8 @@ def get_global_env(env_name, default_value=None, namespace=None):
"""
get os environment value
"""
_env_name = env_name if namespace is None else ".".join([namespace, env_name])
_env_name = env_name if namespace is None else ".".join(
[namespace, env_name])
return global_envs.get(_env_name, default_value)
......@@ -146,7 +147,8 @@ def pretty_print_envs(envs, header=None):
def lazy_instance_by_package(package, class_name):
models = get_global_env("train.model.models")
model_package = __import__(package, globals(), locals(), package.split("."))
model_package = __import__(
package, globals(), locals(), package.split("."))
instance = getattr(model_package, class_name)
return instance
......@@ -156,7 +158,8 @@ def lazy_instance_by_fliename(abs, class_name):
sys.path.append(dirname)
package = os.path.splitext(os.path.basename(abs))[0]
model_package = __import__(package, globals(), locals(), package.split("."))
model_package = __import__(
package, globals(), locals(), package.split("."))
instance = getattr(model_package, class_name)
return instance
......@@ -170,3 +173,13 @@ def get_platform():
return "DARWIN"
if 'Windows' in plats:
return "WINDOWS"
def find_free_port():
def __free_port():
with closing(socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as s:
s.bind(('', 0))
return s.getsockname()[1]
new_port = __free_port()
return new_port
......@@ -97,6 +97,28 @@ python -m fleetrec.run -m fleetrec.models.rank.dnn -d cpu -e local_cluster
python -m fleetrec.run -m fleetrec.models.rank.dnn -d cpu -e cluster
```
<h2 align="center">支持模型列表</h2>
| 方向 | 模型 | 单机CPU训练 | 单机GPU训练 | 分布式CPU训练 | 大规模稀疏 | 分布式GPU训练 | 自定义数据集 |
| :------: | :--------------------: | :---------: | :---------: | :-----------: | :--------: | :-----------: | :----------: |
| 内容理解 | [Text-Classifcation]() | ✓ | x | ✓ | x | ✓ | ✓ |
| 内容理解 | [TagSpace]() | ✓ | x | ✓ | x | ✓ | ✓ |
| 召回 | [Word2Vec]() | ✓ | x | ✓ | x | ✓ | ✓ |
| 召回 | [TDM]() | ✓ | x | ✓ | x | ✓ | ✓ |
| 召回 | [SSR]() | ✓ | ✓ | ✓ | x | ✓ | ✓ |
| 召回 | [Gru4Rec]() | ✓ | ✓ | ✓ | x | ✓ | ✓ |
| 排序 | [CTR-Dnn]() | ✓ | x | ✓ | x | ✓ | ✓ |
| 排序 | [DeepFm]() | ✓ | x | ✓ | x | ✓ | ✓ |
| 排序 | [ListWise]() | ✓ | x | ✓ | x | ✓ | ✓ |
| 排序 | [DSSM]() | ✓ | x | ✓ | x | ✓ | ✓ |
| 排序 | [Multiview-Simnet]() | ✓ | x | ✓ | x | ✓ | ✓ |
| 融合 | [MMOE]() | ✓ | ✓ | ✓ | x | ✓ | ✓ |
| 融合 | [ESMM]() | ✓ | ✓ | ✓ | x | ✓ | ✓ |
| 融合 | [ESMM]() | ✓ | ✓ | ✓ | x | ✓ | ✓ |
<h2 align="center">文档</h2>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册