未验证 提交 514908dd 编写于 作者: J Jianfeng Wang 提交者: GitHub

feat(detection): support RetinaNet with Objects365 and SyncBN (#29)

* feat(detection): support objects365

* feat(detection): support retinanet with SyncBN

* feat(detection): add GroupedSampler
上级 e9286a5e
......@@ -58,7 +58,7 @@ class FrozenBatchNorm2d(M.Module):
def get_norm(norm, out_channels=None):
"""
Args:
norm (str): currently support "BN" and "FrozenBN"
norm (str): currently support "BN", "SyncBN" and "FrozenBN"
Returns:
M.Module or None: the normalization layer
......@@ -66,7 +66,11 @@ def get_norm(norm, out_channels=None):
if isinstance(norm, str):
if len(norm) == 0:
return None
norm = {"BN": M.BatchNorm2d, "FrozenBN": FrozenBatchNorm2d}[norm]
norm = {
"BN": M.BatchNorm2d,
"SyncBN": M.SyncBatchNorm,
"FrozenBN": FrozenBatchNorm2d
}[norm]
if out_channels is not None:
return norm(out_channels)
else:
......
......@@ -151,7 +151,5 @@ def get_smooth_l1_base(
abs_x = F.abs(x)
in_loss = 0.5 * x ** 2 * sigma2
out_loss = abs_x - 0.5 / sigma2
in_mask = abs_x < cond_point
out_mask = 1 - in_mask
loss = in_loss * in_mask + out_loss * out_mask
loss = F.where(abs_x < cond_point, in_loss, out_loss)
return loss
......@@ -28,7 +28,9 @@ class RetinaNetHead(M.Module):
num_classes = cfg.num_classes
num_convs = 4
prior_prob = cfg.cls_prior_prob
num_anchors = [len(cfg.anchor_ratios) * len(cfg.anchor_scales)] * 5
num_anchors = [len(cfg.anchor_ratios) * len(cfg.anchor_scales)] * len(
input_shape
)
assert (
len(set(num_anchors)) == 1
......
......@@ -53,7 +53,7 @@ class RetinaNet(M.Module):
bottom_up=bottom_up,
in_features=["res3", "res4", "res5"],
out_channels=out_channels,
norm="",
norm=self.cfg.fpn_norm,
top_block=layers.LastLevelP6P7(in_channels_p6p7, out_channels),
)
......@@ -97,7 +97,8 @@ class RetinaNet(M.Module):
]
anchors_list = [
self.anchor_gen(features[i], self.stride_list[i]) for i in range(5)
self.anchor_gen(features[i], self.stride_list[i])
for i in range(len(features))
]
all_level_box_cls = F.sigmoid(F.concat(box_cls_list, axis=1))
......@@ -196,18 +197,19 @@ class RetinaNet(M.Module):
class RetinaNetConfig:
def __init__(self):
self.resnet_norm = "FrozenBN"
self.fpn_norm = ""
self.backbone_freeze_at = 2
# ------------------------ data cfg -------------------------- #
self.train_dataset = dict(
name="coco",
root="train2017",
ann_file="instances_train2017.json"
ann_file="annotations/instances_train2017.json",
)
self.test_dataset = dict(
name="coco",
root="val2017",
ann_file="instances_val2017.json"
ann_file="annotations/instances_val2017.json",
)
self.train_image_short_size = 800
self.train_image_max_size = 1333
......
......@@ -11,33 +11,17 @@ from megengine import hub
from official.vision.detection import models
class CustomRetinaNetConfig(models.RetinaNetConfig):
def __init__(self):
super().__init__()
# ------------------------ data cfg -------------------------- #
self.train_dataset = dict(
name="coco",
root="train2017",
ann_file="annotations/instances_train2017.json"
)
self.test_dataset = dict(
name="coco",
root="val2017",
ann_file="annotations/instances_val2017.json"
)
@hub.pretrained(
"https://data.megengine.org.cn/models/weights/"
"retinanet_d3f58dce_res50_1x_800size_36dot0.pkl"
)
def retinanet_res50_coco_1x_800size(batch_size=1, **kwargs):
r"""ResNet-18 model from
r"""
RetinaNet trained from COCO dataset.
`"RetinaNet" <https://arxiv.org/abs/1708.02002>`_
"""
return models.RetinaNet(RetinaNetConfig(), batch_size=batch_size, **kwargs)
return models.RetinaNet(models.RetinaNetConfig(), batch_size=batch_size, **kwargs)
Net = models.RetinaNet
Cfg = CustomRetinaNetConfig
Cfg = models.RetinaNetConfig
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from megengine import hub
from official.vision.detection import models
class CustomRetinaNetConfig(models.RetinaNetConfig):
def __init__(self):
super().__init__()
self.resnet_norm = "SyncBN"
self.fpn_norm = "SyncBN"
self.backbone_freeze_at = 0
def retinanet_res50_coco_1x_800size_syncbn(batch_size=1, **kwargs):
r"""
RetinaNet with SyncBN trained from COCO dataset.
`"RetinaNet" <https://arxiv.org/abs/1708.02002>`_
"""
return models.RetinaNet(CustomRetinaNetConfig(), batch_size=batch_size, **kwargs)
Net = models.RetinaNet
Cfg = CustomRetinaNetConfig
......@@ -19,23 +19,25 @@ class CustomRetinaNetConfig(models.RetinaNetConfig):
self.train_dataset = dict(
name="objects365",
root="train",
ann_file="annotations/objects365_train_20190423.json"
ann_file="annotations/objects365_train_20190423.json",
)
self.test_dataset = dict(
name="objects365",
root="val",
ann_file="annotations/objects365_val_20190423.json"
ann_file="annotations/objects365_val_20190423.json",
)
self.num_classes = 365
# ------------------------ training cfg ---------------------- #
self.nr_images_epoch = 400000
def retinanet_objects365_res50_1x_800size(batch_size=1, **kwargs):
r"""ResNet-18 model from
r"""
RetinaNet trained from Objects365 dataset.
`"RetinaNet" <https://arxiv.org/abs/1708.02002>`_
"""
return models.RetinaNet(RetinaNetConfig(), batch_size=batch_size, **kwargs)
return models.RetinaNet(CustomRetinaNetConfig(), batch_size=batch_size, **kwargs)
Net = models.RetinaNet
......
......@@ -47,7 +47,10 @@ def main():
current_network = importlib.import_module(os.path.basename(args.file).split(".")[0])
model = current_network.Net(current_network.Cfg(), batch_size=1)
model.eval()
model.load_state_dict(mge.load(args.model)["state_dict"])
state_dict = mge.load(args.model)
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
model.load_state_dict(state_dict)
evaluator = DetEvaluator(model)
......
......@@ -235,7 +235,10 @@ def worker(
model = current_network.Net(current_network.Cfg(), batch_size=1)
model.eval()
evaluator = DetEvaluator(model)
model.load_state_dict(mge.load(model_file)["state_dict"])
state_dict = mge.load(model_file)
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
model.load_state_dict(state_dict)
loader = build_dataloader(worker_id, total_worker, data_dir, model.cfg)
for data_dict in loader:
......
......@@ -232,14 +232,35 @@ def main():
worker(0, 1, args)
def build_sampler(train_dataset, batch_size, aspect_grouping=[1]):
def _compute_aspect_ratios(dataset):
aspect_ratios = []
for i in range(len(dataset)):
info = dataset.get_img_info(i)
aspect_ratios.append(info["height"] / info["width"])
return aspect_ratios
def _quantize(x, bins):
return list(map(lambda y: bisect.bisect_right(sorted(bins), y), x))
if len(aspect_grouping) == 0:
return Infinite(RandomSampler(train_dataset, batch_size, drop_last=True))
aspect_ratios = _compute_aspect_ratios(train_dataset)
group_ids = _quantize(aspect_ratios, aspect_grouping)
return Infinite(GroupedRandomSampler(train_dataset, batch_size, group_ids))
def build_dataloader(batch_size, data_dir, cfg):
train_dataset = data_mapper[cfg.train_dataset["name"]](
os.path.join(data_dir, cfg.train_dataset["name"], cfg.train_dataset["root"]),
os.path.join(data_dir, cfg.train_dataset["name"], cfg.train_dataset["ann_file"]),
os.path.join(
data_dir, cfg.train_dataset["name"], cfg.train_dataset["ann_file"]
),
remove_images_without_annotations=True,
order=["image", "boxes", "boxes_category", "info"],
)
train_sampler = Infinite(RandomSampler(train_dataset, batch_size, drop_last=True))
train_sampler = build_sampler(train_dataset, batch_size)
train_dataloader = DataLoader(
train_dataset,
sampler=train_sampler,
......@@ -259,6 +280,45 @@ def build_dataloader(batch_size, data_dir, cfg):
return {"train": train_dataloader}
class GroupedRandomSampler(RandomSampler):
def __init__(
self,
dataset,
batch_size,
group_ids,
indices=None,
world_size=None,
rank=None,
seed=None,
):
super().__init__(dataset, batch_size, False, indices, world_size, rank, seed)
self.group_ids = group_ids
assert len(group_ids) == len(dataset)
groups = np.unique(self.group_ids).tolist()
# buffer the indices of each group until batch size is reached
self.buffer_per_group = {k: [] for k in groups}
def batch(self):
indices = list(self.sample())
if self.world_size > 1:
indices = self.scatter(indices)
batch_index = []
for ind in indices:
group_id = self.group_ids[ind]
group_buffer = self.buffer_per_group[group_id]
group_buffer.append(ind)
if len(group_buffer) == self.batch_size:
batch_index.append(group_buffer)
self.buffer_per_group[group_id] = []
return iter(batch_index)
def __len__(self):
raise NotImplementedError("len() of GroupedRandomSampler is not well-defined.")
class DetectionPadCollator(Collator):
def __init__(self, pad_value: float = 0.0):
super().__init__()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册