提交 ca1c4695 编写于 作者: X xjqbest

fix

上级 07bd7092
......@@ -35,7 +35,6 @@ class Reader(dg.MultiSlotDataGenerator):
else:
raise ValueError("reader config only support yaml")
@abc.abstractmethod
def init(self):
"""init"""
......@@ -56,8 +55,6 @@ class SlotReader(dg.MultiSlotDataGenerator):
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
raise ValueError("reader config only support yaml")
#envs.set_global_envs(_config)
#envs.update_workspace()
def init(self, sparse_slots, dense_slots, padding=0):
from operator import mul
......
......@@ -69,13 +69,14 @@ class SingleTrainer(TranspileTrainer):
reader = os.path.join(abs_dir, '../utils', 'dataset_instance.py')
if sparse_slots is None and dense_slots is None:
pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config_yaml)
pipe_cmd = "python {} {} {} {}".format(reader, reader_class,
"TRAIN", self._config_yaml)
else:
if sparse_slots is None:
sparse_slots = "#"
if dense_slots is None:
dense_slots = "#"
padding = envs.get_global_env(name +"padding", 0)
padding = envs.get_global_env(name + "padding", 0)
pipe_cmd = "python {} {} {} {} {} {} {} {}".format(
reader, "slot", "slot", self._config_yaml, "fake", \
sparse_slots.replace(" ", "#"), dense_slots.replace(" ", "#"), str(padding))
......@@ -145,19 +146,29 @@ class SingleTrainer(TranspileTrainer):
scope = fluid.Scope()
dataset_name = model_dict["dataset_name"]
opt_name = envs.get_global_env("hyper_parameters.optimizer.class")
opt_lr = envs.get_global_env("hyper_parameters.optimizer.learning_rate")
opt_strategy = envs.get_global_env("hyper_parameters.optimizer.strategy")
opt_lr = envs.get_global_env(
"hyper_parameters.optimizer.learning_rate")
opt_strategy = envs.get_global_env(
"hyper_parameters.optimizer.strategy")
with fluid.program_guard(train_program, startup_program):
with fluid.unique_name.guard():
with fluid.scope_guard(scope):
model_path = model_dict["model"].replace("{workspace}", envs.path_adapter(self._env["workspace"]))
model = envs.lazy_instance_by_fliename(model_path, "Model")(self._env)
model._data_var = model.input_data(dataset_name=model_dict["dataset_name"])
if envs.get_global_env("dataset." + dataset_name + ".type") == "DataLoader":
model_path = model_dict["model"].replace(
"{workspace}",
envs.path_adapter(self._env["workspace"]))
model = envs.lazy_instance_by_fliename(
model_path, "Model")(self._env)
model._data_var = model.input_data(
dataset_name=model_dict["dataset_name"])
if envs.get_global_env("dataset." + dataset_name +
".type") == "DataLoader":
model._init_dataloader()
self._get_dataloader(dataset_name, model._data_loader)
model.net(model._data_var, is_infer=model_dict.get("is_infer", False))
optimizer = model._build_optimizer(opt_name, opt_lr, opt_strategy)
self._get_dataloader(dataset_name,
model._data_loader)
model.net(model._data_var,
is_infer=model_dict.get("is_infer", False))
optimizer = model._build_optimizer(opt_name, opt_lr,
opt_strategy)
optimizer.minimize(model._cost)
self._model[model_dict["name"]][0] = train_program
self._model[model_dict["name"]][1] = startup_program
......@@ -167,13 +178,14 @@ class SingleTrainer(TranspileTrainer):
for dataset in self._env["dataset"]:
if dataset["type"] != "DataLoader":
self._dataset[dataset["name"]] = self._create_dataset(dataset["name"])
self._dataset[dataset["name"]] = self._create_dataset(dataset[
"name"])
context['status'] = 'startup_pass'
def startup(self, context):
for model_dict in self._env["executor"]:
with fluid.scope_guard(self._model[model_dict["name"]][2]):
with fluid.scope_guard(self._model[model_dict["name"]][2]):
self._exe.run(self._model[model_dict["name"]][1])
context['status'] = 'train_pass'
......@@ -289,7 +301,8 @@ class SingleTrainer(TranspileTrainer):
return epoch_id % epoch_interval == 0
def save_inference_model():
save_interval = envs.get_global_env("epoch.save_inference_interval", -1)
save_interval = int(
envs.get_global_env("epoch.save_inference_interval", -1)
if not need_save(epoch_id, save_interval, False):
return
feed_varnames = envs.get_global_env("epoch.save_inference_feed_varnames", None)
......@@ -313,7 +326,8 @@ class SingleTrainer(TranspileTrainer):
fetch_vars, self._exe)
def save_persistables():
save_interval = int(envs.get_global_env("epoch.save_checkpoint_interval", -1))
save_interval = int(
envs.get_global_env("epoch.save_checkpoint_interval", -1))
if not need_save(epoch_id, save_interval, False):
return
dirname = envs.get_global_env("epoch.save_checkpoint_path", None)
......
......@@ -19,6 +19,7 @@ from paddlerec.core.utils.envs import get_global_env
from paddlerec.core.utils.envs import get_runtime_environ
from paddlerec.core.reader import SlotReader
def dataloader_by_name(readerclass, dataset_name, yaml_file):
reader_class = lazy_instance_by_fliename(readerclass, "TrainReader")
name = "dataset." + dataset_name + "."
......@@ -30,9 +31,9 @@ def dataloader_by_name(readerclass, dataset_name, yaml_file):
data_path = os.path.join(package_base, data_path.split("::")[1])
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
reader = reader_class(yaml_file)
reader.init()
def gen_reader():
for file in files:
with open(file, 'r') as f:
......@@ -67,7 +68,6 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file):
data_path = os.path.join(package_base, data_path.split("::")[1])
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
sparse = get_global_env(name + "sparse_slots")
dense = get_global_env(name + "dense_slots")
padding = get_global_env(name + "padding", 0)
......@@ -96,6 +96,7 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file):
return gen_batch_reader()
return gen_reader
def dataloader(readerclass, train, yaml_file):
if train == "TRAIN":
reader_name = "TrainReader"
......
......@@ -20,6 +20,7 @@ import sys
global_envs = {}
def flatten_environs(envs, separator="."):
flatten_dict = {}
assert isinstance(envs, dict)
......@@ -81,6 +82,7 @@ def set_global_envs(envs):
fatten_env_namespace([], envs)
def get_global_env(env_name, default_value=None, namespace=None):
"""
get os environment value
......
......@@ -27,9 +27,12 @@ class Model(ModelBase):
def _init_hyper_parameters(self):
self.is_distributed = True if envs.get_trainer(
) == "CtrTrainer" else False
self.sparse_feature_number = envs.get_global_env("hyper_parameters.sparse_feature_number")
self.sparse_feature_dim = envs.get_global_env("hyper_parameters.sparse_feature_dim")
self.learning_rate = envs.get_global_env("hyper_parameters.learning_rate")
self.sparse_feature_number = envs.get_global_env(
"hyper_parameters.sparse_feature_number")
self.sparse_feature_dim = envs.get_global_env(
"hyper_parameters.sparse_feature_dim")
self.learning_rate = envs.get_global_env(
"hyper_parameters.learning_rate")
def net(self, input, is_infer=False):
self.sparse_inputs = self._sparse_data_var[1:]
......
......@@ -68,10 +68,8 @@ def get_engine(args):
if engine is None:
engine = run_extras.get("epoch.trainer_class", None)
if engine is None:
engine = "single"
engine = "single"
engine = engine.upper()
if engine not in engine_choices:
raise ValueError("train.engin can not be chosen in {}".format(
engine_choices))
......@@ -135,6 +133,7 @@ def single_engine(args):
trainer = TrainerFactory.create(args.model)
return trainer
def cluster_engine(args):
def update_workspace(cluster_envs):
workspace = cluster_envs.get("engine_workspace", None)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册