import os import sys sys.path.append("..") sys.path.append("../..") from cores import Options opt = Options() import random import datetime import time import numpy as np import matplotlib matplotlib.use('Agg') from matplotlib import pyplot as plt import cv2 import torch import torch.backends.cudnn as cudnn import torch.nn as nn from torch import optim from util import mosaic,util,ffmpeg,filt,data from util import image_processing as impro from models import unet_model,BiSeNet_model ''' --------------------------Get options-------------------------- ''' opt.parser.add_argument('--lr',type=float,default=0.001, help='') opt.parser.add_argument('--finesize',type=int,default=360, help='') opt.parser.add_argument('--loadsize',type=int,default=400, help='') opt.parser.add_argument('--batchsize',type=int,default=8, help='') opt.parser.add_argument('--model',type=str,default='BiSeNet', help='BiSeNet or UNet') opt.parser.add_argument('--maxepoch',type=int,default=100, help='') opt.parser.add_argument('--savefreq',type=int,default=5, help='') opt.parser.add_argument('--maxload',type=int,default=1000000, help='') opt.parser.add_argument('--continue_train', action='store_true', help='') opt.parser.add_argument('--startepoch',type=int,default=0, help='') opt.parser.add_argument('--dataset',type=str,default='./datasets/face/', help='') opt.parser.add_argument('--savename',type=str,default='face', help='') ''' --------------------------Init-------------------------- ''' opt = opt.getparse() dir_img = os.path.join(opt.dataset,'origin_image') dir_mask = os.path.join(opt.dataset,'mask') dir_checkpoint = os.path.join('checkpoints/',opt.savename) util.makedirs(dir_checkpoint) util.writelog(os.path.join(dir_checkpoint,'loss.txt'), str(time.asctime(time.localtime(time.time())))+'\n'+util.opt2str(opt)) def Totensor(img,use_gpu=True): size=img.shape[0] img = torch.from_numpy(img).float() if opt.use_gpu != -1: img = img.cuda() return img def loadimage(imagepaths,maskpaths,opt,test_flag = False): batchsize = len(imagepaths) images = np.zeros((batchsize,3,opt.finesize,opt.finesize), dtype=np.float32) masks = np.zeros((batchsize,1,opt.finesize,opt.finesize), dtype=np.float32) for i in range(len(imagepaths)): img = impro.resize(impro.imread(imagepaths[i]),opt.loadsize) mask = impro.resize(impro.imread(maskpaths[i],mod = 'gray'),opt.loadsize) img,mask = data.random_transform_image(img, mask, opt.finesize, test_flag) images[i] = (img.transpose((2, 0, 1))/255.0) masks[i] = (mask.reshape(1,1,opt.finesize,opt.finesize)/255.0) images = Totensor(images,opt.use_gpu) masks = Totensor(masks,opt.use_gpu) return images,masks ''' --------------------------checking dataset-------------------------- ''' print('checking dataset...') imagepaths = sorted(util.Traversal(dir_img))[:opt.maxload] maskpaths = sorted(util.Traversal(dir_mask))[:opt.maxload] data.shuffledata(imagepaths, maskpaths) if len(imagepaths) != len(maskpaths) : print('dataset error!') exit(0) img_num = len(imagepaths) print('find images:',img_num) imagepaths_train = (imagepaths[0:int(img_num*0.8)]).copy() maskpaths_train = (maskpaths[0:int(img_num*0.8)]).copy() imagepaths_eval = (imagepaths[int(img_num*0.8):]).copy() maskpaths_eval = (maskpaths[int(img_num*0.8):]).copy() ''' --------------------------def network-------------------------- ''' if opt.model =='UNet': net = unet_model.UNet(n_channels = 3, n_classes = 1) elif opt.model =='BiSeNet': net = BiSeNet_model.BiSeNet(num_classes=1, context_path='resnet18') if opt.continue_train: if not os.path.isfile(os.path.join(dir_checkpoint,'last.pth')): opt.continue_train = False print('can not load last.pth, training on init weight.') if opt.continue_train: net.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last.pth'))) f = open(os.path.join(dir_checkpoint,'epoch_log.txt'),'r') opt.startepoch = int(f.read()) f.close() if opt.use_gpu != -1: net.cuda() cudnn.benchmark = True optimizer = torch.optim.Adam(net.parameters(), lr=opt.lr) if opt.model =='UNet': criterion = nn.BCELoss() elif opt.model =='BiSeNet': criterion = nn.BCELoss() # criterion = BiSeNet_model.DiceLoss() ''' --------------------------train-------------------------- ''' loss_plot = {'train':[],'eval':[]} print('begin training......') for epoch in range(opt.startepoch,opt.maxepoch): random_save = random.randint(0, int(img_num*0.8/opt.batchsize)) data.shuffledata(imagepaths_train, maskpaths_train) starttime = datetime.datetime.now() util.writelog(os.path.join(dir_checkpoint,'loss.txt'),'Epoch {}/{}.'.format(epoch + 1, opt.maxepoch),True) net.train() if opt.use_gpu != -1: net.cuda() epoch_loss = 0 for i in range(int(img_num*0.8/opt.batchsize)): img,mask = loadimage(imagepaths_train[i*opt.batchsize:(i+1)*opt.batchsize], maskpaths_train[i*opt.batchsize:(i+1)*opt.batchsize], opt) if opt.model =='UNet': mask_pred = net(img) loss = criterion(mask_pred, mask) epoch_loss += loss.item() elif opt.model =='BiSeNet': mask_pred, mask_pred_sup1, mask_pred_sup2 = net(img) loss1 = criterion(mask_pred, mask) loss2 = criterion(mask_pred_sup1, mask) loss3 = criterion(mask_pred_sup2, mask) loss = loss1 + loss2 + loss3 epoch_loss += loss1.item() optimizer.zero_grad() loss.backward() optimizer.step() if i%100 == 0: data.showresult(img,mask,mask_pred,os.path.join(dir_checkpoint,'result.png'),True) if i == random_save: data.showresult(img,mask,mask_pred,os.path.join(dir_checkpoint,'epoch_'+str(epoch+1)+'.png'),True) epoch_loss = epoch_loss/int(img_num*0.8/opt.batchsize) loss_plot['train'].append(epoch_loss) #val epoch_loss_eval = 0 with torch.no_grad(): # net.eval() for i in range(int(img_num*0.2/opt.batchsize)): img,mask = loadimage(imagepaths_eval[i*opt.batchsize:(i+1)*opt.batchsize], maskpaths_eval[i*opt.batchsize:(i+1)*opt.batchsize], opt,test_flag=True) if opt.model =='UNet': mask_pred = net(img) elif opt.model =='BiSeNet': mask_pred, _, _ = net(img) # mask_pred = net(img) loss= criterion(mask_pred, mask) epoch_loss_eval += loss.item() epoch_loss_eval = epoch_loss_eval/int(img_num*0.2/opt.batchsize) loss_plot['eval'].append(epoch_loss_eval) # torch.cuda.empty_cache() #savelog endtime = datetime.datetime.now() util.writelog(os.path.join(dir_checkpoint,'loss.txt'), '--- Epoch train_loss: {0:.6f} eval_loss: {1:.6f} Cost time: {2:} s'.format( epoch_loss, epoch_loss_eval, (endtime - starttime).seconds), True) #plot plt.plot(np.linspace(opt.startepoch+1,epoch+1,epoch+1-opt.startepoch),loss_plot['train'],label='train') plt.plot(np.linspace(opt.startepoch+1,epoch+1,epoch+1-opt.startepoch),loss_plot['eval'],label='eval') plt.xlabel('Epoch') plt.ylabel('BCELoss') plt.legend(loc=1) plt.savefig(os.path.join(dir_checkpoint,'loss.jpg')) plt.close() #save network torch.save(net.cpu().state_dict(),os.path.join(dir_checkpoint,'last.pth')) f = open(os.path.join(dir_checkpoint,'epoch_log.txt'),'w+') f.write(str(epoch+1)) f.close() if (epoch+1)%opt.savefreq == 0: torch.save(net.cpu().state_dict(),os.path.join(dir_checkpoint,'epoch'+str(epoch+1)+'.pth')) print('network saved.')