提交 d2beea9b 编写于 作者: H hypox64

update video_model

上级 0a634a51
......@@ -42,23 +42,6 @@ class encoder_2d(nn.Module):
nn.ReLU(True)]
#torch.Size([1, 256, 32, 32])
# mult = 2 ** n_downsampling
# for i in range(n_blocks): # add ResNet blocks
# model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
#torch.Size([1, 256, 32, 32])
# for i in range(n_downsampling): # add upsampling layers
# mult = 2 ** (n_downsampling - i)
# model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
# kernel_size=3, stride=2,
# padding=1, output_padding=1,
# bias=use_bias),
# norm_layer(int(ngf * mult / 2)),
# nn.ReLU(True)]
# model += [nn.ReflectionPad2d(3)]
# model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
# model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
......@@ -117,7 +100,7 @@ class decoder_2d(nn.Module):
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
# model += [nn.Tanh()]
model += [nn.Sigmoid()]
# model += [nn.Sigmoid()]
self.model = nn.Sequential(*model)
......@@ -147,7 +130,6 @@ class encoder_3d(nn.Module):
self.down1 = conv_3d(1, 64, 3, 2, 1)
self.down2 = conv_3d(64, 128, 3, 2, 1)
self.down3 = conv_3d(128, 256, 3, 1, 1)
# self.down4 = conv_3d(256, 512, 3, 2, 1)
self.conver2d = nn.Sequential(
nn.Conv2d(int(in_channel/4), 1, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(1),
......@@ -160,17 +142,13 @@ class encoder_3d(nn.Module):
x = self.down1(x)
x = self.down2(x)
x = self.down3(x)
# x = self.down4(x)
x = x.view(x.size(1),x.size(2),x.size(3),x.size(4))
x = self.conver2d(x)
x = x.view(x.size(1),x.size(0),x.size(2),x.size(3))
# print(x.size())
# x = self.avgpool(x)
return x
# input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'
class HypoNet(nn.Module):
......@@ -180,24 +158,31 @@ class HypoNet(nn.Module):
self.encoder_2d = encoder_2d(4,-1,64,n_blocks=9)
self.encoder_3d = encoder_3d(in_channel)
self.decoder_2d = decoder_2d(4,3,64,n_blocks=9)
self.merge = nn.Sequential(
nn.Conv2d(256, 256, 1, 1, 0, bias=False),
self.merge1 = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(512, 256, 3, 1, 0, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
)
self.merge2 = nn.Sequential(
nn.ReflectionPad2d(3),
nn.Conv2d(6, out_channel, kernel_size=7, padding=0),
nn.Sigmoid()
)
def forward(self, x):
N = int((x.size()[1])/3)
x_2d = torch.cat((x[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], x[:,N-1:N,:,:]), 1)
shortcat_2d = x[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:]
x_2d = self.encoder_2d(x_2d)
x_3d = self.encoder_3d(x)
x = x_2d + x_3d
x = self.merge(x)
# print(x.size())
x = torch.cat((x_2d,x_3d),1)
x = self.merge1(x)
x = self.decoder_2d(x)
x = torch.cat((x,shortcat_2d),1)
x = self.merge2(x)
return x
......@@ -192,12 +192,6 @@ if use_gpu:
net.cuda()
# optimizer = optim.SGD(net.parameters(),
# lr=LR,
# momentum=0.9,
# weight_decay=0.0005)
optimizer = torch.optim.Adam(net.parameters(), lr=LR, betas=(0.9, 0.99))
criterion = nn.BCELoss()
......
......@@ -25,14 +25,15 @@ use_gpu = True
use_gan = False
use_L2 = False
CONTINUE = False
lambda_L1 = 100.0
lambda_L1 = 1.0#100.0
lambda_gan = 1.0
SAVE_FRE = 10000
start_iter = 0
SIZE = 128
finesize = 128
loadsize = int(finesize*1.1)
savename = 'MosaicNet'
savename = 'MosaicNet_test'
dir_checkpoint = 'checkpoints/'+savename
util.makedirs(dir_checkpoint)
......@@ -76,7 +77,15 @@ if use_gan:
optimizer_D = torch.optim.Adam(netG.parameters(), lr=LR,betas=(beta1, 0.999))
criterionGAN = pix2pix_model.GANLoss(gan_mode='lsgan').cuda()
def random_transform(src,target):
def random_transform(src,target,finesize):
#random crop
h,w = target.shape[:2]
h_move = int((h-finesize)*random.random())
w_move = int((w-finesize)*random.random())
# print(h,w,h_move,w_move)
target = target[h_move:h_move+finesize,w_move:w_move+finesize,:]
src = src[h_move:h_move+finesize,w_move:w_move+finesize,:]
#random flip
if random.random()<0.5:
......@@ -110,21 +119,21 @@ def loaddata():
video_index = random.randint(0,len(videos)-1)
video = videos[video_index]
img_index = random.randint(N,lengths[video_index]- N)
input_img = np.zeros((SIZE,SIZE,3*N+1), dtype='uint8')
input_img = np.zeros((loadsize,loadsize,3*N+1), dtype='uint8')
for i in range(0,N):
# print('./dataset/'+video+'/mosaic/output_'+'%05d'%(img_index+i-int(N/2))+'.png')
img = cv2.imread('./dataset/'+video+'/mosaic/output_'+'%05d'%(img_index+i-int(N/2))+'.png')
img = impro.resize(img,SIZE)
img = impro.resize(img,loadsize)
input_img[:,:,i*3:(i+1)*3] = img
mask = cv2.imread('./dataset/'+video+'/mask/output_'+'%05d'%(img_index)+'.png',0)
mask = impro.resize(mask,SIZE)
mask = impro.resize(mask,loadsize)
mask = impro.mask_threshold(mask,15,128)
input_img[:,:,-1] = mask
ground_true = cv2.imread('./dataset/'+video+'/ori/output_'+'%05d'%(img_index)+'.png')
ground_true = impro.resize(ground_true,SIZE)
ground_true = impro.resize(ground_true,loadsize)
input_img,ground_true = random_transform(input_img,ground_true)
input_img,ground_true = random_transform(input_img,ground_true,finesize)
input_img = data.im2tensor(input_img,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False)
ground_true = data.im2tensor(ground_true,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False)
......@@ -135,8 +144,7 @@ input_imgs=[]
ground_trues=[]
load_cnt = 0
def preload():
global load_cnt
load_cnt += 1
global load_cnt
while 1:
try:
input_img,ground_true = loaddata()
......@@ -145,6 +153,7 @@ def preload():
if len(input_imgs)>10:
del(input_imgs[0])
del(ground_trues[0])
load_cnt += 1
# time.sleep(0.1)
except Exception as e:
print("error:",e)
......@@ -153,8 +162,8 @@ import threading
t = threading.Thread(target=preload,args=()) #t为新创建的线程
t.daemon = True
t.start()
time.sleep(5) #wait frist load
while load_cnt < 10:
time.sleep(0.1)
netG.train()
time_start=time.time()
......@@ -172,7 +181,7 @@ for iter in range(start_iter+1,ITER):
netD.train()
# print(input_img[0,3*N,:,:].size())
# print((input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:]).size())
real_A = torch.cat((input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], input_img[:,-1,:,:].reshape(-1,1,SIZE,SIZE)), 1)
real_A = torch.cat((input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], input_img[:,-1,:,:].reshape(-1,1,finesize,finesize)), 1)
fake_AB = torch.cat((real_A, pred), 1)
pred_fake = netD(fake_AB.detach())
loss_D_fake = criterionGAN(pred_fake, False)
......@@ -190,7 +199,7 @@ for iter in range(start_iter+1,ITER):
netD.eval()
# fake_AB = torch.cat((input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], pred), 1)
real_A = torch.cat((input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], input_img[:,-1,:,:].reshape(-1,1,SIZE,SIZE)), 1)
real_A = torch.cat((input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], input_img[:,-1,:,:].reshape(-1,1,finesize,finesize)), 1)
fake_AB = torch.cat((real_A, pred), 1)
pred_fake = netD(fake_AB)
loss_G_GAN = criterionGAN(pred_fake, True)*lambda_gan
......@@ -229,8 +238,8 @@ for iter in range(start_iter+1,ITER):
if (iter+1)%1000 == 0:
time_end = time.time()
if use_gan:
print('iter:',iter+1,' L1_loss:', round(loss_sum[0]/1000,3),' G_loss:', round(loss_sum[1]/1000,3),
' D_f:',round(loss_sum[2]/1000,3),' D_r:',round(loss_sum[3]/1000,3),' time:',round((time_end-time_start)/1000,2))
print('iter:',iter+1,' L1_loss:', round(loss_sum[0]/1000,4),' G_loss:', round(loss_sum[1]/1000,4),
' D_f:',round(loss_sum[2]/1000,4),' D_r:',round(loss_sum[3]/1000,4),' time:',round((time_end-time_start)/1000,2))
if (iter+1)/1000 >= 10:
loss_plot[0].append(loss_sum[0]/1000)
loss_plot[1].append(loss_sum[1]/1000)
......@@ -243,7 +252,7 @@ for iter in range(start_iter+1,ITER):
except Exception as e:
print("error:",e)
else:
print('iter:',iter+1,' L1_loss:',round(loss_sum[0]/1000,3),' time:',round((time_end-time_start)/1000,2))
print('iter:',iter+1,' L1_loss:',round(loss_sum[0]/1000,4),' time:',round((time_end-time_start)/1000,2))
if (iter+1)/1000 >= 10:
loss_plot[0].append(loss_sum[0]/1000)
item_plot.append(iter+1)
......@@ -278,28 +287,28 @@ for iter in range(start_iter+1,ITER):
#test
netG.eval()
result = np.zeros((SIZE*2,SIZE*4,3), dtype='uint8')
result = np.zeros((finesize*2,finesize*4,3), dtype='uint8')
test_names = os.listdir('./test')
for cnt,test_name in enumerate(test_names,0):
img_names = os.listdir(os.path.join('./test',test_name,'image'))
input_img = np.zeros((SIZE,SIZE,3*N+1), dtype='uint8')
input_img = np.zeros((finesize,finesize,3*N+1), dtype='uint8')
img_names.sort()
for i in range(0,N):
img = impro.imread(os.path.join('./test',test_name,'image',img_names[i]))
img = impro.resize(img,SIZE)
img = impro.resize(img,finesize)
input_img[:,:,i*3:(i+1)*3] = img
mask = impro.imread(os.path.join('./test',test_name,'mask.png'),'gray')
mask = impro.resize(mask,SIZE)
mask = impro.resize(mask,finesize)
mask = impro.mask_threshold(mask,15,128)
input_img[:,:,-1] = mask
result[0:SIZE,SIZE*cnt:SIZE*(cnt+1),:] = input_img[:,:,int((N-1)/2)*3:(int((N-1)/2)+1)*3]
result[0:finesize,finesize*cnt:finesize*(cnt+1),:] = input_img[:,:,int((N-1)/2)*3:(int((N-1)/2)+1)*3]
input_img = data.im2tensor(input_img,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False)
pred = netG(input_img)
pred = (pred.cpu().detach().numpy()*255)[0].transpose((1, 2, 0))
result[SIZE:SIZE*2,SIZE*cnt:SIZE*(cnt+1),:] = pred
result[finesize:finesize*2,finesize*cnt:finesize*(cnt+1),:] = pred
cv2.imwrite(os.path.join(dir_checkpoint,str(iter+1)+'_test.png'), result)
netG.eval()
\ No newline at end of file
netG.train()
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册