anchor.py 4.3 KB
Newer Older
M
MegEngine Team 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from abc import ABCMeta, abstractmethod

import megengine.functional as F
import numpy as np

from megengine.core import tensor, Tensor


class BaseAnchorGenerator(metaclass=ABCMeta):
    """base class for anchor generator.
    """

    def __init__(self):
        pass

    @abstractmethod
    def get_anchors_by_feature(self) -> Tensor:
        pass


class DefaultAnchorGenerator(BaseAnchorGenerator):
    """default retinanet anchor generator.
    This class generate anchors by feature map in level.

    Args:
        base_size (int): anchor base size.
        anchor_scales (np.ndarray): anchor scales based on stride.
            The practical anchor scale is anchor_scale * stride
        anchor_ratios(np.ndarray): anchor aspect ratios.
        offset (float): center point offset.default is 0.

    """

    def __init__(
        self,
        base_size=8,
        anchor_scales: np.ndarray = np.array([2, 3, 4]),
        anchor_ratios: np.ndarray = np.array([0.5, 1, 2]),
        offset: float = 0,
    ):
        super().__init__()
        self.base_size = base_size
        self.anchor_scales = anchor_scales
        self.anchor_ratios = anchor_ratios
        self.offset = offset

    def _whctrs(self, anchor):
        """convert anchor box into (w, h, ctr_x, ctr_y)
        """
        w = anchor[:, 2] - anchor[:, 0] + 1
        h = anchor[:, 3] - anchor[:, 1] + 1
        x_ctr = anchor[:, 0] + 0.5 * (w - 1)
        y_ctr = anchor[:, 1] + 0.5 * (h - 1)

        return w, h, x_ctr, y_ctr

    def get_plane_anchors(self, anchor_scales: np.ndarray):
        """get anchors per location on feature map.
        The anchor number is anchor_scales x anchor_ratios
        """
        base_anchor = tensor([0, 0, self.base_size - 1, self.base_size - 1])
        base_anchor = F.add_axis(base_anchor, 0)
        w, h, x_ctr, y_ctr = self._whctrs(base_anchor)
        # ratio enumerate
        size = w * h
        size_ratios = size / self.anchor_ratios

        ws = size_ratios.sqrt().round()
        hs = (ws * self.anchor_ratios).round()

        # scale enumerate
        anchor_scales = anchor_scales[None, ...]
        ws = F.add_axis(ws, 1)
        hs = F.add_axis(hs, 1)
        ws = (ws * anchor_scales).reshape(-1, 1)
        hs = (hs * anchor_scales).reshape(-1, 1)

        anchors = F.concat(
            [
                x_ctr - 0.5 * (ws - 1),
                y_ctr - 0.5 * (hs - 1),
                x_ctr + 0.5 * (ws - 1),
                y_ctr + 0.5 * (hs - 1),
            ],
            axis=1,
        )

        return anchors.astype(np.float32)

    def get_center_offsets(self, featmap, stride):
        f_shp = featmap.shape
        fm_height, fm_width = f_shp[-2], f_shp[-1]

        shift_x = F.linspace(0, fm_width - 1, fm_width) * stride
        shift_y = F.linspace(0, fm_height - 1, fm_height) * stride

        # make the mesh grid of shift_x and shift_y
        mesh_shape = (fm_height, fm_width)
        broad_shift_x = shift_x.reshape(-1, shift_x.shape[0]).broadcast(*mesh_shape)
        broad_shift_y = shift_y.reshape(shift_y.shape[0], -1).broadcast(*mesh_shape)

        flatten_shift_x = F.add_axis(broad_shift_x.reshape(-1), 1)
        flatten_shift_y = F.add_axis(broad_shift_y.reshape(-1), 1)

        centers = F.concat(
114
            [flatten_shift_x, flatten_shift_y, flatten_shift_x, flatten_shift_y, ],
M
MegEngine Team 已提交
115 116
            axis=1,
        )
117
        centers = centers + self.offset * self.base_size
M
MegEngine Team 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
        return centers

    def get_anchors_by_feature(self, featmap, stride):
        # shifts shape: [A, 4]
        shifts = self.get_center_offsets(featmap, stride)
        # plane_anchors shape: [B, 4], e.g. B=9
        plane_anchors = self.get_plane_anchors(self.anchor_scales * stride)

        all_anchors = F.add_axis(plane_anchors, 0) + F.add_axis(shifts, 1)
        all_anchors = all_anchors.reshape(-1, 4)

        return all_anchors

    def __call__(self, featmap, stride):
        return self.get_anchors_by_feature(featmap, stride)