未验证 提交 48ad276b 编写于 作者: B Bubbliiiing 提交者: GitHub

Add files via upload

上级 0259c0a9
'''
predict.py有几个注意点
1、该代码无法直接进行批量预测,如果想要批量预测,可以利用os.listdir()遍历文件夹,利用Image.open打开图片文件进行预测。
具体流程可以参考get_dr_txt.py,在get_dr_txt.py即实现了遍历还实现了目标信息的保存。
2、如果想要进行检测完的图片的保存,利用r_image.save("img.jpg")即可保存,直接在predict.py里进行修改即可。
3、如果想要获得预测框的坐标,可以进入yolo.detect_image函数,在绘图部分读取top,left,bottom,right这四个值。
4、如果想要利用预测框截取下目标,可以进入yolo.detect_image函数,在绘图部分利用获取到的top,left,bottom,right这四个值
在原图上利用矩阵的方式进行截取。
5、如果想要在预测图上写额外的字,比如检测到的特定目标的数量,可以进入yolo.detect_image函数,在绘图部分对predicted_class进行判断,
比如判断if predicted_class == 'car': 即可判断当前目标是否为车,然后记录数量即可。利用draw.text即可写字。
'''
#----------------------------------------------------#
# 对视频中的predict.py进行了修改,
# 将单张图片预测、摄像头检测和FPS测试功能
# 整合到了一个py文件中,通过指定mode进行模式的修改。
#----------------------------------------------------#
import time
import cv2
import numpy as np
from PIL import Image
from yolo import YOLO
yolo = YOLO()
if __name__ == "__main__":
yolo = YOLO()
#-------------------------------------------------------------------------#
# mode用于指定测试的模式:
# 'predict'表示单张图片预测
# 'video'表示视频检测
# 'fps'表示测试fps
#-------------------------------------------------------------------------#
mode = "fps"
#-------------------------------------------------------------------------#
# video_path用于指定视频的路径,当video_path=0时表示检测摄像头
# video_save_path表示视频保存的路径,当video_save_path=""时表示不保存
# video_fps用于保存的视频的fps
# video_path、video_save_path和video_fps仅在mode='video'时有效
# 保存视频时需要ctrl+c退出才会完成完整的保存步骤,不可直接结束程序。
#-------------------------------------------------------------------------#
video_path = 0
video_save_path = ""
video_fps = 25.0
if mode == "predict":
'''
1、该代码无法直接进行批量预测,如果想要批量预测,可以利用os.listdir()遍历文件夹,利用Image.open打开图片文件进行预测。
具体流程可以参考get_dr_txt.py,在get_dr_txt.py即实现了遍历还实现了目标信息的保存。
2、如果想要进行检测完的图片的保存,利用r_image.save("img.jpg")即可保存,直接在predict.py里进行修改即可。
3、如果想要获得预测框的坐标,可以进入yolo.detect_image函数,在绘图部分读取top,left,bottom,right这四个值。
4、如果想要利用预测框截取下目标,可以进入yolo.detect_image函数,在绘图部分利用获取到的top,left,bottom,right这四个值
在原图上利用矩阵的方式进行截取。
5、如果想要在预测图上写额外的字,比如检测到的特定目标的数量,可以进入yolo.detect_image函数,在绘图部分对predicted_class进行判断,
比如判断if predicted_class == 'car': 即可判断当前目标是否为车,然后记录数量即可。利用draw.text即可写字。
'''
while True:
img = input('Input image filename:')
try:
image = Image.open(img)
except:
print('Open Error! Try again!')
continue
else:
r_image = yolo.detect_image(image)
r_image.show()
elif mode == "video":
capture=cv2.VideoCapture(video_path)
if video_save_path!="":
fourcc = cv2.VideoWriter_fourcc(*'XVID')
size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)
fps = 0.0
while(True):
t1 = time.time()
# 读取某一帧
ref,frame=capture.read()
# 格式转变,BGRtoRGB
frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
# 转变成Image
frame = Image.fromarray(np.uint8(frame))
# 进行检测
frame = np.array(yolo.detect_image(frame))
# RGBtoBGR满足opencv显示格式
frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR)
fps = ( fps + (1./(time.time()-t1)) ) / 2
print("fps= %.2f"%(fps))
frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
cv2.imshow("video",frame)
c= cv2.waitKey(1) & 0xff
if video_save_path!="":
out.write(frame)
if c==27:
capture.release()
break
capture.release()
out.release()
cv2.destroyAllWindows()
while True:
img = input('Input image filename:')
try:
image = Image.open(img)
except:
print('Open Error! Try again!')
continue
elif mode == "fps":
test_interval = 100
img = Image.open('img/street.jpg')
tact_time = yolo.get_FPS(img, test_interval)
print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1')
else:
r_image = yolo.detect_image(image)
r_image.show()
raise AssertionError("Please specify the correct mode: 'predict', 'video' or 'fps'.")
......@@ -3,6 +3,7 @@
#-------------------------------------#
import colorsys
import os
import time
import numpy as np
import torch
......@@ -229,3 +230,85 @@ class YOLO(object):
del draw
return image
def get_FPS(self, image, test_interval):
image_shape = np.array(np.shape(image)[0:2])
#---------------------------------------------------------#
# 给图像增加灰条,实现不失真的resize
# 也可以直接resize进行识别
#---------------------------------------------------------#
if self.letterbox_image:
crop_img = np.array(letterbox_image(image, (self.model_image_size[1],self.model_image_size[0])))
else:
crop_img = image.convert('RGB')
crop_img = crop_img.resize((self.model_image_size[1],self.model_image_size[0]), Image.BICUBIC)
photo = np.array(crop_img,dtype = np.float32) / 255.0
photo = np.transpose(photo, (2, 0, 1))
#---------------------------------------------------------#
# 添加上batch_size维度
#---------------------------------------------------------#
images = [photo]
with torch.no_grad():
images = torch.from_numpy(np.asarray(images))
if self.cuda:
images = images.cuda()
outputs = self.net(images)
output_list = []
for i in range(3):
output_list.append(self.yolo_decodes[i](outputs[i]))
output = torch.cat(output_list, 1)
batch_detections = non_max_suppression(output, len(self.class_names),
conf_thres=self.confidence,
nms_thres=self.iou)
try:
batch_detections = batch_detections[0].cpu().numpy()
top_index = batch_detections[:,4]*batch_detections[:,5] > self.confidence
top_conf = batch_detections[top_index,4]*batch_detections[top_index,5]
top_label = np.array(batch_detections[top_index,-1],np.int32)
top_bboxes = np.array(batch_detections[top_index,:4])
top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(top_bboxes[:,0],-1),np.expand_dims(top_bboxes[:,1],-1),np.expand_dims(top_bboxes[:,2],-1),np.expand_dims(top_bboxes[:,3],-1)
if self.letterbox_image:
boxes = yolo_correct_boxes(top_ymin,top_xmin,top_ymax,top_xmax,np.array([self.model_image_size[0],self.model_image_size[1]]),image_shape)
else:
top_xmin = top_xmin / self.model_image_size[1] * image_shape[1]
top_ymin = top_ymin / self.model_image_size[0] * image_shape[0]
top_xmax = top_xmax / self.model_image_size[1] * image_shape[1]
top_ymax = top_ymax / self.model_image_size[0] * image_shape[0]
boxes = np.concatenate([top_ymin,top_xmin,top_ymax,top_xmax], axis=-1)
except:
pass
t1 = time.time()
for _ in range(test_interval):
with torch.no_grad():
outputs = self.net(images)
output_list = []
for i in range(3):
output_list.append(self.yolo_decodes[i](outputs[i]))
output = torch.cat(output_list, 1)
batch_detections = non_max_suppression(output, len(self.class_names),
conf_thres=self.confidence,
nms_thres=self.iou)
try:
batch_detections = batch_detections[0].cpu().numpy()
top_index = batch_detections[:,4]*batch_detections[:,5] > self.confidence
top_conf = batch_detections[top_index,4]*batch_detections[top_index,5]
top_label = np.array(batch_detections[top_index,-1],np.int32)
top_bboxes = np.array(batch_detections[top_index,:4])
top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(top_bboxes[:,0],-1),np.expand_dims(top_bboxes[:,1],-1),np.expand_dims(top_bboxes[:,2],-1),np.expand_dims(top_bboxes[:,3],-1)
if self.letterbox_image:
boxes = yolo_correct_boxes(top_ymin,top_xmin,top_ymax,top_xmax,np.array([self.model_image_size[0],self.model_image_size[1]]),image_shape)
else:
top_xmin = top_xmin / self.model_image_size[1] * image_shape[1]
top_ymin = top_ymin / self.model_image_size[0] * image_shape[0]
top_xmax = top_xmax / self.model_image_size[1] * image_shape[1]
top_ymax = top_ymax / self.model_image_size[0] * image_shape[0]
boxes = np.concatenate([top_ymin,top_xmin,top_ymax,top_xmax], axis=-1)
except:
pass
t2 = time.time()
tact_time = (t2 - t1) / test_interval
return tact_time
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册