未验证 提交 012fdb70 编写于 作者: B Bubbliiiing 提交者: GitHub

Update utils.py

上级 86f266f3
......@@ -239,8 +239,8 @@ def non_max_suppression(prediction, num_classes, conf_thres=0.5, nms_thres=0.4):
for image_i, image_pred in enumerate(prediction):
#----------------------------------------------------------#
# 对种类预测部分取max。
# class_conf [batch_size, num_anchors, 1] 种类置信度
# class_pred [batch_size, num_anchors, 1] 种类
# class_conf [num_anchors, 1] 种类置信度
# class_pred [num_anchors, 1] 种类
#----------------------------------------------------------#
class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True)
......@@ -258,7 +258,7 @@ def non_max_suppression(prediction, num_classes, conf_thres=0.5, nms_thres=0.4):
if not image_pred.size(0):
continue
#-------------------------------------------------------------------------#
# detections [batch_size, num_anchors, 7]
# detections [num_anchors, 7]
# 7的内容为:x1, y1, x2, y2, obj_conf, class_conf, class_pred
#-------------------------------------------------------------------------#
detections = torch.cat((image_pred[:, :5], class_conf.float(), class_pred.float()), 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册