diff --git a/official/vision/detection/configs/faster_rcnn_res50_objects365_1x_800size.py b/official/vision/detection/configs/faster_rcnn_res50_objects365_1x_800size.py new file mode 100644 index 0000000000000000000000000000000000000000..b44671871246e9d8bfaa0e11f5a13bfc0725012d --- /dev/null +++ b/official/vision/detection/configs/faster_rcnn_res50_objects365_1x_800size.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from official.vision.detection import models + + +class CustomFasterRCNNConfig(models.FasterRCNNConfig): + def __init__(self): + super().__init__() + + # ------------------------ data cfg -------------------------- # + self.train_dataset = dict( + name="objects365", + root="train", + ann_file="annotations/objects365_train_20190423.json", + remove_images_without_annotations=True, + ) + self.test_dataset = dict( + name="objects365", + root="val", + ann_file="annotations/objects365_val_20190423.json", + remove_images_without_annotations=False, + ) + self.num_classes = 365 + + # ------------------------ training cfg ---------------------- # + self.nr_images_epoch = 400000 + + +def faster_rcnn_res50_objects365_1x_800size(batch_size=1, **kwargs): + r""" + Faster-RCNN FPN trained from Objects365 dataset. + `"Faster-RCNN" `_ + `"FPN" `_ + """ + cfg = CustomFasterRCNNConfig() + cfg.backbone_pretrained = False + return models.FasterRCNN(cfg, batch_size=batch_size, **kwargs) + + +Net = models.FasterRCNN +Cfg = CustomFasterRCNNConfig diff --git a/official/vision/detection/configs/retinanet_res50_objects365_1x_800size.py b/official/vision/detection/configs/retinanet_res50_objects365_1x_800size.py new file mode 100644 index 0000000000000000000000000000000000000000..dca382c6bc28e67e6f28c00af30232e5f4800c1a --- /dev/null +++ b/official/vision/detection/configs/retinanet_res50_objects365_1x_800size.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from official.vision.detection import models + + +class CustomRetinaNetConfig(models.RetinaNetConfig): + def __init__(self): + super().__init__() + + # ------------------------ data cfg -------------------------- # + self.train_dataset = dict( + name="objects365", + root="train", + ann_file="annotations/objects365_train_20190423.json", + remove_images_without_annotations=True, + ) + self.test_dataset = dict( + name="objects365", + root="val", + ann_file="annotations/objects365_val_20190423.json", + remove_images_without_annotations=False, + ) + self.num_classes = 365 + + # ------------------------ training cfg ---------------------- # + self.nr_images_epoch = 400000 + + +def retinanet_res50_objects365_1x_800size(batch_size=1, **kwargs): + r""" + RetinaNet trained from Objects365 dataset. + `"RetinaNet" `_ + `"FPN" `_ + """ + cfg = CustomRetinaNetConfig() + cfg.backbone_pretrained = False + return models.RetinaNet(cfg, batch_size=batch_size, **kwargs) + + +Net = models.RetinaNet +Cfg = CustomRetinaNetConfig diff --git a/official/vision/detection/tools/inference.py b/official/vision/detection/tools/inference.py index f49f4cc6cd4ab43e8eadf7f3b8abed42c9533dd3..136f2b9be8278026ae8ae72fe65493e5d5eaf38d 100644 --- a/official/vision/detection/tools/inference.py +++ b/official/vision/detection/tools/inference.py @@ -18,6 +18,7 @@ import megengine as mge from megengine import jit from megengine.data.dataset import COCO +from official.vision.detection.tools.data_mapper import data_mapper from official.vision.detection.tools.utils import DetEvaluator logger = mge.get_logger(__name__) @@ -61,7 +62,10 @@ def main(): model.inputs["im_info"].set_value(im_info) pred_res = evaluator.predict(val_func) res_img = DetEvaluator.vis_det( - ori_img, pred_res, is_show_label=True, classes=COCO.class_names, + ori_img, + pred_res, + is_show_label=True, + classes=data_mapper[cfg.test_dataset["name"]].class_names, ) cv2.imwrite("results.jpg", res_img) diff --git a/official/vision/detection/tools/test.py b/official/vision/detection/tools/test.py index 25b4da88949590cebfcd6374f965f0d7c3e4d89f..a7139937334316b9e08738e290272161a56f83ca 100644 --- a/official/vision/detection/tools/test.py +++ b/official/vision/detection/tools/test.py @@ -182,7 +182,7 @@ def worker( result_queue.put_nowait( { "det_res": pred_res, - "image_id": int(data_dict[1][2][0].split(".")[0]), + "image_id": int(data_dict[1][2][0].split(".")[0].split("_")[-1]), } ) diff --git a/official/vision/detection/tools/utils.py b/official/vision/detection/tools/utils.py index d48748a4c5f96959630f437c58564ce0850e13ba..3c6bc9130cff9068ed93b1254668bb88632ab99e 100644 --- a/official/vision/detection/tools/utils.py +++ b/official/vision/detection/tools/utils.py @@ -242,7 +242,7 @@ class DetEvaluator: dataset_class.class_names[int(box[5])] ] else: - elem["category_id"] = int(box[5]) + elem["category_id"] = int(box[5]) + 1 all_results.append(elem) return all_results