提交 ebeb23c2 编写于 作者: F frankwhzhang

some question

......@@ -100,6 +100,7 @@ class RunnerBase(object):
fetch_period = int(
envs.get_global_env("runner." + context["runner_name"] +
".print_interval", 20))
scope = context["model"][model_name]["scope"]
program = context["model"][model_name]["main_program"]
reader = context["dataset"][reader_name]
......@@ -139,6 +140,9 @@ class RunnerBase(object):
fetch_period = int(
envs.get_global_env("runner." + context["runner_name"] +
".print_interval", 20))
save_step_interval = int(
envs.get_global_env("runner." + context["runner_name"] +
".save_step_interval", -1))
if context["is_infer"]:
metrics = model_class.get_infer_results()
else:
......@@ -202,6 +206,24 @@ class RunnerBase(object):
metrics_logging.insert(1, seconds)
begin_time = end_time
logging.info(metrics_format.format(*metrics_logging))
if save_step_interval >= 1 and batch_id % save_step_interval == 0 and context[
"is_infer"] == False:
if context["fleet_mode"].upper() == "PS":
train_prog = context["model"][model_dict["name"]][
"main_program"]
else:
train_prog = context["model"][model_dict["name"]][
"default_main_program"]
startup_prog = context["model"][model_dict["name"]][
"startup_program"]
with fluid.program_guard(train_prog, startup_prog):
self.save(
context,
is_fleet=context["is_fleet"],
epoch_id=None,
batch_id=batch_id)
batch_id += 1
except fluid.core.EOFException:
reader.reset()
......@@ -314,7 +336,7 @@ class RunnerBase(object):
exec_strategy=_exe_strategy)
return program
def save(self, epoch_id, context, is_fleet=False):
def save(self, context, is_fleet=False, epoch_id=None, batch_id=None):
def need_save(epoch_id, epoch_interval, is_last=False):
name = "runner." + context["runner_name"] + "."
total_epoch = int(envs.get_global_env(name + "epochs", 1))
......@@ -371,7 +393,8 @@ class RunnerBase(object):
assert dirname is not None
dirname = os.path.join(dirname, str(epoch_id))
logging.info("\tsave epoch_id:%d model into: \"%s\"" %
(epoch_id, dirname))
if is_fleet:
warnings.warn(
"Save inference model in cluster training is not recommended! Using save checkpoint instead.",
......@@ -394,14 +417,35 @@ class RunnerBase(object):
if dirname is None or dirname == "":
return
dirname = os.path.join(dirname, str(epoch_id))
logging.info("\tsave epoch_id:%d model into: \"%s\"" %
(epoch_id, dirname))
if is_fleet:
if context["fleet"].worker_index() == 0:
context["fleet"].save_persistables(context["exe"], dirname)
else:
fluid.io.save_persistables(context["exe"], dirname)
def save_checkpoint_step():
name = "runner." + context["runner_name"] + "."
save_interval = int(
envs.get_global_env(name + "save_step_interval", -1))
dirname = envs.get_global_env(name + "save_step_path", None)
if dirname is None or dirname == "":
return
dirname = os.path.join(dirname, str(batch_id))
logging.info("\tsave batch_id:%d model into: \"%s\"" %
(batch_id, dirname))
if is_fleet:
if context["fleet"].worker_index() == 0:
context["fleet"].save_persistables(context["exe"], dirname)
else:
fluid.io.save_persistables(context["exe"], dirname)
save_persistables()
save_inference_model()
if isinstance(epoch_id, int):
save_persistables()
save_inference_model()
if isinstance(batch_id, int):
save_checkpoint_step()
class SingleRunner(RunnerBase):
......@@ -453,7 +497,7 @@ class SingleRunner(RunnerBase):
startup_prog = context["model"][model_dict["name"]][
"startup_program"]
with fluid.program_guard(train_prog, startup_prog):
self.save(epoch, context)
self.save(context=context, epoch_id=epoch)
context["status"] = "terminal_pass"
......@@ -506,7 +550,7 @@ class PSRunner(RunnerBase):
startup_prog = context["model"][model_dict["name"]][
"startup_program"]
with fluid.program_guard(train_prog, startup_prog):
self.save(epoch, context, True)
self.save(context=context, is_fleet=True, epoch_id=epoch)
context["status"] = "terminal_pass"
......@@ -539,7 +583,7 @@ class CollectiveRunner(RunnerBase):
startup_prog = context["model"][model_dict["name"]][
"startup_program"]
with fluid.program_guard(train_prog, startup_prog):
self.save(epoch, context, True)
self.save(context=context, is_fleet=True, epoch_id=epoch)
context["status"] = "terminal_pass"
......
......@@ -20,7 +20,7 @@ import socket
import sys
import six
import traceback
import six
import warnings
global_envs = {}
global_envs_flatten = {}
......@@ -98,6 +98,25 @@ def set_global_envs(envs):
value = os_path_adapter(workspace_adapter(value))
global_envs[name] = value
for runner in envs["runner"]:
if "save_step_interval" in runner or "save_step_path" in runner:
phase_name = runner["phases"]
phase = [
phase for phase in envs["phase"]
if phase["name"] == phase_name[0]
]
dataset_name = phase[0].get("dataset_name")
dataset = [
dataset for dataset in envs["dataset"]
if dataset["name"] == dataset_name
]
if dataset[0].get("type") == "QueueDataset":
runner["save_step_interval"] = None
runner["save_step_path"] = None
warnings.warn(
"QueueDataset can not support save by step, please not config save_step_interval and save_step_path in your yaml"
)
if get_platform() != "LINUX":
for dataset in envs["dataset"]:
name = ".".join(["dataset", dataset["name"], "type"])
......
......@@ -27,6 +27,8 @@
| init_model_path | string | 路径 | 否 | 初始化模型地址 |
| save_checkpoint_interval | int | >= 1 | 否 | Save参数的轮数间隔 |
| save_checkpoint_path | string | 路径 | 否 | Save参数的地址 |
| save_step_interval | int | >= 1 | 否 | Step save参数的batch数间隔 |
| save_step_path | string | 路径 | 否 | Step save参数的地址 |
| save_inference_interval | int | >= 1 | 否 | Save预测模型的轮数间隔 |
| save_inference_path | string | 路径 | 否 | Save预测模型的地址 |
| save_inference_feed_varnames | list[string] | 组网中指定Variable的name | 否 | 预测模型的入口变量name |
......
......@@ -114,6 +114,23 @@ runner:
print_interval: 1
phases: [phase1]
- name: single_multi_gpu_train
class: train
# num of epochs
epochs: 1
# device to run training or infer
device: gpu
selected_gpus: "0,1" # 选择多卡执行训练
save_checkpoint_interval: 1 # save model interval of epochs
save_inference_interval: 4 # save inference
save_step_interval: 1
save_checkpoint_path: "increment_dnn" # save checkpoint path
save_inference_path: "inference" # save inference path
save_step_path: "step_save"
save_inference_feed_varnames: [] # feed vars of save inference
save_inference_fetch_varnames: [] # fetch vars of save inference
print_interval: 1
phases: [phase1]
# runner will run all the phase in each epoch
phase:
- name: phase1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册