Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
陈小光丶
yolov4-keras
提交
3bb5d7b1
Y
yolov4-keras
项目概览
陈小光丶
/
yolov4-keras
与 Fork 源项目一致
从无法访问的项目Fork
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Y
yolov4-keras
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
3bb5d7b1
编写于
2月 07, 2021
作者:
B
Bubbliiiing
提交者:
GitHub
2月 07, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add files via upload
上级
b4284fd8
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
64 addition
and
85 deletion
+64
-85
FPS_test.py
FPS_test.py
+8
-18
get_dr_txt.py
get_dr_txt.py
+5
-9
kmeans_for_anchors.py
kmeans_for_anchors.py
+1
-1
nets/ious.py
nets/ious.py
+2
-5
nets/loss.py
nets/loss.py
+8
-7
predict.py
predict.py
+1
-0
test.py
test.py
+1
-1
train.py
train.py
+28
-27
yolo.py
yolo.py
+10
-17
未找到文件。
FPS_test.py
浏览文件 @
3bb5d7b1
...
...
@@ -17,31 +17,20 @@ video.py里面测试的FPS会低于该FPS,因为摄像头的读取频率有限
'''
class
FPS_YOLO
(
YOLO
):
def
get_FPS
(
self
,
image
,
test_interval
):
#---------------------------------------------------------#
# 给图像增加灰条,实现不失真的resize
# 也可以直接resize进行识别
#---------------------------------------------------------#
if
self
.
letterbox_image
:
boxed_image
=
letterbox_image
(
image
,
(
self
.
model_image_size
[
1
],
self
.
model_image_size
[
0
]))
else
:
boxed_image
=
image
.
convert
(
'RGB'
)
boxed_image
=
boxed_image
.
resize
((
self
.
model_image_size
[
1
],
self
.
model_image_size
[
0
]),
Image
.
BICUBIC
)
# 调整图片使其符合输入要求
new_image_size
=
(
self
.
model_image_size
[
1
],
self
.
model_image_size
[
0
])
boxed_image
=
letterbox_image
(
image
,
new_image_size
)
image_data
=
np
.
array
(
boxed_image
,
dtype
=
'float32'
)
image_data
/=
255.
#---------------------------------------------------------#
# 添加上batch_size维度
#---------------------------------------------------------#
image_data
=
np
.
expand_dims
(
image_data
,
0
)
image_data
=
np
.
expand_dims
(
image_data
,
0
)
#---------------------------------------------------------#
# 将图像输入网络当中进行预测!
#---------------------------------------------------------#
out_boxes
,
out_scores
,
out_classes
=
self
.
sess
.
run
(
[
self
.
boxes
,
self
.
scores
,
self
.
classes
],
feed_dict
=
{
self
.
yolo_model
.
input
:
image_data
,
self
.
input_image_shape
:
[
image
.
size
[
1
],
image
.
size
[
0
]],
K
.
learning_phase
():
0
})
K
.
learning_phase
():
0
})
t1
=
time
.
time
()
for
_
in
range
(
test_interval
):
...
...
@@ -50,7 +39,8 @@ class FPS_YOLO(YOLO):
feed_dict
=
{
self
.
yolo_model
.
input
:
image_data
,
self
.
input_image_shape
:
[
image
.
size
[
1
],
image
.
size
[
0
]],
K
.
learning_phase
():
0
})
K
.
learning_phase
():
0
})
t2
=
time
.
time
()
tact_time
=
(
t2
-
t1
)
/
test_interval
return
tact_time
...
...
get_dr_txt.py
浏览文件 @
3bb5d7b1
...
...
@@ -13,7 +13,7 @@ from keras.layers import Input
from
PIL
import
Image
from
tqdm
import
tqdm
from
nets.yolo4
_tiny
import
yolo_body
,
yolo_eval
from
nets.yolo4
import
yolo_body
,
yolo_eval
from
utils.utils
import
letterbox_image
from
yolo
import
YOLO
...
...
@@ -41,7 +41,7 @@ class mAP_YOLO(YOLO):
try
:
self
.
yolo_model
=
load_model
(
model_path
,
compile
=
False
)
except
:
self
.
yolo_model
=
yolo_body
(
Input
(
shape
=
(
None
,
None
,
3
)),
num_anchors
//
2
,
num_classes
)
self
.
yolo_model
=
yolo_body
(
Input
(
shape
=
(
None
,
None
,
3
)),
num_anchors
//
3
,
num_classes
)
self
.
yolo_model
.
load_weights
(
self
.
model_path
)
else
:
assert
self
.
yolo_model
.
layers
[
-
1
].
output_shape
[
-
1
]
==
\
...
...
@@ -71,7 +71,7 @@ class mAP_YOLO(YOLO):
#---------------------------------------------------------#
boxes
,
scores
,
classes
=
yolo_eval
(
self
.
yolo_model
.
output
,
self
.
anchors
,
num_classes
,
self
.
input_image_shape
,
max_boxes
=
self
.
max_boxes
,
score_threshold
=
self
.
score
,
iou_threshold
=
self
.
iou
,
letterbox_image
=
self
.
letterbox_image
)
score_threshold
=
self
.
score
,
iou_threshold
=
self
.
iou
)
return
boxes
,
scores
,
classes
#---------------------------------------------------#
...
...
@@ -81,13 +81,9 @@ class mAP_YOLO(YOLO):
f
=
open
(
"./input/detection-results/"
+
image_id
+
".txt"
,
"w"
)
#---------------------------------------------------------#
# 给图像增加灰条,实现不失真的resize
# 也可以直接resize进行识别
#---------------------------------------------------------#
if
self
.
letterbox_image
:
boxed_image
=
letterbox_image
(
image
,
(
self
.
model_image_size
[
1
],
self
.
model_image_size
[
0
]))
else
:
boxed_image
=
image
.
convert
(
'RGB'
)
boxed_image
=
boxed_image
.
resize
((
self
.
model_image_size
[
1
],
self
.
model_image_size
[
0
]),
Image
.
BICUBIC
)
new_image_size
=
(
self
.
model_image_size
[
1
],
self
.
model_image_size
[
0
])
boxed_image
=
letterbox_image
(
image
,
new_image_size
)
image_data
=
np
.
array
(
boxed_image
,
dtype
=
'float32'
)
image_data
/=
255.
#---------------------------------------------------------#
...
...
kmeans_for_anchors.py
浏览文件 @
3bb5d7b1
...
...
@@ -86,7 +86,7 @@ if __name__ == '__main__':
# 运行该程序会计算'./VOCdevkit/VOC2007/Annotations'的xml
# 会生成yolo_anchors.txt
SIZE
=
416
anchors_num
=
6
anchors_num
=
9
# 载入数据集,可以使用VOC的xml
path
=
r
'./VOCdevkit/VOC2007/Annotations'
...
...
nets/ious.py
浏览文件 @
3bb5d7b1
import
math
import
tensorflow
as
tf
from
keras
import
backend
as
K
import
tensorflow
as
tf
import
math
def
box_ciou
(
b1
,
b2
):
"""
输入为:
...
...
nets/loss.py
浏览文件 @
3bb5d7b1
...
...
@@ -99,20 +99,21 @@ def box_iou(b1, b2):
# loss值计算
#---------------------------------------------------#
def
yolo_loss
(
args
,
anchors
,
num_classes
,
ignore_thresh
=
.
5
,
label_smoothing
=
0.1
,
print_loss
=
False
,
normalize
=
True
):
# 一共有
两
层
# 一共有
三
层
num_layers
=
len
(
anchors
)
//
3
#---------------------------------------------------------------------------------------------------#
# 将预测结果和实际ground truth分开,args是[*model_body.output, *y_true]
# y_true是一个列表,包含
两个特征层,shape分别为(m,13,13,3,85),(m,26,26,3,85)
# yolo_outputs是一个列表,包含
两个特征层,shape分别为(m,13,13,3,85),(m,26,26,3,85)
# y_true是一个列表,包含
三个特征层,shape分别为(m,13,13,3,85),(m,26,26,3,85),(m,52,52,3,85)。
# yolo_outputs是一个列表,包含
三个特征层,shape分别为(m,13,13,3,85),(m,26,26,3,85),(m,52,52,3,85)。
#---------------------------------------------------------------------------------------------------#
y_true
=
args
[
num_layers
:]
yolo_outputs
=
args
[:
num_layers
]
#-----------------------------------------------------------#
# 13x13的特征层对应的anchor是[81,82], [135,169], [344,319]
# 26x26的特征层对应的anchor是[23,27], [37,58], [81,82]
# 13x13的特征层对应的anchor是[142, 110], [192, 243], [459, 401]
# 26x26的特征层对应的anchor是[36, 75], [76, 55], [72, 146]
# 52x52的特征层对应的anchor是[12, 16], [19, 36], [40, 28]
#-----------------------------------------------------------#
anchor_mask
=
[[
6
,
7
,
8
],
[
3
,
4
,
5
],
[
0
,
1
,
2
]]
if
num_layers
==
3
else
[[
3
,
4
,
5
],
[
1
,
2
,
3
]]
...
...
@@ -129,8 +130,8 @@ def yolo_loss(args, anchors, num_classes, ignore_thresh=.5, label_smoothing=0.1,
mf
=
K
.
cast
(
m
,
K
.
dtype
(
yolo_outputs
[
0
]))
#---------------------------------------------------------------------------------------------------#
# y_true是一个列表,包含
两个特征层,shape分别为(m,13,13,3,85),(m,26,26,3,85)
# yolo_outputs是一个列表,包含
两个特征层,shape分别为(m,13,13,3,85),(m,26,26,3,85)
# y_true是一个列表,包含
三个特征层,shape分别为(m,13,13,3,85),(m,26,26,3,85),(m,52,52,3,85)。
# yolo_outputs是一个列表,包含
三个特征层,shape分别为(m,13,13,3,85),(m,26,26,3,85),(m,52,52,3,85)。
#---------------------------------------------------------------------------------------------------#
for
l
in
range
(
num_layers
):
#-----------------------------------------------------------#
...
...
predict.py
浏览文件 @
3bb5d7b1
...
...
@@ -8,6 +8,7 @@ predict.py有几个注意点
from
keras.layers
import
Input
from
PIL
import
Image
from
nets.yolo4
import
yolo_body
from
yolo
import
YOLO
yolo
=
YOLO
()
...
...
test.py
浏览文件 @
3bb5d7b1
...
...
@@ -5,7 +5,7 @@
#--------------------------------------------#
from
keras.layers
import
Input
from
nets.yolo4
_tiny
import
yolo_body
from
nets.yolo4
import
yolo_body
if
__name__
==
"__main__"
:
inputs
=
Input
([
416
,
416
,
3
])
...
...
train.py
浏览文件 @
3bb5d7b1
...
...
@@ -9,7 +9,7 @@ from keras.models import Model
from
keras.optimizers
import
Adam
from
nets.loss
import
yolo_loss
from
nets.yolo4
_tiny
import
yolo_body
from
nets.yolo4
import
yolo_body
from
utils.utils
import
(
WarmUpCosineDecayScheduler
,
get_random_data
,
get_random_data_with_Mosaic
,
rand
)
...
...
@@ -67,19 +67,20 @@ def data_generator(annotation_lines, batch_size, input_shape, anchors, num_class
#---------------------------------------------------#
def
preprocess_true_boxes
(
true_boxes
,
input_shape
,
anchors
,
num_classes
):
assert
(
true_boxes
[...,
4
]
<
num_classes
).
all
(),
'class id must be less than num_classes'
# 一共有
两
个特征层数
# 一共有
三
个特征层数
num_layers
=
len
(
anchors
)
//
3
#-----------------------------------------------------------#
# 13x13的特征层对应的anchor是[81,82], [135,169], [344,319]
# 26x26的特征层对应的anchor是[23,27], [37,58], [81,82]
# 13x13的特征层对应的anchor是[142, 110], [192, 243], [459, 401]
# 26x26的特征层对应的anchor是[36, 75], [76, 55], [72, 146]
# 52x52的特征层对应的anchor是[12, 16], [19, 36], [40, 28]
#-----------------------------------------------------------#
anchor_mask
=
[[
6
,
7
,
8
],
[
3
,
4
,
5
],
[
0
,
1
,
2
]]
if
num_layers
==
3
else
[[
3
,
4
,
5
],
[
1
,
2
,
3
]]
anchor_mask
=
[[
6
,
7
,
8
],
[
3
,
4
,
5
],
[
0
,
1
,
2
]]
#-----------------------------------------------------------#
# 获得框的坐标和图片的大小
#-----------------------------------------------------------#
true_boxes
=
np
.
array
(
true_boxes
,
dtype
=
'float32'
)
input_shape
=
np
.
array
(
input_shape
,
dtype
=
'int32'
)
input_shape
=
np
.
array
(
input_shape
,
dtype
=
'int32'
)
#-----------------------------------------------------------#
# 通过计算获得真实框的中心和宽高
# 中心点(m,n,2) 宽高(m,n,2)
...
...
@@ -102,7 +103,7 @@ def preprocess_true_boxes(true_boxes, input_shape, anchors, num_classes):
dtype
=
'float32'
)
for
l
in
range
(
num_layers
)]
#-----------------------------------------------------------#
# [
6,2] -> [1,6
,2]
# [
9,2] -> [1,9
,2]
#-----------------------------------------------------------#
anchors
=
np
.
expand_dims
(
anchors
,
0
)
anchor_maxes
=
anchors
/
2.
...
...
@@ -126,10 +127,10 @@ def preprocess_true_boxes(true_boxes, input_shape, anchors, num_classes):
#-----------------------------------------------------------#
# 计算所有真实框和先验框的交并比
# intersect_area [n,
6
]
# intersect_area [n,
9
]
# box_area [n,1]
# anchor_area [1,
6
]
# iou [n,
6
]
# anchor_area [1,
9
]
# iou [n,
9
]
#-----------------------------------------------------------#
intersect_mins
=
np
.
maximum
(
box_mins
,
anchor_mins
)
intersect_maxes
=
np
.
minimum
(
box_maxes
,
anchor_maxes
)
...
...
@@ -199,7 +200,7 @@ if __name__ == "__main__":
# 训练自己的数据集时提示维度不匹配正常
# 预测的东西都不一样了自然维度不匹配
#------------------------------------------------------#
weights_path
=
'model_data/yolo
v4_tiny_weights_coco
.h5'
weights_path
=
'model_data/yolo
4_weight
.h5'
#------------------------------------------------------#
# 训练用图片大小
# 一般在416x416和608x608选择
...
...
@@ -238,8 +239,8 @@ if __name__ == "__main__":
#------------------------------------------------------#
image_input
=
Input
(
shape
=
(
None
,
None
,
3
))
h
,
w
=
input_shape
print
(
'Create YOLOv4
-Tiny
model with {} anchors and {} classes.'
.
format
(
num_anchors
,
num_classes
))
model_body
=
yolo_body
(
image_input
,
num_anchors
//
2
,
num_classes
)
print
(
'Create YOLOv4 model with {} anchors and {} classes.'
.
format
(
num_anchors
,
num_classes
))
model_body
=
yolo_body
(
image_input
,
num_anchors
//
3
,
num_classes
)
#------------------------------------------------------#
# 载入预训练权重
...
...
@@ -251,10 +252,12 @@ if __name__ == "__main__":
# 在这个地方设置损失,将网络的输出结果传入loss函数
# 把整个模型的输出作为loss
#------------------------------------------------------#
y_true
=
[
Input
(
shape
=
(
h
//
{
0
:
32
,
1
:
16
}[
l
],
w
//
{
0
:
32
,
1
:
16
}[
l
],
num_anchors
//
2
,
num_classes
+
5
))
for
l
in
range
(
2
)]
y_true
=
[
Input
(
shape
=
(
h
//
{
0
:
32
,
1
:
16
,
2
:
8
}[
l
],
w
//
{
0
:
32
,
1
:
16
,
2
:
8
}[
l
],
\
num_anchors
//
3
,
num_classes
+
5
))
for
l
in
range
(
3
)]
loss_input
=
[
*
model_body
.
output
,
*
y_true
]
model_loss
=
Lambda
(
yolo_loss
,
output_shape
=
(
1
,),
name
=
'yolo_loss'
,
arguments
=
{
'anchors'
:
anchors
,
'num_classes'
:
num_classes
,
'ignore_thresh'
:
0.5
,
'label_smoothing'
:
label_smoothing
,
'normalize'
:
normalize
})(
loss_input
)
arguments
=
{
'anchors'
:
anchors
,
'num_classes'
:
num_classes
,
'ignore_thresh'
:
0.5
,
'label_smoothing'
:
label_smoothing
,
'normalize'
:
normalize
})(
loss_input
)
model
=
Model
([
model_body
.
input
,
*
y_true
],
model_loss
)
...
...
@@ -284,6 +287,10 @@ if __name__ == "__main__":
num_val
=
int
(
len
(
lines
)
*
val_split
)
num_train
=
len
(
lines
)
-
num_val
freeze_layers
=
249
for
i
in
range
(
freeze_layers
):
model_body
.
layers
[
i
].
trainable
=
False
print
(
'Freeze the first {} layers of total {} layers.'
.
format
(
freeze_layers
,
len
(
model_body
.
layers
)))
#------------------------------------------------------#
# 主干特征提取网络特征通用,冻结训练可以加快训练速度
# 也可以在训练初期防止权值被破坏。
...
...
@@ -292,15 +299,10 @@ if __name__ == "__main__":
# Epoch总训练世代
# 提示OOM或者显存不足请调小Batch_size
#------------------------------------------------------#
freeze_layers
=
60
for
i
in
range
(
freeze_layers
):
model_body
.
layers
[
i
].
trainable
=
False
print
(
'Freeze the first {} layers of total {} layers.'
.
format
(
freeze_layers
,
len
(
model_body
.
layers
)))
# 调整非主干模型first
if
True
:
Init_epoch
=
0
Freeze_epoch
=
50
batch_size
=
32
batch_size
=
16
learning_rate_base
=
1e-3
if
Cosine_scheduler
:
...
...
@@ -324,9 +326,9 @@ if __name__ == "__main__":
model
.
compile
(
optimizer
=
Adam
(
learning_rate_base
),
loss
=
{
'yolo_loss'
:
lambda
y_true
,
y_pred
:
y_pred
})
print
(
'Train on {} samples, val on {} samples, with batch size {}.'
.
format
(
num_train
,
num_val
,
batch_size
))
model
.
fit_generator
(
data_generator
(
lines
[:
num_train
],
batch_size
,
input_shape
,
anchors
,
num_classes
,
mosaic
=
mosaic
),
model
.
fit_generator
(
data_generator
(
lines
[:
num_train
],
batch_size
,
input_shape
,
anchors
,
num_classes
,
mosaic
=
mosaic
,
random
=
True
),
steps_per_epoch
=
max
(
1
,
num_train
//
batch_size
),
validation_data
=
data_generator
(
lines
[
num_train
:],
batch_size
,
input_shape
,
anchors
,
num_classes
,
mosaic
=
False
),
validation_data
=
data_generator
(
lines
[
num_train
:],
batch_size
,
input_shape
,
anchors
,
num_classes
,
mosaic
=
False
,
random
=
False
),
validation_steps
=
max
(
1
,
num_val
//
batch_size
),
epochs
=
Freeze_epoch
,
initial_epoch
=
Init_epoch
,
...
...
@@ -335,11 +337,10 @@ if __name__ == "__main__":
for
i
in
range
(
freeze_layers
):
model_body
.
layers
[
i
].
trainable
=
True
# 解冻后训练
if
True
:
Freeze_epoch
=
50
Epoch
=
100
batch_size
=
16
batch_size
=
2
learning_rate_base
=
1e-4
if
Cosine_scheduler
:
...
...
@@ -363,9 +364,9 @@ if __name__ == "__main__":
model
.
compile
(
optimizer
=
Adam
(
learning_rate_base
),
loss
=
{
'yolo_loss'
:
lambda
y_true
,
y_pred
:
y_pred
})
print
(
'Train on {} samples, val on {} samples, with batch size {}.'
.
format
(
num_train
,
num_val
,
batch_size
))
model
.
fit_generator
(
data_generator
(
lines
[:
num_train
],
batch_size
,
input_shape
,
anchors
,
num_classes
,
mosaic
=
mosaic
),
model
.
fit_generator
(
data_generator
(
lines
[:
num_train
],
batch_size
,
input_shape
,
anchors
,
num_classes
,
mosaic
=
mosaic
,
random
=
True
),
steps_per_epoch
=
max
(
1
,
num_train
//
batch_size
),
validation_data
=
data_generator
(
lines
[
num_train
:],
batch_size
,
input_shape
,
anchors
,
num_classes
,
mosaic
=
False
),
validation_data
=
data_generator
(
lines
[
num_train
:],
batch_size
,
input_shape
,
anchors
,
num_classes
,
mosaic
=
False
,
random
=
False
),
validation_steps
=
max
(
1
,
num_val
//
batch_size
),
epochs
=
Epoch
,
initial_epoch
=
Freeze_epoch
,
...
...
yolo.py
浏览文件 @
3bb5d7b1
import
collections
import
colorsys
import
copy
import
os
...
...
@@ -10,7 +9,7 @@ from keras.layers import Input
from
keras.models
import
load_model
from
PIL
import
Image
,
ImageDraw
,
ImageFont
from
nets.yolo4
_tiny
import
yolo_body
,
yolo_eval
from
nets.yolo4
import
yolo_body
,
yolo_eval
from
utils.utils
import
letterbox_image
...
...
@@ -22,7 +21,7 @@ from utils.utils import letterbox_image
#--------------------------------------------#
class
YOLO
(
object
):
_defaults
=
{
"model_path"
:
'model_data/yolo
v4_tiny_weights_coco
.h5'
,
"model_path"
:
'model_data/yolo
4_weight
.h5'
,
"anchors_path"
:
'model_data/yolo_anchors.txt'
,
"classes_path"
:
'model_data/coco_classes.txt'
,
"score"
:
0.5
,
...
...
@@ -30,12 +29,7 @@ class YOLO(object):
"max_boxes"
:
100
,
# 显存比较小可以使用416x416
# 显存比较大可以使用608x608
"model_image_size"
:
(
416
,
416
),
#---------------------------------------------------------------------#
# 该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize,
# 在多次测试后,发现关闭letterbox_image直接resize的效果更好
#---------------------------------------------------------------------#
"letterbox_image"
:
False
,
"model_image_size"
:
(
416
,
416
)
}
@
classmethod
...
...
@@ -95,7 +89,7 @@ class YOLO(object):
try
:
self
.
yolo_model
=
load_model
(
model_path
,
compile
=
False
)
except
:
self
.
yolo_model
=
yolo_body
(
Input
(
shape
=
(
None
,
None
,
3
)),
num_anchors
//
2
,
num_classes
)
self
.
yolo_model
=
yolo_body
(
Input
(
shape
=
(
None
,
None
,
3
)),
num_anchors
//
3
,
num_classes
)
self
.
yolo_model
.
load_weights
(
self
.
model_path
)
else
:
assert
self
.
yolo_model
.
layers
[
-
1
].
output_shape
[
-
1
]
==
\
...
...
@@ -125,22 +119,19 @@ class YOLO(object):
#---------------------------------------------------------#
boxes
,
scores
,
classes
=
yolo_eval
(
self
.
yolo_model
.
output
,
self
.
anchors
,
num_classes
,
self
.
input_image_shape
,
max_boxes
=
self
.
max_boxes
,
score_threshold
=
self
.
score
,
iou_threshold
=
self
.
iou
,
letterbox_image
=
self
.
letterbox_image
)
score_threshold
=
self
.
score
,
iou_threshold
=
self
.
iou
)
return
boxes
,
scores
,
classes
#---------------------------------------------------#
# 检测图片
#---------------------------------------------------#
def
detect_image
(
self
,
image
):
start
=
timer
()
#---------------------------------------------------------#
# 给图像增加灰条,实现不失真的resize
# 也可以直接resize进行识别
#---------------------------------------------------------#
if
self
.
letterbox_image
:
boxed_image
=
letterbox_image
(
image
,
(
self
.
model_image_size
[
1
],
self
.
model_image_size
[
0
]))
else
:
boxed_image
=
image
.
convert
(
'RGB'
)
boxed_image
=
boxed_image
.
resize
((
self
.
model_image_size
[
1
],
self
.
model_image_size
[
0
]),
Image
.
BICUBIC
)
new_image_size
=
(
self
.
model_image_size
[
1
],
self
.
model_image_size
[
0
])
boxed_image
=
letterbox_image
(
image
,
new_image_size
)
image_data
=
np
.
array
(
boxed_image
,
dtype
=
'float32'
)
image_data
/=
255.
#---------------------------------------------------------#
...
...
@@ -206,6 +197,8 @@ class YOLO(object):
draw
.
text
(
text_origin
,
str
(
label
,
'UTF-8'
),
fill
=
(
0
,
0
,
0
),
font
=
font
)
del
draw
end
=
timer
()
print
(
end
-
start
)
return
image
def
close_session
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录