提交 2555c0e2 编写于 作者: X xixiaoyao

add ernie argument

上级 bbbb7357
......@@ -31,7 +31,7 @@ class ERNIE(Backbone):
def __init__(self, hidden_size, num_hidden_layers, num_attention_heads, vocab_size, \
max_position_embeddings, sent_type_vocab_size, task_type_vocab_size, \
hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, initializer_range, is_pairwise=False, phase='train'):
hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, initializer_range, is_pairwise=False, use_task_emb=True, phase='train'):
# self._is_training = phase == 'train' # backbone一般不用关心运行阶段,因为outputs在任何阶段基本不会变
......@@ -54,6 +54,7 @@ class ERNIE(Backbone):
self._task_emb_name = "task_embedding"
self._emb_dtype = "float32"
self._is_pairwise = is_pairwise
self._use_task_emb = use_task_emb
self._phase=phase
self._param_initializer = fluid.initializer.TruncatedNormal(
scale=initializer_range)
......@@ -85,6 +86,10 @@ class ERNIE(Backbone):
task_type_vocab_size = config['task_type_vocab_size']
else:
task_type_vocab_size = config['type_vocab_size']
if 'use_task_emb' in config:
use_task_emb = config['use_task_emb']
else:
use_task_emb = True
hidden_act = config['hidden_act']
hidden_dropout_prob = config['hidden_dropout_prob']
attention_probs_dropout_prob = config['attention_probs_dropout_prob']
......@@ -96,7 +101,7 @@ class ERNIE(Backbone):
return cls(hidden_size, num_hidden_layers, num_attention_heads, vocab_size, \
max_position_embeddings, sent_type_vocab_size, task_type_vocab_size, \
hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, initializer_range, is_pairwise, phase=phase)
hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, initializer_range, is_pairwise, use_task_emb=use_task_emb, phase=phase)
@property
def inputs_attr(self):
......@@ -180,6 +185,7 @@ class ERNIE(Backbone):
emb_out = emb_out + position_emb_out
emb_out = emb_out + sent_emb_out
if self._use_task_emb:
task_emb_out = fluid.embedding(
task_ids,
size=[self._task_types, self._emb_size],
......
......@@ -122,11 +122,11 @@ class Head(object):
output_dir: 积累结果的保存路径。
"""
if output_dir is not None:
for i in self._results_buffer:
print(i)
else:
if not os.path.exists(output_dir):
os.makedirs(output_dir)
with open(os.path.join(output_dir, self._phase), 'w') as writer:
for i in self._results_buffer:
writer.write(json.dumps(i)+'\n')
else:
return self._results_buffer
......@@ -159,8 +159,6 @@ class Match(Head):
else:
return {'probs': pos_score}
def batch_postprocess(self, rt_outputs):
if not self._is_training:
probs = []
......@@ -171,6 +169,10 @@ class Match(Head):
logits = rt_outputs['logits']
self._preds_logits.extend(logits.tolist())
def reset(self):
self._preds_logits = []
self._preds = []
def epoch_postprocess(self, post_inputs, output_dir=None):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if not self._is_training:
......
......@@ -587,6 +587,9 @@ class Trainer(object):
results = self._pred_head.epoch_postprocess({'reader':reader_outputs}, output_dir=output_dir)
return results
def reset_buffer(self):
self._pred_head.reset()
def _check_phase(self, phase):
assert phase in ['train', 'predict'], "Supported phase: train, predict,"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册