未验证 提交 74410ff9 编写于 作者: G Guanghua Yu 提交者: GitHub

support depoly infer for fcos (#1330)

上级 419671a8
......@@ -74,7 +74,7 @@ class Resize(object):
self.arch = arch
self.use_cv2 = use_cv2
self.interp = interp
self.scale_set = {'RCNN', 'RetinaNet'}
self.scale_set = {'RCNN', 'RetinaNet', 'FCOS'}
def __call__(self, im, im_info):
"""
......@@ -259,7 +259,7 @@ def create_inputs(im, im_info, model_arch='YOLO'):
scale = scale_x
im_info = np.array([resize_shape + [scale]]).astype('float32')
inputs['im_info'] = im_info
elif 'RCNN' in model_arch:
elif ('RCNN' in model_arch) or ('FCOS' in model_arch):
scale = scale_x
im_info = np.array([resize_shape + [scale]]).astype('float32')
im_shape = np.array([origin_shape + [1.]]).astype('float32')
......@@ -276,7 +276,15 @@ class Config():
Args:
model_dir (str): root path of model.yml
"""
support_models = ['YOLO', 'SSD', 'RetinaNet', 'RCNN', 'Face', 'TTF']
support_models = [
'YOLO',
'SSD',
'RetinaNet',
'RCNN',
'Face',
'TTF',
'FCOS',
]
def __init__(self, model_dir):
# parsing Yaml config for Preprocess
......@@ -566,15 +574,19 @@ def predict_image():
output_dir=FLAGS.output_dir)
def predict_video():
def predict_video(camera_id):
detector = Detector(
FLAGS.model_dir, use_gpu=FLAGS.use_gpu, run_mode=FLAGS.run_mode)
capture = cv2.VideoCapture(FLAGS.video_file)
if camera_id != -1:
capture = cv2.VideoCapture(camera_id)
video_name = 'output.mp4'
else:
capture = cv2.VideoCapture(FLAGS.video_file)
video_name = os.path.split(FLAGS.video_file)[-1]
fps = 30
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_name = os.path.split(FLAGS.video_file)[-1]
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name)
......@@ -594,6 +606,10 @@ def predict_video():
mask_resolution=detector.config.mask_resolution)
im = np.array(im)
writer.write(im)
if camera_id != -1:
cv2.imshow('Mask Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
writer.release()
......@@ -617,6 +633,11 @@ if __name__ == '__main__':
"--image_file", type=str, default='', help="Path of image file.")
parser.add_argument(
"--video_file", type=str, default='', help="Path of video file.")
parser.add_argument(
"--camera_id",
type=int,
default=-1,
help="device id of camera to predict.")
parser.add_argument(
"--run_mode",
type=str,
......@@ -647,5 +668,5 @@ if __name__ == '__main__':
assert "Cannot predict image and video at the same time"
if FLAGS.image_file != '':
predict_image()
if FLAGS.video_file != '':
predict_video()
if FLAGS.video_file != '' or FLAGS.camera_id != -1:
predict_video(FLAGS.camera_id)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册