train.py 12.5 KB
Newer Older
H
hypox64 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
import os
import numpy as np
import cv2
import random
import torch
import torch.nn as nn
import time

import sys
sys.path.append("..")
sys.path.append("../..")
from util import mosaic,util,ffmpeg,filt,data
from util import image_processing as impro
from cores import Options
15
from models import pix2pix_model,pix2pixHD_model,video_model,unet_model,loadmodel,videoHD_model
H
hypox64 已提交
16 17 18
from matplotlib import pyplot as plt
import torch.backends.cudnn as cudnn

H
HypoX64 已提交
19 20 21 22
'''
--------------------------Get options--------------------------
'''

23 24 25 26
opt = Options()
opt.parser.add_argument('--N',type=int,default=25, help='')
opt.parser.add_argument('--lr',type=float,default=0.0002, help='')
opt.parser.add_argument('--beta1',type=float,default=0.5, help='')
H
HypoX64 已提交
27 28 29
opt.parser.add_argument('--gan', action='store_true', help='if specified, use gan')
opt.parser.add_argument('--l2', action='store_true', help='if specified, use L2 loss')
opt.parser.add_argument('--hd', action='store_true', help='if specified, use HD model')
30 31 32 33 34 35 36
opt.parser.add_argument('--lambda_L1',type=float,default=100, help='')
opt.parser.add_argument('--lambda_gan',type=float,default=1, help='')
opt.parser.add_argument('--finesize',type=int,default=256, help='')
opt.parser.add_argument('--loadsize',type=int,default=286, help='')
opt.parser.add_argument('--batchsize',type=int,default=1, help='')
opt.parser.add_argument('--perload_num',type=int,default=16, help='')
opt.parser.add_argument('--norm',type=str,default='instance', help='')
H
hypox64 已提交
37

38 39 40 41 42 43
opt.parser.add_argument('--maxiter',type=int,default=10000000, help='')
opt.parser.add_argument('--savefreq',type=int,default=10000, help='')
opt.parser.add_argument('--startiter',type=int,default=0, help='')
opt.parser.add_argument('--continuetrain', action='store_true', help='')
opt.parser.add_argument('--savename',type=str,default='MosaicNet', help='')

H
HypoX64 已提交
44 45 46 47

'''
--------------------------Init--------------------------
'''
48 49
opt = opt.getparse()
dir_checkpoint = os.path.join('checkpoints/',opt.savename)
H
hypox64 已提交
50
util.makedirs(dir_checkpoint)
51 52
util.writelog(os.path.join(dir_checkpoint,'loss.txt'), 
              str(time.asctime(time.localtime(time.time())))+'\n'+util.opt2str(opt))
H
hypox64 已提交
53

54
N = opt.N
H
hypox64 已提交
55 56 57 58 59 60 61
loss_sum = [0.,0.,0.,0.]
loss_plot = [[],[]]
item_plot = []

videos = os.listdir('./dataset')
videos.sort()
lengths = []
H
hypox64 已提交
62
print('check dataset...')
H
hypox64 已提交
63 64 65
for video in videos:
    video_images = os.listdir('./dataset/'+video+'/ori')
    lengths.append(len(video_images))
H
HypoX64 已提交
66 67 68 69
if opt.hd:
    netG = videoHD_model.MosaicNet(3*N+1, 3, norm=opt.norm)
else:
    netG = video_model.MosaicNet(3*N+1, 3, norm=opt.norm)
H
HypoX64 已提交
70
loadmodel.show_paramsnumber(netG,'netG')
H
HypoX64 已提交
71

72
if opt.gan:
H
HypoX64 已提交
73 74 75 76 77
    if opt.hd:
        netD = pix2pixHD_model.define_D(6, 64, 3, norm = opt.norm, use_sigmoid=False, num_D=2)    
    else:
        netD = pix2pix_model.define_D(3*2, 64, 'basic', norm = opt.norm)
    netD.train()
H
hypox64 已提交
78

79
if opt.continuetrain:
H
hypox64 已提交
80
    if not os.path.isfile(os.path.join(dir_checkpoint,'last_G.pth')):
81
        opt.continuetrain = False
H
hypox64 已提交
82
        print('can not load last_G, training on init weight.')
83
if opt.continuetrain:     
H
hypox64 已提交
84
    netG.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_G.pth')))
85
    if opt.gan:
H
hypox64 已提交
86 87
        netD.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_D.pth')))
    f = open(os.path.join(dir_checkpoint,'iter'),'r')
88
    opt.startiter = int(f.read())
H
hypox64 已提交
89 90
    f.close()

91
optimizer_G = torch.optim.Adam(netG.parameters(), lr=opt.lr,betas=(opt.beta1, 0.999))
H
hypox64 已提交
92 93
criterion_L1 = nn.L1Loss()
criterion_L2 = nn.MSELoss()
94
if opt.gan:
H
HypoX64 已提交
95 96 97 98 99
    optimizer_D = torch.optim.Adam(netD.parameters(), lr=opt.lr,betas=(opt.beta1, 0.999))
    if opt.hd:
        criterionGAN = pix2pixHD_model.GANLoss(tensor=torch.cuda.FloatTensor)
    else:
        criterionGAN = pix2pix_model.GANLoss(gan_mode='lsgan').cuda()   
H
hypox64 已提交
100

101
if opt.use_gpu:
H
hypox64 已提交
102
    netG.cuda()
103
    if opt.gan:
H
hypox64 已提交
104 105 106
        netD.cuda()
        criterionGAN.cuda()
    cudnn.benchmark = True
H
hypox64 已提交
107

H
HypoX64 已提交
108 109 110
'''
--------------------------preload data--------------------------
'''
H
hypox64 已提交
111 112 113
def loaddata():
    video_index = random.randint(0,len(videos)-1)
    video = videos[video_index]
114 115
    img_index = random.randint(int(N/2)+1,lengths[video_index]- int(N/2)-1)
    input_img = np.zeros((opt.loadsize,opt.loadsize,3*N+1), dtype='uint8')
H
hypox64 已提交
116
    for i in range(0,N):
H
HypoX64 已提交
117
    
H
hypox64 已提交
118
        img = cv2.imread('./dataset/'+video+'/mosaic/output_'+'%05d'%(img_index+i-int(N/2))+'.png')
119
        img = impro.resize(img,opt.loadsize)
H
hypox64 已提交
120 121
        input_img[:,:,i*3:(i+1)*3] = img
    mask = cv2.imread('./dataset/'+video+'/mask/output_'+'%05d'%(img_index)+'.png',0)
122
    mask = impro.resize(mask,opt.loadsize)
H
hypox64 已提交
123 124 125 126
    mask = impro.mask_threshold(mask,15,128)
    input_img[:,:,-1] = mask

    ground_true = cv2.imread('./dataset/'+video+'/ori/output_'+'%05d'%(img_index)+'.png')
127
    ground_true = impro.resize(ground_true,opt.loadsize)
H
hypox64 已提交
128

129
    input_img,ground_true = data.random_transform_video(input_img,ground_true,opt.finesize,N)
H
hypox64 已提交
130 131
    input_img = data.im2tensor(input_img,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1=False)
    ground_true = data.im2tensor(ground_true,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1=False)
H
hypox64 已提交
132 133 134 135
    
    return input_img,ground_true

print('preloading data, please wait 5s...')
H
hypox64 已提交
136

137 138 139 140
if opt.perload_num <= opt.batchsize:
    opt.perload_num = opt.batchsize*2
input_imgs = torch.rand(opt.perload_num,N*3+1,opt.finesize,opt.finesize).cuda()
ground_trues = torch.rand(opt.perload_num,3,opt.finesize,opt.finesize).cuda()
H
hypox64 已提交
141
load_cnt = 0
H
hypox64 已提交
142

H
hypox64 已提交
143
def preload():
H
hypox64 已提交
144
    global load_cnt   
H
hypox64 已提交
145 146
    while 1:
        try:
147
            ran = random.randint(0, opt.perload_num-1)
H
hypox64 已提交
148
            input_imgs[ran],ground_trues[ran] = loaddata()
H
hypox64 已提交
149
            load_cnt += 1
H
hypox64 已提交
150 151 152 153 154 155 156
            # time.sleep(0.1)
        except Exception as e:
            print("error:",e)
import threading
t = threading.Thread(target=preload,args=())  #t为新创建的线程
t.daemon = True
t.start()
H
hypox64 已提交
157
time_start=time.time()
158
while load_cnt < opt.perload_num:
H
hypox64 已提交
159
    time.sleep(0.1)
H
hypox64 已提交
160
time_end=time.time()
161
print('load speed:',round((time_end-time_start)/opt.perload_num,3),'s/it')
H
hypox64 已提交
162

H
HypoX64 已提交
163 164 165
'''
--------------------------train--------------------------
'''
H
hypox64 已提交
166
util.copyfile('./train.py', os.path.join(dir_checkpoint,'train.py'))
167
util.copyfile('../../models/videoHD_model.py', os.path.join(dir_checkpoint,'model.py'))
H
hypox64 已提交
168 169 170
netG.train()
time_start=time.time()
print("Begin training...")
171
for iter in range(opt.startiter+1,opt.maxiter):
H
hypox64 已提交
172

173 174 175
    ran = random.randint(0, opt.perload_num-opt.batchsize-1)
    inputdata = input_imgs[ran:ran+opt.batchsize].clone()
    target = ground_trues[ran:ran+opt.batchsize].clone()
H
hypox64 已提交
176

177
    if opt.gan:
H
hypox64 已提交
178 179 180 181 182 183 184
        # compute fake images: G(A)
        pred = netG(inputdata)
        # update D
        pix2pix_model.set_requires_grad(netD,True)
        optimizer_D.zero_grad()
        # Fake
        real_A = inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:]
H
hypox64 已提交
185 186 187
        fake_AB = torch.cat((real_A, pred), 1)
        pred_fake = netD(fake_AB.detach())
        loss_D_fake = criterionGAN(pred_fake, False)
H
hypox64 已提交
188
        # Real
H
hypox64 已提交
189
        real_AB = torch.cat((real_A, target), 1)
H
hypox64 已提交
190 191
        pred_real = netD(real_AB)
        loss_D_real = criterionGAN(pred_real, True)
H
hypox64 已提交
192
        # combine loss and calculate gradients
H
hypox64 已提交
193 194 195
        loss_D = (loss_D_fake + loss_D_real) * 0.5
        loss_sum[2] += loss_D_fake.item()
        loss_sum[3] += loss_D_real.item()
H
hypox64 已提交
196
        # udpate D's weights
H
hypox64 已提交
197 198 199
        loss_D.backward()
        optimizer_D.step()

H
hypox64 已提交
200 201 202 203 204
        # update G
        pix2pix_model.set_requires_grad(netD,False)
        optimizer_G.zero_grad()
        # First, G(A) should fake the discriminator
        real_A = inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:]
H
hypox64 已提交
205 206
        fake_AB = torch.cat((real_A, pred), 1)
        pred_fake = netD(fake_AB)
207
        loss_G_GAN = criterionGAN(pred_fake, True)*opt.lambda_gan
H
hypox64 已提交
208
        # Second, G(A) = B
209 210
        if opt.l2:
            loss_G_L1 = (criterion_L1(pred, target)+criterion_L2(pred, target)) * opt.lambda_L1
H
hypox64 已提交
211
        else:
212
            loss_G_L1 = criterion_L1(pred, target) * opt.lambda_L1
H
hypox64 已提交
213 214 215 216
        # combine loss and calculate gradients
        loss_G = loss_G_GAN + loss_G_L1
        loss_sum[0] += loss_G_L1.item()
        loss_sum[1] += loss_G_GAN.item()
H
hypox64 已提交
217
        # udpate G's weights
H
hypox64 已提交
218 219 220 221
        loss_G.backward()
        optimizer_G.step()

    else:
H
hypox64 已提交
222
        pred = netG(inputdata)
223 224
        if opt.l2:
            loss_G_L1 = (criterion_L1(pred, target)+criterion_L2(pred, target)) * opt.lambda_L1
H
hypox64 已提交
225
        else:
226
            loss_G_L1 = criterion_L1(pred, target) * opt.lambda_L1
H
hypox64 已提交
227 228 229 230 231 232 233 234
        loss_sum[0] += loss_G_L1.item()

        optimizer_G.zero_grad()
        loss_G_L1.backward()
        optimizer_G.step()

    if (iter+1)%100 == 0:
        try:
H
hypox64 已提交
235
            data.showresult(inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:],
236
             target, pred,os.path.join(dir_checkpoint,'result_train.jpg'))
H
hypox64 已提交
237 238
        except Exception as e:
            print(e)
H
HypoX64 已提交
239
    # plot
H
hypox64 已提交
240 241
    if (iter+1)%1000 == 0:
        time_end = time.time()
242 243 244 245
        if opt.gan:
            savestr ='iter:{0:d} L1_loss:{1:.4f} G_loss:{2:.4f} D_f:{3:.4f} D_r:{4:.4f} time:{5:.2f}'.format(
                iter+1,loss_sum[0]/1000,loss_sum[1]/1000,loss_sum[2]/1000,loss_sum[3]/1000,(time_end-time_start)/1000)
            util.writelog(os.path.join(dir_checkpoint,'loss.txt'), savestr,True)
H
hypox64 已提交
246 247 248 249 250 251 252
            if (iter+1)/1000 >= 10:
                loss_plot[0].append(loss_sum[0]/1000)
                loss_plot[1].append(loss_sum[1]/1000)
                item_plot.append(iter+1)
                try:
                    plt.plot(item_plot,loss_plot[0])
                    plt.plot(item_plot,loss_plot[1])
253
                    plt.savefig(os.path.join(dir_checkpoint,'loss.jpg'))
H
hypox64 已提交
254 255 256 257
                    plt.close()
                except Exception as e:
                    print("error:",e)
        else:
258 259
            savestr ='iter:{0:d}  L1_loss:{1:.4f}  time:{2:.2f}'.format(iter+1,loss_sum[0]/1000,(time_end-time_start)/1000)
            util.writelog(os.path.join(dir_checkpoint,'loss.txt'), savestr,True)
H
hypox64 已提交
260 261 262 263 264
            if (iter+1)/1000 >= 10:
                loss_plot[0].append(loss_sum[0]/1000)
                item_plot.append(iter+1)
                try:
                    plt.plot(item_plot,loss_plot[0])
265
                    plt.savefig(os.path.join(dir_checkpoint,'loss.jpg'))
H
hypox64 已提交
266 267 268 269 270 271
                    plt.close()
                except Exception as e:
                    print("error:",e)
        loss_sum = [0.,0.,0.,0.]
        time_start=time.time()

H
HypoX64 已提交
272
    # save network
273 274 275
    if (iter+1)%opt.savefreq == 0:
        if iter+1 != opt.savefreq:
            os.rename(os.path.join(dir_checkpoint,'last_G.pth'),os.path.join(dir_checkpoint,str(iter+1-opt.savefreq)+'G.pth'))
H
hypox64 已提交
276
        torch.save(netG.cpu().state_dict(),os.path.join(dir_checkpoint,'last_G.pth'))
277 278 279
        if opt.gan:
            if iter+1 != opt.savefreq:
                os.rename(os.path.join(dir_checkpoint,'last_D.pth'),os.path.join(dir_checkpoint,str(iter+1-opt.savefreq)+'D.pth'))
H
hypox64 已提交
280
            torch.save(netD.cpu().state_dict(),os.path.join(dir_checkpoint,'last_D.pth'))
281
        if opt.use_gpu:
H
hypox64 已提交
282
            netG.cuda()
283
            if opt.gan:
H
hypox64 已提交
284 285 286 287 288 289 290
                netD.cuda()
        f = open(os.path.join(dir_checkpoint,'iter'),'w+')
        f.write(str(iter+1))
        f.close()
        print('network saved.')

        #test
H
HypoX64 已提交
291 292 293 294 295 296
        if os.path.isdir('./test'):  
            netG.eval()
            
            test_names = os.listdir('./test')
            test_names.sort()
            result = np.zeros((opt.finesize*2,opt.finesize*len(test_names),3), dtype='uint8')
H
hypox64 已提交
297

H
HypoX64 已提交
298 299 300 301 302 303 304 305
            for cnt,test_name in enumerate(test_names,0):
                img_names = os.listdir(os.path.join('./test',test_name,'image'))
                img_names.sort()
                inputdata = np.zeros((opt.finesize,opt.finesize,3*N+1), dtype='uint8')
                for i in range(0,N):
                    img = impro.imread(os.path.join('./test',test_name,'image',img_names[i]))
                    img = impro.resize(img,opt.finesize)
                    inputdata[:,:,i*3:(i+1)*3] = img
H
hypox64 已提交
306

H
HypoX64 已提交
307 308 309 310 311 312 313 314 315 316
                mask = impro.imread(os.path.join('./test',test_name,'mask.png'),'gray')
                mask = impro.resize(mask,opt.finesize)
                mask = impro.mask_threshold(mask,15,128)
                inputdata[:,:,-1] = mask
                result[0:opt.finesize,opt.finesize*cnt:opt.finesize*(cnt+1),:] = inputdata[:,:,int((N-1)/2)*3:(int((N-1)/2)+1)*3]
                inputdata = data.im2tensor(inputdata,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1 = False)
                pred = netG(inputdata)
     
                pred = data.tensor2im(pred,rgb2bgr = False, is0_1 = False)
                result[opt.finesize:opt.finesize*2,opt.finesize*cnt:opt.finesize*(cnt+1),:] = pred
H
hypox64 已提交
317

H
HypoX64 已提交
318 319
            cv2.imwrite(os.path.join(dir_checkpoint,str(iter+1)+'_test.jpg'), result)
            netG.train()