From 033906aa95cd5eae0468208fed4f34f1d46abb41 Mon Sep 17 00:00:00 2001 From: xixiaoyao Date: Thu, 28 May 2020 17:28:15 +0800 Subject: [PATCH] fix fine grain train --- examples/train_with_eval/README.md | 81 ++++++++++++++++++++++++++++ examples/train_with_eval/download.py | 42 +++++++++++++++ examples/train_with_eval/evaluate.py | 57 ++++++++++++++++++++ examples/train_with_eval/run.py | 78 +++++++++++++++++++++++++++ paddlepalm/multihead_trainer.py | 9 ++++ paddlepalm/trainer.py | 8 ++- 6 files changed, 273 insertions(+), 2 deletions(-) create mode 100644 examples/train_with_eval/README.md create mode 100755 examples/train_with_eval/download.py create mode 100644 examples/train_with_eval/evaluate.py create mode 100644 examples/train_with_eval/run.py diff --git a/examples/train_with_eval/README.md b/examples/train_with_eval/README.md new file mode 100644 index 0000000..c77316d --- /dev/null +++ b/examples/train_with_eval/README.md @@ -0,0 +1,81 @@ +## Train with Evaluation version of Example 1: Classification +This task is a sentiment analysis task. The following sections detail model preparation, dataset preparation, and how to run the task. Here to demonstrate how to do evaluation during training in PaddlePALM. + +### Step 1: Prepare Pre-trained Model & Dataset + +#### Pre-trained Model + +The pre-training model of this mission is: [ERNIE-v1-zh-base](https://github.com/PaddlePaddle/PALM/tree/r0.3-api). + +Make sure you have downloaded the required pre-training model in the current folder. + + +#### Dataset + +This example demonstrates with [ChnSentiCorp](https://github.com/SophonPlus/ChineseNlpCorpus/tree/master/datasets/ChnSentiCorp_htl_all), a Chinese sentiment analysis dataset. + +Download dataset: +```shell +python download.py +``` + +If everything goes well, there will be a folder named `data/` created with all the data files in it. + +The dataset file (for training) should have 2 fields, `text_a` and `label`, stored with [tsv](https://en.wikipedia.org/wiki/Tab-separated_values) format. Here shows an example: + +``` +label text_a +0 当当网名不符实,订货多日不见送货,询问客服只会推托,只会要求用户再下订单。如此服务留不住顾客的。去别的网站买书服务更好。 +0 XP的驱动不好找!我的17号提的货,现在就降价了100元,而且还送杀毒软件! +1 <荐书> 推荐所有喜欢<红楼>的红迷们一定要收藏这本书,要知道当年我听说这本书的时候花很长时间去图书馆找和借都没能如愿,所以这次一看到当当有,马上买了,红迷们也要记得备货哦! +``` + +### Step 2: Train & Predict + +The code used to perform this task is in `run.py`. If you have prepared the pre-training model and the data set required for the task, run: + +```shell +python run.py +``` + +If you want to specify a specific gpu or use multiple gpus for training, please use **`CUDA_VISIBLE_DEVICES`**, for example: + +```shell +CUDA_VISIBLE_DEVICES=0,1 python run.py +``` + +Note: On multi-gpu mode, PaddlePALM will automatically split each batch onto the available cards. For example, if the `batch_size` is set 64, and there are 4 cards visible for PaddlePALM, then the batch_size in each card is actually 64/4=16. If you want to change the `batch_size` or the number of gpus used in the example, **you need to ensure that the set batch_size can be divided by the number of cards.** + + +Some logs will be shown below: + +``` +step 1/154 (epoch 0), loss: 5.512, speed: 0.51 steps/s +step 2/154 (epoch 0), loss: 2.595, speed: 3.36 steps/s +step 3/154 (epoch 0), loss: 1.798, speed: 3.48 steps/s +``` + + +After the run, you can view the saved models in the `outputs/` folder and the predictions in the `outputs/predict` folder. Here are some examples of predictions: + + +``` +{"index": 0, "logits": [-0.2014336884021759, 0.6799028515815735], "probs": [0.29290086030960083, 0.7070990800857544], "label": 1} +{"index": 1, "logits": [0.8593899011611938, -0.29743513464927673], "probs": [0.7607553601264954, 0.23924466967582703], "label": 0} +{"index": 2, "logits": [0.7462944388389587, -0.7083730101585388], "probs": [0.8107157349586487, 0.18928426504135132], "label": 0} +``` + +### Step 3: Evaluate + +Once you have the prediction, you can run the evaluation script to evaluate the model: + +```shell +python evaluate.py +``` + +The evaluation results are as follows: + +``` +data num: 1200 +accuracy: 0.9575, precision: 0.9634, recall: 0.9523, f1: 0.9578 +``` diff --git a/examples/train_with_eval/download.py b/examples/train_with_eval/download.py new file mode 100755 index 0000000..72435bb --- /dev/null +++ b/examples/train_with_eval/download.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function +import os +import tarfile +import shutil +import sys +import urllib +URLLIB=urllib +if sys.version_info >= (3, 0): + import urllib.request + URLLIB=urllib.request + +def download(src, url): + def _reporthook(count, chunk_size, total_size): + bytes_so_far = count * chunk_size + percent = float(bytes_so_far) / float(total_size) + if percent > 1: + percent = 1 + print('\r>> Downloading... {:.1%}'.format(percent), end="") + + URLLIB.urlretrieve(url, src, reporthook=_reporthook) + +abs_path = os.path.abspath(__file__) +download_url = "https://ernie.bj.bcebos.com/task_data_zh.tgz" +downlaod_path = os.path.join(os.path.dirname(abs_path), "task_data_zh.tgz") +target_dir = os.path.dirname(abs_path) +download(downlaod_path, download_url) + +tar = tarfile.open(downlaod_path) +tar.extractall(target_dir) +os.remove(downlaod_path) + +abs_path = os.path.abspath(__file__) +dst_dir = os.path.join(os.path.dirname(abs_path), "data") +if not os.path.exists(dst_dir) or not os.path.isdir(dst_dir): + os.makedirs(dst_dir) + +for file in os.listdir(os.path.join(target_dir, 'task_data', 'chnsenticorp')): + shutil.move(os.path.join(target_dir, 'task_data', 'chnsenticorp', file), dst_dir) + +shutil.rmtree(os.path.join(target_dir, 'task_data')) +print(" done!") diff --git a/examples/train_with_eval/evaluate.py b/examples/train_with_eval/evaluate.py new file mode 100644 index 0000000..4b1b0d3 --- /dev/null +++ b/examples/train_with_eval/evaluate.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- + +import json +import numpy as np + +def accuracy(preds, labels): + preds = np.array(preds) + labels = np.array(labels) + return (preds == labels).mean() + +def pre_recall_f1(preds, labels): + preds = np.array(preds) + labels = np.array(labels) + # recall=TP/(TP+FN) + tp = np.sum((labels == '1') & (preds == '1')) + fp = np.sum((labels == '0') & (preds == '1')) + fn = np.sum((labels == '1') & (preds == '0')) + r = tp * 1.0 / (tp + fn) + # Precision=TP/(TP+FP) + p = tp * 1.0 / (tp + fp) + epsilon = 1e-31 + f1 = 2 * p * r / (p+r+epsilon) + return p, r, f1 + + +def res_evaluate(res_dir="./outputs/predict/predictions.json", eval_phase='test'): + if eval_phase == 'test': + data_dir="./data/test.tsv" + elif eval_phase == 'dev': + data_dir="./data/dev.tsv" + else: + assert eval_phase in ['dev', 'test'], 'eval_phase should be dev or test' + + labels = [] + with open(data_dir, "r") as file: + first_flag = True + for line in file: + line = line.split("\t") + label = line[0] + if label=='label': + continue + labels.append(str(label)) + file.close() + + preds = [] + with open(res_dir, "r") as file: + for line in file.readlines(): + line = json.loads(line) + pred = line['label'] + preds.append(str(pred)) + file.close() + assert len(labels) == len(preds), "prediction result doesn't match to labels" + print('data num: {}'.format(len(labels))) + p, r, f1 = pre_recall_f1(preds, labels) + print("accuracy: {:.4f}, precision: {:.4f}, recall: {:.4f}, f1: {:.4f}".format(accuracy(preds, labels), p, r, f1)) + +res_evaluate() diff --git a/examples/train_with_eval/run.py b/examples/train_with_eval/run.py new file mode 100644 index 0000000..cc0b3c9 --- /dev/null +++ b/examples/train_with_eval/run.py @@ -0,0 +1,78 @@ +# coding=utf-8 +import paddlepalm as palm +import json + + +if __name__ == '__main__': + + # configs + max_seqlen = 256 + batch_size = 8 + num_epochs = 10 + lr = 5e-5 + weight_decay = 0.01 + vocab_path = './pretrain/ERNIE-v1-zh-base/vocab.txt' + + train_file = './data/train.tsv' + predict_file = './data/test.tsv' + config = json.load(open('./pretrain/ERNIE-v1-zh-base/ernie_config.json')) + input_dim = config['hidden_size'] + num_classes = 2 + dropout_prob = 0.1 + random_seed = 1 + task_name = 'chnsenticorp' + save_path = './outputs/' + pred_output = './outputs/predict/' + save_type = 'ckpt' + print_steps = 20 + pre_params = './pretrain/ERNIE-v1-zh-base/params' + + # ----------------------- for training ----------------------- + + # step 1-1: create readers for training + cls_reader = palm.reader.ClassifyReader(vocab_path, max_seqlen, seed=random_seed) + # step 1-2: load the training data + cls_reader.load_data(train_file, batch_size, num_epochs=num_epochs) + + # step 2: create a backbone of the model to extract text features + ernie = palm.backbone.ERNIE.from_config(config) + + # step 3: register the backbone in reader + cls_reader.register_with(ernie) + + # step 4: create the task output head + cls_head = palm.head.Classify(num_classes, input_dim, dropout_prob) + + # step 5-1: create a task trainer + trainer = palm.Trainer(task_name) + # step 5-2: build forward graph with backbone and task head + loss_var = trainer.build_forward(ernie, cls_head) + + # step 6-1*: use warmup + n_steps = cls_reader.num_examples * num_epochs // batch_size + warmup_steps = int(0.1 * n_steps) + sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps) + # step 6-2: create a optimizer + adam = palm.optimizer.Adam(loss_var, lr, sched) + # step 6-3: build backward + trainer.build_backward(optimizer=adam, weight_decay=weight_decay) + + # step 7: fit prepared reader and data + iterator = trainer.fit_reader(cls_reader) + + # step 8-1*: load pretrained parameters + trainer.load_pretrain(pre_params) + # step 8-2*: set saver to save model + # save_steps = n_steps + save_steps = 2396 + trainer.set_saver(save_steps=save_steps, save_path=save_path, save_type=save_type) + + # step 8-3: start training + # you can repeatly get one train batch with trainer.get_one_batch() + # batch = trainer.get_one_batch() + for step, batch in enumerate(iterator, start=1): + trainer.train_one_step(batch) + if step % 100 == 0: + print('do evaluation.') + # insert evaluation code here + diff --git a/paddlepalm/multihead_trainer.py b/paddlepalm/multihead_trainer.py index 1c886f9..2a7f2c5 100644 --- a/paddlepalm/multihead_trainer.py +++ b/paddlepalm/multihead_trainer.py @@ -32,6 +32,7 @@ class MultiHeadTrainer(Trainer): self._name_pads = {i.name: name_maxlen-len(i.name) for i in self._trainers} self._train_init = False + self._dist_train_init = False self._predict_init = False self._feeded_var_names = None self._cur_train_step = 0 @@ -274,6 +275,7 @@ class MultiHeadTrainer(Trainer): elif phase == 'predict': self._predict_reader = distribute_feeder_fn self._pred_feed_batch_process_fn = feed_batch_process_fn + return distribute_feeder_fn def _check_finish(self, task_name, silent=False): trainers = {t.name:t for t in self._trainers} @@ -327,6 +329,13 @@ class MultiHeadTrainer(Trainer): break def train_one_step(self, batch): + if not self._dist_train_init: + self._distribute_train_prog = fluid.CompiledProgram(self._train_prog).with_data_parallel(loss_name=self._loss_var.name) + for t in self._trainers: + t._set_exe(self._exe) + t._set_dist_train(self._distribute_train_prog) + t._set_fetch_list(self._fetch_list) + self._dist_train_init = True if dev_count > 1: assert isinstance(batch, tuple) diff --git a/paddlepalm/trainer.py b/paddlepalm/trainer.py index fbc0c2a..b6a056d 100644 --- a/paddlepalm/trainer.py +++ b/paddlepalm/trainer.py @@ -49,6 +49,7 @@ class Trainer(object): self._pred_head = None self._train_reader = None + self._dist_train_init = False self._predict_reader = None self._train_iterator = None self._predict_iterator = None @@ -389,8 +390,7 @@ class Trainer(object): elif phase == 'predict': self._predict_iterator = distribute_feeder_fn self._pred_feed_batch_process_fn = feed_batch_process_fn - # return distribute_feeder_fn() - + return distribute_feeder_fn def load_ckpt(self, model_path): """ @@ -646,6 +646,10 @@ class Trainer(object): def train_one_step(self, batch): + if not self._dist_train_init: + self._distribute_train_prog = fluid.CompiledProgram(self._train_prog).with_data_parallel(loss_name=self._loss_var.name) + self._dist_train_init = True + exe = self._exe distribute_train_prog = self._distribute_train_prog fetch_list = self._fetch_list -- GitLab