From d2beea9ba64872244b8b961be91190e54b13266e Mon Sep 17 00:00:00 2001 From: hypox64 Date: Mon, 13 Jan 2020 20:09:56 +0800 Subject: [PATCH] update video_model --- models/video_model.py | 47 ++++++++++++---------------------- train/add/train.py | 6 ----- train/clean/train.py | 59 +++++++++++++++++++++++++------------------ 3 files changed, 50 insertions(+), 62 deletions(-) diff --git a/models/video_model.py b/models/video_model.py index 011e8c8..0cbcab4 100644 --- a/models/video_model.py +++ b/models/video_model.py @@ -42,23 +42,6 @@ class encoder_2d(nn.Module): nn.ReLU(True)] #torch.Size([1, 256, 32, 32]) - # mult = 2 ** n_downsampling - # for i in range(n_blocks): # add ResNet blocks - # model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] - #torch.Size([1, 256, 32, 32]) - - # for i in range(n_downsampling): # add upsampling layers - # mult = 2 ** (n_downsampling - i) - # model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), - # kernel_size=3, stride=2, - # padding=1, output_padding=1, - # bias=use_bias), - # norm_layer(int(ngf * mult / 2)), - # nn.ReLU(True)] - # model += [nn.ReflectionPad2d(3)] - # model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] - # model += [nn.Tanh()] - self.model = nn.Sequential(*model) def forward(self, input): @@ -117,7 +100,7 @@ class decoder_2d(nn.Module): model += [nn.ReflectionPad2d(3)] model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] # model += [nn.Tanh()] - model += [nn.Sigmoid()] + # model += [nn.Sigmoid()] self.model = nn.Sequential(*model) @@ -147,7 +130,6 @@ class encoder_3d(nn.Module): self.down1 = conv_3d(1, 64, 3, 2, 1) self.down2 = conv_3d(64, 128, 3, 2, 1) self.down3 = conv_3d(128, 256, 3, 1, 1) - # self.down4 = conv_3d(256, 512, 3, 2, 1) self.conver2d = nn.Sequential( nn.Conv2d(int(in_channel/4), 1, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(1), @@ -160,17 +142,13 @@ class encoder_3d(nn.Module): x = self.down1(x) x = self.down2(x) x = self.down3(x) - # x = self.down4(x) - x = x.view(x.size(1),x.size(2),x.size(3),x.size(4)) x = self.conver2d(x) x = x.view(x.size(1),x.size(0),x.size(2),x.size(3)) - # print(x.size()) - # x = self.avgpool(x) + return x -# input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect' class HypoNet(nn.Module): @@ -180,24 +158,31 @@ class HypoNet(nn.Module): self.encoder_2d = encoder_2d(4,-1,64,n_blocks=9) self.encoder_3d = encoder_3d(in_channel) self.decoder_2d = decoder_2d(4,3,64,n_blocks=9) - self.merge = nn.Sequential( - nn.Conv2d(256, 256, 1, 1, 0, bias=False), + self.merge1 = nn.Sequential( + nn.ReflectionPad2d(1), + nn.Conv2d(512, 256, 3, 1, 0, bias=False), nn.BatchNorm2d(256), nn.ReLU(inplace=True), ) + self.merge2 = nn.Sequential( + nn.ReflectionPad2d(3), + nn.Conv2d(6, out_channel, kernel_size=7, padding=0), + nn.Sigmoid() + ) def forward(self, x): N = int((x.size()[1])/3) x_2d = torch.cat((x[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], x[:,N-1:N,:,:]), 1) - + shortcat_2d = x[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:] + x_2d = self.encoder_2d(x_2d) x_3d = self.encoder_3d(x) - x = x_2d + x_3d - x = self.merge(x) - # print(x.size()) + x = torch.cat((x_2d,x_3d),1) + x = self.merge1(x) x = self.decoder_2d(x) - + x = torch.cat((x,shortcat_2d),1) + x = self.merge2(x) return x diff --git a/train/add/train.py b/train/add/train.py index c1845cf..49e0888 100644 --- a/train/add/train.py +++ b/train/add/train.py @@ -192,12 +192,6 @@ if use_gpu: net.cuda() - -# optimizer = optim.SGD(net.parameters(), -# lr=LR, -# momentum=0.9, -# weight_decay=0.0005) - optimizer = torch.optim.Adam(net.parameters(), lr=LR, betas=(0.9, 0.99)) criterion = nn.BCELoss() diff --git a/train/clean/train.py b/train/clean/train.py index cd5cb3f..b754aba 100644 --- a/train/clean/train.py +++ b/train/clean/train.py @@ -25,14 +25,15 @@ use_gpu = True use_gan = False use_L2 = False CONTINUE = False -lambda_L1 = 100.0 +lambda_L1 = 1.0#100.0 lambda_gan = 1.0 SAVE_FRE = 10000 start_iter = 0 -SIZE = 128 +finesize = 128 +loadsize = int(finesize*1.1) -savename = 'MosaicNet' +savename = 'MosaicNet_test' dir_checkpoint = 'checkpoints/'+savename util.makedirs(dir_checkpoint) @@ -76,7 +77,15 @@ if use_gan: optimizer_D = torch.optim.Adam(netG.parameters(), lr=LR,betas=(beta1, 0.999)) criterionGAN = pix2pix_model.GANLoss(gan_mode='lsgan').cuda() -def random_transform(src,target): +def random_transform(src,target,finesize): + + #random crop + h,w = target.shape[:2] + h_move = int((h-finesize)*random.random()) + w_move = int((w-finesize)*random.random()) + # print(h,w,h_move,w_move) + target = target[h_move:h_move+finesize,w_move:w_move+finesize,:] + src = src[h_move:h_move+finesize,w_move:w_move+finesize,:] #random flip if random.random()<0.5: @@ -110,21 +119,21 @@ def loaddata(): video_index = random.randint(0,len(videos)-1) video = videos[video_index] img_index = random.randint(N,lengths[video_index]- N) - input_img = np.zeros((SIZE,SIZE,3*N+1), dtype='uint8') + input_img = np.zeros((loadsize,loadsize,3*N+1), dtype='uint8') for i in range(0,N): # print('./dataset/'+video+'/mosaic/output_'+'%05d'%(img_index+i-int(N/2))+'.png') img = cv2.imread('./dataset/'+video+'/mosaic/output_'+'%05d'%(img_index+i-int(N/2))+'.png') - img = impro.resize(img,SIZE) + img = impro.resize(img,loadsize) input_img[:,:,i*3:(i+1)*3] = img mask = cv2.imread('./dataset/'+video+'/mask/output_'+'%05d'%(img_index)+'.png',0) - mask = impro.resize(mask,SIZE) + mask = impro.resize(mask,loadsize) mask = impro.mask_threshold(mask,15,128) input_img[:,:,-1] = mask ground_true = cv2.imread('./dataset/'+video+'/ori/output_'+'%05d'%(img_index)+'.png') - ground_true = impro.resize(ground_true,SIZE) + ground_true = impro.resize(ground_true,loadsize) - input_img,ground_true = random_transform(input_img,ground_true) + input_img,ground_true = random_transform(input_img,ground_true,finesize) input_img = data.im2tensor(input_img,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False) ground_true = data.im2tensor(ground_true,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False) @@ -135,8 +144,7 @@ input_imgs=[] ground_trues=[] load_cnt = 0 def preload(): - global load_cnt - load_cnt += 1 + global load_cnt while 1: try: input_img,ground_true = loaddata() @@ -145,6 +153,7 @@ def preload(): if len(input_imgs)>10: del(input_imgs[0]) del(ground_trues[0]) + load_cnt += 1 # time.sleep(0.1) except Exception as e: print("error:",e) @@ -153,8 +162,8 @@ import threading t = threading.Thread(target=preload,args=()) #t为新创建的线程 t.daemon = True t.start() -time.sleep(5) #wait frist load - +while load_cnt < 10: + time.sleep(0.1) netG.train() time_start=time.time() @@ -172,7 +181,7 @@ for iter in range(start_iter+1,ITER): netD.train() # print(input_img[0,3*N,:,:].size()) # print((input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:]).size()) - real_A = torch.cat((input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], input_img[:,-1,:,:].reshape(-1,1,SIZE,SIZE)), 1) + real_A = torch.cat((input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], input_img[:,-1,:,:].reshape(-1,1,finesize,finesize)), 1) fake_AB = torch.cat((real_A, pred), 1) pred_fake = netD(fake_AB.detach()) loss_D_fake = criterionGAN(pred_fake, False) @@ -190,7 +199,7 @@ for iter in range(start_iter+1,ITER): netD.eval() # fake_AB = torch.cat((input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], pred), 1) - real_A = torch.cat((input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], input_img[:,-1,:,:].reshape(-1,1,SIZE,SIZE)), 1) + real_A = torch.cat((input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], input_img[:,-1,:,:].reshape(-1,1,finesize,finesize)), 1) fake_AB = torch.cat((real_A, pred), 1) pred_fake = netD(fake_AB) loss_G_GAN = criterionGAN(pred_fake, True)*lambda_gan @@ -229,8 +238,8 @@ for iter in range(start_iter+1,ITER): if (iter+1)%1000 == 0: time_end = time.time() if use_gan: - print('iter:',iter+1,' L1_loss:', round(loss_sum[0]/1000,3),' G_loss:', round(loss_sum[1]/1000,3), - ' D_f:',round(loss_sum[2]/1000,3),' D_r:',round(loss_sum[3]/1000,3),' time:',round((time_end-time_start)/1000,2)) + print('iter:',iter+1,' L1_loss:', round(loss_sum[0]/1000,4),' G_loss:', round(loss_sum[1]/1000,4), + ' D_f:',round(loss_sum[2]/1000,4),' D_r:',round(loss_sum[3]/1000,4),' time:',round((time_end-time_start)/1000,2)) if (iter+1)/1000 >= 10: loss_plot[0].append(loss_sum[0]/1000) loss_plot[1].append(loss_sum[1]/1000) @@ -243,7 +252,7 @@ for iter in range(start_iter+1,ITER): except Exception as e: print("error:",e) else: - print('iter:',iter+1,' L1_loss:',round(loss_sum[0]/1000,3),' time:',round((time_end-time_start)/1000,2)) + print('iter:',iter+1,' L1_loss:',round(loss_sum[0]/1000,4),' time:',round((time_end-time_start)/1000,2)) if (iter+1)/1000 >= 10: loss_plot[0].append(loss_sum[0]/1000) item_plot.append(iter+1) @@ -278,28 +287,28 @@ for iter in range(start_iter+1,ITER): #test netG.eval() - result = np.zeros((SIZE*2,SIZE*4,3), dtype='uint8') + result = np.zeros((finesize*2,finesize*4,3), dtype='uint8') test_names = os.listdir('./test') for cnt,test_name in enumerate(test_names,0): img_names = os.listdir(os.path.join('./test',test_name,'image')) - input_img = np.zeros((SIZE,SIZE,3*N+1), dtype='uint8') + input_img = np.zeros((finesize,finesize,3*N+1), dtype='uint8') img_names.sort() for i in range(0,N): img = impro.imread(os.path.join('./test',test_name,'image',img_names[i])) - img = impro.resize(img,SIZE) + img = impro.resize(img,finesize) input_img[:,:,i*3:(i+1)*3] = img mask = impro.imread(os.path.join('./test',test_name,'mask.png'),'gray') - mask = impro.resize(mask,SIZE) + mask = impro.resize(mask,finesize) mask = impro.mask_threshold(mask,15,128) input_img[:,:,-1] = mask - result[0:SIZE,SIZE*cnt:SIZE*(cnt+1),:] = input_img[:,:,int((N-1)/2)*3:(int((N-1)/2)+1)*3] + result[0:finesize,finesize*cnt:finesize*(cnt+1),:] = input_img[:,:,int((N-1)/2)*3:(int((N-1)/2)+1)*3] input_img = data.im2tensor(input_img,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False) pred = netG(input_img) pred = (pred.cpu().detach().numpy()*255)[0].transpose((1, 2, 0)) - result[SIZE:SIZE*2,SIZE*cnt:SIZE*(cnt+1),:] = pred + result[finesize:finesize*2,finesize*cnt:finesize*(cnt+1),:] = pred cv2.imwrite(os.path.join(dir_checkpoint,str(iter+1)+'_test.png'), result) - netG.eval() \ No newline at end of file + netG.train() \ No newline at end of file -- GitLab