提交 480fea7b 编写于 作者: H HypoX64

Fix training part #5

上级 9cb4eb05
...@@ -50,7 +50,7 @@ class Options(): ...@@ -50,7 +50,7 @@ class Options():
self.initialized = True self.initialized = True
def getparse(self): def getparse(self, test_flag = False):
if not self.initialized: if not self.initialized:
self.initialize() self.initialize()
self.opt = self.parser.parse_args() self.opt = self.parser.parse_args()
...@@ -65,10 +65,11 @@ class Options(): ...@@ -65,10 +65,11 @@ class Options():
else: else:
self.opt.use_gpu = -1 self.opt.use_gpu = -1
if not os.path.exists(self.opt.media_path): if test_flag:
print('Error: Bad media path!') if not os.path.exists(self.opt.media_path):
input('Please press any key to exit.\n') print('Error: Bad media path!')
exit(0) input('Please press any key to exit.\n')
exit(0)
if self.opt.mode == 'auto': if self.opt.mode == 'auto':
if 'clean' in model_name or self.opt.traditional: if 'clean' in model_name or self.opt.traditional:
......
...@@ -5,7 +5,7 @@ from cores import Options,core ...@@ -5,7 +5,7 @@ from cores import Options,core
from util import util from util import util
from models import loadmodel from models import loadmodel
opt = Options().getparse() opt = Options().getparse(test_flag = True)
util.file_init(opt) util.file_init(opt)
def main(): def main():
......
...@@ -106,7 +106,13 @@ for fold in range(opt.fold): ...@@ -106,7 +106,13 @@ for fold in range(opt.fold):
# t.start() # t.start()
saveflag = True saveflag = True
x,y,size,area = impro.boundingSquare(mask, random.uniform(1.4,1.6)) if opt.mod == ['drawn','irregular']:
x,y,size,area = impro.boundingSquare(mask_drawn, random.uniform(1.2,1.6))
elif opt.mod == ['network','irregular']:
x,y,size,area = impro.boundingSquare(mask_net, random.uniform(1.2,1.6))
else:
x,y,size,area = impro.boundingSquare(mask, random.uniform(1.2,1.6))
if area < 1000: if area < 1000:
saveflag = False saveflag = False
else: else:
......
...@@ -10,6 +10,8 @@ import datetime ...@@ -10,6 +10,8 @@ import datetime
import time import time
import numpy as np import numpy as np
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
import cv2 import cv2
......
...@@ -16,6 +16,8 @@ from multiprocessing import Process, Queue ...@@ -16,6 +16,8 @@ from multiprocessing import Process, Queue
from util import mosaic,util,ffmpeg,filt,data from util import mosaic,util,ffmpeg,filt,data
from util import image_processing as impro from util import image_processing as impro
from models import pix2pix_model,pix2pixHD_model,video_model,unet_model,loadmodel,videoHD_model from models import pix2pix_model,pix2pixHD_model,video_model,unet_model,loadmodel,videoHD_model
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
......
...@@ -79,7 +79,7 @@ def ch_one2three(img): ...@@ -79,7 +79,7 @@ def ch_one2three(img):
res = cv2.merge([img, img, img]) res = cv2.merge([img, img, img])
return res return res
def color_adjust(img,alpha=1,beta=0,b=0,g=0,r=0,ran = False): def color_adjust(img,alpha=0,beta=0,b=0,g=0,r=0,ran = False):
''' '''
g(x) = (1+α)g(x)+255*β, g(x) = (1+α)g(x)+255*β,
g(x) = g(x[:+b*255,:+g*255,:+r*255]) g(x) = g(x[:+b*255,:+g*255,:+r*255])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册