train.py 7.4 KB
Newer Older
J
JiaQi Xu 已提交
1 2
from nets.ssd import get_ssd
from nets.ssd_training import Generator,MultiBoxLoss
B
Bubbliiiing 已提交
3 4
from torch.utils.data import DataLoader
from utils.dataloader import ssd_dataset_collate, SSDDataset
J
JiaQi Xu 已提交
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
from utils.config import Config
from torchsummary import summary
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import time
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
def adjust_learning_rate(optimizer, lr, gamma, step):
    lr = lr * (gamma ** (step))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

B
Bubbliiiing 已提交
21 22 23 24
#----------------------------------------------------#
#   检测精度mAP和pr曲线计算参考视频
#   https://www.bilibili.com/video/BV1zE411u7Vw
#----------------------------------------------------#
J
JiaQi Xu 已提交
25
if __name__ == "__main__":
J
JiaQi Xu 已提交
26 27 28 29 30 31
    # ------------------------------------#
    #   先冻结一部分权重训练
    #   后解冻全部权重训练
    #   先大学习率
    #   后小学习率
    # ------------------------------------#
B
Bubbliiiing 已提交
32
    lr = 1e-4
J
JiaQi Xu 已提交
33
    freeze_lr = 1e-5
J
JiaQi Xu 已提交
34
    Cuda = True
B
Bubbliiiing 已提交
35

J
JiaQi Xu 已提交
36
    Start_iter = 0
J
JiaQi Xu 已提交
37 38 39
    Freeze_epoch = 25
    Epoch = 50

B
Bubbliiiing 已提交
40 41 42 43 44 45
    Batch_size = 4
    #-------------------------------#
    #   Dataloder的使用
    #-------------------------------#
    Use_Data_Loader = True

J
JiaQi Xu 已提交
46
    model = get_ssd("train",Config["num_classes"])
B
Bubbliiiing 已提交
47 48 49 50
    
    #-------------------------------------------#
    #   权值文件的下载请看README
    #-------------------------------------------#
J
JiaQi Xu 已提交
51
    print('Loading weights into state dict...')
B
Bubbliiiing 已提交
52
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
J
JiaQi Xu 已提交
53
    model_dict = model.state_dict()
B
Bubbliiiing 已提交
54
    pretrained_dict = torch.load("model_data/ssd_weights.pth", map_location=device)
J
JiaQi Xu 已提交
55
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) ==  np.shape(v)}
J
JiaQi Xu 已提交
56
    model_dict.update(pretrained_dict)
J
JiaQi Xu 已提交
57
    model.load_state_dict(model_dict)
J
JiaQi Xu 已提交
58 59
    print('Finished!')

J
JiaQi Xu 已提交
60
    net = model.train()
J
JiaQi Xu 已提交
61 62 63 64 65 66 67 68 69 70 71
    if Cuda:
        net = torch.nn.DataParallel(model)
        cudnn.benchmark = True
        net = net.cuda()

    annotation_path = '2007_train.txt'
    with open(annotation_path) as f:
        lines = f.readlines()
    np.random.seed(10101)
    np.random.shuffle(lines)
    np.random.seed(None)
J
JiaQi Xu 已提交
72 73
    num_train = len(lines)

B
Bubbliiiing 已提交
74 75 76 77 78 79 80
    if Use_Data_Loader:
        train_dataset = SSDDataset(lines[:num_train], (Config["min_dim"], Config["min_dim"]))
        gen = DataLoader(train_dataset, batch_size=Batch_size, num_workers=8, pin_memory=True,
                                drop_last=True, collate_fn=ssd_dataset_collate)
    else:
        gen = Generator(Batch_size, lines,
                        (Config["min_dim"], Config["min_dim"]), Config["num_classes"]).generate()
J
JiaQi Xu 已提交
81 82 83

    criterion = MultiBoxLoss(Config['num_classes'], 0.5, True, 0, True, 3, 0.5,
                             False, Cuda)
J
JiaQi Xu 已提交
84
    epoch_size = num_train // Batch_size
J
JiaQi Xu 已提交
85

B
Bubbliiiing 已提交
86

B
Bubbliiiing 已提交
87 88 89 90 91 92 93 94
    #------------------------------------------------------#
    #   主干特征提取网络特征通用,冻结训练可以加快训练速度
    #   也可以在训练初期防止权值被破坏。
    #   Init_Epoch为起始世代
    #   Freeze_Epoch为冻结训练的世代
    #   Epoch总训练世代
    #   提示OOM或者显存不足请调小Batch_size
    #------------------------------------------------------#
J
JiaQi Xu 已提交
95 96 97 98 99 100
    if True:
        # ------------------------------------#
        #   冻结一定部分训练
        # ------------------------------------#
        for param in model.vgg.parameters():
            param.requires_grad = False
J
JiaQi Xu 已提交
101

J
JiaQi Xu 已提交
102 103 104
        optimizer = optim.Adam(net.parameters(), lr=lr)
        for epoch in range(Start_iter,Freeze_epoch):
            if epoch%10==0:
J
JiaQi Xu 已提交
105
                adjust_learning_rate(optimizer,lr,0.9,epoch)
J
JiaQi Xu 已提交
106 107
            loc_loss = 0
            conf_loss = 0
B
Bubbliiiing 已提交
108 109 110 111
            for iteration, batch in enumerate(gen):
                if iteration >= epoch_size:
                    break
                images, targets = batch[0], batch[1]
J
JiaQi Xu 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
                with torch.no_grad():
                    if Cuda:
                        images = Variable(torch.from_numpy(images).type(torch.FloatTensor)).cuda()
                        targets = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)).cuda() for ann in targets]
                    else:
                        images = Variable(torch.from_numpy(images).type(torch.FloatTensor))
                        targets = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets]
                # 前向传播
                out = net(images)
                # 清零梯度
                optimizer.zero_grad()
                # 计算loss
                loss_l, loss_c = criterion(out, targets)
                loss = loss_l + loss_c
                # 反向传播
                loss.backward()
                optimizer.step()
                # 加上
                loc_loss += loss_l.item()
                conf_loss += loss_c.item()
J
JiaQi Xu 已提交
132

J
JiaQi Xu 已提交
133 134
                print('\nEpoch:'+ str(epoch+1) + '/' + str(Freeze_epoch))
                print('iter:' + str(iteration) + '/' + str(epoch_size) + ' || Loc_Loss: %.4f || Conf_Loss: %.4f ||' % (loc_loss/(iteration+1),conf_loss/(iteration+1)), end=' ')
B
Bubbliiiing 已提交
135 136
                
                
J
JiaQi Xu 已提交
137 138 139 140 141 142 143 144 145 146 147 148 149
            print('Saving state, iter:', str(epoch+1))
            torch.save(model.state_dict(), 'logs/Epoch%d-Loc%.4f-Conf%.4f.pth'%((epoch+1),loc_loss/(iteration+1),conf_loss/(iteration+1)))

    if True:
        # ------------------------------------#
        #   全部解冻训练
        # ------------------------------------#
        for param in model.vgg.parameters():
            param.requires_grad = True

        optimizer = optim.Adam(net.parameters(), lr=freeze_lr)
        for epoch in range(Freeze_epoch,Epoch):
            if epoch%10==0:
J
JiaQi Xu 已提交
150
                adjust_learning_rate(optimizer,freeze_lr,0.9,epoch)
J
JiaQi Xu 已提交
151 152
            loc_loss = 0
            conf_loss = 0
B
Bubbliiiing 已提交
153 154 155 156
            for iteration, batch in enumerate(gen):
                if iteration >= epoch_size:
                    break
                images, targets = batch[0], batch[1]
J
JiaQi Xu 已提交
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
                with torch.no_grad():
                    if Cuda:
                        images = Variable(torch.from_numpy(images).type(torch.FloatTensor)).cuda()
                        targets = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)).cuda() for ann in targets]
                    else:
                        images = Variable(torch.from_numpy(images).type(torch.FloatTensor))
                        targets = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets]
                # 前向传播
                out = net(images)
                # 清零梯度
                optimizer.zero_grad()
                # 计算loss
                loss_l, loss_c = criterion(out, targets)
                loss = loss_l + loss_c
                # 反向传播
                loss.backward()
                optimizer.step()
                # 加上
                loc_loss += loss_l.item()
                conf_loss += loss_c.item()
J
JiaQi Xu 已提交
177

J
JiaQi Xu 已提交
178 179
                print('\nEpoch:'+ str(epoch+1) + '/' + str(Epoch))
                print('iter:' + str(iteration) + '/' + str(epoch_size) + ' || Loc_Loss: %.4f || Conf_Loss: %.4f ||' % (loc_loss/(iteration+1),conf_loss/(iteration+1)), end=' ')
J
JiaQi Xu 已提交
180

J
JiaQi Xu 已提交
181
            print('Saving state, iter:', str(epoch+1))
B
Bubbliiiing 已提交
182
            torch.save(model.state_dict(), 'logs/Epoch%d-Loc%.4f-Conf%.4f.pth'%((epoch+1),loc_loss/(iteration+1),conf_loss/(iteration+1)))