提交 6b78ba4e 编写于 作者: X xixiaoyao

fix multi head inference

上级 01a745b7
......@@ -125,27 +125,8 @@ class MultiHeadTrainer(Trainer):
branch_index=task_id_var,
branch_fns=task_fns
)
# self._task_id_var = task_id_var
# self._loss_var = loss_var
# self._fetch_list = [loss_var.name]
if not self._multi_task:
self._init_exe_prog(for_train=False)
# return self.build_forward()
# """
# Build computation graph for evaluation and prediction.
# Arguments:
# - pred_backbone: a Backbone object with phase == 'predict'. For evaluating model during training, the predict backbone should keep the same with train backbone.
# - pred_head: a Head object with phase == 'predict'. For evaluating model during training, the predict head should keep the same with train head.
#
# Return:
# - output_vars: dict type. Each value is a computational graph variable(node) argumented by pred_head outputs_attr.
# """
# for i in self._trainers:
# assert i._predict_vars is not None, "{} need to build_predict_forward before "
#
# return output_vars
def merge_inference_readers(self, readers):
......
......@@ -215,16 +215,23 @@ class Trainer(object):
self._pred_name_to_position = pred_name_to_position
self._pred_input_names = pred_input_names
if not self._lock_prog:
pred_prog = fluid.Program()
self._pred_prog = pred_prog
pred_init_prog = fluid.Program()
self._pred_init_prog = pred_init_prog
with fluid.program_guard(pred_prog, pred_init_prog):
pred_net_inputs = reader_helper.create_net_inputs(pred_input_attrs)
pred_bb_output_vars = pred_backbone.build(pred_net_inputs)
self._pred_net_inputs = pred_net_inputs
else:
pred_net_inputs = reader_helper.create_net_inputs(pred_input_attrs)
pred_bb_output_vars = pred_backbone.build(pred_net_inputs)
self._pred_net_inputs = pred_net_inputs
# prepare predict vars for saving inference model
if not self._lock_prog:
with fluid.program_guard(pred_prog, pred_init_prog):
cur_inputs = helper.decode_inputs(pred_net_inputs, self.name)
self._pred_input_name_list, self._pred_input_varname_list = \
......@@ -234,6 +241,15 @@ class Trainer(object):
scope = self.name + '.'
with fluid.unique_name.guard(scope):
output_vars = self._build_head(pred_task_inputs, phase='predict', scope=scope)
else:
cur_inputs = helper.decode_inputs(pred_net_inputs, self.name)
self._pred_input_name_list, self._pred_input_varname_list = \
zip(*[[k, v.name] for k,v in cur_inputs.items()])
pred_task_inputs = {'backbone': pred_bb_output_vars, 'reader': cur_inputs}
scope = self.name + '.'
with fluid.unique_name.guard(scope):
output_vars = self._build_head(pred_task_inputs, phase='predict', scope=scope)
if output_vars is not None:
self._pred_fetch_name_list, self._pred_fetch_list = zip(*output_vars.items())
......@@ -385,20 +401,32 @@ class Trainer(object):
"""
assert self._train_init_prog is not None or self._pred_init_prog is not None, "model graph not built. You should at least build_forward or build_predict_forward to load its checkpoint."
# if self._train_init_prog is not None:
# saver.init_pretraining_params(
# self._exe,
# model_path,
# convert=False,
# main_program=self._train_init_prog,
# strict=True)
# elif self._pred_init_prog is not None:
# saver.init_pretraining_params(
# self._exe,
# model_path,
# convert=False,
# main_program=self._pred_init_prog,
# strict=True)
if self._train_init_prog is not None:
saver.init_pretraining_params(
print('loading checkpoint into train program')
saver.init_checkpoint(
self._exe,
model_path,
convert=False,
main_program=self._train_init_prog,
strict=True)
main_program=self._train_init_prog)
elif self._pred_init_prog is not None:
saver.init_pretraining_params(
saver.init_checkpoint(
self._exe,
model_path,
convert=False,
main_program=self._pred_init_prog,
strict=True)
main_program=self._pred_init_prog)
else:
raise Exception("model not found. You should at least build_forward or build_predict_forward to load its checkpoint.")
......@@ -529,6 +557,7 @@ class Trainer(object):
iterator = self._predict_iterator
self._distribute_pred_prog = fluid.CompiledProgram(self._pred_prog).with_data_parallel()
if output_dir is not None and not os.path.exists(output_dir):
os.makedirs(output_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册