未验证 提交 2eec48c8 编写于 作者: B Bubbliiiing 提交者: GitHub

Update yolo_training.py

上级 e90017f0
...@@ -155,51 +155,59 @@ class YOLOLoss(nn.Module): ...@@ -155,51 +155,59 @@ class YOLOLoss(nn.Module):
box_loss_scale_x = torch.zeros(bs, int(self.num_anchors/3), in_h, in_w, requires_grad=False) box_loss_scale_x = torch.zeros(bs, int(self.num_anchors/3), in_h, in_w, requires_grad=False)
box_loss_scale_y = torch.zeros(bs, int(self.num_anchors/3), in_h, in_w, requires_grad=False) box_loss_scale_y = torch.zeros(bs, int(self.num_anchors/3), in_h, in_w, requires_grad=False)
for b in range(bs): for b in range(bs):
for t in range(target[b].shape[0]): if len(target[b])==0:
# 计算出在特征层上的点位 continue
gx = target[b][t, 0] * in_w # 计算出在特征层上的点位
gy = target[b][t, 1] * in_h gxs = target[b][:, 0:1] * in_w
gys = target[b][:, 1:2] * in_h
gw = target[b][t, 2] * in_w
gh = target[b][t, 3] * in_h gws = target[b][:, 2:3] * in_w
ghs = target[b][:, 3:4] * in_h
# 计算出属于哪个网格
gi = int(gx) # 计算出属于哪个网格
gj = int(gy) gis = torch.floor(gxs)
gjs = torch.floor(gys)
# 计算真实框的位置
gt_box = torch.FloatTensor(np.array([0, 0, gw, gh])).unsqueeze(0) # 计算真实框的位置
gt_box = torch.FloatTensor(torch.cat([torch.zeros_like(gws), torch.zeros_like(ghs), gws, ghs], 1))
# 计算出所有先验框的位置
anchor_shapes = torch.FloatTensor(np.concatenate((np.zeros((self.num_anchors, 2)), # 计算出所有先验框的位置
np.array(anchors)), 1)) anchor_shapes = torch.FloatTensor(torch.cat((torch.zeros((self.num_anchors, 2)), torch.FloatTensor(anchors)), 1))
# 计算重合程度 # 计算重合程度
anch_ious = bbox_iou(gt_box, anchor_shapes) anch_ious = jaccard(gt_box, anchor_shapes)
# Find the best matching anchor box # Find the best matching anchor box
best_n = np.argmax(anch_ious) best_ns = torch.argmax(anch_ious,dim=-1)
for i, best_n in enumerate(best_ns):
if best_n not in anchor_index: if best_n not in anchor_index:
continue continue
# Masks # Masks
gi = gis[i].long()
gj = gjs[i].long()
gx = gxs[i]
gy = gys[i]
gw = gws[i]
gh = ghs[i]
# Masks
if (gj < in_h) and (gi < in_w): if (gj < in_h) and (gi < in_w):
best_n = best_n - subtract_index best_n = best_n - subtract_index
# 判定哪些先验框内部真实的存在物体 # 判定哪些先验框内部真实的存在物体
noobj_mask[b, best_n, gj, gi] = 0 noobj_mask[b, best_n, gj, gi] = 0
mask[b, best_n, gj, gi] = 1 mask[b, best_n, gj, gi] = 1
# 计算先验框中心调整参数 # 计算先验框中心调整参数
tx[b, best_n, gj, gi] = gx - gi tx[b, best_n, gj, gi] = gx - gi.float()
ty[b, best_n, gj, gi] = gy - gj ty[b, best_n, gj, gi] = gy - gj.float()
# 计算先验框宽高调整参数 # 计算先验框宽高调整参数
tw[b, best_n, gj, gi] = math.log(gw / anchors[best_n+subtract_index][0]) tw[b, best_n, gj, gi] = math.log(gw / anchors[best_n+subtract_index][0])
th[b, best_n, gj, gi] = math.log(gh / anchors[best_n+subtract_index][1]) th[b, best_n, gj, gi] = math.log(gh / anchors[best_n+subtract_index][1])
# 用于获得xywh的比例 # 用于获得xywh的比例
box_loss_scale_x[b, best_n, gj, gi] = target[b][t, 2] box_loss_scale_x[b, best_n, gj, gi] = target[b][i, 2]
box_loss_scale_y[b, best_n, gj, gi] = target[b][t, 3] box_loss_scale_y[b, best_n, gj, gi] = target[b][i, 3]
# 物体置信度 # 物体置信度
tconf[b, best_n, gj, gi] = 1 tconf[b, best_n, gj, gi] = 1
# 种类 # 种类
tcls[b, best_n, gj, gi, int(target[b][t, 4])] = 1 tcls[b, best_n, gj, gi, int(target[b][i, 4])] = 1
else: else:
print('Step {0} out of bound'.format(b)) print('Step {0} out of bound'.format(b))
print('gj: {0}, height: {1} | gi: {2}, width: {3}'.format(gj, in_h, gi, in_w)) print('gj: {0}, height: {1} | gi: {2}, width: {3}'.format(gj, in_h, gi, in_w))
...@@ -245,18 +253,17 @@ class YOLOLoss(nn.Module): ...@@ -245,18 +253,17 @@ class YOLOLoss(nn.Module):
for i in range(bs): for i in range(bs):
pred_boxes_for_ignore = pred_boxes[i] pred_boxes_for_ignore = pred_boxes[i]
pred_boxes_for_ignore = pred_boxes_for_ignore.view(-1, 4) pred_boxes_for_ignore = pred_boxes_for_ignore.view(-1, 4)
if len(target[i]) > 0: if len(target[i]) > 0:
gx = target[i][:, 0:1] * in_w gx = target[i][:, 0:1] * in_w
gy = target[i][:, 1:2] * in_h gy = target[i][:, 1:2] * in_h
gw = target[i][:, 2:3] * in_w gw = target[i][:, 2:3] * in_w
gh = target[i][:, 3:4] * in_h gh = target[i][:, 3:4] * in_h
gt_box = torch.FloatTensor(np.concatenate([gx, gy, gw, gh],-1)).type(FloatTensor) gt_box = torch.FloatTensor(torch.cat([gx, gy, gw, gh],-1)).type(FloatTensor)
anch_ious = jaccard(gt_box, pred_boxes_for_ignore) anch_ious = jaccard(gt_box, pred_boxes_for_ignore)
for t in range(target[i].shape[0]): anch_ious_max, _ = torch.max(anch_ious,dim=0)
anch_iou = anch_ious[t].view(pred_boxes[i].size()[:3]) anch_ious_max = anch_ious_max.view(pred_boxes[i].size()[:3])
noobj_mask[i][anch_iou>self.ignore_threshold] = 0 noobj_mask[i][anch_ious_max>self.ignore_threshold] = 0
# print(torch.max(anch_ious)) # print(torch.max(anch_ious))
return noobj_mask return noobj_mask
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册