train.py 10.6 KB
Newer Older
B
Bubbliiiing 已提交
1 2
import warnings

J
JiaQi Xu 已提交
3
import numpy as np
B
Bubbliiiing 已提交
4 5 6 7 8 9 10
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

from nets.ssd import get_ssd
B
Bubbliiiing 已提交
11
from nets.ssd_training import LossHistory, MultiBoxLoss, weights_init
B
Bubbliiiing 已提交
12 13 14 15
from utils.config import Config
from utils.dataloader import SSDDataset, ssd_dataset_collate

warnings.filterwarnings("ignore")
J
JiaQi Xu 已提交
16

B
Bubbliiiing 已提交
17 18 19 20 21 22 23
#------------------------------------------------------------------------#
#   这里看到的train.py和视频上不太一样
#   我重构了一下train.py,添加了验证集
#   这样训练的时候可以有个参考。
#   训练前注意在config.py里面修改num_classes
#   训练世代、学习率、批处理大小等参数在本代码靠下的if True:内进行修改。
#-------------------------------------------------------------------------#
B
Bubbliiiing 已提交
24 25 26
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']
B
Bubbliiiing 已提交
27 28

def fit_one_epoch(net,criterion,epoch,epoch_size,epoch_size_val,gen,genval,Epoch,cuda):
B
Bubbliiiing 已提交
29 30 31 32
    loc_loss        = 0
    conf_loss       = 0
    loc_loss_val    = 0
    conf_loss_val   = 0
B
Bubbliiiing 已提交
33 34 35 36 37 38 39 40 41 42

    net.train()
    print('Start Train')
    with tqdm(total=epoch_size,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) as pbar:
        for iteration, batch in enumerate(gen):
            if iteration >= epoch_size:
                break
            images, targets = batch[0], batch[1]
            with torch.no_grad():
                if cuda:
B
Bubbliiiing 已提交
43 44
                    images  = torch.from_numpy(images).type(torch.FloatTensor).cuda()
                    targets = [torch.from_numpy(ann).type(torch.FloatTensor).cuda() for ann in targets]
B
Bubbliiiing 已提交
45
                else:
B
Bubbliiiing 已提交
46 47
                    images  = torch.from_numpy(images).type(torch.FloatTensor)
                    targets = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in targets]
B
Bubbliiiing 已提交
48 49 50 51 52 53 54 55 56 57 58
            #----------------------#
            #   前向传播
            #----------------------#
            out = net(images)
            #----------------------#
            #   清零梯度
            #----------------------#
            optimizer.zero_grad()
            #----------------------#
            #   计算损失
            #----------------------#
B
Bubbliiiing 已提交
59 60
            loss_l, loss_c  = criterion(out, targets)
            loss            = loss_l + loss_c
B
Bubbliiiing 已提交
61 62 63 64 65 66
            #----------------------#
            #   反向传播
            #----------------------#
            loss.backward()
            optimizer.step()

B
Bubbliiiing 已提交
67 68
            loc_loss    += loss_l.item()
            conf_loss   += loss_c.item()
B
Bubbliiiing 已提交
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83

            pbar.set_postfix(**{'loc_loss'  : loc_loss / (iteration + 1), 
                                'conf_loss' : conf_loss / (iteration + 1),
                                'lr'        : get_lr(optimizer)})
            pbar.update(1)
                
    net.eval()
    print('Start Validation')
    with tqdm(total=epoch_size_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) as pbar:
        for iteration, batch in enumerate(genval):
            if iteration >= epoch_size_val:
                break
            images, targets = batch[0], batch[1]
            with torch.no_grad():
                if cuda:
B
Bubbliiiing 已提交
84 85
                    images  = torch.from_numpy(images).type(torch.FloatTensor).cuda()
                    targets = [torch.from_numpy(ann).type(torch.FloatTensor).cuda() for ann in targets]
B
Bubbliiiing 已提交
86
                else:
B
Bubbliiiing 已提交
87 88
                    images  = torch.from_numpy(images).type(torch.FloatTensor)
                    targets = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in targets]
B
Bubbliiiing 已提交
89 90 91 92 93

                out = net(images)
                optimizer.zero_grad()
                loss_l, loss_c = criterion(out, targets)

B
Bubbliiiing 已提交
94 95
                loc_loss_val    += loss_l.item()
                conf_loss_val   += loss_c.item()
B
Bubbliiiing 已提交
96 97 98 99 100 101

                pbar.set_postfix(**{'loc_loss'  : loc_loss_val / (iteration + 1), 
                                    'conf_loss' : conf_loss_val / (iteration + 1),
                                    'lr'        : get_lr(optimizer)})
                pbar.update(1)

B
Bubbliiiing 已提交
102 103
    total_loss  = loc_loss + conf_loss
    val_loss    = loc_loss_val + conf_loss_val
B
Bubbliiiing 已提交
104 105

    loss_history.append_loss(total_loss/(epoch_size+1), val_loss/(epoch_size_val+1))
B
Bubbliiiing 已提交
106 107 108 109 110 111
    print('Finish Validation')
    print('Epoch:'+ str(epoch+1) + '/' + str(Epoch))
    print('Total Loss: %.4f || Val Loss: %.4f ' % (total_loss/(epoch_size+1),val_loss/(epoch_size_val+1)))
    print('Saving state, iter:', str(epoch+1))

    torch.save(model.state_dict(), 'logs/Epoch%d-Total_Loss%.4f-Val_Loss%.4f.pth'%((epoch+1),total_loss/(epoch_size+1),val_loss/(epoch_size_val+1)))
B
Bubbliiiing 已提交
112
    return val_loss/(epoch_size_val+1)
B
Bubbliiiing 已提交
113 114 115 116 117

#----------------------------------------------------#
#   检测精度mAP和pr曲线计算参考视频
#   https://www.bilibili.com/video/BV1zE411u7Vw
#----------------------------------------------------#
J
JiaQi Xu 已提交
118
if __name__ == "__main__":
B
Bubbliiiing 已提交
119 120 121 122
    #-------------------------------#
    #   是否使用Cuda
    #   没有GPU可以设置成False
    #-------------------------------#
J
JiaQi Xu 已提交
123
    Cuda = True
B
Bubbliiiing 已提交
124 125 126 127 128 129 130 131 132 133
    #--------------------------------------------#
    #   与视频中不同、新添加了主干网络的选择
    #   分别实现了基于mobilenetv2和vgg的ssd
    #   可通过修改backbone变量进行主干网络的选择
    #   vgg或者mobilenet
    #---------------------------------------------#
    backbone = "vgg"

    model = get_ssd("train", Config["num_classes"], backbone)
    weights_init(model)
B
Bubbliiiing 已提交
134 135 136 137
    #------------------------------------------------------#
    #   权值文件请看README,百度网盘下载
    #------------------------------------------------------#
    model_path = "model_data/ssd_weights.pth"
J
JiaQi Xu 已提交
138
    print('Loading weights into state dict...')
B
Bubbliiiing 已提交
139
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
J
JiaQi Xu 已提交
140
    model_dict = model.state_dict()
B
Bubbliiiing 已提交
141
    pretrained_dict = torch.load(model_path, map_location=device)
J
JiaQi Xu 已提交
142
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) ==  np.shape(v)}
J
JiaQi Xu 已提交
143
    model_dict.update(pretrained_dict)
J
JiaQi Xu 已提交
144
    model.load_state_dict(model_dict)
J
JiaQi Xu 已提交
145 146 147
    print('Finished!')

    annotation_path = '2007_train.txt'
B
Bubbliiiing 已提交
148 149 150 151 152 153
    #----------------------------------------------------------------------#
    #   验证集的划分在train.py代码里面进行
    #   2007_test.txt和2007_val.txt里面没有内容是正常的。训练不会使用到。
    #   当前划分方式下,验证集和训练集的比例为1:9
    #----------------------------------------------------------------------#
    val_split = 0.1
J
JiaQi Xu 已提交
154 155 156 157 158
    with open(annotation_path) as f:
        lines = f.readlines()
    np.random.seed(10101)
    np.random.shuffle(lines)
    np.random.seed(None)
B
Bubbliiiing 已提交
159 160 161 162
    num_val = int(len(lines)*val_split)
    num_train = len(lines) - num_val
    
    criterion = MultiBoxLoss(Config['num_classes'], 0.5, True, 0, True, 3, 0.5, False, Cuda)
B
Bubbliiiing 已提交
163
    loss_history = LossHistory("logs/")
J
JiaQi Xu 已提交
164

B
Bubbliiiing 已提交
165 166 167 168 169
    net = model.train()
    if Cuda:
        net = torch.nn.DataParallel(model)
        cudnn.benchmark = True
        net = net.cuda()
J
JiaQi Xu 已提交
170

B
Bubbliiiing 已提交
171 172 173 174 175
    #------------------------------------------------------#
    #   主干特征提取网络特征通用,冻结训练可以加快训练速度
    #   也可以在训练初期防止权值被破坏。
    #   Init_Epoch为起始世代
    #   Freeze_Epoch为冻结训练的世代
B
Bubbliiiing 已提交
176
    #   Unfreeze_Epoch总训练世代
B
Bubbliiiing 已提交
177 178
    #   提示OOM或者显存不足请调小Batch_size
    #------------------------------------------------------#
J
JiaQi Xu 已提交
179
    if True:
B
Bubbliiiing 已提交
180
        lr              = 5e-4
B
Bubbliiiing 已提交
181 182 183
        Batch_size      = 32
        Init_Epoch      = 0
        Freeze_Epoch    = 50
B
Bubbliiiing 已提交
184

B
Bubbliiiing 已提交
185 186
        optimizer       = optim.Adam(net.parameters(), lr=lr)
        lr_scheduler    = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
B
Bubbliiiing 已提交
187

B
Bubbliiiing 已提交
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
        train_dataset   = SSDDataset(lines[:num_train], (Config["min_dim"], Config["min_dim"]), True)
        val_dataset     = SSDDataset(lines[num_train:], (Config["min_dim"], Config["min_dim"]), False)

        gen             = DataLoader(train_dataset, shuffle=True, batch_size=Batch_size, num_workers=4, pin_memory=True,
                                drop_last=True, collate_fn=ssd_dataset_collate)
        gen_val         = DataLoader(val_dataset, shuffle=True, batch_size=Batch_size, num_workers=4, pin_memory=True,
                                drop_last=True, collate_fn=ssd_dataset_collate)

        if backbone == "vgg":
            for param in model.vgg.parameters():
                param.requires_grad = False
        else:
            for param in model.mobilenet.parameters():
                param.requires_grad = False

        epoch_size      = num_train // Batch_size
        epoch_size_val  = num_val // Batch_size
B
Bubbliiiing 已提交
205

B
Bubbliiiing 已提交
206 207 208
        if epoch_size == 0 or epoch_size_val == 0:
            raise ValueError("数据集过小,无法进行训练,请扩充数据集。")

B
Bubbliiiing 已提交
209
        for epoch in range(Init_Epoch,Freeze_Epoch):
B
Bubbliiiing 已提交
210 211
            val_loss = fit_one_epoch(net,criterion,epoch,epoch_size,epoch_size_val,gen,gen_val,Freeze_Epoch,Cuda)
            lr_scheduler.step(val_loss)
J
JiaQi Xu 已提交
212 213

    if True:
B
Bubbliiiing 已提交
214
        lr              = 1e-4
B
Bubbliiiing 已提交
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
        Batch_size      = 16
        Freeze_Epoch    = 50
        Unfreeze_Epoch  = 100

        optimizer       = optim.Adam(net.parameters(), lr=lr)
        lr_scheduler    = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

        train_dataset   = SSDDataset(lines[:num_train], (Config["min_dim"], Config["min_dim"]), True)
        val_dataset     = SSDDataset(lines[num_train:], (Config["min_dim"], Config["min_dim"]), False)
        
        gen             = DataLoader(train_dataset, shuffle=True, batch_size=Batch_size, num_workers=4, pin_memory=True,
                                drop_last=True, collate_fn=ssd_dataset_collate)
        gen_val         = DataLoader(val_dataset, shuffle=True, batch_size=Batch_size, num_workers=4, pin_memory=True,
                                drop_last=True, collate_fn=ssd_dataset_collate)

        if backbone == "vgg":
            for param in model.vgg.parameters():
                param.requires_grad = True
B
Bubbliiiing 已提交
233
        else:
B
Bubbliiiing 已提交
234 235
            for param in model.mobilenet.parameters():
                param.requires_grad = True
B
Bubbliiiing 已提交
236

B
Bubbliiiing 已提交
237 238
        epoch_size      = num_train // Batch_size
        epoch_size_val  = num_val // Batch_size
B
Bubbliiiing 已提交
239

B
Bubbliiiing 已提交
240 241 242
        if epoch_size == 0 or epoch_size_val == 0:
            raise ValueError("数据集过小,无法进行训练,请扩充数据集。")
            
B
Bubbliiiing 已提交
243
        for epoch in range(Freeze_Epoch,Unfreeze_Epoch):
B
Bubbliiiing 已提交
244 245
            val_loss = fit_one_epoch(net,criterion,epoch,epoch_size,epoch_size_val,gen,gen_val,Unfreeze_Epoch,Cuda)
            lr_scheduler.step(val_loss)