From dbd1a08a6dfd6c2c3c15b848b6720470375ab09d Mon Sep 17 00:00:00 2001 From: Bubbliiiing <47347516+bubbliiiing@users.noreply.github.com> Date: Wed, 25 Nov 2020 13:55:29 +0800 Subject: [PATCH] Update yolo_training.py --- nets/yolo_training.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nets/yolo_training.py b/nets/yolo_training.py index 656a49f..b7b03ca 100644 --- a/nets/yolo_training.py +++ b/nets/yolo_training.py @@ -274,9 +274,9 @@ class YOLOLoss(nn.Module): LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor # 生成网格,先验框中心,网格左上角 - grid_x = torch.linspace(0, in_w - 1, in_w).repeat(in_w, 1).repeat( + grid_x = torch.linspace(0, in_w - 1, in_w).repeat(in_h, 1).repeat( int(bs*self.num_anchors/3), 1, 1).view(x.shape).type(FloatTensor) - grid_y = torch.linspace(0, in_h - 1, in_h).repeat(in_h, 1).t().repeat( + grid_y = torch.linspace(0, in_h - 1, in_h).repeat(in_w, 1).t().repeat( int(bs*self.num_anchors/3), 1, 1).view(y.shape).type(FloatTensor) # 生成先验框的宽高 -- GitLab