train.py 13.4 KB
Newer Older
H
hypox64 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
import os
import numpy as np
import cv2
import random
import torch
import torch.nn as nn
import time

import sys
sys.path.append("..")
sys.path.append("../..")
from util import mosaic,util,ffmpeg,filt,data
from util import image_processing as impro
from cores import Options
15
from models import pix2pix_model,pix2pixHD_model,video_model,unet_model,loadmodel,videoHD_model
H
hypox64 已提交
16 17 18
from matplotlib import pyplot as plt
import torch.backends.cudnn as cudnn

H
HypoX64 已提交
19 20 21 22
'''
--------------------------Get options--------------------------
'''

23
opt = Options()
H
hypox64 已提交
24
opt.parser.add_argument('--gpu_id',type=int,default=0, help='')
25 26 27
opt.parser.add_argument('--N',type=int,default=25, help='')
opt.parser.add_argument('--lr',type=float,default=0.0002, help='')
opt.parser.add_argument('--beta1',type=float,default=0.5, help='')
H
HypoX64 已提交
28 29 30
opt.parser.add_argument('--gan', action='store_true', help='if specified, use gan')
opt.parser.add_argument('--l2', action='store_true', help='if specified, use L2 loss')
opt.parser.add_argument('--hd', action='store_true', help='if specified, use HD model')
31 32 33 34 35
opt.parser.add_argument('--lambda_L1',type=float,default=100, help='')
opt.parser.add_argument('--lambda_gan',type=float,default=1, help='')
opt.parser.add_argument('--finesize',type=int,default=256, help='')
opt.parser.add_argument('--loadsize',type=int,default=286, help='')
opt.parser.add_argument('--batchsize',type=int,default=1, help='')
H
hypox64 已提交
36
opt.parser.add_argument('--perload_num',type=int,default=16, help='number of images pool')
37
opt.parser.add_argument('--norm',type=str,default='instance', help='')
H
hypox64 已提交
38

H
hypox64 已提交
39
opt.parser.add_argument('--dataset',type=str,default='./datasets/face/', help='')
40 41 42 43
opt.parser.add_argument('--maxiter',type=int,default=10000000, help='')
opt.parser.add_argument('--savefreq',type=int,default=10000, help='')
opt.parser.add_argument('--startiter',type=int,default=0, help='')
opt.parser.add_argument('--continuetrain', action='store_true', help='')
H
hypox64 已提交
44
opt.parser.add_argument('--savename',type=str,default='face', help='')
45

H
HypoX64 已提交
46 47 48 49

'''
--------------------------Init--------------------------
'''
50 51
opt = opt.getparse()
dir_checkpoint = os.path.join('checkpoints/',opt.savename)
H
hypox64 已提交
52
util.makedirs(dir_checkpoint)
53 54
util.writelog(os.path.join(dir_checkpoint,'loss.txt'), 
              str(time.asctime(time.localtime(time.time())))+'\n'+util.opt2str(opt))
H
hypox64 已提交
55
torch.cuda.set_device(opt.gpu_id)
H
hypox64 已提交
56

57
N = opt.N
H
hypox64 已提交
58 59 60 61
loss_sum = [0.,0.,0.,0.]
loss_plot = [[],[]]
item_plot = []

H
hypox64 已提交
62 63 64 65 66 67 68 69 70 71 72 73 74 75
# list video dir 
videonames = os.listdir(opt.dataset)
videonames.sort()
lengths = [];tmp = []
print('Check dataset...')
for video in videonames:
    if video != 'opt.txt':
        video_images = os.listdir(os.path.join(opt.dataset,video,'origin_image'))
        lengths.append(len(video_images))
        tmp.append(video)
videonames = tmp
video_num = len(videonames)
#def network
print('Init network...')
H
HypoX64 已提交
76 77 78 79
if opt.hd:
    netG = videoHD_model.MosaicNet(3*N+1, 3, norm=opt.norm)
else:
    netG = video_model.MosaicNet(3*N+1, 3, norm=opt.norm)
H
HypoX64 已提交
80
loadmodel.show_paramsnumber(netG,'netG')
H
HypoX64 已提交
81

82
if opt.gan:
H
HypoX64 已提交
83
    if opt.hd:
H
hypox64 已提交
84 85
        #netD = pix2pixHD_model.define_D(6, 64, 3, norm = opt.norm, use_sigmoid=False, num_D=1)
        netD = pix2pixHD_model.define_D(6, 64, 3, norm = opt.norm, use_sigmoid=False, num_D=2,getIntermFeat=True)    
H
HypoX64 已提交
86 87 88
    else:
        netD = pix2pix_model.define_D(3*2, 64, 'basic', norm = opt.norm)
    netD.train()
H
hypox64 已提交
89

90
if opt.continuetrain:
H
hypox64 已提交
91
    if not os.path.isfile(os.path.join(dir_checkpoint,'last_G.pth')):
92
        opt.continuetrain = False
H
hypox64 已提交
93
        print('can not load last_G, training on init weight.')
94
if opt.continuetrain:     
H
hypox64 已提交
95
    netG.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_G.pth')))
96
    if opt.gan:
H
hypox64 已提交
97 98
        netD.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_D.pth')))
    f = open(os.path.join(dir_checkpoint,'iter'),'r')
99
    opt.startiter = int(f.read())
H
hypox64 已提交
100 101
    f.close()

102
optimizer_G = torch.optim.Adam(netG.parameters(), lr=opt.lr,betas=(opt.beta1, 0.999))
H
hypox64 已提交
103 104
criterion_L1 = nn.L1Loss()
criterion_L2 = nn.MSELoss()
105
if opt.gan:
H
HypoX64 已提交
106 107 108 109 110
    optimizer_D = torch.optim.Adam(netD.parameters(), lr=opt.lr,betas=(opt.beta1, 0.999))
    if opt.hd:
        criterionGAN = pix2pixHD_model.GANLoss(tensor=torch.cuda.FloatTensor)
    else:
        criterionGAN = pix2pix_model.GANLoss(gan_mode='lsgan').cuda()   
H
hypox64 已提交
111

112
if opt.use_gpu:
H
hypox64 已提交
113
    netG.cuda()
114
    if opt.gan:
H
hypox64 已提交
115 116 117
        netD.cuda()
        criterionGAN.cuda()
    cudnn.benchmark = True
H
hypox64 已提交
118

H
HypoX64 已提交
119
'''
H
hypox64 已提交
120
--------------------------preload data & data pool--------------------------
H
HypoX64 已提交
121
'''
H
hypox64 已提交
122 123 124
def loaddata(video_index):
    
    videoname = videonames[video_index]
125
    img_index = random.randint(int(N/2)+1,lengths[video_index]- int(N/2)-1)
H
hypox64 已提交
126
    
127
    input_img = np.zeros((opt.loadsize,opt.loadsize,3*N+1), dtype='uint8')
H
hypox64 已提交
128 129 130 131 132 133 134
    # this frame
    this_mask = impro.imread(os.path.join(opt.dataset,videoname,'mask','%05d'%(img_index)+'.png'),'gray',loadsize=opt.loadsize)
    input_img[:,:,-1] = this_mask
    #print(os.path.join(opt.dataset,videoname,'origin_image','%05d'%(img_index)+'.jpg'))
    ground_true =  impro.imread(os.path.join(opt.dataset,videoname,'origin_image','%05d'%(img_index)+'.jpg'),loadsize=opt.loadsize)
    mosaic_size,mod,rect_rat,father = mosaic.get_random_parameter(ground_true,this_mask)
    # merge other frame
H
hypox64 已提交
135
    for i in range(0,N):
H
hypox64 已提交
136 137 138 139 140
        img =  impro.imread(os.path.join(opt.dataset,videoname,'origin_image','%05d'%(img_index+i-int(N/2))+'.jpg'),loadsize=opt.loadsize)
        mask = impro.imread(os.path.join(opt.dataset,videoname,'mask','%05d'%(img_index+i-int(N/2))+'.png'),'gray',loadsize=opt.loadsize)
        img_mosaic = mosaic.addmosaic_base(img, mask, mosaic_size,model = mod,rect_rat=rect_rat,father=father)
        input_img[:,:,i*3:(i+1)*3] = img_mosaic
    # to tensor
141
    input_img,ground_true = data.random_transform_video(input_img,ground_true,opt.finesize,N)
H
hypox64 已提交
142 143
    input_img = data.im2tensor(input_img,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1=False)
    ground_true = data.im2tensor(ground_true,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1=False)
H
hypox64 已提交
144 145 146
    
    return input_img,ground_true

H
hypox64 已提交
147
print('Preloading data, please wait...')
H
hypox64 已提交
148

149 150
if opt.perload_num <= opt.batchsize:
    opt.perload_num = opt.batchsize*2
H
hypox64 已提交
151
#data pool
152 153
input_imgs = torch.rand(opt.perload_num,N*3+1,opt.finesize,opt.finesize).cuda()
ground_trues = torch.rand(opt.perload_num,3,opt.finesize,opt.finesize).cuda()
H
hypox64 已提交
154
load_cnt = 0
H
hypox64 已提交
155

H
hypox64 已提交
156
def preload():
H
hypox64 已提交
157
    global load_cnt   
H
hypox64 已提交
158 159
    while 1:
        try:
H
hypox64 已提交
160
            video_index = random.randint(0,video_num-1)
161
            ran = random.randint(0, opt.perload_num-1)
H
hypox64 已提交
162
            input_imgs[ran],ground_trues[ran] = loaddata(video_index)
H
hypox64 已提交
163
            load_cnt += 1
H
hypox64 已提交
164 165 166 167
            # time.sleep(0.1)
        except Exception as e:
            print("error:",e)
import threading
H
hypox64 已提交
168
t = threading.Thread(target=preload,args=()) 
H
hypox64 已提交
169 170
t.daemon = True
t.start()
H
hypox64 已提交
171
time_start=time.time()
172
while load_cnt < opt.perload_num:
H
hypox64 已提交
173
    time.sleep(0.1)
H
hypox64 已提交
174
time_end=time.time()
175
print('load speed:',round((time_end-time_start)/opt.perload_num,3),'s/it')
H
hypox64 已提交
176

H
HypoX64 已提交
177 178 179
'''
--------------------------train--------------------------
'''
H
hypox64 已提交
180
util.copyfile('./train.py', os.path.join(dir_checkpoint,'train.py'))
181
util.copyfile('../../models/videoHD_model.py', os.path.join(dir_checkpoint,'model.py'))
H
hypox64 已提交
182 183 184
netG.train()
time_start=time.time()
print("Begin training...")
185
for iter in range(opt.startiter+1,opt.maxiter):
H
hypox64 已提交
186

187 188 189
    ran = random.randint(0, opt.perload_num-opt.batchsize-1)
    inputdata = input_imgs[ran:ran+opt.batchsize].clone()
    target = ground_trues[ran:ran+opt.batchsize].clone()
H
hypox64 已提交
190

191
    if opt.gan:
H
hypox64 已提交
192 193 194 195 196 197 198
        # compute fake images: G(A)
        pred = netG(inputdata)
        # update D
        pix2pix_model.set_requires_grad(netD,True)
        optimizer_D.zero_grad()
        # Fake
        real_A = inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:]
H
hypox64 已提交
199 200 201
        fake_AB = torch.cat((real_A, pred), 1)
        pred_fake = netD(fake_AB.detach())
        loss_D_fake = criterionGAN(pred_fake, False)
H
hypox64 已提交
202
        # Real
H
hypox64 已提交
203
        real_AB = torch.cat((real_A, target), 1)
H
hypox64 已提交
204 205
        pred_real = netD(real_AB)
        loss_D_real = criterionGAN(pred_real, True)
H
hypox64 已提交
206
        # combine loss and calculate gradients
H
hypox64 已提交
207 208 209
        loss_D = (loss_D_fake + loss_D_real) * 0.5
        loss_sum[2] += loss_D_fake.item()
        loss_sum[3] += loss_D_real.item()
H
hypox64 已提交
210
        # udpate D's weights
H
hypox64 已提交
211 212 213
        loss_D.backward()
        optimizer_D.step()

H
hypox64 已提交
214 215 216 217 218
        # update G
        pix2pix_model.set_requires_grad(netD,False)
        optimizer_G.zero_grad()
        # First, G(A) should fake the discriminator
        real_A = inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:]
H
hypox64 已提交
219 220
        fake_AB = torch.cat((real_A, pred), 1)
        pred_fake = netD(fake_AB)
221
        loss_G_GAN = criterionGAN(pred_fake, True)*opt.lambda_gan
H
hypox64 已提交
222
        # Second, G(A) = B
223 224
        if opt.l2:
            loss_G_L1 = (criterion_L1(pred, target)+criterion_L2(pred, target)) * opt.lambda_L1
H
hypox64 已提交
225
        else:
226
            loss_G_L1 = criterion_L1(pred, target) * opt.lambda_L1
H
hypox64 已提交
227 228 229 230
        # combine loss and calculate gradients
        loss_G = loss_G_GAN + loss_G_L1
        loss_sum[0] += loss_G_L1.item()
        loss_sum[1] += loss_G_GAN.item()
H
hypox64 已提交
231
        # udpate G's weights
H
hypox64 已提交
232 233 234 235
        loss_G.backward()
        optimizer_G.step()

    else:
H
hypox64 已提交
236
        pred = netG(inputdata)
237 238
        if opt.l2:
            loss_G_L1 = (criterion_L1(pred, target)+criterion_L2(pred, target)) * opt.lambda_L1
H
hypox64 已提交
239
        else:
240
            loss_G_L1 = criterion_L1(pred, target) * opt.lambda_L1
H
hypox64 已提交
241 242 243 244 245 246 247 248
        loss_sum[0] += loss_G_L1.item()

        optimizer_G.zero_grad()
        loss_G_L1.backward()
        optimizer_G.step()

    if (iter+1)%100 == 0:
        try:
H
hypox64 已提交
249
            data.showresult(inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:],
250
             target, pred,os.path.join(dir_checkpoint,'result_train.jpg'))
H
hypox64 已提交
251 252
        except Exception as e:
            print(e)
H
HypoX64 已提交
253
    # plot
H
hypox64 已提交
254 255
    if (iter+1)%1000 == 0:
        time_end = time.time()
256 257 258 259
        if opt.gan:
            savestr ='iter:{0:d} L1_loss:{1:.4f} G_loss:{2:.4f} D_f:{3:.4f} D_r:{4:.4f} time:{5:.2f}'.format(
                iter+1,loss_sum[0]/1000,loss_sum[1]/1000,loss_sum[2]/1000,loss_sum[3]/1000,(time_end-time_start)/1000)
            util.writelog(os.path.join(dir_checkpoint,'loss.txt'), savestr,True)
H
hypox64 已提交
260 261 262 263 264 265 266
            if (iter+1)/1000 >= 10:
                loss_plot[0].append(loss_sum[0]/1000)
                loss_plot[1].append(loss_sum[1]/1000)
                item_plot.append(iter+1)
                try:
                    plt.plot(item_plot,loss_plot[0])
                    plt.plot(item_plot,loss_plot[1])
267
                    plt.savefig(os.path.join(dir_checkpoint,'loss.jpg'))
H
hypox64 已提交
268 269 270 271
                    plt.close()
                except Exception as e:
                    print("error:",e)
        else:
272 273
            savestr ='iter:{0:d}  L1_loss:{1:.4f}  time:{2:.2f}'.format(iter+1,loss_sum[0]/1000,(time_end-time_start)/1000)
            util.writelog(os.path.join(dir_checkpoint,'loss.txt'), savestr,True)
H
hypox64 已提交
274 275 276 277 278
            if (iter+1)/1000 >= 10:
                loss_plot[0].append(loss_sum[0]/1000)
                item_plot.append(iter+1)
                try:
                    plt.plot(item_plot,loss_plot[0])
279
                    plt.savefig(os.path.join(dir_checkpoint,'loss.jpg'))
H
hypox64 已提交
280 281 282 283 284 285
                    plt.close()
                except Exception as e:
                    print("error:",e)
        loss_sum = [0.,0.,0.,0.]
        time_start=time.time()

H
HypoX64 已提交
286
    # save network
287 288 289
    if (iter+1)%opt.savefreq == 0:
        if iter+1 != opt.savefreq:
            os.rename(os.path.join(dir_checkpoint,'last_G.pth'),os.path.join(dir_checkpoint,str(iter+1-opt.savefreq)+'G.pth'))
H
hypox64 已提交
290
        torch.save(netG.cpu().state_dict(),os.path.join(dir_checkpoint,'last_G.pth'))
291 292 293
        if opt.gan:
            if iter+1 != opt.savefreq:
                os.rename(os.path.join(dir_checkpoint,'last_D.pth'),os.path.join(dir_checkpoint,str(iter+1-opt.savefreq)+'D.pth'))
H
hypox64 已提交
294
            torch.save(netD.cpu().state_dict(),os.path.join(dir_checkpoint,'last_D.pth'))
295
        if opt.use_gpu:
H
hypox64 已提交
296
            netG.cuda()
297
            if opt.gan:
H
hypox64 已提交
298 299 300 301 302 303 304
                netD.cuda()
        f = open(os.path.join(dir_checkpoint,'iter'),'w+')
        f.write(str(iter+1))
        f.close()
        print('network saved.')

        #test
H
HypoX64 已提交
305 306 307 308 309 310
        if os.path.isdir('./test'):  
            netG.eval()
            
            test_names = os.listdir('./test')
            test_names.sort()
            result = np.zeros((opt.finesize*2,opt.finesize*len(test_names),3), dtype='uint8')
H
hypox64 已提交
311

H
HypoX64 已提交
312 313 314 315 316 317 318 319
            for cnt,test_name in enumerate(test_names,0):
                img_names = os.listdir(os.path.join('./test',test_name,'image'))
                img_names.sort()
                inputdata = np.zeros((opt.finesize,opt.finesize,3*N+1), dtype='uint8')
                for i in range(0,N):
                    img = impro.imread(os.path.join('./test',test_name,'image',img_names[i]))
                    img = impro.resize(img,opt.finesize)
                    inputdata[:,:,i*3:(i+1)*3] = img
H
hypox64 已提交
320

H
HypoX64 已提交
321 322 323 324 325 326 327 328 329 330
                mask = impro.imread(os.path.join('./test',test_name,'mask.png'),'gray')
                mask = impro.resize(mask,opt.finesize)
                mask = impro.mask_threshold(mask,15,128)
                inputdata[:,:,-1] = mask
                result[0:opt.finesize,opt.finesize*cnt:opt.finesize*(cnt+1),:] = inputdata[:,:,int((N-1)/2)*3:(int((N-1)/2)+1)*3]
                inputdata = data.im2tensor(inputdata,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1 = False)
                pred = netG(inputdata)
     
                pred = data.tensor2im(pred,rgb2bgr = False, is0_1 = False)
                result[opt.finesize:opt.finesize*2,opt.finesize*cnt:opt.finesize*(cnt+1),:] = pred
H
hypox64 已提交
331

H
HypoX64 已提交
332 333
            cv2.imwrite(os.path.join(dir_checkpoint,str(iter+1)+'_test.jpg'), result)
            netG.train()