train.py 5.2 KB
Newer Older
H
hypox64 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
import sys
import os
import random
import datetime

import numpy as np
import cv2

import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
from torch import optim

H
HypoX64 已提交
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
import sys
sys.path.append("..")
sys.path.append("../..")
from util import mosaic,util,ffmpeg,filt,data
from util import image_processing as impro
from models import unet_model
from matplotlib import pyplot as plt
import torch.backends.cudnn as cudnn

LR = 0.0002
EPOCHS = 100
BATCHSIZE = 16
LOADSIZE = 256
FINESIZE = 224
CONTINUE = False
use_gpu = True
SAVE_FRE = 5
cudnn.benchmark = False
H
hypox64 已提交
32

H
HypoX64 已提交
33 34 35
dir_img = './datasets/av/origin_image/'
dir_mask = './datasets/av/mask/'
dir_checkpoint = 'checkpoints/'
H
hypox64 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50


def Totensor(img,use_gpu=True):
    size=img.shape[0]
    img = torch.from_numpy(img).float()
    if use_gpu:
        img = img.cuda()
    return img


def Toinputshape(imgs,masks,finesize):
    batchsize = len(imgs)
    result_imgs=[];result_masks=[]
    for i in range(batchsize):
        # print(imgs[i].shape,masks[i].shape)
H
HypoX64 已提交
51
        img,mask = data.random_transform_image(imgs[i], masks[i], finesize)
H
hypox64 已提交
52
        # print(img.shape,mask.shape)
H
hypox64 已提交
53 54
        mask = (mask.reshape(1,finesize,finesize)/255.0-0.5)/0.5
        img = (img.transpose((2, 0, 1))/255.0-0.5)/0.5
H
hypox64 已提交
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
        result_imgs.append(img)
        result_masks.append(mask)
    result_imgs = np.array(result_imgs)
    result_masks  = np.array(result_masks)
    return result_imgs,result_masks

def batch_generator(images,masks,batchsize):
    dataset_images = []
    dataset_masks = []

    for i in range(int(len(images)/batchsize)):
        dataset_images.append(images[i*batchsize:(i+1)*batchsize])
        dataset_masks.append(masks[i*batchsize:(i+1)*batchsize])
    if len(images)%batchsize != 0:
        dataset_images.append(images[len(images)-len(images)%batchsize:])
        dataset_masks.append(masks[len(images)-len(images)%batchsize:])

    return dataset_images,dataset_masks

def loadimage(dir_img,dir_mask,loadsize,eval_p):
    t1 = datetime.datetime.now()
    imgnames = os.listdir(dir_img)
H
HypoX64 已提交
77
    # imgnames = imgnames[:100]
H
hypox64 已提交
78 79 80 81 82 83
    print('images num:',len(imgnames))
    random.shuffle(imgnames)
    imgnames = (f[:-4] for f in imgnames)
    images = []
    masks = []
    for imgname in imgnames:
H
HypoX64 已提交
84 85 86 87
        img = impro.imread(dir_img+imgname+'.jpg')
        mask = impro.imread(dir_mask+imgname+'.png',mod = 'gray')
        img = impro.resize(img,loadsize)
        mask = impro.resize(mask,loadsize)
H
hypox64 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
        images.append(img)
        masks.append(mask)
    train_images,train_masks = images[0:int(len(masks)*(1-eval_p))],masks[0:int(len(masks)*(1-eval_p))]
    eval_images,eval_masks = images[int(len(masks)*(1-eval_p)):len(masks)],masks[int(len(masks)*(1-eval_p)):len(masks)]
    t2 = datetime.datetime.now()
    print('load data cost time:',(t2 - t1).seconds,'s')
    return train_images,train_masks,eval_images,eval_masks



print('loading data......')
train_images,train_masks,eval_images,eval_masks = loadimage(dir_img,dir_mask,LOADSIZE,0.2)
dataset_eval_images,dataset_eval_masks = batch_generator(eval_images,eval_masks,BATCHSIZE)
dataset_train_images,dataset_train_masks = batch_generator(train_images,train_masks,BATCHSIZE)


H
HypoX64 已提交
104
net = unet_model.UNet(n_channels = 3, n_classes = 1)
H
hypox64 已提交
105 106 107 108 109 110 111 112


if CONTINUE:
    net.load_state_dict(torch.load(dir_checkpoint+'last.pth'))
if use_gpu:
    net.cuda()


H
HypoX64 已提交
113
optimizer = torch.optim.Adam(net.parameters(), lr=LR, betas=(0.9, 0.999))
H
hypox64 已提交
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140

criterion = nn.BCELoss()
# criterion = nn.L1Loss()

print('begin training......')
for epoch in range(EPOCHS):

    starttime = datetime.datetime.now()
    print('Epoch {}/{}.'.format(epoch + 1, EPOCHS))
    net.train()
    if use_gpu:
        net.cuda()
    epoch_loss = 0
    for i,(img,mask) in enumerate(zip(dataset_train_images,dataset_train_masks)):
        # print(epoch,i,img.shape,mask.shape)
        img,mask = Toinputshape(img, mask, FINESIZE)
        img = Totensor(img,use_gpu)
        mask = Totensor(mask,use_gpu)

        mask_pred = net(img)
        loss = criterion(mask_pred, mask)
        epoch_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

H
HypoX64 已提交
141 142
        if i%100 == 0:
            data.showresult(img,mask,mask_pred,os.path.join(dir_checkpoint,'result.png'))
H
hypox64 已提交
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163

    # torch.cuda.empty_cache()
    # # net.eval()
    epoch_loss_eval = 0
    with torch.no_grad():
        for i,(img,mask) in enumerate(zip(dataset_eval_images,dataset_eval_masks)):
            # print(epoch,i,img.shape,mask.shape)
            img,mask = Toinputshape(img, mask, FINESIZE)
            img = Totensor(img,use_gpu)
            mask = Totensor(mask,use_gpu)
            mask_pred = net(img)
            loss = criterion(mask_pred, mask)
            epoch_loss_eval += loss.item()
    # torch.cuda.empty_cache()

    endtime = datetime.datetime.now()
    print('--- Epoch train_loss: {0:.6f} eval_loss: {1:.6f} Cost time: {2:} s'.format(
        epoch_loss/len(dataset_train_images),
        epoch_loss_eval/len(dataset_eval_images),
        (endtime - starttime).seconds)),
    torch.save(net.cpu().state_dict(),dir_checkpoint+'last.pth')
H
HypoX64 已提交
164

H
hypox64 已提交
165 166
    if (epoch+1)%SAVE_FRE == 0:
        torch.save(net.cpu().state_dict(),dir_checkpoint+'epoch'+str(epoch+1)+'.pth')
H
HypoX64 已提交
167
        data.showresult(img,mask,mask_pred,os.path.join(dir_checkpoint,'epoch_'+str(epoch+1)+'.png'))
H
hypox64 已提交
168
        print('network saved.')