diff --git a/examples/seq2seq/README.md b/examples/seq2seq/README.md index 808d4516c57593210988542150b74c671e41d5da..6aa82800f896e3a1cae8921cf9bc16e15d107216 100644 --- a/examples/seq2seq/README.md +++ b/examples/seq2seq/README.md @@ -22,10 +22,22 @@ Sequence to Sequence (Seq2Seq),使用编码器-解码器(Encoder-Decoder) 本目录包含Seq2Seq的一个经典样例:机器翻译,实现了一个base model(不带attention机制),一个带attention机制的翻译模型。Seq2Seq翻译模型,模拟了人类在进行翻译类任务时的行为:先解析源语言,理解其含义,再根据该含义来写出目标语言的语句。更多关于机器翻译的具体原理和数学表达式,我们推荐参考飞桨官网[机器翻译案例](https://www.paddlepaddle.org.cn/documentation/docs/zh/user_guides/nlp_case/machine_translation/README.cn.html)。 + ## 模型概览 本模型中,在编码器方面,我们采用了基于LSTM的多层的RNN encoder;在解码器方面,我们使用了带注意力(Attention)机制的RNN decoder,并同时提供了一个不带注意力机制的解码器实现作为对比。在预测时我们使用柱搜索(beam search)算法来生成翻译的目标语句。 +## 代码下载 + +克隆代码库到本地,并设置`PYTHONPATH`环境变量 + +```shell +git clone https://github.com/PaddlePaddle/hapi +cd hapi +export PYTHONPATH=$PYTHONPATH:`pwd` +cd examples/seq2seq +``` + ## 数据介绍 本教程使用[IWSLT'15 English-Vietnamese data ](https://nlp.stanford.edu/projects/nmt/)数据集中的英语到越南语的数据作为训练语料,tst2012的数据作为开发集,tst2013的数据作为测试集 @@ -96,7 +108,7 @@ python train.py \ ```sh export CUDA_VISIBLE_DEVICES=0 -python infer.py \ +python predict.py \ --attention True \ --src_lang en --tar_lang vi \ --num_layers 2 \ diff --git a/examples/seq2seq/predict.py b/examples/seq2seq/predict.py index d1e3e87fddf05d453ed984a49d42fcac0f833cab..930c2e5189469ed174da02e9cf5d6e6e8c2b5a05 100644 --- a/examples/seq2seq/predict.py +++ b/examples/seq2seq/predict.py @@ -78,8 +78,6 @@ def do_predict(args): dataset=dataset, batch_sampler=batch_sampler, places=device, - feed_list=None - if fluid.in_dygraph_mode() else [x.forward() for x in inputs], collate_fn=partial( prepare_infer_input, bos_id=bos_id, eos_id=eos_id, pad_id=eos_id), num_workers=0, @@ -98,7 +96,7 @@ def do_predict(args): beam_size=args.beam_size, max_out_len=256) - model.prepare(inputs=inputs) + model.prepare(inputs=inputs, device=device) # load the trained model assert args.reload_model, ( diff --git a/examples/seq2seq/train.py b/examples/seq2seq/train.py index b7dc7698e31b1b5b935a63de66ee632956d3b102..55a31d39ad74686728593824151ac4bdf7b1b1ba 100644 --- a/examples/seq2seq/train.py +++ b/examples/seq2seq/train.py @@ -73,7 +73,8 @@ def do_train(args): CrossEntropyCriterion(), ppl_metric, inputs=inputs, - labels=labels) + labels=labels, + device=device) model.fit(train_data=train_loader, eval_data=eval_loader, epochs=args.max_epoch, diff --git a/examples/transformer/predict.py b/examples/transformer/predict.py index a6e14314f523d78dee2f770e69a21ae808cd8ad1..f99bf774cb2c9d6ceaa5b4cf69b941f9b2558358 100644 --- a/examples/transformer/predict.py +++ b/examples/transformer/predict.py @@ -119,7 +119,7 @@ def do_predict(args): args.eos_idx, beam_size=args.beam_size, max_out_len=args.max_out_len) - transformer.prepare(inputs=inputs) + transformer.prepare(inputs=inputs, device=device) # load the trained model assert args.init_from_params, ( diff --git a/examples/transformer/train.py b/examples/transformer/train.py index 94b52b4423839a0d7e01f0243cbb3d0f5907a4b0..39bee1dea46ce459c5f9388ce1d0e08fce914ac4 100644 --- a/examples/transformer/train.py +++ b/examples/transformer/train.py @@ -138,7 +138,8 @@ def do_train(args): parameter_list=transformer.parameters()), CrossEntropyCriterion(args.label_smooth_eps), inputs=inputs, - labels=labels) + labels=labels, + device=device) ## init from some checkpoint, to resume the previous training if args.init_from_checkpoint: