未验证 提交 3f2a6961 编写于 作者: B Bubbliiiing 提交者: GitHub

Update utils.py

上级 bc720d7e
......@@ -208,16 +208,17 @@ def non_max_suppression(prediction, num_classes, conf_thres=0.5, nms_thres=0.4):
output = [None for _ in range(len(prediction))]
for image_i, image_pred in enumerate(prediction):
# 获得种类及其置信度
class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True)
# 利用置信度进行第一轮筛选
conf_mask = (image_pred[:, 4] >= conf_thres).squeeze()
image_pred = image_pred[conf_mask]
conf_mask = (image_pred[:, 4]*class_conf[:, 0] >= conf_thres).squeeze()
image_pred = image_pred[conf_mask]
class_conf = class_conf[conf_mask]
class_pred = class_pred[conf_mask]
if not image_pred.size(0):
continue
# 获得种类及其置信度
class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True)
# 获得的内容为(x1, y1, x2, y2, obj_conf, class_conf, class_pred)
detections = torch.cat((image_pred[:, :5], class_conf.float(), class_pred.float()), 1)
......@@ -237,13 +238,13 @@ def non_max_suppression(prediction, num_classes, conf_thres=0.5, nms_thres=0.4):
#------------------------------------------#
keep = nms(
detections_class[:, :4],
detections_class[:, 4],
detections_class[:, 4]*detections_class[:, 5],
nms_thres
)
max_detections = detections_class[keep]
# # 按照存在物体的置信度排序
# _, conf_sort_index = torch.sort(detections_class[:, 4], descending=True)
# _, conf_sort_index = torch.sort(detections_class[:, 4]*detections_class[:, 5], descending=True)
# detections_class = detections_class[conf_sort_index]
# # 进行非极大抑制
# max_detections = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册