From 980bc5e014658d510d621098094df336491330b4 Mon Sep 17 00:00:00 2001 From: Bubbliiiing <47347516+bubbliiiing@users.noreply.github.com> Date: Thu, 21 Jan 2021 20:42:26 +0800 Subject: [PATCH] Update utils.py --- utils/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/utils/utils.py b/utils/utils.py index 2055a8f..8a909aa 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -189,8 +189,8 @@ def non_max_suppression(prediction, num_classes, conf_thres=0.5, nms_thres=0.4): for image_i, image_pred in enumerate(prediction): #----------------------------------------------------------# # 对种类预测部分取max。 - # class_conf [batch_size, num_anchors, 1] 种类置信度 - # class_pred [batch_size, num_anchors, 1] 种类 + # class_conf [num_anchors, 1] 种类置信度 + # class_pred [num_anchors, 1] 种类 #----------------------------------------------------------# class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True) @@ -208,7 +208,7 @@ def non_max_suppression(prediction, num_classes, conf_thres=0.5, nms_thres=0.4): if not image_pred.size(0): continue #-------------------------------------------------------------------------# - # detections [batch_size, num_anchors, 7] + # detections [num_anchors, 7] # 7的内容为:x1, y1, x2, y2, obj_conf, class_conf, class_pred #-------------------------------------------------------------------------# detections = torch.cat((image_pred[:, :5], class_conf.float(), class_pred.float()), 1) -- GitLab