未验证 提交 c00b77fe 编写于 作者: X Xiaoyao Xi 提交者: GitHub

Merge pull request #72 from wangxiao1021/api

fix #68 #71 and other bugs
......@@ -4,7 +4,7 @@
PaddlePALM (PArallel Learning from Multi-tasks) 是一个灵活,通用且易于使用的NLP大规模预训练和多任务学习框架。 PALM是一个旨在**快速开发高性能NLP模型**的上层框架。
使用PaddlePALM,可以非常轻松灵活的探索具有多种任务辅助训练的“高鲁棒性”阅读理解模型,基于PALM训练的模型[D-Net](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/Research/MRQA2019-D-NET)[EMNLP2019国际阅读理解评测](mrqa .github.io)中夺得冠军。
使用PaddlePALM,可以非常轻松灵活的探索具有多种任务辅助训练的“高鲁棒性”阅读理解模型,基于PALM训练的模型[D-Net](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/Research/MRQA2019-D-NET)[EMNLP2019国际阅读理解评测](https://mrqa.github.io/)中夺得冠军。
<p align="center">
<img src="https://tva1.sinaimg.cn/large/006tNbRwly1gbjkuuwrmlj30hs0hzdh2.jpg" alt="Sample" width="300" height="333">
......@@ -196,10 +196,10 @@ Available pretrain items:
更多实现细节请见示例:
- [Sentiment Classification](https://github.com/PaddlePaddle/PALM/tree/master/examples/classification)
- [Quora Question Pairs matching](https://github.com/PaddlePaddle/PALM/tree/master/examples/matching)
- [Tagging](https://github.com/PaddlePaddle/PALM/tree/master/examples/tagging)
- [SQuAD machine Reading Comprehension](https://github.com/PaddlePaddle/PALM/tree/master/examples/mrc).
- [情感分析](https://github.com/PaddlePaddle/PALM/tree/master/examples/classification)
- [Quora问题相似度匹配](https://github.com/PaddlePaddle/PALM/tree/master/examples/matching)
- [命名实体识别](https://github.com/PaddlePaddle/PALM/tree/master/examples/tagging)
- [类SQuAD机器阅读理解](https://github.com/PaddlePaddle/PALM/tree/master/examples/mrc)
#### 多任务学习
......@@ -218,7 +218,7 @@ multi_head_trainer的保存/加载和预测操作与trainer相同。
更多实现`multi_head_trainer`的细节,请见
- [ATIS: joint training of dialogue intent recognition and slot filling](https://github.com/PaddlePaddle/PALM/tree/master/examples/multi-task)
- [ATIS: 对话意图识别和插槽填充的联合训练](https://github.com/PaddlePaddle/PALM/tree/master/examples/multi-task)
#### 设置saver
......
# coding=utf-8
import paddlepalm as palm
import json
from paddlepalm.distribute import gpu_dev_count
if __name__ == '__main__':
......
# coding=utf-8
import paddlepalm as palm
import json
from paddlepalm.distribute import gpu_dev_count
if __name__ == '__main__':
......
# coding=utf-8
import paddlepalm as palm
import json
from paddlepalm.distribute import gpu_dev_count
if __name__ == '__main__':
......
# coding=utf-8
import paddlepalm as palm
import json
from paddlepalm.distribute import gpu_dev_count
if __name__ == '__main__':
......@@ -80,4 +79,4 @@ if __name__ == '__main__':
# save_steps = 10
trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type)
# step 8-3: start training
trainer.train(print_steps=print_steps)
\ No newline at end of file
trainer.train(print_steps=print_steps)
# coding=utf-8
import paddlepalm as palm
import json
from paddlepalm.distribute import gpu_dev_count
if __name__ == '__main__':
......
# coding=utf-8
import paddlepalm as palm
import json
from paddlepalm.distribute import gpu_dev_count
if __name__ == '__main__':
......@@ -64,9 +63,9 @@ if __name__ == '__main__':
# step 7: fit prepared reader and data
trainer.fit_reader(seq_label_reader)
# # step 8-1*: load pretrained parameters
# step 8-1*: load pretrained parameters
trainer.load_pretrain(pre_params)
# # step 8-2*: set saver to save model
# step 8-2*: set saver to save model
save_steps = 1951
# print('save_steps: {}'.format(save_steps))
trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type)
......
......@@ -98,7 +98,7 @@ class Classify(Head):
raise ValueError('argument output_dir not found in config. Please add it into config dict/file.')
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for i in range(len(self._preds)):
label = np.argmax(np.array(self._preds[i]))
label = int(np.argmax(np.array(self._preds[i])))
result = {'index': i, 'label': label, 'logits': self._preds[i], 'probs': self._probs[i]}
result = json.dumps(result)
writer.write(result+'\n')
......
......@@ -179,7 +179,7 @@ class Match(Head):
with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer:
for i in range(len(self._preds)):
if self._learning_strategy == 'pointwise':
label = np.argmax(np.array(self._preds[i]))
label = int(np.argmax(np.array(self._preds[i])))
result = {'index': i, 'label': label, 'logits': self._preds_logits[i], 'probs': self._preds[i]}
elif self._learning_strategy == 'pairwise':
result = {'index': i, 'probs': self._preds[i][0]}
......
......@@ -37,7 +37,7 @@ from paddlepalm.reader.utils.mlm_batching import prepare_batch_data
log = logging.getLogger(__name__)
if six.PY3:
if six.PY3 and hasattr(sys.stdout, 'buffer'):
import io
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册