Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_47816946
simple-faster-rcnn-pytorch
提交
f649bd72
S
simple-faster-rcnn-pytorch
项目概览
weixin_47816946
/
simple-faster-rcnn-pytorch
与 Fork 源项目一致
从无法访问的项目Fork
通知
8
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
S
simple-faster-rcnn-pytorch
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
f649bd72
编写于
12月 22, 2017
作者:
C
chenyuntc
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add support for caffe-pretrain model
上级
20ed4dce
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
78 addition
and
115 deletion
+78
-115
config.py
config.py
+3
-0
data/dataset.py
data/dataset.py
+36
-9
model/faster_rcnn.py
model/faster_rcnn.py
+8
-83
model/faster_rcnn_vgg16.py
model/faster_rcnn_vgg16.py
+12
-4
train.py
train.py
+17
-17
train_as_chainer.py
train_as_chainer.py
+1
-1
train_fast.py
train_fast.py
+1
-1
未找到文件。
config.py
浏览文件 @
f649bd72
...
...
@@ -50,6 +50,9 @@ class Config:
# model
load_path
=
None
# '/mnt/3/rpn.pth'
caffe_pretrain
=
False
caffe_pretrain_path
=
'checkpoints/vgg16-caffe.pth'
def
_parse
(
self
,
kwargs
):
state_dict
=
self
.
_state_dict
()
for
k
,
v
in
kwargs
.
items
():
...
...
data/dataset.py
浏览文件 @
f649bd72
...
...
@@ -4,10 +4,37 @@ from skimage import transform as sktsf
from
torchvision
import
transforms
as
tvtsf
from
.
import
util
import
numpy
as
np
from
config
import
opt
from
util
import
array_tool
as
at
def
inverse_normalize
(
img
):
if
opt
.
caffe_pretrain
:
img
=
img
+
(
np
.
array
([
122.7717
,
115.9465
,
102.9801
]).
reshape
(
3
,
1
,
1
))
return
img
[::
-
1
,
:,
:]
# approximate un-normalize for visualize
return
(
img
*
0.225
+
0.45
).
clip
(
min
=
0
,
max
=
1
)
*
255
def
pytorch_normalze
(
img
):
"""
https://github.com/pytorch/vision/issues/223
return appr -1~1 RGB
"""
normalize
=
tvtsf
.
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
])
img
=
normalize
(
t
.
from_numpy
(
img
))
return
img
.
numpy
()
def
caffe_normalize
(
img
):
"""
return appr -125-125 BGR
"""
img
=
img
[[
2
,
1
,
0
],:,:]
#RGB-BGR
img
=
img
*
255
mean
=
np
.
array
([
122.7717
,
115.9465
,
102.9801
]).
reshape
(
3
,
1
,
1
)
img
=
(
img
-
mean
).
astype
(
np
.
float32
,
copy
=
True
)
return
img
def
preprocess
(
img
,
min_size
=
600
,
max_size
=
1000
):
"""Preprocess an image for feature extraction.
...
...
@@ -32,15 +59,15 @@ def preprocess(img, min_size=600, max_size=1000):
scale1
=
min_size
/
min
(
H
,
W
)
scale2
=
max_size
/
max
(
H
,
W
)
scale
=
min
(
scale1
,
scale2
)
# both the longer and shorter should be less than
# max_size and min_size
img
=
img
/
255.
img
=
img
[[
2
,
1
,
0
],:,:]
#RGB-BGR
img
=
sktsf
.
resize
(
img
,
(
C
,
H
*
scale
,
W
*
scale
),
mode
=
'reflect'
)
img
=
img
*
255
mean
=
np
.
array
([
122.7717
,
115.9465
,
102.9801
]).
reshape
(
3
,
1
,
1
)
img
=
(
img
-
mean
).
astype
(
np
.
float32
,
copy
=
True
)
return
img
# both the longer and shorter should be less than
# max_size and min_size
if
opt
.
caffe_pretrain
:
normalize
=
caffe_normalize
else
:
normalize
=
pytorch_normalze
return
normalize
(
img
)
class
Transform
(
object
):
...
...
@@ -77,7 +104,7 @@ class Dataset():
img
,
bbox
,
label
,
scale
=
self
.
tsf
((
ori_img
,
bbox
,
label
))
# TODO: check whose stride is negative to fix this instead copy all
# some of the strides of a given numpy array are negative.
return
img
.
copy
(),
bbox
.
copy
(),
label
.
copy
(),
scale
,
ori_img
return
img
.
copy
(),
bbox
.
copy
(),
label
.
copy
(),
scale
def
__len__
(
self
):
return
len
(
self
.
db
)
...
...
model/faster_rcnn.py
浏览文件 @
f649bd72
...
...
@@ -195,7 +195,7 @@ class FasterRCNN(nn.Module):
score
=
np
.
concatenate
(
score
,
axis
=
0
).
astype
(
np
.
float32
)
return
bbox
,
label
,
score
def
predict
(
self
,
imgs
,
visualize
=
False
):
def
predict
(
self
,
imgs
,
sizes
=
None
,
visualize
=
False
):
"""Detect objects from images.
This method predicts objects for each image.
...
...
@@ -226,13 +226,13 @@ class FasterRCNN(nn.Module):
self
.
eval
()
if
visualize
:
self
.
use_preset
(
'visualize'
)
prepared_imgs
=
list
()
sizes
=
list
()
for
img
in
imgs
:
size
=
img
.
shape
[
1
:]
img
=
preprocess
(
img
.
numpy
(
))
prepared_imgs
.
append
(
img
)
sizes
.
append
(
size
)
prepared_imgs
=
list
()
sizes
=
list
()
for
img
in
imgs
:
size
=
img
.
shape
[
1
:]
img
=
preprocess
(
at
.
tonumpy
(
img
))
prepared_imgs
.
append
(
img
)
sizes
.
append
(
size
)
bboxes
=
list
()
labels
=
list
()
...
...
@@ -278,81 +278,6 @@ class FasterRCNN(nn.Module):
self
.
use_preset
(
'evaluate'
)
self
.
train
()
return
bboxes
,
labels
,
scores
def
predict2
(
self
,
prepared_imgs
,
sizes
):
"""Detect objects from images.
This method predicts objects for each image.
Args:
imgs (iterable of numpy.ndarray): Arrays holding images.
All images are in CHW and RGB format
and the range of their value is :math:`[0, 255]`.
Returns:
tuple of lists:
This method returns a tuple of three lists,
:obj:`(bboxes, labels, scores)`.
* **bboxes**: A list of float arrays of shape :math:`(R, 4)`,
\
where :math:`R` is the number of bounding boxes in a image.
\
Each bouding box is organized by
\
:math:`(y_{min}, x_{min}, y_{max}, x_{max})`
\
in the second axis.
* **labels** : A list of integer arrays of shape :math:`(R,)`.
\
Each value indicates the class of the bounding box.
\
Values are in range :math:`[0, L - 1]`, where :math:`L` is the
\
number of the foreground classes.
* **scores** : A list of float arrays of shape :math:`(R,)`.
\
Each value indicates how confident the prediction is.
"""
self
.
eval
()
# self.use_preset('visualize')
self
.
use_preset
(
'evaluate'
)
bboxes
=
list
()
labels
=
list
()
scores
=
list
()
for
img
,
size
in
zip
(
prepared_imgs
,
sizes
):
img
=
t
.
autograd
.
Variable
(
at
.
totensor
(
img
).
float
()[
None
],
volatile
=
True
)
scale
=
img
.
shape
[
3
]
/
size
[
1
]
roi_cls_loc
,
roi_scores
,
rois
,
_
=
self
(
img
,
scale
=
scale
)
# We are assuming that batch size is 1.
# roi_cls_loc = at.tonumpy(roi_cls_locs)#.data.numpy()
roi_score
=
roi_scores
.
data
roi_cls_loc
=
roi_cls_loc
.
data
roi
=
at
.
totensor
(
rois
)
/
scale
# Convert predictions to bounding boxes in image coordinates.
# Bounding boxes are scaled to the scale of the input images.
mean
=
t
.
Tensor
(
self
.
loc_normalize_mean
).
cuda
().
\
repeat
(
self
.
n_class
)[
None
]
std
=
t
.
Tensor
(
self
.
loc_normalize_std
).
cuda
().
\
repeat
(
self
.
n_class
)[
None
]
roi_cls_loc
=
(
roi_cls_loc
*
std
+
mean
)
roi_cls_loc
=
roi_cls_loc
.
view
(
-
1
,
self
.
n_class
,
4
)
roi
=
roi
.
view
(
-
1
,
1
,
4
).
expand_as
(
roi_cls_loc
)
cls_bbox
=
loc2bbox
(
at
.
tonumpy
(
roi
).
reshape
((
-
1
,
4
)),
at
.
tonumpy
(
roi_cls_loc
).
reshape
((
-
1
,
4
)))
cls_bbox
=
at
.
totensor
(
cls_bbox
)
cls_bbox
=
cls_bbox
.
view
(
-
1
,
self
.
n_class
*
4
)
# clip bounding box
cls_bbox
[:,
0
::
2
]
=
(
cls_bbox
[:,
0
::
2
]).
clamp
(
min
=
0
,
max
=
size
[
0
])
cls_bbox
[:,
1
::
2
]
=
(
cls_bbox
[:,
1
::
2
]).
clamp
(
min
=
0
,
max
=
size
[
1
])
prob
=
at
.
tonumpy
(
F
.
softmax
(
at
.
tovariable
(
roi_score
),
dim
=
1
))
raw_cls_bbox
=
at
.
tonumpy
(
cls_bbox
)
raw_prob
=
at
.
tonumpy
(
prob
)
bbox
,
label
,
score
=
self
.
_suppress
(
raw_cls_bbox
,
raw_prob
)
bboxes
.
append
(
bbox
)
labels
.
append
(
label
)
scores
.
append
(
score
)
# self.use_preset('evaluate')
self
.
train
()
return
bboxes
,
labels
,
scores
def
get_optimizer_group
(
self
):
self
.
lr1
,
self
.
lr2
,
self
.
lr3
=
opt
.
lr1
,
opt
.
lr2
,
opt
.
lr3
...
...
model/faster_rcnn_vgg16.py
浏览文件 @
f649bd72
...
...
@@ -12,8 +12,14 @@ from config import opt
def
decom_vgg16
(
pretrained
=
True
):
# the 30th layer of features is relu of conv5_3
model
=
vgg16
(
pretrained
=
False
)
model
.
load_state_dict
(
t
.
load
(
'/home/a/code/pytorch/faster-rcnn/pytorch-faster-rcnn/data/imagenet_weights/vgg16.pth'
))
if
opt
.
caffe_pretrain
:
model
=
vgg16
(
pretrained
=
False
)
if
not
opt
.
load_path
:
model
.
load_state_dict
(
t
.
load
(
opt
.
caffe_pretrain_path
))
else
:
model
=
vgg16
(
not
opt
.
load_path
)
features
=
list
(
model
.
features
)[:
30
]
classifier
=
model
.
classifier
...
...
@@ -139,8 +145,10 @@ class FasterRCNNVGG16(FasterRCNN):
ratios
=
[
0.5
,
1
,
2
],
anchor_scales
=
[
8
,
16
,
32
]
):
if
opt
.
use_chainer
:
decom
=
decom_vgg16_chainer
else
:
decom
=
decom_vgg16
if
opt
.
use_chainer
:
decom
=
decom_vgg16_chainer
else
:
decom
=
decom_vgg16
extractor
,
classifier
=
decom
(
not
opt
.
load_path
)
rpn
=
RegionProposalNetwork
(
...
...
train.py
浏览文件 @
f649bd72
...
...
@@ -6,7 +6,7 @@ from tqdm import tqdm
import
torch
as
t
from
config
import
opt
from
data.dataset
import
Dataset
,
TestDataset
from
data.dataset
import
Dataset
,
TestDataset
,
inverse_normalize
from
model
import
FasterRCNNVGG16
from
torch.autograd
import
Variable
from
torch.utils
import
data
as
data_
...
...
@@ -22,7 +22,7 @@ def eval(dataloader, faster_rcnn, test_num=10000):
gt_bboxes
,
gt_labels
,
gt_difficults
=
list
(),
list
(),
list
()
for
ii
,
(
imgs
,
sizes
,
gt_bboxes_
,
gt_labels_
,
gt_difficults_
)
in
tqdm
(
enumerate
(
dataloader
)):
sizes
=
[
sizes
[
0
][
0
],
sizes
[
1
][
0
]]
pred_bboxes_
,
pred_labels_
,
pred_scores_
=
faster_rcnn
.
predict
2
(
imgs
,
[
sizes
])
pred_bboxes_
,
pred_labels_
,
pred_scores_
=
faster_rcnn
.
predict
(
imgs
,
[
sizes
])
gt_bboxes
+=
list
(
gt_bboxes_
.
numpy
())
gt_labels
+=
list
(
gt_labels_
.
numpy
())
gt_difficults
+=
list
(
gt_difficults_
.
numpy
())
...
...
@@ -67,7 +67,7 @@ def train(**kwargs):
best_map
=
0
for
epoch
in
range
(
opt
.
epoch
):
trainer
.
reset_meters
()
for
ii
,
(
img
,
bbox_
,
label_
,
scale
,
ori_img
)
in
tqdm
(
enumerate
(
dataloader
)):
for
ii
,
(
img
,
bbox_
,
label_
,
scale
)
in
tqdm
(
enumerate
(
dataloader
)):
scale
=
at
.
scalar
(
scale
)
img
,
bbox
,
label
=
img
.
cuda
().
float
(),
bbox_
.
cuda
(),
label_
.
cuda
()
img
,
bbox
,
label
=
Variable
(
img
),
Variable
(
bbox
),
Variable
(
label
)
...
...
@@ -80,20 +80,20 @@ def train(**kwargs):
# plot loss
trainer
.
vis
.
plot_many
(
trainer
.
get_meter_data
())
#
#
plot groud truth bboxes
# ori_img_ = (img * 0.225 + 0.45).clamp(min=0, max=1) * 255
# gt_img = visdom_bbox(at.tonumpy(ori_img_)[0]
,
# at.tonumpy(bbox_)[0]
,
# label_[0].numpy(
))
#
trainer.vis.img('gt_img', gt_img)
#
#
plot predicti bboxes
# _bboxes, _labels, _scores = trainer.faster_rcnn.predict(ori_img
,visualize=True)
# pred_img = visdom_bbox( at.tonumpy(ori_img[0])
,
#
at.tonumpy(_bboxes[0]),
#
at.tonumpy(_labels[0]).reshape(-1),
#
at.tonumpy(_scores[0]))
#
trainer.vis.img('pred_img', pred_img)
# plot groud truth bboxes
ori_img_
=
inverse_normalize
(
at
.
tonumpy
(
img
[
0
]))
gt_img
=
visdom_bbox
(
ori_img_
,
at
.
tonumpy
(
bbox_
[
0
])
,
at
.
tonumpy
(
label_
[
0
]
))
trainer
.
vis
.
img
(
'gt_img'
,
gt_img
)
# plot predicti bboxes
_bboxes
,
_labels
,
_scores
=
trainer
.
faster_rcnn
.
predict
([
ori_img_
]
,
visualize
=
True
)
pred_img
=
visdom_bbox
(
ori_img_
,
at
.
tonumpy
(
_bboxes
[
0
]),
at
.
tonumpy
(
_labels
[
0
]).
reshape
(
-
1
),
at
.
tonumpy
(
_scores
[
0
]))
trainer
.
vis
.
img
(
'pred_img'
,
pred_img
)
# rpn confusion matrix(meter)
trainer
.
vis
.
text
(
str
(
trainer
.
rpn_cm
.
value
().
tolist
()),
win
=
'rpn_cm'
)
...
...
train_as_chainer.py
浏览文件 @
f649bd72
...
...
@@ -23,7 +23,7 @@ def eval(dataloader, faster_rcnn, test_num=10000):
gt_bboxes
,
gt_labels
,
gt_difficults
=
list
(),
list
(),
list
()
for
ii
,
(
imgs
,
sizes
,
gt_bboxes_
,
gt_labels_
,
gt_difficults_
)
in
tqdm
(
enumerate
(
dataloader
)):
sizes
=
[
sizes
[
0
][
0
],
sizes
[
1
][
0
]]
pred_bboxes_
,
pred_labels_
,
pred_scores_
=
faster_rcnn
.
predict
2
(
imgs
,
[
sizes
])
pred_bboxes_
,
pred_labels_
,
pred_scores_
=
faster_rcnn
.
predict
(
imgs
,
[
sizes
])
gt_bboxes
+=
list
(
gt_bboxes_
.
numpy
())
gt_labels
+=
list
(
gt_labels_
.
numpy
())
gt_difficults
+=
list
(
gt_difficults_
.
numpy
())
...
...
train_fast.py
浏览文件 @
f649bd72
...
...
@@ -22,7 +22,7 @@ def eval(dataloader, faster_rcnn, test_num=10000):
gt_bboxes
,
gt_labels
,
gt_difficults
=
list
(),
list
(),
list
()
for
ii
,
(
imgs
,
sizes
,
gt_bboxes_
,
gt_labels_
,
gt_difficults_
)
in
tqdm
(
enumerate
(
dataloader
)):
sizes
=
[
sizes
[
0
][
0
],
sizes
[
1
][
0
]]
pred_bboxes_
,
pred_labels_
,
pred_scores_
=
faster_rcnn
.
predict
2
(
imgs
,
[
sizes
])
pred_bboxes_
,
pred_labels_
,
pred_scores_
=
faster_rcnn
.
predict
(
imgs
,
[
sizes
])
gt_bboxes
+=
list
(
gt_bboxes_
.
numpy
())
gt_labels
+=
list
(
gt_labels_
.
numpy
())
gt_difficults
+=
list
(
gt_difficults_
.
numpy
())
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录