From f64a4ccbd2e1de0839451fcd36e553f798a7411a Mon Sep 17 00:00:00 2001 From: greatlog Date: Sun, 28 Jun 2020 14:51:55 +0800 Subject: [PATCH] fix(keypoints) fix bug in test (#36) --- README.md | 12 ++++++++ official/vision/keypoints/README.md | 30 ++++++++++++------- official/vision/keypoints/models/__init__.py | 2 +- official/vision/keypoints/models/mspn.py | 2 +- .../vision/keypoints/models/simplebaseline.py | 6 ++-- official/vision/keypoints/test.py | 9 ++++-- 6 files changed, 42 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index cfcfc3f..09b9cbc 100644 --- a/README.md +++ b/README.md @@ -88,7 +88,19 @@ export PYTHONPATH=/path/to/models:$PYTHONPATH | :--: |:--: |:--: |:--: | | Deeplabv3plus | Resnet101 | 79.0 | 79.8 | +<<<<<<< HEAD +<<<<<<< HEAD +<<<<<<< HEAD ### 人体关节点检测 +======= +### 人体关节点 +>>>>>>> update readme +======= +### 人体关节点 +>>>>>>> update readme +======= +### 人体关节点检测 +>>>>>>> 3fdaf98eee3169f70ace463d54cd177ee1fcf68e 我们提供了人体关节点检测的经典模型[SimpleBaseline](https://arxiv.org/pdf/1804.06208.pdf)和高精度模型[MSPN](https://arxiv.org/pdf/1901.00148.pdf),使用在COCO val2017上人体检测AP为56的检测结果,提供的模型在COCO val2017上的关节点检测结果为: diff --git a/official/vision/keypoints/README.md b/official/vision/keypoints/README.md index ba74c0c..e209b78 100644 --- a/official/vision/keypoints/README.md +++ b/official/vision/keypoints/README.md @@ -38,12 +38,12 @@ ${COCO_DATA_ROOT} | |-- person_keypoints_val2017.json |-- person_detection_results | |-- COCO_val2017_detections_AP_H_56_person.json -|-- |-- train2017 - | |-- 000000000009.jpg - | |-- 000000000025.jpg - | |-- 000000000030.jpg - | |-- ... - |-- val2017 +|-- train2017 +| | |-- 000000000009.jpg +| | |-- 000000000025.jpg +| | |-- 000000000030.jpg +| | |-- ... +|-- val2017 |-- 000000000139.jpg |-- 000000000285.jpg |-- 000000000632.jpg @@ -79,19 +79,27 @@ python3 train.py --arch mspn_4stage \ ## 如何测试 -模型训练好之后,可以通过如下命令测试模型在COCOval2017验证集的性能: - +模型训练好之后,可以通过如下命令测试指定模型在COCOval2017验证集的性能: ```bash python3 test.py --arch name/of/network \ --model /path/to/model.pkl \ --dt_file /name/human/detection/results ``` - `test.py`的命令行参数如下: - `--arch`, 网络的名字; - `--model`, 待检测的模; - `--dt_path`,人体检测结果. +也可以连续验证多个模型的性能: + +```bash +python3 test.py --arch name/of/network \ + --model_dir path/of/saved/models \ + --start_epoch num/of/start/epoch \ + --end_epoch num/of/end/epoch \ + --test_freq test/frequence +``` + ## 如何使用 模型训练好之后,可以通过如下命令测试单张图片(先使用预训练的RetainNet检测出人的框),得到人体姿态可视化结果: @@ -111,5 +119,5 @@ python3 inference.py --arch /name/of/tested/network \ ## 参考文献 -- [Simple Baselines for Human Pose Estimation and Tracking](https://arxiv.org/pdf/1804.06208.pdf), Bin Xiao, Haiping Wu, and Yichen Wei -- [Rethinking on Multi-Stage Networks for Human Pose Estimation](https://arxiv.org/pdf/1901.00148.pdf) Wenbo Li1, Zhicheng Wang, Binyi Yin, Qixiang Peng, Yuming Du, Tianzi Xiao, Gang Yu, Hongtao Lu, Yichen Wei and Jian Sun \ No newline at end of file +- [Simple Baselines for Human Pose Estimation and Tracking](https://arxiv.org/pdf/1804.06208.pdf) Bin Xiao, Haiping Wu, and Yichen Wei +- [Rethinking on Multi-Stage Networks for Human Pose Estimation](https://arxiv.org/pdf/1901.00148.pdf) Wenbo Li1, Zhicheng Wang, Binyi Yin, Qixiang Peng, Yuming Du, Tianzi Xiao, Gang Yu, Hongtao Lu, Yichen Wei and Jian Sun diff --git a/official/vision/keypoints/models/__init__.py b/official/vision/keypoints/models/__init__.py index 3fd8456..76c885d 100644 --- a/official/vision/keypoints/models/__init__.py +++ b/official/vision/keypoints/models/__init__.py @@ -11,5 +11,5 @@ from .simplebaseline import ( simplebaseline_res101, simplebaseline_res152, ) - from .mspn import mspn_4stage + diff --git a/official/vision/keypoints/models/mspn.py b/official/vision/keypoints/models/mspn.py index 4c9ed0f..6deae88 100644 --- a/official/vision/keypoints/models/mspn.py +++ b/official/vision/keypoints/models/mspn.py @@ -243,7 +243,7 @@ class MSPN(M.Module): @hub.pretrained( - "https://data.megengine.org.cn/models/weights/mspn_4stage_256x192_0_255_75_2.pkl" + "https://data.megengine.org.cn/models/weights/keypoint_models/mspn_4stage_0_255_75_2.pkl" ) def mspn_4stage(**kwargs): model = MSPN( diff --git a/official/vision/keypoints/models/simplebaseline.py b/official/vision/keypoints/models/simplebaseline.py index 4220698..1856102 100644 --- a/official/vision/keypoints/models/simplebaseline.py +++ b/official/vision/keypoints/models/simplebaseline.py @@ -110,7 +110,7 @@ cfg = SimpleBaseline_Config() @hub.pretrained( - "https://data.megengine.org.cn/models/weights/simplebaseline50_256x192_0_255_71_2.pkl" + "https://data.megengine.org.cn/models/weights/keypoint_models/simplebaseline50_256x192_0_255_71_2.pkl" ) def simplebaseline_res50(**kwargs): @@ -119,7 +119,7 @@ def simplebaseline_res50(**kwargs): @hub.pretrained( - "https://data.megengine.org.cn/models/weights/simplebaseline101_256x192_0_255_72_2.pkl" + "https://data.megengine.org.cn/models/weights/keypoint_models/simplebaseline101_256x192_0_255_72_2.pkl" ) def simplebaseline_res101(**kwargs): @@ -128,7 +128,7 @@ def simplebaseline_res101(**kwargs): @hub.pretrained( - "https://data.megengine.org.cn/models/weights/simplebaseline152_256x192_0_255_72_4.pkl" + "https://data.megengine.org.cn/models/weights/keypoint_models/simplebaseline152_256x192_0_255_72_4.pkl" ) def simplebaseline_res152(**kwargs): diff --git a/official/vision/keypoints/test.py b/official/vision/keypoints/test.py index 460ab88..bc0f036 100644 --- a/official/vision/keypoints/test.py +++ b/official/vision/keypoints/test.py @@ -221,6 +221,9 @@ def make_parser(): ) parser.add_argument("-se", "--start_epoch", default=-1, type=int) parser.add_argument("-ee", "--end_epoch", default=-1, type=int) + parser.add_argument("-md", "--model_dir", default="/data/models/simplebaseline_res50_256x192/", type=str) + parser.add_argument("-tf", "--test_freq", default=1, type=int) + parser.add_argument( "-a", "--arch", @@ -266,12 +269,12 @@ def main(): if args.end_epoch == -1: args.end_epoch = args.start_epoch - for epoch_num in range(args.start_epoch, args.end_epoch + 1): + for epoch_num in range(args.start_epoch, args.end_epoch + 1, args.test_freq): if args.model: model_file = args.model else: - model_file = "log-of-{}/epoch_{}.pkl".format( - os.path.basename(args.file).split(".")[0], epoch_num + model_file = "{}/epoch_{}.pkl".format( + args.model_dir, epoch_num ) logger.info("Load Model : %s completed", model_file) -- GitLab