From 2eec48c847d3240f36b71ef97a8300495daaf719 Mon Sep 17 00:00:00 2001 From: Bubbliiiing <47347516+bubbliiiing@users.noreply.github.com> Date: Tue, 29 Sep 2020 09:59:46 +0800 Subject: [PATCH] Update yolo_training.py --- nets/yolo_training.py | 75 +++++++++++++++++++++++-------------------- 1 file changed, 41 insertions(+), 34 deletions(-) diff --git a/nets/yolo_training.py b/nets/yolo_training.py index 263c45b..dc068b4 100644 --- a/nets/yolo_training.py +++ b/nets/yolo_training.py @@ -155,51 +155,59 @@ class YOLOLoss(nn.Module): box_loss_scale_x = torch.zeros(bs, int(self.num_anchors/3), in_h, in_w, requires_grad=False) box_loss_scale_y = torch.zeros(bs, int(self.num_anchors/3), in_h, in_w, requires_grad=False) - for b in range(bs): - for t in range(target[b].shape[0]): - # 计算出在特征层上的点位 - gx = target[b][t, 0] * in_w - gy = target[b][t, 1] * in_h - - gw = target[b][t, 2] * in_w - gh = target[b][t, 3] * in_h - - # 计算出属于哪个网格 - gi = int(gx) - gj = int(gy) - - # 计算真实框的位置 - gt_box = torch.FloatTensor(np.array([0, 0, gw, gh])).unsqueeze(0) - - # 计算出所有先验框的位置 - anchor_shapes = torch.FloatTensor(np.concatenate((np.zeros((self.num_anchors, 2)), - np.array(anchors)), 1)) - # 计算重合程度 - anch_ious = bbox_iou(gt_box, anchor_shapes) - - # Find the best matching anchor box - best_n = np.argmax(anch_ious) + for b in range(bs): + if len(target[b])==0: + continue + # 计算出在特征层上的点位 + gxs = target[b][:, 0:1] * in_w + gys = target[b][:, 1:2] * in_h + + gws = target[b][:, 2:3] * in_w + ghs = target[b][:, 3:4] * in_h + + # 计算出属于哪个网格 + gis = torch.floor(gxs) + gjs = torch.floor(gys) + + # 计算真实框的位置 + gt_box = torch.FloatTensor(torch.cat([torch.zeros_like(gws), torch.zeros_like(ghs), gws, ghs], 1)) + + # 计算出所有先验框的位置 + anchor_shapes = torch.FloatTensor(torch.cat((torch.zeros((self.num_anchors, 2)), torch.FloatTensor(anchors)), 1)) + # 计算重合程度 + anch_ious = jaccard(gt_box, anchor_shapes) + + # Find the best matching anchor box + best_ns = torch.argmax(anch_ious,dim=-1) + for i, best_n in enumerate(best_ns): if best_n not in anchor_index: continue # Masks + gi = gis[i].long() + gj = gjs[i].long() + gx = gxs[i] + gy = gys[i] + gw = gws[i] + gh = ghs[i] + # Masks if (gj < in_h) and (gi < in_w): best_n = best_n - subtract_index # 判定哪些先验框内部真实的存在物体 noobj_mask[b, best_n, gj, gi] = 0 mask[b, best_n, gj, gi] = 1 # 计算先验框中心调整参数 - tx[b, best_n, gj, gi] = gx - gi - ty[b, best_n, gj, gi] = gy - gj + tx[b, best_n, gj, gi] = gx - gi.float() + ty[b, best_n, gj, gi] = gy - gj.float() # 计算先验框宽高调整参数 tw[b, best_n, gj, gi] = math.log(gw / anchors[best_n+subtract_index][0]) th[b, best_n, gj, gi] = math.log(gh / anchors[best_n+subtract_index][1]) # 用于获得xywh的比例 - box_loss_scale_x[b, best_n, gj, gi] = target[b][t, 2] - box_loss_scale_y[b, best_n, gj, gi] = target[b][t, 3] + box_loss_scale_x[b, best_n, gj, gi] = target[b][i, 2] + box_loss_scale_y[b, best_n, gj, gi] = target[b][i, 3] # 物体置信度 tconf[b, best_n, gj, gi] = 1 # 种类 - tcls[b, best_n, gj, gi, int(target[b][t, 4])] = 1 + tcls[b, best_n, gj, gi, int(target[b][i, 4])] = 1 else: print('Step {0} out of bound'.format(b)) print('gj: {0}, height: {1} | gi: {2}, width: {3}'.format(gj, in_h, gi, in_w)) @@ -245,18 +253,17 @@ class YOLOLoss(nn.Module): for i in range(bs): pred_boxes_for_ignore = pred_boxes[i] pred_boxes_for_ignore = pred_boxes_for_ignore.view(-1, 4) - if len(target[i]) > 0: gx = target[i][:, 0:1] * in_w gy = target[i][:, 1:2] * in_h gw = target[i][:, 2:3] * in_w gh = target[i][:, 3:4] * in_h - gt_box = torch.FloatTensor(np.concatenate([gx, gy, gw, gh],-1)).type(FloatTensor) + gt_box = torch.FloatTensor(torch.cat([gx, gy, gw, gh],-1)).type(FloatTensor) anch_ious = jaccard(gt_box, pred_boxes_for_ignore) - for t in range(target[i].shape[0]): - anch_iou = anch_ious[t].view(pred_boxes[i].size()[:3]) - noobj_mask[i][anch_iou>self.ignore_threshold] = 0 + anch_ious_max, _ = torch.max(anch_ious,dim=0) + anch_ious_max = anch_ious_max.view(pred_boxes[i].size()[:3]) + noobj_mask[i][anch_ious_max>self.ignore_threshold] = 0 # print(torch.max(anch_ious)) return noobj_mask -- GitLab