Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Edaker
PaddleHub
提交
9e2c2b34
P
PaddleHub
项目概览
Edaker
/
PaddleHub
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleHub
通知
4
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
9e2c2b34
编写于
4月 04, 2019
作者:
Z
Zeyu Chen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
remove useless code of ernie finetuning
上级
e0de0d86
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
433 addition
and
101 deletion
+433
-101
demo/ernie-classification/finetune_with_hub.py
demo/ernie-classification/finetune_with_hub.py
+1
-4
demo/ernie-classification/run_fintune_with_hub.sh
demo/ernie-classification/run_fintune_with_hub.sh
+1
-9
demo/ernie-seq-label/finetune_with_hub.py
demo/ernie-seq-label/finetune_with_hub.py
+97
-0
demo/ernie-seq-label/run_fintune_with_hub.sh
demo/ernie-seq-label/run_fintune_with_hub.sh
+11
-0
paddlehub/__init__.py
paddlehub/__init__.py
+1
-0
paddlehub/dataset/msra_ner.py
paddlehub/dataset/msra_ner.py
+20
-2
paddlehub/finetune/finetune.py
paddlehub/finetune/finetune.py
+254
-3
paddlehub/finetune/network.py
paddlehub/finetune/network.py
+37
-3
paddlehub/reader/__init__.py
paddlehub/reader/__init__.py
+2
-0
paddlehub/reader/batching.py
paddlehub/reader/batching.py
+8
-73
paddlehub/reader/nlp_reader.py
paddlehub/reader/nlp_reader.py
+0
-6
paddlehub/version.py
paddlehub/version.py
+1
-1
未找到文件。
demo/ernie-classification/finetune_with_hub.py
浏览文件 @
9e2c2b34
...
...
@@ -36,7 +36,6 @@ parser.add_argument("--data_dir", type=str, default=None, help="Path to training
parser
.
add_argument
(
"--checkpoint_dir"
,
type
=
str
,
default
=
None
,
help
=
"Directory to model checkpoint"
)
parser
.
add_argument
(
"--max_seq_len"
,
type
=
int
,
default
=
512
,
help
=
"Number of words of the longest seqence."
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
32
,
help
=
"Total examples' number in batch for training."
)
args
=
parser
.
parse_args
()
# yapf: enable.
...
...
@@ -55,9 +54,7 @@ if __name__ == '__main__':
# loading Paddlehub BERT
module
=
hub
.
Module
(
module_dir
=
args
.
hub_module_dir
)
# Use BERTTokenizeReader to tokenize the dataset according to model's
# vocabulary
reader
=
hub
.
reader
.
BERTTokenizeReader
(
reader
=
hub
.
reader
.
ClassifyReader
(
dataset
=
hub
.
dataset
.
ChnSentiCorp
(),
# download chnsenticorp dataset
vocab_path
=
module
.
get_vocab_path
(),
max_seq_len
=
args
.
max_seq_len
)
...
...
demo/ernie-classification/run_fintune_with_hub.sh
浏览文件 @
9e2c2b34
export
CUDA_VISIBLE_DEVICES
=
5
export
CUDA_VISIBLE_DEVICES
=
3
DATA_PATH
=
./chnsenticorp_data
HUB_MODULE_DIR
=
"./hub_module/bert_chinese_L-12_H-768_A-12.hub_module"
#HUB_MODULE_DIR="./hub_module/ernie_stable.hub_module"
CKPT_DIR
=
"./ckpt"
#rm -rf $CKPT_DIR
python
-u
finetune_with_hub.py
\
--batch_size
32
\
--hub_module_dir
=
$HUB_MODULE_DIR
\
--data_dir
${
DATA_PATH
}
\
--weight_decay
0.01
\
--checkpoint_dir
$CKPT_DIR
\
--num_epoch
3
\
...
...
demo/ernie-seq-label/finetune_with_hub.py
0 → 100644
浏览文件 @
9e2c2b34
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Finetuning on classification tasks."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
time
import
argparse
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
import
paddlehub
as
hub
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
"--num_epoch"
,
type
=
int
,
default
=
3
,
help
=
"Number of epoches for fine-tuning."
)
parser
.
add_argument
(
"--learning_rate"
,
type
=
float
,
default
=
5e-5
,
help
=
"Learning rate used to train with warmup."
)
parser
.
add_argument
(
"--hub_module_dir"
,
type
=
str
,
default
=
None
,
help
=
"PaddleHub module directory"
)
parser
.
add_argument
(
"--weight_decay"
,
type
=
float
,
default
=
0.01
,
help
=
"Weight decay rate for L2 regularizer."
)
parser
.
add_argument
(
"--checkpoint_dir"
,
type
=
str
,
default
=
None
,
help
=
"Directory to model checkpoint"
)
parser
.
add_argument
(
"--max_seq_len"
,
type
=
int
,
default
=
512
,
help
=
"Number of words of the longest seqence."
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
32
,
help
=
"Total examples' number in batch for training."
)
args
=
parser
.
parse_args
()
# yapf: enable.
if
__name__
==
'__main__'
:
strategy
=
hub
.
BERTFinetuneStrategy
(
weight_decay
=
args
.
weight_decay
,
learning_rate
=
args
.
learning_rate
,
warmup_strategy
=
"linear_warmup_decay"
,
)
config
=
hub
.
RunConfig
(
eval_interval
=
100
,
use_cuda
=
True
,
num_epoch
=
args
.
num_epoch
,
batch_size
=
args
.
batch_size
,
strategy
=
strategy
)
# loading Paddlehub ERNIE
module
=
hub
.
Module
(
name
=
"ernie"
)
reader
=
hub
.
reader
.
SequenceLabelReader
(
dataset
=
hub
.
dataset
.
MSRA_NER
(),
vocab_path
=
module
.
get_vocab_path
(),
max_seq_len
=
args
.
max_seq_len
)
num_labels
=
len
(
reader
.
get_labels
())
input_dict
,
output_dict
,
program
=
module
.
context
(
sign_name
=
"tokens"
,
trainable
=
True
,
max_seq_len
=
args
.
max_seq_len
)
with
fluid
.
program_guard
(
program
):
label
=
fluid
.
layers
.
data
(
name
=
"label"
,
shape
=
[
args
.
max_seq_len
,
1
],
dtype
=
'int64'
)
seq_len
=
fluid
.
layers
.
data
(
name
=
"seq_len"
,
shape
=
[
1
],
dtype
=
'int64'
)
# Use "pooled_output" for classification tasks on an entire sentence.
# Use "sequence_output" for token-level output.
sequence_output
=
output_dict
[
"sequence_output"
]
# Setup feed list for data feeder
# Must feed all the tensor of bert's module need
feed_list
=
[
input_dict
[
"input_ids"
].
name
,
input_dict
[
"position_ids"
].
name
,
input_dict
[
"segment_ids"
].
name
,
input_dict
[
"input_mask"
].
name
,
label
.
name
,
seq_len
]
# Define a classfication finetune task by PaddleHub's API
seq_label_task
=
hub
.
append_sequence_labeler
(
feature
=
sequence_output
,
labels
=
label
,
seq_len
=
seq_len
,
num_classes
=
num_labels
)
# Finetune and evaluate by PaddleHub's API
# will finish training, evaluation, testing, save model automatically
hub
.
finetune_and_eval
(
task
=
seq_label_task
,
data_reader
=
reader
,
feed_list
=
feed_list
,
config
=
config
)
demo/ernie-seq-label/run_fintune_with_hub.sh
0 → 100644
浏览文件 @
9e2c2b34
export
CUDA_VISIBLE_DEVICES
=
3
CKPT_DIR
=
"./ckpt"
python
-u
finetune_with_hub.py
\
--batch_size
16
\
--weight_decay
0.01
\
--checkpoint_dir
$CKPT_DIR
\
--num_epoch
3
\
--max_seq_len
256
\
--learning_rate
5e-5
paddlehub/__init__.py
浏览文件 @
9e2c2b34
...
...
@@ -33,6 +33,7 @@ from .module.manager import default_module_manager
from
.io.type
import
DataType
from
.finetune.network
import
append_mlp_classifier
from
.finetune.network
import
append_sequence_labeler
from
.finetune.finetune
import
finetune_and_eval
from
.finetune.config
import
RunConfig
from
.finetune.task
import
Task
...
...
paddlehub/dataset/msra_ner.py
浏览文件 @
9e2c2b34
...
...
@@ -30,6 +30,8 @@ class MSRA_NER(object):
self
.
_load_label_map
()
self
.
_load_train_examples
()
self
.
_load_test_examples
()
self
.
_load_dev_examples
()
def
_load_label_map
(
self
):
self
.
label_map_file
=
os
.
path
.
join
(
self
.
dataset_dir
,
"label_map.json"
)
...
...
@@ -40,12 +42,28 @@ class MSRA_NER(object):
train_file
=
os
.
path
.
join
(
self
.
dataset_dir
,
"train.tsv"
)
self
.
train_examples
=
self
.
_read_tsv
(
train_file
)
def
_load_dev_examples
(
self
):
self
.
dev_file
=
os
.
path
.
join
(
self
.
dataset_dir
,
"dev.tsv"
)
self
.
dev_examples
=
self
.
_read_tsv
(
self
.
dev_file
)
def
_load_test_examples
(
self
):
self
.
test_file
=
os
.
path
.
join
(
self
.
dataset_dir
,
"test.tsv"
)
self
.
test_examples
=
self
.
_read_tsv
(
self
.
test_file
)
def
get_train_examples
(
self
):
return
self
.
train_examples
def
get_dev_examples
(
self
):
return
self
.
dev_examples
def
get_test_examples
(
self
):
return
self
.
test_examples
def
get_labels
(
self
):
"""See base class."""
return
[
"0"
,
"1"
]
return
[
"B-PER"
,
"I-PER"
,
"B-ORG"
,
"I-ORG"
,
"B-LOC"
,
"I-LOC"
,
"O"
]
def
get_label_map
(
self
):
return
self
.
label_map
def
_read_tsv
(
self
,
input_file
,
quotechar
=
None
):
"""Reads a tab separated value file."""
...
...
paddlehub/finetune/finetune.py
浏览文件 @
9e2c2b34
...
...
@@ -22,6 +22,7 @@ import multiprocessing
import
paddle
import
paddle.fluid
as
fluid
import
numpy
as
np
from
paddlehub.common.logger
import
logger
from
paddlehub.finetune.strategy
import
BERTFinetuneStrategy
,
DefaultStrategy
...
...
@@ -61,7 +62,149 @@ def _do_memory_optimization(task, config):
(
lower_mem
,
upper_mem
,
unit
)),
def
_finetune_model
(
task
,
data_reader
,
feed_list
,
config
=
None
,
do_eval
=
False
):
def
_finetune_seq_label_task
(
task
,
data_reader
,
feed_list
,
config
=
None
,
do_eval
=
False
):
"""
Finetune sequence labeling task, evaluate metric is F1, precision and recall
"""
main_program
=
task
.
main_program
()
startup_program
=
task
.
startup_program
()
loss
=
task
.
variable
(
"loss"
)
seq_len
=
task
.
variable
(
"seq_len"
)
num_epoch
=
config
.
num_epoch
batch_size
=
config
.
batch_size
place
,
dev_count
=
_get_running_device_info
(
config
)
with
fluid
.
program_guard
(
main_program
,
startup_program
):
exe
=
fluid
.
Executor
(
place
=
place
)
data_feeder
=
fluid
.
DataFeeder
(
feed_list
=
feed_list
,
place
=
place
)
# Select strategy
if
isinstance
(
config
.
strategy
,
hub
.
BERTFinetuneStrategy
):
scheduled_lr
=
config
.
strategy
.
execute
(
loss
,
main_program
,
data_reader
,
config
)
elif
isinstance
(
config
.
strategy
,
hub
.
DefaultStrategy
):
config
.
strategy
.
execute
(
loss
)
#TODO: add more finetune strategy
_do_memory_optimization
(
task
,
config
)
# Try to restore model training checkpoint
current_epoch
,
global_step
=
load_checkpoint
(
config
.
checkpoint_dir
,
exe
)
train_time_used
=
0
logger
.
info
(
"PaddleHub finetune start"
)
# Finetune loop
for
epoch
in
range
(
current_epoch
,
num_epoch
+
1
):
train_reader
=
data_reader
.
data_generator
(
batch_size
=
batch_size
,
phase
=
'train'
)
num_trained_examples
=
loss_sum
=
0
for
batch
in
train_reader
():
num_batch_examples
=
len
(
batch
)
train_time_begin
=
time
.
time
()
loss_v
=
exe
.
run
(
feed
=
data_feeder
.
feed
(
batch
),
fetch_list
=
[
loss
.
name
])
train_time_used
+=
time
.
time
()
-
train_time_begin
global_step
+=
1
num_trained_examples
+=
num_batch_examples
loss_sum
+=
loss_v
[
0
]
*
num_batch_examples
# log fintune status
if
global_step
%
config
.
log_interval
==
0
:
avg_loss
=
loss_sum
/
num_trained_examples
speed
=
config
.
log_interval
/
train_time_used
logger
.
info
(
"step %d: loss=%.5f [step/sec: %.2f]"
%
(
global_step
,
avg_loss
,
speed
))
train_time_used
=
0
num_trained_examples
=
loss_sum
=
0
if
config
.
save_ckpt_interval
and
global_step
%
config
.
save_ckpt_interval
==
0
:
# NOTE: current saved checkpoint machanism is not completed,
# it can't restore correct dataset training status
save_checkpoint
(
checkpoint_dir
=
config
.
checkpoint_dir
,
current_epoch
=
epoch
,
global_step
=
global_step
,
exe
=
exe
)
if
do_eval
and
global_step
%
config
.
eval_interval
==
0
:
evaluate_seq_label
(
task
,
data_reader
,
feed_list
,
phase
=
"dev"
,
config
=
config
)
evaluate_seq_label
(
task
,
data_reader
,
feed_list
,
phase
=
"test"
,
config
=
config
)
# NOTE: current saved checkpoint machanism is not completed, it can't
# resotre dataset training status
save_checkpoint
(
checkpoint_dir
=
config
.
checkpoint_dir
,
current_epoch
=
num_epoch
+
1
,
global_step
=
global_step
,
exe
=
exe
)
if
do_eval
:
evaluate_seq_label
(
task
,
data_reader
,
feed_list
,
phase
=
"test"
,
config
=
config
)
logger
.
info
(
"PaddleHub finetune finished."
)
def
evaluate_seq_label
(
task
,
data_reader
,
feed_list
,
phase
=
"test"
,
config
=
None
):
fetch_list
=
[
task
.
variable
(
"labels"
).
name
,
task
.
variable
(
"infers"
).
name
,
task
.
variable
(
"seq_len"
).
name
,
task
.
variable
(
"loss"
).
name
]
logger
.
info
(
"Evaluation on {} dataset start"
.
format
(
phase
))
inference_program
=
task
.
inference_program
()
batch_size
=
config
.
batch_size
place
,
dev_count
=
_get_running_device_info
(
config
)
exe
=
fluid
.
Executor
(
place
=
place
)
with
fluid
.
program_guard
(
inference_program
):
data_feeder
=
fluid
.
DataFeeder
(
feed_list
=
feed_list
,
place
=
place
)
num_eval_examples
=
acc_sum
=
loss_sum
=
0
test_reader
=
data_reader
.
data_generator
(
batch_size
=
batch_size
,
phase
=
phase
)
eval_time_begin
=
time
.
time
()
eval_step
=
0
total_label
,
total_infer
,
total_correct
=
0.0
,
0.0
,
0.0
for
batch
in
test_reader
():
num_batch_examples
=
len
(
batch
)
eval_step
+=
1
np_labels
,
np_infers
,
np_lens
,
_
=
exe
.
run
(
feed
=
data_feeder
.
feed
(
batch
),
fetch_list
=
fetch_list
)
label_num
,
infer_num
,
correct_num
=
chunk_eval
(
np_labels
,
np_infers
,
np_lens
,
7
,
dev_count
)
total_infer
+=
infer_num
total_label
+=
label_num
total_correct
+=
correct_num
precision
,
recall
,
f1
=
calculate_f1
(
total_label
,
total_infer
,
total_correct
)
eval_time_used
=
time
.
time
()
-
eval_time_begin
eval_speed
=
eval_step
/
eval_time_used
logger
.
info
(
"[%s evaluation] F1-Score=%f, precision=%f, recall=%f [step/sec: %.2f]"
%
(
phase
,
f1
,
precision
,
recall
,
eval_speed
))
def
_finetune_cls_task
(
task
,
data_reader
,
feed_list
,
config
=
None
,
do_eval
=
False
):
main_program
=
task
.
main_program
()
startup_program
=
task
.
startup_program
()
loss
=
task
.
variable
(
"loss"
)
...
...
@@ -175,11 +318,15 @@ def _finetune_model(task, data_reader, feed_list, config=None, do_eval=False):
def
finetune_and_eval
(
task
,
data_reader
,
feed_list
,
config
=
None
):
_finetune_model
(
task
,
data_reader
,
feed_list
,
config
,
do_eval
=
True
)
if
task
.
task_type
==
"sequence_labeling"
:
_finetune_seq_label_task
(
task
,
data_reader
,
feed_list
,
config
,
do_eval
=
True
)
else
:
_finetune_cls_task
(
task
,
data_reader
,
feed_list
,
config
,
do_eval
=
True
)
def
finetune
(
task
,
data_reader
,
feed_list
,
config
=
None
):
_finetune_
model
(
task
,
data_reader
,
feed_list
,
config
,
do_eval
=
False
)
_finetune_
cls_task
(
task
,
data_reader
,
feed_list
,
config
,
do_eval
=
False
)
def
evaluate
(
task
,
data_reader
,
feed_list
,
phase
=
"test"
,
config
=
None
):
...
...
@@ -217,3 +364,107 @@ def evaluate(task, data_reader, feed_list, phase="test", config=None):
(
phase
,
avg_loss
,
avg_acc
,
eval_speed
))
return
avg_loss
,
avg_acc
,
eval_speed
# Sequence label evaluation functions
def
chunk_eval
(
np_labels
,
np_infers
,
np_lens
,
tag_num
,
dev_count
=
1
):
def
extract_bio_chunk
(
seq
):
chunks
=
[]
cur_chunk
=
None
null_index
=
tag_num
-
1
for
index
in
range
(
len
(
seq
)):
tag
=
seq
[
index
]
tag_type
=
tag
//
2
tag_pos
=
tag
%
2
if
tag
==
null_index
:
if
cur_chunk
is
not
None
:
chunks
.
append
(
cur_chunk
)
cur_chunk
=
None
continue
if
tag_pos
==
0
:
if
cur_chunk
is
not
None
:
chunks
.
append
(
cur_chunk
)
cur_chunk
=
{}
cur_chunk
=
{
"st"
:
index
,
"en"
:
index
+
1
,
"type"
:
tag_type
}
else
:
if
cur_chunk
is
None
:
cur_chunk
=
{
"st"
:
index
,
"en"
:
index
+
1
,
"type"
:
tag_type
}
continue
if
cur_chunk
[
"type"
]
==
tag_type
:
cur_chunk
[
"en"
]
=
index
+
1
else
:
chunks
.
append
(
cur_chunk
)
cur_chunk
=
{
"st"
:
index
,
"en"
:
index
+
1
,
"type"
:
tag_type
}
if
cur_chunk
is
not
None
:
chunks
.
append
(
cur_chunk
)
return
chunks
null_index
=
tag_num
-
1
num_label
=
0
num_infer
=
0
num_correct
=
0
labels
=
np_labels
.
reshape
([
-
1
]).
astype
(
np
.
int32
).
tolist
()
infers
=
np_infers
.
reshape
([
-
1
]).
astype
(
np
.
int32
).
tolist
()
all_lens
=
np_lens
.
reshape
([
dev_count
,
-
1
]).
astype
(
np
.
int32
).
tolist
()
base_index
=
0
for
dev_index
in
range
(
dev_count
):
lens
=
all_lens
[
dev_index
]
max_len
=
0
for
l
in
lens
:
max_len
=
max
(
max_len
,
l
)
for
i
in
range
(
len
(
lens
)):
seq_st
=
base_index
+
i
*
max_len
+
1
seq_en
=
seq_st
+
(
lens
[
i
]
-
2
)
infer_chunks
=
extract_bio_chunk
(
infers
[
seq_st
:
seq_en
])
label_chunks
=
extract_bio_chunk
(
labels
[
seq_st
:
seq_en
])
num_infer
+=
len
(
infer_chunks
)
num_label
+=
len
(
label_chunks
)
infer_index
=
0
label_index
=
0
while
label_index
<
len
(
label_chunks
)
\
and
infer_index
<
len
(
infer_chunks
):
if
infer_chunks
[
infer_index
][
"st"
]
\
<
label_chunks
[
label_index
][
"st"
]:
infer_index
+=
1
elif
infer_chunks
[
infer_index
][
"st"
]
\
>
label_chunks
[
label_index
][
"st"
]:
label_index
+=
1
else
:
if
infer_chunks
[
infer_index
][
"en"
]
\
==
label_chunks
[
label_index
][
"en"
]
\
and
infer_chunks
[
infer_index
][
"type"
]
\
==
label_chunks
[
label_index
][
"type"
]:
num_correct
+=
1
infer_index
+=
1
label_index
+=
1
base_index
+=
max_len
*
len
(
lens
)
return
num_label
,
num_infer
,
num_correct
def
calculate_f1
(
num_label
,
num_infer
,
num_correct
):
if
num_infer
==
0
:
precision
=
0.0
else
:
precision
=
num_correct
*
1.0
/
num_infer
if
num_label
==
0
:
recall
=
0.0
else
:
recall
=
num_correct
*
1.0
/
num_label
if
num_correct
==
0
:
f1
=
0.0
else
:
f1
=
2
*
precision
*
recall
/
(
precision
+
recall
)
return
precision
,
recall
,
f1
paddlehub/finetune/network.py
浏览文件 @
9e2c2b34
...
...
@@ -55,7 +55,6 @@ def append_mlp_classifier(feature, label, num_classes=2, hidden_units=None):
accuracy
=
fluid
.
layers
.
accuracy
(
input
=
probs
,
label
=
label
,
total
=
num_example
)
# TODO: encapsulate to Task
graph_var_dict
=
{
"loss"
:
loss
,
"probs"
:
probs
,
...
...
@@ -77,5 +76,40 @@ def append_mlp_multi_classifier(feature,
pass
def
append_sequence_labler
(
feature
,
label
):
pass
def
append_sequence_labeler
(
feature
,
labels
,
seq_len
,
num_classes
=
None
):
logits
=
fluid
.
layers
.
fc
(
input
=
feature
,
size
=
num_classes
,
num_flatten_dims
=
2
,
param_attr
=
fluid
.
ParamAttr
(
name
=
"cls_seq_label_out_w"
,
initializer
=
fluid
.
initializer
.
TruncatedNormal
(
scale
=
0.02
)),
bias_attr
=
fluid
.
ParamAttr
(
name
=
"cls_seq_label_out_b"
,
initializer
=
fluid
.
initializer
.
Constant
(
0.
)))
ret_labels
=
fluid
.
layers
.
reshape
(
x
=
labels
,
shape
=
[
-
1
,
1
])
ret_infers
=
fluid
.
layers
.
reshape
(
x
=
fluid
.
layers
.
argmax
(
logits
,
axis
=
2
),
shape
=
[
-
1
,
1
])
labels
=
fluid
.
layers
.
flatten
(
labels
,
axis
=
2
)
ce_loss
,
probs
=
fluid
.
layers
.
softmax_with_cross_entropy
(
logits
=
fluid
.
layers
.
flatten
(
logits
,
axis
=
2
),
label
=
labels
,
return_softmax
=
True
)
loss
=
fluid
.
layers
.
mean
(
x
=
ce_loss
)
# accuracy = fluid.layers.accuracy(
# input=probs, label=labels, total=num_example)
graph_var_dict
=
{
"loss"
:
loss
,
"probs"
:
probs
,
"labels"
:
ret_labels
,
"infers"
:
ret_infers
,
"seq_len"
:
seq_len
}
task
=
Task
(
"sequence_labeling"
,
graph_var_dict
,
fluid
.
default_main_program
(),
fluid
.
default_startup_program
())
return
task
paddlehub/reader/__init__.py
浏览文件 @
9e2c2b34
...
...
@@ -13,3 +13,5 @@
# limitations under the License.
from
.nlp_reader
import
BERTTokenizeReader
from
.task_reader
import
ClassifyReader
from
.task_reader
import
SequenceLabelReader
paddlehub/reader/batching.py
浏览文件 @
9e2c2b34
...
...
@@ -20,63 +20,8 @@ from __future__ import print_function
import
numpy
as
np
def
mask
(
batch_tokens
,
total_token_num
,
vocab_size
,
CLS
=
1
,
SEP
=
2
,
MASK
=
3
):
"""
Add mask for batch_tokens, return out, mask_label, mask_pos;
Note: mask_pos responding the batch_tokens after padded;
"""
max_len
=
max
([
len
(
sent
)
for
sent
in
batch_tokens
])
mask_label
=
[]
mask_pos
=
[]
prob_mask
=
np
.
random
.
rand
(
total_token_num
)
# Note: the first token is [CLS], so [low=1]
replace_ids
=
np
.
random
.
randint
(
1
,
high
=
vocab_size
,
size
=
total_token_num
)
pre_sent_len
=
0
prob_index
=
0
for
sent_index
,
sent
in
enumerate
(
batch_tokens
):
mask_flag
=
False
prob_index
+=
pre_sent_len
for
token_index
,
token
in
enumerate
(
sent
):
prob
=
prob_mask
[
prob_index
+
token_index
]
if
prob
>
0.15
:
continue
elif
0.03
<
prob
<=
0.15
:
# mask
if
token
!=
SEP
and
token
!=
CLS
:
mask_label
.
append
(
sent
[
token_index
])
sent
[
token_index
]
=
MASK
mask_flag
=
True
mask_pos
.
append
(
sent_index
*
max_len
+
token_index
)
elif
0.015
<
prob
<=
0.03
:
# random replace
if
token
!=
SEP
and
token
!=
CLS
:
mask_label
.
append
(
sent
[
token_index
])
sent
[
token_index
]
=
replace_ids
[
prob_index
+
token_index
]
mask_flag
=
True
mask_pos
.
append
(
sent_index
*
max_len
+
token_index
)
else
:
# keep the original token
if
token
!=
SEP
and
token
!=
CLS
:
mask_label
.
append
(
sent
[
token_index
])
mask_pos
.
append
(
sent_index
*
max_len
+
token_index
)
pre_sent_len
=
len
(
sent
)
# ensure at least mask one word in a sentence
while
not
mask_flag
:
token_index
=
int
(
np
.
random
.
randint
(
1
,
high
=
len
(
sent
)
-
1
,
size
=
1
))
if
sent
[
token_index
]
!=
SEP
and
sent
[
token_index
]
!=
CLS
:
mask_label
.
append
(
sent
[
token_index
])
sent
[
token_index
]
=
MASK
mask_flag
=
True
mask_pos
.
append
(
sent_index
*
max_len
+
token_index
)
mask_label
=
np
.
array
(
mask_label
).
astype
(
"int64"
).
reshape
([
-
1
,
1
])
mask_pos
=
np
.
array
(
mask_pos
).
astype
(
"int64"
).
reshape
([
-
1
,
1
])
return
batch_tokens
,
mask_label
,
mask_pos
def
prepare_batch_data
(
insts
,
total_token_num
,
voc_size
=
0
,
max_seq_len
=
128
,
pad_id
=
None
,
cls_id
=
None
,
...
...
@@ -103,17 +48,7 @@ def prepare_batch_data(insts,
labels
=
np
.
array
(
labels
).
astype
(
"int64"
).
reshape
([
-
1
,
1
])
labels_list
.
append
(
labels
)
# First step: do mask without padding
if
mask_id
>=
0
:
out
,
mask_label
,
mask_pos
=
mask
(
batch_src_ids
,
total_token_num
,
vocab_size
=
voc_size
,
CLS
=
cls_id
,
SEP
=
sep_id
,
MASK
=
mask_id
)
else
:
out
=
batch_src_ids
out
=
batch_src_ids
# Second step: padding
src_id
,
self_input_mask
=
pad_batch_data
(
out
,
pad_idx
=
pad_id
,
max_seq_len
=
max_seq_len
,
return_input_mask
=
True
)
...
...
@@ -130,12 +65,7 @@ def prepare_batch_data(insts,
return_pos
=
False
,
return_input_mask
=
False
)
if
mask_id
>=
0
:
return_list
=
[
src_id
,
pos_id
,
sent_id
,
self_input_mask
,
mask_label
,
mask_pos
]
+
labels_list
else
:
return_list
=
[
src_id
,
pos_id
,
sent_id
,
self_input_mask
]
+
labels_list
return_list
=
[
src_id
,
pos_id
,
sent_id
,
self_input_mask
]
+
labels_list
return
return_list
if
len
(
return_list
)
>
1
else
return_list
[
0
]
...
...
@@ -146,7 +76,8 @@ def pad_batch_data(insts,
return_pos
=
False
,
return_input_mask
=
False
,
return_max_len
=
False
,
return_num_token
=
False
):
return_num_token
=
False
,
return_seq_lens
=
False
):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and input mask.
...
...
@@ -187,4 +118,8 @@ def pad_batch_data(insts,
num_token
+=
len
(
inst
)
return_list
+=
[
num_token
]
if
return_seq_lens
:
seq_lens
=
np
.
array
([
len
(
inst
)
for
inst
in
insts
])
return_list
+=
[
seq_lens
.
astype
(
"int64"
).
reshape
([
-
1
,
1
])]
return
return_list
if
len
(
return_list
)
>
1
else
return_list
[
0
]
paddlehub/reader/nlp_reader.py
浏览文件 @
9e2c2b34
...
...
@@ -83,20 +83,16 @@ class BERTTokenizeReader(object):
def
generate_batch_data
(
self
,
batch_data
,
total_token_num
,
voc_size
=-
1
,
mask_id
=-
1
,
return_input_mask
=
True
,
return_max_len
=
False
,
return_num_token
=
False
):
return
prepare_batch_data
(
batch_data
,
total_token_num
,
voc_size
=-
1
,
max_seq_len
=
self
.
max_seq_len
,
pad_id
=
self
.
vocab
[
"[PAD]"
],
cls_id
=
self
.
vocab
[
"[CLS]"
],
sep_id
=
self
.
vocab
[
"[SEP]"
],
mask_id
=-
1
,
return_input_mask
=
return_input_mask
,
return_max_len
=
return_max_len
,
return_num_token
=
return_num_token
)
...
...
@@ -166,8 +162,6 @@ class BERTTokenizeReader(object):
batch_data
=
self
.
generate_batch_data
(
batch_data
,
total_token_num
,
voc_size
=-
1
,
mask_id
=-
1
,
return_input_mask
=
True
,
return_max_len
=
True
,
return_num_token
=
False
)
...
...
paddlehub/version.py
浏览文件 @
9e2c2b34
...
...
@@ -11,6 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Store
PaddleHub version string """
""" PaddleHub version string """
hub_version
=
"0.3.1.alpha"
module_proto_version
=
"0.1.0"
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录