diff --git a/README.md b/README.md index e29309af4e2f152885b376f103ecb9e4d53873ea..3b7cb7f11b9e4d8f1e3f9af1f233023dfe96bc7e 100644 --- a/README.md +++ b/README.md @@ -32,4 +32,29 @@ ## 项目使用方法 ### 模型训练 -* 根目录下运行命令: python train.py +* 根目录下运行命令: python train.py (注意脚本内相关参数配置 ) + +### 模型推理 +* 根目录下运行命令: python inference.py (注意脚本内相关参数配置 ) + +*------------------------------------------------------------------------- +··· +*建议 +检测手bbox后,crop手图片size预处理方式: + # img 为原图 ,np为numpy + x_min,y_min,x_max,y_max,score = bbox + w_ = max(abs(x_max-x_min),abs(y_max-y_min)) + + w_ = w_*1.1 + + x_mid = (x_max+x_min)/2 + y_mid = (y_max+y_min)/2 + + x1,y1,x2,y2 = int(x_mid-w_/2),int(y_mid-w_/2),int(x_mid+w_/2),int(y_mid+w_/2) + + x1 = np.clip(x1,0,img.shape[1]-1) + x2 = np.clip(x2,0,img.shape[1]-1) + + y1 = np.clip(y1,0,img.shape[0]-1) + y2 = np.clip(y2,0,img.shape[0]-1) +··· diff --git a/hand_data_iter/datasets.py b/hand_data_iter/datasets.py index a6c98d94b2ee1935d32f703a3df0a464ab617fcd..e6e2cd4f131b7a1e3b2d78fe778ae23a45643478 100644 --- a/hand_data_iter/datasets.py +++ b/hand_data_iter/datasets.py @@ -177,7 +177,8 @@ class LoadImagesAndLabels(Dataset): "y":yh, } - draw_bd_handpose(hand_rot,pts_hand,0,0) + draw_bd_handpose(hand_rot,pts_hand,0,0)# 绘制关键点 连线 + cv2.namedWindow("hand_rotd",0) cv2.imshow("hand_rotd",hand_rot) print("hand_rot shape : {}".format(hand_rot.shape)) diff --git a/image/image_2021-02-04_20-05-49.jpg b/image/image_2021-02-04_20-05-49.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a5b456819c9ff81e45032d2ef59217d97ddcd54a Binary files /dev/null and b/image/image_2021-02-04_20-05-49.jpg differ diff --git a/image/image_2021-02-04_20-05-58.jpg b/image/image_2021-02-04_20-05-58.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1e7bcd999e845ba3df3b8cf83965ffdbb5342656 Binary files /dev/null and b/image/image_2021-02-04_20-05-58.jpg differ diff --git a/image/image_2021-02-04_20-06-03.jpg b/image/image_2021-02-04_20-06-03.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2794587338f5ceaf01bed7738dd1619f1a3ba5c1 Binary files /dev/null and b/image/image_2021-02-04_20-06-03.jpg differ diff --git a/image/image_2021-02-04_20-06-04.jpg b/image/image_2021-02-04_20-06-04.jpg new file mode 100644 index 0000000000000000000000000000000000000000..14b6a0cb4695b6e9af030b3896a373547075a85a Binary files /dev/null and b/image/image_2021-02-04_20-06-04.jpg differ diff --git a/image/image_2021-02-04_20-06-05.jpg b/image/image_2021-02-04_20-06-05.jpg new file mode 100644 index 0000000000000000000000000000000000000000..91d1ecb0fc961905a2b8d5e98a553d10ae2b6b39 Binary files /dev/null and b/image/image_2021-02-04_20-06-05.jpg differ diff --git a/image/image_2021-02-04_20-06-13.jpg b/image/image_2021-02-04_20-06-13.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ac4ae9d64cc0a2727a87c41526c88ed8b7524946 Binary files /dev/null and b/image/image_2021-02-04_20-06-13.jpg differ diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..e19943d106550ffadf6b5111b3a69c438eedc7e8 --- /dev/null +++ b/inference.py @@ -0,0 +1,154 @@ +#-*-coding:utf-8-*- +# date:2021-04-5 +# Author: Eric.Lee +# function: Inference + +import os +import argparse +import torch +import torch.nn as nn +import numpy as np + +import time +import datetime +import os +import math +from datetime import datetime +import cv2 +import torch.nn.functional as F + +from models.resnet import resnet50, resnet34 +from utils.common_utils import * +import copy +from hand_data_iter.datasets import draw_bd_handpose + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description=' Project Hand Pose Inference') + + parser.add_argument('--model_path', type=str, default = './weights/resnet50_2021-478.pth', + help = 'model_path') # 模型路径 + parser.add_argument('--model', type=str, default = 'resnet_50', + help = 'model : resnet_x,') # 模型类型 + parser.add_argument('--num_classes', type=int , default = 42, + help = 'num_classes') # 手部21关键点, (x,y)*2 = 42 + parser.add_argument('--GPUS', type=str, default = '0', + help = 'GPUS') # GPU选择 + parser.add_argument('--test_path', type=str, default = './image/', + help = 'test_path') # 测试图片路径 + parser.add_argument('--img_size', type=tuple , default = (256,256), + help = 'img_size') # 输入模型图片尺寸 + parser.add_argument('--vis', type=bool , default = True, + help = 'vis') # 是否可视化图片 + + print('\n/******************* {} ******************/\n'.format(parser.description)) + #-------------------------------------------------------------------------- + ops = parser.parse_args()# 解析添加参数 + #-------------------------------------------------------------------------- + print('----------------------------------') + + unparsed = vars(ops) # parse_args()方法的返回值为namespace,用vars()内建函数化为字典 + for key in unparsed.keys(): + print('{} : {}'.format(key,unparsed[key])) + + #--------------------------------------------------------------------------- + os.environ['CUDA_VISIBLE_DEVICES'] = ops.GPUS + + test_path = ops.test_path # 测试图片文件夹路径 + + #---------------------------------------------------------------- 构建模型 + print('use model : %s'%(ops.model)) + + if ops.model == 'resnet_50': + model_ = resnet50(num_classes = ops.num_classes,img_size=ops.img_size[0]) + elif ops.model == 'resnet_34': + model_ = resnet34(num_classes = ops.num_classes,img_size=ops.img_size[0]) + + use_cuda = torch.cuda.is_available() + + device = torch.device("cuda:0" if use_cuda else "cpu") + model_ = model_.to(device) + model_.eval() # 设置为前向推断模式 + + # print(model_)# 打印模型结构 + + # 加载测试模型 + if os.access(ops.model_path,os.F_OK):# checkpoint + chkpt = torch.load(ops.model_path, map_location=device) + model_.load_state_dict(chkpt) + print('load test model : {}'.format(ops.model_path)) + + #---------------------------------------------------------------- 预测图片 + '''建议 检测手bbox后,crop手图片的预处理方式: + # img 为原图 + x_min,y_min,x_max,y_max,score = bbox + w_ = max(abs(x_max-x_min),abs(y_max-y_min)) + + w_ = w_*1.1 + + x_mid = (x_max+x_min)/2 + y_mid = (y_max+y_min)/2 + + x1,y1,x2,y2 = int(x_mid-w_/2),int(y_mid-w_/2),int(x_mid+w_/2),int(y_mid+w_/2) + + x1 = np.clip(x1,0,img.shape[1]-1) + x2 = np.clip(x2,0,img.shape[1]-1) + + y1 = np.clip(y1,0,img.shape[0]-1) + y2 = np.clip(y2,0,img.shape[0]-1) + ''' + with torch.no_grad(): + idx = 0 + for file in os.listdir(ops.test_path): + if '.jpg' not in file: + continue + idx += 1 + print('{}) image : {}'.format(idx,file)) + img = cv2.imread(ops.test_path + file) + img_width = img.shape[1] + img_height = img.shape[0] + # 输入图片预处理 + img_ = cv2.resize(img, (ops.img_size[1],ops.img_size[0]), interpolation = cv2.INTER_CUBIC) + img_ = img_.astype(np.float32) + img_ = (img_-128.)/256. + + img_ = img_.transpose(2, 0, 1) + img_ = torch.from_numpy(img_) + img_ = img_.unsqueeze_(0) + + if use_cuda: + img_ = img_.cuda() # (bs, 3, h, w) + pre_ = model_(img_.float()) # 模型推理 + output = pre_.cpu().detach().numpy() + output = np.squeeze(output) + + pts_hand = {} #构建关键点连线可视化结构 + for i in range(int(output.shape[0]/2)): + x = (output[i*2+0]*float(img_width)) + y = (output[i*2+1]*float(img_height)) + + pts_hand[str(i)] = {} + pts_hand[str(i)] = { + "x":x, + "y":y, + } + draw_bd_handpose(img,pts_hand,0,0) # 绘制关键点连线 + + #------------- 绘制关键点 + for i in range(int(output.shape[0]/2)): + x = (output[i*2+0]*float(img_width)) + y = (output[i*2+1]*float(img_height)) + + cv2.circle(img, (int(x),int(y)), 3, (255,50,60),-1) + cv2.circle(img, (int(x),int(y)), 1, (255,150,180),-1) + + if ops.vis: + cv2.namedWindow('image',0) + cv2.imshow('image',img) + if cv2.waitKey(600) == 27 : + break + + cv2.destroyAllWindows() + + print('well done ')