From a2354d8bc664c06260b176fea453460fb8d55055 Mon Sep 17 00:00:00 2001 From: wangjiawei04 Date: Thu, 20 Aug 2020 11:31:17 +0800 Subject: [PATCH] add inmemory dataset --- core/trainers/framework/dataset.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/core/trainers/framework/dataset.py b/core/trainers/framework/dataset.py index ae3b7c38..449aae41 100644 --- a/core/trainers/framework/dataset.py +++ b/core/trainers/framework/dataset.py @@ -26,6 +26,7 @@ from paddle.fluid.contrib.utils.hdfs_utils import HDFSClient __all__ = ["DatasetBase", "DataLoader", "QueueDataset", "InMemoryDataset"] + class DatasetBase(object): """R """ @@ -152,9 +153,10 @@ class QueueDataset(DatasetBase): break return dataset + class InMemoryDataset(QueueDataset): def _get_dataset(self, dataset_name, context): - with open("context.txt", "w+") as fout: + with open("context.txt", "w+") as fout: fout.write(str(context)) name = "dataset." + dataset_name + "." reader_class = envs.get_global_env(name + "data_converter") @@ -197,7 +199,10 @@ class InMemoryDataset(QueueDataset): "hadoop.job.ugi": hdfs_ugi } hdfs_client = HDFSClient(hadoop_home, hdfs_configs) - file_list = ["{}/{}".format(hdfs_addr, x) for x in hdfs_client.lsr(train_data_path)] + file_list = [ + "{}/{}".format(hdfs_addr, x) + for x in hdfs_client.lsr(train_data_path) + ] if context["engine"] == EngineMode.LOCAL_CLUSTER: file_list = split_files(file_list, context["fleet"].worker_index(), context["fleet"].worker_num()) -- GitLab