未验证 提交 241c0282 编写于 作者: M Meiyim 提交者: GitHub

Dygraph fix4 (#464)

* upgrade version no

* seq2seq + length penalty

* upgrade to paddle 1.8

* fix readme

* update seq2seq to 1.8

* fix seq2seq beam-search

* + aisudio tutorial: loading old-styled checkpoint

* add mising seq2seq eval file for cnndm

* fix seq2seq decode post process

* + handlers
上级 655cf256
......@@ -69,10 +69,11 @@ Don't have GPU? try ERNIE in [AIStudio](https://aistudio.baidu.com/aistudio/inde
(please choose the latest version and apply for a GPU environment)
1. [ERNIE for beginners](https://aistudio.baidu.com/studio/edu/group/quick/join/314947)
1. [Sementic Analysis](https://aistudio.baidu.com/aistudio/projectdetail/427482)
2. [Cloze Test](https://aistudio.baidu.com/aistudio/projectdetail/433491)
3. [Knowledge Distillation](https://aistudio.baidu.com/aistudio/projectdetail/439460)
4. [Ask Ernie](https://aistudio.baidu.com/aistudio/projectdetail/456443)
1. [Sementic analysis](https://aistudio.baidu.com/aistudio/projectdetail/427482)
2. [Cloze test](https://aistudio.baidu.com/aistudio/projectdetail/433491)
3. [Knowledge distillation](https://aistudio.baidu.com/aistudio/projectdetail/439460)
4. [Ask ERNIE](https://aistudio.baidu.com/aistudio/projectdetail/456443)
5. [Loading old-styled checkpoint](https://aistudio.baidu.com/aistudio/projectdetail/493415)
# Setup
......
......@@ -70,6 +70,7 @@ print(pooled.numpy()) # convert results to numpy
2. [完形填空](https://aistudio.baidu.com/aistudio/projectdetail/433491)
3. [知识蒸馏](https://aistudio.baidu.com/aistudio/projectdetail/439460)
4. [万事不决问ERNIE](https://aistudio.baidu.com/aistudio/projectdetail/456443)
5. [加载并读取老式checkpoint](https://aistudio.baidu.com/aistudio/projectdetail/493415)
# 安装
......@@ -230,7 +231,7 @@ sids = np.expand_dims(sids, 0)
result = client(ids, sids)
```
你也可从[此处]((https://ernie.bj.bcebos.com/ernie1.0_zh_inference_model.tar.gz.)下载一个预先制作好的ernie-1.0 base模型的 `inference_model`.
你也可从[此处](https://ernie.bj.bcebos.com/ernie1.0_zh_inference_model.tar.gz.)下载一个预先制作好的ernie-1.0 base模型的 `inference_model`.
该模型没有经过finetune,一般可以用做上层模型结构的 feature-base finetune或者做为一个文本特征抽取器。
因为该模行由老版API 产出,在进行客户端请求时需要在输入tensor后面追加一个维度:
......
......@@ -43,7 +43,6 @@ import propeller.paddle as propeller
log.setLevel(logging.DEBUG)
logging.getLogger().addHandler(log.handlers[0])
logging.getLogger().setLevel(logging.DEBUG)
def model_fn(features, mode, params, run_config):
......
......@@ -34,7 +34,6 @@ from propeller import log
import propeller.paddle as propeller
log.setLevel(logging.DEBUG)
logging.getLogger().addHandler(log.handlers[0])
logging.getLogger().setLevel(logging.DEBUG)
......@@ -104,12 +103,12 @@ if __name__ == '__main__':
with FD.guard(place):
model = ErnieModelForSequenceClassification.from_pretrained(args.from_pretrained, num_labels=3, name='')
g_clip = F.clip.GradientClipByGlobalNorm(1.0) #experimental
if args.use_lr_decay:
opt = AdamW(learning_rate=LinearDecay(args.lr, int(args.warmup_proportion * args.max_steps), args.max_steps), parameter_list=model.parameters(), weight_decay=args.wd)
opt = AdamW(learning_rate=LinearDecay(args.lr, int(args.warmup_proportion * args.max_steps), args.max_steps), parameter_list=model.parameters(), weight_decay=args.wd, grad_clip=g_clip)
else:
opt = AdamW(args.lr, parameter_list=model.parameters(), weight_decay=args.wd)
opt = AdamW(args.lr, parameter_list=model.parameters(), weight_decay=args.wd, grad_clip=g_clip)
g_clip = F.dygraph_grad_clip.GradClipByGlobalNorm(1.0) #experimental
for epoch in range(args.epoch):
for step, d in enumerate(tqdm(train_ds.start(place), desc='training')):
ids, sids, label = d
......@@ -117,7 +116,7 @@ if __name__ == '__main__':
loss.backward()
if step % 10 == 0:
log.debug('train loss %.5f lr %.3e' % (loss.numpy(), opt.current_step_lr()))
opt.minimize(loss, grad_clip=g_clip)
opt.minimize(loss)
model.clear_gradients()
if step % 100 == 0:
acc = []
......
......@@ -33,7 +33,6 @@ from propeller import log
import propeller.paddle as propeller
log.setLevel(logging.DEBUG)
logging.getLogger().addHandler(log.handlers[0])
logging.getLogger().setLevel(logging.DEBUG)
......@@ -102,15 +101,22 @@ if __name__ == '__main__':
model = ErnieModelForSequenceClassification.from_pretrained(args.from_pretrained, num_labels=3, name='')
model = FD.parallel.DataParallel(model, ctx)
opt = AdamW(learning_rate=LinearDecay(args.lr, int(args.warmup_proportion * args.max_steps), args.max_steps), parameter_list=model.parameters(), weight_decay=args.wd)
g_clip = F.dygraph_grad_clip.GradClipByGlobalNorm(1.0) #experimental
g_clip = F.clip.GradientClipByGlobalNorm(1.0) #experimental
opt = AdamW(learning_rate=LinearDecay(
args.lr,
int(args.warmup_proportion * args.max_steps),
args.max_steps),
parameter_list=model.parameters(),
weight_decay=args.wd,
grad_clip=g_clip)
for step, d in enumerate(tqdm(train_ds.start(place), desc='training')):
ids, sids, label = d
loss, _ = model(ids, sids, labels=label)
scaled_loss = model.scale_loss(loss)
scaled_loss.backward()
model.apply_collective_grads()
opt.minimize(scaled_loss, grad_clip=g_clip)
opt.minimize(scaled_loss)
model.clear_gradients()
if step % 10 == 0:
log.debug('train loss %.5f, lr %.e3' % (loss.numpy(), opt.current_step_lr()))
......
......@@ -48,7 +48,6 @@ from demo.mrc import mrc_reader
from demo.mrc import mrc_metrics
log.setLevel(logging.DEBUG)
logging.getLogger().addHandler(log.handlers[0])
logging.getLogger().setLevel(logging.DEBUG)
......@@ -84,7 +83,7 @@ def train(model, train_dataset, dev_dataset, dev_examples, dev_features, tokeniz
max_steps = len(train_features) * args.epoch // args.bsz
opt = AdamW(learning_rate=args.lr, parameter_list=model.parameters(), weight_decay=args.wd)
g_clip = F.dygraph_grad_clip.GradClipByGlobalNorm(1.0) #experimental
g_clip = F.clip.GradientClipByGlobalNorm(1.0) #experimental
train_dataset = train_dataset \
.repeat() \
......
......@@ -39,7 +39,6 @@ from propeller import log
import propeller.paddle as propeller
log.setLevel(logging.DEBUG)
logging.getLogger().addHandler(log.handlers[0])
logging.getLogger().setLevel(logging.DEBUG)
from ernie.modeling_ernie import ErnieModel, ErnieModelForSequenceClassification, ErnieModelForTokenClassification
......@@ -127,13 +126,14 @@ if __name__ == '__main__':
test_ds.data_shapes = shapes
test_ds.data_types = types
with FD.guard():
place = F.CUDAPlace(0)
with FD.guard(place):
model = ErnieModelForTokenClassification.from_pretrained(args.from_pretrained, num_labels=7, name='')
opt = AdamW(learning_rate=LinearDecay(args.lr, args.warmup_steps, args.max_steps), parameter_list=model.parameters(), weight_decay=0.01)
#opt = F.optimizer.AdamOptimizer(learning_rate=LinearDecay(args.lr, args.warmup_steps, args.max_steps), parameter_list=model.parameters())
for epoch in range(args.epoch):
for step, (ids, sids, aligned_label, label, orig_pos) in enumerate(tqdm(train_ds.start())):
for step, (ids, sids, aligned_label, label, orig_pos) in enumerate(tqdm(train_ds.start(place))):
loss, _ = model(ids, sids, labels=aligned_label)
loss.backward()
if step % 10 == 0 :
......@@ -144,7 +144,7 @@ if __name__ == '__main__':
all_pred, all_label = [], []
with FD.base._switch_tracer_mode_guard_(is_train=False):
model.eval()
for step, (ids, sids, aligned_label, label, orig_pos) in enumerate(tqdm(dev_ds.start())):
for step, (ids, sids, aligned_label, label, orig_pos) in enumerate(tqdm(dev_ds.start(place))):
loss, logits = model(ids, sids, labels=aligned_label)
#print('\n'.join(map(str, logits.numpy().tolist())))
......
......@@ -34,7 +34,6 @@ from propeller import log
import propeller.paddle as propeller
log.setLevel(logging.DEBUG)
logging.getLogger().addHandler(log.handlers[0])
logging.getLogger().setLevel(logging.DEBUG)
log = logging.getLogger()
......@@ -101,7 +100,7 @@ if __name__ == '__main__':
int(args.warmup_proportion * args.max_steps), args.max_steps),
parameter_list=model.parameters(),
weight_decay=args.wd)
g_clip = F.dygraph_grad_clip.GradClipByGlobalNorm(1.0) #experimental
g_clip = F.clip.GradientClipByGlobalNorm(1.0) #experimental
for epoch in range(args.epoch):
for step, d in enumerate(tqdm(train_ds.start(place), desc='training')):
ids, sids, label = d
......
......@@ -48,7 +48,6 @@ from propeller.paddle.data import Dataset
from propeller import log
log.setLevel(logging.DEBUG)
logging.getLogger().addHandler(log.handlers[0])
logging.getLogger().setLevel(logging.DEBUG)
if six.PY3:
......
......@@ -49,7 +49,6 @@ from propeller.paddle.data import Dataset
from propeller import log
log.setLevel(logging.DEBUG)
logging.getLogger().addHandler(log.handlers[0])
logging.getLogger().setLevel(logging.DEBUG)
if six.PY3:
......
......@@ -100,7 +100,7 @@ teacher_model = ErnieModelForSequenceClassification.from_pretrained('ernie-1.0',
teacher_model.train()
if not os.path.exists('./teacher_model.pdparams'):
opt = AdamW(learning_rate=LinearDecay(LR, 9600*EPOCH*0.1/BATCH, 9600*EPOCH/BATCH), parameter_list=teacher_model.parameters(), weight_decay=0.01)
g_clip = F.dygraph_grad_clip.GradClipByGlobalNorm(1.0)
g_clip = F.clip.GradientClipByGlobalNorm(1.0)
for epoch in range(EPOCH):
for step, (ids_student, ids, sids, labels) in enumerate(train_ds.start(place)):
loss, logits = teacher_model(ids, labels=labels)
......@@ -200,7 +200,7 @@ def KL(pred, target):
teacher_model.eval()
model = BOW()
opt = AdamW(learning_rate=LR, parameter_list=model.parameters(), weight_decay=0.01)
g_clip = F.dygraph_grad_clip.GradClipByGlobalNorm(1.0) #experimental
g_clip = F.clip.GradientClipByGlobalNorm(1.0) #experimental
model.train()
for epoch in range(EPOCH):
for step, (ids_student, ids, sids, _ ) in enumerate(train_ds.start(place)):
......
......@@ -19,6 +19,7 @@ from __future__ import print_function
from __future__ import unicode_literals
import sys
import re
import argparse
import logging
import json
......@@ -81,7 +82,7 @@ def gen_bias(encoder_inputs, decoder_inputs, step):
@D.no_grad
def greedy_search_infilling(model, q_ids, q_sids, sos_id, eos_id, attn_id, max_encode_len=640, max_decode_len=100):
def greedy_search_infilling(model, q_ids, q_sids, sos_id, eos_id, attn_id, max_encode_len=640, max_decode_len=100, tgt_type_id=3):
model.eval()
#log.debug(q_ids.numpy().tolist())
_, logits, info = model(q_ids, q_sids)
......@@ -104,7 +105,7 @@ def greedy_search_infilling(model, q_ids, q_sids, sos_id, eos_id, attn_id, max_e
bias = gen_bias(q_ids, ids, step)
pos_ids = D.to_variable(np.tile(np.array([[step, step + 1]], dtype=np.int64), [d_batch, 1]))
pos_ids += seqlen
_, logits, info = model(ids, L.ones_like(ids) * 3, pos_ids=pos_ids, attn_bias=bias, past_cache=past_cache)
_, logits, info = model(ids, L.ones_like(ids) * tgt_type_id, pos_ids=pos_ids, attn_bias=bias, past_cache=past_cache)
gen_ids = L.argmax(logits, -1)
past_cached_k, past_cached_v = past_cache
......@@ -143,30 +144,30 @@ def mask_prob(p, onehot_eos, finished):
return p
def hyp_score(log_probs, length):
factor=1.
lp = L.pow((5.+L.cast(length, 'float32')) / 6., factor)
def hyp_score(log_probs, length, length_penalty):
lp = L.pow((5.+L.cast(length, 'float32')) / 6., length_penalty)
return log_probs / lp
def beam_search_step(state, logits, eos_id, beam_width, is_first_step):
def beam_search_step(state, logits, eos_id, beam_width, is_first_step, length_penalty):
"""logits.shape == [B*W, V]"""
_, vocab_size = logits.shape
bsz, beam_width = state.log_probs.shape
onehot_eos = L.cast(F.one_hot(L.ones([bsz * beam_width], 'int64') * eos_id, vocab_size), 'int64') #[1, V]
onehot_eos = L.cast(F.one_hot(L.ones([1], 'int64') * eos_id, vocab_size), 'int64') #[1, V]
probs = L.log(L.softmax(logits)) #[B*W, V]
probs = mask_prob(probs, onehot_eos, state.finished) #[B*W, V]
allprobs = L.reshape(state.log_probs, [-1, 1]) + probs #[B*W, V]
length_to_add = 1 - L.reshape(state.finished, [-1, 1]) #[B*W,1]
length_to_add = L.cast((length_to_add + 1 - onehot_eos) != 0, 'int64') #[B*W,V]
not_finished = 1 - L.reshape(state.finished, [-1, 1]) #[B*W,1]
not_eos = 1 - onehot_eos
length_to_add = not_finished * not_eos #[B*W,V]
alllen = L.reshape(state.lengths, [-1, 1]) + length_to_add
allprobs = L.reshape(allprobs, [-1, beam_width * vocab_size])
alllen = L.reshape(alllen, [-1, beam_width * vocab_size])
allscore = hyp_score(allprobs, alllen)
allscore = hyp_score(allprobs, alllen, length_penalty)
if is_first_step:
allscore = L.reshape(allscore, [bsz, beam_width, -1])[:,0,:] # first step only consiter beam 0
scores, idx = L.topk(allscore, k=beam_width) #[B, W]
......@@ -184,6 +185,7 @@ def beam_search_step(state, logits, eos_id, beam_width, is_first_step):
#log.debug(next_finished.numpy())
next_finished += L.cast(next_word_id==eos_id, 'int64')
next_finished = L.cast(next_finished > 0, 'int64')
#log.debug(next_word_id.numpy())
#log.debug(next_beam_id.numpy())
......@@ -194,7 +196,7 @@ def beam_search_step(state, logits, eos_id, beam_width, is_first_step):
@D.no_grad
def beam_search_infilling(model, q_ids, q_sids, sos_id, eos_id, attn_id, max_encode_len=640, max_decode_len=100, beam_width=5, tgt_type_id=3):
def beam_search_infilling(model, q_ids, q_sids, sos_id, eos_id, attn_id, max_encode_len=640, max_decode_len=100, beam_width=5, tgt_type_id=3, length_penalty=1.0):
model.eval()
#log.debug(q_ids.numpy().tolist())
_, __, info = model(q_ids, q_sids)
......@@ -206,6 +208,12 @@ def beam_search_infilling(model, q_ids, q_sids, sos_id, eos_id, attn_id, max_enc
finished=L.zeros([d_batch, beam_width], 'int64'))
outputs = []
def reorder_(t, parent_id):
"""reorder cache according to parent beam id"""
gather_idx = L.where(parent_id!=-1)[:, 0] * beam_width + L.reshape(parent_id, [-1])
t = L.gather(t, gather_idx)
return t
def tile_(t, times):
_shapes = list(t.shape[1:])
ret = L.reshape(L.expand(L.unsqueeze(t, [1]), [1, times,] + [1,] * len(_shapes)), [-1,] + _shapes)
......@@ -230,14 +238,21 @@ def beam_search_infilling(model, q_ids, q_sids, sos_id, eos_id, attn_id, max_enc
pos_ids += seqlen
_, logits, info = model(ids, L.ones_like(ids) * tgt_type_id, pos_ids=pos_ids, attn_bias=bias, past_cache=past_cache)
output, state = beam_search_step(state, logits[:, 1],
eos_id=eos_id,
beam_width=beam_width,
is_first_step=(step==0),
length_penalty=length_penalty)
outputs.append(output)
past_cached_k, past_cached_v = past_cache
cached_k, cached_v = info['caches']
cached_k = [L.concat([pk, k[:, :1, :]], 1) for pk, k in zip(past_cached_k, cached_k)] # concat cached
cached_v = [L.concat([pv, v[:, :1, :]], 1) for pv, v in zip(past_cached_v, cached_v)]
cached_k = [reorder_(L.concat([pk, k[:, :1, :]], 1), output.beam_parent_ids) for pk, k in zip(past_cached_k, cached_k)] # concat cached
cached_v = [reorder_(L.concat([pv, v[:, :1, :]], 1), output.beam_parent_ids) for pv, v in zip(past_cached_v, cached_v)]
past_cache = (cached_k, cached_v)
output, state = beam_search_step(state, logits[:, 1], eos_id=eos_id, beam_width=beam_width, is_first_step=(step==0))
outputs.append(output)
pred_ids_flatten = L.reshape(output.predicted_ids, [d_batch * beam_width])
ids = L.stack([pred_ids_flatten, attn_ids], 1)
......@@ -251,15 +266,19 @@ def beam_search_infilling(model, q_ids, q_sids, sos_id, eos_id, attn_id, max_enc
final_ids = L.gather_tree(final_ids, final_parent_ids)[:,:,0] #pick best beam
final_ids = L.transpose(L.reshape(final_ids, [-1, d_batch * 1]), [1, 0])
return final_ids
en_patten = re.compile(r'^[a-zA-Z0-9]*$')
def post_process(token):
if token.startswith('##'):
ret = token[2:]
else:
ret = ' ' + token
if en_patten.match(token):
ret = ' ' + token
else:
ret = token
return ret
if __name__ == '__main__':
parser = argparse.ArgumentParser('seq2seq model with ERNIE')
......@@ -268,7 +287,9 @@ if __name__ == '__main__':
parser.add_argument('--max_encode_len', type=int, default=640)
parser.add_argument('--max_decode_len', type=int, default=120)
parser.add_argument('--tgt_type_id', type=int, default=3)
parser.add_argument('--beam_width', type=int, default=3)
parser.add_argument('--beam_width', type=int, default=5)
parser.add_argument('--attn_token', type=str, default='[ATTN]', help='if [ATTN] not in vocab, you can specified [MAKK] as attn-token')
parser.add_argument('--length_penalty', type=float, default=1.0)
parser.add_argument('--save_dir', type=str, required=True, help='model dir to be loaded')
args = parser.parse_args()
......@@ -282,12 +303,14 @@ if __name__ == '__main__':
rev_dict[tokenizer.pad_id] = '' # replace [PAD]
rev_dict[tokenizer.unk_id] = '' # replace [PAD]
sd, _ = D.load_dygraph(args.save_dir)
ernie.set_dict(sd)
def map_fn(src_ids):
src_ids = src_ids[: args.max_encode_len]
src_ids, src_sids = tokenizer.build_for_ernie(src_ids)
return (src_ids, src_sids)
bytes_vocab = {k.encode('utf8'): v for k, v in tokenizer.vocab.items()}
feature_column = propeller.data.FeatureColumns([
propeller.data.TextColumn('seg_a', unk_id=tokenizer.unk_id, vocab_dict=tokenizer.vocab, tokenizer=tokenizer.tokenize),
])
......@@ -297,17 +320,19 @@ if __name__ == '__main__':
#result_ids = greedy_search_infilling(ernie, D.to_variable(encoder_ids), D.to_variable(encoder_sids),
# eos_id=tokenizer.sep_id,
# sos_id=tokenizer.cls_id,
# attn_id=tokenizer.vocab['[ATTN]'],
# attn_id=tokenizer.vocab[args.attn_id],
# max_decode_len=args.max_decode_len,
# max_encode_len=args.max_encode_len,
# beam_width=args.beam_width)
# beam_width=args.beam_width,
# tgt_type_id=args.tgt_type_id)
result_ids = beam_search_infilling(ernie, D.to_variable(encoder_ids), D.to_variable(encoder_sids),
eos_id=tokenizer.sep_id,
sos_id=tokenizer.cls_id,
attn_id=tokenizer.vocab['[ATTN]'],
attn_id=tokenizer.vocab[args.attn_token],
max_decode_len=args.max_decode_len,
max_encode_len=args.max_encode_len,
beam_width=args.beam_width,
length_penalty=args.length_penalty,
tgt_type_id=args.tgt_type_id)
output_str = rev_lookup(result_ids.numpy())
......@@ -316,5 +341,6 @@ if __name__ == '__main__':
ostr = ostr[: ostr.index('[SEP]')]
ostr = ''.join(map(post_process, ostr))
ostr = ostr.strip()
print(ostr)
#!./python3.6/bin/python
import argparse
from pyrouge import Rouge155
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(help="Path of the directory containing ROUGE-1.5.5.pl.",
type=str, action="store", dest="home_dir")
return parser.parse_args()
def main():
args = get_args()
Rouge155(args.home_dir)
if __name__ == "__main__":
main()
......@@ -68,6 +68,7 @@ def evaluate(model, datasets, step, args):
max_decode_len=args.max_decode_len,
max_encode_len=args.max_encode_len,
beam_width=args.beam_width,
length_penalty=args.length_penalty,
tgt_type_id=args.tgt_type_id,)
output_str = rev_lookup(output_ids.numpy())
for eid, ostr in zip(example_id.numpy().tolist(), output_str.tolist()):
......@@ -228,8 +229,8 @@ def seq2seq(model, tokenizer, args):
vocab_size, _ = model.word_emb.weight.shape
ctx = D.parallel.prepare_context()
model = D.parallel.DataParallel(model, ctx)
opt = AdamW(learning_rate=LinearDecay(args.lr, int(args.warmup_proportion * args.max_steps), args.max_steps), parameter_list=model.parameters(), weight_decay=args.wd)
g_clip = F.dygraph_grad_clip.GradClipByGlobalNorm(1.0)
g_clip = F.clip.GradientClipByGlobalNorm(1.0)
opt = AdamW(learning_rate=LinearDecay(args.lr, int(args.warmup_proportion * args.max_steps), args.max_steps), parameter_list=model.parameters(), weight_decay=args.wd, grad_clip=g_clip)
attn_id = tokenizer.vocab[args.attn_token]
for step, data in enumerate(train_ds.start(place)):
(example_id, src_ids, src_sids, src_pids,
......@@ -253,7 +254,7 @@ def seq2seq(model, tokenizer, args):
scaled_loss = model.scale_loss(loss)
scaled_loss.backward()
model.apply_collective_grads()
opt.minimize(scaled_loss, grad_clip=g_clip)
opt.minimize(scaled_loss)
model.clear_gradients()
if step % 10 == 0:
loss = loss.numpy()
......@@ -261,12 +262,13 @@ def seq2seq(model, tokenizer, args):
log.debug('[step %d]train loss %.5f, ppl %.5f, lr %.3e' % (step, loss, ppl, opt.current_step_lr()))
if args.save_dir is not None and step % 1000 == 0 and D.parallel.Env().dev_id == 0:
F.save_dygraph(model.state_dict(), args.save_dir)
if args.predict_output_dir is not None and (step + 1) % args.eval_steps == 0:
if args.predict_output_dir is not None and step > args.skip_eval_steps and step % args.eval_steps == 0:
assert os.path.exists(args.predict_output_dir), 'predict_output_dir not found: %s' % args.predict_output_dir
log.debug('doing predict on gpu %d...' % D.parallel.Env().dev_id)
evaluate(model, dev_ds, step, args)
if step > args.max_steps:
break
evaluate(model, dev_ds, step, args)
if args.save_dir is not None:
F.save_dygraph(model.state_dict(), args.save_dir)
......@@ -277,10 +279,10 @@ if __name__ == '__main__':
parser.add_argument('--from_pretrained', type=str, required=True, help='pretrained model directory or tag')
parser.add_argument('--bsz', type=int, default=8, help='batchsize')
parser.add_argument('--eval_bsz', type=int, default=20, help='batchsize')
parser.add_argument('--epoch', type=int, default=30, help='epoch')
parser.add_argument('--data_dir', type=str, required=True, help='data directory includes train / develop data')
parser.add_argument('--max_steps', type=int, required=True, help='max_train_steps, set this to EPOCH * NUM_SAMPLES / BATCH_SIZE')
parser.add_argument('--eval_steps', type=int, default=5000, help='evaluation frequency')
parser.add_argument('--skip_eval_steps', type=int, default=1, help='skip evaluate for first n step')
parser.add_argument('--max_encode_len', type=int, default=640)
parser.add_argument('--max_decode_len', type=int, default=120)
parser.add_argument('--tgt_type_id', type=int, default=3)
......@@ -290,9 +292,11 @@ if __name__ == '__main__':
parser.add_argument('--use_random_noice', action='store_true', help='if set, replace target tokens with random token from vocabulary, else replace with `[NOISE]`')
parser.add_argument('--lr', type=float, default=5e-5, help='learning rate')
parser.add_argument('--label_smooth', type=float, default=0.1)
parser.add_argument('--length_penalty', type=float, default=1.0)
parser.add_argument('--predict_output_dir', type=str, default=None, help='predict file output directory')
parser.add_argument('--attn_token', type=str, default='[ATTN]', help='if [ATTN] not in vocab, you can specified [MAKK] as attn-token')
parser.add_argument('--inference_model_dir', type=str, default=None, help='inference model output directory')
parser.add_argument('--init_checkpoint', type=str, default=None)
parser.add_argument('--save_dir', type=str, default=None, help='model output directory')
parser.add_argument('--wd', type=float, default=0.01, help='weight decay, aka L2 regularizer')
......@@ -307,4 +311,9 @@ if __name__ == '__main__':
rev_dict[tokenizer.pad_id] = '' # replace [PAD]
rev_dict[tokenizer.unk_id] = '' # replace [PAD]
if args.init_checkpoint is not None:
log.info('loading checkpoint from %s' % args.init_checkpoint)
sd, _ = D.load_dygraph(args.init_checkpoint)
ernie.set_dict(sd)
seq2seq(ernie, tokenizer, args)
......@@ -22,7 +22,7 @@ with open("README.md", "r", encoding='utf-8') as fh:
setuptools.setup(
name="paddle-ernie", # Replace with your own username
version="0.0.2dev1",
version="0.0.3dev1",
author="Baidu Ernie Team",
author_email="ernieernie.team@gmail.com",
description="A pretrained NLP model for every NLP tasks",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册