未验证 提交 36d7d800 编写于 作者: S Steffy-zxf 提交者: GitHub

Add nlp module (#433)

* add nlpmodule base class
上级 1e5be079
......@@ -64,4 +64,4 @@ from .finetune.strategy import CombinedStrategy
from .autofinetune.evaluator import report_final_result
from .module.nlp_module import BERTModule
from .module.nlp_module import NLPPredictionModule, TransformerModule
#coding:utf-8
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
......@@ -135,19 +135,27 @@ class Module(object):
if "_is_initialize" in self.__dict__ and self._is_initialize:
return
mod = self.__class__.__module__ + "." + self.__class__.__name__
if mod in _module_runnable_func:
_run_func_name = _module_runnable_func[mod]
self._run_func = getattr(self, _run_func_name)
else:
self._run_func = None
self._serving_func_name = _module_serving_func.get(mod, None)
self._code_version = "v2"
_run_func_name = self._get_func_name(self.__class__,
_module_runnable_func)
self._run_func = getattr(self, _run_func_name)
self._serving_func_name = self._get_func_name(self.__class__,
_module_serving_func)
self._directory = directory
self._initialize(**kwargs)
self._is_initialize = True
self._code_version = "v2"
def _get_func_name(self, current_cls, module_func_dict):
mod = current_cls.__module__ + "." + current_cls.__name__
if mod in module_func_dict:
_func_name = module_func_dict[mod]
return _func_name
elif current_cls.__bases__:
for base_class in current_cls.__bases__:
return self._get_func_name(base_class, module_func_dict)
else:
return None
@classmethod
def init_with_name(cls, name, version=None, **kwargs):
fp_lock = open(os.path.join(CACHE_HOME, name), "a")
......
#coding:utf-8
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
......@@ -17,15 +17,198 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import ast
import json
import os
import re
import six
import paddlehub as hub
import numpy as np
import paddle.fluid as fluid
from paddlehub import logger
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
import paddlehub as hub
from paddlehub.common.logger import logger
from paddlehub.common.utils import sys_stdin_encoding
from paddlehub.io.parser import txt_parser
from paddlehub.module.module import runnable
class DataFormatError(Exception):
def __init__(self, *args):
self.args = args
class NLPBaseModule(hub.Module):
def _initialize(self):
"""
initialize with the necessary elements
This method must be overrided.
"""
raise NotImplementedError()
def get_vocab_path(self):
"""
Get the path to the vocabulary whih was used to pretrain
Returns:
self.vocab_path(str): the path to vocabulary
"""
return self.vocab_path
class NLPPredictionModule(NLPBaseModule):
def _set_config(self):
"""
predictor config setting
"""
cpu_config = AnalysisConfig(self.pretrained_model_path)
cpu_config.disable_glog_info()
cpu_config.disable_gpu()
self.cpu_predictor = create_paddle_predictor(cpu_config)
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
use_gpu = True
except:
use_gpu = False
if use_gpu:
gpu_config = AnalysisConfig(self.pretrained_model_path)
gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=500, device_id=0)
self.gpu_predictor = create_paddle_predictor(gpu_config)
def texts2tensor(self, texts):
"""
Tranform the texts(dict) to PaddleTensor
Args:
texts(list): each element is a dict that must have a named 'processed' key whose value is word_ids, such as
texts = [{'processed': [23, 89, 43, 906]}]
Returns:
tensor(PaddleTensor): tensor with texts data
"""
lod = [0]
data = []
for i, text in enumerate(texts):
data += text['processed']
lod.append(len(text['processed']) + lod[i])
tensor = PaddleTensor(np.array(data).astype('int64'))
tensor.name = "words"
tensor.lod = [lod]
tensor.shape = [lod[-1], 1]
return tensor
def to_unicode(self, texts):
"""
Convert each element's type(str) of texts(list) to unicode in python2.7
Args:
texts(list): each element's type is str in python2.7
Returns:
texts(list): each element's type is unicode in python2.7
"""
if six.PY2:
unicode_texts = []
for text in texts:
if not isinstance(text, six.string_types):
unicode_texts.append(
text.decode(sys_stdin_encoding()).decode("utf8"))
else:
unicode_texts.append(text)
texts = unicode_texts
return texts
@runnable
def run_cmd(self, argvs):
"""
Run as a command
"""
self.parser = argparse.ArgumentParser(
description='Run the %s module.' % self.module_name,
prog='hub run %s' % self.module_name,
usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(
title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group(
title="Config options",
description=
"Run configuration for controlling module behavior, not required.")
self.add_module_config_arg()
self.add_module_input_arg()
args = self.parser.parse_args(argvs)
try:
input_data = self.check_input_data(args)
except DataFormatError and RuntimeError:
self.parser.print_help()
return None
results = self.predict(
texts=input_data, use_gpu=args.use_gpu, batch_size=args.batch_size)
return results
class _BERTEmbeddingTask(hub.BaseTask):
def add_module_config_arg(self):
"""
Add the command config options
"""
self.arg_config_group.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=False,
help="whether use GPU for prediction")
self.arg_config_group.add_argument(
'--batch_size',
type=int,
default=1,
help="batch size for prediction")
def add_module_input_arg(self):
"""
Add the command input options
"""
self.arg_input_group.add_argument(
'--input_file',
type=str,
default=None,
help="file contain input data")
self.arg_input_group.add_argument(
'--input_text', type=str, default=None, help="text to predict")
def check_input_data(self, args):
input_data = []
if args.input_file:
if not os.path.exists(args.input_file):
print("File %s is not exist." % args.input_file)
raise RuntimeError
else:
input_data = txt_parser.parse(args.input_file, use_strip=True)
elif args.input_text:
if args.input_text.strip() != '':
if six.PY2:
input_data = [
args.input_text.decode(
sys_stdin_encoding()).decode("utf8")
]
else:
input_data = [args.input_text]
else:
print(
"ERROR: The input data is inconsistent with expectations.")
if input_data == []:
print("ERROR: The input data is inconsistent with expectations.")
raise DataFormatError
return input_data
class _TransformerEmbeddingTask(hub.BaseTask):
def __init__(self,
pooled_feature,
seq_feature,
......@@ -33,7 +216,7 @@ class _BERTEmbeddingTask(hub.BaseTask):
data_reader,
config=None):
main_program = pooled_feature.block.program
super(_BERTEmbeddingTask, self).__init__(
super(_TransformerEmbeddingTask, self).__init__(
main_program=main_program,
data_reader=data_reader,
feed_list=feed_list,
......@@ -57,21 +240,10 @@ class _BERTEmbeddingTask(hub.BaseTask):
return results
class BERTModule(hub.Module):
def _initialize(self):
"""
Must override this method.
some member variables are required, others are optional.
"""
# required config
self.MAX_SEQ_LEN = None
self.params_path = None
self.vocab_path = None
# optional config
self.spm_path = None
self.word_dict_path = None
raise NotImplementedError
class TransformerModule(NLPBaseModule):
"""
Tranformer Module base class can be used by BERT, ERNIE, RoBERTa and so on.
"""
def init_pretraining_params(self, exe, pretraining_params_path,
main_program):
......@@ -157,7 +329,6 @@ class BERTModule(hub.Module):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_program)
self.init_pretraining_params(
exe, self.params_path, main_program=startup_program)
......@@ -176,7 +347,7 @@ class BERTModule(hub.Module):
def get_embedding(self, texts, use_gpu=False, batch_size=1):
"""
get pooled_output and sequence_output for input texts.
Warnings: this method depends on Paddle Inference Library, it may not work properly in PaddlePaddle < 1.6.2.
Warnings: this method depends on Paddle Inference Library, it may not work properly in PaddlePaddle <= 1.6.2.
Args:
texts (list): each element is a text sample, each sample include text_a and text_b where text_b can be omitted.
......@@ -220,7 +391,7 @@ class BERTModule(hub.Module):
batch_size=batch_size)
self.emb_job = {}
self.emb_job["task"] = _BERTEmbeddingTask(
self.emb_job["task"] = _TransformerEmbeddingTask(
pooled_feature=pooled_feature,
seq_feature=seq_feature,
feed_list=feed_list,
......@@ -233,9 +404,6 @@ class BERTModule(hub.Module):
return self.emb_job["task"].predict(
data=texts, return_result=True, accelerate_mode=True)
def get_vocab_path(self):
return self.vocab_path
def get_spm_path(self):
if hasattr(self, "spm_path"):
return self.spm_path
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册