FPS_test.py 2.1 KB
Newer Older
B
Bubbliiiing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
import os
import time

import numpy as np
from keras import backend as K
from PIL import Image

from utils.utils import letterbox_image
from yolo import YOLO

'''
该FPS测试不包括前处理(归一化与resize部分)、绘图。
包括的内容为:网络推理、得分门限筛选、非极大抑制。
使用'img/street.jpg'图片进行测试,该测试方法参考库https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch

video.py里面测试的FPS会低于该FPS,因为摄像头的读取频率有限,而且处理过程包含了前处理和绘图部分。
'''
class FPS_YOLO(YOLO):
    def get_FPS(self, image, test_interval):
B
Bubbliiiing 已提交
20 21 22 23 24
        if self.letterbox_image:
            boxed_image = letterbox_image(image, (self.model_image_size[1],self.model_image_size[0]))
        else:
            boxed_image = image.convert('RGB')
            boxed_image = boxed_image.resize((self.model_image_size[1],self.model_image_size[0]), Image.BICUBIC)
B
Bubbliiiing 已提交
25 26
        image_data = np.array(boxed_image, dtype='float32')
        image_data /= 255.
B
Bubbliiiing 已提交
27
        image_data = np.expand_dims(image_data, 0)
B
Bubbliiiing 已提交
28 29 30 31 32 33

        out_boxes, out_scores, out_classes = self.sess.run(
            [self.boxes, self.scores, self.classes],
            feed_dict={
                self.yolo_model.input: image_data,
                self.input_image_shape: [image.size[1], image.size[0]],
B
Bubbliiiing 已提交
34
                K.learning_phase(): 0})
B
Bubbliiiing 已提交
35 36 37 38 39 40 41 42

        t1 = time.time()
        for _ in range(test_interval):
            out_boxes, out_scores, out_classes = self.sess.run(
                [self.boxes, self.scores, self.classes],
                feed_dict={
                    self.yolo_model.input: image_data,
                    self.input_image_shape: [image.size[1], image.size[0]],
B
Bubbliiiing 已提交
43
                    K.learning_phase(): 0})
B
Bubbliiiing 已提交
44 45 46 47 48 49 50 51 52
        t2 = time.time()
        tact_time = (t2 - t1) / test_interval
        return tact_time
        
yolo = FPS_YOLO()
test_interval = 100
img = Image.open('img/street.jpg')
tact_time = yolo.get_FPS(img, test_interval)
print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1')