train.py 12.2 KB
Newer Older
H
hypox64 已提交
1
import os
H
hypox64 已提交
2 3 4 5 6 7
import sys
sys.path.append("..")
sys.path.append("../..")
from cores import Options
opt = Options()

H
hypox64 已提交
8 9 10 11 12 13
import numpy as np
import cv2
import random
import torch
import torch.nn as nn
import time
14
from multiprocessing import Process, Queue
H
hypox64 已提交
15 16 17

from util import mosaic,util,ffmpeg,filt,data
from util import image_processing as impro
18
from models import pix2pix_model,pix2pixHD_model,video_model,unet_model,loadmodel,videoHD_model
H
hypox64 已提交
19 20 21
from matplotlib import pyplot as plt
import torch.backends.cudnn as cudnn

H
HypoX64 已提交
22 23 24
'''
--------------------------Get options--------------------------
'''
25 26 27
opt.parser.add_argument('--N',type=int,default=25, help='')
opt.parser.add_argument('--lr',type=float,default=0.0002, help='')
opt.parser.add_argument('--beta1',type=float,default=0.5, help='')
H
HypoX64 已提交
28 29 30
opt.parser.add_argument('--gan', action='store_true', help='if specified, use gan')
opt.parser.add_argument('--l2', action='store_true', help='if specified, use L2 loss')
opt.parser.add_argument('--hd', action='store_true', help='if specified, use HD model')
31 32 33 34 35 36
opt.parser.add_argument('--lambda_L1',type=float,default=100, help='')
opt.parser.add_argument('--lambda_gan',type=float,default=1, help='')
opt.parser.add_argument('--finesize',type=int,default=256, help='')
opt.parser.add_argument('--loadsize',type=int,default=286, help='')
opt.parser.add_argument('--batchsize',type=int,default=1, help='')
opt.parser.add_argument('--norm',type=str,default='instance', help='')
H
hypox64 已提交
37 38 39
opt.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use')
opt.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
opt.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss') 
40 41
opt.parser.add_argument('--image_pool',type=int,default=8, help='number of image load pool')
opt.parser.add_argument('--load_process',type=int,default=4, help='number of process for loading data')
H
hypox64 已提交
42

H
hypox64 已提交
43
opt.parser.add_argument('--dataset',type=str,default='./datasets/face/', help='')
44 45 46
opt.parser.add_argument('--maxiter',type=int,default=10000000, help='')
opt.parser.add_argument('--savefreq',type=int,default=10000, help='')
opt.parser.add_argument('--startiter',type=int,default=0, help='')
47
opt.parser.add_argument('--continue_train', action='store_true', help='')
H
hypox64 已提交
48
opt.parser.add_argument('--savename',type=str,default='face', help='')
49

H
HypoX64 已提交
50 51 52 53

'''
--------------------------Init--------------------------
'''
54 55
opt = opt.getparse()
dir_checkpoint = os.path.join('checkpoints/',opt.savename)
H
hypox64 已提交
56
util.makedirs(dir_checkpoint)
57 58
util.writelog(os.path.join(dir_checkpoint,'loss.txt'), 
              str(time.asctime(time.localtime(time.time())))+'\n'+util.opt2str(opt))
H
hypox64 已提交
59
cudnn.benchmark = True
H
hypox64 已提交
60

61
N = opt.N
H
hypox64 已提交
62 63
loss_sum = [0.,0.,0.,0.,0.,0]
loss_plot = [[],[],[],[]]
H
hypox64 已提交
64 65
item_plot = []

H
hypox64 已提交
66 67 68 69 70 71 72 73 74 75 76 77
# list video dir 
videonames = os.listdir(opt.dataset)
videonames.sort()
lengths = [];tmp = []
print('Check dataset...')
for video in videonames:
    if video != 'opt.txt':
        video_images = os.listdir(os.path.join(opt.dataset,video,'origin_image'))
        lengths.append(len(video_images))
        tmp.append(video)
videonames = tmp
video_num = len(videonames)
H
hypox64 已提交
78 79

#--------------------------Init network--------------------------
H
hypox64 已提交
80
print('Init network...')
H
HypoX64 已提交
81 82 83 84
if opt.hd:
    netG = videoHD_model.MosaicNet(3*N+1, 3, norm=opt.norm)
else:
    netG = video_model.MosaicNet(3*N+1, 3, norm=opt.norm)
H
hypox64 已提交
85
netG.cuda()
H
HypoX64 已提交
86
loadmodel.show_paramsnumber(netG,'netG')
H
HypoX64 已提交
87

88
if opt.gan:
H
HypoX64 已提交
89
    if opt.hd:
H
hypox64 已提交
90
        netD = pix2pixHD_model.define_D(6, 64, opt.n_layers_D, norm = opt.norm, use_sigmoid=False, num_D=opt.num_D,getIntermFeat=True)    
H
HypoX64 已提交
91 92
    else:
        netD = pix2pix_model.define_D(3*2, 64, 'basic', norm = opt.norm)
H
hypox64 已提交
93
    netD.cuda()
94
    netD.train()
H
hypox64 已提交
95

H
hypox64 已提交
96
#--------------------------continue train--------------------------
97
if opt.continue_train:
H
hypox64 已提交
98
    if not os.path.isfile(os.path.join(dir_checkpoint,'last_G.pth')):
99
        opt.continue_train = False
H
hypox64 已提交
100
        print('can not load last_G, training on init weight.')
101
if opt.continue_train:     
H
hypox64 已提交
102
    netG.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_G.pth')))
103
    if opt.gan:
H
hypox64 已提交
104 105
        netD.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_D.pth')))
    f = open(os.path.join(dir_checkpoint,'iter'),'r')
106
    opt.startiter = int(f.read())
H
hypox64 已提交
107 108
    f.close()

H
hypox64 已提交
109
#--------------------------optimizer & loss--------------------------
110
optimizer_G = torch.optim.Adam(netG.parameters(), lr=opt.lr,betas=(opt.beta1, 0.999))
H
hypox64 已提交
111 112
criterion_L1 = nn.L1Loss()
criterion_L2 = nn.MSELoss()
113
if opt.gan:
H
HypoX64 已提交
114 115
    optimizer_D = torch.optim.Adam(netD.parameters(), lr=opt.lr,betas=(opt.beta1, 0.999))
    if opt.hd:
H
hypox64 已提交
116 117 118
        criterionGAN = pix2pixHD_model.GANLoss(tensor=torch.cuda.FloatTensor).cuda() 
        criterionFeat = pix2pixHD_model.GAN_Feat_loss(opt)
        criterionVGG = pix2pixHD_model.VGGLoss([opt.use_gpu])
H
HypoX64 已提交
119 120
    else:
        criterionGAN = pix2pix_model.GANLoss(gan_mode='lsgan').cuda()   
H
hypox64 已提交
121

H
HypoX64 已提交
122
'''
H
hypox64 已提交
123
--------------------------preload data & data pool--------------------------
H
HypoX64 已提交
124
'''
H
hypox64 已提交
125
print('Preloading data, please wait...')
126 127 128 129
def preload(pool):
    cnt = 0
    input_imgs = torch.rand(opt.batchsize,N*3+1,opt.finesize,opt.finesize)
    ground_trues = torch.rand(opt.batchsize,3,opt.finesize,opt.finesize)
H
hypox64 已提交
130 131
    while 1:
        try:
132 133 134 135 136 137 138
            for i in range(opt.batchsize):
                video_index = random.randint(0,video_num-1)
                videoname = videonames[video_index]
                img_index = random.randint(int(N/2)+1,lengths[video_index]- int(N/2)-1)
                input_imgs[i],ground_trues[i] = data.load_train_video(videoname,img_index,opt)
            cnt += 1
            pool.put([input_imgs,ground_trues])
H
hypox64 已提交
139
        except Exception as e:
140 141 142 143 144 145
            print("Error:",videoname,e)
pool = Queue(opt.image_pool)
for i in range(opt.load_process):
    p = Process(target=preload,args=(pool,))
    p.daemon = True
    p.start()
H
hypox64 已提交
146

H
HypoX64 已提交
147 148 149
'''
--------------------------train--------------------------
'''
H
hypox64 已提交
150
util.copyfile('./train.py', os.path.join(dir_checkpoint,'train.py'))
151
util.copyfile('../../models/videoHD_model.py', os.path.join(dir_checkpoint,'model.py'))
H
hypox64 已提交
152 153 154
netG.train()
time_start=time.time()
print("Begin training...")
155
for iter in range(opt.startiter+1,opt.maxiter):
H
hypox64 已提交
156

157 158
    inputdata,target = pool.get()
    inputdata,target = inputdata.cuda(),target.cuda()
H
hypox64 已提交
159

160
    if opt.gan:
H
hypox64 已提交
161 162
        # compute fake images: G(A)
        pred = netG(inputdata)
H
hypox64 已提交
163 164 165
        real_A = inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:]
        
        # --------------------update D--------------------
H
hypox64 已提交
166 167 168
        pix2pix_model.set_requires_grad(netD,True)
        optimizer_D.zero_grad()
        # Fake
H
hypox64 已提交
169 170 171
        fake_AB = torch.cat((real_A, pred), 1)
        pred_fake = netD(fake_AB.detach())
        loss_D_fake = criterionGAN(pred_fake, False)
H
hypox64 已提交
172
        # Real
H
hypox64 已提交
173
        real_AB = torch.cat((real_A, target), 1)
H
hypox64 已提交
174 175
        pred_real = netD(real_AB)
        loss_D_real = criterionGAN(pred_real, True)
H
hypox64 已提交
176
        # combine loss and calculate gradients
H
hypox64 已提交
177
        loss_D = (loss_D_fake + loss_D_real) * 0.5
H
hypox64 已提交
178 179
        loss_sum[4] += loss_D_fake.item()
        loss_sum[5] += loss_D_real.item()
H
hypox64 已提交
180
        # udpate D's weights
H
hypox64 已提交
181 182 183
        loss_D.backward()
        optimizer_D.step()

H
hypox64 已提交
184
        # --------------------update G--------------------
H
hypox64 已提交
185 186
        pix2pix_model.set_requires_grad(netD,False)
        optimizer_G.zero_grad()
H
hypox64 已提交
187

H
hypox64 已提交
188
        # First, G(A) should fake the discriminator
H
hypox64 已提交
189 190
        fake_AB = torch.cat((real_A, pred), 1)
        pred_fake = netD(fake_AB)
191
        loss_G_GAN = criterionGAN(pred_fake, True)*opt.lambda_gan
H
hypox64 已提交
192 193
            
        # combine loss and calculate gradients
194 195
        if opt.l2:
            loss_G_L1 = (criterion_L1(pred, target)+criterion_L2(pred, target)) * opt.lambda_L1
H
hypox64 已提交
196
        else:
197
            loss_G_L1 = criterion_L1(pred, target) * opt.lambda_L1
H
hypox64 已提交
198 199 200 201 202 203 204 205 206

        if opt.hd:
            real_AB = torch.cat((real_A, target), 1)
            pred_real = netD(real_AB)
            loss_G_GAN_Feat = criterionFeat(pred_fake,pred_real)
            loss_VGG = criterionVGG(pred, target) * opt.lambda_feat
            loss_G = loss_G_GAN + loss_G_L1 + loss_G_GAN_Feat + loss_VGG
        else:
            loss_G = loss_G_GAN + loss_G_L1
H
hypox64 已提交
207 208
        loss_sum[0] += loss_G_L1.item()
        loss_sum[1] += loss_G_GAN.item()
H
hypox64 已提交
209 210 211
        loss_sum[2] += loss_G_GAN_Feat.item()
        loss_sum[3] += loss_VGG.item()

H
hypox64 已提交
212
        # udpate G's weights
H
hypox64 已提交
213 214 215 216
        loss_G.backward()
        optimizer_G.step()

    else:
H
hypox64 已提交
217
        pred = netG(inputdata)
218 219
        if opt.l2:
            loss_G_L1 = (criterion_L1(pred, target)+criterion_L2(pred, target)) * opt.lambda_L1
H
hypox64 已提交
220
        else:
221
            loss_G_L1 = criterion_L1(pred, target) * opt.lambda_L1
H
hypox64 已提交
222 223 224 225 226 227
        loss_sum[0] += loss_G_L1.item()

        optimizer_G.zero_grad()
        loss_G_L1.backward()
        optimizer_G.step()

228
    # save train result
H
hypox64 已提交
229
    if (iter+1)%1000 == 0:
H
hypox64 已提交
230
        try:
H
hypox64 已提交
231
            data.showresult(inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:],
232
                target, pred, os.path.join(dir_checkpoint,'result_train.jpg'))
H
hypox64 已提交
233 234
        except Exception as e:
            print(e)
H
hypox64 已提交
235

H
HypoX64 已提交
236
    # plot
H
hypox64 已提交
237 238
    if (iter+1)%1000 == 0:
        time_end = time.time()
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
        #if opt.gan:
        savestr ='iter:{0:d} L1_loss:{1:.3f} GAN_loss:{2:.3f} Feat:{3:.3f} VGG:{4:.3f} time:{5:.2f}'.format(
            iter+1,loss_sum[0]/1000,loss_sum[1]/1000,loss_sum[2]/1000,loss_sum[3]/1000,(time_end-time_start)/1000)
        util.writelog(os.path.join(dir_checkpoint,'loss.txt'), savestr,True)
        if (iter+1)/1000 >= 10:
            for i in range(4):loss_plot[i].append(loss_sum[i]/1000)
            item_plot.append(iter+1)
            try:
                labels = ['L1_loss','GAN_loss','GAN_Feat_loss','VGG_loss']
                for i in range(4):plt.plot(item_plot,loss_plot[i],label=labels[i])     
                plt.xlabel('iter')
                plt.legend(loc=1)
                plt.savefig(os.path.join(dir_checkpoint,'loss.jpg'))
                plt.close()
            except Exception as e:
                print("error:",e)
H
hypox64 已提交
255 256

        loss_sum = [0.,0.,0.,0.,0.,0.]
H
hypox64 已提交
257 258
        time_start=time.time()

H
HypoX64 已提交
259
    # save network
H
hypox64 已提交
260
    if (iter+1)%(opt.savefreq//10) == 0:
H
hypox64 已提交
261
        torch.save(netG.cpu().state_dict(),os.path.join(dir_checkpoint,'last_G.pth'))
262
        if opt.gan:
H
hypox64 已提交
263
            torch.save(netD.cpu().state_dict(),os.path.join(dir_checkpoint,'last_D.pth'))
H
hypox64 已提交
264
        if opt.use_gpu !=-1 :
H
hypox64 已提交
265
            netG.cuda()
266
            if opt.gan:
H
hypox64 已提交
267 268 269 270
                netD.cuda()
        f = open(os.path.join(dir_checkpoint,'iter'),'w+')
        f.write(str(iter+1))
        f.close()
H
hypox64 已提交
271 272 273 274 275

    if (iter+1)%opt.savefreq == 0:
        os.rename(os.path.join(dir_checkpoint,'last_G.pth'),os.path.join(dir_checkpoint,str(iter+1)+'G.pth'))
        if opt.gan:
            os.rename(os.path.join(dir_checkpoint,'last_D.pth'),os.path.join(dir_checkpoint,str(iter+1)+'D.pth'))
H
hypox64 已提交
276 277
        print('network saved.')

H
hypox64 已提交
278 279
    #test
    if (iter+1)%opt.savefreq == 0:
H
HypoX64 已提交
280 281 282 283 284 285
        if os.path.isdir('./test'):  
            netG.eval()
            
            test_names = os.listdir('./test')
            test_names.sort()
            result = np.zeros((opt.finesize*2,opt.finesize*len(test_names),3), dtype='uint8')
H
hypox64 已提交
286

H
HypoX64 已提交
287 288 289 290 291 292 293 294
            for cnt,test_name in enumerate(test_names,0):
                img_names = os.listdir(os.path.join('./test',test_name,'image'))
                img_names.sort()
                inputdata = np.zeros((opt.finesize,opt.finesize,3*N+1), dtype='uint8')
                for i in range(0,N):
                    img = impro.imread(os.path.join('./test',test_name,'image',img_names[i]))
                    img = impro.resize(img,opt.finesize)
                    inputdata[:,:,i*3:(i+1)*3] = img
H
hypox64 已提交
295

H
HypoX64 已提交
296 297 298 299 300 301 302 303 304 305
                mask = impro.imread(os.path.join('./test',test_name,'mask.png'),'gray')
                mask = impro.resize(mask,opt.finesize)
                mask = impro.mask_threshold(mask,15,128)
                inputdata[:,:,-1] = mask
                result[0:opt.finesize,opt.finesize*cnt:opt.finesize*(cnt+1),:] = inputdata[:,:,int((N-1)/2)*3:(int((N-1)/2)+1)*3]
                inputdata = data.im2tensor(inputdata,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1 = False)
                pred = netG(inputdata)
     
                pred = data.tensor2im(pred,rgb2bgr = False, is0_1 = False)
                result[opt.finesize:opt.finesize*2,opt.finesize*cnt:opt.finesize*(cnt+1),:] = pred
H
hypox64 已提交
306

H
HypoX64 已提交
307
            cv2.imwrite(os.path.join(dir_checkpoint,str(iter+1)+'_test.jpg'), result)
308
            netG.train()