提交 8f12e35d 编写于 作者: G guosheng

Fix seq2seq and transformer cpu.

上级 08295d55
......@@ -22,10 +22,22 @@ Sequence to Sequence (Seq2Seq),使用编码器-解码器(Encoder-Decoder)
本目录包含Seq2Seq的一个经典样例:机器翻译,实现了一个base model(不带attention机制),一个带attention机制的翻译模型。Seq2Seq翻译模型,模拟了人类在进行翻译类任务时的行为:先解析源语言,理解其含义,再根据该含义来写出目标语言的语句。更多关于机器翻译的具体原理和数学表达式,我们推荐参考飞桨官网[机器翻译案例](https://www.paddlepaddle.org.cn/documentation/docs/zh/user_guides/nlp_case/machine_translation/README.cn.html)
## 模型概览
本模型中,在编码器方面,我们采用了基于LSTM的多层的RNN encoder;在解码器方面,我们使用了带注意力(Attention)机制的RNN decoder,并同时提供了一个不带注意力机制的解码器实现作为对比。在预测时我们使用柱搜索(beam search)算法来生成翻译的目标语句。
## 代码下载
克隆代码库到本地,并设置`PYTHONPATH`环境变量
```shell
git clone https://github.com/PaddlePaddle/hapi
cd hapi
export PYTHONPATH=$PYTHONPATH:`pwd`
cd examples/seq2seq
```
## 数据介绍
本教程使用[IWSLT'15 English-Vietnamese data ](https://nlp.stanford.edu/projects/nmt/)数据集中的英语到越南语的数据作为训练语料,tst2012的数据作为开发集,tst2013的数据作为测试集
......@@ -96,7 +108,7 @@ python train.py \
```sh
export CUDA_VISIBLE_DEVICES=0
python infer.py \
python predict.py \
--attention True \
--src_lang en --tar_lang vi \
--num_layers 2 \
......
......@@ -78,8 +78,6 @@ def do_predict(args):
dataset=dataset,
batch_sampler=batch_sampler,
places=device,
feed_list=None
if fluid.in_dygraph_mode() else [x.forward() for x in inputs],
collate_fn=partial(
prepare_infer_input, bos_id=bos_id, eos_id=eos_id, pad_id=eos_id),
num_workers=0,
......@@ -98,7 +96,7 @@ def do_predict(args):
beam_size=args.beam_size,
max_out_len=256)
model.prepare(inputs=inputs)
model.prepare(inputs=inputs, device=device)
# load the trained model
assert args.reload_model, (
......
......@@ -73,7 +73,8 @@ def do_train(args):
CrossEntropyCriterion(),
ppl_metric,
inputs=inputs,
labels=labels)
labels=labels,
device=device)
model.fit(train_data=train_loader,
eval_data=eval_loader,
epochs=args.max_epoch,
......
......@@ -119,7 +119,7 @@ def do_predict(args):
args.eos_idx,
beam_size=args.beam_size,
max_out_len=args.max_out_len)
transformer.prepare(inputs=inputs)
transformer.prepare(inputs=inputs, device=device)
# load the trained model
assert args.init_from_params, (
......
......@@ -138,7 +138,8 @@ def do_train(args):
parameter_list=transformer.parameters()),
CrossEntropyCriterion(args.label_smooth_eps),
inputs=inputs,
labels=labels)
labels=labels,
device=device)
## init from some checkpoint, to resume the previous training
if args.init_from_checkpoint:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册