train.py 7.1 KB
Newer Older
J
JiaQi Xu 已提交
1 2
from nets.ssd import get_ssd
from nets.ssd_training import Generator,MultiBoxLoss
B
Bubbliiiing 已提交
3 4
from torch.utils.data import DataLoader
from utils.dataloader import ssd_dataset_collate, SSDDataset
J
JiaQi Xu 已提交
5 6 7
from utils.config import Config
from torchsummary import summary
from torch.autograd import Variable
B
Bubbliiiing 已提交
8
from tqdm import tqdm
J
JiaQi Xu 已提交
9 10 11 12 13 14 15 16
import torch.backends.cudnn as cudnn
import time
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init

B
Bubbliiiing 已提交
17 18 19 20
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']
        
J
JiaQi Xu 已提交
21
if __name__ == "__main__":
J
JiaQi Xu 已提交
22 23 24 25 26 27
    # ------------------------------------#
    #   先冻结一部分权重训练
    #   后解冻全部权重训练
    #   先大学习率
    #   后小学习率
    # ------------------------------------#
B
Bubbliiiing 已提交
28 29
    lr = 5e-4
    freeze_lr = 1e-4
J
JiaQi Xu 已提交
30
    Cuda = True
B
Bubbliiiing 已提交
31

J
JiaQi Xu 已提交
32
    Start_iter = 0
J
JiaQi Xu 已提交
33 34 35
    Freeze_epoch = 25
    Epoch = 50

B
Bubbliiiing 已提交
36 37 38 39 40 41
    Batch_size = 4
    #-------------------------------#
    #   Dataloder的使用
    #-------------------------------#
    Use_Data_Loader = True

J
JiaQi Xu 已提交
42
    model = get_ssd("train",Config["num_classes"])
B
Bubbliiiing 已提交
43

J
JiaQi Xu 已提交
44
    print('Loading weights into state dict...')
B
Bubbliiiing 已提交
45
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
J
JiaQi Xu 已提交
46
    model_dict = model.state_dict()
B
Bubbliiiing 已提交
47
    pretrained_dict = torch.load("model_data/ssd_weights.pth", map_location=device)
J
JiaQi Xu 已提交
48
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) ==  np.shape(v)}
J
JiaQi Xu 已提交
49
    model_dict.update(pretrained_dict)
J
JiaQi Xu 已提交
50
    model.load_state_dict(model_dict)
J
JiaQi Xu 已提交
51 52
    print('Finished!')

J
JiaQi Xu 已提交
53
    net = model.train()
J
JiaQi Xu 已提交
54 55 56 57 58 59 60 61 62 63 64
    if Cuda:
        net = torch.nn.DataParallel(model)
        cudnn.benchmark = True
        net = net.cuda()

    annotation_path = '2007_train.txt'
    with open(annotation_path) as f:
        lines = f.readlines()
    np.random.seed(10101)
    np.random.shuffle(lines)
    np.random.seed(None)
J
JiaQi Xu 已提交
65 66
    num_train = len(lines)

B
Bubbliiiing 已提交
67 68 69 70 71 72 73
    if Use_Data_Loader:
        train_dataset = SSDDataset(lines[:num_train], (Config["min_dim"], Config["min_dim"]))
        gen = DataLoader(train_dataset, batch_size=Batch_size, num_workers=8, pin_memory=True,
                                drop_last=True, collate_fn=ssd_dataset_collate)
    else:
        gen = Generator(Batch_size, lines,
                        (Config["min_dim"], Config["min_dim"]), Config["num_classes"]).generate()
J
JiaQi Xu 已提交
74 75 76

    criterion = MultiBoxLoss(Config['num_classes'], 0.5, True, 0, True, 3, 0.5,
                             False, Cuda)
J
JiaQi Xu 已提交
77
    epoch_size = num_train // Batch_size
J
JiaQi Xu 已提交
78

J
JiaQi Xu 已提交
79 80 81 82 83 84
    if True:
        # ------------------------------------#
        #   冻结一定部分训练
        # ------------------------------------#
        for param in model.vgg.parameters():
            param.requires_grad = False
J
JiaQi Xu 已提交
85

J
JiaQi Xu 已提交
86
        optimizer = optim.Adam(net.parameters(), lr=lr)
B
Bubbliiiing 已提交
87
        lr_scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=1,gamma=0.95)
J
JiaQi Xu 已提交
88
        for epoch in range(Start_iter,Freeze_epoch):
B
Bubbliiiing 已提交
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
            with tqdm(total=epoch_size,desc=f'Epoch {epoch + 1}/{Freeze_epoch}',postfix=dict,mininterval=0.3) as pbar:
                loc_loss = 0
                conf_loss = 0
                for iteration, batch in enumerate(gen):
                    if iteration >= epoch_size:
                        break
                    images, targets = batch[0], batch[1]
                    with torch.no_grad():
                        if Cuda:
                            images = Variable(torch.from_numpy(images).type(torch.FloatTensor)).cuda()
                            targets = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)).cuda() for ann in targets]
                        else:
                            images = Variable(torch.from_numpy(images).type(torch.FloatTensor))
                            targets = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets]
                    # 前向传播
                    out = net(images)
                    # 清零梯度
                    optimizer.zero_grad()
                    # 计算loss
                    loss_l, loss_c = criterion(out, targets)
                    loss = loss_l + loss_c
                    # 反向传播
                    loss.backward()
                    optimizer.step()
                    # 加上
                    loc_loss += loss_l.item()
                    conf_loss += loss_c.item()

                    pbar.set_postfix(**{'loc_loss'  : loc_loss / (iteration + 1), 
                                        'conf_loss' : conf_loss / (iteration + 1),
                                        'lr'        : get_lr(optimizer)})
                    pbar.update(1)
                        
            lr_scheduler.step()
J
JiaQi Xu 已提交
123 124 125 126 127 128 129 130 131 132 133
            print('Saving state, iter:', str(epoch+1))
            torch.save(model.state_dict(), 'logs/Epoch%d-Loc%.4f-Conf%.4f.pth'%((epoch+1),loc_loss/(iteration+1),conf_loss/(iteration+1)))

    if True:
        # ------------------------------------#
        #   全部解冻训练
        # ------------------------------------#
        for param in model.vgg.parameters():
            param.requires_grad = True

        optimizer = optim.Adam(net.parameters(), lr=freeze_lr)
B
Bubbliiiing 已提交
134
        lr_scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=1,gamma=0.95)
J
JiaQi Xu 已提交
135
        for epoch in range(Freeze_epoch,Epoch):
B
Bubbliiiing 已提交
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
            with tqdm(total=epoch_size,desc=f'Epoch {epoch + 1}/{Freeze_epoch}',postfix=dict,mininterval=0.3) as pbar:
                loc_loss = 0
                conf_loss = 0
                for iteration, batch in enumerate(gen):
                    if iteration >= epoch_size:
                        break
                    images, targets = batch[0], batch[1]
                    with torch.no_grad():
                        if Cuda:
                            images = Variable(torch.from_numpy(images).type(torch.FloatTensor)).cuda()
                            targets = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)).cuda() for ann in targets]
                        else:
                            images = Variable(torch.from_numpy(images).type(torch.FloatTensor))
                            targets = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets]
                    # 前向传播
                    out = net(images)
                    # 清零梯度
                    optimizer.zero_grad()
                    # 计算loss
                    loss_l, loss_c = criterion(out, targets)
                    loss = loss_l + loss_c
                    # 反向传播
                    loss.backward()
                    optimizer.step()
                    # 加上
                    loc_loss += loss_l.item()
                    conf_loss += loss_c.item()

                    pbar.set_postfix(**{'loc_loss'  : loc_loss / (iteration + 1), 
                                        'conf_loss' : conf_loss / (iteration + 1),
                                        'lr'        : get_lr(optimizer)})
                    pbar.update(1)

            lr_scheduler.step()
J
JiaQi Xu 已提交
170
            print('Saving state, iter:', str(epoch+1))
B
Bubbliiiing 已提交
171
            torch.save(model.state_dict(), 'logs/Epoch%d-Loc%.4f-Conf%.4f.pth'%((epoch+1),loc_loss/(iteration+1),conf_loss/(iteration+1)))