train.py 12.2 KB
Newer Older
H
hypox64 已提交
1
import os
H
hypox64 已提交
2 3 4 5 6 7
import sys
sys.path.append("..")
sys.path.append("../..")
from cores import Options
opt = Options()

H
hypox64 已提交
8 9 10 11 12 13
import numpy as np
import cv2
import random
import torch
import torch.nn as nn
import time
14
from multiprocessing import Process, Queue
H
hypox64 已提交
15 16 17

from util import mosaic,util,ffmpeg,filt,data
from util import image_processing as impro
18
from models import pix2pix_model,pix2pixHD_model,video_model,unet_model,loadmodel,videoHD_model
H
HypoX64 已提交
19 20
import matplotlib
matplotlib.use('Agg')
H
hypox64 已提交
21 22 23
from matplotlib import pyplot as plt
import torch.backends.cudnn as cudnn

H
HypoX64 已提交
24 25 26
'''
--------------------------Get options--------------------------
'''
27 28 29
opt.parser.add_argument('--N',type=int,default=25, help='')
opt.parser.add_argument('--lr',type=float,default=0.0002, help='')
opt.parser.add_argument('--beta1',type=float,default=0.5, help='')
H
HypoX64 已提交
30 31 32
opt.parser.add_argument('--gan', action='store_true', help='if specified, use gan')
opt.parser.add_argument('--l2', action='store_true', help='if specified, use L2 loss')
opt.parser.add_argument('--hd', action='store_true', help='if specified, use HD model')
33 34 35 36 37 38
opt.parser.add_argument('--lambda_L1',type=float,default=100, help='')
opt.parser.add_argument('--lambda_gan',type=float,default=1, help='')
opt.parser.add_argument('--finesize',type=int,default=256, help='')
opt.parser.add_argument('--loadsize',type=int,default=286, help='')
opt.parser.add_argument('--batchsize',type=int,default=1, help='')
opt.parser.add_argument('--norm',type=str,default='instance', help='')
H
hypox64 已提交
39 40 41
opt.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use')
opt.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
opt.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss') 
42 43
opt.parser.add_argument('--image_pool',type=int,default=8, help='number of image load pool')
opt.parser.add_argument('--load_process',type=int,default=4, help='number of process for loading data')
H
hypox64 已提交
44

H
hypox64 已提交
45
opt.parser.add_argument('--dataset',type=str,default='./datasets/face/', help='')
46 47 48
opt.parser.add_argument('--maxiter',type=int,default=10000000, help='')
opt.parser.add_argument('--savefreq',type=int,default=10000, help='')
opt.parser.add_argument('--startiter',type=int,default=0, help='')
49
opt.parser.add_argument('--continue_train', action='store_true', help='')
H
hypox64 已提交
50
opt.parser.add_argument('--savename',type=str,default='face', help='')
51

H
HypoX64 已提交
52 53 54 55

'''
--------------------------Init--------------------------
'''
56 57
opt = opt.getparse()
dir_checkpoint = os.path.join('checkpoints/',opt.savename)
H
hypox64 已提交
58
util.makedirs(dir_checkpoint)
59 60
util.writelog(os.path.join(dir_checkpoint,'loss.txt'), 
              str(time.asctime(time.localtime(time.time())))+'\n'+util.opt2str(opt))
H
hypox64 已提交
61
cudnn.benchmark = True
H
hypox64 已提交
62

63
N = opt.N
H
hypox64 已提交
64 65
loss_sum = [0.,0.,0.,0.,0.,0]
loss_plot = [[],[],[],[]]
H
hypox64 已提交
66 67
item_plot = []

H
hypox64 已提交
68 69 70 71 72 73 74 75 76 77 78 79
# list video dir 
videonames = os.listdir(opt.dataset)
videonames.sort()
lengths = [];tmp = []
print('Check dataset...')
for video in videonames:
    if video != 'opt.txt':
        video_images = os.listdir(os.path.join(opt.dataset,video,'origin_image'))
        lengths.append(len(video_images))
        tmp.append(video)
videonames = tmp
video_num = len(videonames)
H
hypox64 已提交
80 81

#--------------------------Init network--------------------------
H
hypox64 已提交
82
print('Init network...')
H
HypoX64 已提交
83 84 85 86
if opt.hd:
    netG = videoHD_model.MosaicNet(3*N+1, 3, norm=opt.norm)
else:
    netG = video_model.MosaicNet(3*N+1, 3, norm=opt.norm)
H
hypox64 已提交
87
netG.cuda()
H
HypoX64 已提交
88
loadmodel.show_paramsnumber(netG,'netG')
H
HypoX64 已提交
89

90
if opt.gan:
H
HypoX64 已提交
91
    if opt.hd:
H
hypox64 已提交
92
        netD = pix2pixHD_model.define_D(6, 64, opt.n_layers_D, norm = opt.norm, use_sigmoid=False, num_D=opt.num_D,getIntermFeat=True)    
H
HypoX64 已提交
93 94
    else:
        netD = pix2pix_model.define_D(3*2, 64, 'basic', norm = opt.norm)
H
hypox64 已提交
95
    netD.cuda()
96
    netD.train()
H
hypox64 已提交
97

H
hypox64 已提交
98
#--------------------------continue train--------------------------
99
if opt.continue_train:
H
hypox64 已提交
100
    if not os.path.isfile(os.path.join(dir_checkpoint,'last_G.pth')):
101
        opt.continue_train = False
H
hypox64 已提交
102
        print('can not load last_G, training on init weight.')
103
if opt.continue_train:     
H
hypox64 已提交
104
    netG.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_G.pth')))
105
    if opt.gan:
H
hypox64 已提交
106 107
        netD.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_D.pth')))
    f = open(os.path.join(dir_checkpoint,'iter'),'r')
108
    opt.startiter = int(f.read())
H
hypox64 已提交
109 110
    f.close()

H
hypox64 已提交
111
#--------------------------optimizer & loss--------------------------
112
optimizer_G = torch.optim.Adam(netG.parameters(), lr=opt.lr,betas=(opt.beta1, 0.999))
H
hypox64 已提交
113 114
criterion_L1 = nn.L1Loss()
criterion_L2 = nn.MSELoss()
115
if opt.gan:
H
HypoX64 已提交
116 117
    optimizer_D = torch.optim.Adam(netD.parameters(), lr=opt.lr,betas=(opt.beta1, 0.999))
    if opt.hd:
H
hypox64 已提交
118 119 120
        criterionGAN = pix2pixHD_model.GANLoss(tensor=torch.cuda.FloatTensor).cuda() 
        criterionFeat = pix2pixHD_model.GAN_Feat_loss(opt)
        criterionVGG = pix2pixHD_model.VGGLoss([opt.use_gpu])
H
HypoX64 已提交
121 122
    else:
        criterionGAN = pix2pix_model.GANLoss(gan_mode='lsgan').cuda()   
H
hypox64 已提交
123

H
HypoX64 已提交
124
'''
H
hypox64 已提交
125
--------------------------preload data & data pool--------------------------
H
HypoX64 已提交
126
'''
H
hypox64 已提交
127
print('Preloading data, please wait...')
128 129 130 131
def preload(pool):
    cnt = 0
    input_imgs = torch.rand(opt.batchsize,N*3+1,opt.finesize,opt.finesize)
    ground_trues = torch.rand(opt.batchsize,3,opt.finesize,opt.finesize)
H
hypox64 已提交
132 133
    while 1:
        try:
134 135 136 137 138 139 140
            for i in range(opt.batchsize):
                video_index = random.randint(0,video_num-1)
                videoname = videonames[video_index]
                img_index = random.randint(int(N/2)+1,lengths[video_index]- int(N/2)-1)
                input_imgs[i],ground_trues[i] = data.load_train_video(videoname,img_index,opt)
            cnt += 1
            pool.put([input_imgs,ground_trues])
H
hypox64 已提交
141
        except Exception as e:
142 143 144 145 146 147
            print("Error:",videoname,e)
pool = Queue(opt.image_pool)
for i in range(opt.load_process):
    p = Process(target=preload,args=(pool,))
    p.daemon = True
    p.start()
H
hypox64 已提交
148

H
HypoX64 已提交
149 150 151
'''
--------------------------train--------------------------
'''
H
hypox64 已提交
152
util.copyfile('./train.py', os.path.join(dir_checkpoint,'train.py'))
153
util.copyfile('../../models/videoHD_model.py', os.path.join(dir_checkpoint,'model.py'))
H
hypox64 已提交
154 155 156
netG.train()
time_start=time.time()
print("Begin training...")
157
for iter in range(opt.startiter+1,opt.maxiter):
H
hypox64 已提交
158

159 160
    inputdata,target = pool.get()
    inputdata,target = inputdata.cuda(),target.cuda()
H
hypox64 已提交
161

162
    if opt.gan:
H
hypox64 已提交
163 164
        # compute fake images: G(A)
        pred = netG(inputdata)
H
hypox64 已提交
165 166 167
        real_A = inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:]
        
        # --------------------update D--------------------
H
hypox64 已提交
168 169 170
        pix2pix_model.set_requires_grad(netD,True)
        optimizer_D.zero_grad()
        # Fake
H
hypox64 已提交
171 172 173
        fake_AB = torch.cat((real_A, pred), 1)
        pred_fake = netD(fake_AB.detach())
        loss_D_fake = criterionGAN(pred_fake, False)
H
hypox64 已提交
174
        # Real
H
hypox64 已提交
175
        real_AB = torch.cat((real_A, target), 1)
H
hypox64 已提交
176 177
        pred_real = netD(real_AB)
        loss_D_real = criterionGAN(pred_real, True)
H
hypox64 已提交
178
        # combine loss and calculate gradients
H
hypox64 已提交
179
        loss_D = (loss_D_fake + loss_D_real) * 0.5
H
hypox64 已提交
180 181
        loss_sum[4] += loss_D_fake.item()
        loss_sum[5] += loss_D_real.item()
H
hypox64 已提交
182
        # udpate D's weights
H
hypox64 已提交
183 184 185
        loss_D.backward()
        optimizer_D.step()

H
hypox64 已提交
186
        # --------------------update G--------------------
H
hypox64 已提交
187 188
        pix2pix_model.set_requires_grad(netD,False)
        optimizer_G.zero_grad()
H
hypox64 已提交
189

H
hypox64 已提交
190
        # First, G(A) should fake the discriminator
H
hypox64 已提交
191 192
        fake_AB = torch.cat((real_A, pred), 1)
        pred_fake = netD(fake_AB)
193
        loss_G_GAN = criterionGAN(pred_fake, True)*opt.lambda_gan
H
hypox64 已提交
194 195
            
        # combine loss and calculate gradients
196 197
        if opt.l2:
            loss_G_L1 = (criterion_L1(pred, target)+criterion_L2(pred, target)) * opt.lambda_L1
H
hypox64 已提交
198
        else:
199
            loss_G_L1 = criterion_L1(pred, target) * opt.lambda_L1
H
hypox64 已提交
200 201 202 203 204 205 206 207 208

        if opt.hd:
            real_AB = torch.cat((real_A, target), 1)
            pred_real = netD(real_AB)
            loss_G_GAN_Feat = criterionFeat(pred_fake,pred_real)
            loss_VGG = criterionVGG(pred, target) * opt.lambda_feat
            loss_G = loss_G_GAN + loss_G_L1 + loss_G_GAN_Feat + loss_VGG
        else:
            loss_G = loss_G_GAN + loss_G_L1
H
hypox64 已提交
209 210
        loss_sum[0] += loss_G_L1.item()
        loss_sum[1] += loss_G_GAN.item()
H
hypox64 已提交
211 212 213
        loss_sum[2] += loss_G_GAN_Feat.item()
        loss_sum[3] += loss_VGG.item()

H
hypox64 已提交
214
        # udpate G's weights
H
hypox64 已提交
215 216 217 218
        loss_G.backward()
        optimizer_G.step()

    else:
H
hypox64 已提交
219
        pred = netG(inputdata)
220 221
        if opt.l2:
            loss_G_L1 = (criterion_L1(pred, target)+criterion_L2(pred, target)) * opt.lambda_L1
H
hypox64 已提交
222
        else:
223
            loss_G_L1 = criterion_L1(pred, target) * opt.lambda_L1
H
hypox64 已提交
224 225 226 227 228 229
        loss_sum[0] += loss_G_L1.item()

        optimizer_G.zero_grad()
        loss_G_L1.backward()
        optimizer_G.step()

230
    # save train result
H
hypox64 已提交
231
    if (iter+1)%1000 == 0:
H
hypox64 已提交
232
        try:
H
hypox64 已提交
233
            data.showresult(inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:],
234
                target, pred, os.path.join(dir_checkpoint,'result_train.jpg'))
H
hypox64 已提交
235 236
        except Exception as e:
            print(e)
H
hypox64 已提交
237

H
HypoX64 已提交
238
    # plot
H
hypox64 已提交
239 240
    if (iter+1)%1000 == 0:
        time_end = time.time()
241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
        #if opt.gan:
        savestr ='iter:{0:d} L1_loss:{1:.3f} GAN_loss:{2:.3f} Feat:{3:.3f} VGG:{4:.3f} time:{5:.2f}'.format(
            iter+1,loss_sum[0]/1000,loss_sum[1]/1000,loss_sum[2]/1000,loss_sum[3]/1000,(time_end-time_start)/1000)
        util.writelog(os.path.join(dir_checkpoint,'loss.txt'), savestr,True)
        if (iter+1)/1000 >= 10:
            for i in range(4):loss_plot[i].append(loss_sum[i]/1000)
            item_plot.append(iter+1)
            try:
                labels = ['L1_loss','GAN_loss','GAN_Feat_loss','VGG_loss']
                for i in range(4):plt.plot(item_plot,loss_plot[i],label=labels[i])     
                plt.xlabel('iter')
                plt.legend(loc=1)
                plt.savefig(os.path.join(dir_checkpoint,'loss.jpg'))
                plt.close()
            except Exception as e:
                print("error:",e)
H
hypox64 已提交
257 258

        loss_sum = [0.,0.,0.,0.,0.,0.]
H
hypox64 已提交
259 260
        time_start=time.time()

H
HypoX64 已提交
261
    # save network
H
hypox64 已提交
262
    if (iter+1)%(opt.savefreq//10) == 0:
H
hypox64 已提交
263
        torch.save(netG.cpu().state_dict(),os.path.join(dir_checkpoint,'last_G.pth'))
264
        if opt.gan:
H
hypox64 已提交
265
            torch.save(netD.cpu().state_dict(),os.path.join(dir_checkpoint,'last_D.pth'))
H
hypox64 已提交
266
        if opt.use_gpu !=-1 :
H
hypox64 已提交
267
            netG.cuda()
268
            if opt.gan:
H
hypox64 已提交
269 270 271 272
                netD.cuda()
        f = open(os.path.join(dir_checkpoint,'iter'),'w+')
        f.write(str(iter+1))
        f.close()
H
hypox64 已提交
273 274 275 276 277

    if (iter+1)%opt.savefreq == 0:
        os.rename(os.path.join(dir_checkpoint,'last_G.pth'),os.path.join(dir_checkpoint,str(iter+1)+'G.pth'))
        if opt.gan:
            os.rename(os.path.join(dir_checkpoint,'last_D.pth'),os.path.join(dir_checkpoint,str(iter+1)+'D.pth'))
H
hypox64 已提交
278 279
        print('network saved.')

H
hypox64 已提交
280 281
    #test
    if (iter+1)%opt.savefreq == 0:
H
HypoX64 已提交
282 283 284 285 286 287
        if os.path.isdir('./test'):  
            netG.eval()
            
            test_names = os.listdir('./test')
            test_names.sort()
            result = np.zeros((opt.finesize*2,opt.finesize*len(test_names),3), dtype='uint8')
H
hypox64 已提交
288

H
HypoX64 已提交
289 290 291 292 293 294 295 296
            for cnt,test_name in enumerate(test_names,0):
                img_names = os.listdir(os.path.join('./test',test_name,'image'))
                img_names.sort()
                inputdata = np.zeros((opt.finesize,opt.finesize,3*N+1), dtype='uint8')
                for i in range(0,N):
                    img = impro.imread(os.path.join('./test',test_name,'image',img_names[i]))
                    img = impro.resize(img,opt.finesize)
                    inputdata[:,:,i*3:(i+1)*3] = img
H
hypox64 已提交
297

H
HypoX64 已提交
298 299 300 301 302 303 304 305 306 307
                mask = impro.imread(os.path.join('./test',test_name,'mask.png'),'gray')
                mask = impro.resize(mask,opt.finesize)
                mask = impro.mask_threshold(mask,15,128)
                inputdata[:,:,-1] = mask
                result[0:opt.finesize,opt.finesize*cnt:opt.finesize*(cnt+1),:] = inputdata[:,:,int((N-1)/2)*3:(int((N-1)/2)+1)*3]
                inputdata = data.im2tensor(inputdata,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1 = False)
                pred = netG(inputdata)
     
                pred = data.tensor2im(pred,rgb2bgr = False, is0_1 = False)
                result[opt.finesize:opt.finesize*2,opt.finesize*cnt:opt.finesize*(cnt+1),:] = pred
H
hypox64 已提交
308

H
HypoX64 已提交
309
            cv2.imwrite(os.path.join(dir_checkpoint,str(iter+1)+'_test.jpg'), result)
310
            netG.train()