From 5608758a394881e4ce6c1fe707f5d6dfe64e3b5c Mon Sep 17 00:00:00 2001 From: Wang Feng Date: Wed, 1 Jul 2020 15:13:06 +0800 Subject: [PATCH] feat(detection): Add Faster-RCNN in detection (#28) --- .github/workflows/ci.yml | 4 - README.md | 7 +- hubconf.py | 7 +- official/vision/detection/README.md | 42 ++- .../faster_rcnn_fpn_res50_coco_1x_800size.py | 29 ++ ...r_rcnn_fpn_res50_coco_1x_800size_syncbn.py | 38 +++ official/vision/detection/layers/basic/nn.py | 2 +- .../vision/detection/layers/basic/norm.py | 2 +- .../vision/detection/layers/det/__init__.py | 3 + .../vision/detection/layers/det/anchor.py | 21 +- .../vision/detection/layers/det/box_utils.py | 32 +- official/vision/detection/layers/det/fpn.py | 39 ++- official/vision/detection/layers/det/loss.py | 23 +- .../vision/detection/layers/det/pooler.py | 63 ++++ official/vision/detection/layers/det/rcnn.py | 175 +++++++++++ official/vision/detection/layers/det/rpn.py | 290 ++++++++++++++++++ official/vision/detection/models/__init__.py | 1 + .../detection/models/faster_rcnn_fpn.py | 212 +++++++++++++ official/vision/detection/models/retinanet.py | 9 +- .../retinanet_res50_coco_1x_800size.py | 2 + .../retinanet_res50_coco_1x_800size_syncbn.py | 4 +- .../retinanet_res50_objects365_1x_800size.py | 2 - official/vision/detection/tools/gpu_nms.py | 98 ++++++ .../vision/detection/tools/gpu_nms/nms.cu | 201 ++++++++++++ official/vision/detection/tools/train.py | 16 +- 25 files changed, 1245 insertions(+), 77 deletions(-) create mode 100644 official/vision/detection/faster_rcnn_fpn_res50_coco_1x_800size.py create mode 100644 official/vision/detection/faster_rcnn_fpn_res50_coco_1x_800size_syncbn.py create mode 100644 official/vision/detection/layers/det/pooler.py create mode 100644 official/vision/detection/layers/det/rcnn.py create mode 100644 official/vision/detection/layers/det/rpn.py create mode 100644 official/vision/detection/models/faster_rcnn_fpn.py create mode 100644 official/vision/detection/tools/gpu_nms.py create mode 100644 official/vision/detection/tools/gpu_nms/nms.cu diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e4da0dd..04d6200 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -47,7 +47,3 @@ jobs: exit $pylint_ret fi echo "All lint steps passed!" - - - name: Import hubconf check - run: | - python -c "import hubconf" diff --git a/README.md b/README.md index 09b9cbc..92bfcae 100644 --- a/README.md +++ b/README.md @@ -75,9 +75,10 @@ export PYTHONPATH=/path/to/models:$PYTHONPATH 目标检测同样是计算机视觉中的常见任务,我们提供了一个经典的目标检测模型[retinanet](./official/vision/detection),这个模型在**COCO验证集**上的测试结果如下: -| 模型 | mAP
@5-95 | -| :---: | :---: | -| retinanet-res50-1x-800size | 36.0 | +| 模型 | mAP
@5-95 | +| :---: | :---: | +| retinanet-res50-1x-800size | 36.0 | +| faster-rcnn-fpn-res50-1x-800size | 37.3 | ### 图像分割 diff --git a/hubconf.py b/hubconf.py index ab6e6c7..fc21310 100644 --- a/hubconf.py +++ b/hubconf.py @@ -28,10 +28,15 @@ from official.nlp.bert.model import ( wwm_cased_L_24_H_1024_A_16, ) +from official.vision.detection.faster_rcnn_fpn_res50_coco_1x_800size import ( + faster_rcnn_fpn_res50_coco_1x_800size, +) + from official.vision.detection.retinanet_res50_coco_1x_800size import ( retinanet_res50_coco_1x_800size, ) -from official.vision.detection.models import RetinaNet + +from official.vision.detection.models import FasterRCNN, RetinaNet from official.vision.detection.tools.test import DetEvaluator from official.vision.segmentation.deeplabv3plus import ( diff --git a/official/vision/detection/README.md b/official/vision/detection/README.md index 4b0481d..ea78f84 100644 --- a/official/vision/detection/README.md +++ b/official/vision/detection/README.md @@ -1,20 +1,21 @@ -# Megengine RetinaNet +# Megengine Detection Models ## 介绍 -本目录包含了采用MegEngine实现的经典[RetinaNet](https://arxiv.org/pdf/1708.02002>)网络结构,同时提供了在COCO2017数据集上的完整训练和测试代码。 +本目录包含了采用MegEngine实现的经典网络结构,包括[RetinaNet](https://arxiv.org/pdf/1708.02002>)、[Faster R-CNN with FPN](https://arxiv.org/pdf/1612.03144.pdf)等,同时提供了在COCO2017数据集上的完整训练和测试代码。 -网络的性能在COCO2017验证集上的测试结果如下: +网络的性能在COCO2017数据集上的测试结果如下: -| 模型 | mAP
@5-95 | batch
/gpu | gpu | speed
(8gpu) | speed
(1gpu) | -| --- | --- | --- | --- | --- | --- | -| retinanet-res50-coco-1x-800size | 36.0 | 2 | 2080ti | 2.27(it/s) | 3.7(it/s) | +| 模型 | mAP
@5-95 | batch
/gpu | gpu | trainging speed
(8gpu) | training speed
(1gpu) | +| --- | --- | --- | --- | --- | --- | +| retinanet-res50-coco-1x-800size | 36.0 | 2 | 2080Ti | 2.27(it/s) | 3.7(it/s) | +| faster-rcnn-fpn-res50-coco-1x-800size | 37.3 | 2 | 2080Ti | 1.9(it/s) | 3.1(it/s) | * MegEngine v0.4.0 ## 如何使用 -模型训练好之后,可以通过如下命令测试单张图片: +以RetinaNet为例,模型训练好之后,可以通过如下命令测试单张图片: ```bash python3 tools/inference.py -f retinanet_res50_coco_1x_800size.py \ @@ -60,17 +61,33 @@ python3 tools/train.py -f retinanet_res50_coco_1x_800size.py \ `tools/train.py`提供了灵活的命令行选项,包括: -- `-f`, 所需要训练的网络结构描述文件。 +- `-f`, 所需要训练的网络结构描述文件。可以是RetinaNet、Faster R-CNN等. - `-n`, 用于训练的devices(gpu)数量,默认使用所有可用的gpu. - `-w`, 预训练的backbone网络权重的路径。 - `--batch_size`,训练时采用的`batch size`, 默认2,表示每张卡训2张图。 - `--dataset-dir`, COCO2017数据集的上级目录,默认`/data/datasets`。 -默认情况下模型会存在 `log-of-retinanet_res50_1x_800size`目录下。 +默认情况下模型会存在 `log-of-模型名`目录下。 + +5. 编译可能需要的lib + +GPU NMS位于tools下的GPU NMS文件夹下面,我们需要进入tools文件夹下进行编译. + +首先需要找到MegEngine编译的头文件所在路径,可以通过命令 + +```bash +python3 -c "import megengine as mge; print(mge.__file__)" +``` +将输出结果中__init__.py之前的部分复制(以MegEngine结尾),将其赋值给shell变量MGE,接下来,运行如下命令进行编译。 + +```bash +cd tools +nvcc -I $MGE/_internal/include -shared -o lib_nms.so -Xcompiler "-fno-strict-aliasing -fPIC" gpu_nms/nms.cu +``` ## 如何测试 -在训练的过程中,可以通过如下命令测试模型在`COCO2017`验证集的性能: +在得到训练完保存的模型之后,可以通过tools下的test.py文件测试模型在`COCO2017`验证集的性能: ```bash python3 tools/test.py -f retinanet_res50_coco_1x_800size.py \ @@ -89,5 +106,6 @@ python3 tools/test.py -f retinanet_res50_coco_1x_800size.py \ ## 参考文献 - [Focal Loss for Dense Object Detection](https://arxiv.org/pdf/1708.02002) Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, Piotr Dollár. Proceedings of the IEEE international conference on computer vision. 2017: 2980-2988. -- [Microsoft COCO: Common Objects in Context](https://arxiv.org/pdf/1405.0312.pdf) Lin, Tsung-Yi and Maire, Michael and Belongie, Serge and Hays, James and Perona, Pietro and Ramanan, Deva and Dollár, Piotr and Zitnick, C Lawrence -Lin T Y, Maire M, Belongie S, et al. European conference on computer vision. Springer, Cham, 2014: 740-755. +- [Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks](https://arxiv.org/pdf/1506.01497.pdf) S. Ren, K. He, R. Girshick, and J. Sun. In: Neural Information Processing Systems(NIPS)(2015). +- [Feature Pyramid Networks for Object Detection](https://arxiv.org/pdf/1612.03144.pdf) T. Lin, P. Dollár, R. Girshick, K. He, B. Hariharan and S. Belongie. 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Honolulu, HI, 2017, pp. 936-944, doi: 10.1109/CVPR.2017.106. +- [Microsoft COCO: Common Objects in Context](https://arxiv.org/pdf/1405.0312.pdf) Lin, Tsung-Yi and Maire, Michael and Belongie, Serge and Hays, James and Perona, Pietro and Ramanan, Deva and Dollár, Piotr and Zitnick, C Lawrence, Lin T Y, Maire M, Belongie S, et al. European conference on computer vision. Springer, Cham, 2014: 740-755. diff --git a/official/vision/detection/faster_rcnn_fpn_res50_coco_1x_800size.py b/official/vision/detection/faster_rcnn_fpn_res50_coco_1x_800size.py new file mode 100644 index 0000000..c085b07 --- /dev/null +++ b/official/vision/detection/faster_rcnn_fpn_res50_coco_1x_800size.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from megengine import hub + +from official.vision.detection import models + + +@hub.pretrained( + "https://data.megengine.org.cn/models/weights/" + "faster_rcnn_fpn_ec2e80b9_res50_1x_800size_37dot3.pkl" +) +def faster_rcnn_fpn_res50_coco_1x_800size(batch_size=1, **kwargs): + r""" + Faster-RCNN FPN trained from COCO dataset. + `"Faster-RCNN" `_ + `"FPN" `_ + `"COCO" `_ + """ + return models.FasterRCNN(models.FasterRCNNConfig(), batch_size=batch_size, **kwargs) + + +Net = models.FasterRCNN +Cfg = models.FasterRCNNConfig diff --git a/official/vision/detection/faster_rcnn_fpn_res50_coco_1x_800size_syncbn.py b/official/vision/detection/faster_rcnn_fpn_res50_coco_1x_800size_syncbn.py new file mode 100644 index 0000000..c557d46 --- /dev/null +++ b/official/vision/detection/faster_rcnn_fpn_res50_coco_1x_800size_syncbn.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from megengine import hub + +from official.vision.detection import models + + +class CustomFasterRCNNFPNConfig(models.FasterRCNNConfig): + def __init__(self): + super().__init__() + + self.resnet_norm = "SyncBN" + self.fpn_norm = "SyncBN" + + +@hub.pretrained( + "https://data.megengine.org.cn/models/weights/" + "faster_rcnn_fpn_cf5c020b_res50_1x_800size_syncbn_37dot6.pkl" +) +def faster_rcnn_fpn_res50_coco_1x_800size_syncbn(batch_size=1, **kwargs): + r""" + Faster-RCNN FPN trained from COCO dataset. + `"Faster-RCNN" `_ + `"FPN" `_ + `"COCO" `_ + `"SyncBN" `_ + """ + return models.FasterRCNN(CustomFasterRCNNFPNConfig(), batch_size=batch_size, **kwargs) + + +Net = models.FasterRCNN +Cfg = CustomFasterRCNNFPNConfig diff --git a/official/vision/detection/layers/basic/nn.py b/official/vision/detection/layers/basic/nn.py index bc2d98c..349df52 100644 --- a/official/vision/detection/layers/basic/nn.py +++ b/official/vision/detection/layers/basic/nn.py @@ -22,7 +22,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # # This file has been modified by Megvii ("Megvii Modifications"). -# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved. +# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. # --------------------------------------------------------------------- from collections import namedtuple diff --git a/official/vision/detection/layers/basic/norm.py b/official/vision/detection/layers/basic/norm.py index 5d0463a..43a0917 100644 --- a/official/vision/detection/layers/basic/norm.py +++ b/official/vision/detection/layers/basic/norm.py @@ -22,7 +22,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # # This file has been modified by Megvii ("Megvii Modifications"). -# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved. +# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. # --------------------------------------------------------------------- import megengine.module as M import numpy as np diff --git a/official/vision/detection/layers/det/__init__.py b/official/vision/detection/layers/det/__init__.py index 33d29ce..9dfcc16 100644 --- a/official/vision/detection/layers/det/__init__.py +++ b/official/vision/detection/layers/det/__init__.py @@ -10,7 +10,10 @@ from .anchor import * from .box_utils import * from .fpn import * from .loss import * +from .pooler import * +from .rcnn import * from .retinanet import * +from .rpn import * _EXCLUDE = {} __all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")] diff --git a/official/vision/detection/layers/det/anchor.py b/official/vision/detection/layers/det/anchor.py index 4ceb756..505553f 100644 --- a/official/vision/detection/layers/det/anchor.py +++ b/official/vision/detection/layers/det/anchor.py @@ -1,18 +1,4 @@ # -*- coding: utf-8 -*- -# Copyright 2018-2019 Open-MMLab. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# --------------------------------------------------------------------- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. @@ -20,10 +6,6 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# -# This file has been modified by Megvii ("Megvii Modifications"). -# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved. -# --------------------------------------------------------------------- from abc import ABCMeta, abstractmethod import megengine.functional as F @@ -132,8 +114,7 @@ class DefaultAnchorGenerator(BaseAnchorGenerator): [flatten_shift_x, flatten_shift_y, flatten_shift_x, flatten_shift_y, ], axis=1, ) - if self.offset > 0: - centers = centers + self.offset * stride + centers = centers + self.offset * self.base_size return centers def get_anchors_by_feature(self, featmap, stride): diff --git a/official/vision/detection/layers/det/box_utils.py b/official/vision/detection/layers/det/box_utils.py index aa0ec10..4d03f5e 100644 --- a/official/vision/detection/layers/det/box_utils.py +++ b/official/vision/detection/layers/det/box_utils.py @@ -112,12 +112,12 @@ class BoxCoder(BoxCoderBase, metaclass=ABCMeta): pred_y2 = pred_ctr_y + 0.5 * pred_height pred_box = self._concat_new_axis(pred_x1, pred_y1, pred_x2, pred_y2, 2) - pred_box = pred_box.reshape(pred_box.shape[0], -1) + pred_box = pred_box.reshape(pred_box.shapeof(0), -1) return pred_box -def get_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: +def get_iou(boxes1: Tensor, boxes2: Tensor, return_ignore=False) -> Tensor: """ Given two lists of boxes of size N and M, compute the IoU (intersection over union) @@ -132,10 +132,10 @@ def get_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: """ box = boxes1 gt = boxes2 - target_shape = (boxes1.shape[0], boxes2.shapeof()[0], 4) + target_shape = (boxes1.shapeof(0), boxes2.shapeof(0), 4) b_box = F.add_axis(boxes1, 1).broadcast(*target_shape) - b_gt = F.add_axis(boxes2, 0).broadcast(*target_shape) + b_gt = F.add_axis(boxes2[:, :4], 0).broadcast(*target_shape) iw = F.minimum(b_box[:, :, 2], b_gt[:, :, 2]) - F.maximum( b_box[:, :, 0], b_gt[:, :, 0] @@ -148,7 +148,7 @@ def get_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: area_box = (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1]) area_gt = (gt[:, 2] - gt[:, 0]) * (gt[:, 3] - gt[:, 1]) - area_target_shape = (box.shape[0], gt.shapeof()[0]) + area_target_shape = (box.shapeof(0), gt.shapeof(0)) b_area_box = F.add_axis(area_box, 1).broadcast(*area_target_shape) b_area_gt = F.add_axis(area_gt, 0).broadcast(*area_target_shape) @@ -156,20 +156,34 @@ def get_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: union = b_area_box + b_area_gt - inter overlaps = F.maximum(inter / union, 0) + if return_ignore: + overlaps_ignore = F.maximum(inter / b_area_box, 0) + gt_ignore_mask = F.add_axis((gt[:, 4] == -1), 0).broadcast(*area_target_shape) + overlaps *= (1 - gt_ignore_mask) + overlaps_ignore *= gt_ignore_mask + return overlaps, overlaps_ignore + return overlaps def get_clipped_box(boxes, hw): """ Clip the boxes into the image region.""" # x1 >=0 - box_x1 = F.maximum(F.minimum(boxes[:, 0::4], hw[1]), 0) + box_x1 = F.clamp(boxes[:, 0::4], lower=0, upper=hw[1]) # y1 >=0 - box_y1 = F.maximum(F.minimum(boxes[:, 1::4], hw[0]), 0) + box_y1 = F.clamp(boxes[:, 1::4], lower=0, upper=hw[0]) # x2 < im_info[1] - box_x2 = F.maximum(F.minimum(boxes[:, 2::4], hw[1]), 0) + box_x2 = F.clamp(boxes[:, 2::4], lower=0, upper=hw[1]) # y2 < im_info[0] - box_y2 = F.maximum(F.minimum(boxes[:, 3::4], hw[0]), 0) + box_y2 = F.clamp(boxes[:, 3::4], lower=0, upper=hw[0]) clip_box = F.concat([box_x1, box_y1, box_x2, box_y2], axis=1) return clip_box + + +def filter_boxes(boxes, size=0): + width = boxes[:, 2] - boxes[:, 0] + height = boxes[:, 3] - boxes[:, 1] + keep = (width > size) * (height > size) + return keep diff --git a/official/vision/detection/layers/det/fpn.py b/official/vision/detection/layers/det/fpn.py index 3978eac..b77f35c 100644 --- a/official/vision/detection/layers/det/fpn.py +++ b/official/vision/detection/layers/det/fpn.py @@ -22,7 +22,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # # This file has been modified by Megvii ("Megvii Modifications"). -# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved. +# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. # --------------------------------------------------------------------- import math from typing import List @@ -47,6 +47,8 @@ class FPN(M.Module): out_channels: int = 256, norm: str = "", top_block: M.Module = None, + strides=[8, 16, 32], + channels=[512, 1024, 2048], ): """ Args: @@ -63,8 +65,8 @@ class FPN(M.Module): """ super(FPN, self).__init__() - in_strides = [8, 16, 32] - in_channels = [512, 1024, 2048] + in_strides = strides + in_channels = channels use_bias = norm == "" self.lateral_convs = list() @@ -148,33 +150,50 @@ class FPN(M.Module): top_block_in_feature = results[ self._out_features.index(self.top_block.in_feature) ] - results.extend(self.top_block(top_block_in_feature, results[-1])) + results.extend(self.top_block(top_block_in_feature)) return dict(zip(self._out_features, results)) def output_shape(self): return { - name: layers.ShapeSpec(channels=self._out_feature_channels[name],) + name: layers.ShapeSpec( + channels=self._out_feature_channels[name], + stride=self._out_feature_strides[name], + ) for name in self._out_features } +class FPNP6(M.Module): + """ + used in FPN, generate a downsampled P6 feature from P5. + """ + + def __init__(self, in_feature="p5"): + super().__init__() + self.num_levels = 1 + self.in_feature = in_feature + + def forward(self, x): + return [F.max_pool2d(x, kernel_size=1, stride=2, padding=0)] + + class LastLevelP6P7(M.Module): """ This module is used in RetinaNet to generate extra layers, P6 and P7 from C5 feature. """ - def __init__(self, in_channels: int, out_channels: int): + def __init__(self, in_channels: int, out_channels: int, in_feature="res5"): super().__init__() self.num_levels = 2 - self.in_feature = "res5" + if in_feature == "p5": + assert in_channels == out_channels + self.in_feature = in_feature self.p6 = M.Conv2d(in_channels, out_channels, 3, 2, 1) self.p7 = M.Conv2d(out_channels, out_channels, 3, 2, 1) - self.use_P5 = in_channels == out_channels - def forward(self, c5, p5=None): - x = p5 if self.use_P5 else c5 + def forward(self, x): p6 = self.p6(x) p7 = self.p7(F.relu(p6)) return [p6, p7] diff --git a/official/vision/detection/layers/det/loss.py b/official/vision/detection/layers/det/loss.py index 1d48adb..3fda1e4 100644 --- a/official/vision/detection/layers/det/loss.py +++ b/official/vision/detection/layers/det/loss.py @@ -6,11 +6,9 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import megengine as mge import megengine.functional as F -import numpy as np -from megengine.core import tensor, Tensor +from megengine.core import Tensor def get_focal_loss( @@ -112,7 +110,8 @@ def get_smooth_l1_loss( if norm_type == "fg": loss = (losses.sum(axis=1) * fg_mask).sum() / F.maximum(fg_mask.sum(), 1) elif norm_type == "all": - raise NotImplementedError + all_mask = (label != ignore_label) + loss = (losses.sum(axis=1) * fg_mask).sum() / F.maximum(all_mask.sum(), 1) else: raise NotImplementedError @@ -151,5 +150,19 @@ def get_smooth_l1_base( abs_x = F.abs(x) in_loss = 0.5 * x ** 2 * sigma2 out_loss = abs_x - 0.5 / sigma2 - loss = F.where(abs_x < cond_point, in_loss, out_loss) + + in_mask = abs_x < cond_point + out_mask = 1 - in_mask + loss = in_loss * in_mask + out_loss * out_mask + return loss + + +def softmax_loss(score, label, ignore_label=-1): + max_score = F.zero_grad(score.max(axis=1, keepdims=True)) + score -= max_score + log_prob = score - F.log(F.exp(score).sum(axis=1, keepdims=True)) + mask = (label != ignore_label) + vlabel = label * mask + loss = -(F.indexing_one_hot(log_prob, vlabel.astype("int32"), 1) * mask).sum() + loss = loss / F.maximum(mask.sum(), 1) return loss diff --git a/official/vision/detection/layers/det/pooler.py b/official/vision/detection/layers/det/pooler.py new file mode 100644 index 0000000..d415b52 --- /dev/null +++ b/official/vision/detection/layers/det/pooler.py @@ -0,0 +1,63 @@ +# -*- coding:utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import math + +import numpy as np +import megengine as mge +import megengine.functional as F + + +def roi_pool( + rpn_fms, rois, stride, pool_shape, roi_type='roi_align', +): + assert len(stride) == len(rpn_fms) + canonical_level = 4 + canonical_box_size = 224 + min_level = math.log2(stride[0]) + max_level = math.log2(stride[-1]) + + num_fms = len(rpn_fms) + box_area = (rois[:, 3] - rois[:, 1]) * (rois[:, 4] - rois[:, 2]) + level_assignments = F.floor( + canonical_level + F.log(box_area.sqrt() / canonical_box_size) / np.log(2) + ) + level_assignments = F.minimum(level_assignments, max_level) + level_assignments = F.maximum(level_assignments, min_level) + level_assignments = level_assignments - min_level + + # avoid empty assignment + level_assignments = F.concat( + [level_assignments, mge.tensor(np.arange(num_fms, dtype=np.int32))], + ) + rois = F.concat([rois, mge.zeros((num_fms, rois.shapeof(-1)))]) + + pool_list, inds_list = [], [] + for i in range(num_fms): + mask = (level_assignments == i) + _, inds = F.cond_take(mask == 1, mask) + level_rois = rois.ai[inds] + if roi_type == 'roi_pool': + pool_fm = F.roi_pooling( + rpn_fms[i], level_rois, pool_shape, + mode='max', scale=1.0/stride[i] + ) + elif roi_type == 'roi_align': + pool_fm = F.roi_align( + rpn_fms[i], level_rois, pool_shape, mode='average', + spatial_scale=1.0/stride[i], sample_points=2, aligned=True + ) + pool_list.append(pool_fm) + inds_list.append(inds) + + fm_order = F.concat(inds_list, axis=0) + fm_order = F.argsort(fm_order.reshape(1, -1))[1].reshape(-1) + pool_feature = F.concat(pool_list, axis=0) + pool_feature = pool_feature.ai[fm_order][:-num_fms] + + return pool_feature diff --git a/official/vision/detection/layers/det/rcnn.py b/official/vision/detection/layers/det/rcnn.py new file mode 100644 index 0000000..a52edcf --- /dev/null +++ b/official/vision/detection/layers/det/rcnn.py @@ -0,0 +1,175 @@ +# -*- coding:utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import megengine as mge +import megengine.functional as F +import megengine.module as M + +from official.vision.detection import layers + + +class RCNN(M.Module): + + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.box_coder = layers.BoxCoder( + reg_mean=cfg.bbox_normalize_means, + reg_std=cfg.bbox_normalize_stds + ) + + # roi head + self.in_features = cfg.rcnn_in_features + self.stride = cfg.rcnn_stride + self.pooling_method = cfg.pooling_method + self.pooling_size = cfg.pooling_size + + self.fc1 = M.Linear(256 * self.pooling_size[0] * self.pooling_size[1], 1024) + self.fc2 = M.Linear(1024, 1024) + for l in [self.fc1, self.fc2]: + M.init.normal_(l.weight, std=0.01) + M.init.fill_(l.bias, 0) + + # box predictor + self.pred_cls = M.Linear(1024, cfg.num_classes + 1) + self.pred_delta = M.Linear(1024, (cfg.num_classes + 1) * 4) + M.init.normal_(self.pred_cls.weight, std=0.01) + M.init.normal_(self.pred_delta.weight, std=0.001) + for l in [self.pred_cls, self.pred_delta]: + M.init.fill_(l.bias, 0) + + def forward(self, fpn_fms, rcnn_rois, im_info=None, gt_boxes=None): + rcnn_rois, labels, bbox_targets = self.get_ground_truth(rcnn_rois, im_info, gt_boxes) + + fpn_fms = [fpn_fms[x] for x in self.in_features] + pool_features = layers.roi_pool( + fpn_fms, rcnn_rois, self.stride, + self.pooling_size, self.pooling_method, + ) + flatten_feature = F.flatten(pool_features, start_axis=1) + roi_feature = F.relu(self.fc1(flatten_feature)) + roi_feature = F.relu(self.fc2(roi_feature)) + pred_cls = self.pred_cls(roi_feature) + pred_delta = self.pred_delta(roi_feature) + + if self.training: + # loss for classification + loss_rcnn_cls = layers.softmax_loss(pred_cls, labels) + # loss for regression + pred_delta = pred_delta.reshape(-1, self.cfg.num_classes + 1, 4) + + vlabels = labels.reshape(-1, 1).broadcast((labels.shapeof(0), 4)) + pred_delta = F.indexing_one_hot(pred_delta, vlabels, axis=1) + + loss_rcnn_loc = layers.get_smooth_l1_loss( + pred_delta, bbox_targets, labels, + self.cfg.rcnn_smooth_l1_beta, + norm_type="all", + ) + loss_dict = { + 'loss_rcnn_cls': loss_rcnn_cls, + 'loss_rcnn_loc': loss_rcnn_loc + } + return loss_dict + else: + # slice 1 for removing background + pred_scores = F.softmax(pred_cls, axis=1)[:, 1:] + pred_delta = pred_delta[:, 4:].reshape(-1, 4) + target_shape = (rcnn_rois.shapeof(0), self.cfg.num_classes, 4) + # rois (N, 4) -> (N, 1, 4) -> (N, 80, 4) -> (N * 80, 4) + base_rois = F.add_axis(rcnn_rois[:, 1:5], 1).broadcast(target_shape).reshape(-1, 4) + pred_bbox = self.box_coder.decode(base_rois, pred_delta) + return pred_bbox, pred_scores + + def get_ground_truth(self, rpn_rois, im_info, gt_boxes): + if not self.training: + return rpn_rois, None, None + + return_rois = [] + return_labels = [] + return_bbox_targets = [] + + # get per image proposals and gt_boxes + for bid in range(self.cfg.batch_per_gpu): + num_valid_boxes = im_info[bid, 4] + gt_boxes_per_img = gt_boxes[bid, :num_valid_boxes, :] + batch_inds = mge.ones((gt_boxes_per_img.shapeof(0), 1)) * bid + # if config.proposal_append_gt: + gt_rois = F.concat([batch_inds, gt_boxes_per_img[:, :4]], axis=1) + batch_roi_mask = (rpn_rois[:, 0] == bid) + _, batch_roi_inds = F.cond_take(batch_roi_mask == 1, batch_roi_mask) + # all_rois : [batch_id, x1, y1, x2, y2] + all_rois = F.concat([rpn_rois.ai[batch_roi_inds], gt_rois]) + + overlaps_normal, overlaps_ignore = layers.get_iou( + all_rois[:, 1:5], gt_boxes_per_img, return_ignore=True, + ) + + max_overlaps_normal = overlaps_normal.max(axis=1) + gt_assignment_normal = F.argmax(overlaps_normal, axis=1) + + max_overlaps_ignore = overlaps_ignore.max(axis=1) + gt_assignment_ignore = F.argmax(overlaps_ignore, axis=1) + + ignore_assign_mask = (max_overlaps_normal < self.cfg.fg_threshold) * ( + max_overlaps_ignore > max_overlaps_normal) + max_overlaps = ( + max_overlaps_normal * (1 - ignore_assign_mask) + + max_overlaps_ignore * ignore_assign_mask + ) + gt_assignment = ( + gt_assignment_normal * (1 - ignore_assign_mask) + + gt_assignment_ignore * ignore_assign_mask + ) + gt_assignment = gt_assignment.astype("int32") + labels = gt_boxes_per_img.ai[gt_assignment, 4] + + # ---------------- get the fg/bg labels for each roi ---------------# + fg_mask = (max_overlaps >= self.cfg.fg_threshold) * (labels != self.cfg.ignore_label) + bg_mask = (max_overlaps < self.cfg.bg_threshold_high) * ( + max_overlaps >= self.cfg.bg_threshold_low) + + num_fg_rois = self.cfg.num_rois * self.cfg.fg_ratio + + fg_inds_mask = self._bernoulli_sample_masks(fg_mask, num_fg_rois, 1) + num_bg_rois = self.cfg.num_rois - fg_inds_mask.sum() + bg_inds_mask = self._bernoulli_sample_masks(bg_mask, num_bg_rois, 1) + + labels = labels * fg_inds_mask + + keep_mask = fg_inds_mask + bg_inds_mask + _, keep_inds = F.cond_take(keep_mask == 1, keep_mask) + # Add next line to avoid memory exceed + keep_inds = keep_inds[:F.minimum(self.cfg.num_rois, keep_inds.shapeof(0))] + # labels + labels = labels.ai[keep_inds].astype("int32") + rois = all_rois.ai[keep_inds] + target_boxes = gt_boxes_per_img.ai[gt_assignment.ai[keep_inds], :4] + bbox_targets = self.box_coder.encode(rois[:, 1:5], target_boxes) + bbox_targets = bbox_targets.reshape(-1, 4) + + return_rois.append(rois) + return_labels.append(labels) + return_bbox_targets.append(bbox_targets) + + return ( + F.zero_grad(F.concat(return_rois, axis=0)), + F.zero_grad(F.concat(return_labels, axis=0)), + F.zero_grad(F.concat(return_bbox_targets, axis=0)) + ) + + def _bernoulli_sample_masks(self, masks, num_samples, sample_value): + """ Using the bernoulli sampling method""" + sample_mask = (masks == sample_value) + num_mask = sample_mask.sum() + num_final_samples = F.minimum(num_mask, num_samples) + # here, we use the bernoulli probability to sample the anchors + sample_prob = num_final_samples / num_mask + uniform_rng = mge.random.uniform(sample_mask.shapeof(0)) + after_sampled_mask = (uniform_rng <= sample_prob) * sample_mask + return after_sampled_mask diff --git a/official/vision/detection/layers/det/rpn.py b/official/vision/detection/layers/det/rpn.py new file mode 100644 index 0000000..dfce347 --- /dev/null +++ b/official/vision/detection/layers/det/rpn.py @@ -0,0 +1,290 @@ +# -*- coding:utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import megengine as mge +import megengine.random as rand +import megengine.functional as F +import megengine.module as M +from official.vision.detection import layers +from official.vision.detection.tools.gpu_nms import batched_nms + + +class RPN(M.Module): + + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.box_coder = layers.BoxCoder() + + self.stride_list = cfg.rpn_stride + rpn_channel = cfg.rpn_channel + self.in_features = cfg.rpn_in_features + self.anchors_generator = layers.DefaultAnchorGenerator( + cfg.anchor_base_size, + cfg.anchor_scales, + cfg.anchor_aspect_ratios, + cfg.anchor_offset, + ) + self.rpn_conv = M.Conv2d(256, rpn_channel, kernel_size=3, stride=1, padding=1) + self.rpn_cls_score = M.Conv2d( + rpn_channel, cfg.num_cell_anchors * 2, + kernel_size=1, stride=1 + ) + self.rpn_bbox_offsets = M.Conv2d( + rpn_channel, cfg.num_cell_anchors * 4, + kernel_size=1, stride=1 + ) + + for l in [self.rpn_conv, self.rpn_cls_score, self.rpn_bbox_offsets]: + M.init.normal_(l.weight, std=0.01) + M.init.fill_(l.bias, 0) + + def forward(self, features, im_info, boxes=None): + # prediction + features = [features[x] for x in self.in_features] + + # get anchors + all_anchors_list = [ + self.anchors_generator(fm, stride) + for fm, stride in zip(features, self.stride_list) + ] + + pred_cls_score_list = [] + pred_bbox_offsets_list = [] + for x in features: + t = F.relu(self.rpn_conv(x)) + scores = self.rpn_cls_score(t) + pred_cls_score_list.append( + scores.reshape( + scores.shape[0], 2, self.cfg.num_cell_anchors, + scores.shape[2], scores.shape[3] + ) + ) + bbox_offsets = self.rpn_bbox_offsets(t) + pred_bbox_offsets_list.append( + bbox_offsets.reshape( + bbox_offsets.shape[0], self.cfg.num_cell_anchors, 4, + bbox_offsets.shape[2], bbox_offsets.shape[3] + ) + ) + # sample from the predictions + rpn_rois = self.find_top_rpn_proposals( + pred_bbox_offsets_list, pred_cls_score_list, + all_anchors_list, im_info + ) + + if self.training: + rpn_labels, rpn_bbox_targets = self.get_ground_truth( + boxes, im_info, all_anchors_list) + pred_cls_score, pred_bbox_offsets = self.merge_rpn_score_box( + pred_cls_score_list, pred_bbox_offsets_list + ) + + # rpn loss + loss_rpn_cls = layers.softmax_loss(pred_cls_score, rpn_labels) + loss_rpn_loc = layers.get_smooth_l1_loss( + pred_bbox_offsets, + rpn_bbox_targets, + rpn_labels, + self.cfg.rpn_smooth_l1_beta, + norm_type="all" + ) + loss_dict = { + "loss_rpn_cls": loss_rpn_cls, + "loss_rpn_loc": loss_rpn_loc + } + return rpn_rois, loss_dict + else: + return rpn_rois + + def find_top_rpn_proposals( + self, rpn_bbox_offsets_list, rpn_cls_prob_list, + all_anchors_list, im_info + ): + prev_nms_top_n = self.cfg.train_prev_nms_top_n \ + if self.training else self.cfg.test_prev_nms_top_n + post_nms_top_n = self.cfg.train_post_nms_top_n \ + if self.training else self.cfg.test_post_nms_top_n + + batch_per_gpu = self.cfg.batch_per_gpu if self.training else 1 + nms_threshold = self.cfg.rpn_nms_threshold + + list_size = len(rpn_bbox_offsets_list) + + return_rois = [] + + for bid in range(batch_per_gpu): + batch_proposals_list = [] + batch_probs_list = [] + batch_level_list = [] + for l in range(list_size): + # get proposals and probs + offsets = rpn_bbox_offsets_list[l][bid].dimshuffle(2, 3, 0, 1).reshape(-1, 4) + all_anchors = all_anchors_list[l] + proposals = self.box_coder.decode(all_anchors, offsets) + + probs = rpn_cls_prob_list[l][bid, 1].dimshuffle(1, 2, 0).reshape(1, -1) + # prev nms top n + probs, order = F.argsort(probs, descending=True) + num_proposals = F.minimum(probs.shapeof(1), prev_nms_top_n) + probs = probs.reshape(-1)[:num_proposals] + order = order.reshape(-1)[:num_proposals] + proposals = proposals.ai[order, :] + + batch_proposals_list.append(proposals) + batch_probs_list.append(probs) + batch_level_list.append(mge.ones(probs.shapeof(0)) * l) + + proposals = F.concat(batch_proposals_list, axis=0) + scores = F.concat(batch_probs_list, axis=0) + level = F.concat(batch_level_list, axis=0) + + proposals = layers.get_clipped_box(proposals, im_info[bid, :]) + # filter empty + keep_mask = layers.filter_boxes(proposals) + _, keep_inds = F.cond_take(keep_mask == 1, keep_mask) + proposals = proposals.ai[keep_inds, :] + scores = scores.ai[keep_inds] + level = level.ai[keep_inds] + + # gather the proposals and probs + # sort nms by scores + scores, order = F.argsort(scores.reshape(1, -1), descending=True) + order = order.reshape(-1) + proposals = proposals.ai[order, :] + level = level.ai[order] + + # apply total level nms + rois = F.concat([proposals, scores.reshape(-1, 1)], axis=1) + keep_inds = batched_nms(proposals, scores, level, nms_threshold, post_nms_top_n) + rois = rois.ai[keep_inds] + + # rois shape (N, 5), info [batch_id, x1, y1, x2, y2] + batch_inds = mge.ones((rois.shapeof(0), 1)) * bid + batch_rois = F.concat([batch_inds, rois[:, :4]], axis=1) + return_rois.append(batch_rois) + + return F.zero_grad(F.concat(return_rois, axis=0)) + + def merge_rpn_score_box(self, rpn_cls_score_list, rpn_bbox_offsets_list): + final_rpn_cls_score_list = [] + final_rpn_bbox_offsets_list = [] + + for bid in range(self.cfg.batch_per_gpu): + batch_rpn_cls_score_list = [] + batch_rpn_bbox_offsets_list = [] + + for i in range(len(self.in_features)): + rpn_cls_score = rpn_cls_score_list[i][bid] \ + .dimshuffle(2, 3, 1, 0).reshape(-1, 2) + rpn_bbox_offsets = rpn_bbox_offsets_list[i][bid] \ + .dimshuffle(2, 3, 0, 1).reshape(-1, 4) + + batch_rpn_cls_score_list.append(rpn_cls_score) + batch_rpn_bbox_offsets_list.append(rpn_bbox_offsets) + + batch_rpn_cls_score = F.concat(batch_rpn_cls_score_list, axis=0) + batch_rpn_bbox_offsets = F.concat(batch_rpn_bbox_offsets_list, axis=0) + + final_rpn_cls_score_list.append(batch_rpn_cls_score) + final_rpn_bbox_offsets_list.append(batch_rpn_bbox_offsets) + + final_rpn_cls_score = F.concat(final_rpn_cls_score_list, axis=0) + final_rpn_bbox_offsets = F.concat(final_rpn_bbox_offsets_list, axis=0) + return final_rpn_cls_score, final_rpn_bbox_offsets + + def per_level_gt( + self, gt_boxes, im_info, anchors, allow_low_quality_matches=True + ): + ignore_label = self.cfg.ignore_label + # get the gt boxes + valid_gt_boxes = gt_boxes[:im_info[4], :] + # compute the iou matrix + overlaps = layers.get_iou(anchors, valid_gt_boxes[:, :4]) + # match the dtboxes + a_shp0 = anchors.shape[0] + max_overlaps = F.max(overlaps, axis=1) + argmax_overlaps = F.argmax(overlaps, axis=1) + # all ignore + labels = mge.ones(a_shp0).astype("int32") * ignore_label + # set negative ones + labels = labels * (max_overlaps >= self.cfg.rpn_negative_overlap) + # set positive ones + fg_mask = (max_overlaps >= self.cfg.rpn_positive_overlap) + const_one = mge.tensor(1.0) + if allow_low_quality_matches: + # make sure that max iou of gt matched + gt_argmax_overlaps = F.argmax(overlaps, axis=0) + num_valid_boxes = valid_gt_boxes.shapeof(0) + gt_id = F.linspace(0, num_valid_boxes - 1, num_valid_boxes).astype("int32") + argmax_overlaps = argmax_overlaps.set_ai(gt_id)[gt_argmax_overlaps] + max_overlaps = max_overlaps.set_ai( + const_one.broadcast(num_valid_boxes) + )[gt_argmax_overlaps] + fg_mask = (max_overlaps >= self.cfg.rpn_positive_overlap) + # set positive ones + _, fg_mask_ind = F.cond_take(fg_mask == 1, fg_mask) + labels = labels.set_ai(const_one.broadcast(fg_mask_ind.shapeof(0)))[fg_mask_ind] + # compute the targets + bbox_targets = self.box_coder.encode( + anchors, valid_gt_boxes.ai[argmax_overlaps, :4] + ) + return labels, bbox_targets + + def get_ground_truth(self, gt_boxes, im_info, all_anchors_list): + final_labels_list = [] + final_bbox_targets_list = [] + + for bid in range(self.cfg.batch_per_gpu): + batch_labels_list = [] + batch_bbox_targets_list = [] + for anchors in all_anchors_list: + rpn_labels_perlvl, rpn_bbox_targets_perlvl = self.per_level_gt( + gt_boxes[bid], im_info[bid], anchors, + ) + batch_labels_list.append(rpn_labels_perlvl) + batch_bbox_targets_list.append(rpn_bbox_targets_perlvl) + + concated_batch_labels = F.concat(batch_labels_list, axis=0) + concated_batch_bbox_targets = F.concat(batch_bbox_targets_list, axis=0) + + # sample labels + num_positive = self.cfg.num_sample_anchors * self.cfg.positive_anchor_ratio + # sample positive + concated_batch_labels = self._bernoulli_sample_labels( + concated_batch_labels, + num_positive, 1, self.cfg.ignore_label + ) + # sample negative + num_positive = (concated_batch_labels == 1).sum() + num_negative = self.cfg.num_sample_anchors - num_positive + concated_batch_labels = self._bernoulli_sample_labels( + concated_batch_labels, + num_negative, 0, self.cfg.ignore_label + ) + + final_labels_list.append(concated_batch_labels) + final_bbox_targets_list.append(concated_batch_bbox_targets) + final_labels = F.concat(final_labels_list, axis=0) + final_bbox_targets = F.concat(final_bbox_targets_list, axis=0) + return F.zero_grad(final_labels), F.zero_grad(final_bbox_targets) + + def _bernoulli_sample_labels( + self, labels, num_samples, sample_value, ignore_label=-1 + ): + """ Using the bernoulli sampling method""" + sample_label_mask = (labels == sample_value) + num_mask = sample_label_mask.sum() + num_final_samples = F.minimum(num_mask, num_samples) + # here, we use the bernoulli probability to sample the anchors + sample_prob = num_final_samples / num_mask + uniform_rng = rand.uniform(sample_label_mask.shapeof(0)) + to_ignore_mask = (uniform_rng >= sample_prob) * sample_label_mask + labels = labels * (1 - to_ignore_mask) + to_ignore_mask * ignore_label + + return labels diff --git a/official/vision/detection/models/__init__.py b/official/vision/detection/models/__init__.py index ac9cdf6..be97d0d 100644 --- a/official/vision/detection/models/__init__.py +++ b/official/vision/detection/models/__init__.py @@ -6,6 +6,7 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from .faster_rcnn_fpn import * from .retinanet import * _EXCLUDE = {} diff --git a/official/vision/detection/models/faster_rcnn_fpn.py b/official/vision/detection/models/faster_rcnn_fpn.py new file mode 100644 index 0000000..7122e3d --- /dev/null +++ b/official/vision/detection/models/faster_rcnn_fpn.py @@ -0,0 +1,212 @@ +# -*- coding:utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import numpy as np +import megengine as mge +import megengine.functional as F +import megengine.module as M + +from official.vision.classification.resnet.model import resnet50 +from official.vision.detection import layers + + +class FasterRCNN(M.Module): + + def __init__(self, cfg, batch_size): + super().__init__() + self.cfg = cfg + cfg.batch_per_gpu = batch_size + self.batch_size = batch_size + # ----------------------- build the backbone ------------------------ # + bottom_up = resnet50(norm=layers.get_norm(cfg.resnet_norm)) + + # ------------ freeze the weights of resnet stage1 and stage 2 ------ # + if self.cfg.backbone_freeze_at >= 1: + for p in bottom_up.conv1.parameters(): + p.requires_grad = False + if self.cfg.backbone_freeze_at >= 2: + for p in bottom_up.layer1.parameters(): + p.requires_grad = False + + # -------------------------- build the FPN -------------------------- # + out_channels = 256 + self.backbone = layers.FPN( + bottom_up=bottom_up, + in_features=["res2", "res3", "res4", "res5"], + out_channels=out_channels, + norm="", + top_block=layers.FPNP6(), + strides=[4, 8, 16, 32], + channels=[256, 512, 1024, 2048], + ) + + # -------------------------- build the RPN -------------------------- # + self.RPN = layers.RPN(cfg) + + # ----------------------- build the RCNN head ----------------------- # + self.RCNN = layers.RCNN(cfg) + + # -------------------------- input Tensor --------------------------- # + self.inputs = { + "image": mge.tensor( + np.random.random([2, 3, 224, 224]).astype(np.float32), dtype="float32", + ), + "im_info": mge.tensor( + np.random.random([2, 5]).astype(np.float32), dtype="float32", + ), + "gt_boxes": mge.tensor( + np.random.random([2, 100, 5]).astype(np.float32), dtype="float32", + ), + } + + def preprocess_image(self, image): + normed_image = ( + image - self.cfg.img_mean[None, :, None, None] + ) / self.cfg.img_std[None, :, None, None] + return layers.get_padded_tensor(normed_image, 32, 0.0) + + def forward(self, inputs): + images = inputs['image'] + im_info = inputs['im_info'] + gt_boxes = inputs['gt_boxes'] + # process the images + normed_images = self.preprocess_image(images) + # normed_images = images + fpn_features = self.backbone(normed_images) + + if self.training: + return self._forward_train(fpn_features, im_info, gt_boxes) + else: + return self.inference(fpn_features, im_info) + + def _forward_train(self, fpn_features, im_info, gt_boxes): + rpn_rois, rpn_losses = self.RPN(fpn_features, im_info, gt_boxes) + rcnn_losses = self.RCNN(fpn_features, rpn_rois, im_info, gt_boxes) + + loss_rpn_cls = rpn_losses['loss_rpn_cls'] + loss_rpn_loc = rpn_losses['loss_rpn_loc'] + loss_rcnn_cls = rcnn_losses['loss_rcnn_cls'] + loss_rcnn_loc = rcnn_losses['loss_rcnn_loc'] + total_loss = loss_rpn_cls + loss_rpn_loc + loss_rcnn_cls + loss_rcnn_loc + + loss_dict = { + "total_loss": total_loss, + "rpn_cls": loss_rpn_cls, + "rpn_loc": loss_rpn_loc, + "rcnn_cls": loss_rcnn_cls, + "rcnn_loc": loss_rcnn_loc + } + self.cfg.losses_keys = list(loss_dict.keys()) + return loss_dict + + def inference(self, fpn_features, im_info): + rpn_rois = self.RPN(fpn_features, im_info) + pred_boxes, pred_score = self.RCNN(fpn_features, rpn_rois) + # pred_score = pred_score[:, None] + pred_boxes = pred_boxes.reshape(-1, 4) + scale_w = im_info[0, 1] / im_info[0, 3] + scale_h = im_info[0, 0] / im_info[0, 2] + pred_boxes = pred_boxes / F.concat( + [scale_w, scale_h, scale_w, scale_h], axis=0 + ) + + clipped_boxes = layers.get_clipped_box( + pred_boxes, im_info[0, 2:4] + ).reshape(-1, self.cfg.num_classes, 4) + return pred_score, clipped_boxes + + +class FasterRCNNConfig: + + def __init__(self): + self.resnet_norm = "FrozenBN" + self.backbone_freeze_at = 2 + + # ------------------------ data cfg --------------------------- # + self.train_dataset = dict( + name="coco", + root="train2017", + ann_file="annotations/instances_train2017.json", + ) + self.test_dataset = dict( + name="coco", + root="val2017", + ann_file="annotations/instances_val2017.json", + ) + self.num_classes = 80 + + self.img_mean = np.array([103.530, 116.280, 123.675]) # BGR + self.img_std = np.array([57.375, 57.120, 58.395]) + + # ----------------------- rpn cfg ------------------------- # + self.anchor_base_size = 16 + self.anchor_scales = np.array([0.5]) + self.anchor_aspect_ratios = [0.5, 1, 2] + self.anchor_offset = -0.5 + self.num_cell_anchors = len(self.anchor_aspect_ratios) + + self.bbox_normalize_means = None + self.bbox_normalize_stds = np.array([0.1, 0.1, 0.2, 0.2]) + + self.rpn_stride = np.array([4, 8, 16, 32, 64]).astype(np.float32) + self.rpn_in_features = ["p2", "p3", "p4", "p5", "p6"] + self.rpn_channel = 256 + + self.rpn_nms_threshold = 0.7 + self.allow_low_quality = True + self.num_sample_anchors = 256 + self.positive_anchor_ratio = 0.5 + self.rpn_positive_overlap = 0.7 + self.rpn_negative_overlap = 0.3 + self.ignore_label = -1 + + # ----------------------- rcnn cfg ------------------------- # + self.pooling_method = 'roi_align' + self.pooling_size = (7, 7) + + self.num_rois = 512 + self.fg_ratio = 0.5 + self.fg_threshold = 0.5 + self.bg_threshold_high = 0.5 + self.bg_threshold_low = 0.0 + + self.rcnn_in_features = ["p2", "p3", "p4", "p5"] + self.rcnn_stride = [4, 8, 16, 32] + + # ------------------------ loss cfg -------------------------- # + self.rpn_smooth_l1_beta = 3 + self.rcnn_smooth_l1_beta = 1 + + # ------------------------ training cfg ---------------------- # + self.train_image_short_size = 800 + self.train_image_max_size = 1333 + self.train_prev_nms_top_n = 2000 + self.train_post_nms_top_n = 1000 + + self.num_losses = 5 + self.basic_lr = 0.02 / 16.0 # The basic learning rate for single-image + self.momentum = 0.9 + self.weight_decay = 1e-4 + self.log_interval = 20 + self.nr_images_epoch = 80000 + self.max_epoch = 18 + self.warm_iters = 500 + self.lr_decay_rate = 0.1 + self.lr_decay_sates = [12, 16, 17] + + # ------------------------ testing cfg ------------------------- # + self.test_image_short_size = 800 + self.test_image_max_size = 1333 + self.test_prev_nms_top_n = 1000 + self.test_post_nms_top_n = 1000 + self.test_max_boxes_per_image = 100 + + self.test_vis_threshold = 0.3 + self.test_cls_threshold = 0.05 + self.test_nms = 0.5 + self.class_aware_box = True diff --git a/official/vision/detection/models/retinanet.py b/official/vision/detection/models/retinanet.py index 1013819..eb5f07e 100644 --- a/official/vision/detection/models/retinanet.py +++ b/official/vision/detection/models/retinanet.py @@ -123,7 +123,13 @@ class RetinaNet(M.Module): ) total = rpn_cls_loss + rpn_bbox_loss - return total, rpn_cls_loss, rpn_bbox_loss + loss_dict = { + "total_loss": total, + "loss_cls": rpn_cls_loss, + "loss_loc": rpn_bbox_loss + } + self.cfg.losses_keys = list(loss_dict.keys()) + return loss_dict else: # currently not support multi-batch testing assert self.batch_size == 1 @@ -231,6 +237,7 @@ class RetinaNetConfig: self.focal_loss_alpha = 0.25 self.focal_loss_gamma = 2 self.reg_loss_weight = 1.0 / 4.0 + self.num_losses = 3 # ------------------------ training cfg ---------------------- # self.basic_lr = 0.01 / 16.0 # The basic learning rate for single-image diff --git a/official/vision/detection/retinanet_res50_coco_1x_800size.py b/official/vision/detection/retinanet_res50_coco_1x_800size.py index 8324290..9c06d55 100644 --- a/official/vision/detection/retinanet_res50_coco_1x_800size.py +++ b/official/vision/detection/retinanet_res50_coco_1x_800size.py @@ -19,6 +19,8 @@ def retinanet_res50_coco_1x_800size(batch_size=1, **kwargs): r""" RetinaNet trained from COCO dataset. `"RetinaNet" `_ + `"FPN" `_ + `"COCO" `_ """ return models.RetinaNet(models.RetinaNetConfig(), batch_size=batch_size, **kwargs) diff --git a/official/vision/detection/retinanet_res50_coco_1x_800size_syncbn.py b/official/vision/detection/retinanet_res50_coco_1x_800size_syncbn.py index 363a542..5944395 100644 --- a/official/vision/detection/retinanet_res50_coco_1x_800size_syncbn.py +++ b/official/vision/detection/retinanet_res50_coco_1x_800size_syncbn.py @@ -6,7 +6,6 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from megengine import hub from official.vision.detection import models @@ -24,6 +23,9 @@ def retinanet_res50_coco_1x_800size_syncbn(batch_size=1, **kwargs): r""" RetinaNet with SyncBN trained from COCO dataset. `"RetinaNet" `_ + `"FPN" `_ + `"COCO" `_ + `"SyncBN" `_ """ return models.RetinaNet(CustomRetinaNetConfig(), batch_size=batch_size, **kwargs) diff --git a/official/vision/detection/retinanet_res50_objects365_1x_800size.py b/official/vision/detection/retinanet_res50_objects365_1x_800size.py index 951d09a..2b53978 100644 --- a/official/vision/detection/retinanet_res50_objects365_1x_800size.py +++ b/official/vision/detection/retinanet_res50_objects365_1x_800size.py @@ -6,8 +6,6 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from megengine import hub - from official.vision.detection import models diff --git a/official/vision/detection/tools/gpu_nms.py b/official/vision/detection/tools/gpu_nms.py new file mode 100644 index 0000000..cf84229 --- /dev/null +++ b/official/vision/detection/tools/gpu_nms.py @@ -0,0 +1,98 @@ +#!/usr/bin/env mdl +# This file will seal the nms opr within a better way than lib_nms +import ctypes +import os +import struct + +import numpy as np +import megengine as mge +import megengine.functional as F +from megengine._internal.craniotome import CraniotomeBase +from megengine.core.tensor import wrap_io_tensor + +_so_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'lib_nms.so') +_so_lib = ctypes.CDLL(_so_path) + +_TYPE_POINTER = ctypes.c_void_p +_TYPE_POINTER = ctypes.c_void_p +_TYPE_INT = ctypes.c_int32 +_TYPE_FLOAT = ctypes.c_float + +_so_lib.NMSForwardGpu.argtypes = [ + _TYPE_POINTER, + _TYPE_POINTER, + _TYPE_POINTER, + _TYPE_POINTER, + _TYPE_FLOAT, + _TYPE_INT, + _TYPE_POINTER, +] +_so_lib.NMSForwardGpu.restype = _TYPE_INT + +_so_lib.CreateHostDevice.restype = _TYPE_POINTER + + +class NMSCran(CraniotomeBase): + __nr_inputs__ = 1 + __nr_outputs__ = 3 + + def setup(self, iou_threshold, max_output): + self._iou_threshold = iou_threshold + self._max_output = max_output + # Load the necessary host device + self._host_device = _so_lib.CreateHostDevice() + + def execute(self, inputs, outputs): + box_tensor_ptr = inputs[0].pubapi_dev_tensor_ptr + output_tensor_ptr = outputs[0].pubapi_dev_tensor_ptr + output_num_tensor_ptr = outputs[1].pubapi_dev_tensor_ptr + mask_tensor_ptr = outputs[2].pubapi_dev_tensor_ptr + + _so_lib.NMSForwardGpu( + box_tensor_ptr, mask_tensor_ptr, + output_tensor_ptr, output_num_tensor_ptr, + self._iou_threshold, self._max_output, + self._host_device + ) + + def grad(self, wrt_idx, inputs, outputs, out_grad): + return 0 + + def init_output_dtype(self, input_dtypes): + return [np.int32, np.int32, np.int32] + + def get_serialize_params(self): + return ('nms', struct.pack('fi', self._iou_threshold, self._max_output)) + + def infer_shape(self, inp_shapes): + nr_box = inp_shapes[0][0] + threadsPerBlock = 64 + output_size = nr_box + # here we compute the number of int32 used in mask_outputs. + # In original version, we compute the bytes only. + mask_size = int( + nr_box * ( + nr_box // threadsPerBlock + int((nr_box % threadsPerBlock) > 0) + ) * 8 / 4 + ) + return [[output_size], [1], [mask_size]] + + +@wrap_io_tensor +def gpu_nms(box, iou_threshold, max_output): + keep, num, _ = NMSCran.make(box, iou_threshold=iou_threshold, max_output=max_output) + return keep[:num] + + +def batched_nms(boxes, scores, idxs, iou_threshold, num_keep, use_offset=False): + if use_offset: + boxes_offset = mge.tensor( + [0, 0, 1, 1], device=boxes.device + ).reshape(1, 4).broadcast(boxes.shapeof(0), 4) + boxes = boxes - boxes_offset + max_coordinate = boxes.max() + offsets = idxs * (max_coordinate + 1) + boxes_for_nms = boxes + offsets.reshape(-1, 1).broadcast(boxes.shapeof(0), 4) + boxes_with_scores = F.concat([boxes_for_nms, scores.reshape(-1, 1)], axis=1) + keep_inds = gpu_nms(boxes_with_scores, iou_threshold, num_keep) + return keep_inds diff --git a/official/vision/detection/tools/gpu_nms/nms.cu b/official/vision/detection/tools/gpu_nms/nms.cu new file mode 100644 index 0000000..c89e2c8 --- /dev/null +++ b/official/vision/detection/tools/gpu_nms/nms.cu @@ -0,0 +1,201 @@ +#include "megbrain_pubapi.h" +#include +#include +#include + +#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) +#define CUDA_CHECK(condition) \ + /* Code block avoids redefinition of cudaError_t error */ \ + do { \ + cudaError_t error = condition; \ + if (error != cudaSuccess) { \ + std::cout << " " << cudaGetErrorString(error); \ + } \ + } while (0) +#define CUDA_POST_KERNEL_CHECK CUDA_CHECK(cudaPeekAtLastError()) + +int const threadsPerBlock = sizeof(unsigned long long) * 8; // 64 + +__device__ inline float devIoU(float const * const a, float const * const b) { + float left = max(a[0], b[0]), right = min(a[2], b[2]); + float top = max(a[1], b[1]), bottom = min(a[3], b[3]); + float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f); + float interS = width * height; + float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1); + float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1); + return interS / (Sa + Sb - interS); +} + +__global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh, + const float *dev_boxes, unsigned long long *dev_mask) { + const int row_start = blockIdx.y; + const int col_start = blockIdx.x; + + if (row_start > col_start) return; + + const int row_size = + min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); + const int col_size = + min(n_boxes - col_start * threadsPerBlock, threadsPerBlock); + + __shared__ float block_boxes[threadsPerBlock * 5]; + if (threadIdx.x < col_size) { + block_boxes[threadIdx.x * 5 + 0] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0]; + block_boxes[threadIdx.x * 5 + 1] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1]; + block_boxes[threadIdx.x * 5 + 2] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2]; + block_boxes[threadIdx.x * 5 + 3] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3]; + block_boxes[threadIdx.x * 5 + 4] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4]; + } + __syncthreads(); + + if (threadIdx.x < row_size) { + const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; + const float *cur_box = dev_boxes + cur_box_idx * 5; + int i = 0; + unsigned long long t = 0; + int start = 0; + if (row_start == col_start) { + start = threadIdx.x + 1; + } + for (i = start; i < col_size; i++) { + if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) { + t |= 1ULL << i; + } + } + const int col_blocks = DIVUP(n_boxes, threadsPerBlock); + dev_mask[cur_box_idx * col_blocks + col_start] = t; + } +} + +template +static inline void cpu_unroll_for(unsigned long long *dst, const unsigned long long *src, int n) { + int nr_out = (n - n % unroll) / unroll; + for (int i = 0; i < nr_out; ++i) { +#pragma unroll + for (int j = 0; j < unroll; ++j) { + *(dst++) |= *(src++); + } + } + for (int j = 0; j < n % unroll; ++j) { + *(dst++) |= *(src++); + } +} + +using std::vector; +// const int nr_init_box = 8000; +// vector _mask_host(nr_init_box * (nr_init_box / threadsPerBlock)); +// vector _remv(nr_init_box / threadsPerBlock); +// vector _keep_out(nr_init_box); + +// NOTE: If we directly use this lib in nmp.py, we will meet the same _mask_host and other +// objects, which is not safe for multi-processing programs. + +class HostDevice{ +protected: + static const int nr_init_box = 8000; +public: + vector mask_host; + vector remv; + vector keep_out; + + HostDevice(): mask_host(nr_init_box * (nr_init_box / threadsPerBlock)), remv(nr_init_box / threadsPerBlock), keep_out(nr_init_box){} +}; + +extern "C"{ + using MGBDevTensor = mgb::pubapi::DeviceTensor; + using std::cout; + + void * CreateHostDevice(){ + return new HostDevice(); + } + + int NMSForwardGpu(void* box_ptr, void* mask_ptr, void* output_ptr, void* output_num_ptr, float iou_threshold, int max_output, void* host_device_ptr){ + auto box_tensor = mgb::pubapi::as_versioned_obj(box_ptr); + auto mask_tensor= mgb::pubapi::as_versioned_obj(mask_ptr); + auto output_tensor = mgb::pubapi::as_versioned_obj(output_ptr); + auto output_num_tensor = mgb::pubapi::as_versioned_obj(output_num_ptr); + + // auto cuda_stream = static_cast (box_tensor->desc.cuda_ctx.stream); + auto cuda_stream = static_cast (output_tensor->desc.cuda_ctx.stream); + // assert(box_tensor->desc.shape[0] == output_tensor->desc.shape[0]); + + // cout << "box_tensor.ndim: " << box_tensor->desc.ndim << "\n"; + // cout << "box_tensor.shape_0: " << box_tensor->desc.shape[0] << "\n"; + // cout << "box_tensor.shape_1: " << box_tensor->desc.shape[1] << "\n"; + int box_num = box_tensor->desc.shape[0]; + int box_dim = box_tensor->desc.shape[1]; + assert(box_dim == 5); + + const int col_blocks = DIVUP(box_num, threadsPerBlock); + // cout << "mask_dev size: " << box_num * col_blocks * sizeof(unsigned long long) << "\n"; + // cout << "mask_ptr size: " << mask_tensor->desc.shape[0] * sizeof(int) << "\n"; + // cout << "mask shape : " << mask_tensor->desc.shape[0] << "\n"; + + dim3 blocks(DIVUP(box_num, threadsPerBlock), DIVUP(box_num, threadsPerBlock)); + // dim3 blocks(col_blocks, col_blocks); + dim3 threads(threadsPerBlock); + // cout << "sizeof unsigned long long " << sizeof(unsigned long long) << "\n"; + float* dev_box = static_cast (box_tensor->desc.dev_ptr); + unsigned long long* dev_mask = static_cast (mask_tensor->desc.dev_ptr); + int * dev_output = static_cast (output_tensor->desc.dev_ptr); + + CUDA_CHECK(cudaMemsetAsync(dev_mask, 0, mask_tensor->desc.shape[0] * sizeof(int), cuda_stream)); + // CUDA_CHECK(cudaMemsetAsync(dev_output, 0, output_tensor->desc.shape[0] * sizeof(int), cuda_stream)); + nms_kernel<<>>(box_num, iou_threshold, dev_box, dev_mask); + // cudaDeviceSynchronize(); + + // get the host device vectors + HostDevice* host_device = static_cast(host_device_ptr); + vector& _mask_host = host_device->mask_host; + vector& _remv = host_device->remv; + vector& _keep_out = host_device->keep_out; + + + int current_mask_host_size = box_num * col_blocks; + if(_mask_host.capacity() < current_mask_host_size){ + _mask_host.reserve(current_mask_host_size); + } + CUDA_CHECK(cudaMemcpyAsync(&_mask_host[0], dev_mask, sizeof(unsigned long long) * box_num * col_blocks, cudaMemcpyDeviceToHost, cuda_stream)); + // cout << "\n m_host site: " << static_cast (&_mask_host[0]) << "\n"; + + if(_remv.capacity() < col_blocks){ + _remv.reserve(col_blocks); + } + if(_keep_out.capacity() < box_num){ + _keep_out.reserve(box_num); + } + if(max_output < 0){ + max_output = box_num; + } + memset(&_remv[0], 0, sizeof(unsigned long long) * col_blocks); + CUDA_CHECK(cudaStreamSynchronize(cuda_stream)); + + // do the cpu reduce + int num_to_keep = 0; + for (int i = 0; i < box_num; i++) { + int nblock = i / threadsPerBlock; + int inblock = i % threadsPerBlock; + + if (!(_remv[nblock] & (1ULL << inblock))) { + _keep_out[num_to_keep++] = i; + if(num_to_keep == max_output){ + break; + } + // NOTE: here we need add nblock to pointer p + unsigned long long *p = &_mask_host[0] + i * col_blocks + nblock; + unsigned long long *q = &_remv[0] + nblock; + cpu_unroll_for(q, p, col_blocks - nblock); + } + } + CUDA_CHECK(cudaMemcpyAsync(dev_output, &_keep_out[0], num_to_keep * sizeof(int), cudaMemcpyHostToDevice, cuda_stream)); + int* dev_output_num = static_cast(output_num_tensor->desc.dev_ptr); + CUDA_CHECK(cudaMemcpyAsync(dev_output_num, &num_to_keep, sizeof(int), cudaMemcpyHostToDevice, cuda_stream)); + // CUDA_CHECK(cudaStreamSynchronize(cuda_stream)); + return num_to_keep; + } +} diff --git a/official/vision/detection/tools/train.py b/official/vision/detection/tools/train.py index f93093d..a65bdbc 100644 --- a/official/vision/detection/tools/train.py +++ b/official/vision/detection/tools/train.py @@ -128,11 +128,12 @@ def adjust_learning_rate(optimizer, epoch_id, step, model, world_size): def train_one_epoch(model, data_queue, opt, tot_steps, rank, epoch_id, world_size): @jit.trace(symbolic=True, opt_level=2) def propagate(): - loss_list = model(model.inputs) - opt.backward(loss_list[0]) - return loss_list + loss_dict = model(model.inputs) + opt.backward(loss_dict["total_loss"]) + losses = list(loss_dict.values()) + return losses - meter = AverageMeter(record_len=3) + meter = AverageMeter(record_len=model.cfg.num_losses) log_interval = model.cfg.log_interval for step in range(tot_steps): adjust_learning_rate(opt, epoch_id, step, model, world_size) @@ -146,17 +147,18 @@ def train_one_epoch(model, data_queue, opt, tot_steps, rank, epoch_id, world_siz opt.step() if rank == 0: + loss_str = ", ".join(["{}:%f".format(loss) for loss in model.cfg.losses_keys]) + log_info_str = "e%d, %d/%d, lr:%f, " + loss_str meter.update([loss.numpy() for loss in loss_list]) if step % log_interval == 0: average_loss = meter.average() logger.info( - "e%d, %d/%d, lr:%f, cls:%f, loc:%f", + log_info_str, epoch_id, step, tot_steps, opt.param_groups[0]["lr"], - average_loss[1], - average_loss[2], + *average_loss, ) meter.reset() -- GitLab