diff --git a/core/trainers/framework/dataset.py b/core/trainers/framework/dataset.py index ae3b7c38824eec0b509579bf29f122ab58fb3a30..449aae4124c514fa30de921fd5c2ebcc94588897 100644 --- a/core/trainers/framework/dataset.py +++ b/core/trainers/framework/dataset.py @@ -26,6 +26,7 @@ from paddle.fluid.contrib.utils.hdfs_utils import HDFSClient __all__ = ["DatasetBase", "DataLoader", "QueueDataset", "InMemoryDataset"] + class DatasetBase(object): """R """ @@ -152,9 +153,10 @@ class QueueDataset(DatasetBase): break return dataset + class InMemoryDataset(QueueDataset): def _get_dataset(self, dataset_name, context): - with open("context.txt", "w+") as fout: + with open("context.txt", "w+") as fout: fout.write(str(context)) name = "dataset." + dataset_name + "." reader_class = envs.get_global_env(name + "data_converter") @@ -197,7 +199,10 @@ class InMemoryDataset(QueueDataset): "hadoop.job.ugi": hdfs_ugi } hdfs_client = HDFSClient(hadoop_home, hdfs_configs) - file_list = ["{}/{}".format(hdfs_addr, x) for x in hdfs_client.lsr(train_data_path)] + file_list = [ + "{}/{}".format(hdfs_addr, x) + for x in hdfs_client.lsr(train_data_path) + ] if context["engine"] == EngineMode.LOCAL_CLUSTER: file_list = split_files(file_list, context["fleet"].worker_index(), context["fleet"].worker_num())