From cb78b70977e7ae208d0209ed166fc057bc7b20f7 Mon Sep 17 00:00:00 2001 From: Bubbliiiing <47347516+bubbliiiing@users.noreply.github.com> Date: Mon, 28 Sep 2020 15:59:57 +0800 Subject: [PATCH] Update yolo_training.py --- nets/yolo_training.py | 68 ++++++++++++++++++++++++------------------- 1 file changed, 38 insertions(+), 30 deletions(-) diff --git a/nets/yolo_training.py b/nets/yolo_training.py index a26feee..b58414a 100644 --- a/nets/yolo_training.py +++ b/nets/yolo_training.py @@ -197,32 +197,39 @@ class YOLOLoss(nn.Module): box_loss_scale_x = torch.zeros(bs, int(self.num_anchors/3), in_h, in_w, requires_grad=False) box_loss_scale_y = torch.zeros(bs, int(self.num_anchors/3), in_h, in_w, requires_grad=False) for b in range(bs): - for t in range(target[b].shape[0]): - # 计算出在特征层上的点位 - gx = target[b][t, 0] * in_w - gy = target[b][t, 1] * in_h - - gw = target[b][t, 2] * in_w - gh = target[b][t, 3] * in_h - - # 计算出属于哪个网格 - gi = int(gx) - gj = int(gy) - - # 计算真实框的位置 - gt_box = torch.FloatTensor(np.array([0, 0, gw, gh])).unsqueeze(0) - - # 计算出所有先验框的位置 - anchor_shapes = torch.FloatTensor(np.concatenate((np.zeros((self.num_anchors, 2)), - np.array(anchors)), 1)) - # 计算重合程度 - anch_ious = bbox_iou(gt_box, anchor_shapes) - - # Find the best matching anchor box - best_n = np.argmax(anch_ious) + if len(target[b])==0: + continue + # 计算出在特征层上的点位 + gxs = target[b][:, 0:1] * in_w + gys = target[b][:, 1:2] * in_h + + gws = target[b][:, 2:3] * in_w + ghs = target[b][:, 3:4] * in_h + + # 计算出属于哪个网格 + gis = torch.floor(gxs) + gjs = torch.floor(gys) + + # 计算真实框的位置 + gt_box = torch.FloatTensor(torch.cat([torch.zeros_like(gws), torch.zeros_like(ghs), gws, ghs], 1)) + + # 计算出所有先验框的位置 + anchor_shapes = torch.FloatTensor(torch.cat((torch.zeros((self.num_anchors, 2)), torch.FloatTensor(anchors)), 1)) + # 计算重合程度 + anch_ious = jaccard(gt_box, anchor_shapes) + + # Find the best matching anchor box + best_ns = torch.argmax(anch_ious,dim=-1) + for i, best_n in enumerate(best_ns): if best_n not in anchor_index: continue # Masks + gi = gis[i].long() + gj = gjs[i].long() + gx = gxs[i] + gy = gys[i] + gw = gws[i] + gh = ghs[i] if (gj < in_h) and (gi < in_w): best_n = best_n - subtract_index # 判定哪些先验框内部真实的存在物体 @@ -235,12 +242,12 @@ class YOLOLoss(nn.Module): tw[b, best_n, gj, gi] = gw th[b, best_n, gj, gi] = gh # 用于获得xywh的比例 - box_loss_scale_x[b, best_n, gj, gi] = target[b][t, 2] - box_loss_scale_y[b, best_n, gj, gi] = target[b][t, 3] + box_loss_scale_x[b, best_n, gj, gi] = target[b][i, 2] + box_loss_scale_y[b, best_n, gj, gi] = target[b][i, 3] # 物体置信度 tconf[b, best_n, gj, gi] = 1 # 种类 - tcls[b, best_n, gj, gi, int(target[b][t, 4])] = 1 + tcls[b, best_n, gj, gi, target[b][i, 4].long()] = 1 else: print('Step {0} out of bound'.format(b)) print('gj: {0}, height: {1} | gi: {2}, width: {3}'.format(gj, in_h, gi, in_w)) @@ -251,6 +258,7 @@ class YOLOLoss(nn.Module): t_box[...,3] = th return mask, noobj_mask, t_box, tconf, tcls, box_loss_scale_x, box_loss_scale_y + def get_ignore(self,prediction,target,scaled_anchors,in_w, in_h,noobj_mask): bs = len(target) anchor_index = [[0,1,2],[3,4,5],[6,7,8]][self.feature_length.index(in_w)] @@ -292,12 +300,12 @@ class YOLOLoss(nn.Module): gy = target[i][:, 1:2] * in_h gw = target[i][:, 2:3] * in_w gh = target[i][:, 3:4] * in_h - gt_box = torch.FloatTensor(np.concatenate([gx, gy, gw, gh],-1)).type(FloatTensor) + gt_box = torch.FloatTensor(torch.cat([gx, gy, gw, gh],-1)).type(FloatTensor) anch_ious = jaccard(gt_box, pred_boxes_for_ignore) - for t in range(target[i].shape[0]): - anch_iou = anch_ious[t].view(pred_boxes[i].size()[:3]) - noobj_mask[i][anch_iou>self.ignore_threshold] = 0 + anch_ious_max, _ = torch.max(anch_ious,dim=0) + anch_ious_max = anch_ious_max.view(pred_boxes[i].size()[:3]) + noobj_mask[i][anch_ious_max>self.ignore_threshold] = 0 return noobj_mask, pred_boxes -- GitLab