diff --git a/FPS_test.py b/FPS_test.py index 5a827c7861599b737717bf40afafb13a2aa2f60b..124232403d1d0158a6f39903dca4268f4a373cd0 100644 --- a/FPS_test.py +++ b/FPS_test.py @@ -17,20 +17,21 @@ video.py里面测试的FPS会低于该FPS,因为摄像头的读取频率有限 ''' class FPS_YOLO(YOLO): def get_FPS(self, image, test_interval): - # 调整图片使其符合输入要求 - new_image_size = (self.model_image_size[1],self.model_image_size[0]) - boxed_image = letterbox_image(image, new_image_size) + if self.letterbox_image: + boxed_image = letterbox_image(image, (self.model_image_size[1],self.model_image_size[0])) + else: + boxed_image = image.convert('RGB') + boxed_image = boxed_image.resize((self.model_image_size[1],self.model_image_size[0]), Image.BICUBIC) image_data = np.array(boxed_image, dtype='float32') image_data /= 255. - image_data = np.expand_dims(image_data, 0) + image_data = np.expand_dims(image_data, 0) out_boxes, out_scores, out_classes = self.sess.run( [self.boxes, self.scores, self.classes], feed_dict={ self.yolo_model.input: image_data, self.input_image_shape: [image.size[1], image.size[0]], - K.learning_phase(): 0 - }) + K.learning_phase(): 0}) t1 = time.time() for _ in range(test_interval): @@ -39,8 +40,7 @@ class FPS_YOLO(YOLO): feed_dict={ self.yolo_model.input: image_data, self.input_image_shape: [image.size[1], image.size[0]], - K.learning_phase(): 0 - }) + K.learning_phase(): 0}) t2 = time.time() tact_time = (t2 - t1) / test_interval return tact_time diff --git a/get_dr_txt.py b/get_dr_txt.py index d7e666a0c664507576aa3d9f05321d406a6e00a3..ade3398508d66b7b7b4aff145ec0d7cb2b2f5dfc 100644 --- a/get_dr_txt.py +++ b/get_dr_txt.py @@ -71,7 +71,7 @@ class mAP_YOLO(YOLO): #---------------------------------------------------------# boxes, scores, classes = yolo_eval(self.yolo_model.output, self.anchors, num_classes, self.input_image_shape, max_boxes = self.max_boxes, - score_threshold = self.score, iou_threshold = self.iou) + score_threshold = self.score, iou_threshold = self.iou, letterbox_image = self.letterbox_image) return boxes, scores, classes #---------------------------------------------------# @@ -79,11 +79,11 @@ class mAP_YOLO(YOLO): #---------------------------------------------------# def detect_image(self, image_id, image): f = open("./input/detection-results/"+image_id+".txt","w") - #---------------------------------------------------------# - # 给图像增加灰条,实现不失真的resize - #---------------------------------------------------------# - new_image_size = (self.model_image_size[1],self.model_image_size[0]) - boxed_image = letterbox_image(image, new_image_size) + if self.letterbox_image: + boxed_image = letterbox_image(image, (self.model_image_size[1],self.model_image_size[0])) + else: + boxed_image = image.convert('RGB') + boxed_image = boxed_image.resize((self.model_image_size[1],self.model_image_size[0]), Image.BICUBIC) image_data = np.array(boxed_image, dtype='float32') image_data /= 255. #---------------------------------------------------------# @@ -100,7 +100,6 @@ class mAP_YOLO(YOLO): self.yolo_model.input: image_data, self.input_image_shape: [image.size[1], image.size[0]], K.learning_phase(): 0}) - for i, c in enumerate(out_classes): predicted_class = self.class_names[int(c)] score = str(out_scores[i]) diff --git a/nets/yolo4.py b/nets/yolo4.py index 39d117fd267ebbc7ed6d6a1ca6b41fb4692aec52..e62d1e4d11b7758b225e31e93aa392b8be004c61 100644 --- a/nets/yolo4.py +++ b/nets/yolo4.py @@ -225,7 +225,7 @@ def yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape): #---------------------------------------------------# # 获取每个box和它的得分 #---------------------------------------------------# -def yolo_boxes_and_scores(feats, anchors, num_classes, input_shape, image_shape): +def yolo_boxes_and_scores(feats, anchors, num_classes, input_shape, image_shape, letterbox_image): #-----------------------------------------------------------------# # 将预测值调成真实值 # box_xy : -1,13,13,3,2; @@ -240,7 +240,23 @@ def yolo_boxes_and_scores(feats, anchors, num_classes, input_shape, image_shape) # 我们需要对齐进行修改,去除灰条的部分。 # 将box_xy、和box_wh调节成y_min,y_max,xmin,xmax #-----------------------------------------------------------------# - boxes = yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape) + if letterbox_image: + boxes = yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape) + else: + box_yx = box_xy[..., ::-1] + box_hw = box_wh[..., ::-1] + box_mins = box_yx - (box_hw / 2.) + box_maxes = box_yx + (box_hw / 2.) + + input_shape = K.cast(input_shape, K.dtype(box_yx)) + image_shape = K.cast(image_shape, K.dtype(box_yx)) + + boxes = K.concatenate([ + box_mins[..., 0:1] * image_shape[0], # y_min + box_mins[..., 1:2] * image_shape[1], # x_min + box_maxes[..., 0:1] * image_shape[0], # y_max + box_maxes[..., 1:2] * image_shape[1] # x_max + ]) #-----------------------------------------------------------------# # 获得最终得分和框的位置 #-----------------------------------------------------------------# @@ -258,7 +274,8 @@ def yolo_eval(yolo_outputs, image_shape, max_boxes=20, score_threshold=.6, - iou_threshold=.5): + iou_threshold=.5, + letterbox_image=True): #---------------------------------------------------# # 获得特征层的数量,有效特征层的数量为3 #---------------------------------------------------# @@ -280,7 +297,7 @@ def yolo_eval(yolo_outputs, # 对每个特征层进行处理 #-----------------------------------------------------------# for l in range(num_layers): - _boxes, _box_scores = yolo_boxes_and_scores(yolo_outputs[l], anchors[anchor_mask[l]], num_classes, input_shape, image_shape) + _boxes, _box_scores = yolo_boxes_and_scores(yolo_outputs[l], anchors[anchor_mask[l]], num_classes, input_shape, image_shape, letterbox_image) boxes.append(_boxes) box_scores.append(_box_scores) #-----------------------------------------------------------# diff --git a/yolo.py b/yolo.py index f6885473f9f41a749465d92f1414fd63e2764a3d..734b0aea5f2df9c26d4089cd1784443b7a9d0faa 100644 --- a/yolo.py +++ b/yolo.py @@ -21,15 +21,20 @@ from utils.utils import letterbox_image #--------------------------------------------# class YOLO(object): _defaults = { - "model_path" : 'model_data/yolo4_weight.h5', + "model_path" : 'model_data/yolo4_voc_weights.h5', "anchors_path" : 'model_data/yolo_anchors.txt', - "classes_path" : 'model_data/coco_classes.txt', + "classes_path" : 'model_data/voc_classes.txt', "score" : 0.5, "iou" : 0.3, "max_boxes" : 100, # 显存比较小可以使用416x416 # 显存比较大可以使用608x608 - "model_image_size" : (416, 416) + "model_image_size" : (416, 416), + #---------------------------------------------------------------------# + # 该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize, + # 在多次测试后,发现关闭letterbox_image直接resize的效果更好 + #---------------------------------------------------------------------# + "letterbox_image" : False, } @classmethod @@ -119,19 +124,22 @@ class YOLO(object): #---------------------------------------------------------# boxes, scores, classes = yolo_eval(self.yolo_model.output, self.anchors, num_classes, self.input_image_shape, max_boxes = self.max_boxes, - score_threshold = self.score, iou_threshold = self.iou) + score_threshold = self.score, iou_threshold = self.iou, letterbox_image = self.letterbox_image) return boxes, scores, classes #---------------------------------------------------# # 检测图片 #---------------------------------------------------# def detect_image(self, image): - start = timer() #---------------------------------------------------------# # 给图像增加灰条,实现不失真的resize + # 也可以直接resize进行识别 #---------------------------------------------------------# - new_image_size = (self.model_image_size[1],self.model_image_size[0]) - boxed_image = letterbox_image(image, new_image_size) + if self.letterbox_image: + boxed_image = letterbox_image(image, (self.model_image_size[1],self.model_image_size[0])) + else: + boxed_image = image.convert('RGB') + boxed_image = boxed_image.resize((self.model_image_size[1],self.model_image_size[0]), Image.BICUBIC) image_data = np.array(boxed_image, dtype='float32') image_data /= 255. #---------------------------------------------------------# @@ -197,8 +205,6 @@ class YOLO(object): draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font) del draw - end = timer() - print(end - start) return image def close_session(self):