提交 edf5630f 编写于 作者: Z zhangxuefei

Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleHub into develop

export FLAGS_eager_delete_tensor_gb=0.0
export CUDA_VISIBLE_DEVICES=0
python -u img_classifier.py $@
export FLAGS_eager_delete_tensor_gb=0.0
export CUDA_VISIBLE_DEVICES=0
python -u predict.py $@
export FLAGS_eager_delete_tensor_gb=0.0
export CUDA_VISIBLE_DEVICES=0
CKPT_DIR="./ckpt_qa"
# Recommending hyper parameters for difference task
# ChnSentiCorp: batch_size=24, weight_decay=0.01, num_epoch=3, max_seq_len=128, lr=5e-5
......
......@@ -89,9 +89,7 @@ if __name__ == '__main__':
# Setup runing config for PaddleHub Finetune API
config = hub.RunConfig(
log_interval=10,
eval_interval=300,
save_ckpt_interval=10000,
use_pyreader=args.use_pyreader,
use_data_parallel=args.use_data_parallel,
use_cuda=args.use_gpu,
......
export FLAGS_eager_delete_tensor_gb=0.0
export CUDA_VISIBLE_DEVICES=0
# Recommending hyper parameters for difference task
# squad: batch_size=8, weight_decay=0, num_epoch=3, max_seq_len=512, lr=5e-5
......
export FLAGS_eager_delete_tensor_gb=0.0
export CUDA_VISIBLE_DEVICES=0
CKPT_DIR="./ckpt_cmrc2018"
dataset=cmrc2018
......
export FLAGS_eager_delete_tensor_gb=0.0
# export CUDA_VISIBLE_DEVICES=0
export CUDA_VISIBLE_DEVICES=0
# User can select chnsenticorp, nlpcc_dbqa, lcqmc and so on for different task
DATASET="STS-B"
......
......@@ -41,7 +41,7 @@ args = parser.parse_args()
if __name__ == '__main__':
# loading Paddlehub ERNIE pretrained model
module = hub.Module(name="ernie")
module = hub.Module(name="ernie_tiny")
inputs, outputs, program = module.context(max_seq_len=args.max_seq_len)
# Sentence labeling dataset reader
......@@ -49,7 +49,9 @@ if __name__ == '__main__':
reader = hub.reader.SequenceLabelReader(
dataset=dataset,
vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len)
max_seq_len=args.max_seq_len,
sp_model_path=module.get_spm_path(),
word_dict_path=module.get_word_dict_path())
inv_label_map = {val: key for key, val in reader.label_map.items()}
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
......
export FLAGS_eager_delete_tensor_gb=0.0
export CUDA_VISIBLE_DEVICES=0
CKPT_DIR="./ckpt_sequence_label"
python -u sequence_label.py \
......
......@@ -71,9 +71,6 @@ if __name__ == '__main__':
# Setup runing config for PaddleHub Finetune API
config = hub.RunConfig(
log_interval=10,
eval_interval=300,
save_ckpt_interval=10000,
use_data_parallel=args.use_data_parallel,
use_pyreader=args.use_pyreader,
use_cuda=args.use_gpu,
......
......@@ -45,15 +45,35 @@ if __name__ == '__main__':
# Download dataset and use ClassifyReader to read dataset
if args.dataset.lower() == "chnsenticorp":
dataset = hub.dataset.ChnSentiCorp()
module = hub.Module(name="ernie")
module = hub.Module(name="ernie_tiny")
metrics_choices = ["acc"]
elif args.dataset.lower() == "tnews":
dataset = hub.dataset.TNews()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc"]
elif args.dataset.lower() == "nlpcc_dbqa":
dataset = hub.dataset.NLPCC_DBQA()
module = hub.Module(name="ernie")
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc"]
elif args.dataset.lower() == "lcqmc":
dataset = hub.dataset.LCQMC()
module = hub.Module(name="ernie")
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc"]
elif args.dataset.lower() == 'inews':
dataset = hub.dataset.INews()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc"]
elif args.dataset.lower() == 'bq':
dataset = hub.dataset.BQ()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc"]
elif args.dataset.lower() == 'thucnews':
dataset = hub.dataset.THUCNEWS()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc"]
elif args.dataset.lower() == 'iflytek':
dataset = hub.dataset.IFLYTEK()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc"]
elif args.dataset.lower() == "mrpc":
dataset = hub.dataset.GLUE("MRPC")
......@@ -90,7 +110,7 @@ if __name__ == '__main__':
metrics_choices = ["acc"]
elif args.dataset.lower().startswith("xnli"):
dataset = hub.dataset.XNLI(language=args.dataset.lower()[-2:])
module = hub.Module(name="bert_multi_cased_L-12_H-768_A-12")
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc"]
else:
raise ValueError("%s dataset is not defined" % args.dataset)
......
export FLAGS_eager_delete_tensor_gb=0.0
export CUDA_VISIBLE_DEVICES=0
# User can select chnsenticorp, nlpcc_dbqa, lcqmc and so on for different task
DATASET="chnsenticorp"
......
......@@ -17,4 +17,4 @@ python -u predict.py --checkpoint_dir=$CKPT_DIR \
--max_seq_len=128 \
--use_gpu=True \
--dataset=${DATASET} \
--batch_size=150 \
--batch_size=32 \
......@@ -47,7 +47,7 @@ if __name__ == '__main__':
elif args.dataset.lower() == "tnews":
dataset = hub.dataset.TNews()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc", "f1"]
metrics_choices = ["acc"]
elif args.dataset.lower() == "nlpcc_dbqa":
dataset = hub.dataset.NLPCC_DBQA()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
......@@ -59,19 +59,19 @@ if __name__ == '__main__':
elif args.dataset.lower() == 'inews':
dataset = hub.dataset.INews()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc", "f1"]
metrics_choices = ["acc"]
elif args.dataset.lower() == 'bq':
dataset = hub.dataset.BQ()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc", "f1"]
metrics_choices = ["acc"]
elif args.dataset.lower() == 'thucnews':
dataset = hub.dataset.THUCNEWS()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc", "f1"]
metrics_choices = ["acc"]
elif args.dataset.lower() == 'iflytek':
dataset = hub.dataset.IFLYTEK()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc", "f1"]
metrics_choices = ["acc"]
elif args.dataset.lower() == "mrpc":
dataset = hub.dataset.GLUE("MRPC")
module = hub.Module(name="ernie_v2_eng_base")
......@@ -97,7 +97,7 @@ if __name__ == '__main__':
dataset = hub.dataset.GLUE("RTE")
module = hub.Module(name="ernie_v2_eng_base")
metrics_choices = ["acc"]
elif args.dataset.lower() == "mnli" or args.dataset.lower() == "mnli":
elif args.dataset.lower() == "mnli" or args.dataset.lower() == "mnli_m":
dataset = hub.dataset.GLUE("MNLI_m")
module = hub.Module(name="ernie_v2_eng_base")
metrics_choices = ["acc"]
......
......@@ -49,6 +49,7 @@ class ImageClassificationReader(object):
self.data_augmentation = data_augmentation
self.images_std = images_std
self.images_mean = images_mean
self.num_examples = {'train': -1, 'dev': -1, 'test': -1}
if self.images_mean is None:
try:
......@@ -80,12 +81,15 @@ class ImageClassificationReader(object):
raise ValueError("The dataset is none and it's not allowed!")
if phase == "train":
data = self.dataset.train_data(shuffle)
self.num_examples['train'] = len(self.get_train_examples())
elif phase == "test":
shuffle = False
data = self.dataset.test_data(shuffle)
self.num_examples['test'] = len(self.get_test_examples())
elif phase == "val" or phase == "dev":
shuffle = False
data = self.dataset.validate_data(shuffle)
self.num_examples['dev'] = len(self.get_dev_examples())
elif phase == "predict":
data = data
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册