run.py 3.1 KB
Newer Older
W
wangxiao1021 已提交
1 2 3 4 5 6 7 8
# coding=utf-8
import paddlepalm as palm
import json


if __name__ == '__main__':

    # configs
W
wangxiao1021 已提交
9
    max_seqlen = 128
W
wangxiao1021 已提交
10 11 12 13 14
    batch_size = 16
    num_epochs = 20
    print_steps = 5
    lr = 2e-5
    num_classes = 130
W
wangxiao1021 已提交
15
    weight_decay = 0.01
W
wangxiao1021 已提交
16
    num_classes_intent = 26
W
wangxiao1021 已提交
17
    dropout_prob = 0.1
W
wangxiao1021 已提交
18 19
    random_seed = 0
    label_map = './data/atis/atis_slot/label_map.json'
W
wangxiao1021 已提交
20
    vocab_path = './pretrain/ERNIE-v2-en-base/vocab.txt'
W
wangxiao1021 已提交
21

W
wangxiao1021 已提交
22 23 24
    train_slot = './data/atis/atis_slot/train.tsv'
    train_intent = './data/atis/atis_intent/train.tsv'
    predict_file = './data/atis/atis_slot/test.tsv'
W
wangxiao1021 已提交
25 26 27
    save_path = './outputs/'
    pred_output = './outputs/predict/'
    save_type = 'ckpt'
W
wangxiao1021 已提交
28

W
wangxiao1021 已提交
29 30
    pre_params = './pretrain/ERNIE-v2-en-base/params'
    config = json.load(open('./pretrain/ERNIE-v2-en-base/ernie_config.json'))
W
wangxiao1021 已提交
31 32 33 34
    input_dim = config['hidden_size']

    # -----------------------  for training ----------------------- 

W
wangxiao1021 已提交
35 36
    # step 1-1: create readers for training 
    seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, seed=random_seed)
W
wangxiao1021 已提交
37
    cls_reader = palm.reader.ClassifyReader(vocab_path, max_seqlen, seed=random_seed)
W
wangxiao1021 已提交
38

W
wangxiao1021 已提交
39
    # step 1-2: load the training data
W
wangxiao1021 已提交
40
    seq_label_reader.load_data(train_slot, file_format='tsv', num_epochs=None, batch_size=batch_size)
W
wangxiao1021 已提交
41 42
    cls_reader.load_data(train_intent, batch_size=batch_size, num_epochs=None)

W
wangxiao1021 已提交
43 44 45 46
    # step 2: create a backbone of the model to extract text features
    ernie = palm.backbone.ERNIE.from_config(config)

    # step 3: register the backbone in readers
W
wangxiao1021 已提交
47
    seq_label_reader.register_with(ernie)
W
wangxiao1021 已提交
48
    cls_reader.register_with(ernie)
W
wangxiao1021 已提交
49 50

    # step 4: create task output heads
W
wangxiao1021 已提交
51
    seq_label_head = palm.head.SequenceLabel(num_classes, input_dim, dropout_prob)
W
wangxiao1021 已提交
52
    cls_head = palm.head.Classify(num_classes_intent, input_dim, dropout_prob)
W
wangxiao1021 已提交
53
   
W
wangxiao1021 已提交
54
    # step 5-1: create a task trainer
W
wangxiao1021 已提交
55
    trainer_seq_label = palm.Trainer("slot", mix_ratio=1.0)
W
wangxiao1021 已提交
56 57
    trainer_cls = palm.Trainer("intent", mix_ratio=1.0)
    trainer = palm.MultiHeadTrainer([trainer_seq_label, trainer_cls])
W
wangxiao1021 已提交
58
    # # step 5-2: build forward graph with backbone and task head
W
wangxiao1021 已提交
59 60
    loss1 = trainer_cls.build_forward(ernie, cls_head)
    loss2 = trainer_seq_label.build_forward(ernie, seq_label_head)
W
wangxiao1021 已提交
61 62
    loss_var = trainer.build_forward()

W
wangxiao1021 已提交
63
    # step 6-1*: use warmup
W
wangxiao1021 已提交
64
    n_steps = seq_label_reader.num_examples * 1.5 * num_epochs // batch_size
W
wangxiao1021 已提交
65 66 67 68 69 70 71 72
    warmup_steps = int(0.1 * n_steps)
    sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps)
    # step 6-2: create a optimizer
    adam = palm.optimizer.Adam(loss_var, lr, sched)
    # step 6-3: build backward
    trainer.build_backward(optimizer=adam, weight_decay=weight_decay)

    # step 7: fit prepared reader and data
W
wangxiao1021 已提交
73
    trainer.fit_readers_with_mixratio([seq_label_reader, cls_reader], "slot", num_epochs)
W
wangxiao1021 已提交
74

W
wangxiao1021 已提交
75 76 77
    # step 8-1*: load pretrained parameters
    trainer.load_pretrain(pre_params)
    # step 8-2*: set saver to save model
W
wangxiao1021 已提交
78 79
    save_steps = int(n_steps-batch_size) // 2
    # save_steps = 10
W
wangxiao1021 已提交
80
    trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type)
W
wangxiao1021 已提交
81
    # step 8-3: start training
W
wangxiao 已提交
82
    trainer.train(print_steps=print_steps)