提交 a6994b52 编写于 作者: H hypox64

add InstanceNorm

上级 29458f1b
......@@ -141,6 +141,7 @@ test*/
video_tmp/
result/
#./
/pix2pix
/pix2pixHD
/tmp
/to_make_show
......
......@@ -97,7 +97,7 @@ def init_weights(net, init_type='normal', init_gain=0.02):
init.normal_(m.weight.data, 1.0, init_gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
#print('initialize network with %s' % init_type)
net.apply(init_func) # apply the initialization function <init_func>
......
......@@ -31,4 +31,4 @@ class UNet(nn.Module):
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.outc(x)
return torch.sigmoid(x)
return torch.Tanh(x)
\ No newline at end of file
......@@ -4,13 +4,23 @@ import torch.nn.functional as F
from .unet_parts import *
from .pix2pix_model import *
Norm = 'batch'
if Norm == 'instance':
NormLayer_2d = nn.InstanceNorm2d
NormLayer_3d = nn.InstanceNorm3d
use_bias = False
else:
NormLayer_2d = nn.BatchNorm2d
NormLayer_3d = nn.BatchNorm3d
use_bias = True
class encoder_2d(nn.Module):
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
"""
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=NormLayer_2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
"""Construct a Resnet-based generator
Parameters:
......@@ -55,7 +65,7 @@ class decoder_2d(nn.Module):
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
"""
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=NormLayer_2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
"""Construct a Resnet-based generator
Parameters:
......@@ -114,8 +124,8 @@ class conv_3d(nn.Module):
def __init__(self,inchannel,outchannel,kernel_size=3,stride=2,padding=1):
super(conv_3d, self).__init__()
self.conv = nn.Sequential(
nn.Conv3d(inchannel, outchannel, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
nn.BatchNorm3d(outchannel),
nn.Conv3d(inchannel, outchannel, kernel_size=kernel_size, stride=stride, padding=padding, bias=use_bias),
NormLayer_3d(outchannel),
nn.ReLU(inplace=True),
)
......@@ -128,8 +138,8 @@ class conv_2d(nn.Module):
super(conv_2d, self).__init__()
self.conv = nn.Sequential(
nn.ReflectionPad2d(padding),
nn.Conv2d(inchannel, outchannel, kernel_size=kernel_size, stride=stride, padding=0, bias=False),
nn.BatchNorm2d(outchannel),
nn.Conv2d(inchannel, outchannel, kernel_size=kernel_size, stride=stride, padding=0, bias=use_bias),
NormLayer_2d(outchannel),
nn.ReLU(inplace=True),
)
......@@ -145,8 +155,8 @@ class encoder_3d(nn.Module):
self.down2 = conv_3d(64, 128, 3, 2, 1)
self.down3 = conv_3d(128, 256, 3, 1, 1)
self.conver2d = nn.Sequential(
nn.Conv2d(256*int(in_channel/4), 256, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.Conv2d(256*int(in_channel/4), 256, kernel_size=3, stride=1, padding=1, bias=use_bias),
NormLayer_2d(256),
nn.ReLU(inplace=True),
)
......
......@@ -50,8 +50,8 @@ def Toinputshape(imgs,masks,finesize):
# print(imgs[i].shape,masks[i].shape)
img,mask = data.random_transform_image(imgs[i], masks[i], finesize)
# print(img.shape,mask.shape)
mask = mask.reshape(1,finesize,finesize)/255.0
img = img.transpose((2, 0, 1))/255.0
mask = (mask.reshape(1,finesize,finesize)/255.0-0.5)/0.5
img = (img.transpose((2, 0, 1))/255.0-0.5)/0.5
result_imgs.append(img)
result_masks.append(mask)
result_imgs = np.array(result_imgs)
......
......@@ -18,21 +18,21 @@ import torch.backends.cudnn as cudnn
N = 25
ITER = 10000000
LR = 0.001
LR = 0.0002
beta1 = 0.5
use_gpu = True
use_gan = False
use_L2 = True
CONTINUE = True
lambda_L1 = 1.0#100.0
lambda_gan = 1.0
CONTINUE = False
lambda_L1 = 100.0
lambda_gan = 1
SAVE_FRE = 10000
start_iter = 0
finesize = 128
loadsize = int(finesize*1.1)
batchsize = 8
perload_num = 32
batchsize = 1
perload_num = 16
savename = 'MosaicNet_test'
dir_checkpoint = 'checkpoints/'+savename
util.makedirs(dir_checkpoint)
......@@ -45,6 +45,7 @@ opt = Options().getparse()
videos = os.listdir('./dataset')
videos.sort()
lengths = []
print('check dataset...')
for video in videos:
video_images = os.listdir('./dataset/'+video+'/ori')
lengths.append(len(video_images))
......@@ -55,7 +56,8 @@ netG = video_model.MosaicNet(3*N+1, 3)
loadmodel.show_paramsnumber(netG,'netG')
# netG = unet_model.UNet(3*N+1, 3)
if use_gan:
netD = pix2pix_model.define_D(3*2+1, 64, 'basic', n_layers_D=3, norm='instance', init_type='normal', init_gain=0.02, gpu_ids=[])
#netD = pix2pix_model.define_D(3*2+1, 64, 'pixel', norm='instance')
netD = pix2pix_model.define_D(3*2+1, 64, 'basic', norm='instance')
#netD = pix2pix_model.define_D(3*2+1, 64, 'n_layers', n_layers_D=5, norm='instance', init_type='normal', init_gain=0.02, gpu_ids=[])
if CONTINUE:
......@@ -104,26 +106,19 @@ def loaddata():
return input_img,ground_true
print('preloading data, please wait 5s...')
# input_imgs=[]
# ground_trues=[]
input_imgs = torch.rand(batchsize,N*3+1,finesize,finesize).cuda()
ground_trues = torch.rand(batchsize,3,finesize,finesize).cuda()
if perload_num <= batchsize:
perload_num = batchsize*2
input_imgs = torch.rand(perload_num,N*3+1,finesize,finesize).cuda()
ground_trues = torch.rand(perload_num,3,finesize,finesize).cuda()
load_cnt = 0
def preload():
global load_cnt
while 1:
try:
# input_img,ground_true = loaddata()
# input_imgs.append(input_img)
# ground_trues.append(ground_true)
ran = random.randint(0, batchsize-1)
ran = random.randint(0, perload_num-1)
input_imgs[ran],ground_trues[ran] = loaddata()
# if len(input_imgs)>perload_num:
# del(input_imgs[0])
# del(ground_trues[0])
load_cnt += 1
# time.sleep(0.1)
except Exception as e:
......@@ -133,21 +128,24 @@ import threading
t = threading.Thread(target=preload,args=()) #t为新创建的线程
t.daemon = True
t.start()
while load_cnt < batchsize*2:
time_start=time.time()
while load_cnt < perload_num:
time.sleep(0.1)
time_end=time.time()
print('load speed:',round((time_end-time_start)/perload_num,3),'s/it')
util.copyfile('./train.py', os.path.join(dir_checkpoint,'train.py'))
util.copyfile('../../models/video_model.py', os.path.join(dir_checkpoint,'model.py'))
netG.train()
time_start=time.time()
print("Begin training...")
for iter in range(start_iter+1,ITER):
# inputdata,target = loaddata()
# ran = random.randint(1, perload_num-2)
# inputdata = inputdatas[ran]
# target = targets[ran]
inputdata = input_imgs.clone()
target = ground_trues.clone()
ran = random.randint(0, perload_num-batchsize-1)
inputdata = input_imgs[ran:ran+batchsize].clone()
target = ground_trues[ran:ran+batchsize].clone()
pred = netG(inputdata)
......@@ -262,13 +260,13 @@ for iter in range(start_iter+1,ITER):
netG.eval()
test_names = os.listdir('./test')
test_names.sort()
result = np.zeros((finesize*2,finesize*len(test_names),3), dtype='uint8')
for cnt,test_name in enumerate(test_names,0):
img_names = os.listdir(os.path.join('./test',test_name,'image'))
img_names.sort()
inputdata = np.zeros((finesize,finesize,3*N+1), dtype='uint8')
img_names.sort()
for i in range(0,N):
img = impro.imread(os.path.join('./test',test_name,'image',img_names[i]))
img = impro.resize(img,finesize)
......@@ -286,4 +284,4 @@ for iter in range(start_iter+1,ITER):
result[finesize:finesize*2,finesize*cnt:finesize*(cnt+1),:] = pred
cv2.imwrite(os.path.join(dir_checkpoint,str(iter+1)+'_test.png'), result)
netG.train()
\ No newline at end of file
netG.train()
......@@ -74,11 +74,11 @@ def random_transform_video(src,target,finesize,N):
target = target[:,::-1,:]
#random color
alpha = random.uniform(-0.2,0.2)
alpha = random.uniform(-0.3,0.3)
beta = random.uniform(-0.2,0.2)
b = random.uniform(-0.1,0.1)
g = random.uniform(-0.1,0.1)
r = random.uniform(-0.1,0.1)
b = random.uniform(-0.05,0.05)
g = random.uniform(-0.05,0.05)
r = random.uniform(-0.05,0.05)
for i in range(N):
src[:,:,i*3:(i+1)*3] = color_adjust(src[:,:,i*3:(i+1)*3],alpha,beta,b,g,r)
target = color_adjust(target,alpha,beta,b,g,r)
......
......@@ -79,4 +79,10 @@ def get_bar(percent,num = 25):
else:
bar += '-'
bar += ']'
return bar+' '+str(round(percent,2))+'%'
\ No newline at end of file
return bar+' '+str(round(percent,2))+'%'
def copyfile(scr,dst):
try:
shutil.copyfile(src, dst)
except Exception as e:
print(e)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册