From 3f2a696162154caa84dbc0068be0e3e1098f55ca Mon Sep 17 00:00:00 2001 From: Bubbliiiing <47347516+bubbliiiing@users.noreply.github.com> Date: Fri, 11 Sep 2020 15:08:01 +0800 Subject: [PATCH] Update utils.py --- utils/utils.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/utils/utils.py b/utils/utils.py index d4d8443..495bb1b 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -208,16 +208,17 @@ def non_max_suppression(prediction, num_classes, conf_thres=0.5, nms_thres=0.4): output = [None for _ in range(len(prediction))] for image_i, image_pred in enumerate(prediction): + # 获得种类及其置信度 + class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True) + # 利用置信度进行第一轮筛选 - conf_mask = (image_pred[:, 4] >= conf_thres).squeeze() - image_pred = image_pred[conf_mask] + conf_mask = (image_pred[:, 4]*class_conf[:, 0] >= conf_thres).squeeze() + image_pred = image_pred[conf_mask] + class_conf = class_conf[conf_mask] + class_pred = class_pred[conf_mask] if not image_pred.size(0): continue - - # 获得种类及其置信度 - class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True) - # 获得的内容为(x1, y1, x2, y2, obj_conf, class_conf, class_pred) detections = torch.cat((image_pred[:, :5], class_conf.float(), class_pred.float()), 1) @@ -237,13 +238,13 @@ def non_max_suppression(prediction, num_classes, conf_thres=0.5, nms_thres=0.4): #------------------------------------------# keep = nms( detections_class[:, :4], - detections_class[:, 4], + detections_class[:, 4]*detections_class[:, 5], nms_thres ) max_detections = detections_class[keep] # # 按照存在物体的置信度排序 - # _, conf_sort_index = torch.sort(detections_class[:, 4], descending=True) + # _, conf_sort_index = torch.sort(detections_class[:, 4]*detections_class[:, 5], descending=True) # detections_class = detections_class[conf_sort_index] # # 进行非极大抑制 # max_detections = [] -- GitLab