未验证 提交 93e5188c 编写于 作者: B Bubbliiiing 提交者: GitHub

Update yolo_training.py

上级 daded78f
......@@ -172,16 +172,16 @@ class YOLOLoss(nn.Module):
box_loss_scale = 2 - box_loss_scale_x * box_loss_scale_y
# 计算中心偏移情况的loss,使用BCELoss效果好一些
loss_x = torch.sum(BCELoss(x, tx) / bs * box_loss_scale * mask)
loss_y = torch.sum(BCELoss(y, ty) / bs * box_loss_scale * mask)
loss_x = torch.sum(BCELoss(x, tx) * box_loss_scale * mask)
loss_y = torch.sum(BCELoss(y, ty) * box_loss_scale * mask)
# 计算宽高调整值的loss
loss_w = torch.sum(MSELoss(w, tw) / bs * 0.5 * box_loss_scale * mask)
loss_h = torch.sum(MSELoss(h, th) / bs * 0.5 * box_loss_scale * mask)
loss_w = torch.sum(MSELoss(w, tw) * 0.5 * box_loss_scale * mask)
loss_h = torch.sum(MSELoss(h, th) * 0.5 * box_loss_scale * mask)
# 计算置信度的loss
loss_conf = torch.sum(BCELoss(conf, mask) * mask / bs) + \
torch.sum(BCELoss(conf, mask) * noobj_mask / bs)
loss_conf = torch.sum(BCELoss(conf, mask) * mask) + \
torch.sum(BCELoss(conf, mask) * noobj_mask)
loss_cls = torch.sum(BCELoss(pred_cls[mask == 1], tcls[mask == 1])/bs)
loss_cls = torch.sum(BCELoss(pred_cls[mask == 1], tcls[mask == 1]))
loss = loss_x * self.lambda_xy + loss_y * self.lambda_xy + \
loss_w * self.lambda_wh + loss_h * self.lambda_wh + \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册