提交 1fac53aa 编写于 作者: X xyzhou-puck

update nlp models

上级 85e422bb
......@@ -16,14 +16,60 @@
import paddle.fluid as fluid
from hapi.metrics import Accuracy
from hapi.configure import Config
from hapi.text.bert import BertEncoder
from paddle.fluid.dygraph import Linear, Layer
from hapi.model import set_device, Model, SoftmaxWithCrossEntropy, Input
from cls import ClsModelLayer
import hapi.text.tokenizer.tokenization as tokenization
from hapi.text.bert import Optimizer, BertConfig, BertDataLoader, BertInputExample
def train():
class ClsModelLayer(Model):
"""
classify model
"""
def __init__(self,
args,
config,
num_labels,
return_pooled_out=True,
use_fp16=False):
super(ClsModelLayer, self).__init__()
self.config = config
self.use_fp16 = use_fp16
self.loss_scaling = args.loss_scaling
self.bert_layer = BertEncoder(
config=self.config, return_pooled_out=True, use_fp16=self.use_fp16)
self.cls_fc = Linear(
input_dim=self.config["hidden_size"],
output_dim=num_labels,
param_attr=fluid.ParamAttr(
name="cls_out_w",
initializer=fluid.initializer.TruncatedNormal(scale=0.02)),
bias_attr=fluid.ParamAttr(
name="cls_out_b", initializer=fluid.initializer.Constant(0.)))
def forward(self, src_ids, position_ids, sentence_ids, input_mask):
"""
forward
"""
enc_output, next_sent_feat = self.bert_layer(src_ids, position_ids,
sentence_ids, input_mask)
cls_feats = fluid.layers.dropout(
x=next_sent_feat,
dropout_prob=0.1,
dropout_implementation="upscale_in_train")
pred = self.cls_fc(cls_feats)
return pred
def main():
config = Config(yaml_file="./bert.yaml")
config.build()
......@@ -35,8 +81,6 @@ def train():
bert_config = BertConfig(config.bert_config_path)
bert_config.print_config()
trainer_count = fluid.dygraph.parallel.Env().nranks
tokenizer = tokenization.FullTokenizer(
vocab_file=config.vocab_path, do_lower_case=config.do_lower_case)
......@@ -52,14 +96,24 @@ def train():
return BertInputExample(
uid=uid, text_a=text_a, text_b=text_b, label=label)
bert_dataloader = BertDataLoader(
train_dataloader = BertDataLoader(
"./data/glue_data/MNLI/train.tsv",
tokenizer, ["contradiction", "entailment", "neutral"],
max_seq_length=64,
batch_size=32,
max_seq_length=config.max_seq_len,
batch_size=config.batch_size,
line_processor=mnli_line_processor)
num_train_examples = len(bert_dataloader.dataset)
dev_dataloader = BertDataLoader(
"./data/glue_data/MNLI/dev_matched.tsv",
tokenizer, ["contradiction", "entailment", "neutral"],
max_seq_length=config.max_seq_len,
batch_size=config.batch_size,
line_processor=mnli_line_processor,
shuffle=False,
phase="predict")
trainer_count = fluid.dygraph.parallel.Env().nranks
num_train_examples = len(train_dataloader.dataset)
max_train_steps = config.epoch * num_train_examples // config.batch_size // trainer_count
warmup_steps = int(max_train_steps * config.warmup_proportion)
......@@ -82,7 +136,6 @@ def train():
config,
bert_config,
len(["contradiction", "entailment", "neutral"]),
is_training=True,
return_pooled_out=True)
optimizer = Optimizer(
......@@ -106,10 +159,15 @@ def train():
cls_model.bert_layer.init_parameters(
config.init_pretraining_params, verbose=config.verbose)
cls_model.fit(train_data=bert_dataloader.dataloader, epochs=config.epoch)
# do train
cls_model.fit(train_data=train_dataloader.dataloader,
epochs=config.epoch,
save_dir=config.checkpoints)
return cls_model
# do eval
cls_model.evaluate(
eval_data=test_dataloader.dataloader, batch_size=config.batch_size)
if __name__ == '__main__':
cls_model = train()
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"dygraph transformer layers"
import six
import json
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import Linear, Layer
from hapi.text.bert import BertEncoder
from hapi.model import Model
class ClsModelLayer(Model):
"""
classify model
"""
def __init__(self,
args,
config,
num_labels,
is_training=True,
return_pooled_out=True,
use_fp16=False):
super(ClsModelLayer, self).__init__()
self.config = config
self.is_training = is_training
self.use_fp16 = use_fp16
self.loss_scaling = args.loss_scaling
self.bert_layer = BertEncoder(
config=self.config, return_pooled_out=True, use_fp16=self.use_fp16)
self.cls_fc = Linear(
input_dim=self.config["hidden_size"],
output_dim=num_labels,
param_attr=fluid.ParamAttr(
name="cls_out_w",
initializer=fluid.initializer.TruncatedNormal(scale=0.02)),
bias_attr=fluid.ParamAttr(
name="cls_out_b", initializer=fluid.initializer.Constant(0.)))
def forward(self, src_ids, position_ids, sentence_ids, input_mask):
"""
forward
"""
enc_output, next_sent_feat = self.bert_layer(src_ids, position_ids,
sentence_ids, input_mask)
cls_feats = fluid.layers.dropout(
x=next_sent_feat,
dropout_prob=0.1,
dropout_implementation="upscale_in_train")
logits = self.cls_fc(cls_feats)
return logits
......@@ -18,7 +18,7 @@ batch_size: 32
in_tokens: False
do_lower_case: True
random_seed: 5512
use_cuda: False
use_cuda: True
shuffle: True
do_train: True
do_test: True
......
......@@ -16,14 +16,60 @@
import paddle.fluid as fluid
from hapi.metrics import Accuracy
from hapi.configure import Config
from hapi.text.bert import BertEncoder
from paddle.fluid.dygraph import Linear, Layer
from hapi.model import set_device, Model, SoftmaxWithCrossEntropy, Input
from cls import ClsModelLayer
import hapi.text.tokenizer.tokenization as tokenization
from hapi.text.bert import Optimizer, BertConfig, BertDataLoader, BertInputExample
def train():
class ClsModelLayer(Model):
"""
classify model
"""
def __init__(self,
args,
config,
num_labels,
return_pooled_out=True,
use_fp16=False):
super(ClsModelLayer, self).__init__()
self.config = config
self.use_fp16 = use_fp16
self.loss_scaling = args.loss_scaling
self.bert_layer = BertEncoder(
config=self.config, return_pooled_out=True, use_fp16=self.use_fp16)
self.cls_fc = Linear(
input_dim=self.config["hidden_size"],
output_dim=num_labels,
param_attr=fluid.ParamAttr(
name="cls_out_w",
initializer=fluid.initializer.TruncatedNormal(scale=0.02)),
bias_attr=fluid.ParamAttr(
name="cls_out_b", initializer=fluid.initializer.Constant(0.)))
def forward(self, src_ids, position_ids, sentence_ids, input_mask):
"""
forward
"""
enc_output, next_sent_feat = self.bert_layer(src_ids, position_ids,
sentence_ids, input_mask)
cls_feats = fluid.layers.dropout(
x=next_sent_feat,
dropout_prob=0.1,
dropout_implementation="upscale_in_train")
pred = self.cls_fc(cls_feats)
return pred
def main():
config = Config(yaml_file="./bert.yaml")
config.build()
......@@ -35,8 +81,6 @@ def train():
bert_config = BertConfig(config.bert_config_path)
bert_config.print_config()
trainer_count = fluid.dygraph.parallel.Env().nranks
tokenizer = tokenization.FullTokenizer(
vocab_file=config.vocab_path, do_lower_case=config.do_lower_case)
......@@ -52,15 +96,26 @@ def train():
return BertInputExample(
uid=uid, text_a=text_a, text_b=text_b, label=label)
bert_dataloader = BertDataLoader(
train_dataloader = BertDataLoader(
"./data/glue_data/MNLI/train.tsv",
tokenizer, ["contradiction", "entailment", "neutral"],
max_seq_length=64,
batch_size=32,
max_seq_length=config.max_seq_len,
batch_size=config.batch_size,
line_processor=mnli_line_processor,
mode="leveldb")
mode="leveldb",
phase="train")
num_train_examples = len(bert_dataloader.dataset)
dev_dataloader = BertDataLoader(
"./data/glue_data/MNLI/dev_matched.tsv",
tokenizer, ["contradiction", "entailment", "neutral"],
max_seq_length=config.max_seq_len,
batch_size=config.batch_size,
line_processor=mnli_line_processor,
shuffle=False,
phase="predict")
trainer_count = fluid.dygraph.parallel.Env().nranks
num_train_examples = len(train_dataloader.dataset)
max_train_steps = config.epoch * num_train_examples // config.batch_size // trainer_count
warmup_steps = int(max_train_steps * config.warmup_proportion)
......@@ -83,7 +138,6 @@ def train():
config,
bert_config,
len(["contradiction", "entailment", "neutral"]),
is_training=True,
return_pooled_out=True)
optimizer = Optimizer(
......@@ -107,10 +161,15 @@ def train():
cls_model.bert_layer.init_parameters(
config.init_pretraining_params, verbose=config.verbose)
cls_model.fit(train_data=bert_dataloader.dataloader, epochs=config.epoch)
# do train
cls_model.fit(train_data=train_dataloader.dataloader,
epochs=config.epoch,
save_dir=config.checkpoints)
return cls_model
# do eval
cls_model.evaluate(
eval_data=test_dataloader.dataloader, batch_size=config.batch_size)
if __name__ == '__main__':
cls_model = train()
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"dygraph transformer layers"
import six
import json
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import Linear, Layer
from hapi.text.bert import BertEncoder
from hapi.model import Model
class ClsModelLayer(Model):
"""
classify model
"""
def __init__(self,
args,
config,
num_labels,
is_training=True,
return_pooled_out=True,
use_fp16=False):
super(ClsModelLayer, self).__init__()
self.config = config
self.is_training = is_training
self.use_fp16 = use_fp16
self.loss_scaling = args.loss_scaling
self.bert_layer = BertEncoder(
config=self.config, return_pooled_out=True, use_fp16=self.use_fp16)
self.cls_fc = Linear(
input_dim=self.config["hidden_size"],
output_dim=num_labels,
param_attr=fluid.ParamAttr(
name="cls_out_w",
initializer=fluid.initializer.TruncatedNormal(scale=0.02)),
bias_attr=fluid.ParamAttr(
name="cls_out_b", initializer=fluid.initializer.Constant(0.)))
def forward(self, src_ids, position_ids, sentence_ids, input_mask):
"""
forward
"""
enc_output, next_sent_feat = self.bert_layer(src_ids, position_ids,
sentence_ids, input_mask)
cls_feats = fluid.layers.dropout(
x=next_sent_feat,
dropout_prob=0.1,
dropout_implementation="upscale_in_train")
logits = self.cls_fc(cls_feats)
return logits
grep: warning: GREP_OPTIONS is deprecated; please use an alias or script
2020-04-13 13:08:30,568-WARNING: use_shared_memory can only be used in multi-process mode(num_workers > 0), set use_shared_memory as False
W0413 13:08:31.584532 119379 device_context.cc:237] Please NOTE: device: 0, CUDA Capability: 70, Driver API Version: 10.1, Runtime API Version: 9.0
W0413 13:08:31.589192 119379 device_context.cc:245] device: 0, cuDNN Version: 7.5.
----------------------------------------------------------------------
bert_config_path: ./data/pretrained_models/uncased_L-12_H-768_A-12//bert_config.json
init_checkpoint: None
init_pretraining_params: ./data/pretrained_models/uncased_L-12_H-768_A-12//dygraph_params/
checkpoints: ./data/saved_model/mnli_models
epoch: 3
learning_rate: 5e-05
lr_scheduler: linear_warmup_decay
weight_decay: 0.01
warmup_proportion: 0.1
save_steps: 1000
validation_steps: 100
loss_scaling: 1.0
skip_steps: 10
data_dir: ./data/glue_data/MNLI/
vocab_path: ./data/pretrained_models/uncased_L-12_H-768_A-12//vocab.txt
max_seq_len: 128
batch_size: 64
in_tokens: False
do_lower_case: True
random_seed: 5512
use_cuda: True
shuffle: True
do_train: True
do_test: True
use_data_parallel: False
verbose: False
----------------------------------------------------------------------
attention_probs_dropout_prob: 0.1
hidden_act: gelu
hidden_dropout_prob: 0.1
hidden_size: 768
initializer_range: 0.02
intermediate_size: 3072
max_position_embeddings: 512
num_attention_heads: 12
num_hidden_layers: 12
type_vocab_size: 2
vocab_size: 30522
------------------------------------------------
Trainer count: 1
Num train examples: 392703
Max train steps: 18407
Num warmup steps: 1840
Epoch 1/3
step 10/12272 - loss: 1.1000 - acc_top1: 0.3531 - acc_top2: 0.6813 - 1s/step
step 20/12272 - loss: 1.1878 - acc_top1: 0.3578 - acc_top2: 0.6875 - 1s/step
step 30/12272 - loss: 1.0812 - acc_top1: 0.3708 - acc_top2: 0.6948 - 1s/step
step 40/12272 - loss: 1.1244 - acc_top1: 0.3773 - acc_top2: 0.6992 - 1s/step
step 50/12272 - loss: 1.1202 - acc_top1: 0.3756 - acc_top2: 0.7006 - 1s/step
step 60/12272 - loss: 1.1291 - acc_top1: 0.3703 - acc_top2: 0.6990 - 1s/step
step 70/12272 - loss: 1.0991 - acc_top1: 0.3634 - acc_top2: 0.6946 - 1s/step
step 80/12272 - loss: 1.0988 - acc_top1: 0.3602 - acc_top2: 0.6914 - 1s/step
step 90/12272 - loss: 1.0718 - acc_top1: 0.3646 - acc_top2: 0.6889 - 1s/step
step 100/12272 - loss: 1.0949 - acc_top1: 0.3638 - acc_top2: 0.6878 - 1s/step
step 110/12272 - loss: 1.1120 - acc_top1: 0.3608 - acc_top2: 0.6895 - 1s/step
step 120/12272 - loss: 1.1105 - acc_top1: 0.3622 - acc_top2: 0.6922 - 1s/step
step 130/12272 - loss: 1.0958 - acc_top1: 0.3623 - acc_top2: 0.6940 - 1s/step
step 140/12272 - loss: 1.0995 - acc_top1: 0.3636 - acc_top2: 0.6926 - 1s/step
step 150/12272 - loss: 1.1272 - acc_top1: 0.3671 - acc_top2: 0.6950 - 1s/step
step 160/12272 - loss: 1.0850 - acc_top1: 0.3697 - acc_top2: 0.6975 - 1s/step
step 170/12272 - loss: 1.0607 - acc_top1: 0.3691 - acc_top2: 0.6991 - 1s/step
step 180/12272 - loss: 1.0623 - acc_top1: 0.3707 - acc_top2: 0.6991 - 1s/step
step 190/12272 - loss: 1.1092 - acc_top1: 0.3697 - acc_top2: 0.6997 - 1s/step
step 200/12272 - loss: 1.1046 - acc_top1: 0.3713 - acc_top2: 0.7030 - 1s/step
step 210/12272 - loss: 1.0945 - acc_top1: 0.3720 - acc_top2: 0.7043 - 1s/step
step 220/12272 - loss: 1.0935 - acc_top1: 0.3719 - acc_top2: 0.7051 - 1s/step
step 230/12272 - loss: 1.1567 - acc_top1: 0.3742 - acc_top2: 0.7048 - 1s/step
step 240/12272 - loss: 1.0745 - acc_top1: 0.3766 - acc_top2: 0.7081 - 1s/step
step 250/12272 - loss: 1.0664 - acc_top1: 0.3756 - acc_top2: 0.7090 - 1s/step
step 260/12272 - loss: 1.0770 - acc_top1: 0.3751 - acc_top2: 0.7085 - 1s/step
step 270/12272 - loss: 1.1008 - acc_top1: 0.3730 - acc_top2: 0.7088 - 1s/step
step 280/12272 - loss: 1.0850 - acc_top1: 0.3737 - acc_top2: 0.7098 - 1s/step
step 290/12272 - loss: 1.0759 - acc_top1: 0.3747 - acc_top2: 0.7100 - 1s/step
step 300/12272 - loss: 1.0352 - acc_top1: 0.3758 - acc_top2: 0.7108 - 1s/step
step 310/12272 - loss: 1.0224 - acc_top1: 0.3786 - acc_top2: 0.7127 - 1s/step
step 320/12272 - loss: 1.0919 - acc_top1: 0.3800 - acc_top2: 0.7137 - 1s/step
step 330/12272 - loss: 1.0884 - acc_top1: 0.3825 - acc_top2: 0.7145 - 1s/step
step 340/12272 - loss: 1.1380 - acc_top1: 0.3849 - acc_top2: 0.7157 - 1s/step
step 350/12272 - loss: 0.9523 - acc_top1: 0.3890 - acc_top2: 0.7176 - 1s/step
step 360/12272 - loss: 0.9963 - acc_top1: 0.3922 - acc_top2: 0.7191 - 1s/step
step 370/12272 - loss: 1.1187 - acc_top1: 0.3955 - acc_top2: 0.7205 - 1s/step
step 380/12272 - loss: 0.9634 - acc_top1: 0.3988 - acc_top2: 0.7229 - 1s/step
step 390/12272 - loss: 0.9944 - acc_top1: 0.4017 - acc_top2: 0.7254 - 1s/step
step 400/12272 - loss: 1.1071 - acc_top1: 0.4044 - acc_top2: 0.7272 - 1s/step
step 410/12272 - loss: 0.9307 - acc_top1: 0.4070 - acc_top2: 0.7293 - 1s/step
step 420/12272 - loss: 1.1307 - acc_top1: 0.4087 - acc_top2: 0.7315 - 1s/step
step 430/12272 - loss: 0.9936 - acc_top1: 0.4110 - acc_top2: 0.7334 - 1s/step
step 440/12272 - loss: 0.9791 - acc_top1: 0.4139 - acc_top2: 0.7357 - 1s/step
step 450/12272 - loss: 1.0112 - acc_top1: 0.4147 - acc_top2: 0.7372 - 1s/step
step 460/12272 - loss: 0.8554 - acc_top1: 0.4179 - acc_top2: 0.7395 - 1s/step
step 470/12272 - loss: 0.9411 - acc_top1: 0.4198 - acc_top2: 0.7406 - 1s/step
step 480/12272 - loss: 0.8481 - acc_top1: 0.4231 - acc_top2: 0.7424 - 1s/step
step 490/12272 - loss: 1.0338 - acc_top1: 0.4261 - acc_top2: 0.7441 - 1s/step
step 500/12272 - loss: 0.9651 - acc_top1: 0.4281 - acc_top2: 0.7459 - 1s/step
step 510/12272 - loss: 0.8091 - acc_top1: 0.4306 - acc_top2: 0.7479 - 1s/step
step 520/12272 - loss: 1.0528 - acc_top1: 0.4325 - acc_top2: 0.7489 - 1s/step
step 530/12272 - loss: 0.9898 - acc_top1: 0.4338 - acc_top2: 0.7500 - 1s/step
step 540/12272 - loss: 0.7900 - acc_top1: 0.4364 - acc_top2: 0.7519 - 1s/step
step 550/12272 - loss: 0.9055 - acc_top1: 0.4389 - acc_top2: 0.7534 - 1s/step
step 560/12272 - loss: 1.0092 - acc_top1: 0.4410 - acc_top2: 0.7549 - 1s/step
step 570/12272 - loss: 0.7068 - acc_top1: 0.4441 - acc_top2: 0.7570 - 1s/step
step 580/12272 - loss: 0.9695 - acc_top1: 0.4455 - acc_top2: 0.7581 - 1s/step
step 590/12272 - loss: 0.8640 - acc_top1: 0.4487 - acc_top2: 0.7600 - 1s/step
step 600/12272 - loss: 0.9068 - acc_top1: 0.4514 - acc_top2: 0.7618 - 1s/step
step 610/12272 - loss: 0.9023 - acc_top1: 0.4524 - acc_top2: 0.7627 - 1s/step
step 620/12272 - loss: 0.7377 - acc_top1: 0.4552 - acc_top2: 0.7640 - 1s/step
step 630/12272 - loss: 0.8900 - acc_top1: 0.4574 - acc_top2: 0.7659 - 1s/step
step 640/12272 - loss: 0.8902 - acc_top1: 0.4590 - acc_top2: 0.7669 - 1s/step
step 650/12272 - loss: 0.9069 - acc_top1: 0.4608 - acc_top2: 0.7686 - 1s/step
step 660/12272 - loss: 0.9630 - acc_top1: 0.4631 - acc_top2: 0.7699 - 1s/step
step 670/12272 - loss: 0.9005 - acc_top1: 0.4652 - acc_top2: 0.7712 - 1s/step
step 680/12272 - loss: 1.0725 - acc_top1: 0.4670 - acc_top2: 0.7725 - 1s/step
step 690/12272 - loss: 0.8322 - acc_top1: 0.4689 - acc_top2: 0.7739 - 1s/step
step 700/12272 - loss: 0.9874 - acc_top1: 0.4714 - acc_top2: 0.7753 - 1s/step
step 710/12272 - loss: 0.7915 - acc_top1: 0.4728 - acc_top2: 0.7765 - 1s/step
step 720/12272 - loss: 0.7174 - acc_top1: 0.4746 - acc_top2: 0.7777 - 1s/step
step 730/12272 - loss: 0.7635 - acc_top1: 0.4770 - acc_top2: 0.7793 - 1s/step
step 740/12272 - loss: 0.9180 - acc_top1: 0.4793 - acc_top2: 0.7804 - 1s/step
step 750/12272 - loss: 0.8424 - acc_top1: 0.4817 - acc_top2: 0.7815 - 1s/step
step 760/12272 - loss: 0.9357 - acc_top1: 0.4837 - acc_top2: 0.7829 - 1s/step
step 770/12272 - loss: 0.7643 - acc_top1: 0.4858 - acc_top2: 0.7839 - 1s/step
step 780/12272 - loss: 0.8910 - acc_top1: 0.4868 - acc_top2: 0.7849 - 1s/step
step 790/12272 - loss: 0.8781 - acc_top1: 0.4888 - acc_top2: 0.7862 - 1s/step
step 800/12272 - loss: 0.8005 - acc_top1: 0.4907 - acc_top2: 0.7877 - 1s/step
step 810/12272 - loss: 0.6740 - acc_top1: 0.4929 - acc_top2: 0.7889 - 1s/step
step 820/12272 - loss: 0.7026 - acc_top1: 0.4947 - acc_top2: 0.7898 - 1s/step
step 830/12272 - loss: 0.8666 - acc_top1: 0.4964 - acc_top2: 0.7908 - 1s/step
step 840/12272 - loss: 0.6296 - acc_top1: 0.4983 - acc_top2: 0.7920 - 1s/step
step 850/12272 - loss: 0.7907 - acc_top1: 0.4992 - acc_top2: 0.7930 - 1s/step
step 860/12272 - loss: 0.7292 - acc_top1: 0.5007 - acc_top2: 0.7935 - 1s/step
step 870/12272 - loss: 0.7498 - acc_top1: 0.5026 - acc_top2: 0.7944 - 1s/step
step 880/12272 - loss: 0.9928 - acc_top1: 0.5040 - acc_top2: 0.7953 - 1s/step
step 890/12272 - loss: 1.0025 - acc_top1: 0.5056 - acc_top2: 0.7962 - 1s/step
step 900/12272 - loss: 0.7810 - acc_top1: 0.5071 - acc_top2: 0.7969 - 1s/step
step 910/12272 - loss: 0.6114 - acc_top1: 0.5090 - acc_top2: 0.7978 - 1s/step
step 920/12272 - loss: 0.7780 - acc_top1: 0.5105 - acc_top2: 0.7988 - 1s/step
step 930/12272 - loss: 0.9457 - acc_top1: 0.5116 - acc_top2: 0.7995 - 1s/step
step 940/12272 - loss: 0.7907 - acc_top1: 0.5135 - acc_top2: 0.8006 - 1s/step
step 950/12272 - loss: 0.5520 - acc_top1: 0.5153 - acc_top2: 0.8013 - 1s/step
step 960/12272 - loss: 0.8251 - acc_top1: 0.5168 - acc_top2: 0.8022 - 1s/step
step 970/12272 - loss: 0.8482 - acc_top1: 0.5179 - acc_top2: 0.8031 - 1s/step
step 980/12272 - loss: 0.8010 - acc_top1: 0.5196 - acc_top2: 0.8038 - 1s/step
step 990/12272 - loss: 0.8326 - acc_top1: 0.5207 - acc_top2: 0.8047 - 1s/step
step 1000/12272 - loss: 0.6979 - acc_top1: 0.5222 - acc_top2: 0.8057 - 1s/step
step 1010/12272 - loss: 0.7506 - acc_top1: 0.5234 - acc_top2: 0.8065 - 1s/step
step 1020/12272 - loss: 0.8457 - acc_top1: 0.5248 - acc_top2: 0.8073 - 1s/step
step 1030/12272 - loss: 0.8698 - acc_top1: 0.5263 - acc_top2: 0.8082 - 1s/step
step 1040/12272 - loss: 0.7016 - acc_top1: 0.5279 - acc_top2: 0.8091 - 1s/step
step 1050/12272 - loss: 0.7766 - acc_top1: 0.5290 - acc_top2: 0.8099 - 1s/step
step 1060/12272 - loss: 0.7994 - acc_top1: 0.5300 - acc_top2: 0.8105 - 1s/step
step 1070/12272 - loss: 0.7053 - acc_top1: 0.5317 - acc_top2: 0.8115 - 1s/step
step 1080/12272 - loss: 0.9085 - acc_top1: 0.5330 - acc_top2: 0.8125 - 1s/step
step 1090/12272 - loss: 0.7556 - acc_top1: 0.5342 - acc_top2: 0.8134 - 1s/step
step 1100/12272 - loss: 0.9364 - acc_top1: 0.5355 - acc_top2: 0.8141 - 1s/step
step 1110/12272 - loss: 0.9403 - acc_top1: 0.5367 - acc_top2: 0.8148 - 1s/step
step 1120/12272 - loss: 0.8228 - acc_top1: 0.5375 - acc_top2: 0.8152 - 1s/step
step 1130/12272 - loss: 0.6802 - acc_top1: 0.5388 - acc_top2: 0.8160 - 1s/step
step 1140/12272 - loss: 0.8222 - acc_top1: 0.5397 - acc_top2: 0.8167 - 1s/step
step 1150/12272 - loss: 0.9321 - acc_top1: 0.5407 - acc_top2: 0.8172 - 1s/step
step 1160/12272 - loss: 0.7478 - acc_top1: 0.5417 - acc_top2: 0.8181 - 1s/step
step 1170/12272 - loss: 0.7976 - acc_top1: 0.5430 - acc_top2: 0.8188 - 1s/step
step 1180/12272 - loss: 0.7386 - acc_top1: 0.5441 - acc_top2: 0.8192 - 1s/step
step 1190/12272 - loss: 0.6448 - acc_top1: 0.5450 - acc_top2: 0.8200 - 1s/step
step 1200/12272 - loss: 0.7441 - acc_top1: 0.5463 - acc_top2: 0.8206 - 1s/step
step 1210/12272 - loss: 0.8171 - acc_top1: 0.5476 - acc_top2: 0.8213 - 1s/step
step 1220/12272 - loss: 0.7480 - acc_top1: 0.5487 - acc_top2: 0.8219 - 1s/step
step 1230/12272 - loss: 0.6363 - acc_top1: 0.5497 - acc_top2: 0.8225 - 1s/step
step 1240/12272 - loss: 0.6630 - acc_top1: 0.5507 - acc_top2: 0.8231 - 1s/step
step 1250/12272 - loss: 0.8668 - acc_top1: 0.5517 - acc_top2: 0.8237 - 1s/step
step 1260/12272 - loss: 0.6057 - acc_top1: 0.5527 - acc_top2: 0.8243 - 1s/step
step 1270/12272 - loss: 0.8432 - acc_top1: 0.5538 - acc_top2: 0.8248 - 1s/step
step 1280/12272 - loss: 0.8447 - acc_top1: 0.5546 - acc_top2: 0.8253 - 1s/step
step 1290/12272 - loss: 0.6928 - acc_top1: 0.5556 - acc_top2: 0.8261 - 1s/step
step 1300/12272 - loss: 0.7872 - acc_top1: 0.5567 - acc_top2: 0.8266 - 1s/step
step 1310/12272 - loss: 0.7968 - acc_top1: 0.5570 - acc_top2: 0.8269 - 1s/step
step 1320/12272 - loss: 0.8059 - acc_top1: 0.5580 - acc_top2: 0.8275 - 1s/step
step 1330/12272 - loss: 0.8603 - acc_top1: 0.5587 - acc_top2: 0.8278 - 1s/step
step 1340/12272 - loss: 0.7872 - acc_top1: 0.5599 - acc_top2: 0.8285 - 1s/step
step 1350/12272 - loss: 0.7037 - acc_top1: 0.5609 - acc_top2: 0.8290 - 1s/step
step 1360/12272 - loss: 0.8268 - acc_top1: 0.5618 - acc_top2: 0.8297 - 1s/step
step 1370/12272 - loss: 0.5962 - acc_top1: 0.5627 - acc_top2: 0.8303 - 1s/step
step 1380/12272 - loss: 0.7712 - acc_top1: 0.5638 - acc_top2: 0.8310 - 1s/step
step 1390/12272 - loss: 0.5770 - acc_top1: 0.5650 - acc_top2: 0.8315 - 1s/step
step 1400/12272 - loss: 0.7174 - acc_top1: 0.5656 - acc_top2: 0.8319 - 1s/step
step 1410/12272 - loss: 0.6224 - acc_top1: 0.5660 - acc_top2: 0.8323 - 1s/step
step 1420/12272 - loss: 0.6782 - acc_top1: 0.5671 - acc_top2: 0.8328 - 1s/step
step 1430/12272 - loss: 0.4087 - acc_top1: 0.5682 - acc_top2: 0.8335 - 1s/step
step 1440/12272 - loss: 0.7534 - acc_top1: 0.5692 - acc_top2: 0.8342 - 1s/step
step 1450/12272 - loss: 0.6446 - acc_top1: 0.5702 - acc_top2: 0.8345 - 1s/step
step 1460/12272 - loss: 0.6606 - acc_top1: 0.5712 - acc_top2: 0.8351 - 1s/step
step 1470/12272 - loss: 0.7308 - acc_top1: 0.5723 - acc_top2: 0.8357 - 1s/step
step 1480/12272 - loss: 0.9016 - acc_top1: 0.5727 - acc_top2: 0.8359 - 1s/step
step 1490/12272 - loss: 0.8445 - acc_top1: 0.5730 - acc_top2: 0.8362 - 1s/step
step 1500/12272 - loss: 0.8217 - acc_top1: 0.5737 - acc_top2: 0.8367 - 1s/step
step 1510/12272 - loss: 0.8413 - acc_top1: 0.5747 - acc_top2: 0.8370 - 1s/step
step 1520/12272 - loss: 0.4643 - acc_top1: 0.5757 - acc_top2: 0.8376 - 1s/step
step 1530/12272 - loss: 0.9351 - acc_top1: 0.5764 - acc_top2: 0.8381 - 1s/step
step 1540/12272 - loss: 0.7856 - acc_top1: 0.5773 - acc_top2: 0.8386 - 1s/step
step 1550/12272 - loss: 0.5921 - acc_top1: 0.5780 - acc_top2: 0.8390 - 1s/step
step 1560/12272 - loss: 0.4460 - acc_top1: 0.5788 - acc_top2: 0.8395 - 1s/step
step 1570/12272 - loss: 0.6814 - acc_top1: 0.5793 - acc_top2: 0.8401 - 1s/step
step 1580/12272 - loss: 0.4115 - acc_top1: 0.5805 - acc_top2: 0.8407 - 1s/step
step 1590/12272 - loss: 0.9326 - acc_top1: 0.5810 - acc_top2: 0.8410 - 1s/step
step 1600/12272 - loss: 0.6989 - acc_top1: 0.5818 - acc_top2: 0.8413 - 1s/step
step 1610/12272 - loss: 0.5238 - acc_top1: 0.5826 - acc_top2: 0.8418 - 1s/step
step 1620/12272 - loss: 0.5827 - acc_top1: 0.5832 - acc_top2: 0.8422 - 1s/step
step 1630/12272 - loss: 0.7703 - acc_top1: 0.5838 - acc_top2: 0.8425 - 1s/step
step 1640/12272 - loss: 0.7926 - acc_top1: 0.5844 - acc_top2: 0.8428 - 1s/step
step 1650/12272 - loss: 0.7143 - acc_top1: 0.5851 - acc_top2: 0.8434 - 1s/step
step 1660/12272 - loss: 0.6240 - acc_top1: 0.5858 - acc_top2: 0.8438 - 1s/step
step 1670/12272 - loss: 0.7869 - acc_top1: 0.5862 - acc_top2: 0.8440 - 1s/step
step 1680/12272 - loss: 0.6485 - acc_top1: 0.5868 - acc_top2: 0.8444 - 1s/step
step 1690/12272 - loss: 0.7539 - acc_top1: 0.5876 - acc_top2: 0.8450 - 1s/step
step 1700/12272 - loss: 0.6173 - acc_top1: 0.5882 - acc_top2: 0.8454 - 1s/step
step 1710/12272 - loss: 0.8056 - acc_top1: 0.5890 - acc_top2: 0.8458 - 1s/step
step 1720/12272 - loss: 0.7035 - acc_top1: 0.5898 - acc_top2: 0.8463 - 1s/step
step 1730/12272 - loss: 0.5892 - acc_top1: 0.5908 - acc_top2: 0.8468 - 1s/step
step 1740/12272 - loss: 0.7755 - acc_top1: 0.5915 - acc_top2: 0.8472 - 1s/step
step 1750/12272 - loss: 0.6911 - acc_top1: 0.5920 - acc_top2: 0.8474 - 1s/step
step 1760/12272 - loss: 0.6309 - acc_top1: 0.5926 - acc_top2: 0.8477 - 1s/step
step 1770/12272 - loss: 0.7506 - acc_top1: 0.5932 - acc_top2: 0.8480 - 1s/step
step 1780/12272 - loss: 0.8711 - acc_top1: 0.5939 - acc_top2: 0.8482 - 1s/step
step 1790/12272 - loss: 0.9146 - acc_top1: 0.5945 - acc_top2: 0.8484 - 1s/step
step 1800/12272 - loss: 0.6208 - acc_top1: 0.5952 - acc_top2: 0.8487 - 1s/step
step 1810/12272 - loss: 0.8506 - acc_top1: 0.5959 - acc_top2: 0.8490 - 1s/step
step 1820/12272 - loss: 0.8330 - acc_top1: 0.5965 - acc_top2: 0.8494 - 1s/step
step 1830/12272 - loss: 0.8315 - acc_top1: 0.5970 - acc_top2: 0.8497 - 1s/step
step 1840/12272 - loss: 0.6227 - acc_top1: 0.5977 - acc_top2: 0.8501 - 1s/step
step 1850/12272 - loss: 0.5972 - acc_top1: 0.5985 - acc_top2: 0.8506 - 1s/step
step 1860/12272 - loss: 0.6309 - acc_top1: 0.5992 - acc_top2: 0.8510 - 1s/step
step 1870/12272 - loss: 0.8707 - acc_top1: 0.5995 - acc_top2: 0.8512 - 1s/step
step 1880/12272 - loss: 0.6419 - acc_top1: 0.6004 - acc_top2: 0.8516 - 1s/step
step 1890/12272 - loss: 0.6015 - acc_top1: 0.6010 - acc_top2: 0.8521 - 1s/step
step 1900/12272 - loss: 0.6000 - acc_top1: 0.6015 - acc_top2: 0.8524 - 1s/step
step 1910/12272 - loss: 0.7010 - acc_top1: 0.6020 - acc_top2: 0.8527 - 1s/step
step 1920/12272 - loss: 0.8539 - acc_top1: 0.6026 - acc_top2: 0.8530 - 1s/step
step 1930/12272 - loss: 0.8381 - acc_top1: 0.6031 - acc_top2: 0.8533 - 1s/step
step 1940/12272 - loss: 0.5921 - acc_top1: 0.6039 - acc_top2: 0.8537 - 1s/step
step 1950/12272 - loss: 0.4974 - acc_top1: 0.6047 - acc_top2: 0.8541 - 1s/step
step 1960/12272 - loss: 0.8269 - acc_top1: 0.6052 - acc_top2: 0.8544 - 1s/step
step 1970/12272 - loss: 0.6157 - acc_top1: 0.6058 - acc_top2: 0.8548 - 1s/step
step 1980/12272 - loss: 1.0949 - acc_top1: 0.6064 - acc_top2: 0.8552 - 1s/step
step 1990/12272 - loss: 0.6442 - acc_top1: 0.6070 - acc_top2: 0.8555 - 1s/step
step 2000/12272 - loss: 0.8747 - acc_top1: 0.6073 - acc_top2: 0.8558 - 1s/step
step 2010/12272 - loss: 0.8101 - acc_top1: 0.6078 - acc_top2: 0.8560 - 1s/step
step 2020/12272 - loss: 0.8623 - acc_top1: 0.6082 - acc_top2: 0.8562 - 1s/step
step 2030/12272 - loss: 0.6664 - acc_top1: 0.6089 - acc_top2: 0.8567 - 1s/step
step 2040/12272 - loss: 0.7616 - acc_top1: 0.6092 - acc_top2: 0.8567 - 1s/step
step 2050/12272 - loss: 0.7282 - acc_top1: 0.6095 - acc_top2: 0.8570 - 1s/step
step 2060/12272 - loss: 0.6914 - acc_top1: 0.6099 - acc_top2: 0.8574 - 1s/step
step 2070/12272 - loss: 0.6129 - acc_top1: 0.6105 - acc_top2: 0.8577 - 1s/step
step 2080/12272 - loss: 0.5605 - acc_top1: 0.6111 - acc_top2: 0.8580 - 1s/step
step 2090/12272 - loss: 0.6432 - acc_top1: 0.6116 - acc_top2: 0.8582 - 1s/step
step 2100/12272 - loss: 0.6783 - acc_top1: 0.6121 - acc_top2: 0.8586 - 1s/step
step 2110/12272 - loss: 0.5949 - acc_top1: 0.6128 - acc_top2: 0.8589 - 1s/step
step 2120/12272 - loss: 0.7832 - acc_top1: 0.6134 - acc_top2: 0.8592 - 1s/step
step 2130/12272 - loss: 0.6633 - acc_top1: 0.6139 - acc_top2: 0.8594 - 1s/step
step 2140/12272 - loss: 0.8456 - acc_top1: 0.6143 - acc_top2: 0.8596 - 1s/step
step 2150/12272 - loss: 0.7133 - acc_top1: 0.6150 - acc_top2: 0.8599 - 1s/step
step 2160/12272 - loss: 0.4699 - acc_top1: 0.6155 - acc_top2: 0.8602 - 1s/step
step 2170/12272 - loss: 0.6013 - acc_top1: 0.6161 - acc_top2: 0.8605 - 1s/step
step 2180/12272 - loss: 0.5676 - acc_top1: 0.6165 - acc_top2: 0.8608 - 1s/step
step 2190/12272 - loss: 0.5850 - acc_top1: 0.6172 - acc_top2: 0.8611 - 1s/step
step 2200/12272 - loss: 0.6887 - acc_top1: 0.6177 - acc_top2: 0.8612 - 1s/step
step 2210/12272 - loss: 0.5706 - acc_top1: 0.6180 - acc_top2: 0.8614 - 1s/step
step 2220/12272 - loss: 0.8251 - acc_top1: 0.6184 - acc_top2: 0.8617 - 1s/step
step 2230/12272 - loss: 0.6532 - acc_top1: 0.6188 - acc_top2: 0.8620 - 1s/step
step 2240/12272 - loss: 0.5888 - acc_top1: 0.6194 - acc_top2: 0.8623 - 1s/step
step 2250/12272 - loss: 0.6360 - acc_top1: 0.6198 - acc_top2: 0.8625 - 1s/step
step 2260/12272 - loss: 1.0555 - acc_top1: 0.6202 - acc_top2: 0.8628 - 1s/step
step 2270/12272 - loss: 0.4848 - acc_top1: 0.6207 - acc_top2: 0.8629 - 1s/step
step 2280/12272 - loss: 0.7243 - acc_top1: 0.6212 - acc_top2: 0.8632 - 1s/step
step 2290/12272 - loss: 0.4358 - acc_top1: 0.6216 - acc_top2: 0.8635 - 1s/step
step 2300/12272 - loss: 0.5473 - acc_top1: 0.6221 - acc_top2: 0.8637 - 1s/step
step 2310/12272 - loss: 0.6440 - acc_top1: 0.6226 - acc_top2: 0.8640 - 1s/step
step 2320/12272 - loss: 0.5785 - acc_top1: 0.6233 - acc_top2: 0.8643 - 1s/step
step 2330/12272 - loss: 0.7199 - acc_top1: 0.6237 - acc_top2: 0.8646 - 1s/step
step 2340/12272 - loss: 0.5622 - acc_top1: 0.6241 - acc_top2: 0.8647 - 1s/step
step 2350/12272 - loss: 0.6742 - acc_top1: 0.6245 - acc_top2: 0.8650 - 1s/step
step 2360/12272 - loss: 0.8149 - acc_top1: 0.6249 - acc_top2: 0.8652 - 1s/step
step 2370/12272 - loss: 0.5900 - acc_top1: 0.6253 - acc_top2: 0.8654 - 1s/step
step 2380/12272 - loss: 0.8046 - acc_top1: 0.6256 - acc_top2: 0.8656 - 1s/step
step 2390/12272 - loss: 0.6097 - acc_top1: 0.6262 - acc_top2: 0.8659 - 1s/step
step 2400/12272 - loss: 0.5936 - acc_top1: 0.6266 - acc_top2: 0.8660 - 1s/step
step 2410/12272 - loss: 0.7245 - acc_top1: 0.6270 - acc_top2: 0.8662 - 1s/step
step 2420/12272 - loss: 0.6349 - acc_top1: 0.6274 - acc_top2: 0.8665 - 1s/step
step 2430/12272 - loss: 0.7009 - acc_top1: 0.6278 - acc_top2: 0.8668 - 1s/step
step 2440/12272 - loss: 0.3881 - acc_top1: 0.6282 - acc_top2: 0.8670 - 1s/step
step 2450/12272 - loss: 0.5226 - acc_top1: 0.6286 - acc_top2: 0.8673 - 1s/step
step 2460/12272 - loss: 0.5748 - acc_top1: 0.6292 - acc_top2: 0.8675 - 1s/step
step 2470/12272 - loss: 0.4798 - acc_top1: 0.6297 - acc_top2: 0.8678 - 1s/step
step 2480/12272 - loss: 0.5857 - acc_top1: 0.6303 - acc_top2: 0.8680 - 1s/step
step 2490/12272 - loss: 0.6729 - acc_top1: 0.6308 - acc_top2: 0.8683 - 1s/step
step 2500/12272 - loss: 0.6392 - acc_top1: 0.6312 - acc_top2: 0.8686 - 1s/step
step 2510/12272 - loss: 0.9607 - acc_top1: 0.6315 - acc_top2: 0.8687 - 1s/step
step 2520/12272 - loss: 0.6036 - acc_top1: 0.6319 - acc_top2: 0.8690 - 1s/step
step 2530/12272 - loss: 0.6505 - acc_top1: 0.6324 - acc_top2: 0.8693 - 1s/step
step 2540/12272 - loss: 0.4558 - acc_top1: 0.6329 - acc_top2: 0.8696 - 1s/step
step 2550/12272 - loss: 0.4215 - acc_top1: 0.6333 - acc_top2: 0.8699 - 1s/step
step 2560/12272 - loss: 0.6908 - acc_top1: 0.6338 - acc_top2: 0.8701 - 1s/step
step 2570/12272 - loss: 0.5833 - acc_top1: 0.6342 - acc_top2: 0.8703 - 1s/step
step 2580/12272 - loss: 0.8548 - acc_top1: 0.6346 - acc_top2: 0.8706 - 1s/step
step 2590/12272 - loss: 0.5770 - acc_top1: 0.6351 - acc_top2: 0.8708 - 1s/step
step 2600/12272 - loss: 0.4476 - acc_top1: 0.6355 - acc_top2: 0.8711 - 1s/step
step 2610/12272 - loss: 0.4145 - acc_top1: 0.6360 - acc_top2: 0.8714 - 1s/step
step 2620/12272 - loss: 0.6625 - acc_top1: 0.6365 - acc_top2: 0.8717 - 1s/step
step 2630/12272 - loss: 0.4808 - acc_top1: 0.6369 - acc_top2: 0.8719 - 1s/step
#!/bin/bash
BERT_BASE_PATH="./data/pretrained_models/uncased_L-12_H-768_A-12/"
TASK_NAME='MNLI'
DATA_PATH="./data/glue_data/MNLI/"
CKPT_PATH="./data/saved_model/mnli_models"
# start fine-tuning
python3.7 -m paddle.distributed.launch --started_port 8899 --selected_gpus=0,1,2,3 bert_classifier.py\
--use_cuda true \
--do_train true \
--do_test true \
--batch_size 64 \
--init_pretraining_params ${BERT_BASE_PATH}/dygraph_params/ \
--data_dir ${DATA_PATH} \
--vocab_path ${BERT_BASE_PATH}/vocab.txt \
--checkpoints ${CKPT_PATH} \
--save_steps 1000 \
--weight_decay 0.01 \
--warmup_proportion 0.1 \
--validation_steps 100 \
--epoch 3 \
--max_seq_len 128 \
--bert_config_path ${BERT_BASE_PATH}/bert_config.json \
--learning_rate 5e-5 \
--skip_steps 10 \
--shuffle true
......@@ -4,7 +4,7 @@ TASK_NAME='MNLI'
DATA_PATH="./data/glue_data/MNLI/"
CKPT_PATH="./data/saved_model/mnli_models"
export CUDA_VISIBLE_DEVICES=7
export CUDA_VISIBLE_DEVICES=0
# start fine-tuning
python3.7 bert_classifier.py\
......
......@@ -30,6 +30,7 @@ from hapi.distributed import DistributedBatchSampler
from hapi.text.bert.data_processor import DataProcessor, XnliProcessor, ColaProcessor, MrpcProcessor, MnliProcessor
from hapi.text.bert.batching import prepare_batch_data
import hapi.text.tokenizer.tokenization as tokenization
from paddle.fluid.dygraph.parallel import ParallelEnv, ParallelStrategy
__all__ = [
'BertInputExample', 'BertInputFeatures', 'SingleSentenceDataset',
......@@ -227,6 +228,9 @@ class SingleSentenceDataset(Dataset):
if line_processor is None:
line_processor = default_line_processor
if ParallelEnv().nranks > 1:
leveldb_file = leveldb_file + "_" + str(ParallelEnv().local_rank)
if not os.path.exists(leveldb_file):
print("putting data %s into leveldb %s" %
(input_file, leveldb_file))
......@@ -384,7 +388,12 @@ class BertDataLoader(object):
quotechar=None,
device=fluid.CPUPlace(),
num_workers=0,
return_list=True):
return_list=True,
phase="train"):
assert phase in [
"train", "predict", "test"
], "phase of BertDataLoader should be in [train, predict, test], but get %s" % phase
self.dataset = SingleSentenceDataset(tokenizer, label_list,
max_seq_length, mode)
......@@ -394,15 +403,21 @@ class BertDataLoader(object):
input_file, label_list, max_seq_length, tokenizer,
line_processor, delimiter, quotechar)
elif mode == "leveldb":
#prepare_leveldb(self, input_file, leveldb_file, label_list, max_seq_length, tokenizer, line_processor=None, delimiter="\t", quotechar=None):
self.dataset.prepare_leveldb(input_file, leveldb_file, label_list,
max_seq_length, tokenizer,
line_processor, delimiter, quotechar)
else:
raise ValueError("mode should be in [all_in_memory, leveldb]")
self.sampler = DistributedBatchSampler(
self.dataset, batch_size, shuffle=shuffle, drop_last=drop_last)
if phase == "train":
self.sampler = DistributedBatchSampler(
self.dataset, batch_size, shuffle=shuffle, drop_last=drop_last)
elif phase == "test" or phase == "predict":
self.sampler = BatchSampler(
dataset=self.dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last)
self.dataloader = DataLoader(
dataset=self.dataset,
......
......@@ -48,8 +48,8 @@ __all__ = [
'RNNCell', 'BasicLSTMCell', 'BasicGRUCell', 'RNN', 'DynamicDecode',
'BeamSearchDecoder', 'MultiHeadAttention', 'FFN',
'TransformerEncoderLayer', 'TransformerEncoder', 'TransformerDecoderLayer',
'TransformerDecoder', 'TransformerBeamSearchDecoder', 'BiGRU',
'Linear_chain_crf', 'Crf_decoding', 'SequenceTagging'
'TransformerDecoder', 'TransformerBeamSearchDecoder', 'Linear_chain_crf',
'Crf_decoding', 'SequenceTagging'
]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册