提交 74ee4f3f 编写于 作者: T tangwei

add optimizer config, add workspace

上级 3006e6b2
......@@ -18,14 +18,15 @@ train:
strategy: "async"
epochs: 10
workspace: "fleetrec.models.rank.dnn"
reader:
batch_size: 2
class: "fleetrec.models.rank.criteo_reader"
train_data_path: "fleetrec::models/rank/dnn/data/train"
class: "{workspace}/../criteo_reader.py"
train_data_path: "{workspace}/data/train"
model:
models: "fleetrec.models.rank.dnn.model"
models: "{workspace}/model.py"
hyper_parameters:
sparse_inputs_slots: 27
sparse_feature_number: 1000001
......@@ -33,22 +34,14 @@ train:
dense_input_dim: 13
fc_sizes: [512, 256, 128, 32]
learning_rate: 0.001
optimizer: adam
save:
increment:
dirname: "models_for_increment"
dirname: "increment"
epoch_interval: 2
save_last: True
inference:
dirname: "models_for_inference"
dirname: "inference"
epoch_interval: 4
feed_varnames: ["C1", "C2", "C3"]
fetch_varnames: "predict"
save_last: True
evaluate:
batch_size: 32
train_thread_num: 12
reader: "reader.py"
......@@ -67,6 +67,8 @@ class TrainerFactory(object):
raise ValueError("fleetrec's config only support yaml")
envs.set_global_envs(_config)
envs.update_workspace()
trainer = TrainerFactory._build_trainer(config)
return trainer
......
import abc
import paddle.fluid as fluid
from fleetrec.core.utils import envs
class Model(object):
"""R
"""
......@@ -33,11 +37,35 @@ class Model(object):
def get_fetch_period(self):
return self._fetch_interval
def _build_optimizer(self, name, lr):
name = name.upper()
optimizers = ["SGD", "ADAM", "ADAGRAD"]
if name not in optimizers:
raise ValueError("configured optimizer can only supported SGD/Adam/Adagrad")
if name == "SGD":
optimizer_i = fluid.optimizer.Adam(lr, lazy_mode=True)
elif name == "ADAM":
optimizer_i = fluid.optimizer.Adam(lr, lazy_mode=True)
elif name == "ADAGRAD":
optimizer_i = fluid.optimizer.Adam(lr, lazy_mode=True)
else:
raise ValueError("configured optimizer can only supported SGD/Adam/Adagrad")
return optimizer_i
def optimizer(self):
learning_rate = envs.get_global_env("hyper_parameters.learning_rate", None, self._namespace)
optimizer = envs.get_global_env("hyper_parameters.optimizer", None, self._namespace)
return self._build_optimizer(optimizer, learning_rate)
@abc.abstractmethod
def train_net(self):
"""R
"""
pass
@abc.abstractmethod
def infer_net(self):
pass
......@@ -46,9 +46,11 @@ def set_runtime_environs(environs):
for k, v in environs.items():
os.environ[k] = str(v)
def get_runtime_environ(key):
return os.getenv(key, None)
def get_trainer():
train_mode = get_runtime_environ("train.trainer.trainer")
return train_mode
......@@ -83,6 +85,25 @@ def get_global_envs():
return global_envs
def update_workspace():
workspace = global_envs.get("train.workspace", None)
if not workspace:
return
workspace = ""
# is fleet inner models
if workspace.startswith("fleetrec."):
fleet_package = get_runtime_environ("PACKAGE_BASE")
workspace_dir = workspace.split("fleetrec.")[1].replace(".", "/")
path = os.path.join(fleet_package, workspace_dir)
else:
path = workspace
for name, value in global_envs.items():
if isinstance(value, str):
value = value.replace("{workspace}", path)
global_envs[name] = value
def pretty_print_envs(envs, header=None):
spacing = 5
max_k = 45
......
......@@ -63,12 +63,9 @@ class Model(ModelBase):
feed_list=self._data_var, capacity=64, use_double_buffer=False, iterable=False)
def net(self):
trainer = envs.get_trainer()
is_distributed = True if trainer == "CtrTrainer" else False
is_distributed = True if envs.get_trainer() == "CtrTrainer" else False
sparse_feature_number = envs.get_global_env("hyper_parameters.sparse_feature_number", None, self._namespace)
sparse_feature_dim = envs.get_global_env("hyper_parameters.sparse_feature_dim", None, self._namespace)
sparse_feature_dim = 9 if trainer == "CtrTrainer" else sparse_feature_dim
def embedding_layer(input):
emb = fluid.layers.embedding(
......@@ -106,8 +103,7 @@ class Model(ModelBase):
size=2,
act="softmax",
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(fcs[-1].shape[1]))),
)
scale=1 / math.sqrt(fcs[-1].shape[1]))))
self.predict = predict
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册