From 9de9141454b4d249b88b9de493484a4d8a43757e Mon Sep 17 00:00:00 2001 From: hypox64 Date: Sun, 12 Jan 2020 23:59:14 +0800 Subject: [PATCH] just commit, unstable --- cores/core.py | 9 +- make_datasets/csv/video_used_time.csv | 22 +- make_datasets/cut_video.py | 9 +- .../use_addmosaic_model_make_video_dataset.py | 115 +++---- models/loadmodel.py | 12 +- models/video_model.py | 203 ++++++++++++ models/video_model_unet.py | 108 +++++++ train/add/train.py | 259 +++++++++++++++ train/clean/train.py | 305 ++++++++++++++++++ train/train.py | 209 ------------ util/image_processing.py | 4 +- 11 files changed, 978 insertions(+), 277 deletions(-) create mode 100644 models/video_model.py create mode 100644 models/video_model_unet.py create mode 100644 train/add/train.py create mode 100644 train/clean/train.py delete mode 100644 train/train.py diff --git a/cores/core.py b/cores/core.py index c544a40..57bb510 100644 --- a/cores/core.py +++ b/cores/core.py @@ -104,10 +104,11 @@ def cleanmosaic_video_byframe(opt): os.path.join(opt.result_dir,os.path.splitext(os.path.basename(path))[0]+'_clean.mp4')) def cleanmosaic_video_fusion(opt): - net = loadmodel.pix2pix(opt) + net = loadmodel.video(opt) net_mosaic_pos = loadmodel.unet_clean(opt) path = opt.media_path N = 25 + INPUT_SIZE = 128 util.clean_tempfiles() fps = ffmpeg.get_video_infos(path)[0] @@ -140,15 +141,15 @@ def cleanmosaic_video_fusion(opt): if size==0: cv2.imwrite(os.path.join('./tmp/replace_mosaic',imagepath),img_origin) else: - mosaic_input = np.zeros((256,256,3*N+1), dtype='uint8') + mosaic_input = np.zeros((INPUT_SIZE,INPUT_SIZE,3*N+1), dtype='uint8') for j in range(0,N): img = impro.imread(os.path.join('./tmp/video2image',imagepaths[np.clip(i+j-12,0,len(imagepaths)-1)])) img = img[y-size:y+size,x-size:x+size] - img = impro.resize(img,256) + img = impro.resize(img,INPUT_SIZE) mosaic_input[:,:,j*3:(j+1)*3] = img mask = impro.resize(mask,np.min(img_origin.shape[:2])) mask = mask[y-size:y+size,x-size:x+size] - mask = impro.resize(mask, 256) + mask = impro.resize(mask, INPUT_SIZE) mosaic_input[:,:,-1] = mask mosaic_input = data.im2tensor(mosaic_input,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False) unmosaic_pred = net(mosaic_input) diff --git a/make_datasets/csv/video_used_time.csv b/make_datasets/csv/video_used_time.csv index 6dee6b2..3250579 100644 --- a/make_datasets/csv/video_used_time.csv +++ b/make_datasets/csv/video_used_time.csv @@ -17,4 +17,24 @@ 1pondo_070315_108_1080p.mp4,00:11:10,00:11:50,00:13:50,00:14:20,00:14:35,00:15:50,00:17:20,00:18:35,00:20:45,00:24:35,00:25:05,00:29:15,00:30:40,00:31:55,00:35:20,00:42:55,00:43:05,00:46:15,00:48:00,00:51:45,00:52:33,00:54:20,00:59:25,00:59:40,01:00:05 071114_842-1pon-whole1_hd.mp4,00:09:50,00:11:25,00:16:35,00:18:20,00:22:10,00:25:25,00:26:35,00:33:50,00:35:40,00:43:10 071715_116-1pon-1080p.mp4,00:10:50,00:11:30,00:12:50,00:15:10,00:16:45,00:17:05,00:25:20,00:26:45,00:28:30,00:30:20,00:32:55,00:34:30,00:37:40,00:38:40,00:40:20,00:41:20,00:44:10,00:47:15,00:55:00,00:59:40,00:59:50 -071815_117-1pon-1080p.mp4,00:14:50,00:15:10,00:18:05,00:14:50,00:25:55,00:26:25,00:32:45,00:33:40,00:43:15,00:45:05,00:45:45,00:48:40,00:48:50,00:55:45,10:00:20,01:00:35,01:01:00,01:01:10 \ No newline at end of file +071815_117-1pon-1080p.mp4,00:14:50,00:15:10,00:18:05,00:14:50,00:25:55,00:26:25,00:32:45,00:33:40,00:43:15,00:45:05,00:45:45,00:48:40,00:48:50,00:55:45,10:00:20,01:00:35,01:01:00,01:01:10 +080815_130-1pon-1080p,00:14:50,00:17:15,00:17:20,00:23:55,00:25:30,00:25:55,00:28:20,00:28:30,00:30:10,00:31:00,00:33:25,00:33:35,00:33:45,00:33:50,00:39:25,00:39:50,00:40:25,00:44:05,00:45:00,00:45:40,00:45:50,00:46:55,00:49:15,00:49:25,00:46:40,00:50:10,00:50:15,00:51:25,00:51:50,00:53:14,00:53:20,00:54:15,00:56:15,00:56:25,00:56:45,00:57:45,00:57:30,00:58:00,00:56:45,00:56:55,01:00:00,01:00:05,01:00:25,01:00:30 +081514_863-1pon-whole1_hd.avi,00:10:30,00:26:00,00:30:00,00:38:21,00:40:15,00:40:30,00:49:10,00:50:05,00:57:10,00:59:00 +090614_877-1pon-whole1_hd.mp4,00:04:45,00:05:15,00:12:25,00:12:40,00:15:00,00:15:15,00:16:25,00:20:50,00:21:45,00:26:10,00:33:35,00:35:55,00:37:50,00:37:55,00:38:12,00:39:55,00:41:50,00:44:27,00:44:37,00:46:30,00:47:35,00:47:40,00:48:20,00:59:50 +091215_152-1pon-1080p.mp4,00:05:30,00:06:10,00:06:20,00:08:15,00:10:10,00:11:15,00:12:15,00:12:55,0:15:15,00:15:35,00:18:00,00:24:45,00:25:45,00:33:45,00:35:32,00:37:35,00:37:55,00:38:50,00:42:15,00:45:00,00:47:55,00:48:20,00:48:35,00:48:42,00:49:43,00:50:15,00:51:10,00:55:35,00:57:00,00:57:55,01:03:30,01:05:00 +092813_670-1pon-whole1_hd.avi,00:16:32,00:19:00,00:22:10,00:23:20,00:23:40,00:30:20,00:32:00,00:35:00,00:36:50,00:41:40,00:44:50,00:52:45,00:54:00 +103015_180-1pon-1080p.mp4,00:24:50,00:31:25,00:41:20,00:48:10,00:48:50,00:49:20,00:50:15,00:52:45,00:53:30,01:02:40,01:03:35,01:09:50,01:15:05,01:16:50 +110615_185-1pon-1080p.mp4,00:15:00,00:15:40,00:34:15,00:34:50,00:35:30,00:37:05,00:39:35,00:40:30,00:41:40,00:47:35,00:50:15,00:51:01,00:51:35,00:54:15,00:55:40,00:55:50,00:57:20,00:59:35,01:00:00,01:00:25 +120310_979-1pon-whole1_hd.avi,00:15:10,00:14:25,00:14:30,00:14:50,00:15:45,00:16:35,00:16:55,00:17:25,00:19:25,00:20:45,00:27:05,00:30:17,00:32:00,00:33:50,00:35:45,00:38:55,00:40:25,00:40:40,00:41:10,00:42:50,00:44:35,00:45:15,00:46:15,00:48:00,00:49:10,00:50:10,00:54:00,00:55:23,00:55:30,00:55:50 +021315-806-carib-1080p.mp4,00:13:30,00:15:20,00:17:40,00:21:50,00:22:25,00:24:35,00:28:50,00:28:52,00:31:00,00:37:25,00:37:35,00:38:20,00:38:45,00:43:30,00:48:35,00:51:30,00:51:50,00:52:19,00:56:20,00:58:35 +021715-809-carib-1080p.mp4,00:17:30,00:20:35,00:21:00,00:22:00,00:23:55,00:24:15,00:28:40,00:37:20,00:39:05,00:40:05,00:40:50,00:42:45,00:45:00,00:46:40,00:48:00,00:48:20,00:51:30,00:52:10,00:53:35,00:54:10,00:54:20,00:56:45,00:56:55,00:59:10,00:59:35,00:59:55 +022715-817-carib-1080p.mp4,00:57:52,00:08:50,00:10:00,00:12:50,00:14:05,00:18:25,00:20:45,00:20:57,00:22:15,00:23:30,00:23:55,00:24:18,00:24:50,00:25:25,00:26:30,00:26:55,00:28:50,00:31:55,00:34:00,00:34:35,00:42:45,00:44:33 +030914-558-carib-high_1.mp4,00:10:45,00:12:45,00:14:40,00:16:33,00:19:40,00:21:35,00:21:55,00:23:05,00:26:15,00:27:30,00:29:55,00:31:10,00:31:40,00:36:40,00:41:40,00:42:40,00:44:50,00:49:50,00:52:25,00:53:50,00:54:30,00:55:20,00:55:10,00:57:05,00:57:25,00:59:05,01:00:15,01:02:11,01:03:55,01:05:10 +031815-830-carib-1080p.mp4,00:13:15,00:13:25,00:13:55,00:14:40,00:15:40,00:17:30,00:18:20,00:19:10,00:21:00,00:22:10,00:22:25,00:23:25,00:27:10,00:28:33,00:35:05,00:35:40,00:37:50,00:38:00,00:39:35,00:41:35,00:42:40,00:47:40,00:50:33,00:55:50,01:02:10,01:05:20,01:05:30 +032016-121-carib-1080p.mp4,00:27:20,00:28:40,00:28:55,00:30:35,00:36:10,00:39:10,00:40:30,00:43:00,00:46:05,00:50:00,00:56:05,00:56:20,00:59:20 +032913-301-carib-whole_hd1.wmv,00:06:00,00:09:40,00:11:00,00:13:00,00:15:05,00:16:40,00:18:05,00:20:00,00:39:31,00:34:35,00:44:50,00:47:25,00:49:50,00:51:20,00:54:58,00:56:55,00:59:50,01:00:50 +032914-571-carib-high_1.mp4,00:13:30,00:13:55,00:16:40,00:15:25,00:20:40,00:26:45,00:32:05,00:33:15,00:36:40,00:38:55,00:39:00,00:39:25,00:47:30,00:49:20 +042514-588-carib-high_1.mp4,00:10:30,00:11:15,00:19:15,00:20:00,00:20:30,00:22:05,00:22:45,00:22:53,00:24:15,00:30:50,00:32:25,00:34:15,00:34:45,00:34:55,0:36:05,00:37:20,00:37:40,00:38:30,00:39:35,00:41:00,00:43:30,00:43:40 +052315-884-carib-1080p.mp4,00:09:35,00:14:10,00:14:30,00:14:40,00:17:10,00:17:50,00:19:00,00:20:20,01:21:55,00:22:40,00:23:05,00:24:00,00:26:00,00:27:15,00:30:25,00:32:50,00:37:55,0:39:35,00:40:10,00:41:40,00:43:15,00:43:40,00:47:55,00:49:30,00:49:55,00:58:55,01:00:40 +053114-612-carib-high_1.mp4,00:08:35,00:13:35,00:15:25,00:16:40,00:20:35,00:22:25,00:26:10,00:29:10,00:32:55,00:34:10,00:37:05,00:37:40,00:39:40,00:40:52,00:42:08,00:42:15 +062615-908-carib-1080p.mp4,00:13:45,00:14:40,00:15:45,00:16:11,00:17:00,00:22:10,00:23:40,00:26:10,00:27:15,00:27:50,00:31:30,00:35:00,00:40:20,00:43:10,00:44:35,00:47:17,00:50:25,00:51:15,00:52:20,00:54:10,00:55:30,01:00:20 \ No newline at end of file diff --git a/make_datasets/cut_video.py b/make_datasets/cut_video.py index 7b4d317..6e3ad1a 100644 --- a/make_datasets/cut_video.py +++ b/make_datasets/cut_video.py @@ -13,21 +13,20 @@ files = util.Traversal('/media/hypo/Media/download') videos = util.is_videos(files) -video_times = [] + useable_videos = [] video_dict = {} reader = csv.reader(open('./csv/video_used_time.csv')) for line in reader: useable_videos.append(line[0]) - video_times.append(line[1:]) video_dict[line[0]]=line[1:] in_cnt = 0 -out_cnt = 502 +out_cnt = 1 for video in videos: if os.path.basename(video) in useable_videos: - # print(video) - for i in range(len(video_times[in_cnt])): + + for i in range(len(video_dict[os.path.basename(video)])): ffmpeg.cut_video(video, video_dict[os.path.basename(video)][i], '00:00:05', './video/'+'%04d'%out_cnt+'.mp4') out_cnt +=1 in_cnt += 1 diff --git a/make_datasets/use_addmosaic_model_make_video_dataset.py b/make_datasets/use_addmosaic_model_make_video_dataset.py index 1e715da..af5b3b3 100644 --- a/make_datasets/use_addmosaic_model_make_video_dataset.py +++ b/make_datasets/use_addmosaic_model_make_video_dataset.py @@ -8,72 +8,77 @@ sys.path.append("..") from models import runmodel,loadmodel from util import mosaic,util,ffmpeg,filt from util import image_processing as impro -from options import Options +from cores import options -opt = Options().getparse() +opt = options.Options().getparse() util.file_init(opt) videos = os.listdir('./video') videos.sort() opt.model_path = '../pretrained_models/add_youknow_128.pth' opt.use_gpu = True +Ex = 1.4 +Area_Type = 'normal' +suffix = '' net = loadmodel.unet(opt) for path in videos: + try: + path = os.path.join('./video',path) + util.clean_tempfiles() + ffmpeg.video2voice(path,'./tmp/voice_tmp.mp3') + ffmpeg.video2image(path,'./tmp/video2image/output_%05d.'+opt.tempimage_type) + imagepaths=os.listdir('./tmp/video2image') + imagepaths.sort() - path = os.path.join('./video',path) - util.clean_tempfiles() - ffmpeg.video2voice(path,'./tmp/voice_tmp.mp3') - ffmpeg.video2image(path,'./tmp/video2image/output_%05d.'+opt.tempimage_type) - imagepaths=os.listdir('./tmp/video2image') - imagepaths.sort() + # get position + positions = [] + img_ori_example = impro.imread(os.path.join('./tmp/video2image',imagepaths[0])) + mask_avg = np.zeros((impro.resize(img_ori_example, 128)).shape[:2]) + for imagepath in imagepaths: + imagepath = os.path.join('./tmp/video2image',imagepath) + print('Find ROI location:',imagepath) + img = impro.imread(imagepath) + x,y,size,mask = runmodel.get_mosaic_position(img,net,opt,threshold = 64) + cv2.imwrite(os.path.join('./tmp/ROI_mask', + os.path.basename(imagepath)),mask) + positions.append([x,y,size]) + mask_avg = mask_avg + mask + print('Optimize ROI locations...') + mask_index = filt.position_medfilt(np.array(positions), 13) - # get position - positions = [] - img_ori_example = impro.imread(os.path.join('./tmp/video2image',imagepaths[0])) - mask_avg = np.zeros((impro.resize(img_ori_example, 128)).shape[:2]) - for imagepath in imagepaths: - imagepath = os.path.join('./tmp/video2image',imagepath) - print('Find ROI location:',imagepath) - img = impro.imread(imagepath) - x,y,size,mask = runmodel.get_mosaic_position(img,net,opt,threshold = 64) - cv2.imwrite(os.path.join('./tmp/ROI_mask', - os.path.basename(imagepath)),mask) - positions.append([x,y,size]) - mask_avg = mask_avg + mask - print('Optimize ROI locations...') - mask_index = filt.position_medfilt(np.array(positions), 13) + mask = np.clip(mask_avg/len(imagepaths),0,255).astype('uint8') + mask = impro.mask_threshold(mask,20,32) + x,y,size,area = impro.boundingSquare(mask,Ex_mul=Ex) + rat = min(img_ori_example.shape[:2])/128.0 + x,y,size = int(rat*x),int(rat*y),int(rat*size) + cv2.imwrite(os.path.join('./tmp/ROI_mask_check', + 'test_show.png'),mask) + if size !=0 : + mask_path = './dataset/'+os.path.splitext(os.path.basename(path))[0]+suffix+'/mask' + ori_path = './dataset/'+os.path.splitext(os.path.basename(path))[0]+suffix+'/ori' + mosaic_path = './dataset/'+os.path.splitext(os.path.basename(path))[0]+suffix+'/mosaic' + os.makedirs('./dataset/'+os.path.splitext(os.path.basename(path))[0]+suffix) + os.makedirs(mask_path) + os.makedirs(ori_path) + os.makedirs(mosaic_path) + print('Add mosaic to images...') + mosaic_size = mosaic.get_autosize(img_ori_example,mask,area_type = Area_Type)*random.uniform(1,2) + models = ['squa_avg','rect_avg','squa_mid'] + mosaic_type = random.randint(0,len(models)-1) + rect_rat = random.uniform(1.2,1.6) + for i in range(len(imagepaths)): + mask = impro.imread(os.path.join('./tmp/ROI_mask',imagepaths[mask_index[i]]),mod = 'gray') + img_ori = impro.imread(os.path.join('./tmp/video2image',imagepaths[i])) + img_mosaic = mosaic.addmosaic_normal(img_ori,mask,mosaic_size,model = models[mosaic_type],rect_rat=rect_rat) + mask = impro.resize(mask, min(img_ori.shape[:2])) - mask = np.clip(mask_avg/len(imagepaths),0,255).astype('uint8') - mask = impro.mask_threshold(mask,20,32) - x,y,size,area = impro.boundingSquare(mask,Ex_mul=1.5) - rat = min(img_ori_example.shape[:2])/128.0 - x,y,size = int(rat*x),int(rat*y),int(rat*size) - cv2.imwrite(os.path.join('./tmp/ROI_mask_check', - 'test_show.png'),mask) - if size !=0 : - mask_path = './dataset/'+os.path.splitext(os.path.basename(path))[0]+'/mask' - ori_path = './dataset/'+os.path.splitext(os.path.basename(path))[0]+'/ori' - mosaic_path = './dataset/'+os.path.splitext(os.path.basename(path))[0]+'/mosaic' - os.makedirs('./dataset/'+os.path.splitext(os.path.basename(path))[0]+'') - os.makedirs(mask_path) - os.makedirs(ori_path) - os.makedirs(mosaic_path) - print('Add mosaic to images...') - mosaic_size = mosaic.get_autosize(img_ori_example,mask,area_type = 'bounding')*random.uniform(1,2) - models = ['squa_avg','rect_avg','squa_mid'] - mosaic_type = random.randint(0,len(models)-1) - rect_rat = random.uniform(1.2,1.6) - for i in range(len(imagepaths)): - mask = impro.imread(os.path.join('./tmp/ROI_mask',imagepaths[mask_index[i]])) - img_ori = impro.imread(os.path.join('./tmp/video2image',imagepaths[i])) - img_mosaic = mosaic.addmosaic_normal(img_ori,mask,mosaic_size,model = models[mosaic_type],rect_rat=rect_rat) - mask = impro.resize(mask, min(img_ori.shape[:2])) + img_ori_crop = impro.resize(img_ori[y-size:y+size,x-size:x+size],256) + img_mosaic_crop = impro.resize(img_mosaic[y-size:y+size,x-size:x+size],256) + mask_crop = impro.resize(mask[y-size:y+size,x-size:x+size],256) - img_ori_crop = impro.resize(img_ori[y-size:y+size,x-size:x+size],256) - img_mosaic_crop = impro.resize(img_mosaic[y-size:y+size,x-size:x+size],256) - mask_crop = impro.resize(mask[y-size:y+size,x-size:x+size],256) - - cv2.imwrite(os.path.join(ori_path,os.path.basename(imagepaths[i])),img_ori_crop) - cv2.imwrite(os.path.join(mosaic_path,os.path.basename(imagepaths[i])),img_mosaic_crop) - cv2.imwrite(os.path.join(mask_path,os.path.basename(imagepaths[i])),mask_crop) \ No newline at end of file + cv2.imwrite(os.path.join(ori_path,os.path.basename(imagepaths[i])),img_ori_crop) + cv2.imwrite(os.path.join(mosaic_path,os.path.basename(imagepaths[i])),img_mosaic_crop) + cv2.imwrite(os.path.join(mask_path,os.path.basename(imagepaths[i])),mask_crop) + except Exception as e: + print(e) \ No newline at end of file diff --git a/models/loadmodel.py b/models/loadmodel.py index 8126ee8..2e91a7c 100755 --- a/models/loadmodel.py +++ b/models/loadmodel.py @@ -2,13 +2,12 @@ import torch from .pix2pix_model import define_G from .pix2pixHD_model import define_G as define_G_HD from .unet_model import UNet +from .video_model import HypoNet def pix2pix(opt): # print(opt.model_path,opt.netG) if opt.netG == 'HD': netG = define_G_HD(3, 3, 64, 'global' ,4) - elif opt.netG == 'video': - netG = define_G(3*25+1, 3, 128, 'unet_128', norm='instance',use_dropout=True, init_type='normal', gpu_ids=[]) else: netG = define_G(3, 3, 64, opt.netG, norm='batch',use_dropout=True, init_type='normal', gpu_ids=[]) @@ -18,6 +17,15 @@ def pix2pix(opt): netG.cuda() return netG +def video(opt): + netG = HypoNet(3*25+1, 3) + netG.load_state_dict(torch.load(opt.model_path)) + netG.eval() + if opt.use_gpu: + netG.cuda() + return netG + + def unet_clean(opt): net = UNet(n_channels = 3, n_classes = 1) net.load_state_dict(torch.load(opt.mosaic_position_model_path)) diff --git a/models/video_model.py b/models/video_model.py new file mode 100644 index 0000000..011e8c8 --- /dev/null +++ b/models/video_model.py @@ -0,0 +1,203 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .unet_parts import * +from .pix2pix_model import * + +class encoder_2d(nn.Module): + """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. + + We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) + """ + + def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'): + """Construct a Resnet-based generator + + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + ngf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers + n_blocks (int) -- the number of ResNet blocks + padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero + """ + assert(n_blocks >= 0) + super(encoder_2d, self).__init__() + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + model = [nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), + norm_layer(ngf), + nn.ReLU(True)] + + n_downsampling = 2 + for i in range(n_downsampling): # add downsampling layers + mult = 2 ** i + model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), + norm_layer(ngf * mult * 2), + nn.ReLU(True)] + #torch.Size([1, 256, 32, 32]) + + # mult = 2 ** n_downsampling + # for i in range(n_blocks): # add ResNet blocks + # model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] + #torch.Size([1, 256, 32, 32]) + + # for i in range(n_downsampling): # add upsampling layers + # mult = 2 ** (n_downsampling - i) + # model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), + # kernel_size=3, stride=2, + # padding=1, output_padding=1, + # bias=use_bias), + # norm_layer(int(ngf * mult / 2)), + # nn.ReLU(True)] + # model += [nn.ReflectionPad2d(3)] + # model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] + # model += [nn.Tanh()] + + self.model = nn.Sequential(*model) + + def forward(self, input): + """Standard forward""" + return self.model(input) + + +class decoder_2d(nn.Module): + """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. + + We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) + """ + + def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'): + """Construct a Resnet-based generator + + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + ngf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers + n_blocks (int) -- the number of ResNet blocks + padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero + """ + super(decoder_2d, self).__init__() + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + model = [] + + n_downsampling = 2 + mult = 2 ** n_downsampling + for i in range(n_blocks): # add ResNet blocks + model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] + #torch.Size([1, 256, 32, 32]) + + for i in range(n_downsampling): # add upsampling layers + mult = 2 ** (n_downsampling - i) + # model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), + # kernel_size=3, stride=2, + # padding=1, output_padding=1, + # bias=use_bias), + # norm_layer(int(ngf * mult / 2)), + # nn.ReLU(True)] + #https://distill.pub/2016/deconv-checkerboard/ + #https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/190 + + model += [ nn.Upsample(scale_factor = 2, mode='nearest'), + nn.ReflectionPad2d(1), + nn.Conv2d(ngf * mult, int(ngf * mult / 2),kernel_size=3, stride=1, padding=0), + norm_layer(int(ngf * mult / 2)), + nn.ReLU(True)] + model += [nn.ReflectionPad2d(3)] + model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] + # model += [nn.Tanh()] + model += [nn.Sigmoid()] + + self.model = nn.Sequential(*model) + + def forward(self, input): + """Standard forward""" + return self.model(input) + + + +class conv_3d(nn.Module): + def __init__(self,inchannel,outchannel,kernel_size=3,stride=2,padding=1): + super(conv_3d, self).__init__() + self.conv = nn.Sequential( + nn.Conv3d(inchannel, outchannel, kernel_size=kernel_size, stride=stride, padding=padding, bias=False), + nn.BatchNorm3d(outchannel), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + x = self.conv(x) + return x + + +class encoder_3d(nn.Module): + def __init__(self,in_channel): + super(encoder_3d, self).__init__() + self.down1 = conv_3d(1, 64, 3, 2, 1) + self.down2 = conv_3d(64, 128, 3, 2, 1) + self.down3 = conv_3d(128, 256, 3, 1, 1) + # self.down4 = conv_3d(256, 512, 3, 2, 1) + self.conver2d = nn.Sequential( + nn.Conv2d(int(in_channel/4), 1, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(1), + nn.ReLU(inplace=True), + ) + + + def forward(self, x): + x = x.view(x.size(0),1,x.size(1),x.size(2),x.size(3)) + x = self.down1(x) + x = self.down2(x) + x = self.down3(x) + # x = self.down4(x) + + + x = x.view(x.size(1),x.size(2),x.size(3),x.size(4)) + x = self.conver2d(x) + x = x.view(x.size(1),x.size(0),x.size(2),x.size(3)) + # print(x.size()) + # x = self.avgpool(x) + return x + +# input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect' + + +class HypoNet(nn.Module): + def __init__(self, in_channel, out_channel): + super(HypoNet, self).__init__() + + self.encoder_2d = encoder_2d(4,-1,64,n_blocks=9) + self.encoder_3d = encoder_3d(in_channel) + self.decoder_2d = decoder_2d(4,3,64,n_blocks=9) + self.merge = nn.Sequential( + nn.Conv2d(256, 256, 1, 1, 0, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + + N = int((x.size()[1])/3) + x_2d = torch.cat((x[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], x[:,N-1:N,:,:]), 1) + + x_2d = self.encoder_2d(x_2d) + x_3d = self.encoder_3d(x) + x = x_2d + x_3d + x = self.merge(x) + # print(x.size()) + x = self.decoder_2d(x) + + + return x + diff --git a/models/video_model_unet.py b/models/video_model_unet.py new file mode 100644 index 0000000..8e338b8 --- /dev/null +++ b/models/video_model_unet.py @@ -0,0 +1,108 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .unet_parts import * + + +class conv_3d(nn.Module): + def __init__(self,inchannel,outchannel,kernel_size=3,stride=2,padding=1): + super(conv_3d, self).__init__() + self.conv = nn.Sequential( + nn.Conv3d(inchannel, outchannel, kernel_size=kernel_size, stride=stride, padding=padding, bias=False), + nn.BatchNorm3d(outchannel), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + x = self.conv(x) + return x + + +class encoder_3d(nn.Module): + def __init__(self,in_channel): + super(encoder_3d, self).__init__() + self.down1 = conv_3d(1, 64, 3, 2, 1) + self.down2 = conv_3d(64, 128, 3, 2, 1) + self.down3 = conv_3d(128, 256, 3, 2, 1) + self.down4 = conv_3d(256, 512, 3, 2, 1) + self.conver2d = nn.Sequential( + nn.Conv2d(int(in_channel/16)+1, 1, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(1), + nn.ReLU(inplace=True), + ) + + + def forward(self, x): + x = x.view(x.size(0),1,x.size(1),x.size(2),x.size(3)) + x = self.down1(x) + x = self.down2(x) + x = self.down3(x) + x = self.down4(x) + x = x.view(x.size(1),x.size(2),x.size(3),x.size(4)) + x = self.conver2d(x) + x = x.view(x.size(1),x.size(0),x.size(2),x.size(3)) + # print(x.size()) + # x = self.avgpool(x) + return x + + + + +class encoder_2d(nn.Module): + def __init__(self, in_channel): + super(encoder_2d, self).__init__() + self.inc = inconv(in_channel, 64) + self.down1 = down(64, 128) + self.down2 = down(128, 256) + self.down3 = down(256, 512) + self.down4 = down(512, 512) + + def forward(self, x): + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + + return x1,x2,x3,x4,x5 + +class decoder_2d(nn.Module): + def __init__(self, out_channel): + super(decoder_2d, self).__init__() + self.up1 = up(1024, 256,bilinear=False) + self.up2 = up(512, 128,bilinear=False) + self.up3 = up(256, 64,bilinear=False) + self.up4 = up(128, 64,bilinear=False) + self.outc = outconv(64, out_channel) + + def forward(self,x5,x4,x3,x2,x1): + x = self.up1(x5, x4) + x = self.up2(x, x3) + x = self.up3(x, x2) + x = self.up4(x, x1) + x = self.outc(x) + + return x + + +class HypoNet(nn.Module): + def __init__(self, in_channel, out_channel): + super(HypoNet, self).__init__() + + self.encoder_2d = encoder_2d(4) + self.encoder_3d = encoder_3d(in_channel) + self.decoder_2d = decoder_2d(out_channel) + + def forward(self, x): + + N = int((x.size()[1])/3) + x_2d = torch.cat((x[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], x[:,N-1:N,:,:]), 1) + # print(x_2d.size()) + x_3d = self.encoder_3d(x) + + x1,x2,x3,x4,x5 = self.encoder_2d(x_2d) + x5 = x5 + x_3d + x_2d = self.decoder_2d(x5,x4,x3,x2,x1) + + return x_2d + diff --git a/train/add/train.py b/train/add/train.py new file mode 100644 index 0000000..c1845cf --- /dev/null +++ b/train/add/train.py @@ -0,0 +1,259 @@ +import sys +import os +import random +import datetime + +import numpy as np +import cv2 + +import torch +import torch.backends.cudnn as cudnn +import torch.nn as nn +from torch import optim + +from unet import UNet + +def resize(img,size): + h, w = img.shape[:2] + if w >= h: + res = cv2.resize(img,(int(size*w/h), size)) + else: + res = cv2.resize(img,(size, int(size*h/w))) + return res + + +def Totensor(img,use_gpu=True): + size=img.shape[0] + img = torch.from_numpy(img).float() + if use_gpu: + img = img.cuda() + return img + +def random_color(img,random_num): + for i in range(3): img[:,:,i]=np.clip(img[:,:,i].astype('int')+random.randint(-random_num,random_num),0,255).astype('uint8') + bright = random.randint(-random_num*2,random_num*2) + for i in range(3): img[:,:,i]=np.clip(img[:,:,i].astype('int')+bright,0,255).astype('uint8') + return img + +def Toinputshape(imgs,masks,finesize): + batchsize = len(imgs) + result_imgs=[];result_masks=[] + for i in range(batchsize): + # print(imgs[i].shape,masks[i].shape) + img,mask = random_transform(imgs[i], masks[i], finesize) + # print(img.shape,mask.shape) + mask = mask[:,:,0].reshape(1,finesize,finesize)/255.0 + img = img.transpose((2, 0, 1))/255.0 + result_imgs.append(img) + result_masks.append(mask) + result_imgs = np.array(result_imgs) + result_masks = np.array(result_masks) + return result_imgs,result_masks + + + +def random_transform(img,mask,finesize): + + + # randomsize = int(finesize*(1.2+0.2*random.random())+2) + + h,w = img.shape[:2] + loadsize = min((h,w)) + a = (float(h)/float(w))*random.uniform(0.9, 1.1) + + if h10: + del(input_imgs[0]) + del(ground_trues[0]) + # time.sleep(0.1) + except Exception as e: + print("error:",e) + +import threading +t = threading.Thread(target=preload,args=()) #t为新创建的线程 +t.daemon = True +t.start() +time.sleep(5) #wait frist load + + +netG.train() +time_start=time.time() +print("Begin training...") +for iter in range(start_iter+1,ITER): + + # input_img,ground_true = loaddata() + ran = random.randint(1, 8) + input_img = input_imgs[ran] + ground_true = ground_trues[ran] + + pred = netG(input_img) + + if use_gan: + netD.train() + # print(input_img[0,3*N,:,:].size()) + # print((input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:]).size()) + real_A = torch.cat((input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], input_img[:,-1,:,:].reshape(-1,1,SIZE,SIZE)), 1) + fake_AB = torch.cat((real_A, pred), 1) + pred_fake = netD(fake_AB.detach()) + loss_D_fake = criterionGAN(pred_fake, False) + + real_AB = torch.cat((real_A, ground_true), 1) + pred_real = netD(real_AB) + loss_D_real = criterionGAN(pred_real, True) + loss_D = (loss_D_fake + loss_D_real) * 0.5 + loss_sum[2] += loss_D_fake.item() + loss_sum[3] += loss_D_real.item() + + optimizer_D.zero_grad() + loss_D.backward() + optimizer_D.step() + netD.eval() + + # fake_AB = torch.cat((input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], pred), 1) + real_A = torch.cat((input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], input_img[:,-1,:,:].reshape(-1,1,SIZE,SIZE)), 1) + fake_AB = torch.cat((real_A, pred), 1) + pred_fake = netD(fake_AB) + loss_G_GAN = criterionGAN(pred_fake, True)*lambda_gan + # Second, G(A) = B + if use_L2: + loss_G_L1 = (criterion_L1(pred, ground_true)+criterion_L2(pred, ground_true)) * lambda_L1 + else: + loss_G_L1 = criterion_L1(pred, ground_true) * lambda_L1 + # combine loss and calculate gradients + loss_G = loss_G_GAN + loss_G_L1 + loss_sum[0] += loss_G_L1.item() + loss_sum[1] += loss_G_GAN.item() + + optimizer_G.zero_grad() + loss_G.backward() + optimizer_G.step() + + else: + if use_L2: + loss_G_L1 = (criterion_L1(pred, ground_true)+criterion_L2(pred, ground_true)) * lambda_L1 + else: + loss_G_L1 = criterion_L1(pred, ground_true) * lambda_L1 + loss_sum[0] += loss_G_L1.item() + + optimizer_G.zero_grad() + loss_G_L1.backward() + optimizer_G.step() + + + if (iter+1)%100 == 0: + try: + showresult(input_img[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], ground_true, pred,'result_train.png') + except Exception as e: + print(e) + + if (iter+1)%1000 == 0: + time_end = time.time() + if use_gan: + print('iter:',iter+1,' L1_loss:', round(loss_sum[0]/1000,3),' G_loss:', round(loss_sum[1]/1000,3), + ' D_f:',round(loss_sum[2]/1000,3),' D_r:',round(loss_sum[3]/1000,3),' time:',round((time_end-time_start)/1000,2)) + if (iter+1)/1000 >= 10: + loss_plot[0].append(loss_sum[0]/1000) + loss_plot[1].append(loss_sum[1]/1000) + item_plot.append(iter+1) + try: + plt.plot(item_plot,loss_plot[0]) + plt.plot(item_plot,loss_plot[1]) + plt.savefig(os.path.join(dir_checkpoint,'loss.png')) + plt.close() + except Exception as e: + print("error:",e) + else: + print('iter:',iter+1,' L1_loss:',round(loss_sum[0]/1000,3),' time:',round((time_end-time_start)/1000,2)) + if (iter+1)/1000 >= 10: + loss_plot[0].append(loss_sum[0]/1000) + item_plot.append(iter+1) + try: + plt.plot(item_plot,loss_plot[0]) + plt.savefig(os.path.join(dir_checkpoint,'loss.png')) + plt.close() + except Exception as e: + print("error:",e) + loss_sum = [0.,0.,0.,0.] + time_start=time.time() + + + + if (iter+1)%SAVE_FRE == 0: + if iter+1 != SAVE_FRE: + os.rename(os.path.join(dir_checkpoint,'last_G.pth'),os.path.join(dir_checkpoint,str(iter+1-SAVE_FRE)+'G.pth')) + torch.save(netG.cpu().state_dict(),os.path.join(dir_checkpoint,'last_G.pth')) + if use_gan: + if iter+1 != SAVE_FRE: + os.rename(os.path.join(dir_checkpoint,'last_D.pth'),os.path.join(dir_checkpoint,str(iter+1-SAVE_FRE)+'D.pth')) + torch.save(netD.cpu().state_dict(),os.path.join(dir_checkpoint,'last_D.pth')) + if use_gpu: + netG.cuda() + if use_gan: + netD.cuda() + f = open(os.path.join(dir_checkpoint,'iter'),'w+') + f.write(str(iter+1)) + f.close() + # torch.save(netG.cpu().state_dict(),dir_checkpoint+'iter'+str(iter+1)+'.pth') + print('network saved.') + + #test + netG.eval() + result = np.zeros((SIZE*2,SIZE*4,3), dtype='uint8') + test_names = os.listdir('./test') + + for cnt,test_name in enumerate(test_names,0): + img_names = os.listdir(os.path.join('./test',test_name,'image')) + input_img = np.zeros((SIZE,SIZE,3*N+1), dtype='uint8') + img_names.sort() + for i in range(0,N): + img = impro.imread(os.path.join('./test',test_name,'image',img_names[i])) + img = impro.resize(img,SIZE) + input_img[:,:,i*3:(i+1)*3] = img + + mask = impro.imread(os.path.join('./test',test_name,'mask.png'),'gray') + mask = impro.resize(mask,SIZE) + mask = impro.mask_threshold(mask,15,128) + input_img[:,:,-1] = mask + result[0:SIZE,SIZE*cnt:SIZE*(cnt+1),:] = input_img[:,:,int((N-1)/2)*3:(int((N-1)/2)+1)*3] + input_img = data.im2tensor(input_img,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False) + pred = netG(input_img) + + pred = (pred.cpu().detach().numpy()*255)[0].transpose((1, 2, 0)) + result[SIZE:SIZE*2,SIZE*cnt:SIZE*(cnt+1),:] = pred + + cv2.imwrite(os.path.join(dir_checkpoint,str(iter+1)+'_test.png'), result) + netG.eval() \ No newline at end of file diff --git a/train/train.py b/train/train.py deleted file mode 100644 index a88c5ae..0000000 --- a/train/train.py +++ /dev/null @@ -1,209 +0,0 @@ -import os -import numpy as np -import cv2 -import random -import torch -import torch.nn as nn -import time - -import sys -sys.path.append("..") -from models import runmodel,loadmodel -from util import mosaic,util,ffmpeg,filt,data -from util import image_processing as impro -from cores import Options -from models import pix2pix_model -from matplotlib import pyplot as plt -import torch.backends.cudnn as cudnn - -N = 25 -ITER = 1000000 -LR = 0.0002 -use_gpu = True -CONTINUE = True -# BATCHSIZE = 4 -dir_checkpoint = 'checkpoints/' -SAVE_FRE = 5000 -start_iter = 0 -SIZE = 256 -lambda_L1 = 100.0 -opt = Options().getparse() -opt.use_gpu=True -videos = os.listdir('./dataset') -videos.sort() -lengths = [] -for video in videos: - video_images = os.listdir('./dataset/'+video+'/ori') - lengths.append(len(video_images)) - - -netG = pix2pix_model.define_G(3*N+1, 3, 128, 'resnet_9blocks', norm='instance',use_dropout=True, init_type='normal', gpu_ids=[]) -netD = pix2pix_model.define_D(3*2, 64, 'basic', n_layers_D=3, norm='instance', init_type='normal', init_gain=0.02, gpu_ids=[]) - -if CONTINUE: - netG.load_state_dict(torch.load(dir_checkpoint+'last_G.pth')) - netD.load_state_dict(torch.load(dir_checkpoint+'last_D.pth')) - f = open('./iter','r') - start_iter = int(f.read()) - f.close() -if use_gpu: - netG.cuda() - netD.cuda() - cudnn.benchmark = True -optimizer_G = torch.optim.Adam(netG.parameters(), lr=LR) -optimizer_D = torch.optim.Adam(netG.parameters(), lr=LR) -criterion_L1 = nn.L1Loss() -criterion_L2 = nn.MSELoss() -criterionGAN = pix2pix_model.GANLoss('lsgan').cuda() - -def showresult(img1,img2,img3,name): - img1 = (img1.cpu().detach().numpy()*255) - img2 = (img2.cpu().detach().numpy()*255) - img3 = (img3.cpu().detach().numpy()*255) - batchsize = img1.shape[0] - size = img1.shape[3] - ran =int(batchsize*random.random()) - showimg=np.zeros((size,size*3,3)) - showimg[0:size,0:size] =img1[ran].transpose((1, 2, 0)) - showimg[0:size,size:size*2] = img2[ran].transpose((1, 2, 0)) - showimg[0:size,size*2:size*3] = img3[ran].transpose((1, 2, 0)) - cv2.imwrite(name, showimg) - - -def loaddata(): - video_index = random.randint(0,len(videos)-1) - video = videos[video_index] - img_index = random.randint(N,lengths[video_index]- N) - input_img = np.zeros((SIZE,SIZE,3*N+1), dtype='uint8') - for i in range(0,N): - # print('./dataset/'+video+'/mosaic/output_'+'%05d'%(img_index+i-int(N/2))+'.png') - img = cv2.imread('./dataset/'+video+'/mosaic/output_'+'%05d'%(img_index+i-int(N/2))+'.png') - img = impro.resize(img,SIZE) - input_img[:,:,i*3:(i+1)*3] = img - mask = cv2.imread('./dataset/'+video+'/mask/output_'+'%05d'%(img_index)+'.png',0) - mask = impro.resize(mask,256) - mask = impro.mask_threshold(mask,15,128) - input_img[:,:,-1] = mask - input_img = data.im2tensor(input_img,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False) - - ground_true = cv2.imread('./dataset/'+video+'/ori/output_'+'%05d'%(img_index)+'.png') - ground_true = impro.resize(ground_true,SIZE) - # ground_true = im2tensor(ground_true,use_gpu) - ground_true = data.im2tensor(ground_true,bgr2rgb=False,use_gpu=opt.use_gpu,use_transform = False) - return input_img,ground_true - -input_imgs=[] -ground_trues=[] -def preload(): - while 1: - input_img,ground_true = loaddata() - input_imgs.append(input_img) - ground_trues.append(ground_true) - if len(input_imgs)>10: - del(input_imgs[0]) - del(ground_trues[0]) -import threading -t=threading.Thread(target=preload,args=()) #t为新创建的线程 -t.start() -time.sleep(3) #wait frist load - - -netG.train() -loss_sum = [0.,0.] -loss_plot = [[],[]] -item_plot = [] -time_start=time.time() -print("Begin training...") -for iter in range(start_iter+1,ITER): - - # input_img,ground_true = loaddata() - ran = random.randint(0, 9) - input_img = input_imgs[ran] - ground_true = ground_trues[ran] - - pred = netG(input_img) - - fake_AB = torch.cat((input_img[:,int((N+1)/2)*3:(int((N+1)/2)+1)*3,:,:], pred), 1) - pred_fake = netD(fake_AB.detach()) - loss_D_fake = criterionGAN(pred_fake, False) - - real_AB = torch.cat((input_img[:,int((N+1)/2)*3:(int((N+1)/2)+1)*3,:,:], ground_true), 1) - pred_real = netD(real_AB) - loss_D_real = criterionGAN(pred_real, True) - loss_D = (loss_D_fake + loss_D_real) * 0.5 - - optimizer_D.zero_grad() - loss_D.backward() - optimizer_D.step() - - fake_AB = torch.cat((input_img[:,int((N+1)/2)*3:(int((N+1)/2)+1)*3,:,:], pred), 1) - pred_fake = netD(fake_AB) - loss_G_GAN = criterionGAN(pred_fake, True) - # Second, G(A) = B - loss_G_L1 = criterion_L1(pred, ground_true) * lambda_L1 - # combine loss and calculate gradients - loss_G = loss_G_GAN + loss_G_L1 - loss_sum[0] += loss_G_L1.item() - loss_sum[1] += loss_G.item() - - optimizer_G.zero_grad() - loss_G.backward() - optimizer_G.step() - - - - # a = netD(ground_true) - # print(a.size()) - # loss = criterion_L1(pred, ground_true)+criterion_L2(pred, ground_true) - # # loss = criterion_L2(pred, ground_true) - # loss_sum += loss.item() - - # optimizer_G.zero_grad() - # loss.backward() - # optimizer_G.step() - - if (iter+1)%100 == 0: - showresult(input_img[:,int((N+1)/2)*3:(int((N+1)/2)+1)*3,:,:], ground_true, pred,'./result_train.png') - if (iter+1)%100 == 0: - time_end=time.time() - print('iter:',iter+1,' L1_loss:', round(loss_sum[0]/100,4),'G_loss:', round(loss_sum[1]/100,4),'time:',round((time_end-time_start)/100,4)) - if (iter+1)/100 >= 10: - loss_plot[0].append(loss_sum[0]/100) - loss_plot[1].append(loss_sum[1]/100) - item_plot.append(iter+1) - plt.plot(item_plot,loss_plot[0]) - plt.plot(item_plot,loss_plot[1]) - plt.savefig('./loss.png') - plt.close() - loss_sum = [0.,0.] - - #show test result - # netG.eval() - # input_img = np.zeros((SIZE,SIZE,3*N), dtype='uint8') - # imgs = os.listdir('./test') - # for i in range(0,N): - # # print('./dataset/'+video+'/mosaic/output_'+'%05d'%(img_index+i-int(N/2))+'.png') - # img = cv2.imread('./test/'+imgs[i]) - # img = impro.resize(img,SIZE) - # input_img[:,:,i*3:(i+1)*3] = img - # input_img = im2tensor(input_img,use_gpu) - # ground_true = cv2.imread('./test/output_'+'%05d'%13+'.png') - # ground_true = impro.resize(ground_true,SIZE) - # ground_true = im2tensor(ground_true,use_gpu) - # pred = netG(input_img) - # showresult(input_img[:,int((N+1)/2)*3:(int((N+1)/2)+1)*3,:,:],pred,pred,'./result_test.png') - - netG.train() - time_start=time.time() - - if (iter+1)%SAVE_FRE == 0: - torch.save(netG.cpu().state_dict(),dir_checkpoint+'last_G.pth') - torch.save(netD.cpu().state_dict(),dir_checkpoint+'last_D.pth') - if use_gpu: - netG.cuda() - netD.cuda() - f = open('./iter','w+') - f.write(str(iter+1)) - f.close() - # torch.save(netG.cpu().state_dict(),dir_checkpoint+'iter'+str(iter+1)+'.pth') - print('network saved.') diff --git a/util/image_processing.py b/util/image_processing.py index eb7952c..f945f7f 100755 --- a/util/image_processing.py +++ b/util/image_processing.py @@ -2,7 +2,9 @@ import cv2 import numpy as np def imread(file_path,mod = 'normal'): - + ''' + mod = 'normal' | 'gray' | 'all' + ''' if mod == 'normal': cv_img = cv2.imread(file_path) elif mod == 'gray': -- GitLab