未验证 提交 ca8b7e92 编写于 作者: J Jianfeng Wang 提交者: GitHub

feat(detection): support Objects365 (#58)

上级 31a9096c
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from official.vision.detection import models
class CustomFasterRCNNConfig(models.FasterRCNNConfig):
def __init__(self):
super().__init__()
# ------------------------ data cfg -------------------------- #
self.train_dataset = dict(
name="objects365",
root="train",
ann_file="annotations/objects365_train_20190423.json",
remove_images_without_annotations=True,
)
self.test_dataset = dict(
name="objects365",
root="val",
ann_file="annotations/objects365_val_20190423.json",
remove_images_without_annotations=False,
)
self.num_classes = 365
# ------------------------ training cfg ---------------------- #
self.nr_images_epoch = 400000
def faster_rcnn_res50_objects365_1x_800size(batch_size=1, **kwargs):
r"""
Faster-RCNN FPN trained from Objects365 dataset.
`"Faster-RCNN" <https://arxiv.org/abs/1506.01497>`_
`"FPN" <https://arxiv.org/abs/1612.03144>`_
"""
cfg = CustomFasterRCNNConfig()
cfg.backbone_pretrained = False
return models.FasterRCNN(cfg, batch_size=batch_size, **kwargs)
Net = models.FasterRCNN
Cfg = CustomFasterRCNNConfig
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from official.vision.detection import models
class CustomRetinaNetConfig(models.RetinaNetConfig):
def __init__(self):
super().__init__()
# ------------------------ data cfg -------------------------- #
self.train_dataset = dict(
name="objects365",
root="train",
ann_file="annotations/objects365_train_20190423.json",
remove_images_without_annotations=True,
)
self.test_dataset = dict(
name="objects365",
root="val",
ann_file="annotations/objects365_val_20190423.json",
remove_images_without_annotations=False,
)
self.num_classes = 365
# ------------------------ training cfg ---------------------- #
self.nr_images_epoch = 400000
def retinanet_res50_objects365_1x_800size(batch_size=1, **kwargs):
r"""
RetinaNet trained from Objects365 dataset.
`"RetinaNet" <https://arxiv.org/abs/1708.02002>`_
`"FPN" <https://arxiv.org/abs/1612.03144>`_
"""
cfg = CustomRetinaNetConfig()
cfg.backbone_pretrained = False
return models.RetinaNet(cfg, batch_size=batch_size, **kwargs)
Net = models.RetinaNet
Cfg = CustomRetinaNetConfig
......@@ -18,6 +18,7 @@ import megengine as mge
from megengine import jit
from megengine.data.dataset import COCO
from official.vision.detection.tools.data_mapper import data_mapper
from official.vision.detection.tools.utils import DetEvaluator
logger = mge.get_logger(__name__)
......@@ -61,7 +62,10 @@ def main():
model.inputs["im_info"].set_value(im_info)
pred_res = evaluator.predict(val_func)
res_img = DetEvaluator.vis_det(
ori_img, pred_res, is_show_label=True, classes=COCO.class_names,
ori_img,
pred_res,
is_show_label=True,
classes=data_mapper[cfg.test_dataset["name"]].class_names,
)
cv2.imwrite("results.jpg", res_img)
......
......@@ -182,7 +182,7 @@ def worker(
result_queue.put_nowait(
{
"det_res": pred_res,
"image_id": int(data_dict[1][2][0].split(".")[0]),
"image_id": int(data_dict[1][2][0].split(".")[0].split("_")[-1]),
}
)
......
......@@ -242,7 +242,7 @@ class DetEvaluator:
dataset_class.class_names[int(box[5])]
]
else:
elem["category_id"] = int(box[5])
elem["category_id"] = int(box[5]) + 1
all_results.append(elem)
return all_results
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册