未验证 提交 cb78b709 编写于 作者: B Bubbliiiing 提交者: GitHub

Update yolo_training.py

上级 9c6099ba
......@@ -197,32 +197,39 @@ class YOLOLoss(nn.Module):
box_loss_scale_x = torch.zeros(bs, int(self.num_anchors/3), in_h, in_w, requires_grad=False)
box_loss_scale_y = torch.zeros(bs, int(self.num_anchors/3), in_h, in_w, requires_grad=False)
for b in range(bs):
for t in range(target[b].shape[0]):
# 计算出在特征层上的点位
gx = target[b][t, 0] * in_w
gy = target[b][t, 1] * in_h
gw = target[b][t, 2] * in_w
gh = target[b][t, 3] * in_h
# 计算出属于哪个网格
gi = int(gx)
gj = int(gy)
# 计算真实框的位置
gt_box = torch.FloatTensor(np.array([0, 0, gw, gh])).unsqueeze(0)
# 计算出所有先验框的位置
anchor_shapes = torch.FloatTensor(np.concatenate((np.zeros((self.num_anchors, 2)),
np.array(anchors)), 1))
# 计算重合程度
anch_ious = bbox_iou(gt_box, anchor_shapes)
# Find the best matching anchor box
best_n = np.argmax(anch_ious)
if len(target[b])==0:
continue
# 计算出在特征层上的点位
gxs = target[b][:, 0:1] * in_w
gys = target[b][:, 1:2] * in_h
gws = target[b][:, 2:3] * in_w
ghs = target[b][:, 3:4] * in_h
# 计算出属于哪个网格
gis = torch.floor(gxs)
gjs = torch.floor(gys)
# 计算真实框的位置
gt_box = torch.FloatTensor(torch.cat([torch.zeros_like(gws), torch.zeros_like(ghs), gws, ghs], 1))
# 计算出所有先验框的位置
anchor_shapes = torch.FloatTensor(torch.cat((torch.zeros((self.num_anchors, 2)), torch.FloatTensor(anchors)), 1))
# 计算重合程度
anch_ious = jaccard(gt_box, anchor_shapes)
# Find the best matching anchor box
best_ns = torch.argmax(anch_ious,dim=-1)
for i, best_n in enumerate(best_ns):
if best_n not in anchor_index:
continue
# Masks
gi = gis[i].long()
gj = gjs[i].long()
gx = gxs[i]
gy = gys[i]
gw = gws[i]
gh = ghs[i]
if (gj < in_h) and (gi < in_w):
best_n = best_n - subtract_index
# 判定哪些先验框内部真实的存在物体
......@@ -235,12 +242,12 @@ class YOLOLoss(nn.Module):
tw[b, best_n, gj, gi] = gw
th[b, best_n, gj, gi] = gh
# 用于获得xywh的比例
box_loss_scale_x[b, best_n, gj, gi] = target[b][t, 2]
box_loss_scale_y[b, best_n, gj, gi] = target[b][t, 3]
box_loss_scale_x[b, best_n, gj, gi] = target[b][i, 2]
box_loss_scale_y[b, best_n, gj, gi] = target[b][i, 3]
# 物体置信度
tconf[b, best_n, gj, gi] = 1
# 种类
tcls[b, best_n, gj, gi, int(target[b][t, 4])] = 1
tcls[b, best_n, gj, gi, target[b][i, 4].long()] = 1
else:
print('Step {0} out of bound'.format(b))
print('gj: {0}, height: {1} | gi: {2}, width: {3}'.format(gj, in_h, gi, in_w))
......@@ -251,6 +258,7 @@ class YOLOLoss(nn.Module):
t_box[...,3] = th
return mask, noobj_mask, t_box, tconf, tcls, box_loss_scale_x, box_loss_scale_y
def get_ignore(self,prediction,target,scaled_anchors,in_w, in_h,noobj_mask):
bs = len(target)
anchor_index = [[0,1,2],[3,4,5],[6,7,8]][self.feature_length.index(in_w)]
......@@ -292,12 +300,12 @@ class YOLOLoss(nn.Module):
gy = target[i][:, 1:2] * in_h
gw = target[i][:, 2:3] * in_w
gh = target[i][:, 3:4] * in_h
gt_box = torch.FloatTensor(np.concatenate([gx, gy, gw, gh],-1)).type(FloatTensor)
gt_box = torch.FloatTensor(torch.cat([gx, gy, gw, gh],-1)).type(FloatTensor)
anch_ious = jaccard(gt_box, pred_boxes_for_ignore)
for t in range(target[i].shape[0]):
anch_iou = anch_ious[t].view(pred_boxes[i].size()[:3])
noobj_mask[i][anch_iou>self.ignore_threshold] = 0
anch_ious_max, _ = torch.max(anch_ious,dim=0)
anch_ious_max = anch_ious_max.view(pred_boxes[i].size()[:3])
noobj_mask[i][anch_ious_max>self.ignore_threshold] = 0
return noobj_mask, pred_boxes
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册