diff --git a/nets/yolo_training.py b/nets/yolo_training.py index 656a49f078ad42bff43998470003dda113a7f505..b7b03cad7440583001627aacb7d2f5667f1036d2 100644 --- a/nets/yolo_training.py +++ b/nets/yolo_training.py @@ -274,9 +274,9 @@ class YOLOLoss(nn.Module): LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor # 生成网格,先验框中心,网格左上角 - grid_x = torch.linspace(0, in_w - 1, in_w).repeat(in_w, 1).repeat( + grid_x = torch.linspace(0, in_w - 1, in_w).repeat(in_h, 1).repeat( int(bs*self.num_anchors/3), 1, 1).view(x.shape).type(FloatTensor) - grid_y = torch.linspace(0, in_h - 1, in_h).repeat(in_h, 1).t().repeat( + grid_y = torch.linspace(0, in_h - 1, in_h).repeat(in_w, 1).t().repeat( int(bs*self.num_anchors/3), 1, 1).view(y.shape).type(FloatTensor) # 生成先验框的宽高