retinanet.py 11.0 KB
Newer Older
M
MegEngine Team 已提交
1 2 3 4 5 6 7 8
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 10
import numpy as np

M
MegEngine Team 已提交
11 12 13 14
import megengine as mge
import megengine.functional as F
import megengine.module as M

15
import official.vision.classification.resnet.model as resnet
M
MegEngine Team 已提交
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
from official.vision.detection import layers


class RetinaNet(M.Module):
    """
    Implement RetinaNet (https://arxiv.org/abs/1708.02002).
    """

    def __init__(self, cfg, batch_size):
        super().__init__()
        self.cfg = cfg
        self.batch_size = batch_size

        self.anchor_gen = layers.DefaultAnchorGenerator(
            base_size=4,
            anchor_scales=self.cfg.anchor_scales,
            anchor_ratios=self.cfg.anchor_ratios,
        )
34
        self.box_coder = layers.BoxCoder(cfg.reg_mean, cfg.reg_std)
M
MegEngine Team 已提交
35

36
        self.stride_list = np.array(cfg.stride, dtype=np.float32)
M
MegEngine Team 已提交
37 38 39
        self.in_features = ["p3", "p4", "p5", "p6", "p7"]

        # ----------------------- build the backbone ------------------------ #
40 41 42
        bottom_up = getattr(resnet, cfg.backbone)(
            norm=layers.get_norm(cfg.resnet_norm), pretrained=cfg.backbone_pretrained
        )
M
MegEngine Team 已提交
43 44 45 46 47 48 49 50 51

        # ------------ freeze the weights of resnet stage1 and stage 2 ------ #
        if self.cfg.backbone_freeze_at >= 1:
            for p in bottom_up.conv1.parameters():
                p.requires_grad = False
        if self.cfg.backbone_freeze_at >= 2:
            for p in bottom_up.layer1.parameters():
                p.requires_grad = False

52
        # ----------------------- build the FPN ----------------------------- #
M
MegEngine Team 已提交
53 54 55 56 57 58
        in_channels_p6p7 = 2048
        out_channels = 256
        self.backbone = layers.FPN(
            bottom_up=bottom_up,
            in_features=["res3", "res4", "res5"],
            out_channels=out_channels,
59
            norm=cfg.fpn_norm,
M
MegEngine Team 已提交
60 61 62 63 64 65
            top_block=layers.LastLevelP6P7(in_channels_p6p7, out_channels),
        )

        backbone_shape = self.backbone.output_shape()
        feature_shapes = [backbone_shape[f] for f in self.in_features]

66
        # ----------------------- build the RetinaNet Head ------------------ #
M
MegEngine Team 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80
        self.head = layers.RetinaNetHead(cfg, feature_shapes)

        self.inputs = {
            "image": mge.tensor(
                np.random.random([2, 3, 224, 224]).astype(np.float32), dtype="float32",
            ),
            "im_info": mge.tensor(
                np.random.random([2, 5]).astype(np.float32), dtype="float32",
            ),
            "gt_boxes": mge.tensor(
                np.random.random([2, 100, 5]).astype(np.float32), dtype="float32",
            ),
        }

81 82
        self.loss_normalizer = mge.tensor(100.0)

M
MegEngine Team 已提交
83
    def preprocess_image(self, image):
84
        padded_image = layers.get_padded_tensor(image, 32, 0.0)
M
MegEngine Team 已提交
85
        normed_image = (
86 87
            padded_image
            - np.array(self.cfg.img_mean, dtype=np.float32)[None, :, None, None]
88
        ) / np.array(self.cfg.img_std, dtype=np.float32)[None, :, None, None]
89
        return normed_image
M
MegEngine Team 已提交
90 91 92 93 94 95

    def forward(self, inputs):
        image = self.preprocess_image(inputs["image"])
        features = self.backbone(image)
        features = [features[f] for f in self.in_features]

96
        box_logits, box_offsets = self.head(features)
M
MegEngine Team 已提交
97

98
        box_logits_list = [
M
MegEngine Team 已提交
99
            _.dimshuffle(0, 2, 3, 1).reshape(self.batch_size, -1, self.cfg.num_classes)
100
            for _ in box_logits
M
MegEngine Team 已提交
101
        ]
102
        box_offsets_list = [
103 104
            _.dimshuffle(0, 2, 3, 1).reshape(self.batch_size, -1, 4)
            for _ in box_offsets
M
MegEngine Team 已提交
105 106 107
        ]

        anchors_list = [
108 109
            self.anchor_gen(features[i], self.stride_list[i])
            for i in range(len(features))
M
MegEngine Team 已提交
110 111
        ]

112 113
        all_level_box_logits = F.concat(box_logits_list, axis=1)
        all_level_box_offsets = F.concat(box_offsets_list, axis=1)
M
MegEngine Team 已提交
114 115 116
        all_level_anchors = F.concat(anchors_list, axis=0)

        if self.training:
117
            box_gt_scores, box_gt_offsets = self.get_ground_truth(
M
MegEngine Team 已提交
118 119 120 121
                all_level_anchors,
                inputs["gt_boxes"],
                inputs["im_info"][:, 4].astype(np.int32),
            )
122
            norm_type = "none" if self.cfg.loss_normalizer_momentum > 0.0 else "fg"
M
MegEngine Team 已提交
123
            rpn_cls_loss = layers.get_focal_loss(
124 125
                all_level_box_logits,
                box_gt_scores,
M
MegEngine Team 已提交
126 127
                alpha=self.cfg.focal_loss_alpha,
                gamma=self.cfg.focal_loss_gamma,
128
                norm_type=norm_type,
M
MegEngine Team 已提交
129 130
            )
            rpn_bbox_loss = (
131
                layers.get_smooth_l1_loss(
132 133 134
                    all_level_box_offsets,
                    box_gt_offsets,
                    box_gt_scores,
135
                    self.cfg.smooth_l1_beta,
136
                    norm_type=norm_type,
137
                )
M
MegEngine Team 已提交
138 139 140
                * self.cfg.reg_loss_weight
            )

141 142 143 144 145 146 147 148 149 150
            if norm_type == "none":
                F.add_update(
                    self.loss_normalizer,
                    (box_gt_scores > 0).sum(),
                    alpha=self.cfg.loss_normalizer_momentum,
                    beta=1 - self.cfg.loss_normalizer_momentum,
                )
                rpn_cls_loss = rpn_cls_loss / F.maximum(self.loss_normalizer, 1)
                rpn_bbox_loss = rpn_bbox_loss / F.maximum(self.loss_normalizer, 1)

M
MegEngine Team 已提交
151
            total = rpn_cls_loss + rpn_bbox_loss
152 153 154
            loss_dict = {
                "total_loss": total,
                "loss_cls": rpn_cls_loss,
155
                "loss_loc": rpn_bbox_loss,
156 157 158
            }
            self.cfg.losses_keys = list(loss_dict.keys())
            return loss_dict
M
MegEngine Team 已提交
159 160 161 162 163
        else:
            # currently not support multi-batch testing
            assert self.batch_size == 1

            transformed_box = self.box_coder.decode(
164
                all_level_anchors, all_level_box_offsets[0],
M
MegEngine Team 已提交
165 166 167 168 169 170 171 172 173 174 175
            )
            transformed_box = transformed_box.reshape(-1, 4)

            scale_w = inputs["im_info"][0, 1] / inputs["im_info"][0, 3]
            scale_h = inputs["im_info"][0, 0] / inputs["im_info"][0, 2]
            transformed_box = transformed_box / F.concat(
                [scale_w, scale_h, scale_w, scale_h], axis=0
            )
            clipped_box = layers.get_clipped_box(
                transformed_box, inputs["im_info"][0, 2:4]
            ).reshape(-1, 4)
176 177
            all_level_box_scores = F.sigmoid(all_level_box_logits)
            return all_level_box_scores[0], clipped_box
M
MegEngine Team 已提交
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231

    def get_ground_truth(self, anchors, batched_gt_boxes, batched_valid_gt_box_number):
        total_anchors = anchors.shape[0]
        labels_cat_list = []
        bbox_targets_list = []

        for b_id in range(self.batch_size):
            gt_boxes = batched_gt_boxes[b_id, : batched_valid_gt_box_number[b_id]]

            overlaps = layers.get_iou(anchors, gt_boxes[:, :4])
            argmax_overlaps = F.argmax(overlaps, axis=1)

            max_overlaps = overlaps.ai[
                F.linspace(0, total_anchors - 1, total_anchors).astype(np.int32),
                argmax_overlaps,
            ]

            labels = mge.tensor([-1]).broadcast(total_anchors)
            labels = labels * (max_overlaps >= self.cfg.negative_thresh)
            labels = labels * (max_overlaps < self.cfg.positive_thresh) + (
                max_overlaps >= self.cfg.positive_thresh
            )

            bbox_targets = self.box_coder.encode(
                anchors, gt_boxes.ai[argmax_overlaps, :4]
            )

            labels_cat = gt_boxes.ai[argmax_overlaps, 4]
            labels_cat = labels_cat * (1.0 - F.less_equal(F.abs(labels), 1e-5))
            ignore_mask = F.less_equal(F.abs(labels + 1), 1e-5)
            labels_cat = labels_cat * (1 - ignore_mask) - ignore_mask

            # assign low_quality boxes
            if self.cfg.allow_low_quality:
                gt_argmax_overlaps = F.argmax(overlaps, axis=0)
                labels_cat = labels_cat.set_ai(gt_boxes[:, 4])[gt_argmax_overlaps]
                matched_low_bbox_targets = self.box_coder.encode(
                    anchors.ai[gt_argmax_overlaps, :], gt_boxes[:, :4]
                )
                bbox_targets = bbox_targets.set_ai(matched_low_bbox_targets)[
                    gt_argmax_overlaps, :
                ]

            labels_cat_list.append(F.add_axis(labels_cat, 0))
            bbox_targets_list.append(F.add_axis(bbox_targets, 0))

        return (
            F.zero_grad(F.concat(labels_cat_list, axis=0)),
            F.zero_grad(F.concat(bbox_targets_list, axis=0)),
        )


class RetinaNetConfig:
    def __init__(self):
232 233
        self.backbone = "resnet50"
        self.backbone_pretrained = True
M
MegEngine Team 已提交
234
        self.resnet_norm = "FrozenBN"
235
        self.fpn_norm = None
M
MegEngine Team 已提交
236 237
        self.backbone_freeze_at = 2

238 239 240 241
        # ------------------------ data cfg -------------------------- #
        self.train_dataset = dict(
            name="coco",
            root="train2017",
242
            ann_file="annotations/instances_train2017.json",
243
            remove_images_without_annotations=True,
244 245 246 247
        )
        self.test_dataset = dict(
            name="coco",
            root="val2017",
248
            ann_file="annotations/instances_val2017.json",
249
            remove_images_without_annotations=False,
250
        )
M
MegEngine Team 已提交
251
        self.num_classes = 80
252 253 254 255 256 257 258 259
        self.img_mean = [103.530, 116.280, 123.675]  # BGR
        self.img_std = [57.375, 57.120, 58.395]
        self.stride = [8, 16, 32, 64, 128]
        self.reg_mean = [0.0, 0.0, 0.0, 0.0]
        self.reg_std = [1.0, 1.0, 1.0, 1.0]

        self.anchor_scales = [2 ** 0, 2 ** (1 / 3), 2 ** (2 / 3)]
        self.anchor_ratios = [0.5, 1, 2]
M
MegEngine Team 已提交
260 261 262 263 264 265
        self.negative_thresh = 0.4
        self.positive_thresh = 0.5
        self.allow_low_quality = True
        self.class_aware_box = False
        self.cls_prior_prob = 0.01

266
        # ------------------------ loss cfg -------------------------- #
267
        self.loss_normalizer_momentum = 0.9  # 0.0 means disable EMA normalizer
M
MegEngine Team 已提交
268 269
        self.focal_loss_alpha = 0.25
        self.focal_loss_gamma = 2
270 271
        self.smooth_l1_beta = 0  # use L1 loss
        self.reg_loss_weight = 1.0
272
        self.num_losses = 3
M
MegEngine Team 已提交
273 274

        # ------------------------ training cfg ---------------------- #
275
        self.train_image_short_size = (640, 672, 704, 736, 768, 800)
276 277
        self.train_image_max_size = 1333

M
MegEngine Team 已提交
278 279 280 281 282 283
        self.basic_lr = 0.01 / 16.0  # The basic learning rate for single-image
        self.momentum = 0.9
        self.weight_decay = 1e-4
        self.log_interval = 20
        self.nr_images_epoch = 80000
        self.max_epoch = 18
284
        self.warm_iters = 500
M
MegEngine Team 已提交
285
        self.lr_decay_rate = 0.1
286
        self.lr_decay_stages = [12, 16, 17]
M
MegEngine Team 已提交
287

288
        # ------------------------ testing cfg ----------------------- #
M
MegEngine Team 已提交
289 290 291 292 293 294
        self.test_image_short_size = 800
        self.test_image_max_size = 1333
        self.test_max_boxes_per_image = 100
        self.test_vis_threshold = 0.3
        self.test_cls_threshold = 0.05
        self.test_nms = 0.5