Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
牧羊zove
fcos
提交
cd4be13c
F
fcos
项目概览
牧羊zove
/
fcos
与 Fork 源项目一致
从无法访问的项目Fork
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
F
fcos
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
cd4be13c
编写于
11月 19, 2019
作者:
Z
Zhi Tian
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add onnx
上级
b516eb54
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
236 addition
and
0 deletion
+236
-0
tools/export_model_to_onnx.py
tools/export_model_to_onnx.py
+95
-0
tools/test_fcos_onnx_model.py
tools/test_fcos_onnx_model.py
+141
-0
未找到文件。
tools/export_model_to_onnx.py
0 → 100644
浏览文件 @
cd4be13c
"""
A working example to export the R-50 based FCOS model:
python tools/export_model_to_onnx.py --config-file configs/fcos/fcos_imprv_R_50_FPN_1x.yaml MODEL.WEIGHT FCOS_imprv_R_50_FPN_1x.pth
"""
from
fcos_core.utils.env
import
setup_environment
# noqa F401 isort:skip
import
argparse
import
os
import
torch
from
fcos_core.config
import
cfg
from
fcos_core.data
import
make_data_loader
from
fcos_core.engine.inference
import
inference
from
fcos_core.modeling.detector
import
build_detection_model
from
fcos_core.utils.checkpoint
import
DetectronCheckpointer
from
fcos_core.utils.collect_env
import
collect_env_info
from
fcos_core.utils.comm
import
synchronize
,
get_rank
from
fcos_core.utils.logger
import
setup_logger
from
fcos_core.utils.miscellaneous
import
mkdir
from
collections
import
OrderedDict
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Export model to the onnx format"
)
parser
.
add_argument
(
"--config-file"
,
default
=
"/private/home/fmassa/github/detectron.pytorch_v2/configs/e2e_faster_rcnn_R_50_C4_1x_caffe2.yaml"
,
metavar
=
"FILE"
,
help
=
"path to config file"
,
)
parser
.
add_argument
(
"--output"
,
default
=
"fcos.onnx"
,
metavar
=
"FILE"
,
help
=
"path to the output onnx file"
,
)
parser
.
add_argument
(
"opts"
,
help
=
"Modify config options using the command-line"
,
default
=
None
,
nargs
=
argparse
.
REMAINDER
,
)
args
=
parser
.
parse_args
()
cfg
.
merge_from_file
(
args
.
config_file
)
cfg
.
merge_from_list
(
args
.
opts
)
cfg
.
freeze
()
assert
cfg
.
MODEL
.
FCOS_ON
,
"This script is only tested for the detector FCOS."
save_dir
=
""
logger
=
setup_logger
(
"fcos_core"
,
save_dir
,
get_rank
())
logger
.
info
(
cfg
)
logger
.
info
(
"Collecting env info (might take some time)"
)
logger
.
info
(
"
\n
"
+
collect_env_info
())
model
=
build_detection_model
(
cfg
)
model
.
to
(
cfg
.
MODEL
.
DEVICE
)
output_dir
=
cfg
.
OUTPUT_DIR
checkpointer
=
DetectronCheckpointer
(
cfg
,
model
,
save_dir
=
output_dir
)
_
=
checkpointer
.
load
(
cfg
.
MODEL
.
WEIGHT
)
onnx_model
=
torch
.
nn
.
Sequential
(
OrderedDict
([
(
'backbone'
,
model
.
backbone
),
(
'heads'
,
model
.
rpn
.
head
),
]))
input_names
=
[
"input_image"
]
dummy_input
=
torch
.
zeros
((
1
,
3
,
800
,
1216
)).
to
(
cfg
.
MODEL
.
DEVICE
)
output_names
=
[]
for
l
in
range
(
len
(
cfg
.
MODEL
.
FCOS
.
FPN_STRIDES
)):
fpn_name
=
"P{}/"
.
format
(
3
+
l
)
output_names
.
extend
([
fpn_name
+
"logits"
,
fpn_name
+
"bbox_reg"
,
fpn_name
+
"centerness"
])
torch
.
onnx
.
export
(
onnx_model
,
dummy_input
,
args
.
output
,
verbose
=
True
,
input_names
=
input_names
,
output_names
=
output_names
,
keep_initializers_as_inputs
=
True
)
logger
.
info
(
"Done. The onnx model is saved into {}."
.
format
(
args
.
output
))
if
__name__
==
"__main__"
:
main
()
tools/test_fcos_onnx_model.py
0 → 100644
浏览文件 @
cd4be13c
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# Set up custom environment before nearly anything else is imported
# NOTE: this should be the first import (no not reorder)
from
fcos_core.utils.env
import
setup_environment
# noqa F401 isort:skip
import
argparse
import
os
import
torch
from
torch
import
nn
import
onnx
from
fcos_core.config
import
cfg
from
fcos_core.data
import
make_data_loader
from
fcos_core.engine.inference
import
inference
from
fcos_core.modeling.detector
import
build_detection_model
from
fcos_core.utils.checkpoint
import
DetectronCheckpointer
from
fcos_core.utils.collect_env
import
collect_env_info
from
fcos_core.utils.comm
import
synchronize
,
get_rank
from
fcos_core.utils.logger
import
setup_logger
from
fcos_core.utils.miscellaneous
import
mkdir
from
fcos_core.modeling.rpn.fcos.inference
import
make_fcos_postprocessor
import
caffe2.python.onnx.backend
as
backend
import
numpy
as
np
class
ONNX_FCOS
(
nn
.
Module
):
def
__init__
(
self
,
onnx_model_path
,
cfg
):
super
(
ONNX_FCOS
,
self
).
__init__
()
self
.
onnx_model
=
backend
.
prepare
(
onnx
.
load
(
onnx_model_path
),
device
=
cfg
.
MODEL
.
DEVICE
.
upper
()
)
self
.
postprocessing
=
make_fcos_postprocessor
(
cfg
)
self
.
cfg
=
cfg
self
.
fpn_strides
=
cfg
.
MODEL
.
FCOS
.
FPN_STRIDES
def
forward
(
self
,
images
):
outputs
=
self
.
onnx_model
.
run
(
images
.
tensors
.
cpu
().
numpy
())
outputs
=
[
torch
.
from_numpy
(
o
).
to
(
self
.
cfg
.
MODEL
.
DEVICE
)
for
o
in
outputs
]
logits
=
outputs
[::
3
]
bbox_reg
=
outputs
[
1
::
3
]
centerness
=
outputs
[
2
::
3
]
locations
=
self
.
compute_locations
(
logits
)
boxes
=
self
.
postprocessing
(
locations
,
logits
,
bbox_reg
,
centerness
,
images
.
image_sizes
)
return
boxes
def
compute_locations
(
self
,
features
):
locations
=
[]
for
level
,
feature
in
enumerate
(
features
):
h
,
w
=
feature
.
size
()[
-
2
:]
locations_per_level
=
self
.
compute_locations_per_level
(
h
,
w
,
self
.
fpn_strides
[
level
],
feature
.
device
)
locations
.
append
(
locations_per_level
)
return
locations
def
compute_locations_per_level
(
self
,
h
,
w
,
stride
,
device
):
shifts_x
=
torch
.
arange
(
0
,
w
*
stride
,
step
=
stride
,
dtype
=
torch
.
float32
,
device
=
device
)
shifts_y
=
torch
.
arange
(
0
,
h
*
stride
,
step
=
stride
,
dtype
=
torch
.
float32
,
device
=
device
)
shift_y
,
shift_x
=
torch
.
meshgrid
(
shifts_y
,
shifts_x
)
shift_x
=
shift_x
.
reshape
(
-
1
)
shift_y
=
shift_y
.
reshape
(
-
1
)
locations
=
torch
.
stack
((
shift_x
,
shift_y
),
dim
=
1
)
+
stride
//
2
return
locations
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Test onnx models of FCOS"
)
parser
.
add_argument
(
"--config-file"
,
default
=
"/private/home/fmassa/github/detectron.pytorch_v2/configs/e2e_faster_rcnn_R_50_C4_1x_caffe2.yaml"
,
metavar
=
"FILE"
,
help
=
"path to config file"
,
)
parser
.
add_argument
(
"--onnx-model"
,
default
=
"fcos.onnx"
,
metavar
=
"FILE"
,
help
=
"path to the onnx model"
,
)
parser
.
add_argument
(
"opts"
,
help
=
"Modify config options using the command-line"
,
default
=
None
,
nargs
=
argparse
.
REMAINDER
,
)
args
=
parser
.
parse_args
()
cfg
.
merge_from_file
(
args
.
config_file
)
cfg
.
merge_from_list
(
args
.
opts
)
cfg
.
freeze
()
save_dir
=
""
logger
=
setup_logger
(
"fcos_core"
,
save_dir
,
get_rank
())
logger
.
info
(
cfg
)
logger
.
info
(
"Collecting env info (might take some time)"
)
logger
.
info
(
"
\n
"
+
collect_env_info
())
model
=
ONNX_FCOS
(
args
.
onnx_model
,
cfg
)
model
.
to
(
cfg
.
MODEL
.
DEVICE
)
iou_types
=
(
"bbox"
,)
if
cfg
.
MODEL
.
MASK_ON
:
iou_types
=
iou_types
+
(
"segm"
,)
if
cfg
.
MODEL
.
KEYPOINT_ON
:
iou_types
=
iou_types
+
(
"keypoints"
,)
output_folders
=
[
None
]
*
len
(
cfg
.
DATASETS
.
TEST
)
dataset_names
=
cfg
.
DATASETS
.
TEST
if
cfg
.
OUTPUT_DIR
:
for
idx
,
dataset_name
in
enumerate
(
dataset_names
):
output_folder
=
os
.
path
.
join
(
cfg
.
OUTPUT_DIR
,
"inference"
,
dataset_name
)
mkdir
(
output_folder
)
output_folders
[
idx
]
=
output_folder
data_loaders_val
=
make_data_loader
(
cfg
,
is_train
=
False
,
is_distributed
=
False
)
for
output_folder
,
dataset_name
,
data_loader_val
in
zip
(
output_folders
,
dataset_names
,
data_loaders_val
):
inference
(
model
,
data_loader_val
,
dataset_name
=
dataset_name
,
iou_types
=
iou_types
,
box_only
=
False
if
cfg
.
MODEL
.
FCOS_ON
or
cfg
.
MODEL
.
RETINANET_ON
else
cfg
.
MODEL
.
RPN_ONLY
,
device
=
cfg
.
MODEL
.
DEVICE
,
expected_results
=
cfg
.
TEST
.
EXPECTED_RESULTS
,
expected_results_sigma_tol
=
cfg
.
TEST
.
EXPECTED_RESULTS_SIGMA_TOL
,
output_folder
=
output_folder
,
)
synchronize
()
if
__name__
==
"__main__"
:
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录