提交 2555c0e2 编写于 作者: X xixiaoyao

add ernie argument

上级 bbbb7357
...@@ -31,7 +31,7 @@ class ERNIE(Backbone): ...@@ -31,7 +31,7 @@ class ERNIE(Backbone):
def __init__(self, hidden_size, num_hidden_layers, num_attention_heads, vocab_size, \ def __init__(self, hidden_size, num_hidden_layers, num_attention_heads, vocab_size, \
max_position_embeddings, sent_type_vocab_size, task_type_vocab_size, \ max_position_embeddings, sent_type_vocab_size, task_type_vocab_size, \
hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, initializer_range, is_pairwise=False, phase='train'): hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, initializer_range, is_pairwise=False, use_task_emb=True, phase='train'):
# self._is_training = phase == 'train' # backbone一般不用关心运行阶段,因为outputs在任何阶段基本不会变 # self._is_training = phase == 'train' # backbone一般不用关心运行阶段,因为outputs在任何阶段基本不会变
...@@ -54,6 +54,7 @@ class ERNIE(Backbone): ...@@ -54,6 +54,7 @@ class ERNIE(Backbone):
self._task_emb_name = "task_embedding" self._task_emb_name = "task_embedding"
self._emb_dtype = "float32" self._emb_dtype = "float32"
self._is_pairwise = is_pairwise self._is_pairwise = is_pairwise
self._use_task_emb = use_task_emb
self._phase=phase self._phase=phase
self._param_initializer = fluid.initializer.TruncatedNormal( self._param_initializer = fluid.initializer.TruncatedNormal(
scale=initializer_range) scale=initializer_range)
...@@ -85,6 +86,10 @@ class ERNIE(Backbone): ...@@ -85,6 +86,10 @@ class ERNIE(Backbone):
task_type_vocab_size = config['task_type_vocab_size'] task_type_vocab_size = config['task_type_vocab_size']
else: else:
task_type_vocab_size = config['type_vocab_size'] task_type_vocab_size = config['type_vocab_size']
if 'use_task_emb' in config:
use_task_emb = config['use_task_emb']
else:
use_task_emb = True
hidden_act = config['hidden_act'] hidden_act = config['hidden_act']
hidden_dropout_prob = config['hidden_dropout_prob'] hidden_dropout_prob = config['hidden_dropout_prob']
attention_probs_dropout_prob = config['attention_probs_dropout_prob'] attention_probs_dropout_prob = config['attention_probs_dropout_prob']
...@@ -96,7 +101,7 @@ class ERNIE(Backbone): ...@@ -96,7 +101,7 @@ class ERNIE(Backbone):
return cls(hidden_size, num_hidden_layers, num_attention_heads, vocab_size, \ return cls(hidden_size, num_hidden_layers, num_attention_heads, vocab_size, \
max_position_embeddings, sent_type_vocab_size, task_type_vocab_size, \ max_position_embeddings, sent_type_vocab_size, task_type_vocab_size, \
hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, initializer_range, is_pairwise, phase=phase) hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, initializer_range, is_pairwise, use_task_emb=use_task_emb, phase=phase)
@property @property
def inputs_attr(self): def inputs_attr(self):
...@@ -180,6 +185,7 @@ class ERNIE(Backbone): ...@@ -180,6 +185,7 @@ class ERNIE(Backbone):
emb_out = emb_out + position_emb_out emb_out = emb_out + position_emb_out
emb_out = emb_out + sent_emb_out emb_out = emb_out + sent_emb_out
if self._use_task_emb:
task_emb_out = fluid.embedding( task_emb_out = fluid.embedding(
task_ids, task_ids,
size=[self._task_types, self._emb_size], size=[self._task_types, self._emb_size],
......
...@@ -122,11 +122,11 @@ class Head(object): ...@@ -122,11 +122,11 @@ class Head(object):
output_dir: 积累结果的保存路径。 output_dir: 积累结果的保存路径。
""" """
if output_dir is not None: if output_dir is not None:
for i in self._results_buffer:
print(i)
else:
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
with open(os.path.join(output_dir, self._phase), 'w') as writer: with open(os.path.join(output_dir, self._phase), 'w') as writer:
for i in self._results_buffer: for i in self._results_buffer:
writer.write(json.dumps(i)+'\n') writer.write(json.dumps(i)+'\n')
else:
return self._results_buffer
...@@ -159,8 +159,6 @@ class Match(Head): ...@@ -159,8 +159,6 @@ class Match(Head):
else: else:
return {'probs': pos_score} return {'probs': pos_score}
def batch_postprocess(self, rt_outputs): def batch_postprocess(self, rt_outputs):
if not self._is_training: if not self._is_training:
probs = [] probs = []
...@@ -171,6 +169,10 @@ class Match(Head): ...@@ -171,6 +169,10 @@ class Match(Head):
logits = rt_outputs['logits'] logits = rt_outputs['logits']
self._preds_logits.extend(logits.tolist()) self._preds_logits.extend(logits.tolist())
def reset(self):
self._preds_logits = []
self._preds = []
def epoch_postprocess(self, post_inputs, output_dir=None): def epoch_postprocess(self, post_inputs, output_dir=None):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs # there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if not self._is_training: if not self._is_training:
......
...@@ -587,6 +587,9 @@ class Trainer(object): ...@@ -587,6 +587,9 @@ class Trainer(object):
results = self._pred_head.epoch_postprocess({'reader':reader_outputs}, output_dir=output_dir) results = self._pred_head.epoch_postprocess({'reader':reader_outputs}, output_dir=output_dir)
return results return results
def reset_buffer(self):
self._pred_head.reset()
def _check_phase(self, phase): def _check_phase(self, phase):
assert phase in ['train', 'predict'], "Supported phase: train, predict," assert phase in ['train', 'predict'], "Supported phase: train, predict,"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册