未验证 提交 7a359597 编写于 作者: C Chengmo 提交者: GitHub

fix split files at PY3 (#103)

* fix split files at PY3

* fix linux at PY3

* fix desc error

* fix collective cards and worknum
Co-authored-by: Ntangwei <tangwei12@baidu.com>
上级 947395bb
......@@ -15,13 +15,13 @@
from __future__ import print_function
import os
import warnings
import paddle.fluid as fluid
from paddlerec.core.utils import envs
from paddlerec.core.utils import dataloader_instance
from paddlerec.core.reader import SlotReader
from paddlerec.core.trainer import EngineMode
from paddlerec.core.utils.util import split_files
__all__ = ["DatasetBase", "DataLoader", "QueueDataset"]
......@@ -123,7 +123,8 @@ class QueueDataset(DatasetBase):
for x in os.listdir(train_data_path)
]
if context["engine"] == EngineMode.LOCAL_CLUSTER:
file_list = context["fleet"].split_files(file_list)
file_list = split_files(file_list, context["fleet"].worker_index(),
context["fleet"].worker_num())
dataset.set_filelist(file_list)
for model_dict in context["phases"]:
......
......@@ -19,6 +19,7 @@ from paddlerec.core.utils.envs import get_global_env
from paddlerec.core.utils.envs import get_runtime_environ
from paddlerec.core.reader import SlotReader
from paddlerec.core.trainer import EngineMode
from paddlerec.core.utils.util import split_files
def dataloader_by_name(readerclass,
......@@ -39,7 +40,8 @@ def dataloader_by_name(readerclass,
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
if context["engine"] == EngineMode.LOCAL_CLUSTER:
files = context["fleet"].split_files(files)
files = split_files(files, context["fleet"].worker_index(),
context["fleet"].worker_num())
print("file_list : {}".format(files))
reader = reader_class(yaml_file)
......@@ -80,7 +82,8 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file, context):
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
if context["engine"] == EngineMode.LOCAL_CLUSTER:
files = context["fleet"].split_files(files)
files = split_files(files, context["fleet"].worker_index(),
context["fleet"].worker_num())
print("file_list: {}".format(files))
sparse = get_global_env(name + "sparse_slots", "#")
......@@ -133,7 +136,8 @@ def slotdataloader(readerclass, train, yaml_file, context):
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
if context["engine"] == EngineMode.LOCAL_CLUSTER:
files = context["fleet"].split_files(files)
files = split_files(files, context["fleet"].worker_index(),
context["fleet"].worker_num())
print("file_list: {}".format(files))
sparse = get_global_env("sparse_slots", "#", namespace)
......
......@@ -18,6 +18,7 @@ import copy
import os
import socket
import sys
import six
import traceback
import six
......@@ -102,6 +103,12 @@ def set_global_envs(envs):
name = ".".join(["dataset", dataset["name"], "type"])
global_envs[name] = "DataLoader"
if get_platform() == "LINUX" and six.PY3:
print("QueueDataset can not support PY3, change to DataLoader")
for dataset in envs["dataset"]:
name = ".".join(["dataset", dataset["name"], "type"])
global_envs[name] = "DataLoader"
def get_global_env(env_name, default_value=None, namespace=None):
"""
......
......@@ -19,11 +19,8 @@ import time
import numpy as np
from paddle import fluid
from paddlerec.core.utils import fs as fs
def save_program_proto(path, program=None):
if program is None:
_program = fluid.default_main_program()
else:
......@@ -171,6 +168,39 @@ def print_cost(cost, params):
return log_str
def split_files(files, trainer_id, trainers):
"""
split files before distributed training,
example 1: files is [a, b, c ,d, e] and trainer_num = 2, then trainer
0 gets [a, b, c] and trainer 1 gets [d, e].
example 2: files is [a, b], and trainer_num = 3, then trainer 0 gets
[a], trainer 1 gets [b], trainer 2 gets []
Args:
files(list): file list need to be read.
Returns:
list: files belongs to this worker.
"""
if not isinstance(files, list):
raise TypeError("files should be a list of file need to be read.")
remainder = len(files) % trainers
blocksize = int(len(files) / trainers)
blocks = [blocksize] * trainers
for i in range(remainder):
blocks[i] += 1
trainer_files = [[]] * trainers
begin = 0
for i in range(trainers):
trainer_files[i] = files[begin:begin + blocks[i]]
begin += blocks[i]
return trainer_files[trainer_id]
class CostPrinter(object):
"""
For count cost time && print cost log
......
......@@ -139,8 +139,8 @@ def get_engine(args, running_config, mode):
engine = "LOCAL_CLUSTER_TRAIN"
if engine not in engine_choices:
raise ValueError("{} can not be chosen in {}".format(engine_class,
engine_choices))
raise ValueError("{} can only be chosen in {}".format(engine_class,
engine_choices))
run_engine = engines[transpiler].get(engine, None)
return run_engine
......@@ -439,8 +439,8 @@ def local_cluster_engine(args):
if fleet_mode == "COLLECTIVE":
cluster_envs["selected_gpus"] = selected_gpus
gpus = selected_gpus.split(",")
gpu_num = get_worker_num(run_extras, len(gpus))
cluster_envs["selected_gpus"] = ','.join(gpus[:gpu_num])
worker_num = get_worker_num(run_extras, len(gpus))
cluster_envs["selected_gpus"] = ','.join(gpus[:worker_num])
cluster_envs["server_num"] = server_num
cluster_envs["worker_num"] = worker_num
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册