From 19384f63dea08f32ecedd8e9b56bef10b77fa142 Mon Sep 17 00:00:00 2001 From: Bubbliiiing <47347516+bubbliiiing@users.noreply.github.com> Date: Wed, 25 Nov 2020 17:18:55 +0800 Subject: [PATCH] Update utils.py --- utils/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/utils/utils.py b/utils/utils.py index d92e705..e67dcc3 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -49,10 +49,10 @@ class DecodeBox(nn.Module): FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor - # 生成网格,先验框中心,网格左上角 - grid_x = torch.linspace(0, input_width - 1, input_width).repeat(input_width, 1).repeat( + # 生成网格,先验框中心,网格左上角 batch_size,3,13,13 + grid_x = torch.linspace(0, input_width - 1, input_width).repeat(input_height, 1).repeat( batch_size * self.num_anchors, 1, 1).view(x.shape).type(FloatTensor) - grid_y = torch.linspace(0, input_height - 1, input_height).repeat(input_height, 1).t().repeat( + grid_y = torch.linspace(0, input_height - 1, input_height).repeat(input_width, 1).t().repeat( batch_size * self.num_anchors, 1, 1).view(y.shape).type(FloatTensor) # 生成先验框的宽高 -- GitLab