diff --git a/hubconf.py b/hubconf.py index 3f0d37d5b354983548679a2f785dd5c5301b7555..86f942b4cc9e6c5048d95e2a3b29d82f1fcf212b 100644 --- a/hubconf.py +++ b/hubconf.py @@ -38,3 +38,9 @@ from official.vision.segmentation.deeplabv3plus import ( deeplabv3plus_res101, DeepLabV3Plus, ) + +from official.vision.keypoints.models import ( + simplebaseline_res50, + simplebaseline_res101, + simplebaseline_res152, +) diff --git a/official/vision/keypoints/README.md b/official/vision/keypoints/README.md new file mode 100644 index 0000000000000000000000000000000000000000..abbf32dfd4ad90535fd7c38e4fbb2afee5f018be --- /dev/null +++ b/official/vision/keypoints/README.md @@ -0,0 +1,116 @@ +# Human Pose Esimation + +本目录包含了采用MegEngine实现的经典[SimpleBaseline](https://arxiv.org/pdf/1804.06208.pdf)和[MSPN](https://arxiv.org/pdf/1901.00148.pdf)网络结构,同时提供了在COCO数据集上的完整训练和测试代码。 + +本目录使用了在COCO val2017上的Human AP为56.4的人体检测结果,最后在COCO val2017上人体关节点估计结果为 +|Methods|Backbone|Input Size| AP | Ap .5 | AP .75 | AP (M) | AP (L) | AR | AR .5 | AR .75 | AR (M) | AR (L) | +|---|:---:|---|---|---|---|---|---|---|---|---|---|---| +| SimpleBaseline |Res50 |256x192| 71.2 | 0.887 | 0.779 | 0.673 | 0.785 | 0.782 | 0.932 | 0.839 | 0.730 | 0.854 | +| SimpleBaseline |Res101|256x192| 72.2 | 0.891 | 0.795 | 0.687 | 0.795 | 0.794 | 0.936 | 0.855 | 0.745 | 0.863 | +| SimpleBaseline |Res152|256x192| 72.4 | 0.888 | 0.794 | 0.688 | 0.795 | 0.795 | 0.934 | 0.856 | 0.746 | 0.863 | + + +## 安装和环境配置 + +* 在开始运行本目录下的代码之前,请确保按照[README](../../../../README.md)进行了正确的环境配置。 +* 安装[COCOAPI](https://github.com/cocodataset/cocoapi): +```bash +# COCOAPI=/path/to/clone/cocoapi +git clone https://github.com/cocodataset/cocoapi.git $COCOAPI +cd $COCOAPI/PythonAPI +# Install into global site-packages +make install +# Alternatively, if you do not have permissions or prefer +# not to install the COCO API into global site-packages +python3 setup.py install --user +``` + + +## 如何训练 + +1、在开始训练前,请下载[COCO官方数据集](http://cocodataset.org/#download),并解压到合适的目录下。从[OneDrive](https://1drv.ms/f/s!AhIXJn_J-blWzzDXoz5BeFl8sWM-) 或者 [GoogleDrive](https://drive.google.com/drive/folders/1fRUDNUDxe9fjqcRZ2bnF_TKMlO0nB_dk?usp=sharing)下载COCO val2017上人体检测的结果,该结果在COCO val2017上人体检测AP为56. + +准备好的 COCO 数据目录结构如下: +```bash +${COCO_DATA_ROOT} +|-- annotations +| |-- person_keypoints_train2017.json +| |-- person_keypoints_val2017.json +|-- person_detection_results +| |-- COCO_val2017_detections_AP_H_56_person.json +|-- images + |-- train2017 + | |-- 000000000009.jpg + | |-- 000000000025.jpg + | |-- 000000000030.jpg + | |-- ... + |-- val2017 + |-- 000000000139.jpg + |-- 000000000285.jpg + |-- 000000000632.jpg + |-- ... +``` + + +3、开始训练: + +`train.py`的命令行参数如下: +- `--arch`, 训练的模型的名字 +- `--data_root`,COCO数据集里`images`的路径; +- `--ann_file`, COCO数据集里标注文件的`json`路径 +- `--batch_size`,训练时采用的batch size, 默认32; +- `--ngpus`, 训练时采用的gpu数量,默认8; 当设置为1时,表示单卡训练 +- `--continue`, 是否从已训好的模型继续训练; +- `--epochs`, 需要训练的epoch数量; +- `--lr`, 初始学习率; + +```bash +python3 train.py --arch name/of/model \ + --data_root /path/to/COCO/images \ + --ann_file /path/to/person_keypoints.json \ + --batch_size 32 \ + --lr 0.0003 \ + --ngpus 8 \ + --epochs 200 \ + --continue /path/to/model +``` + +## 如何测试 + +模型训练好之后,可以通过如下命令测试模型在COCOval2017验证集的性能: + +```bash +python3 test.py --arch name/of/model \ + --data_root /path/to/COCO/images \ + --model /path/to/model.pkl \ + --gt_path /path/to/ground/truth/annotations + --dt_path /path/to/human/detection/results +``` + +`test.py`的命令行参数如下: +- `--arch`, 训练的模型的名字 +- `--data_root`,COCO数据集里`images`的路径; +- `--gt_path`, COCO数据集里验证集的标注文件; +- `--dt_path`,人体检测结果; +- `--model`, 待检测的模型 + +## 如何使用 + +模型训练好之后,可以通过如下命令测试单张图片(先使用预训练的RetainNet检测出人的框),得到人体姿态可视化结果: + +```bash +python3 inference.py --arch /name/of/tested/model \ + --model /path/to/model \ + --image /path/to/image.jpg +``` + +`inference.py`的命令行参数如下: +- `--arch`, 网络的名字; +- `--model`,载入训练好的模型; +- `--image`,载入待测试的图像 + +## 参考文献 + +- [Simple Baselines for Human Pose Estimation and Tracking](https://arxiv.org/pdf/1804.06208.pdf), Bin Xiao, Haiping Wu, and Yichen Wei +- [Rethinking on Multi-Stage Networks for Human Pose Estimation](https://arxiv.org/pdf/1901.00148.pdf) Wenbo Li1, Zhicheng Wang, Binyi Yin, Qixiang Peng, Yuming Du, Tianzi Xiao, Gang Yu, Hongtao Lu, Yichen Wei and Jian Sun + diff --git a/official/vision/keypoints/config.py b/official/vision/keypoints/config.py new file mode 100644 index 0000000000000000000000000000000000000000..52edfa063820576fd19f6a5c5941d7dcca78fe87 --- /dev/null +++ b/official/vision/keypoints/config.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +class Config: + ##############3## train ############################################## + lr_ratio = 0.1 + warm_epochs = 1 + weight_decay = 1e-5 + + half_body_transform = True + extend_boxes = True + + ################## data ############################################### + # basic + # normalize + IMG_MEAN = [0.485 * 255, 0.456 * 255, 0.406 * 255] + IMG_STD = [0.229 * 255, 0.224 * 255, 0.225 * 255] + + # shape + input_shape = (256, 192) + output_shape = (64, 48) + + # heat maps + keypoint_num = 17 + heat_kernel = 1.5 + heat_thr = 1e-2 + heat_range = 255 + + ##################### augumentation ##################################### + # extend + x_ext = 0.6 + y_ext = 0.6 + + # half body + num_keypoints_half_body = 3 + prob_half_body = 0.3 + upper_body_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] + lower_body_ids = [11, 12, 13, 14, 15, 16] + + keypoint_flip_order = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15] + + # scale + scale_prob = 1 + scale_range = [0.7, 1.3] + + # rorate + rotation_prob = 0.6 + rotate_range = [-45, 45] + + ############## testing settings ########################################## + test_aug_border = 10 + test_x_ext = 0.10 + test_y_ext = 0.10 + test_gaussian_kernel = 17 + second_value_aug = True diff --git a/official/vision/keypoints/dataset.py b/official/vision/keypoints/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9318b4d5c06fcd89b3cbba46c2152ad641451d55 --- /dev/null +++ b/official/vision/keypoints/dataset.py @@ -0,0 +1,238 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import megengine as mge +from megengine.data.dataset.vision.meta_vision import VisionDataset +from megengine.data import Collator + +import numpy as np +import cv2 +import os.path as osp +import json +from collections import defaultdict, OrderedDict + + +class COCOJoints(VisionDataset): + """ + we cannot use the official implementation of COCO dataset here. + The output of __getitem__ function here should be a single person instead of a single image. + """ + + supported_order = ("image", "keypoints", "boxes", "info") + + keypoint_names = ( + "nose", + "left_eye", + "right_eye", + "left_ear", + "right_ear", + "left_shoulder", + "right_shoulder", + "left_elbow", + "right_elbow", + "left_wrist", + "right_wrist", + "left_hip", + "right_hip", + "left_knee", + "right_knee", + "left_ankle", + "right_ankle", + ) + + min_bbox_h = 0 + min_bbox_w = 0 + min_box_area = 1500 + min_bbox_score = 1e-10 + + def __init__( + self, root, ann_file, order, image_set="train", remove_untypical_ann=True + ): + + super(COCOJoints, self).__init__( + root, order=order, supported_order=self.supported_order + ) + + self.keypoint_num = len(self.keypoint_names) + self.root = root + self.image_set = image_set + self.order = order + + if isinstance(ann_file, str): + with open(ann_file, "r") as f: + dataset = json.load(f) + else: + dataset = ann_file + + self.imgs = OrderedDict() + + for img in dataset["images"]: + # for saving memory + if "license" in img: + del img["license"] + if "coco_url" in img: + del img["coco_url"] + if "date_captured" in img: + del img["date_captured"] + if "flickr_url" in img: + del img["flickr_url"] + self.imgs[img["id"]] = img + + self.ids = list(sorted(self.imgs.keys())) + + selected_anns = [] + for ann in dataset["annotations"]: + if "image_id" in ann.keys() and ann["image_id"] not in self.ids: + continue + + if "iscrowd" in ann.keys() and ann["iscrowd"]: + continue + + if remove_untypical_ann: + if "keypoints" in ann.keys() and "keypoints" in self.order: + joints = np.array(ann["keypoints"]).reshape(self.keypoint_num, 3) + if np.sum(joints[:, -1]) == 0 or ann["num_keypoints"] == 0: + continue + + if "bbox" in ann.keys() and "bbox" in self.order: + x, y, h, w = ann["bbox"] + if ( + h < self.min_bbox_h + or w < self.min_bbox_w + or h * w < self.min_bbox_area + ): + continue + + if "score" in ann.keys() and "score" in self.order: + if ann["score"] < self.min_bbox_score: + continue + + selected_anns.append(ann) + self.anns = selected_anns + + def __len__(self): + return len(self.anns) + + def get_image_info(self, index): + img_id = self.anns[index]["image_id"] + img_info = self.imgs[img_id] + return img_info + + def __getitem__(self, index): + + ann = self.anns[index] + img_id = ann["image_id"] + + target = [] + for k in self.order: + if k == "image": + + file_name = self.imgs[img_id]["file_name"] + img_path = osp.join(self.root, self.image_set, file_name) + image = cv2.imread(img_path, cv2.IMREAD_COLOR) + target.append(image) + + elif k == "keypoints": + joints = ( + np.array(ann["keypoints"]) + .reshape(len(self.keypoint_names), 3) + .astype(np.float) + ) + joints = joints[np.newaxis] + target.append(joints) + + elif k == "boxes": + x, y, w, h = np.array(ann["bbox"]).reshape(4) + bbox = [x, y, x + w, y + h] + bbox = np.array(bbox, dtype=np.float32) + target.append(bbox[np.newaxis]) + + elif k == "info": + info = self.imgs[img_id] + info = [ + info["height"], + info["width"], + info["file_name"], + ann["image_id"], + ] + if "score" in ann.keys(): + info.append(ann["score"]) + target.append(info) + + return tuple(target) + + +class HeatmapCollator(Collator): + def __init__( + self, + image_shape, + heatmap_shape, + keypoint_num, + heat_thr, + heat_kernel, + heat_range=255, + ): + super().__init__() + self.image_shape = image_shape + self.heatmap_shape = heatmap_shape + self.keypoint_num = keypoint_num + self.heat_thr = heat_thr + self.heat_kernel = heat_kernel + self.heat_range = heat_range + + self.stride = image_shape[1] // heatmap_shape[1] + + x = np.arange(0, heatmap_shape[1], 1) + y = np.arange(0, heatmap_shape[0], 1) + + grid_x, grid_y = np.meshgrid(x, y) + + self.grid_x = grid_x[None].repeat(keypoint_num, 0) + self.grid_y = grid_y[None].repeat(keypoint_num, 0) + + def apply(self, inputs): + """ + assume order = ("images, keypoints, bboxes, info") + """ + batch_data = defaultdict(list) + + for image, keypoints, _, info in inputs: + + batch_data["data"].append(image) + + joints = (keypoints[0, :, :2] + 0.5) / self.stride - 0.5 + heat_valid = np.array(keypoints[0, :, -1]).astype(np.float32) + dis = (self.grid_x - joints[:, 0, np.newaxis, np.newaxis]) ** 2 + ( + self.grid_y - joints[:, 1, np.newaxis, np.newaxis] + ) ** 2 + heatmaps = [] + for k in self.heat_kernel: + heatmap = np.exp(-dis / 2 / k ** 2) + heatmap[heatmap < self.heat_thr] = 0 + heatmap[heat_valid == 0] = 0 + sum_for_norm = heatmap.sum((1, 2)) + heatmap[sum_for_norm > 0] = ( + heatmap[sum_for_norm > 0] + / sum_for_norm[sum_for_norm > 0][:, None, None] + ) + maxi = np.max(heatmap, (1, 2)) + heatmap[maxi > 1e-5] = ( + heatmap[maxi > 1e-5] + / maxi[:, None, None][maxi > 1e-5] + * self.heat_range + ) + heatmaps.append(heatmap) + + batch_data["heatmap"].append(np.array(heatmaps)) + batch_data["heat_valid"].append(heat_valid) + batch_data["info"].append(info) + + for key, v in batch_data.items(): + if key != "info": + batch_data[key] = np.ascontiguousarray(v).astype(np.float32) + return batch_data diff --git a/official/vision/keypoints/inference.py b/official/vision/keypoints/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..006418ab06565e2af394e7881cf3e92880fa786d --- /dev/null +++ b/official/vision/keypoints/inference.py @@ -0,0 +1,187 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import argparse +import json +import os + +import cv2 +import megengine as mge +import numpy as np +from megengine import jit +import math + +from official.vision.keypoints.transforms import get_affine_transform +from official.vision.keypoints.config import Config as cfg + +import official.vision.keypoints.models as M +import official.vision.detection.retinanet_res50_1x_800size as Det +from official.vision.detection.tools.test import DetEvaluator +from official.vision.keypoints.test import find_keypoints + +logger = mge.get_logger(__name__) + + +def make_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-a", + "--arch", + default="simplebaseline_res50", + type=str, + choices=[ + "simplebaseline_res50", + "simplebaseline_res101", + "simplebaseline_res152", + ], + ) + parser.add_argument( + "-det", "--detector", default="retinanet_res50_1x_800size", type=str, + ) + + parser.add_argument( + "-m", + "--model", + default="/data/models/simplebaseline_res50_256x192/epoch_199.pkl", + type=str, + ) + parser.add_argument( + "-image", "--image", default="/data/test_keyoint.jpeg", type=str + ) + return parser + + +def vis_skeleton(img, all_keypoints): + + canvas = img.copy() + for keypoints in all_keypoints: + for ind, skeleton in enumerate(cfg.vis_skeletons): + jotint1 = skeleton[0] + jotint2 = skeleton[1] + + X = np.array([keypoints[jotint1, 0], keypoints[jotint2, 0]]) + + Y = np.array([keypoints[jotint1, 1], keypoints[jotint2, 1]]) + + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + + angle = math.degrees(math.atan2(Y[0] - Y[1], X[0] - X[1])) + polygon = cv2.ellipse2Poly( + (int(mX), int(mY)), (int(length / 2), 4), int(angle), 0, 360, 1 + ) + + cur_canvas = canvas.copy() + cv2.fillConvexPoly(cur_canvas, polygon, cfg.vis_colors[ind]) + canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0) + + return canvas + + +def main(): + + parser = make_parser() + args = parser.parse_args() + + detector = getattr(Det, args.detector)(pretrained=True) + detector.eval() + logger.info("Load Model : %s completed", args.detector) + + keypoint_model = getattr(M, args.arch)() + keypoint_model.load_state_dict(mge.load(args.model)["state_dict"]) + keypoint_model.eval() + logger.info("Load Model : %s completed", args.arch) + + @jit.trace(symbolic=True) + def det_func(): + pred = detector(detector.inputs) + return pred + + @jit.trace(symbolic=True) + def keypoint_func(): + pred = keypoint_model.predict() + return pred + + ori_img = cv2.imread(args.image) + data, im_info = DetEvaluator.process_inputs( + ori_img.copy(), + detector.cfg.test_image_short_size, + detector.cfg.test_image_max_size, + ) + detector.inputs["im_info"].set_value(im_info) + detector.inputs["image"].set_value(data.astype(np.float32)) + + logger.info("Detecting Humans") + evaluator = DetEvaluator(detector) + det_res = evaluator.predict(det_func) + + normalized_img = (ori_img - np.array(cfg.IMG_MEAN).reshape(1, 1, 3)) / np.array( + cfg.IMG_STD + ).reshape(1, 1, 3) + + logger.info("Detecting Keypoints") + all_keypoints = [] + for det in det_res: + cls_id = int(det[5] + 1) + if cls_id == 1: + bbox = det[:4] + w = bbox[2] - bbox[0] + h = bbox[3] - bbox[1] + + center_x = (bbox[0] + bbox[2]) / 2 + center_y = (bbox[1] + bbox[3]) / 2 + + extend_w = w * (1 + cfg.test_x_ext) + extend_h = h * (1 + cfg.test_y_ext) + + w_h_ratio = cfg.input_shape[1] / cfg.input_shape[0] + if extend_w / extend_h > w_h_ratio: + extend_h = extend_w / w_h_ratio + else: + extend_w = extend_h * w_h_ratio + + trans = get_affine_transform( + np.array([center_x, center_y]), + np.array([extend_h, extend_w]), + 1, + 0, + cfg.input_shape, + ) + + croped_img = cv2.warpAffine( + normalized_img, + trans, + (int(cfg.input_shape[1]), int(cfg.input_shape[0])), + flags=cv2.INTER_LINEAR, + borderValue=0, + ) + + fliped_img = croped_img[:, ::-1] + keypoint_input = np.stack([croped_img, fliped_img], 0) + keypoint_input = keypoint_input.transpose(0, 3, 1, 2) + keypoint_input = np.ascontiguousarray(keypoint_input).astype(np.float32) + + keypoint_model.inputs["image"].set_value(keypoint_input) + + outs = keypoint_func() + outs = outs.numpy() + pred = outs[0] + fliped_pred = outs[1][cfg.keypoint_flip_order][:, :, ::-1] + pred = (pred + fliped_pred) / 2 + + keypoints = find_keypoints(pred, bbox) + all_keypoints.append(keypoints) + + logger.info("Visualizing") + canvas = vis_skeleton(ori_img, all_keypoints) + cv2.imwrite("vis_skeleton.jpg", canvas) + + +if __name__ == "__main__": + main() diff --git a/official/vision/keypoints/models/__init__.py b/official/vision/keypoints/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..53bb5936b5c59c77bfea56da9dd3efe0cee4c117 --- /dev/null +++ b/official/vision/keypoints/models/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from .simplebaseline import ( + simplebaseline_res50, + simplebaseline_res101, + simplebaseline_res152, +) diff --git a/official/vision/keypoints/models/simplebaseline.py b/official/vision/keypoints/models/simplebaseline.py new file mode 100644 index 0000000000000000000000000000000000000000..6b9d7dcf08f993c107a48cebbbbc5fbbd19d4620 --- /dev/null +++ b/official/vision/keypoints/models/simplebaseline.py @@ -0,0 +1,135 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import megengine as mge +import megengine.functional as F +import megengine.hub as hub +import megengine.module as M +import official.vision.classification.resnet.model as resnet + +import numpy as np +from functools import partial + + +class DeconvLayers(M.Module): + def __init__(self, nf1, nf2s, kernels, num_layers, bias=True, norm=M.BatchNorm2d): + super(DeconvLayers, self).__init__() + _body = [] + for i in range(num_layers): + kernel = kernels[i] + padding = ( + kernel // 3 + ) # padding=0 when kernel=2 and padding=1 when kernel=4 or kernel=3 + _body += [ + M.ConvTranspose2d(nf1, nf2s[i], kernel, 2, padding, bias=bias), + norm(nf2s[i]), + M.ReLU(), + ] + nf1 = nf2s[i] + self.body = M.Sequential(*_body) + + def forward(self, x): + return self.body(x) + + +class SimpleBaseline(M.Module): + def __init__(self, backbone, cfg, pretrained=False): + + norm = partial(M.BatchNorm2d, momentum=cfg.bn_momentum) + self.backbone = getattr(resnet, backbone)(norm=norm, pretrained=pretrained) + del self.backbone.fc + + self.cfg = cfg + + self.deconv_layers = DeconvLayers( + cfg.initial_deconv_channels, + cfg.deconv_channels, + cfg.deconv_kernel_sizes, + cfg.num_deconv_layers, + cfg.deconv_with_bias, + norm, + ) + self.last_layer = M.Conv2d(cfg.deconv_channels[-1], cfg.keypoint_num, 3, 1, 1) + + self._initialize_weights() + + self.inputs = { + "image": mge.tensor(dtype="float32"), + "heatmap": mge.tensor(dtype="float32"), + "heat_valid": mge.tensor(dtype="float32"), + } + + def calc_loss(self): + out = self.forward(self.inputs["image"]) + valid = self.inputs["heat_valid"][:, :, None, None] + label = self.inputs["heatmap"][:, 0] + loss = F.square_loss(out * valid, label * valid) + return loss + + def predict(self): + return self.forward(self.inputs["image"]) + + def _initialize_weights(self): + + for k, m in self.deconv_layers.named_modules(): + if isinstance(m, M.ConvTranspose2d): + M.init.normal_(m.weight, std=0.001) + if self.cfg.deconv_with_bias: + M.init.zeros_(m.bias) + if isinstance(m, M.BatchNorm2d): + M.init.ones_(m.weight) + M.init.zeros_(m.bias) + + M.init.normal_(self.last_layer.weight, std=0.001) + M.init.zeros_(self.last_layer.bias) + + def forward(self, x): + f = self.backbone.extract_features(x)["res5"] + f = self.deconv_layers(f) + pred = self.last_layer(f) + return pred + + +class SimpleBaseline_Config: + initial_deconv_channels = 2048 + num_deconv_layers = 3 + deconv_channels = [256, 256, 256] + deconv_kernel_sizes = [4, 4, 4] + deconv_with_bias = False + bn_momentum = 0.1 + keypoint_num = 17 + + +cfg = SimpleBaseline_Config() + + +@hub.pretrained( + "https://data.megengine.org.cn/models/weights/simplebaseline50_256x192_0_255_71_2.pkl" +) +def simplebaseline_res50(**kwargs): + + model = SimpleBaseline(backbone="resnet50", cfg=cfg, **kwargs) + return model + + +@hub.pretrained( + "https://data.megengine.org.cn/models/weights/simplebaseline101_256x192_0_255_72_2.pkl" +) +def simplebaseline_res101(**kwargs): + + model = SimpleBaseline(backbone="resnet101", cfg=cfg, **kwargs) + return model + + +@hub.pretrained( + "https://data.megengine.org.cn/models/weights/simplebaseline152_256x192_0_255_72_4.pkl" +) +def simplebaseline_res152(**kwargs): + + model = SimpleBaseline(backbone="resnet152", cfg=cfg, **kwargs) + return model diff --git a/official/vision/keypoints/test.py b/official/vision/keypoints/test.py new file mode 100644 index 0000000000000000000000000000000000000000..1871b8b8d398128f46992564670c63e69c946983 --- /dev/null +++ b/official/vision/keypoints/test.py @@ -0,0 +1,333 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import argparse +import importlib +import json +import os +import random +import sys +from multiprocessing import Process, Queue + +import cv2 +import megengine as mge +import numpy as np +from megengine import jit +from megengine.data import DataLoader, SequentialSampler +from megengine.data.dataset import COCO as COCODataset +import megengine.data.transform as T +from tqdm import tqdm + +from official.vision.keypoints.dataset import COCOJoints +from official.vision.keypoints.transforms import RandomBoxAffine, ExtendBoxes +from official.vision.keypoints.config import Config as cfg +import official.vision.keypoints.models as M + + +logger = mge.get_logger(__name__) + + +def build_dataloader(rank, world_size, data_root, ann_file): + val_dataset = COCOJoints( + data_root, ann_file, image_set="val", order=("image", "boxes", "info") + ) + val_sampler = SequentialSampler(val_dataset, 1, world_size=world_size, rank=rank) + val_dataloader = DataLoader( + val_dataset, + sampler=val_sampler, + num_workers=4, + transform=T.Compose( + transforms=[ + T.Normalize(mean=cfg.IMG_MEAN, std=cfg.IMG_STD), + ExtendBoxes( + cfg.test_x_ext, + cfg.test_y_ext, + cfg.input_shape[1] / cfg.input_shape[0], + random_extend_prob=0, + ), + RandomBoxAffine( + degrees=(0, 0), + scale=(1, 1), + output_shape=cfg.input_shape, + rotate_prob=0, + scale_prob=0, + ), + T.ToMode(), + ], + order=("image", "boxes", "info"), + ), + ) + return val_dataloader + + +def find_keypoints(pred, bbox): + + heat_prob = pred.copy() + heat_prob = heat_prob / cfg.heat_range + 1 + + border = cfg.test_aug_border + pred_aug = np.zeros( + (pred.shape[0], pred.shape[1] + 2 * border, pred.shape[2] + 2 * border), + dtype=np.float32, + ) + pred_aug[:, border:-border, border:-border] = pred.copy() + for i in range(pred_aug.shape[0]): + pred_aug[i] = cv2.GaussianBlur( + pred_aug[i], (cfg.test_gaussian_kernel, cfg.test_gaussian_kernel), 0 + ) + + results = np.zeros((pred_aug.shape[0], 3), dtype=np.float32) + for i in range(pred_aug.shape[0]): + lb = pred_aug[i].argmax() + y, x = np.unravel_index(lb, pred_aug[i].shape) + if cfg.second_value_aug: + y -= border + x -= border + + pred_aug[i, y, x] = 0 + lb = pred_aug[i].argmax() + py, px = np.unravel_index(lb, pred_aug[i].shape) + pred_aug[i, py, px] = 0 + + py -= border + y + px -= border + x + ln = (px ** 2 + py ** 2) ** 0.5 + delta = 0.35 + if ln > 1e-3: + x += delta * px / ln + y += delta * py / ln + + lb = pred_aug[i].argmax() + py, px = np.unravel_index(lb, pred_aug[i].shape) + pred_aug[i, py, px] = 0 + + py -= border + y + px -= border + x + ln = (px ** 2 + py ** 2) ** 0.5 + delta = 0.15 + if ln > 1e-3: + x += delta * px / ln + y += delta * py / ln + + lb = pred_aug[i].argmax() + py, px = np.unravel_index(lb, pred_aug[i].shape) + pred_aug[i, py, px] = 0 + + py -= border + y + px -= border + x + ln = (px ** 2 + py ** 2) ** 0.5 + delta = 0.05 + if ln > 1e-3: + x += delta * px / ln + y += delta * py / ln + else: + y -= border + x -= border + x = max(0, min(x, cfg.output_shape[1] - 1)) + y = max(0, min(y, cfg.output_shape[0] - 1)) + skeleton_score = heat_prob[i, int(round(y)), int(round(x))] + + stride = cfg.input_shape[1] / cfg.output_shape[1] + + x = (x + 0.5) * stride - 0.5 + y = (y + 0.5) * stride - 0.5 + + bbox_top_leftx, bbox_top_lefty, bbox_bottom_rightx, bbox_bottom_righty = bbox + x = ( + x / cfg.input_shape[1] * (bbox_bottom_rightx - bbox_top_leftx) + + bbox_top_leftx + ) + y = ( + y / cfg.input_shape[0] * (bbox_bottom_righty - bbox_top_lefty) + + bbox_top_lefty + ) + + results[i, 0] = x + results[i, 1] = y + results[i, 2] = skeleton_score + + return results + + +def find_results(func, img, bbox, info): + outs = func() + outs = outs.numpy() + pred = outs[0] + fliped_pred = outs[1][cfg.keypoint_flip_order][:, :, ::-1] + pred = (pred + fliped_pred) / 2 + + results = find_keypoints(pred, bbox) + + final_score = float(results[:, -1].mean() * info[-1]) + image_id = int(info[-2]) + keypoints = results.copy() + keypoints[:, -1] = 1 + keypoints = keypoints.reshape(-1,).tolist() + instance = { + "image_id": image_id, + "category_id": 1, + "score": final_score, + "keypoints": keypoints, + } + return instance + + +def worker( + arch, model_file, data_root, ann_file, worker_id, total_worker, result_queue, +): + """ + :param net_file: network description file + :param model_file: file of dump weights + :param data_dir: the dataset directory + :param worker_id: the index of the worker + :param total_worker: number of gpu for evaluation + :param result_queue: processing queue + """ + os.environ["CUDA_VISIBLE_DEVICES"] = str(worker_id) + + @jit.trace(symbolic=True, opt_level=2) + def val_func(): + pred = model.predict() + return pred + + model = getattr(M, arch)() + model.eval() + model.load_state_dict(mge.load(model_file)["state_dict"]) + + loader = build_dataloader(worker_id, total_worker, data_root, ann_file) + for data_dict in loader: + img, bbox, info = data_dict + fliped_img = img[:, :, :, ::-1] - np.zeros_like(img) + data = np.concatenate([img, fliped_img], 0) + model.inputs["image"].set_value(np.ascontiguousarray(data).astype(np.float32)) + instance = find_results(val_func, img, bbox[0, 0], info) + + result_queue.put_nowait(instance) + + +def make_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("-n", "--ngpus", default=8, type=int) + parser.add_argument("-d", "--data_root", default="/", type=str) + parser.add_argument( + "-gt", + "--gt_path", + default="/data/coco/annotations/person_keypoints_val2017.json", + type=str, + ) + parser.add_argument( + "-dt", + "--dt_path", + default="/data/coco/person_detection_results/COCO_val2017_detections_AP_H_56_person.json", + type=str, + ) + parser.add_argument("-se", "--start_epoch", default=-1, type=int) + parser.add_argument("-ee", "--end_epoch", default=-1, type=int) + parser.add_argument( + "-a", + "--arch", + default="simplebaseline_res50", + type=str, + choices=[ + "simplebaseline_res50", + "Simplebaseline_res101", + "Simplebaseline_res152", + ], + ) + parser.add_argument( + "-m", + "--model", + default="/data/models/simplebaseline_res50_256x192/epoch_199.pkl", + type=str, + ) + return parser + + +def main(): + from pycocotools.coco import COCO + from pycocotools.cocoeval import COCOeval + + parser = make_parser() + args = parser.parse_args() + + dets = json.load(open(args.dt_path, "r")) + eval_gt = COCO(args.gt_path) + gt = eval_gt.dataset + + dets = [ + i for i in dets if (i["image_id"] in eval_gt.imgs and i["category_id"] == 1) + ] + ann_file = {"images": gt["images"], "annotations": dets} + + if args.end_epoch == -1: + args.end_epoch = args.start_epoch + + for epoch_num in range(args.start_epoch, args.end_epoch + 1): + if args.model: + model_file = args.model + else: + model_file = "log-of-{}/epoch_{}.pkl".format( + os.path.basename(args.file).split(".")[0], epoch_num + ) + logger.info("Load Model : %s completed", model_file) + + all_results = list() + result_queue = Queue(2000) + procs = [] + for i in range(args.ngpus): + proc = Process( + target=worker, + args=( + args.arch, + model_file, + args.data_root, + ann_file, + i, + args.ngpus, + result_queue, + ), + ) + proc.start() + procs.append(proc) + + for _ in tqdm(range(len(dets))): + all_results.append(result_queue.get()) + for p in procs: + p.join() + + json_path = "log-of-{}_epoch_{}.json".format(args.arch, epoch_num) + all_results = json.dumps(all_results) + with open(json_path, "w") as fo: + fo.write(all_results) + logger.info("Save to %s finished, start evaluation!", json_path) + + eval_dt = eval_gt.loadRes(json_path) + cocoEval = COCOeval(eval_gt, eval_dt, iouType="keypoints") + cocoEval.evaluate() + cocoEval.accumulate() + cocoEval.summarize() + metrics = [ + "AP", + "AP@0.5", + "AP@0.75", + "APm", + "APl", + "AR", + "AR@0.5", + "AR@0.75", + "ARm", + "ARl", + ] + logger.info("mmAP".center(32, "-")) + for i, m in enumerate(metrics): + logger.info("|\t%s\t|\t%.03f\t|", m, cocoEval.stats[i]) + logger.info("-" * 32) + + +if __name__ == "__main__": + main() diff --git a/official/vision/keypoints/train.py b/official/vision/keypoints/train.py new file mode 100644 index 0000000000000000000000000000000000000000..763a84a9e472e12864d177c9b82a918722cfdb48 --- /dev/null +++ b/official/vision/keypoints/train.py @@ -0,0 +1,241 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import argparse +import multiprocessing as mp +import os +import shutil +import time +import json +import numpy as np +import cv2 + +import megengine as mge +import megengine.data as data +import megengine.data.transform as T +import megengine.distributed as dist +import megengine.functional as F +import megengine.jit as jit +import megengine.optimizer as optim + +import official.vision.keypoints.models as M +from official.vision.keypoints.transforms import ( + RandomBoxAffine, + RandomHorizontalFlip, + HalfBodyTransform, + ExtendBoxes, +) +from official.vision.keypoints.dataset import COCOJoints, HeatmapCollator +from official.vision.keypoints.config import Config as cfg + +logger = mge.get_logger(__name__) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-a", + "--arch", + default="simplebaseline_res50", + type=str, + choices=[ + "simplebaseline_res50", + "simplebaseline_res101", + "simplebaseline_res152", + ], + ) + parser.add_argument("--pretrained", default=True, type=bool) + parser.add_argument("-s", "--save", default="/data/models", type=str) + parser.add_argument("--data_root", default="/data/coco/images/", type=str) + parser.add_argument( + "--ann_file", + default="/data/coco/annotations/person_keypoints_train2017.json", + type=str, + ) + parser.add_argument("--continue", default=None, type=str) + + parser.add_argument("-b", "--batch_size", default=64, type=int) + parser.add_argument("--lr", default=6e-4, type=float) + parser.add_argument("--epochs", default=200, type=int) + + parser.add_argument("--multi_scale_supervision", default=True, type=bool) + + parser.add_argument("-n", "--ngpus", default=8, type=int) + parser.add_argument("-w", "--workers", default=8, type=int) + parser.add_argument("--report-freq", default=10, type=int) + + args = parser.parse_args() + + model_name = "{}_{}x{}".format(args.arch, cfg.input_shape[0], cfg.input_shape[1]) + save_dir = os.path.join(args.save, model_name) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + mge.set_log_file(os.path.join(save_dir, "log.txt")) + + world_size = mge.get_device_count("gpu") if args.ngpus is None else args.ngpus + + if world_size > 1: + # scale learning rate by number of gpus + args.lr *= world_size + # start distributed training, dispatch sub-processes + processes = [] + for rank in range(world_size): + p = mp.Process(target=worker, args=(rank, world_size, args)) + p.start() + processes.append(p) + + for p in processes: + p.join() + else: + worker(0, 1, args) + + +def worker(rank, world_size, args): + if world_size > 1: + # Initialize distributed process group + logger.info("init distributed process group {} / {}".format(rank, world_size)) + dist.init_process_group( + master_ip="localhost", + master_port=23456, + world_size=world_size, + rank=rank, + dev=rank, + ) + + model_name = "{}_{}x{}".format(args.arch, cfg.input_shape[0], cfg.input_shape[1]) + save_dir = os.path.join(args.save, model_name) + + model = getattr(M, args.arch)(pretrained=args.pretrained) + model.train() + start_epoch = 0 + if args.c is not None: + file = mge.load(args.c) + model.load_state_dict(file["state_dict"]) + start_epoch = file["epoch"] + + optimizer = optim.Adam( + model.parameters(requires_grad=True), lr=args.lr, weight_decay=cfg.weight_decay, + ) + # Build train datasets + logger.info("preparing dataset..") + train_dataset = COCOJoints( + args.data_root, + args.ann_file, + image_set="train", + order=("image", "keypoints", "boxes", "info"), + ) + train_sampler = data.RandomSampler( + train_dataset, batch_size=args.batch_size, drop_last=True + ) + + transforms = [T.Normalize(mean=cfg.IMG_MEAN, std=cfg.IMG_STD)] + if cfg.half_body_transform: + transforms.append( + HalfBodyTransform( + cfg.upper_body_ids, cfg.lower_body_ids, cfg.prob_half_body + ) + ) + if cfg.extend_boxes: + transforms.append( + ExtendBoxes(cfg.x_ext, cfg.y_ext, cfg.input_shape[1] / cfg.input_shape[0]) + ) + transforms += [ + RandomHorizontalFlip(0.5, keypoint_flip_order=cfg.keypoint_flip_order) + ] + transforms += [ + RandomBoxAffine( + degrees=cfg.rotate_range, + scale=cfg.scale_range, + output_shape=cfg.input_shape, + rotate_prob=cfg.rotation_prob, + scale_prob=cfg.scale_prob, + ) + ] + transforms += [T.ToMode()] + + train_queue = data.DataLoader( + train_dataset, + sampler=train_sampler, + num_workers=args.workers, + transform=T.Compose(transforms=transforms, order=train_dataset.order,), + collator=HeatmapCollator( + cfg.input_shape, + cfg.output_shape, + cfg.keypoint_num, + cfg.heat_thre, + cfg.heat_kernel if args.multi_scale_supervision else cfg.heat_kernel[-1:], + cfg.heat_range, + ), + ) + + # Start training + for epoch in range(start_epoch, args.epochs): + loss = train(model, train_queue, optimizer, args, epoch=epoch) + logger.info("Epoch %d Train %.6f ", epoch, loss) + + if rank == 0: # save checkpoint + mge.save( + {"epoch": epoch + 1, "state_dict": model.state_dict(),}, + os.path.join(save_dir, "epoch_{}.pkl".format(epoch)), + ) + + +def train(model, data_queue, optimizer, args, epoch=0): + @jit.trace(symbolic=True, opt_level=2) + def train_func(): + loss = model.calc_loss() + optimizer.backward(loss) # compute gradients + if dist.is_distributed(): # all_reduce_mean + loss = dist.all_reduce_sum(loss, "train_loss") / dist.get_world_size() + return loss + + avg_loss = 0 + total_time = 0 + + t = time.time() + for step, mini_batch in enumerate(data_queue): + + for param_group in optimizer.param_groups: + current_step = epoch * len(data_queue) + step + if current_step < cfg.warm_epochs * len(data_queue): + lr_factor = cfg.lr_ratio + ( + 1 - cfg.lr_ratio + ) * current_step / cfg.warm_epochs / len(data_queue) + else: + lr_factor = 1 - (current_step - len(data_queue) * cfg.warm_epochs) / ( + len(data_queue) * (args.epochs - cfg.warm_epochs) + ) + + lr = args.initial_lr * lr_factor + param_group["lr"] = lr + + lr = optimizer.param_groups[0]["lr"] + model.inputs["image"].set_value(mini_batch["data"]) + model.inputs["heatmap"].set_value(mini_batch["heatmap"]) + model.inputs["heat_valid"].set_value(mini_batch["heat_valid"]) + + optimizer.zero_grad() + loss = train_func() + optimizer.step() + + avg_loss = (avg_loss * step + loss.numpy().item()) / (step + 1) + total_time += time.time() - t + t = time.time() + + if step % args.report_freq == 0 and dist.get_rank() == 0: + logger.info( + "Epoch {} Step {}, LR {:.6f} Loss {:.6f} Elapsed Time {:.3f}s".format( + epoch, step, lr, loss.numpy().item(), total_time + ) + ) + + return avg_loss + + +if __name__ == "__main__": + main() diff --git a/official/vision/keypoints/transforms.py b/official/vision/keypoints/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..4325b96ae9347e9ba7f2ee94252bcb767b6b9c0b --- /dev/null +++ b/official/vision/keypoints/transforms.py @@ -0,0 +1,307 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# --------------------------------------------------------------------- +# Part of the following code in this file refs to human-pose-estimation.pytorch +# MIT License. +# +# Copyright (c) Microsoft +# --------------------------------------------------------------------- +from megengine.data.transform import VisionTransform, RandomHorizontalFlip +from megengine.data.transform.vision import functional as F +import cv2 +import numpy as np + + +def get_dir(src_point, rot_rad): + sn, cs = np.sin(rot_rad), np.cos(rot_rad) + src_result = [0, 0] + src_result[0] = src_point[0] * cs - src_point[1] * sn + src_result[1] = src_point[0] * sn + src_point[1] * cs + return src_result + + +def get_3rd_point(a, b): + direct = a - b + return b + np.array([-direct[1], direct[0]], dtype=np.float32) + + +def get_affine_transform(center, bbox_shape, scale, rot, output_shape, inv=0): + + dst_w = output_shape[1] + dst_h = output_shape[0] + dst_center = np.array([dst_w * 0.5, dst_h * 0.5], dtype=np.float32) + + scale = dst_w / (bbox_shape[1] * scale) + + rot_rad = np.pi * rot / 180 + src_dir = get_dir([0, 1 * -0.5], rot_rad) + dst_dir = np.array([0, scale * -0.5], np.float32) + + src = np.zeros((3, 2), dtype=np.float32) + dst = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + src[1, :] = center + src_dir + dst[0, :] = dst_center + dst[1, :] = dst_center + dst_dir + + src[2:, :] = get_3rd_point(src[0, :], src[1, :]) + dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :]) + + if inv == 0: + trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + else: + trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + + return trans + + +class HalfBodyTransform(VisionTransform): + """ + Randomly select only half of the body (upper or lower) of an annotated person. + It aims to help the model generalize better to obstructed cases. + :param upper_body_ids: id of upper body. + :param lower_body_ids: id of lower body. + :param prob: probability that this transform is performed. + """ + + def __init__(self, upper_body_ids, lower_body_ids, prob=0.3, order=None): + super(HalfBodyTransform, self).__init__() + + self.prob = prob + self.upper_body_ids = upper_body_ids + self.lower_body_ids = lower_body_ids + self.order = order + + def apply(self, input: tuple): + + self.joints = input[self.order.index("keypoints")][0] + self.keypoint_num = self.joints.shape[0] + + self._is_transform = False + if np.random.random() < self.prob: + self._is_transform = True + + return super().apply(input) + + def _apply_image(self, image): + return image + + def _apply_keypoints(self, keypoints): + return keypoints + + def _apply_boxes(self, boxes): + + if self._is_transform: + upper_joints = [] + lower_joints = [] + for joint_id in range(self.keypoint_num): + if self.joints[joint_id, -1] > 0: + if joint_id in self.upper_body_ids: + upper_joints.append(self.joints[joint_id]) + else: + lower_joints.append(self.joints[joint_id]) + + # randomly keep only the upper or lower body + # ensure the selected part has at least 3 joints + if np.random.randn() < 0.5 and len(upper_joints) > 3: + selected_joints = upper_joints + else: + selected_joints = ( + lower_joints if len(lower_joints) > 3 else upper_joints + ) + + selected_joints = np.array(selected_joints, np.float32) + if len(selected_joints) < 3: + return boxes + else: + # Adjust the box to wrap only the selected part + left_top = np.amin(selected_joints[:, :2], axis=0) + right_bottom = np.amax(selected_joints[:, :2], axis=0) + + center = (left_top + right_bottom) / 2 + + w = right_bottom[0] - left_top[0] + h = right_bottom[1] - left_top[1] + + boxes[0] = np.array( + [ + center[0] - w / 2, + center[1] - h / 2, + center[0] + w / 2, + center[1] + h / 2, + ], + dtype=np.float32, + ) + return boxes + else: + return boxes + + +class ExtendBoxes(VisionTransform): + """ + Randomly extends the bounding box for each person, + and transforms the width/height ratio to fixed value. + :param extend_x: ratio that width is extended. + :param extend_y: ratio that height is extended. + :param w_h_ratio: width/height ratio. + :param random_extend_prob: probability that boxes are randomly extended, in which case extend_x and extend_y are the maximum ratios. + """ + + def __init__(self, extend_x, extend_y, w_h_ratio, random_extend_prob=1, order=None): + super(ExtendBoxes, self).__init__() + self.extend_x = extend_x + self.extend_y = extend_y + self.w_h_ratio = w_h_ratio + self.random_extend_prob = random_extend_prob + self.order = order + + def apply(self, input: tuple): + self._rand = 1 + if np.random.random() < self.random_extend_prob: + self._rand = np.random.random() + return super().apply(input) + + def _apply_image(self, image): + return image + + def _apply_keypoints(self, keypoints): + return keypoints + + def _apply_boxes(self, boxes): + for i in range(boxes.shape[0]): + x1, y1, x2, y2 = boxes[i] + center_x = (x1 + x2) / 2 + center_y = (y1 + y2) / 2 + h = y2 - y1 + w = x2 - x1 + extend_h = (1 + self._rand * self.extend_y) * h + extend_w = (1 + self._rand * self.extend_x) * w + + if extend_w > self.w_h_ratio * extend_h: + extend_h = extend_w * 1.0 / self.w_h_ratio + else: + extend_w = extend_h * 1.0 * self.w_h_ratio + + boxes[i] = np.array( + [ + center_x - extend_w / 2, + center_y - extend_h / 2, + center_x + extend_w / 2, + center_y + extend_h / 2, + ], + dtype=np.float32, + ) + return boxes + + +class RandomBoxAffine(VisionTransform): + """ + Randomly scale and rotate the image, then crop out and the person according to its bounding box. + The cropped person is then resized to disired size. + This process is completed mainly by cv2.warpAffine. + :param degrees: tuple, minmum and maximum of rotated angles. + :param scale: tuple, minmum and maximum of scales. + :param ouput_shape: the final desired shape. + :param scale_prob: probability that image is scaled. + :param rotate_prob: probability that image is rotated. + :param bordervalue: value that is used to pad image. + """ + + def __init__( + self, + degrees, + scale, + output_shape, + rotate_prob=1, + scale_prob=1, + bordervalue=0, + order=None, + ): + super(RandomBoxAffine, self).__init__(order) + + self.degrees_range = degrees + self.scale_range = scale + self.output_shape = output_shape + self.rotate_prob = rotate_prob + self.scale_prob = scale_prob + self.bordervalue = bordervalue + self.order = order + + def apply(self, input: tuple): + scale = 1 + is_scale = np.random.random() < self.scale_prob + if is_scale: + scale = np.random.uniform(self.scale_range[0], self.scale_range[1]) + + degree = 0 + is_rotate = np.random.random() < self.rotate_prob + if is_rotate: + degree = np.random.uniform(self.degrees_range[0], self.degrees_range[1]) + + bbox = input[self.order.index("boxes")][0] + + center = np.array( + [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2], dtype=np.float32 + ) + bbox_shape = np.array([bbox[3] - bbox[1], bbox[2] - bbox[0]], dtype=np.float32) + + self.trans = get_affine_transform( + center, bbox_shape, scale, degree, self.output_shape + ) + return super().apply(input) + + def _apply_image(self, image): + img = cv2.warpAffine( + image, + self.trans, + (int(self.output_shape[1]), int(self.output_shape[0])), + flags=cv2.INTER_LINEAR, + borderValue=self.bordervalue, + ) + return img + + def _apply_keypoints(self, keypoints): + keypoints_copy = keypoints.copy() + keypoints_copy[:, :, 2] = 1 + pt = np.matmul(keypoints_copy[:, :, None], self.trans.transpose(1, 0))[ + :, :, 0, :2 + ] + keypoints_copy[:, :, :2] = pt + keypoints_copy[:, :, 2] = keypoints[:, :, 2] + delete_pt = ( + (pt[:, :, 0] < 0) + + (pt[:, :, 0] > self.output_shape[1] - 1) + + (pt[:, :, 1] < 0) + + (pt[:, :, 1] > self.output_shape[0] - 1) + + (keypoints[:, :, 2] == 0) + ) + keypoints_copy[delete_pt] = 0 + return keypoints_copy + + def _apply_boxes(self, boxes): + return boxes + + +class RandomHorizontalFlip(RandomHorizontalFlip): + """Horizontally flip the input data randomly with a given probability. + :param p: probability of the input data being flipped. Default: 0.5 + :param order: The same with :class:`VisionTransform` + """ + + def __init__(self, prob: float = 0.5, *, keypoint_flip_order, order=None): + super().__init__(order) + self.prob = prob + self.keypoint_flip_order = keypoint_flip_order + + def _apply_keypoints(self, keypoints): + if self._flipped: + for i in range(len(keypoints)): + keypoints[i, :, 0] = self._w - keypoints[i, :, 0] - 1 + keypoints[i] = keypoints[i][self.keypoint_flip_order] + return keypoints