Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_46462485
yolo3-pytorch
提交
48ad276b
Y
yolo3-pytorch
项目概览
weixin_46462485
/
yolo3-pytorch
与 Fork 源项目一致
从无法访问的项目Fork
通知
3
Star
1
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Y
yolo3-pytorch
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
48ad276b
编写于
5月 21, 2021
作者:
B
Bubbliiiing
提交者:
GitHub
5月 21, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add files via upload
上级
0259c0a9
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
176 addition
and
21 deletion
+176
-21
predict.py
predict.py
+93
-21
yolo.py
yolo.py
+83
-0
未找到文件。
predict.py
浏览文件 @
48ad276b
'''
predict.py有几个注意点
1、该代码无法直接进行批量预测,如果想要批量预测,可以利用os.listdir()遍历文件夹,利用Image.open打开图片文件进行预测。
具体流程可以参考get_dr_txt.py,在get_dr_txt.py即实现了遍历还实现了目标信息的保存。
2、如果想要进行检测完的图片的保存,利用r_image.save("img.jpg")即可保存,直接在predict.py里进行修改即可。
3、如果想要获得预测框的坐标,可以进入yolo.detect_image函数,在绘图部分读取top,left,bottom,right这四个值。
4、如果想要利用预测框截取下目标,可以进入yolo.detect_image函数,在绘图部分利用获取到的top,left,bottom,right这四个值
在原图上利用矩阵的方式进行截取。
5、如果想要在预测图上写额外的字,比如检测到的特定目标的数量,可以进入yolo.detect_image函数,在绘图部分对predicted_class进行判断,
比如判断if predicted_class == 'car': 即可判断当前目标是否为车,然后记录数量即可。利用draw.text即可写字。
'''
#----------------------------------------------------#
# 对视频中的predict.py进行了修改,
# 将单张图片预测、摄像头检测和FPS测试功能
# 整合到了一个py文件中,通过指定mode进行模式的修改。
#----------------------------------------------------#
import
time
import
cv2
import
numpy
as
np
from
PIL
import
Image
from
yolo
import
YOLO
yolo
=
YOLO
()
if
__name__
==
"__main__"
:
yolo
=
YOLO
()
#-------------------------------------------------------------------------#
# mode用于指定测试的模式:
# 'predict'表示单张图片预测
# 'video'表示视频检测
# 'fps'表示测试fps
#-------------------------------------------------------------------------#
mode
=
"fps"
#-------------------------------------------------------------------------#
# video_path用于指定视频的路径,当video_path=0时表示检测摄像头
# video_save_path表示视频保存的路径,当video_save_path=""时表示不保存
# video_fps用于保存的视频的fps
# video_path、video_save_path和video_fps仅在mode='video'时有效
# 保存视频时需要ctrl+c退出才会完成完整的保存步骤,不可直接结束程序。
#-------------------------------------------------------------------------#
video_path
=
0
video_save_path
=
""
video_fps
=
25.0
if
mode
==
"predict"
:
'''
1、该代码无法直接进行批量预测,如果想要批量预测,可以利用os.listdir()遍历文件夹,利用Image.open打开图片文件进行预测。
具体流程可以参考get_dr_txt.py,在get_dr_txt.py即实现了遍历还实现了目标信息的保存。
2、如果想要进行检测完的图片的保存,利用r_image.save("img.jpg")即可保存,直接在predict.py里进行修改即可。
3、如果想要获得预测框的坐标,可以进入yolo.detect_image函数,在绘图部分读取top,left,bottom,right这四个值。
4、如果想要利用预测框截取下目标,可以进入yolo.detect_image函数,在绘图部分利用获取到的top,left,bottom,right这四个值
在原图上利用矩阵的方式进行截取。
5、如果想要在预测图上写额外的字,比如检测到的特定目标的数量,可以进入yolo.detect_image函数,在绘图部分对predicted_class进行判断,
比如判断if predicted_class == 'car': 即可判断当前目标是否为车,然后记录数量即可。利用draw.text即可写字。
'''
while
True
:
img
=
input
(
'Input image filename:'
)
try
:
image
=
Image
.
open
(
img
)
except
:
print
(
'Open Error! Try again!'
)
continue
else
:
r_image
=
yolo
.
detect_image
(
image
)
r_image
.
show
()
elif
mode
==
"video"
:
capture
=
cv2
.
VideoCapture
(
video_path
)
if
video_save_path
!=
""
:
fourcc
=
cv2
.
VideoWriter_fourcc
(
*
'XVID'
)
size
=
(
int
(
capture
.
get
(
cv2
.
CAP_PROP_FRAME_WIDTH
)),
int
(
capture
.
get
(
cv2
.
CAP_PROP_FRAME_HEIGHT
)))
out
=
cv2
.
VideoWriter
(
video_save_path
,
fourcc
,
video_fps
,
size
)
fps
=
0.0
while
(
True
):
t1
=
time
.
time
()
# 读取某一帧
ref
,
frame
=
capture
.
read
()
# 格式转变,BGRtoRGB
frame
=
cv2
.
cvtColor
(
frame
,
cv2
.
COLOR_BGR2RGB
)
# 转变成Image
frame
=
Image
.
fromarray
(
np
.
uint8
(
frame
))
# 进行检测
frame
=
np
.
array
(
yolo
.
detect_image
(
frame
))
# RGBtoBGR满足opencv显示格式
frame
=
cv2
.
cvtColor
(
frame
,
cv2
.
COLOR_RGB2BGR
)
fps
=
(
fps
+
(
1.
/
(
time
.
time
()
-
t1
))
)
/
2
print
(
"fps= %.2f"
%
(
fps
))
frame
=
cv2
.
putText
(
frame
,
"fps= %.2f"
%
(
fps
),
(
0
,
40
),
cv2
.
FONT_HERSHEY_SIMPLEX
,
1
,
(
0
,
255
,
0
),
2
)
cv2
.
imshow
(
"video"
,
frame
)
c
=
cv2
.
waitKey
(
1
)
&
0xff
if
video_save_path
!=
""
:
out
.
write
(
frame
)
if
c
==
27
:
capture
.
release
()
break
capture
.
release
()
out
.
release
()
cv2
.
destroyAllWindows
()
while
True
:
img
=
input
(
'Input image filename:'
)
try
:
image
=
Image
.
open
(
img
)
except
:
print
(
'Open Error! Try again!'
)
continue
elif
mode
==
"fps"
:
test_interval
=
100
img
=
Image
.
open
(
'img/street.jpg'
)
tact_time
=
yolo
.
get_FPS
(
img
,
test_interval
)
print
(
str
(
tact_time
)
+
' seconds, '
+
str
(
1
/
tact_time
)
+
'FPS, @batch_size 1'
)
else
:
r_image
=
yolo
.
detect_image
(
image
)
r_image
.
show
()
raise
AssertionError
(
"Please specify the correct mode: 'predict', 'video' or 'fps'."
)
yolo.py
浏览文件 @
48ad276b
...
...
@@ -3,6 +3,7 @@
#-------------------------------------#
import
colorsys
import
os
import
time
import
numpy
as
np
import
torch
...
...
@@ -229,3 +230,85 @@ class YOLO(object):
del
draw
return
image
def
get_FPS
(
self
,
image
,
test_interval
):
image_shape
=
np
.
array
(
np
.
shape
(
image
)[
0
:
2
])
#---------------------------------------------------------#
# 给图像增加灰条,实现不失真的resize
# 也可以直接resize进行识别
#---------------------------------------------------------#
if
self
.
letterbox_image
:
crop_img
=
np
.
array
(
letterbox_image
(
image
,
(
self
.
model_image_size
[
1
],
self
.
model_image_size
[
0
])))
else
:
crop_img
=
image
.
convert
(
'RGB'
)
crop_img
=
crop_img
.
resize
((
self
.
model_image_size
[
1
],
self
.
model_image_size
[
0
]),
Image
.
BICUBIC
)
photo
=
np
.
array
(
crop_img
,
dtype
=
np
.
float32
)
/
255.0
photo
=
np
.
transpose
(
photo
,
(
2
,
0
,
1
))
#---------------------------------------------------------#
# 添加上batch_size维度
#---------------------------------------------------------#
images
=
[
photo
]
with
torch
.
no_grad
():
images
=
torch
.
from_numpy
(
np
.
asarray
(
images
))
if
self
.
cuda
:
images
=
images
.
cuda
()
outputs
=
self
.
net
(
images
)
output_list
=
[]
for
i
in
range
(
3
):
output_list
.
append
(
self
.
yolo_decodes
[
i
](
outputs
[
i
]))
output
=
torch
.
cat
(
output_list
,
1
)
batch_detections
=
non_max_suppression
(
output
,
len
(
self
.
class_names
),
conf_thres
=
self
.
confidence
,
nms_thres
=
self
.
iou
)
try
:
batch_detections
=
batch_detections
[
0
].
cpu
().
numpy
()
top_index
=
batch_detections
[:,
4
]
*
batch_detections
[:,
5
]
>
self
.
confidence
top_conf
=
batch_detections
[
top_index
,
4
]
*
batch_detections
[
top_index
,
5
]
top_label
=
np
.
array
(
batch_detections
[
top_index
,
-
1
],
np
.
int32
)
top_bboxes
=
np
.
array
(
batch_detections
[
top_index
,:
4
])
top_xmin
,
top_ymin
,
top_xmax
,
top_ymax
=
np
.
expand_dims
(
top_bboxes
[:,
0
],
-
1
),
np
.
expand_dims
(
top_bboxes
[:,
1
],
-
1
),
np
.
expand_dims
(
top_bboxes
[:,
2
],
-
1
),
np
.
expand_dims
(
top_bboxes
[:,
3
],
-
1
)
if
self
.
letterbox_image
:
boxes
=
yolo_correct_boxes
(
top_ymin
,
top_xmin
,
top_ymax
,
top_xmax
,
np
.
array
([
self
.
model_image_size
[
0
],
self
.
model_image_size
[
1
]]),
image_shape
)
else
:
top_xmin
=
top_xmin
/
self
.
model_image_size
[
1
]
*
image_shape
[
1
]
top_ymin
=
top_ymin
/
self
.
model_image_size
[
0
]
*
image_shape
[
0
]
top_xmax
=
top_xmax
/
self
.
model_image_size
[
1
]
*
image_shape
[
1
]
top_ymax
=
top_ymax
/
self
.
model_image_size
[
0
]
*
image_shape
[
0
]
boxes
=
np
.
concatenate
([
top_ymin
,
top_xmin
,
top_ymax
,
top_xmax
],
axis
=-
1
)
except
:
pass
t1
=
time
.
time
()
for
_
in
range
(
test_interval
):
with
torch
.
no_grad
():
outputs
=
self
.
net
(
images
)
output_list
=
[]
for
i
in
range
(
3
):
output_list
.
append
(
self
.
yolo_decodes
[
i
](
outputs
[
i
]))
output
=
torch
.
cat
(
output_list
,
1
)
batch_detections
=
non_max_suppression
(
output
,
len
(
self
.
class_names
),
conf_thres
=
self
.
confidence
,
nms_thres
=
self
.
iou
)
try
:
batch_detections
=
batch_detections
[
0
].
cpu
().
numpy
()
top_index
=
batch_detections
[:,
4
]
*
batch_detections
[:,
5
]
>
self
.
confidence
top_conf
=
batch_detections
[
top_index
,
4
]
*
batch_detections
[
top_index
,
5
]
top_label
=
np
.
array
(
batch_detections
[
top_index
,
-
1
],
np
.
int32
)
top_bboxes
=
np
.
array
(
batch_detections
[
top_index
,:
4
])
top_xmin
,
top_ymin
,
top_xmax
,
top_ymax
=
np
.
expand_dims
(
top_bboxes
[:,
0
],
-
1
),
np
.
expand_dims
(
top_bboxes
[:,
1
],
-
1
),
np
.
expand_dims
(
top_bboxes
[:,
2
],
-
1
),
np
.
expand_dims
(
top_bboxes
[:,
3
],
-
1
)
if
self
.
letterbox_image
:
boxes
=
yolo_correct_boxes
(
top_ymin
,
top_xmin
,
top_ymax
,
top_xmax
,
np
.
array
([
self
.
model_image_size
[
0
],
self
.
model_image_size
[
1
]]),
image_shape
)
else
:
top_xmin
=
top_xmin
/
self
.
model_image_size
[
1
]
*
image_shape
[
1
]
top_ymin
=
top_ymin
/
self
.
model_image_size
[
0
]
*
image_shape
[
0
]
top_xmax
=
top_xmax
/
self
.
model_image_size
[
1
]
*
image_shape
[
1
]
top_ymax
=
top_ymax
/
self
.
model_image_size
[
0
]
*
image_shape
[
0
]
boxes
=
np
.
concatenate
([
top_ymin
,
top_xmin
,
top_ymax
,
top_xmax
],
axis
=-
1
)
except
:
pass
t2
=
time
.
time
()
tact_time
=
(
t2
-
t1
)
/
test_interval
return
tact_time
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录