From 8f12e35da33d475c212855a8ff4a066ddb586bf9 Mon Sep 17 00:00:00 2001 From: guosheng Date: Thu, 30 Apr 2020 10:57:12 +0800 Subject: [PATCH] Fix seq2seq and transformer cpu. --- examples/seq2seq/README.md | 14 +++++++++++++- examples/seq2seq/predict.py | 4 +--- examples/seq2seq/train.py | 3 ++- examples/transformer/predict.py | 2 +- examples/transformer/train.py | 3 ++- 5 files changed, 19 insertions(+), 7 deletions(-) diff --git a/examples/seq2seq/README.md b/examples/seq2seq/README.md index 808d451..6aa8280 100644 --- a/examples/seq2seq/README.md +++ b/examples/seq2seq/README.md @@ -22,10 +22,22 @@ Sequence to Sequence (Seq2Seq),使用编码器-解码器(Encoder-Decoder) 本目录包含Seq2Seq的一个经典样例:机器翻译,实现了一个base model(不带attention机制),一个带attention机制的翻译模型。Seq2Seq翻译模型,模拟了人类在进行翻译类任务时的行为:先解析源语言,理解其含义,再根据该含义来写出目标语言的语句。更多关于机器翻译的具体原理和数学表达式,我们推荐参考飞桨官网[机器翻译案例](https://www.paddlepaddle.org.cn/documentation/docs/zh/user_guides/nlp_case/machine_translation/README.cn.html)。 + ## 模型概览 本模型中,在编码器方面,我们采用了基于LSTM的多层的RNN encoder;在解码器方面,我们使用了带注意力(Attention)机制的RNN decoder,并同时提供了一个不带注意力机制的解码器实现作为对比。在预测时我们使用柱搜索(beam search)算法来生成翻译的目标语句。 +## 代码下载 + +克隆代码库到本地,并设置`PYTHONPATH`环境变量 + +```shell +git clone https://github.com/PaddlePaddle/hapi +cd hapi +export PYTHONPATH=$PYTHONPATH:`pwd` +cd examples/seq2seq +``` + ## 数据介绍 本教程使用[IWSLT'15 English-Vietnamese data ](https://nlp.stanford.edu/projects/nmt/)数据集中的英语到越南语的数据作为训练语料,tst2012的数据作为开发集,tst2013的数据作为测试集 @@ -96,7 +108,7 @@ python train.py \ ```sh export CUDA_VISIBLE_DEVICES=0 -python infer.py \ +python predict.py \ --attention True \ --src_lang en --tar_lang vi \ --num_layers 2 \ diff --git a/examples/seq2seq/predict.py b/examples/seq2seq/predict.py index d1e3e87..930c2e5 100644 --- a/examples/seq2seq/predict.py +++ b/examples/seq2seq/predict.py @@ -78,8 +78,6 @@ def do_predict(args): dataset=dataset, batch_sampler=batch_sampler, places=device, - feed_list=None - if fluid.in_dygraph_mode() else [x.forward() for x in inputs], collate_fn=partial( prepare_infer_input, bos_id=bos_id, eos_id=eos_id, pad_id=eos_id), num_workers=0, @@ -98,7 +96,7 @@ def do_predict(args): beam_size=args.beam_size, max_out_len=256) - model.prepare(inputs=inputs) + model.prepare(inputs=inputs, device=device) # load the trained model assert args.reload_model, ( diff --git a/examples/seq2seq/train.py b/examples/seq2seq/train.py index b7dc769..55a31d3 100644 --- a/examples/seq2seq/train.py +++ b/examples/seq2seq/train.py @@ -73,7 +73,8 @@ def do_train(args): CrossEntropyCriterion(), ppl_metric, inputs=inputs, - labels=labels) + labels=labels, + device=device) model.fit(train_data=train_loader, eval_data=eval_loader, epochs=args.max_epoch, diff --git a/examples/transformer/predict.py b/examples/transformer/predict.py index a6e1431..f99bf77 100644 --- a/examples/transformer/predict.py +++ b/examples/transformer/predict.py @@ -119,7 +119,7 @@ def do_predict(args): args.eos_idx, beam_size=args.beam_size, max_out_len=args.max_out_len) - transformer.prepare(inputs=inputs) + transformer.prepare(inputs=inputs, device=device) # load the trained model assert args.init_from_params, ( diff --git a/examples/transformer/train.py b/examples/transformer/train.py index 94b52b4..39bee1d 100644 --- a/examples/transformer/train.py +++ b/examples/transformer/train.py @@ -138,7 +138,8 @@ def do_train(args): parameter_list=transformer.parameters()), CrossEntropyCriterion(args.label_smooth_eps), inputs=inputs, - labels=labels) + labels=labels, + device=device) ## init from some checkpoint, to resume the previous training if args.init_from_checkpoint: -- GitLab