Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_47816946
simple-faster-rcnn-pytorch
提交
14d409ee
S
simple-faster-rcnn-pytorch
项目概览
weixin_47816946
/
simple-faster-rcnn-pytorch
与 Fork 源项目一致
从无法访问的项目Fork
通知
8
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
S
simple-faster-rcnn-pytorch
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
14d409ee
编写于
12月 21, 2017
作者:
C
chenyuntc
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add training for fast
上级
214c5553
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
112 addition
and
0 deletion
+112
-0
train_fast.py
train_fast.py
+112
-0
未找到文件。
train_fast.py
0 → 100644
浏览文件 @
14d409ee
import
os
import
ipdb
import
matplotlib
from
tqdm
import
tqdm
import
torch
as
t
from
config
import
opt
from
data.dataset
import
Dataset
,
TestDataset
from
model
import
FasterRCNNVGG16
from
torch.autograd
import
Variable
from
torch.utils
import
data
as
data_
from
trainer
import
FasterRCNNTrainer
from
util
import
array_tool
as
at
from
util.vis_tool
import
visdom_bbox
from
util.eval_tool
import
eval_detection_voc
matplotlib
.
use
(
'agg'
)
def
eval
(
dataloader
,
faster_rcnn
,
test_num
=
10000
):
pred_bboxes
,
pred_labels
,
pred_scores
=
list
(),
list
(),
list
()
gt_bboxes
,
gt_labels
,
gt_difficults
=
list
(),
list
(),
list
()
for
ii
,
(
imgs
,
sizes
,
gt_bboxes_
,
gt_labels_
,
gt_difficults_
)
in
tqdm
(
enumerate
(
dataloader
)):
sizes
=
[
sizes
[
0
][
0
],
sizes
[
1
][
0
]]
pred_bboxes_
,
pred_labels_
,
pred_scores_
=
faster_rcnn
.
predict2
(
imgs
,
[
sizes
])
gt_bboxes
+=
list
(
gt_bboxes_
.
numpy
())
gt_labels
+=
list
(
gt_labels_
.
numpy
())
gt_difficults
+=
list
(
gt_difficults_
.
numpy
())
pred_bboxes
+=
pred_bboxes_
pred_labels
+=
pred_labels_
pred_scores
+=
pred_scores_
if
ii
==
test_num
:
break
result
=
eval_detection_voc
(
pred_bboxes
,
pred_labels
,
pred_scores
,
gt_bboxes
,
gt_labels
,
gt_difficults
,
use_07_metric
=
True
)
return
result
def
train
(
**
kwargs
):
opt
.
_parse
(
kwargs
)
dataset
=
Dataset
(
opt
)
print
(
'load data'
)
dataloader
=
data_
.
DataLoader
(
dataset
,
\
batch_size
=
1
,
\
shuffle
=
True
,
\
# pin_memory=True,
num_workers
=
opt
.
num_workers
)
testset
=
TestDataset
(
opt
)
test_dataloader
=
data_
.
DataLoader
(
testset
,
batch_size
=
1
,
num_workers
=
2
,
shuffle
=
False
,
\
# pin_memory=True
)
faster_rcnn
=
FasterRCNNVGG16
()
print
(
'model construct completed'
)
trainer
=
FasterRCNNTrainer
(
faster_rcnn
).
cuda
()
if
opt
.
load_path
:
trainer
.
load
(
opt
.
load_path
)
print
(
'load pretrained model from %s'
%
opt
.
load_path
)
# trainer.optimizer = trainer.faster_rcnn.get_great_optimizer()
trainer
.
vis
.
text
(
dataset
.
db
.
label_names
,
win
=
'labels'
)
best_map
=
0
for
epoch
in
range
(
7
):
trainer
.
reset_meters
()
for
ii
,
(
img
,
bbox_
,
label_
,
scale
,
ori_img
)
in
tqdm
(
enumerate
(
dataloader
)):
scale
=
at
.
scalar
(
scale
)
img
,
bbox
,
label
=
img
.
cuda
().
float
(),
bbox_
.
cuda
(),
label_
.
cuda
()
img
,
bbox
,
label
=
Variable
(
img
),
Variable
(
bbox
),
Variable
(
label
)
losses
=
trainer
.
train_step
(
img
,
bbox
,
label
,
scale
)
if
(
ii
+
1
)
%
opt
.
plot_every
==
0
:
if
os
.
path
.
exists
(
opt
.
debug_file
):
ipdb
.
set_trace
()
# plot loss
trainer
.
vis
.
plot_many
(
trainer
.
get_meter_data
())
# plot groud truth bboxes
ori_img_
=
(
img
*
0.225
+
0.45
).
clamp
(
min
=
0
,
max
=
1
)
*
255
gt_img
=
visdom_bbox
(
at
.
tonumpy
(
ori_img_
)[
0
],
at
.
tonumpy
(
bbox_
)[
0
],
label_
[
0
].
numpy
())
trainer
.
vis
.
img
(
'gt_img'
,
gt_img
)
# plot predicti bboxes
_bboxes
,
_labels
,
_scores
=
trainer
.
faster_rcnn
.
predict
(
ori_img
,
visualize
=
True
)
pred_img
=
visdom_bbox
(
at
.
tonumpy
(
ori_img
[
0
]),
at
.
tonumpy
(
_bboxes
[
0
]),
at
.
tonumpy
(
_labels
[
0
]).
reshape
(
-
1
),
at
.
tonumpy
(
_scores
[
0
]))
trainer
.
vis
.
img
(
'pred_img'
,
pred_img
)
# rpn confusion matrix(meter)
trainer
.
vis
.
text
(
str
(
trainer
.
rpn_cm
.
value
().
tolist
()),
win
=
'rpn_cm'
)
# roi confusion matrix
trainer
.
vis
.
img
(
'roi_cm'
,
at
.
totensor
(
trainer
.
roi_cm
.
conf
,
False
).
float
())
if
epoch
==
4
:
trainer
.
faster_rcnn
.
scale_lr
(
opt
.
lr_decay
)
eval_result
=
eval
(
test_dataloader
,
faster_rcnn
,
test_num
=
1e100
)
print
(
'eval_result'
)
trainer
.
save
(
mAP
=
eval_result
[
'map'
])
if
__name__
==
'__main__'
:
import
fire
fire
.
Fire
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录