提交 4f3c266d 编写于 作者: Eric.Lee2021's avatar Eric.Lee2021 🚴🏻

add inference

上级 1195c15b
......@@ -32,4 +32,29 @@
## 项目使用方法
### 模型训练
* 根目录下运行命令: python train.py
* 根目录下运行命令: python train.py (注意脚本内相关参数配置 )
### 模型推理
* 根目录下运行命令: python inference.py (注意脚本内相关参数配置 )
*-------------------------------------------------------------------------
···
*建议
检测手bbox后,crop手图片size预处理方式:
# img 为原图 ,np为numpy
x_min,y_min,x_max,y_max,score = bbox
w_ = max(abs(x_max-x_min),abs(y_max-y_min))
w_ = w_*1.1
x_mid = (x_max+x_min)/2
y_mid = (y_max+y_min)/2
x1,y1,x2,y2 = int(x_mid-w_/2),int(y_mid-w_/2),int(x_mid+w_/2),int(y_mid+w_/2)
x1 = np.clip(x1,0,img.shape[1]-1)
x2 = np.clip(x2,0,img.shape[1]-1)
y1 = np.clip(y1,0,img.shape[0]-1)
y2 = np.clip(y2,0,img.shape[0]-1)
···
......@@ -177,7 +177,8 @@ class LoadImagesAndLabels(Dataset):
"y":yh,
}
draw_bd_handpose(hand_rot,pts_hand,0,0)
draw_bd_handpose(hand_rot,pts_hand,0,0)# 绘制关键点 连线
cv2.namedWindow("hand_rotd",0)
cv2.imshow("hand_rotd",hand_rot)
print("hand_rot shape : {}".format(hand_rot.shape))
......
#-*-coding:utf-8-*-
# date:2021-04-5
# Author: Eric.Lee
# function: Inference
import os
import argparse
import torch
import torch.nn as nn
import numpy as np
import time
import datetime
import os
import math
from datetime import datetime
import cv2
import torch.nn.functional as F
from models.resnet import resnet50, resnet34
from utils.common_utils import *
import copy
from hand_data_iter.datasets import draw_bd_handpose
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=' Project Hand Pose Inference')
parser.add_argument('--model_path', type=str, default = './weights/resnet50_2021-478.pth',
help = 'model_path') # 模型路径
parser.add_argument('--model', type=str, default = 'resnet_50',
help = 'model : resnet_x,') # 模型类型
parser.add_argument('--num_classes', type=int , default = 42,
help = 'num_classes') # 手部21关键点, (x,y)*2 = 42
parser.add_argument('--GPUS', type=str, default = '0',
help = 'GPUS') # GPU选择
parser.add_argument('--test_path', type=str, default = './image/',
help = 'test_path') # 测试图片路径
parser.add_argument('--img_size', type=tuple , default = (256,256),
help = 'img_size') # 输入模型图片尺寸
parser.add_argument('--vis', type=bool , default = True,
help = 'vis') # 是否可视化图片
print('\n/******************* {} ******************/\n'.format(parser.description))
#--------------------------------------------------------------------------
ops = parser.parse_args()# 解析添加参数
#--------------------------------------------------------------------------
print('----------------------------------')
unparsed = vars(ops) # parse_args()方法的返回值为namespace,用vars()内建函数化为字典
for key in unparsed.keys():
print('{} : {}'.format(key,unparsed[key]))
#---------------------------------------------------------------------------
os.environ['CUDA_VISIBLE_DEVICES'] = ops.GPUS
test_path = ops.test_path # 测试图片文件夹路径
#---------------------------------------------------------------- 构建模型
print('use model : %s'%(ops.model))
if ops.model == 'resnet_50':
model_ = resnet50(num_classes = ops.num_classes,img_size=ops.img_size[0])
elif ops.model == 'resnet_34':
model_ = resnet34(num_classes = ops.num_classes,img_size=ops.img_size[0])
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
model_ = model_.to(device)
model_.eval() # 设置为前向推断模式
# print(model_)# 打印模型结构
# 加载测试模型
if os.access(ops.model_path,os.F_OK):# checkpoint
chkpt = torch.load(ops.model_path, map_location=device)
model_.load_state_dict(chkpt)
print('load test model : {}'.format(ops.model_path))
#---------------------------------------------------------------- 预测图片
'''建议 检测手bbox后,crop手图片的预处理方式:
# img 为原图
x_min,y_min,x_max,y_max,score = bbox
w_ = max(abs(x_max-x_min),abs(y_max-y_min))
w_ = w_*1.1
x_mid = (x_max+x_min)/2
y_mid = (y_max+y_min)/2
x1,y1,x2,y2 = int(x_mid-w_/2),int(y_mid-w_/2),int(x_mid+w_/2),int(y_mid+w_/2)
x1 = np.clip(x1,0,img.shape[1]-1)
x2 = np.clip(x2,0,img.shape[1]-1)
y1 = np.clip(y1,0,img.shape[0]-1)
y2 = np.clip(y2,0,img.shape[0]-1)
'''
with torch.no_grad():
idx = 0
for file in os.listdir(ops.test_path):
if '.jpg' not in file:
continue
idx += 1
print('{}) image : {}'.format(idx,file))
img = cv2.imread(ops.test_path + file)
img_width = img.shape[1]
img_height = img.shape[0]
# 输入图片预处理
img_ = cv2.resize(img, (ops.img_size[1],ops.img_size[0]), interpolation = cv2.INTER_CUBIC)
img_ = img_.astype(np.float32)
img_ = (img_-128.)/256.
img_ = img_.transpose(2, 0, 1)
img_ = torch.from_numpy(img_)
img_ = img_.unsqueeze_(0)
if use_cuda:
img_ = img_.cuda() # (bs, 3, h, w)
pre_ = model_(img_.float()) # 模型推理
output = pre_.cpu().detach().numpy()
output = np.squeeze(output)
pts_hand = {} #构建关键点连线可视化结构
for i in range(int(output.shape[0]/2)):
x = (output[i*2+0]*float(img_width))
y = (output[i*2+1]*float(img_height))
pts_hand[str(i)] = {}
pts_hand[str(i)] = {
"x":x,
"y":y,
}
draw_bd_handpose(img,pts_hand,0,0) # 绘制关键点连线
#------------- 绘制关键点
for i in range(int(output.shape[0]/2)):
x = (output[i*2+0]*float(img_width))
y = (output[i*2+1]*float(img_height))
cv2.circle(img, (int(x),int(y)), 3, (255,50,60),-1)
cv2.circle(img, (int(x),int(y)), 1, (255,150,180),-1)
if ops.vis:
cv2.namedWindow('image',0)
cv2.imshow('image',img)
if cv2.waitKey(600) == 27 :
break
cv2.destroyAllWindows()
print('well done ')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册