train.py 15.3 KB
Newer Older
H
hypox64 已提交
1
import os
H
hypox64 已提交
2 3 4 5 6 7
import sys
sys.path.append("..")
sys.path.append("../..")
from cores import Options
opt = Options()

H
hypox64 已提交
8 9 10 11 12 13 14 15 16
import numpy as np
import cv2
import random
import torch
import torch.nn as nn
import time

from util import mosaic,util,ffmpeg,filt,data
from util import image_processing as impro
17
from models import pix2pix_model,pix2pixHD_model,video_model,unet_model,loadmodel,videoHD_model
H
hypox64 已提交
18 19 20
from matplotlib import pyplot as plt
import torch.backends.cudnn as cudnn

H
HypoX64 已提交
21 22 23
'''
--------------------------Get options--------------------------
'''
24 25 26
opt.parser.add_argument('--N',type=int,default=25, help='')
opt.parser.add_argument('--lr',type=float,default=0.0002, help='')
opt.parser.add_argument('--beta1',type=float,default=0.5, help='')
H
HypoX64 已提交
27 28 29
opt.parser.add_argument('--gan', action='store_true', help='if specified, use gan')
opt.parser.add_argument('--l2', action='store_true', help='if specified, use L2 loss')
opt.parser.add_argument('--hd', action='store_true', help='if specified, use HD model')
30 31 32 33 34
opt.parser.add_argument('--lambda_L1',type=float,default=100, help='')
opt.parser.add_argument('--lambda_gan',type=float,default=1, help='')
opt.parser.add_argument('--finesize',type=int,default=256, help='')
opt.parser.add_argument('--loadsize',type=int,default=286, help='')
opt.parser.add_argument('--batchsize',type=int,default=1, help='')
H
hypox64 已提交
35
opt.parser.add_argument('--perload_num',type=int,default=64, help='number of images pool')
36
opt.parser.add_argument('--norm',type=str,default='instance', help='')
H
hypox64 已提交
37 38 39
opt.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use')
opt.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
opt.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss') 
H
hypox64 已提交
40

H
hypox64 已提交
41
opt.parser.add_argument('--dataset',type=str,default='./datasets/face/', help='')
42 43 44 45
opt.parser.add_argument('--maxiter',type=int,default=10000000, help='')
opt.parser.add_argument('--savefreq',type=int,default=10000, help='')
opt.parser.add_argument('--startiter',type=int,default=0, help='')
opt.parser.add_argument('--continuetrain', action='store_true', help='')
H
hypox64 已提交
46
opt.parser.add_argument('--savename',type=str,default='face', help='')
47

H
HypoX64 已提交
48 49 50 51

'''
--------------------------Init--------------------------
'''
52 53
opt = opt.getparse()
dir_checkpoint = os.path.join('checkpoints/',opt.savename)
H
hypox64 已提交
54
util.makedirs(dir_checkpoint)
55 56
util.writelog(os.path.join(dir_checkpoint,'loss.txt'), 
              str(time.asctime(time.localtime(time.time())))+'\n'+util.opt2str(opt))
H
hypox64 已提交
57
cudnn.benchmark = True
H
hypox64 已提交
58

59
N = opt.N
H
hypox64 已提交
60 61
loss_sum = [0.,0.,0.,0.,0.,0]
loss_plot = [[],[],[],[]]
H
hypox64 已提交
62 63
item_plot = []

H
hypox64 已提交
64 65 66 67 68 69 70 71 72 73 74 75
# list video dir 
videonames = os.listdir(opt.dataset)
videonames.sort()
lengths = [];tmp = []
print('Check dataset...')
for video in videonames:
    if video != 'opt.txt':
        video_images = os.listdir(os.path.join(opt.dataset,video,'origin_image'))
        lengths.append(len(video_images))
        tmp.append(video)
videonames = tmp
video_num = len(videonames)
H
hypox64 已提交
76 77

#--------------------------Init network--------------------------
H
hypox64 已提交
78
print('Init network...')
H
HypoX64 已提交
79 80 81 82
if opt.hd:
    netG = videoHD_model.MosaicNet(3*N+1, 3, norm=opt.norm)
else:
    netG = video_model.MosaicNet(3*N+1, 3, norm=opt.norm)
H
hypox64 已提交
83
netG.cuda()
H
HypoX64 已提交
84
loadmodel.show_paramsnumber(netG,'netG')
H
HypoX64 已提交
85

86
if opt.gan:
H
HypoX64 已提交
87
    if opt.hd:
H
hypox64 已提交
88
        netD = pix2pixHD_model.define_D(6, 64, opt.n_layers_D, norm = opt.norm, use_sigmoid=False, num_D=opt.num_D,getIntermFeat=True)    
H
HypoX64 已提交
89 90
    else:
        netD = pix2pix_model.define_D(3*2, 64, 'basic', norm = opt.norm)
H
hypox64 已提交
91
    netD.cuda()
H
hypox64 已提交
92

H
hypox64 已提交
93
#--------------------------continue train--------------------------
94
if opt.continuetrain:
H
hypox64 已提交
95
    if not os.path.isfile(os.path.join(dir_checkpoint,'last_G.pth')):
96
        opt.continuetrain = False
H
hypox64 已提交
97
        print('can not load last_G, training on init weight.')
98
if opt.continuetrain:     
H
hypox64 已提交
99
    netG.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_G.pth')))
100
    if opt.gan:
H
hypox64 已提交
101 102
        netD.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_D.pth')))
    f = open(os.path.join(dir_checkpoint,'iter'),'r')
103
    opt.startiter = int(f.read())
H
hypox64 已提交
104 105
    f.close()

H
hypox64 已提交
106
#--------------------------optimizer & loss--------------------------
107
optimizer_G = torch.optim.Adam(netG.parameters(), lr=opt.lr,betas=(opt.beta1, 0.999))
H
hypox64 已提交
108 109
criterion_L1 = nn.L1Loss()
criterion_L2 = nn.MSELoss()
110
if opt.gan:
H
HypoX64 已提交
111 112
    optimizer_D = torch.optim.Adam(netD.parameters(), lr=opt.lr,betas=(opt.beta1, 0.999))
    if opt.hd:
H
hypox64 已提交
113 114 115 116
        criterionGAN = pix2pixHD_model.GANLoss(tensor=torch.cuda.FloatTensor).cuda() 
        # criterionFeat = torch.nn.L1Loss()
        criterionFeat = pix2pixHD_model.GAN_Feat_loss(opt)
        criterionVGG = pix2pixHD_model.VGGLoss([opt.use_gpu])
H
HypoX64 已提交
117 118
    else:
        criterionGAN = pix2pix_model.GANLoss(gan_mode='lsgan').cuda()   
H
hypox64 已提交
119

H
HypoX64 已提交
120
'''
H
hypox64 已提交
121
--------------------------preload data & data pool--------------------------
H
HypoX64 已提交
122
'''
H
hypox64 已提交
123
# def loaddata(video_index):
H
hypox64 已提交
124
    
H
hypox64 已提交
125 126
#     videoname = videonames[video_index]
#     img_index = random.randint(int(N/2)+1,lengths[video_index]- int(N/2)-1)
H
hypox64 已提交
127
    
H
hypox64 已提交
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
#     input_img = np.zeros((opt.loadsize,opt.loadsize,3*N+1), dtype='uint8')
#     # this frame
#     this_mask = impro.imread(os.path.join(opt.dataset,videoname,'mask','%05d'%(img_index)+'.png'),'gray',loadsize=opt.loadsize)
#     input_img[:,:,-1] = this_mask
#     #print(os.path.join(opt.dataset,videoname,'origin_image','%05d'%(img_index)+'.jpg'))
#     ground_true =  impro.imread(os.path.join(opt.dataset,videoname,'origin_image','%05d'%(img_index)+'.jpg'),loadsize=opt.loadsize)
#     mosaic_size,mod,rect_rat,feather = mosaic.get_random_parameter(ground_true,this_mask)
#     start_pos = mosaic.get_random_startpos(num=N,bisa_p=0.3,bisa_max=mosaic_size,bisa_max_part=3)
#     # merge other frame
#     for i in range(0,N):
#         img =  impro.imread(os.path.join(opt.dataset,videoname,'origin_image','%05d'%(img_index+i-int(N/2))+'.jpg'),loadsize=opt.loadsize)
#         mask = impro.imread(os.path.join(opt.dataset,videoname,'mask','%05d'%(img_index+i-int(N/2))+'.png'),'gray',loadsize=opt.loadsize)
#         img_mosaic = mosaic.addmosaic_base(img, mask, mosaic_size,model = mod,rect_rat=rect_rat,feather=feather,start_point=start_pos[i])
#         input_img[:,:,i*3:(i+1)*3] = img_mosaic
#     # to tensor
#     input_img,ground_true = data.random_transform_video(input_img,ground_true,opt.finesize,N)
#     input_img = data.im2tensor(input_img,bgr2rgb=False,use_gpu=-1,use_transform = False,is0_1=False)
#     ground_true = data.im2tensor(ground_true,bgr2rgb=False,use_gpu=-1,use_transform = False,is0_1=False)
H
hypox64 已提交
146
    
H
hypox64 已提交
147
#     return input_img,ground_true
H
hypox64 已提交
148

H
hypox64 已提交
149
print('Preloading data, please wait...')
H
hypox64 已提交
150

151 152
if opt.perload_num <= opt.batchsize:
    opt.perload_num = opt.batchsize*2
H
hypox64 已提交
153
#data pool
H
hypox64 已提交
154 155
input_imgs = torch.rand(opt.perload_num,N*3+1,opt.finesize,opt.finesize)
ground_trues = torch.rand(opt.perload_num,3,opt.finesize,opt.finesize)
H
hypox64 已提交
156
load_cnt = 0
H
hypox64 已提交
157

H
hypox64 已提交
158
def preload():
H
hypox64 已提交
159
    global load_cnt   
H
hypox64 已提交
160 161
    while 1:
        try:
H
hypox64 已提交
162
            video_index = random.randint(0,video_num-1)
H
hypox64 已提交
163 164 165 166
            videoname = videonames[video_index]
            img_index = random.randint(int(N/2)+1,lengths[video_index]- int(N/2)-1)
            input_imgs[load_cnt%opt.perload_num],ground_trues[load_cnt%opt.perload_num] = data.load_train_video(videoname,img_index,opt)
            # input_imgs[load_cnt%opt.perload_num],ground_trues[load_cnt%opt.perload_num] = loaddata(video_index)
H
hypox64 已提交
167
            load_cnt += 1
H
hypox64 已提交
168 169 170 171
            # time.sleep(0.1)
        except Exception as e:
            print("error:",e)
import threading
H
hypox64 已提交
172
t = threading.Thread(target=preload,args=()) 
H
hypox64 已提交
173 174
t.daemon = True
t.start()
H
hypox64 已提交
175
time_start=time.time()
176
while load_cnt < opt.perload_num:
H
hypox64 已提交
177
    time.sleep(0.1)
H
hypox64 已提交
178
time_end=time.time()
H
hypox64 已提交
179 180
util.writelog(os.path.join(dir_checkpoint,'loss.txt'), 
              'load speed: '+str(round((time_end-time_start)/(opt.perload_num),3))+' s/it',True)
H
hypox64 已提交
181

H
HypoX64 已提交
182 183 184
'''
--------------------------train--------------------------
'''
H
hypox64 已提交
185
util.copyfile('./train.py', os.path.join(dir_checkpoint,'train.py'))
186
util.copyfile('../../models/videoHD_model.py', os.path.join(dir_checkpoint,'model.py'))
H
hypox64 已提交
187
netG.train()
H
hypox64 已提交
188
netD.train()
H
hypox64 已提交
189 190
time_start=time.time()
print("Begin training...")
191
for iter in range(opt.startiter+1,opt.maxiter):
H
hypox64 已提交
192

H
hypox64 已提交
193 194 195
    ran = random.randint(0, opt.perload_num-opt.batchsize)
    inputdata = (input_imgs[ran:ran+opt.batchsize].clone()).cuda()
    target = (ground_trues[ran:ran+opt.batchsize].clone()).cuda()
H
hypox64 已提交
196

197
    if opt.gan:
H
hypox64 已提交
198 199
        # compute fake images: G(A)
        pred = netG(inputdata)
H
hypox64 已提交
200 201 202
        real_A = inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:]
        
        # --------------------update D--------------------
H
hypox64 已提交
203 204 205
        pix2pix_model.set_requires_grad(netD,True)
        optimizer_D.zero_grad()
        # Fake
H
hypox64 已提交
206 207 208
        fake_AB = torch.cat((real_A, pred), 1)
        pred_fake = netD(fake_AB.detach())
        loss_D_fake = criterionGAN(pred_fake, False)
H
hypox64 已提交
209
        # Real
H
hypox64 已提交
210
        real_AB = torch.cat((real_A, target), 1)
H
hypox64 已提交
211 212
        pred_real = netD(real_AB)
        loss_D_real = criterionGAN(pred_real, True)
H
hypox64 已提交
213
        # combine loss and calculate gradients
H
hypox64 已提交
214
        loss_D = (loss_D_fake + loss_D_real) * 0.5
H
hypox64 已提交
215 216
        loss_sum[4] += loss_D_fake.item()
        loss_sum[5] += loss_D_real.item()
H
hypox64 已提交
217
        # udpate D's weights
H
hypox64 已提交
218 219 220
        loss_D.backward()
        optimizer_D.step()

H
hypox64 已提交
221
        # --------------------update G--------------------
H
hypox64 已提交
222 223
        pix2pix_model.set_requires_grad(netD,False)
        optimizer_G.zero_grad()
H
hypox64 已提交
224

H
hypox64 已提交
225
        # First, G(A) should fake the discriminator
H
hypox64 已提交
226 227
        fake_AB = torch.cat((real_A, pred), 1)
        pred_fake = netD(fake_AB)
228
        loss_G_GAN = criterionGAN(pred_fake, True)*opt.lambda_gan
H
hypox64 已提交
229 230 231 232 233 234 235 236 237 238 239 240 241
        # GAN feature matching loss
        # if opt.hd:
        #     real_AB = torch.cat((real_A, target), 1)
        #     pred_real = netD(real_AB)
        #     loss_G_GAN_Feat=criterionFeat(pred_fake,pred_real)
            # loss_G_GAN_Feat = 0
            # feat_weights = 4.0 / (opt.n_layers_D + 1)
            # D_weights = 1.0 / opt.num_D
            # for i in range(opt.num_D):
            #     for j in range(len(pred_fake[i])-1):
            #         loss_G_GAN_Feat += D_weights * feat_weights * criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * opt.lambda_feat
            
        # combine loss and calculate gradients
242 243
        if opt.l2:
            loss_G_L1 = (criterion_L1(pred, target)+criterion_L2(pred, target)) * opt.lambda_L1
H
hypox64 已提交
244
        else:
245
            loss_G_L1 = criterion_L1(pred, target) * opt.lambda_L1
H
hypox64 已提交
246 247 248 249 250 251 252 253 254

        if opt.hd:
            real_AB = torch.cat((real_A, target), 1)
            pred_real = netD(real_AB)
            loss_G_GAN_Feat = criterionFeat(pred_fake,pred_real)
            loss_VGG = criterionVGG(pred, target) * opt.lambda_feat
            loss_G = loss_G_GAN + loss_G_L1 + loss_G_GAN_Feat + loss_VGG
        else:
            loss_G = loss_G_GAN + loss_G_L1
H
hypox64 已提交
255 256
        loss_sum[0] += loss_G_L1.item()
        loss_sum[1] += loss_G_GAN.item()
H
hypox64 已提交
257 258 259
        loss_sum[2] += loss_G_GAN_Feat.item()
        loss_sum[3] += loss_VGG.item()

H
hypox64 已提交
260
        # udpate G's weights
H
hypox64 已提交
261 262 263 264
        loss_G.backward()
        optimizer_G.step()

    else:
H
hypox64 已提交
265
        pred = netG(inputdata)
266 267
        if opt.l2:
            loss_G_L1 = (criterion_L1(pred, target)+criterion_L2(pred, target)) * opt.lambda_L1
H
hypox64 已提交
268
        else:
269
            loss_G_L1 = criterion_L1(pred, target) * opt.lambda_L1
H
hypox64 已提交
270 271 272 273 274 275
        loss_sum[0] += loss_G_L1.item()

        optimizer_G.zero_grad()
        loss_G_L1.backward()
        optimizer_G.step()

H
hypox64 已提交
276 277 278 279 280 281 282 283 284 285 286
    # save eval result
    if (iter+1)%1000 == 0:
        video_index = random.randint(0,video_num-1)
        videoname = videonames[video_index]
        img_index = random.randint(int(N/2)+1,lengths[video_index]- int(N/2)-1)
        inputdata,target = data.load_train_video(videoname, img_index, opt)

        # inputdata,target = loaddata(random.randint(0,video_num-1))
        inputdata,target = inputdata.cuda(),target.cuda()
        with torch.no_grad():
            pred = netG(inputdata)
H
hypox64 已提交
287
        try:
H
hypox64 已提交
288
            data.showresult(inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:],
H
hypox64 已提交
289
                target, pred, os.path.join(dir_checkpoint,'result_eval.jpg'))
H
hypox64 已提交
290 291
        except Exception as e:
            print(e)
H
hypox64 已提交
292

H
HypoX64 已提交
293
    # plot
H
hypox64 已提交
294 295
    if (iter+1)%1000 == 0:
        time_end = time.time()
296
        if opt.gan:
H
hypox64 已提交
297
            savestr ='iter:{0:d} L1_loss:{1:.3f} GAN_loss:{2:.3f} Feat:{3:.3f} VGG:{4:.3f} time:{5:.2f}'.format(
298 299
                iter+1,loss_sum[0]/1000,loss_sum[1]/1000,loss_sum[2]/1000,loss_sum[3]/1000,(time_end-time_start)/1000)
            util.writelog(os.path.join(dir_checkpoint,'loss.txt'), savestr,True)
H
hypox64 已提交
300
            if (iter+1)/1000 >= 10:
H
hypox64 已提交
301
                for i in range(4):loss_plot[i].append(loss_sum[i]/1000)
H
hypox64 已提交
302 303
                item_plot.append(iter+1)
                try:
H
hypox64 已提交
304 305 306 307
                    labels = ['L1_loss','GAN_loss','GAN_Feat_loss','VGG_loss']
                    for i in range(4):plt.plot(item_plot,loss_plot[i],label=labels[i])     
                    plt.xlabel('iter')
                    plt.legend(loc=1)
308
                    plt.savefig(os.path.join(dir_checkpoint,'loss.jpg'))
H
hypox64 已提交
309 310 311
                    plt.close()
                except Exception as e:
                    print("error:",e)
H
hypox64 已提交
312 313

        loss_sum = [0.,0.,0.,0.,0.,0.]
H
hypox64 已提交
314 315
        time_start=time.time()

H
HypoX64 已提交
316
    # save network
H
hypox64 已提交
317
    if (iter+1)%(opt.savefreq//10) == 0:
H
hypox64 已提交
318
        torch.save(netG.cpu().state_dict(),os.path.join(dir_checkpoint,'last_G.pth'))
319
        if opt.gan:
H
hypox64 已提交
320
            torch.save(netD.cpu().state_dict(),os.path.join(dir_checkpoint,'last_D.pth'))
H
hypox64 已提交
321
        if opt.use_gpu !=-1 :
H
hypox64 已提交
322
            netG.cuda()
323
            if opt.gan:
H
hypox64 已提交
324 325 326 327
                netD.cuda()
        f = open(os.path.join(dir_checkpoint,'iter'),'w+')
        f.write(str(iter+1))
        f.close()
H
hypox64 已提交
328 329 330 331 332

    if (iter+1)%opt.savefreq == 0:
        os.rename(os.path.join(dir_checkpoint,'last_G.pth'),os.path.join(dir_checkpoint,str(iter+1)+'G.pth'))
        if opt.gan:
            os.rename(os.path.join(dir_checkpoint,'last_D.pth'),os.path.join(dir_checkpoint,str(iter+1)+'D.pth'))
H
hypox64 已提交
333 334
        print('network saved.')

H
hypox64 已提交
335 336
    #test
    if (iter+1)%opt.savefreq == 0:
H
HypoX64 已提交
337 338 339 340 341 342
        if os.path.isdir('./test'):  
            netG.eval()
            
            test_names = os.listdir('./test')
            test_names.sort()
            result = np.zeros((opt.finesize*2,opt.finesize*len(test_names),3), dtype='uint8')
H
hypox64 已提交
343

H
HypoX64 已提交
344 345 346 347 348 349 350 351
            for cnt,test_name in enumerate(test_names,0):
                img_names = os.listdir(os.path.join('./test',test_name,'image'))
                img_names.sort()
                inputdata = np.zeros((opt.finesize,opt.finesize,3*N+1), dtype='uint8')
                for i in range(0,N):
                    img = impro.imread(os.path.join('./test',test_name,'image',img_names[i]))
                    img = impro.resize(img,opt.finesize)
                    inputdata[:,:,i*3:(i+1)*3] = img
H
hypox64 已提交
352

H
HypoX64 已提交
353 354 355 356 357 358 359 360 361 362
                mask = impro.imread(os.path.join('./test',test_name,'mask.png'),'gray')
                mask = impro.resize(mask,opt.finesize)
                mask = impro.mask_threshold(mask,15,128)
                inputdata[:,:,-1] = mask
                result[0:opt.finesize,opt.finesize*cnt:opt.finesize*(cnt+1),:] = inputdata[:,:,int((N-1)/2)*3:(int((N-1)/2)+1)*3]
                inputdata = data.im2tensor(inputdata,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1 = False)
                pred = netG(inputdata)
     
                pred = data.tensor2im(pred,rgb2bgr = False, is0_1 = False)
                result[opt.finesize:opt.finesize*2,opt.finesize*cnt:opt.finesize*(cnt+1),:] = pred
H
hypox64 已提交
363

H
HypoX64 已提交
364 365
            cv2.imwrite(os.path.join(dir_checkpoint,str(iter+1)+'_test.jpg'), result)
            netG.train()