train.py 12.1 KB
Newer Older
H
hypox64 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
import os
import numpy as np
import cv2
import random
import torch
import torch.nn as nn
import time

import sys
sys.path.append("..")
sys.path.append("../..")
from util import mosaic,util,ffmpeg,filt,data
from util import image_processing as impro
from cores import Options
15
from models import pix2pix_model,pix2pixHD_model,video_model,unet_model,loadmodel,videoHD_model
H
hypox64 已提交
16 17 18
from matplotlib import pyplot as plt
import torch.backends.cudnn as cudnn

19 20 21 22 23 24 25 26 27 28 29 30 31
opt = Options()
opt.parser.add_argument('--N',type=int,default=25, help='')
opt.parser.add_argument('--lr',type=float,default=0.0002, help='')
opt.parser.add_argument('--beta1',type=float,default=0.5, help='')
opt.parser.add_argument('--gan', action='store_true', help='if input it, use gan')
opt.parser.add_argument('--l2', action='store_true', help='if input it, use L2 loss')
opt.parser.add_argument('--lambda_L1',type=float,default=100, help='')
opt.parser.add_argument('--lambda_gan',type=float,default=1, help='')
opt.parser.add_argument('--finesize',type=int,default=256, help='')
opt.parser.add_argument('--loadsize',type=int,default=286, help='')
opt.parser.add_argument('--batchsize',type=int,default=1, help='')
opt.parser.add_argument('--perload_num',type=int,default=16, help='')
opt.parser.add_argument('--norm',type=str,default='instance', help='')
H
hypox64 已提交
32

33 34 35 36 37 38 39 40
opt.parser.add_argument('--maxiter',type=int,default=10000000, help='')
opt.parser.add_argument('--savefreq',type=int,default=10000, help='')
opt.parser.add_argument('--startiter',type=int,default=0, help='')
opt.parser.add_argument('--continuetrain', action='store_true', help='')
opt.parser.add_argument('--savename',type=str,default='MosaicNet', help='')

opt = opt.getparse()
dir_checkpoint = os.path.join('checkpoints/',opt.savename)
H
hypox64 已提交
41
util.makedirs(dir_checkpoint)
42 43
util.writelog(os.path.join(dir_checkpoint,'loss.txt'), 
              str(time.asctime(time.localtime(time.time())))+'\n'+util.opt2str(opt))
H
hypox64 已提交
44

45
N = opt.N
H
hypox64 已提交
46 47 48 49 50 51 52
loss_sum = [0.,0.,0.,0.]
loss_plot = [[],[]]
item_plot = []

videos = os.listdir('./dataset')
videos.sort()
lengths = []
H
hypox64 已提交
53
print('check dataset...')
H
hypox64 已提交
54 55 56 57 58 59
for video in videos:
    video_images = os.listdir('./dataset/'+video+'/ori')
    lengths.append(len(video_images))
#unet_128
#resnet_9blocks
#netG = pix2pix_model.define_G(3*N+1, 3, 128, 'resnet_6blocks', norm='instance',use_dropout=True, init_type='normal', gpu_ids=[])
60
netG = videoHD_model.MosaicNet(3*N+1, 3, norm=opt.norm)
H
HypoX64 已提交
61
loadmodel.show_paramsnumber(netG,'netG')
H
hypox64 已提交
62
# netG = unet_model.UNet(3*N+1, 3)
63 64
if opt.gan:
    netD = pix2pixHD_model.define_D(6, 64, 3, norm=opt.norm, use_sigmoid=False, num_D=2)
H
hypox64 已提交
65
    #netD = pix2pix_model.define_D(3*2+1, 64, 'pixel', norm='instance')
H
hypox64 已提交
66 67
    #netD = pix2pix_model.define_D(3*2, 64, 'basic', norm='instance')
    #netD = pix2pix_model.define_D(3*2+1, 64, 'n_layers', n_layers_D=5, norm='instance')
H
hypox64 已提交
68

69
if opt.continuetrain:
H
hypox64 已提交
70
    if not os.path.isfile(os.path.join(dir_checkpoint,'last_G.pth')):
71
        opt.continuetrain = False
H
hypox64 已提交
72
        print('can not load last_G, training on init weight.')
73
if opt.continuetrain:     
H
hypox64 已提交
74
    netG.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_G.pth')))
75
    if opt.gan:
H
hypox64 已提交
76 77
        netD.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last_D.pth')))
    f = open(os.path.join(dir_checkpoint,'iter'),'r')
78
    opt.startiter = int(f.read())
H
hypox64 已提交
79 80
    f.close()

81
optimizer_G = torch.optim.Adam(netG.parameters(), lr=opt.lr,betas=(opt.beta1, 0.999))
H
hypox64 已提交
82 83
criterion_L1 = nn.L1Loss()
criterion_L2 = nn.MSELoss()
84 85
if opt.gan:
    optimizer_D = torch.optim.Adam(netG.parameters(), lr=opt.lr,betas=(opt.beta1, 0.999))
H
hypox64 已提交
86 87 88
    # criterionGAN = pix2pix_model.GANLoss(gan_mode='lsgan').cuda()
    criterionGAN = pix2pixHD_model.GANLoss(tensor=torch.cuda.FloatTensor)
    netD.train()
H
hypox64 已提交
89

90
if opt.use_gpu:
H
hypox64 已提交
91
    netG.cuda()
92
    if opt.gan:
H
hypox64 已提交
93 94 95
        netD.cuda()
        criterionGAN.cuda()
    cudnn.benchmark = True
H
hypox64 已提交
96 97 98 99

def loaddata():
    video_index = random.randint(0,len(videos)-1)
    video = videos[video_index]
100 101
    img_index = random.randint(int(N/2)+1,lengths[video_index]- int(N/2)-1)
    input_img = np.zeros((opt.loadsize,opt.loadsize,3*N+1), dtype='uint8')
H
hypox64 已提交
102
    for i in range(0,N):
H
HypoX64 已提交
103
    
H
hypox64 已提交
104
        img = cv2.imread('./dataset/'+video+'/mosaic/output_'+'%05d'%(img_index+i-int(N/2))+'.png')
105
        img = impro.resize(img,opt.loadsize)
H
hypox64 已提交
106 107
        input_img[:,:,i*3:(i+1)*3] = img
    mask = cv2.imread('./dataset/'+video+'/mask/output_'+'%05d'%(img_index)+'.png',0)
108
    mask = impro.resize(mask,opt.loadsize)
H
hypox64 已提交
109 110 111 112
    mask = impro.mask_threshold(mask,15,128)
    input_img[:,:,-1] = mask

    ground_true = cv2.imread('./dataset/'+video+'/ori/output_'+'%05d'%(img_index)+'.png')
113
    ground_true = impro.resize(ground_true,opt.loadsize)
H
hypox64 已提交
114

115
    input_img,ground_true = data.random_transform_video(input_img,ground_true,opt.finesize,N)
H
hypox64 已提交
116 117
    input_img = data.im2tensor(input_img,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1=False)
    ground_true = data.im2tensor(ground_true,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1=False)
H
hypox64 已提交
118 119 120 121
    
    return input_img,ground_true

print('preloading data, please wait 5s...')
H
hypox64 已提交
122

123 124 125 126
if opt.perload_num <= opt.batchsize:
    opt.perload_num = opt.batchsize*2
input_imgs = torch.rand(opt.perload_num,N*3+1,opt.finesize,opt.finesize).cuda()
ground_trues = torch.rand(opt.perload_num,3,opt.finesize,opt.finesize).cuda()
H
hypox64 已提交
127
load_cnt = 0
H
hypox64 已提交
128

H
hypox64 已提交
129
def preload():
H
hypox64 已提交
130
    global load_cnt   
H
hypox64 已提交
131 132
    while 1:
        try:
133
            ran = random.randint(0, opt.perload_num-1)
H
hypox64 已提交
134
            input_imgs[ran],ground_trues[ran] = loaddata()
H
hypox64 已提交
135
            load_cnt += 1
H
hypox64 已提交
136 137 138 139 140 141 142 143
            # time.sleep(0.1)
        except Exception as e:
            print("error:",e)

import threading
t = threading.Thread(target=preload,args=())  #t为新创建的线程
t.daemon = True
t.start()
H
hypox64 已提交
144 145

time_start=time.time()
146
while load_cnt < opt.perload_num:
H
hypox64 已提交
147
    time.sleep(0.1)
H
hypox64 已提交
148
time_end=time.time()
149
print('load speed:',round((time_end-time_start)/opt.perload_num,3),'s/it')
H
hypox64 已提交
150

H
hypox64 已提交
151

H
hypox64 已提交
152
util.copyfile('./train.py', os.path.join(dir_checkpoint,'train.py'))
153
util.copyfile('../../models/videoHD_model.py', os.path.join(dir_checkpoint,'model.py'))
H
hypox64 已提交
154 155 156
netG.train()
time_start=time.time()
print("Begin training...")
157
for iter in range(opt.startiter+1,opt.maxiter):
H
hypox64 已提交
158

159 160 161
    ran = random.randint(0, opt.perload_num-opt.batchsize-1)
    inputdata = input_imgs[ran:ran+opt.batchsize].clone()
    target = ground_trues[ran:ran+opt.batchsize].clone()
H
hypox64 已提交
162

163
    if opt.gan:
H
hypox64 已提交
164 165 166 167 168 169 170
        # compute fake images: G(A)
        pred = netG(inputdata)
        # update D
        pix2pix_model.set_requires_grad(netD,True)
        optimizer_D.zero_grad()
        # Fake
        real_A = inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:]
H
hypox64 已提交
171 172 173
        fake_AB = torch.cat((real_A, pred), 1)
        pred_fake = netD(fake_AB.detach())
        loss_D_fake = criterionGAN(pred_fake, False)
H
hypox64 已提交
174
        # Real
H
hypox64 已提交
175
        real_AB = torch.cat((real_A, target), 1)
H
hypox64 已提交
176 177
        pred_real = netD(real_AB)
        loss_D_real = criterionGAN(pred_real, True)
H
hypox64 已提交
178
        # combine loss and calculate gradients
H
hypox64 已提交
179 180 181
        loss_D = (loss_D_fake + loss_D_real) * 0.5
        loss_sum[2] += loss_D_fake.item()
        loss_sum[3] += loss_D_real.item()
H
hypox64 已提交
182
        # udpate D's weights
H
hypox64 已提交
183 184 185
        loss_D.backward()
        optimizer_D.step()

H
hypox64 已提交
186 187 188 189 190
        # update G
        pix2pix_model.set_requires_grad(netD,False)
        optimizer_G.zero_grad()
        # First, G(A) should fake the discriminator
        real_A = inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:]
H
hypox64 已提交
191 192
        fake_AB = torch.cat((real_A, pred), 1)
        pred_fake = netD(fake_AB)
193
        loss_G_GAN = criterionGAN(pred_fake, True)*opt.lambda_gan
H
hypox64 已提交
194
        # Second, G(A) = B
195 196
        if opt.l2:
            loss_G_L1 = (criterion_L1(pred, target)+criterion_L2(pred, target)) * opt.lambda_L1
H
hypox64 已提交
197
        else:
198
            loss_G_L1 = criterion_L1(pred, target) * opt.lambda_L1
H
hypox64 已提交
199 200 201 202
        # combine loss and calculate gradients
        loss_G = loss_G_GAN + loss_G_L1
        loss_sum[0] += loss_G_L1.item()
        loss_sum[1] += loss_G_GAN.item()
H
hypox64 已提交
203
        # udpate G's weights
H
hypox64 已提交
204 205 206 207
        loss_G.backward()
        optimizer_G.step()

    else:
H
hypox64 已提交
208
        pred = netG(inputdata)
209 210
        if opt.l2:
            loss_G_L1 = (criterion_L1(pred, target)+criterion_L2(pred, target)) * opt.lambda_L1
H
hypox64 已提交
211
        else:
212
            loss_G_L1 = criterion_L1(pred, target) * opt.lambda_L1
H
hypox64 已提交
213 214 215 216 217 218 219 220
        loss_sum[0] += loss_G_L1.item()

        optimizer_G.zero_grad()
        loss_G_L1.backward()
        optimizer_G.step()

    if (iter+1)%100 == 0:
        try:
H
hypox64 已提交
221
            data.showresult(inputdata[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:],
222
             target, pred,os.path.join(dir_checkpoint,'result_train.jpg'))
H
hypox64 已提交
223 224 225 226 227
        except Exception as e:
            print(e)
     
    if (iter+1)%1000 == 0:
        time_end = time.time()
228 229 230 231
        if opt.gan:
            savestr ='iter:{0:d} L1_loss:{1:.4f} G_loss:{2:.4f} D_f:{3:.4f} D_r:{4:.4f} time:{5:.2f}'.format(
                iter+1,loss_sum[0]/1000,loss_sum[1]/1000,loss_sum[2]/1000,loss_sum[3]/1000,(time_end-time_start)/1000)
            util.writelog(os.path.join(dir_checkpoint,'loss.txt'), savestr,True)
H
hypox64 已提交
232 233 234 235 236 237 238
            if (iter+1)/1000 >= 10:
                loss_plot[0].append(loss_sum[0]/1000)
                loss_plot[1].append(loss_sum[1]/1000)
                item_plot.append(iter+1)
                try:
                    plt.plot(item_plot,loss_plot[0])
                    plt.plot(item_plot,loss_plot[1])
239
                    plt.savefig(os.path.join(dir_checkpoint,'loss.jpg'))
H
hypox64 已提交
240 241 242 243
                    plt.close()
                except Exception as e:
                    print("error:",e)
        else:
244 245
            savestr ='iter:{0:d}  L1_loss:{1:.4f}  time:{2:.2f}'.format(iter+1,loss_sum[0]/1000,(time_end-time_start)/1000)
            util.writelog(os.path.join(dir_checkpoint,'loss.txt'), savestr,True)
H
hypox64 已提交
246 247 248 249 250
            if (iter+1)/1000 >= 10:
                loss_plot[0].append(loss_sum[0]/1000)
                item_plot.append(iter+1)
                try:
                    plt.plot(item_plot,loss_plot[0])
251
                    plt.savefig(os.path.join(dir_checkpoint,'loss.jpg'))
H
hypox64 已提交
252 253 254 255 256 257 258
                    plt.close()
                except Exception as e:
                    print("error:",e)
        loss_sum = [0.,0.,0.,0.]
        time_start=time.time()


259 260 261
    if (iter+1)%opt.savefreq == 0:
        if iter+1 != opt.savefreq:
            os.rename(os.path.join(dir_checkpoint,'last_G.pth'),os.path.join(dir_checkpoint,str(iter+1-opt.savefreq)+'G.pth'))
H
hypox64 已提交
262
        torch.save(netG.cpu().state_dict(),os.path.join(dir_checkpoint,'last_G.pth'))
263 264 265
        if opt.gan:
            if iter+1 != opt.savefreq:
                os.rename(os.path.join(dir_checkpoint,'last_D.pth'),os.path.join(dir_checkpoint,str(iter+1-opt.savefreq)+'D.pth'))
H
hypox64 已提交
266
            torch.save(netD.cpu().state_dict(),os.path.join(dir_checkpoint,'last_D.pth'))
267
        if opt.use_gpu:
H
hypox64 已提交
268
            netG.cuda()
269
            if opt.gan:
H
hypox64 已提交
270 271 272 273 274 275 276 277
                netD.cuda()
        f = open(os.path.join(dir_checkpoint,'iter'),'w+')
        f.write(str(iter+1))
        f.close()
        print('network saved.')

        #test
        netG.eval()
H
hypox64 已提交
278
        
H
hypox64 已提交
279
        test_names = os.listdir('./test')
H
hypox64 已提交
280
        test_names.sort()
281
        result = np.zeros((opt.finesize*2,opt.finesize*len(test_names),3), dtype='uint8')
H
hypox64 已提交
282 283 284

        for cnt,test_name in enumerate(test_names,0):
            img_names = os.listdir(os.path.join('./test',test_name,'image'))
H
HypoX64 已提交
285
            img_names.sort()
286
            inputdata = np.zeros((opt.finesize,opt.finesize,3*N+1), dtype='uint8')
H
hypox64 已提交
287 288
            for i in range(0,N):
                img = impro.imread(os.path.join('./test',test_name,'image',img_names[i]))
289
                img = impro.resize(img,opt.finesize)
H
hypox64 已提交
290
                inputdata[:,:,i*3:(i+1)*3] = img
H
hypox64 已提交
291 292

            mask = impro.imread(os.path.join('./test',test_name,'mask.png'),'gray')
293
            mask = impro.resize(mask,opt.finesize)
H
hypox64 已提交
294
            mask = impro.mask_threshold(mask,15,128)
H
hypox64 已提交
295
            inputdata[:,:,-1] = mask
296
            result[0:opt.finesize,opt.finesize*cnt:opt.finesize*(cnt+1),:] = inputdata[:,:,int((N-1)/2)*3:(int((N-1)/2)+1)*3]
H
hypox64 已提交
297 298
            inputdata = data.im2tensor(inputdata,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False,is0_1 = False)
            pred = netG(inputdata)
H
hypox64 已提交
299
 
H
hypox64 已提交
300
            pred = data.tensor2im(pred,rgb2bgr = False, is0_1 = False)
301
            result[opt.finesize:opt.finesize*2,opt.finesize*cnt:opt.finesize*(cnt+1),:] = pred
H
hypox64 已提交
302

303
        cv2.imwrite(os.path.join(dir_checkpoint,str(iter+1)+'_test.jpg'), result)
H
hypox64 已提交
304
        netG.train()