未验证 提交 ec841321 编写于 作者: B Bubbliiiing 提交者: GitHub

Update utils.py

上级 2e4c5b14
...@@ -5,10 +5,11 @@ import time ...@@ -5,10 +5,11 @@ import time
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np import numpy as np
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from torch.autograd import Variable
from PIL import Image, ImageDraw, ImageFont
from torchvision.ops import nms
class DecodeBox(nn.Module): class DecodeBox(nn.Module):
def __init__(self, anchors, num_classes, img_size): def __init__(self, anchors, num_classes, img_size):
...@@ -225,24 +226,37 @@ def non_max_suppression(prediction, num_classes, conf_thres=0.5, nms_thres=0.4): ...@@ -225,24 +226,37 @@ def non_max_suppression(prediction, num_classes, conf_thres=0.5, nms_thres=0.4):
if prediction.is_cuda: if prediction.is_cuda:
unique_labels = unique_labels.cuda() unique_labels = unique_labels.cuda()
detections = detections.cuda()
for c in unique_labels: for c in unique_labels:
# 获得某一类初步筛选后全部的预测结果 # 获得某一类初步筛选后全部的预测结果
detections_class = detections[detections[:, -1] == c] detections_class = detections[detections[:, -1] == c]
# 按照存在物体的置信度排序
_, conf_sort_index = torch.sort(detections_class[:, 4], descending=True) #------------------------------------------#
detections_class = detections_class[conf_sort_index] # 使用官方自带的非极大抑制会速度更快一些!
# 进行非极大抑制 #------------------------------------------#
max_detections = [] keep = nms(
while detections_class.size(0): detections_class[:, :4],
# 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉 detections_class[:, 4],
max_detections.append(detections_class[0].unsqueeze(0)) nms_thres
if len(detections_class) == 1: )
break max_detections = detections_class[keep]
ious = bbox_iou(max_detections[-1], detections_class[1:])
detections_class = detections_class[1:][ious < nms_thres] # # 按照存在物体的置信度排序
# 堆叠 # _, conf_sort_index = torch.sort(detections_class[:, 4], descending=True)
max_detections = torch.cat(max_detections).data # detections_class = detections_class[conf_sort_index]
# # 进行非极大抑制
# max_detections = []
# while detections_class.size(0):
# # 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉
# max_detections.append(detections_class[0].unsqueeze(0))
# if len(detections_class) == 1:
# break
# ious = bbox_iou(max_detections[-1], detections_class[1:])
# detections_class = detections_class[1:][ious < nms_thres]
# # 堆叠
# max_detections = torch.cat(max_detections).data
# Add max detections to outputs # Add max detections to outputs
output[image_i] = max_detections if output[image_i] is None else torch.cat( output[image_i] = max_detections if output[image_i] is None else torch.cat(
(output[image_i], max_detections)) (output[image_i], max_detections))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册