未验证 提交 d9d160a2 编写于 作者: K kinghuin 提交者: GitHub

fix ernie_gen bug. and plato and ddparser config (#817)

上级 cc78bd12
......@@ -6,17 +6,18 @@ $ hub run ddparser --input_text="百度是一家高科技公司"
# API
## parse(texts=[])
## parse(texts=[], return\_visual=False)
依存分析接口,输入文本,输出依存关系。
**参数**
* texts(list[list[str] or list[str]]): 待预测数据。各元素可以是未分词的字符串,也可以是已分词的token列表。
* texts(list\[list\[str\] or list\[str\]]): 待预测数据。各元素可以是未分词的字符串,也可以是已分词的token列表。
* return\_visual(bool): 是否返回依存分析可视化结果。如果为True,返回结果中将包含'visual'字段。
**返回**
* results(list[dict]): 依存分析结果。每个元素都是dict类型,包含以下信息:
* results(list\[dict\]): 依存分析结果。每个元素都是dict类型,包含以下信息:
```python
{
'word': list[str], 分词结果
......@@ -34,9 +35,9 @@ $ hub run ddparser --input_text="百度是一家高科技公司"
**参数**
* word(list[list[str]): 分词信息。
* head(list[int]): 当前成分其支配者的id。
* deprel(list[str]): 当前成分与支配者的依存关系。
* word(list\[list\[str\]\): 分词信息。
* head(list\[int\]): 当前成分其支配者的id。
* deprel(list\[str\]): 当前成分与支配者的依存关系。
**返回**
......@@ -55,11 +56,12 @@ results = module.parse(texts=test_text)
print(results)
test_tokens = [['百度', '是', '一家', '高科技', '公司']]
results = module.parse(texts=test_text)
results = module.parse(texts=test_text, return_visual = True)
print(results)
result = results[0]
data = module.visualize(result['word'],result['head'],result['deprel'])
# or data = result['visual']
cv2.imwrite('test.jpg',data)
```
......@@ -81,7 +83,7 @@ Loading ddparser successful.
这样就完成了服务化API的部署,默认端口号为8866。
**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA_VISIBLE_DEVICES环境变量,否则不用设置。
**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。
## 第二步:发送预测请求
......@@ -105,12 +107,12 @@ data = {"texts": text, "return_visual": return_visual}
url = "http://0.0.0.0:8866/predict/ddparser"
headers = {"Content-Type": "application/json"}
r = requests.post(url=url, headers=headers, data=json.dumps(data))
results, visuals = r.json()['results']
results = r.json()['results']
for i in range(len(results)):
print(results[i])
print(results[i]['word'])
# 不同于本地调用parse接口,serving返回的图像是list类型的,需要先用numpy加载再显示或保存。
cv2.imwrite('%s.jpg'%i, np.array(visuals[i]))
cv2.imwrite('%s.jpg'%i, np.array(results[i]['visual']))
```
关于PaddleHub Serving更多信息参考[服务部署](https://github.com/PaddlePaddle/PaddleHub/blob/release/v1.6/docs/tutorial/serving.md)
......
......@@ -32,15 +32,16 @@ class ddparser(hub.NLPPredictionModule):
"""
self.ddp = DDParserModel(prob=True, use_pos=True)
self.font = font_manager.FontProperties(
fname=os.path.join(self.directory, "SimHei.ttf"))
fname=os.path.join(self.directory, "SourceHanSans-Regular.ttf"))
@serving
def serving_parse(self, texts=[], return_visual=False):
results, visuals = self.parse(texts, return_visual)
for i, visual in enumerate(visuals):
visuals[i] = visual.tolist()
results = self.parse(texts, return_visual)
if return_visual:
for i, result in enumerate(results):
result['visual'] = result['visual'].tolist()
return results, visuals
return results
def parse(self, texts=[], return_visual=False):
"""
......@@ -57,11 +58,9 @@ class ddparser(hub.NLPPredictionModule):
'head': list[int], the head ids.
'deprel': list[str], the dependency relation.
'prob': list[float], the prediction probility of the dependency relation.
'postag': list[str], the POS tag. If the element of the texts is list, the key 'postag' will not be returned.
'postag': list[str], the POS tag. If the element of the texts is list, the key 'postag' will not return.
'visual' : list[numpy.array]: the dependency visualization. Use cv2.imshow to show or cv2.imwrite to save it. If return_visual=False, it will not return.
}
visuals : list[numpy.array]: the dependency visualization. Use cv2.imshow to show or cv2.imwrite to save it. If return_visual=False, it will not be empty.
"""
if not texts:
......@@ -73,13 +72,11 @@ class ddparser(hub.NLPPredictionModule):
else:
raise ValueError("All of the elements should be string or list")
results = do_parse(texts)
visuals = []
if return_visual:
for result in results:
visuals.append(
self.visualize(result['word'], result['head'],
result['deprel']))
return results, visuals
result['visual'] = self.visualize(
result['word'], result['head'], result['deprel'])
return results
@runnable
def run_cmd(self, argvs):
......@@ -194,10 +191,11 @@ if __name__ == "__main__":
results = module.parse(texts=test_text)
print(results)
test_tokens = [['百度', '是', '一家', '高科技', '公司']]
results = module.parse(texts=test_text)
results = module.parse(texts=test_text, return_visual=True)
print(results)
result = results[0]
data = module.visualize(result['word'], result['head'], result['deprel'])
import cv2
import numpy as np
cv2.imwrite('test.jpg', np.array(data))
cv2.imwrite('test1.jpg', data)
cv2.imwrite('test2.jpg', result['visual'])
......@@ -97,3 +97,7 @@ paddlehub >= 1.7.0
* 1.0.0
初始发布
* 1.0.1
修复windows中的编码问题
......@@ -35,7 +35,7 @@ from ernie_gen_couplet.model.modeling_ernie_gen import ErnieModelForGeneration
@moduleinfo(
name="ernie_gen_couplet",
version="1.0.0",
version="1.0.1",
summary=
"ERNIE-GEN is a multi-flow language generation framework for both pre-training and fine-tuning. This module has fine-tuned for couplet generation task.",
author="baidu-nlp",
......@@ -50,10 +50,10 @@ class ErnieGen(hub.NLPPredictionModule):
assets_path = os.path.join(self.directory, "assets")
gen_checkpoint_path = os.path.join(assets_path, "ernie_gen_couplet")
ernie_cfg_path = os.path.join(assets_path, 'ernie_config.json')
with open(ernie_cfg_path) as ernie_cfg_file:
with open(ernie_cfg_path, encoding='utf8') as ernie_cfg_file:
ernie_cfg = dict(json.loads(ernie_cfg_file.read()))
ernie_vocab_path = os.path.join(assets_path, 'vocab.txt')
with open(ernie_vocab_path) as ernie_vocab_file:
with open(ernie_vocab_path, encoding='utf8') as ernie_vocab_file:
ernie_vocab = {
j.strip().split('\t')[0]: i
for i, j in enumerate(ernie_vocab_file.readlines())
......
......@@ -97,3 +97,7 @@ paddlehub >= 1.7.0
* 1.0.0
初始发布
* 1.0.1
修复windows中的编码问题
......@@ -35,7 +35,7 @@ from ernie_gen_poetry.model.modeling_ernie_gen import ErnieModelForGeneration
@moduleinfo(
name="ernie_gen_poetry",
version="1.0.0",
version="1.0.1",
summary=
"ERNIE-GEN is a multi-flow language generation framework for both pre-training and fine-tuning. This module has fine-tuned for poetry generation task.",
author="baidu-nlp",
......@@ -50,10 +50,10 @@ class ErnieGen(hub.NLPPredictionModule):
assets_path = os.path.join(self.directory, "assets")
gen_checkpoint_path = os.path.join(assets_path, "ernie_gen_poetry")
ernie_cfg_path = os.path.join(assets_path, 'ernie_config.json')
with open(ernie_cfg_path) as ernie_cfg_file:
with open(ernie_cfg_path, encoding='utf8') as ernie_cfg_file:
ernie_cfg = dict(json.loads(ernie_cfg_file.read()))
ernie_vocab_path = os.path.join(assets_path, 'vocab.txt')
with open(ernie_vocab_path) as ernie_vocab_file:
with open(ernie_vocab_path, encoding='utf8') as ernie_vocab_file:
ernie_vocab = {
j.strip().split('\t')[0]: i
for i, j in enumerate(ernie_vocab_file.readlines())
......
......@@ -10,7 +10,7 @@ PLATO2是一个超大规模生成式对话系统模型。它承袭了PLATO隐变
## 命令行预测
```shell
$ hub run plato2_en_base --input_text="Hello, how are you" --use_gpu
$ hub run plato2_en_base --input_text="Hello, how are you"
```
## API
......
......@@ -7,10 +7,12 @@ PLATO2是一个超大规模生成式对话系统模型。它承袭了PLATO隐变
更多详情参考论文[PLATO-2: Towards Building an Open-Domain Chatbot via Curriculum Learning](https://arxiv.org/abs/2006.16779)
**注:plato2\_en\_large 模型大小12GB,下载时间较长,请耐心等候。运行此模型要求显存至少16GB。**
## 命令行预测
```shell
$ hub run plato2_en_large --input_text="Hello, how are you" --use_gpu
$ hub run plato2_en_large --input_text="Hello, how are you"
```
## API
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册