未验证 提交 38d8e41a 编写于 作者: B Bubbliiiing 提交者: GitHub

Add files via upload

上级 216ea699
......@@ -99,8 +99,12 @@ class SSD(object):
# 检测图片
#---------------------------------------------------#
def detect_image(self, image):
image_shape = np.array(np.shape(image)[0:2])
#---------------------------------------------------------#
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
#---------------------------------------------------------#
image = image.convert('RGB')
image_shape = np.array(np.shape(image)[0:2])
#---------------------------------------------------------#
# 给图像增加灰条,实现不失真的resize
# 也可以直接resize进行识别
......@@ -108,8 +112,7 @@ class SSD(object):
if self.letterbox_image:
crop_img = np.array(letterbox_image(image, (self.input_shape[1],self.input_shape[0])))
else:
crop_img = image.convert('RGB')
crop_img = crop_img.resize((self.input_shape[1],self.input_shape[0]), Image.BICUBIC)
crop_img = image.resize((self.input_shape[1],self.input_shape[0]), Image.BICUBIC)
photo = np.array(crop_img,dtype = np.float64)
with torch.no_grad():
......
......@@ -13,7 +13,7 @@ from torchsummary import summary
from tqdm import tqdm
from nets.ssd import get_ssd
from nets.ssd_training import Generator, MultiBoxLoss
from nets.ssd_training import Generator, LossHistory, MultiBoxLoss
from utils.config import Config
from utils.dataloader import SSDDataset, ssd_dataset_collate
......@@ -106,6 +106,8 @@ def fit_one_epoch(net,criterion,epoch,epoch_size,epoch_size_val,gen,genval,Epoch
total_loss = loc_loss + conf_loss
val_loss = loc_loss_val + conf_loss_val
loss_history.append_loss(total_loss/(epoch_size+1), val_loss/(epoch_size_val+1))
print('Finish Validation')
print('Epoch:'+ str(epoch+1) + '/' + str(Epoch))
print('Total Loss: %.4f || Val Loss: %.4f ' % (total_loss/(epoch_size+1),val_loss/(epoch_size_val+1)))
......@@ -159,6 +161,7 @@ if __name__ == "__main__":
num_train = len(lines) - num_val
criterion = MultiBoxLoss(Config['num_classes'], 0.5, True, 0, True, 3, 0.5, False, Cuda)
loss_history = LossHistory("logs/")
net = model.train()
if Cuda:
......@@ -201,6 +204,9 @@ if __name__ == "__main__":
epoch_size = num_train // Batch_size
epoch_size_val = num_val // Batch_size
if epoch_size == 0 or epoch_size_val == 0:
raise ValueError("数据集过小,无法进行训练,请扩充数据集。")
for epoch in range(Init_Epoch,Freeze_Epoch):
fit_one_epoch(net,criterion,epoch,epoch_size,epoch_size_val,gen,gen_val,Freeze_Epoch,Cuda)
lr_scheduler.step()
......@@ -232,6 +238,9 @@ if __name__ == "__main__":
epoch_size = num_train // Batch_size
epoch_size_val = num_val // Batch_size
if epoch_size == 0 or epoch_size_val == 0:
raise ValueError("数据集过小,无法进行训练,请扩充数据集。")
for epoch in range(Freeze_Epoch,Unfreeze_Epoch):
fit_one_epoch(net,criterion,epoch,epoch_size,epoch_size_val,gen,gen_val,Unfreeze_Epoch,Cuda)
lr_scheduler.step()
......@@ -7,7 +7,9 @@ import xml.etree.ElementTree as ET
from os import getcwd
sets=[('2007', 'train'), ('2007', 'val'), ('2007', 'test')]
#-----------------------------------------------------#
# 这里设定的classes顺序要和model_data里的txt一样
#-----------------------------------------------------#
classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
def convert_annotation(year, image_id, list_file):
......@@ -24,14 +26,14 @@ def convert_annotation(year, image_id, list_file):
continue
cls_id = classes.index(cls)
xmlbox = obj.find('bndbox')
b = (int(xmlbox.find('xmin').text), int(xmlbox.find('ymin').text), int(xmlbox.find('xmax').text), int(xmlbox.find('ymax').text))
b = (int(float(xmlbox.find('xmin').text)), int(float(xmlbox.find('ymin').text)), int(float(xmlbox.find('xmax').text)), int(float(xmlbox.find('ymax').text)))
list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))
wd = getcwd()
for year, image_set in sets:
image_ids = open('VOCdevkit/VOC%s/ImageSets/Main/%s.txt'%(year, image_set)).read().strip().split()
list_file = open('%s_%s.txt'%(year, image_set), 'w')
image_ids = open('VOCdevkit/VOC%s/ImageSets/Main/%s.txt'%(year, image_set), encoding='utf-8').read().strip().split()
list_file = open('%s_%s.txt'%(year, image_set), 'w', encoding='utf-8')
for image_id in image_ids:
list_file.write('%s/VOCdevkit/VOC%s/JPEGImages/%s.jpg'%(wd, year, image_id))
convert_annotation(year, image_id, list_file)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册