Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
b10238ac
MegEngine
项目概览
MegEngine 天元
/
MegEngine
9 个月 前同步成功
通知
392
Star
4702
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
b10238ac
编写于
3月 26, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/tools): add support of receptive_field stats for NetworkNode
GitOrigin-RevId: 11ef3354689d343883348d4129bc89db784e3fe0
上级
84c2a5c2
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
238 addition
and
148 deletion
+238
-148
imperative/python/megengine/tools/network_visualize.py
imperative/python/megengine/tools/network_visualize.py
+28
-25
imperative/python/megengine/utils/module_stats.py
imperative/python/megengine/utils/module_stats.py
+143
-76
imperative/python/megengine/utils/network_node.py
imperative/python/megengine/utils/network_node.py
+67
-47
未找到文件。
imperative/python/megengine/tools/network_visualize.py
浏览文件 @
b10238ac
...
...
@@ -7,6 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
argparse
import
json
import
logging
import
numpy
as
np
...
...
@@ -14,6 +15,7 @@ import numpy as np
from
megengine.core.tensor.dtype
import
is_quantize
from
megengine.logger
import
_imperative_rt_logger
,
get_logger
,
set_mgb_log_level
from
megengine.utils.module_stats
import
(
get_flops_stats
,
get_param_stats
,
print_flops_stats
,
print_params_stats
,
...
...
@@ -89,6 +91,7 @@ def visualize(
inp_list
=
[
process_name
(
var
.
owner
.
name
)
for
var
in
node
.
inputs
]
if
log_path
:
# detail format see tensorboard/compat/proto/attr_value.proto
attr
=
{
"_output_shapes"
:
AttrValue
(
list
=
AttrValue
.
ListValue
(
...
...
@@ -101,24 +104,20 @@ def visualize(
]
)
),
"params"
:
AttrValue
(
s
=
str
(
node
.
params
).
encode
(
encoding
=
"utf-8"
)),
"dtype"
:
AttrValue
(
s
=
str
(
node_oup
.
dtype
).
encode
(
encoding
=
"utf-8"
)),
}
if
hasattr
(
node
,
"calc_flops"
):
flops_num
=
node
.
calc_flops
()
flops_stats
=
get_flops_stats
(
node
,
node
.
inputs
,
node
.
outputs
)
if
flops_stats
is
not
None
:
# add op flops attr
if
log_path
:
if
log_path
and
hasattr
(
flops_stats
,
"flops_num"
)
:
attr
[
"flops"
]
=
AttrValue
(
s
=
sizeof_fmt
(
flops_num
).
encode
(
encoding
=
"utf-8"
)
)
flops_list
.
append
(
dict
(
name
=
node
.
name
,
class_name
=
node
.
type
,
input_shapes
=
[
i
.
shape
for
i
in
node
.
inputs
],
output_shapes
=
[
o
.
shape
for
o
in
node
.
outputs
],
flops_num
=
flops_num
,
flops_cum
=
0
,
s
=
sizeof_fmt
(
flops_stats
[
"flops"
]).
encode
(
encoding
=
"utf-8"
)
)
)
flops_stats
[
"name"
]
=
node
.
name
flops_stats
[
"class_name"
]
=
node
.
type
flops_list
.
append
(
flops_stats
)
if
node
.
type
==
"ImmutableTensor"
:
param_stats
=
get_param_stats
(
node
.
numpy
())
# add tensor size attr
...
...
@@ -132,6 +131,7 @@ def visualize(
# FIXME(MGE-2165): nodes outside network module may lead to unknown display bug
if
not
len
(
node
.
name
.
split
(
"."
))
>
2
and
not
node
in
graph
.
input_vars
:
continue
if
log_path
:
node_list
.
append
(
NodeDef
(
...
...
@@ -141,14 +141,26 @@ def visualize(
attr
=
attr
,
)
)
# summary
extra_info
=
{
"#ops"
:
len
(
graph
.
all_oprs
),
"#params"
:
len
(
params_list
),
}
total_flops
,
total_param
s
=
None
,
None
total_flops
,
total_param
_dims
,
total_param_size
=
0
,
0
,
0
if
log_params
:
total_param_dims
,
total_param_size
=
print_params_stats
(
params_list
,
bar_length_max
)
extra_info
[
"total_param_dims"
]
=
sizeof_fmt
(
total_param_dims
)
extra_info
[
"total_param_size"
]
=
sizeof_fmt
(
total_param_size
)
if
log_flops
:
total_flops
=
print_flops_stats
(
flops_list
,
bar_length_max
)
extra_info
[
"total_flops"
]
=
sizeof_fmt
(
total_flops
,
suffix
=
"OPs"
)
if
log_params
and
log_flops
:
extra_info
[
"flops/param_size"
]
=
"{:3.3f}"
.
format
(
total_flops
/
total_param_size
)
if
log_path
:
graph_def
=
GraphDef
(
node
=
node_list
,
versions
=
VersionDef
(
producer
=
22
))
...
...
@@ -160,21 +172,12 @@ def visualize(
writer
=
SummaryWriter
(
log_path
)
writer
.
_get_file_writer
().
add_graph
((
graph_def
,
stepstats
))
# summary
extra_info
=
{
"#ops"
:
len
(
graph
.
all_oprs
),
"#params"
:
len
(
params_list
),
"total_param_dims"
:
sizeof_fmt
(
total_param_dims
),
"total_param_size"
:
sizeof_fmt
(
total_param_size
),
"total_flops"
:
sizeof_fmt
(
total_flops
,
suffix
=
"OPs"
),
"flops/param_size"
:
"{:3.3f}"
.
format
(
total_flops
/
total_param_size
),
}
print_summary
(
**
extra_info
)
# FIXME: remove this after resolving "span dist too large" warning
_imperative_rt_logger
.
set_log_level
(
old_level
)
return
total_param
s
,
total_flops
return
total_param
_size
,
total_flops
def
main
():
...
...
imperative/python/megengine/utils/module_stats.py
浏览文件 @
b10238ac
...
...
@@ -26,61 +26,95 @@ logger = mge.get_logger(__name__)
logger
.
setLevel
(
"INFO"
)
CALC_FLOPS
=
{}
def
_register_modules
(
*
modules
):
_calc_flops_dict
=
{}
_calc_receptive_field_dict
=
{}
def
_receptive_field_fallback
(
module
,
inputs
,
outputs
):
assert
not
hasattr
(
module
,
"_rf"
)
assert
not
hasattr
(
module
,
"_stride"
)
if
len
(
inputs
)
==
0
:
# TODO: support other dimension
module
.
_rf
=
(
1
,
1
)
module
.
_stride
=
(
1
,
1
)
return
module
.
_rf
,
module
.
_stride
rf
,
stride
=
preprocess_receptive_field
(
module
,
inputs
,
outputs
)
module
.
_rf
=
rf
module
.
_stride
=
stride
return
rf
,
stride
# key tuple, impl_dict, fallback
_iter_list
=
[
(
"flops_num"
,
_calc_flops_dict
,
None
),
(
(
"receptive_field"
,
"stride"
),
_calc_receptive_field_dict
,
_receptive_field_fallback
,
),
]
def
_register_dict
(
*
modules
,
dict
=
None
):
def
callback
(
impl
):
for
module
in
modules
:
CALC_FLOPS
[
module
]
=
impl
dict
[
module
]
=
impl
return
impl
return
callback
@
_register_modules
(
m
.
Conv2d
,
m
.
ConvTranspose2d
,
m
.
LocalConv2d
,
qm
.
Conv2d
,
qm
.
ConvRelu2d
,
qm
.
ConvBn2d
,
qm
.
ConvBnRelu2d
,
qatm
.
Conv2d
,
qatm
.
ConvRelu2d
,
qatm
.
ConvBn2d
,
qatm
.
ConvBnRelu2d
,
def
register_flops
(
*
modules
):
return
_register_dict
(
*
modules
,
dict
=
_calc_flops_dict
)
def
register_receptive_field
(
*
modules
):
return
_register_dict
(
*
modules
,
dict
=
_calc_receptive_field_dict
)
@
register_flops
(
m
.
Conv1d
,
m
.
Conv2d
,
m
.
Conv3d
,
)
def
count_convNd
(
module
,
input
,
output
):
def
flops_convNd
(
module
:
m
.
Conv2d
,
inputs
,
outputs
):
bias
=
1
if
module
.
bias
is
not
None
else
0
group
=
module
.
groups
ic
=
input
[
0
].
shape
[
1
]
oc
=
output
[
0
].
shape
[
1
]
ic
=
input
s
[
0
].
shape
[
1
]
oc
=
output
s
[
0
].
shape
[
1
]
goc
=
oc
//
group
gic
=
ic
//
group
N
=
output
[
0
].
shape
[
0
]
HW
=
np
.
prod
(
output
[
0
].
shape
[
2
:])
N
=
output
s
[
0
].
shape
[
0
]
HW
=
np
.
prod
(
output
s
[
0
].
shape
[
2
:])
# N x Cout x H x W x (Cin x Kw x Kh + bias)
return
N
*
HW
*
goc
*
(
gic
*
np
.
prod
(
module
.
kernel_size
)
+
bias
)
@
_register_modules
(
m
.
ConvTranspose2d
)
def
count_deconvNd
(
module
,
input
,
output
):
return
np
.
prod
(
input
[
0
].
shape
)
*
output
[
0
].
shape
[
1
]
*
np
.
prod
(
module
.
kernel_size
)
@
register_flops
(
m
.
ConvTranspose2d
)
def
flops_deconvNd
(
module
:
m
.
ConvTranspose2d
,
inputs
,
outputs
):
return
np
.
prod
(
inputs
[
0
].
shape
)
*
outputs
[
0
].
shape
[
1
]
*
np
.
prod
(
module
.
kernel_size
)
@
register_flops
(
m
.
Linear
)
def
flops_linear
(
module
:
m
.
Linear
,
inputs
,
outputs
):
bias
=
1
if
module
.
bias
is
not
None
else
0
return
np
.
prod
(
outputs
[
0
].
shape
)
*
module
.
in_features
@
_register_modules
(
m
.
Linear
,
qatm
.
Linear
,
qm
.
Linear
)
def
count_linear
(
module
,
input
,
output
):
return
np
.
prod
(
output
[
0
].
shape
)
*
module
.
in_features
@
register_flops
(
m
.
BatchMatMulActivation
)
def
flops_batchmatmul
(
module
:
m
.
BatchMatMulActivation
,
inputs
,
outputs
):
bias
=
1
if
module
.
bias
is
not
None
else
0
x
=
inputs
[
0
]
w
=
module
.
weight
batch_size
=
x
.
shape
[
0
]
n
,
p
=
x
.
shape
[
1
:]
_
,
m
=
w
.
shape
[
1
:]
return
n
*
(
p
+
bias
)
*
m
*
batch_size
# does not need import qat and quantized module since they inherit from float module.
hook_modules
=
(
m
.
Conv2d
,
m
.
ConvTranspose2d
,
m
.
LocalConv2d
,
m
.
BatchNorm2d
,
m
.
conv
.
_ConvNd
,
m
.
Linear
,
m
.
BatchMatMulActivation
,
)
...
...
@@ -106,28 +140,71 @@ def sizeof_fmt(num, suffix="B"):
return
"{}{:.1f} {}{}"
.
format
(
sign_str
,
num
,
"Yi"
,
suffix
)
def
preprocess_receptive_field
(
module
,
inputs
,
outputs
):
# TODO: support other dimensions
pre_rf
=
(
max
(
getattr
(
i
.
owner
,
"_rf"
,
(
1
,
1
))[
0
]
for
i
in
inputs
),
max
(
i
.
owner
.
_rf
[
1
]
for
i
in
inputs
),
)
pre_stride
=
(
max
(
getattr
(
i
.
owner
,
"_stride"
,
(
1
,
1
))[
0
]
for
i
in
inputs
),
max
(
i
.
owner
.
_stride
[
1
]
for
i
in
inputs
),
)
return
pre_rf
,
pre_stride
def
get_flops_stats
(
module
,
inputs
,
outputs
):
rst
=
{
"input_shapes"
:
[
i
.
shape
for
i
in
inputs
],
"output_shapes"
:
[
o
.
shape
for
o
in
outputs
],
}
valid_flag
=
False
for
key
,
_dict
,
fallback
in
_iter_list
:
for
_type
in
_dict
:
if
isinstance
(
module
,
_type
):
value
=
_dict
[
_type
](
module
,
inputs
,
outputs
)
valid_flag
=
True
break
else
:
if
fallback
is
not
None
:
value
=
fallback
(
module
,
inputs
,
outputs
)
continue
if
isinstance
(
key
,
tuple
):
assert
isinstance
(
value
,
tuple
)
for
k
,
v
in
zip
(
key
,
value
):
rst
[
k
]
=
v
else
:
rst
[
key
]
=
value
if
valid_flag
:
return
rst
else
:
return
None
return
def
print_flops_stats
(
flops
,
bar_length_max
=
20
):
flops_list
=
[
i
[
"flops_num"
]
for
i
in
flops
]
max_flops_num
=
max
(
flops_list
+
[
0
])
# calc total flops and set flops_cum
max_flops_num
=
max
([
i
[
"flops_num"
]
for
i
in
flops
]
+
[
0
])
total_flops_num
=
0
for
d
in
flops
:
total_flops_num
+=
int
(
d
[
"flops_num"
])
d
[
"flops_cum"
]
=
sizeof_fmt
(
total_flops_num
,
suffix
=
"OPs"
)
for
d
in
flops
:
f
=
d
[
"flops_num"
]
d
[
"flops"
]
=
sizeof_fmt
(
f
,
suffix
=
"OPs"
)
r
=
d
[
"ratio"
]
=
f
/
total_flops_num
d
[
"percentage"
]
=
"{:.2f}%"
.
format
(
r
*
100
)
bar_length
=
int
(
f
/
max_flops_num
*
bar_length_max
)
ratio
=
d
[
"ratio"
]
=
d
[
"flops_num"
]
/
total_flops_num
d
[
"percentage"
]
=
"{:.2f}%"
.
format
(
ratio
*
100
)
bar_length
=
int
(
d
[
"flops_num"
]
/
max_flops_num
*
bar_length_max
)
d
[
"bar"
]
=
"#"
*
bar_length
d
[
"flops"
]
=
sizeof_fmt
(
d
[
"flops_num"
],
suffix
=
"OPs"
)
header
=
[
"name"
,
"class_name"
,
"input_shapes"
,
"output_shapes"
,
"receptive_field"
,
"stride"
,
"flops"
,
"flops_cum"
,
"percentage"
,
...
...
@@ -154,8 +231,8 @@ def get_param_stats(param: np.ndarray):
param_size
=
param_dim
*
nbits
//
8
return
{
"shape"
:
shape
,
"mean"
:
param
.
mean
(
),
"std"
:
param
.
std
(
),
"mean"
:
"{:.3g}"
.
format
(
param
.
mean
()
),
"std"
:
"{:.3g}"
.
format
(
param
.
std
()
),
"param_dim"
:
param_dim
,
"nbits"
:
nbits
,
"size"
:
param_size
,
...
...
@@ -163,21 +240,20 @@ def get_param_stats(param: np.ndarray):
def
print_params_stats
(
params
,
bar_length_max
=
20
):
max_size
=
max
([
d
[
"size"
]
for
d
in
params
]
+
[
0
])
total_param_dims
,
total_param_size
=
0
,
0
for
d
in
params
:
total_param_dims
+=
int
(
d
[
"param_dim"
])
total_param_size
+=
int
(
d
[
"size"
])
ratio
=
d
[
"size"
]
/
total_param_size
d
[
"size"
]
=
sizeof_fmt
(
d
[
"size"
])
d
[
"size_cum"
]
=
sizeof_fmt
(
total_param_size
)
d
[
"ratio"
]
=
ratio
d
[
"percentage"
]
=
"{:.2f}%"
.
format
(
ratio
*
100
)
# construct bar
max_ratio
=
max
([
d
[
"ratio"
]
for
d
in
params
])
for
d
in
params
:
bar_length
=
int
(
d
[
"ratio"
]
/
max_ratio
*
bar_length_max
)
ratio
=
d
[
"size"
]
/
total_param_size
d
[
"ratio"
]
=
ratio
d
[
"percentage"
]
=
"{:.2f}%"
.
format
(
ratio
*
100
)
bar_length
=
int
(
d
[
"size"
]
/
max_size
*
bar_length_max
)
d
[
"size_bar"
]
=
"#"
*
bar_length
d
[
"size"
]
=
sizeof_fmt
(
d
[
"size"
])
param_size
=
sizeof_fmt
(
total_param_size
)
params
.
append
(
dict
(
name
=
"total"
,
param_dim
=
total_param_dims
,
size
=
param_size
,))
...
...
@@ -225,26 +301,14 @@ def module_stats(
:param log_flops: whether print and record op flops.
"""
def
module_stats_hook
(
module
,
input
,
output
,
name
=
""
):
def
module_stats_hook
(
module
,
input
s
,
outputs
,
name
=
""
):
class_name
=
str
(
module
.
__class__
).
split
(
"."
)[
-
1
].
split
(
"'"
)[
0
]
flops_fun
=
CALC_FLOPS
.
get
(
type
(
module
))
if
callable
(
flops_fun
):
flops_num
=
flops_fun
(
module
,
input
,
output
)
if
not
isinstance
(
output
,
(
list
,
tuple
)):
output
=
[
output
]
flops
.
append
(
dict
(
name
=
name
,
class_name
=
class_name
,
input_shapes
=
[
i
.
shape
for
i
in
input
],
output_shapes
=
[
o
.
shape
for
o
in
output
],
flops_num
=
flops_num
,
flops_cum
=
0
,
)
)
flops_stats
=
get_flops_stats
(
module
,
inputs
,
outputs
)
if
flops_stats
is
not
None
:
flops_stats
[
"name"
]
=
name
flops_stats
[
"class_name"
]
=
class_name
flops
.
append
(
flops_stats
)
if
hasattr
(
module
,
"weight"
)
and
module
.
weight
is
not
None
:
w
=
module
.
weight
...
...
@@ -278,19 +342,22 @@ def module_stats(
for
h
in
hooks
:
h
.
remove
()
total_flops
,
total_params
=
0
,
0
extra_info
=
{
"#params"
:
len
(
params
),
}
total_flops
,
total_param_dims
,
total_param_size
=
0
,
0
,
0
if
log_params
:
total_param_dims
,
total_param_size
=
print_params_stats
(
params
,
bar_length_max
)
extra_info
[
"total_param_dims"
]
=
sizeof_fmt
(
total_param_dims
)
extra_info
[
"total_param_size"
]
=
sizeof_fmt
(
total_param_size
)
if
log_flops
:
total_flops
=
print_flops_stats
(
flops
,
bar_length_max
)
extra_info
[
"total_flops"
]
=
sizeof_fmt
(
total_flops
,
suffix
=
"OPs"
)
if
log_params
and
log_flops
:
extra_info
[
"flops/param_size"
]
=
"{:3.3f}"
.
format
(
total_flops
/
total_param_size
)
extra_info
=
{
"#params"
:
len
(
params
),
"total_param_dims"
:
sizeof_fmt
(
total_param_dims
),
"total_param_size"
:
sizeof_fmt
(
total_param_size
),
"total_flops"
:
sizeof_fmt
(
total_flops
,
suffix
=
"OPs"
),
"flops/param_size"
:
"{:3.3f}"
.
format
(
total_flops
/
total_param_size
),
}
print_summary
(
**
extra_info
)
return
total_param
s
,
total_flops
return
total_param
_size
,
total_flops
imperative/python/megengine/utils/network_node.py
浏览文件 @
b10238ac
...
...
@@ -18,6 +18,11 @@ from ..core.ops import builtin
from
..core.tensor.megbrain_graph
import
InputNode
from
..tensor
import
Tensor
from
.comp_graph_tools
import
replace_vars
from
.module_stats
import
(
preprocess_receptive_field
,
register_flops
,
register_receptive_field
,
)
class
NetworkNode
:
...
...
@@ -225,8 +230,21 @@ class Elemwise(OpNode):
type
=
"Elemwise"
opdef
=
builtin
.
Elemwise
def
calc_flops
(
self
):
return
np
.
prod
(
self
.
outputs
[
0
].
shape
)
class
ElemwiseMultiType
(
OpNode
):
type
=
"ElemwiseMultiType"
opdef
=
builtin
.
ElemwiseMultiType
@
classmethod
def
load
(
cls
,
opr
):
obj
=
super
(
ElemwiseMultiType
,
cls
).
load
(
opr
)
obj
.
params
[
"dtype"
]
=
opr
.
outputs
[
0
].
dtype
return
obj
@
register_flops
(
Elemwise
,
ElemwiseMultiType
)
def
flops_elemwise
(
opnode
:
Elemwise
,
inputs
,
outputs
):
return
np
.
prod
(
outputs
[
0
].
shape
)
class
Reduce
(
OpNode
):
...
...
@@ -255,20 +273,24 @@ class MatrixMul(OpNode):
type
=
"MatrixMul"
opdef
=
builtin
.
MatrixMul
def
calc_flops
(
self
):
assert
len
(
self
.
inputs
[
0
].
shape
)
==
2
and
len
(
self
.
outputs
[
0
].
shape
)
==
2
mid_shape
=
self
.
inputs
[
0
].
shape
[
1
]
return
np
.
prod
(
self
.
outputs
[
0
].
shape
)
*
mid_shape
@
register_flops
(
MatrixMul
)
def
flops_matmul
(
opnode
:
MatrixMul
,
inputs
,
outputs
):
assert
len
(
inputs
[
0
].
shape
)
==
2
and
len
(
outputs
[
0
].
shape
)
==
2
mid_shape
=
inputs
[
0
].
shape
[
1
]
return
np
.
prod
(
outputs
[
0
].
shape
)
*
mid_shape
class
BatchedMatrixMul
(
OpNode
):
type
=
"BatchedMatmul"
opdef
=
builtin
.
BatchedMatrixMul
def
calc_flops
(
self
):
assert
len
(
self
.
inputs
[
0
].
shape
)
==
3
and
len
(
self
.
outputs
[
0
].
shape
)
==
3
mid_shape
=
self
.
inputs
[
0
].
shape
[
2
]
return
np
.
prod
(
self
.
outputs
[
0
].
shape
)
*
mid_shape
@
register_flops
(
BatchedMatrixMul
)
def
flops_batchmatmul
(
opnode
:
BatchedMatrixMul
,
inputs
,
outputs
):
assert
len
(
inputs
[
0
].
shape
)
==
3
and
len
(
outputs
[
0
].
shape
)
==
3
mid_shape
=
inputs
[
0
].
shape
[
2
]
return
np
.
prod
(
outputs
[
0
].
shape
)
*
mid_shape
class
Dot
(
OpNode
):
...
...
@@ -285,18 +307,6 @@ class ConvolutionForward(OpNode):
type
=
"Convolution"
opdef
=
builtin
.
Convolution
def
calc_flops
(
self
):
param_W_shape
=
self
.
inputs
[
1
].
shape
kh
=
param_W_shape
[
-
2
]
kw
=
param_W_shape
[
-
1
]
if
len
(
param_W_shape
)
==
5
:
num_input
=
param_W_shape
[
2
]
else
:
num_input
=
param_W_shape
[
1
]
NCHW
=
np
.
prod
(
self
.
outputs
[
0
].
shape
)
# N x Cout x H x W x (Cin x Kw x Kh)
return
NCHW
*
(
num_input
*
kw
*
kh
)
class
ConvolutionBackwardData
(
OpNode
):
type
=
"ConvTranspose"
...
...
@@ -343,17 +353,41 @@ class ConvBiasForward(OpNode):
obj
.
params
[
"dtype"
]
=
opr
.
outputs
[
0
].
dtype
return
obj
def
calc_flops
(
self
):
param_W_shape
=
self
.
inputs
[
1
].
shape
kh
=
param_W_shape
[
-
2
]
kw
=
param_W_shape
[
-
1
]
if
len
(
param_W_shape
)
==
5
:
num_input
=
param_W_shape
[
2
]
else
:
num_input
=
param_W_shape
[
1
]
NCHW
=
np
.
prod
(
self
.
outputs
[
0
].
shape
)
# N x Cout x H x W x (Cin x Kw x Kh + bias)
return
NCHW
*
(
num_input
*
kw
*
kh
+
1
)
@
register_flops
(
ConvolutionForward
,
ConvBiasForward
,
)
def
flops_conv
(
opnode
:
ConvolutionForward
,
inputs
,
outputs
):
param_W_shape
=
inputs
[
1
].
shape
kh
=
param_W_shape
[
-
2
]
kw
=
param_W_shape
[
-
1
]
if
len
(
param_W_shape
)
==
5
:
num_input
=
param_W_shape
[
2
]
else
:
num_input
=
param_W_shape
[
1
]
NCHW
=
np
.
prod
(
outputs
[
0
].
shape
)
bias
=
1
if
isinstance
(
opnode
,
ConvBiasForward
)
else
0
# N x Cout x H x W x (Cin x Kw x Kh)
return
NCHW
*
(
num_input
*
kw
*
kh
+
bias
)
@
register_receptive_field
(
ConvolutionForward
,
ConvBiasForward
)
def
receptive_field
(
opnode
:
ConvolutionForward
,
inputs
,
outputs
):
pre_rf
,
pre_stride
=
preprocess_receptive_field
(
opnode
,
inputs
,
outputs
)
param_W_shape
=
inputs
[
1
].
shape
kh
=
param_W_shape
[
-
2
]
kw
=
param_W_shape
[
-
1
]
rf
=
(
kh
*
pre_stride
[
0
]
+
pre_rf
[
0
]
-
pre_stride
[
0
],
kw
*
pre_stride
[
1
]
+
pre_rf
[
1
]
-
pre_stride
[
1
],
)
stride
=
(
opnode
.
params
[
"stride_h"
]
*
pre_stride
[
0
],
opnode
.
params
[
"stride_w"
]
*
pre_stride
[
1
],
)
opnode
.
_rf
=
rf
opnode
.
_stride
=
stride
return
rf
,
stride
class
BatchConvBiasForward
(
OpNode
):
...
...
@@ -652,20 +686,6 @@ class AssertEqual(OpNode):
opdef
=
builtin
.
AssertEqual
class
ElemwiseMultiType
(
OpNode
):
type
=
"ElemwiseMultiType"
opdef
=
builtin
.
ElemwiseMultiType
@
classmethod
def
load
(
cls
,
opr
):
obj
=
super
(
ElemwiseMultiType
,
cls
).
load
(
opr
)
obj
.
params
[
"dtype"
]
=
opr
.
outputs
[
0
].
dtype
return
obj
def
calc_flops
(
self
):
return
np
.
prod
(
self
.
outputs
[
0
].
shape
)
class
CvtColorForward
(
OpNode
):
type
=
"CvtColor"
opdef
=
builtin
.
CvtColor
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录