faster_rcnn.py 7.9 KB
Newer Older
1 2 3 4 5 6 7 8 9
# -*- coding:utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
10

11 12 13 14
import megengine as mge
import megengine.functional as F
import megengine.module as M

15
import official.vision.classification.resnet.model as resnet
16 17 18 19 20 21 22 23 24 25
from official.vision.detection import layers


class FasterRCNN(M.Module):
    def __init__(self, cfg, batch_size):
        super().__init__()
        self.cfg = cfg
        cfg.batch_per_gpu = batch_size
        self.batch_size = batch_size
        # ----------------------- build the backbone ------------------------ #
26 27 28
        bottom_up = getattr(resnet, cfg.backbone)(
            norm=layers.get_norm(cfg.resnet_norm), pretrained=cfg.backbone_pretrained
        )
29 30 31 32 33 34 35 36 37

        # ------------ freeze the weights of resnet stage1 and stage 2 ------ #
        if self.cfg.backbone_freeze_at >= 1:
            for p in bottom_up.conv1.parameters():
                p.requires_grad = False
        if self.cfg.backbone_freeze_at >= 2:
            for p in bottom_up.layer1.parameters():
                p.requires_grad = False

38
        # ----------------------- build the FPN ----------------------------- #
39 40 41 42 43
        out_channels = 256
        self.backbone = layers.FPN(
            bottom_up=bottom_up,
            in_features=["res2", "res3", "res4", "res5"],
            out_channels=out_channels,
44
            norm=cfg.fpn_norm,
45 46 47 48 49
            top_block=layers.FPNP6(),
            strides=[4, 8, 16, 32],
            channels=[256, 512, 1024, 2048],
        )

50
        # ----------------------- build the RPN ----------------------------- #
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
        self.RPN = layers.RPN(cfg)

        # ----------------------- build the RCNN head ----------------------- #
        self.RCNN = layers.RCNN(cfg)

        # -------------------------- input Tensor --------------------------- #
        self.inputs = {
            "image": mge.tensor(
                np.random.random([2, 3, 224, 224]).astype(np.float32), dtype="float32",
            ),
            "im_info": mge.tensor(
                np.random.random([2, 5]).astype(np.float32), dtype="float32",
            ),
            "gt_boxes": mge.tensor(
                np.random.random([2, 100, 5]).astype(np.float32), dtype="float32",
            ),
        }

    def preprocess_image(self, image):
70
        padded_image = layers.get_padded_tensor(image, 32, 0.0)
71
        normed_image = (
72 73
            padded_image
            - np.array(self.cfg.img_mean, dtype=np.float32)[None, :, None, None]
74
        ) / np.array(self.cfg.img_std, dtype=np.float32)[None, :, None, None]
75
        return normed_image
76 77

    def forward(self, inputs):
78 79 80
        images = inputs["image"]
        im_info = inputs["im_info"]
        gt_boxes = inputs["gt_boxes"]
81 82 83 84 85 86 87 88 89 90 91 92 93 94
        # process the images
        normed_images = self.preprocess_image(images)
        # normed_images = images
        fpn_features = self.backbone(normed_images)

        if self.training:
            return self._forward_train(fpn_features, im_info, gt_boxes)
        else:
            return self.inference(fpn_features, im_info)

    def _forward_train(self, fpn_features, im_info, gt_boxes):
        rpn_rois, rpn_losses = self.RPN(fpn_features, im_info, gt_boxes)
        rcnn_losses = self.RCNN(fpn_features, rpn_rois, im_info, gt_boxes)

95 96 97 98
        loss_rpn_cls = rpn_losses["loss_rpn_cls"]
        loss_rpn_loc = rpn_losses["loss_rpn_loc"]
        loss_rcnn_cls = rcnn_losses["loss_rcnn_cls"]
        loss_rcnn_loc = rcnn_losses["loss_rcnn_loc"]
99 100 101 102 103 104 105
        total_loss = loss_rpn_cls + loss_rpn_loc + loss_rcnn_cls + loss_rcnn_loc

        loss_dict = {
            "total_loss": total_loss,
            "rpn_cls": loss_rpn_cls,
            "rpn_loc": loss_rpn_loc,
            "rcnn_cls": loss_rcnn_cls,
106
            "rcnn_loc": loss_rcnn_loc,
107 108 109 110 111 112 113 114 115 116 117
        }
        self.cfg.losses_keys = list(loss_dict.keys())
        return loss_dict

    def inference(self, fpn_features, im_info):
        rpn_rois = self.RPN(fpn_features, im_info)
        pred_boxes, pred_score = self.RCNN(fpn_features, rpn_rois)
        # pred_score = pred_score[:, None]
        pred_boxes = pred_boxes.reshape(-1, 4)
        scale_w = im_info[0, 1] / im_info[0, 3]
        scale_h = im_info[0, 0] / im_info[0, 2]
118
        pred_boxes = pred_boxes / F.concat([scale_w, scale_h, scale_w, scale_h], axis=0)
119

120 121 122
        clipped_boxes = layers.get_clipped_box(pred_boxes, im_info[0, 2:4]).reshape(
            -1, self.cfg.num_classes, 4
        )
123 124 125 126 127
        return pred_score, clipped_boxes


class FasterRCNNConfig:
    def __init__(self):
128 129
        self.backbone = "resnet50"
        self.backbone_pretrained = True
130
        self.resnet_norm = "FrozenBN"
131
        self.fpn_norm = None
132 133
        self.backbone_freeze_at = 2

134
        # ------------------------ data cfg -------------------------- #
135 136 137 138
        self.train_dataset = dict(
            name="coco",
            root="train2017",
            ann_file="annotations/instances_train2017.json",
139
            remove_images_without_annotations=True,
140 141 142 143 144
        )
        self.test_dataset = dict(
            name="coco",
            root="val2017",
            ann_file="annotations/instances_val2017.json",
145
            remove_images_without_annotations=False,
146 147
        )
        self.num_classes = 80
148 149
        self.img_mean = [103.530, 116.280, 123.675]  # BGR
        self.img_std = [57.375, 57.120, 58.395]
150 151 152

        # ----------------------- rpn cfg ------------------------- #
        self.anchor_base_size = 16
153 154
        self.anchor_scales = [0.5]
        self.anchor_ratios = [0.5, 1, 2]
155 156
        self.anchor_offset = -0.5

157 158 159
        self.rpn_stride = [4, 8, 16, 32, 64]
        self.rpn_reg_mean = [0.0, 0.0, 0.0, 0.0]
        self.rpn_reg_std = [1.0, 1.0, 1.0, 1.0]
160 161 162 163 164 165 166 167 168 169 170 171
        self.rpn_in_features = ["p2", "p3", "p4", "p5", "p6"]
        self.rpn_channel = 256

        self.rpn_nms_threshold = 0.7
        self.allow_low_quality = True
        self.num_sample_anchors = 256
        self.positive_anchor_ratio = 0.5
        self.rpn_positive_overlap = 0.7
        self.rpn_negative_overlap = 0.3
        self.ignore_label = -1

        # ----------------------- rcnn cfg ------------------------- #
172
        self.pooling_method = "roi_align"
173 174 175 176 177 178 179 180
        self.pooling_size = (7, 7)

        self.num_rois = 512
        self.fg_ratio = 0.5
        self.fg_threshold = 0.5
        self.bg_threshold_high = 0.5
        self.bg_threshold_low = 0.0

181 182
        self.rcnn_reg_mean = [0.0, 0.0, 0.0, 0.0]
        self.rcnn_reg_std = [0.1, 0.1, 0.2, 0.2]
183 184 185 186
        self.rcnn_in_features = ["p2", "p3", "p4", "p5"]
        self.rcnn_stride = [4, 8, 16, 32]

        # ------------------------ loss cfg -------------------------- #
187 188
        self.rpn_smooth_l1_beta = 0  # use L1 loss
        self.rcnn_smooth_l1_beta = 0  # use L1 loss
189
        self.num_losses = 5
190 191

        # ------------------------ training cfg ---------------------- #
192
        self.train_image_short_size = (640, 672, 704, 736, 768, 800)
193 194 195 196 197 198 199 200 201 202 203 204
        self.train_image_max_size = 1333
        self.train_prev_nms_top_n = 2000
        self.train_post_nms_top_n = 1000

        self.basic_lr = 0.02 / 16.0  # The basic learning rate for single-image
        self.momentum = 0.9
        self.weight_decay = 1e-4
        self.log_interval = 20
        self.nr_images_epoch = 80000
        self.max_epoch = 18
        self.warm_iters = 500
        self.lr_decay_rate = 0.1
205
        self.lr_decay_stages = [12, 16, 17]
206

207
        # ------------------------ testing cfg ----------------------- #
208 209 210 211 212 213 214 215 216
        self.test_image_short_size = 800
        self.test_image_max_size = 1333
        self.test_prev_nms_top_n = 1000
        self.test_post_nms_top_n = 1000
        self.test_max_boxes_per_image = 100
        self.test_vis_threshold = 0.3
        self.test_cls_threshold = 0.05
        self.test_nms = 0.5
        self.class_aware_box = True