提交 2c77c937 编写于 作者: T tangwei

code clean

上级 7f99ff03
......@@ -31,8 +31,8 @@ train:
reader:
batch_size: 2
pipe_command: "python /paddle/eleps/fleetrec/models/ctr_dnn/dataset.py"
train_data_path: "/paddle/eleps/fleetrec/models/ctr_dnn/data/train"
class: "fleetrec.models.ctr_dnn.data_generator"
train_data_path: "/root/FleetRec/fleetrec/models/ctr_dnn/data/train/"
model:
models: "fleetrec.models.ctr_dnn.model"
......
......@@ -26,7 +26,8 @@ class Reader(dg.MultiSlotDataGenerator):
__metaclass__ = abc.ABCMeta
def __init__(self, config):
super().__init__()
dg.MultiSlotDataGenerator.__init__(self)
if os.path.exists(config) and os.path.isfile(config):
with open(config, 'r') as rb:
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
......
......@@ -30,14 +30,14 @@ class TrainerFactory(object):
pass
@staticmethod
def _build_trainer(config):
def _build_trainer(config, yaml_path):
print(envs.pretty_print_envs(envs.get_global_envs()))
train_mode = envs.get_global_env("train.trainer")
if train_mode == "SingleTraining":
trainer = SingleTrainer(config)
trainer = SingleTrainer(yaml_path)
elif train_mode == "ClusterTraining":
trainer = ClusterTrainer(config)
trainer = ClusterTrainer(yaml_path)
elif train_mode == "CtrTrainer":
trainer = CtrPaddleTrainer(config)
else:
......@@ -75,7 +75,7 @@ class TrainerFactory(object):
if mode == "ClusterTraining" and container == "local" and not instance:
trainer = TrainerFactory._build_engine(config)
else:
trainer = TrainerFactory._build_trainer(_config)
trainer = TrainerFactory._build_trainer(_config, config)
return trainer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册