simplebaseline.py 4.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import megengine as mge
import megengine.functional as F
import megengine.hub as hub
import megengine.module as M
import official.vision.classification.resnet.model as resnet

import numpy as np


class DeconvLayers(M.Module):
    def __init__(self, nf1, nf2s, kernels, num_layers, bias=True, norm=M.BatchNorm2d):
        super(DeconvLayers, self).__init__()
        _body = []
        for i in range(num_layers):
            kernel = kernels[i]
            padding = (
                kernel // 3
            )  # padding=0 when kernel=2 and padding=1 when kernel=4 or kernel=3
            _body += [
                M.ConvTranspose2d(nf1, nf2s[i], kernel, 2, padding, bias=bias),
                norm(nf2s[i]),
                M.ReLU(),
            ]
            nf1 = nf2s[i]
        self.body = M.Sequential(*_body)

    def forward(self, x):
        return self.body(x)


class SimpleBaseline(M.Module):
G
greatlog 已提交
40 41 42 43 44 45
    def __init__(self, backbone, cfg):
        super(SimpleBaseline, self).__init__()
        norm = M.BatchNorm2d
        self.backbone = getattr(resnet, backbone)(
            norm=norm, pretrained=cfg.backbone_pretrained
        )
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
        del self.backbone.fc

        self.cfg = cfg

        self.deconv_layers = DeconvLayers(
            cfg.initial_deconv_channels,
            cfg.deconv_channels,
            cfg.deconv_kernel_sizes,
            cfg.num_deconv_layers,
            cfg.deconv_with_bias,
            norm,
        )
        self.last_layer = M.Conv2d(cfg.deconv_channels[-1], cfg.keypoint_num, 3, 1, 1)

        self._initialize_weights()

        self.inputs = {
            "image": mge.tensor(dtype="float32"),
            "heatmap": mge.tensor(dtype="float32"),
            "heat_valid": mge.tensor(dtype="float32"),
        }

    def calc_loss(self):
        out = self.forward(self.inputs["image"])
        valid = self.inputs["heat_valid"][:, :, None, None]
G
greatlog 已提交
71
        label = self.inputs["heatmap"][:, -1]
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
        loss = F.square_loss(out * valid, label * valid)
        return loss

    def predict(self):
        return self.forward(self.inputs["image"])

    def _initialize_weights(self):

        for k, m in self.deconv_layers.named_modules():
            if isinstance(m, M.ConvTranspose2d):
                M.init.normal_(m.weight, std=0.001)
                if self.cfg.deconv_with_bias:
                    M.init.zeros_(m.bias)
            if isinstance(m, M.BatchNorm2d):
                M.init.ones_(m.weight)
                M.init.zeros_(m.bias)

        M.init.normal_(self.last_layer.weight, std=0.001)
        M.init.zeros_(self.last_layer.bias)

    def forward(self, x):
        f = self.backbone.extract_features(x)["res5"]
        f = self.deconv_layers(f)
        pred = self.last_layer(f)
        return pred


class SimpleBaseline_Config:
    initial_deconv_channels = 2048
    num_deconv_layers = 3
    deconv_channels = [256, 256, 256]
    deconv_kernel_sizes = [4, 4, 4]
    deconv_with_bias = False
    keypoint_num = 17
G
greatlog 已提交
106
    backbone_pretrained = True
107 108 109 110 111 112


cfg = SimpleBaseline_Config()


@hub.pretrained(
G
greatlog 已提交
113
    "https://data.megengine.org.cn/models/weights/keypoint_models/simplebaseline50_256x192_0_255_71_2.pkl"
114 115 116 117 118 119 120 121
)
def simplebaseline_res50(**kwargs):

    model = SimpleBaseline(backbone="resnet50", cfg=cfg, **kwargs)
    return model


@hub.pretrained(
G
greatlog 已提交
122
    "https://data.megengine.org.cn/models/weights/keypoint_models/simplebaseline101_256x192_0_255_72_2.pkl"
123 124 125 126 127 128 129 130
)
def simplebaseline_res101(**kwargs):

    model = SimpleBaseline(backbone="resnet101", cfg=cfg, **kwargs)
    return model


@hub.pretrained(
G
greatlog 已提交
131
    "https://data.megengine.org.cn/models/weights/keypoint_models/simplebaseline152_256x192_0_255_72_4.pkl"
132 133 134 135 136
)
def simplebaseline_res152(**kwargs):

    model = SimpleBaseline(backbone="resnet152", cfg=cfg, **kwargs)
    return model