diff --git a/nets/yolo_training.py b/nets/yolo_training.py index dc068b486f517ba4f3a78b6139282e40f28630f3..25b14405e6109b02b8b1886c9ae9507e8e9c6951 100644 --- a/nets/yolo_training.py +++ b/nets/yolo_training.py @@ -231,9 +231,9 @@ class YOLOLoss(nn.Module): LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor # 生成网格,先验框中心,网格左上角 - grid_x = torch.linspace(0, in_w - 1, in_w).repeat(in_w, 1).repeat( + grid_x = torch.linspace(0, in_w - 1, in_w).repeat(in_h, 1).repeat( int(bs*self.num_anchors/3), 1, 1).view(x.shape).type(FloatTensor) - grid_y = torch.linspace(0, in_h - 1, in_h).repeat(in_h, 1).t().repeat( + grid_y = torch.linspace(0, in_h - 1, in_h).repeat(in_w, 1).t().repeat( int(bs*self.num_anchors/3), 1, 1).view(y.shape).type(FloatTensor) # 生成先验框的宽高