Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
84c2a5c2
MegEngine
项目概览
MegEngine 天元
/
MegEngine
9 个月 前同步成功
通知
392
Star
4702
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
84c2a5c2
编写于
3月 26, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/tools): add summary print for module_stats and network_visualize
GitOrigin-RevId: 7d85aa0ea2cc349369bb295d76db38a8748314ad
上级
edea528b
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
74 addition
and
68 deletion
+74
-68
imperative/python/megengine/core/tensor/dtype.py
imperative/python/megengine/core/tensor/dtype.py
+7
-0
imperative/python/megengine/tools/network_visualize.py
imperative/python/megengine/tools/network_visualize.py
+20
-17
imperative/python/megengine/utils/module_stats.py
imperative/python/megengine/utils/module_stats.py
+47
-51
未找到文件。
imperative/python/megengine/core/tensor/dtype.py
浏览文件 @
84c2a5c2
...
...
@@ -5,6 +5,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
re
from
collections
import
namedtuple
from
typing
import
Union
...
...
@@ -22,6 +23,12 @@ from .._imperative_rt.common import (
)
def
get_dtype_bit
(
dtype_name
:
str
):
numbers
=
re
.
findall
(
r
"\d+"
,
dtype_name
)
assert
len
(
numbers
)
==
1
,
"Unsupport dtype name with more than one number."
return
int
(
numbers
[
0
])
# normal dtype related
def
is_lowbit
(
dtype
):
return
(
dtype
is
intb1
)
or
(
dtype
is
intb2
)
or
(
dtype
is
intb4
)
...
...
imperative/python/megengine/tools/network_visualize.py
浏览文件 @
84c2a5c2
...
...
@@ -14,8 +14,10 @@ import numpy as np
from
megengine.core.tensor.dtype
import
is_quantize
from
megengine.logger
import
_imperative_rt_logger
,
get_logger
,
set_mgb_log_level
from
megengine.utils.module_stats
import
(
get_param_stats
,
print_flops_stats
,
print_params_stats
,
print_summary
,
sizeof_fmt
,
)
from
megengine.utils.network
import
Network
...
...
@@ -69,6 +71,7 @@ def visualize(
def
process_name
(
name
):
return
name
.
replace
(
"."
,
"/"
).
encode
(
encoding
=
"utf-8"
)
summary
=
[[
"item"
,
"value"
]]
node_list
=
[]
flops_list
=
[]
params_list
=
[]
...
...
@@ -117,26 +120,15 @@ def visualize(
)
)
if
node
.
type
==
"ImmutableTensor"
:
param_dim
=
np
.
prod
(
node_oup
.
shape
)
# TODO: consider other quantize dtypes
param_bytes
=
1
if
is_quantize
(
node_oup
.
dtype
)
else
4
param_stats
=
get_param_stats
(
node
.
numpy
())
# add tensor size attr
if
log_path
:
attr
[
"size"
]
=
AttrValue
(
s
=
sizeof_fmt
(
param_
dim
*
param_bytes
).
encode
(
encoding
=
"utf-8"
)
s
=
sizeof_fmt
(
param_
stats
[
"size"
]
).
encode
(
encoding
=
"utf-8"
)
)
params_list
.
append
(
dict
(
name
=
node
.
name
,
shape
=
node_oup
.
shape
,
param_dim
=
param_dim
,
bits
=
param_bytes
*
8
,
size
=
param_dim
*
param_bytes
,
size_cum
=
0
,
mean
=
"{:.2g}"
.
format
(
node
.
numpy
().
mean
()),
std
=
"{:.2g}"
.
format
(
node
.
numpy
().
std
()),
)
)
param_stats
[
"name"
]
=
node
.
name
params_list
.
append
(
param_stats
)
# FIXME(MGE-2165): nodes outside network module may lead to unknown display bug
if
not
len
(
node
.
name
.
split
(
"."
))
>
2
and
not
node
in
graph
.
input_vars
:
continue
...
...
@@ -152,7 +144,9 @@ def visualize(
total_flops
,
total_params
=
None
,
None
if
log_params
:
total_params
=
print_params_stats
(
params_list
,
bar_length_max
)
total_param_dims
,
total_param_size
=
print_params_stats
(
params_list
,
bar_length_max
)
if
log_flops
:
total_flops
=
print_flops_stats
(
flops_list
,
bar_length_max
)
...
...
@@ -167,6 +161,15 @@ def visualize(
writer
.
_get_file_writer
().
add_graph
((
graph_def
,
stepstats
))
# summary
extra_info
=
{
"#ops"
:
len
(
graph
.
all_oprs
),
"#params"
:
len
(
params_list
),
"total_param_dims"
:
sizeof_fmt
(
total_param_dims
),
"total_param_size"
:
sizeof_fmt
(
total_param_size
),
"total_flops"
:
sizeof_fmt
(
total_flops
,
suffix
=
"OPs"
),
"flops/param_size"
:
"{:3.3f}"
.
format
(
total_flops
/
total_param_size
),
}
print_summary
(
**
extra_info
)
# FIXME: remove this after resolving "span dist too large" warning
_imperative_rt_logger
.
set_log_level
(
old_level
)
...
...
imperative/python/megengine/utils/module_stats.py
浏览文件 @
84c2a5c2
...
...
@@ -11,10 +11,10 @@ import numpy as np
import
tabulate
import
megengine
as
mge
import
megengine.core.tensor.dtype
as
dtype
import
megengine.module
as
m
import
megengine.module.qat
as
qatm
import
megengine.module.quantized
as
qm
from
megengine.core.tensor.dtype
import
get_dtype_bit
from
megengine.functional.tensor
import
zeros
try
:
...
...
@@ -115,13 +115,13 @@ def print_flops_stats(flops, bar_length_max=20):
total_flops_num
+=
int
(
d
[
"flops_num"
])
d
[
"flops_cum"
]
=
sizeof_fmt
(
total_flops_num
,
suffix
=
"OPs"
)
for
i
in
flops
:
f
=
i
[
"flops_num"
]
i
[
"flops"
]
=
sizeof_fmt
(
f
,
suffix
=
"OPs"
)
r
=
i
[
"ratio"
]
=
f
/
total_flops_num
i
[
"percentage"
]
=
"{:.2f}%"
.
format
(
r
*
100
)
for
d
in
flops
:
f
=
d
[
"flops_num"
]
d
[
"flops"
]
=
sizeof_fmt
(
f
,
suffix
=
"OPs"
)
r
=
d
[
"ratio"
]
=
f
/
total_flops_num
d
[
"percentage"
]
=
"{:.2f}%"
.
format
(
r
*
100
)
bar_length
=
int
(
f
/
max_flops_num
*
bar_length_max
)
i
[
"bar"
]
=
"#"
*
bar_length
d
[
"bar"
]
=
"#"
*
bar_length
header
=
[
"name"
,
...
...
@@ -136,7 +136,7 @@ def print_flops_stats(flops, bar_length_max=20):
total_flops_str
=
sizeof_fmt
(
total_flops_num
,
suffix
=
"OPs"
)
total_var_size
=
sum
(
sum
(
s
[
1
]
if
len
(
s
)
>
1
else
0
for
s
in
i
[
"output_shapes"
])
for
i
in
flops
sum
(
s
[
1
]
if
len
(
s
)
>
1
else
0
for
s
in
d
[
"output_shapes"
])
for
d
in
flops
)
flops
.
append
(
dict
(
name
=
"total"
,
flops
=
total_flops_str
,
output_shapes
=
total_var_size
)
...
...
@@ -147,16 +147,29 @@ def print_flops_stats(flops, bar_length_max=20):
return
total_flops_num
def
get_param_stats
(
param
:
np
.
ndarray
):
nbits
=
get_dtype_bit
(
param
.
dtype
.
name
)
shape
=
param
.
shape
param_dim
=
np
.
prod
(
param
.
shape
)
param_size
=
param_dim
*
nbits
//
8
return
{
"shape"
:
shape
,
"mean"
:
param
.
mean
(),
"std"
:
param
.
std
(),
"param_dim"
:
param_dim
,
"nbits"
:
nbits
,
"size"
:
param_size
,
}
def
print_params_stats
(
params
,
bar_length_max
=
20
):
total_param_dims
,
total_param_size
=
0
,
0
for
d
in
params
:
total_param_dims
+=
int
(
d
[
"param_dim"
])
total_param_size
+=
int
(
d
[
"size"
])
ratio
=
d
[
"size"
]
/
total_param_size
d
[
"size"
]
=
sizeof_fmt
(
d
[
"size"
])
d
[
"size_cum"
]
=
sizeof_fmt
(
total_param_size
)
for
d
in
params
:
ratio
=
d
[
"param_dim"
]
/
total_param_dims
d
[
"ratio"
]
=
ratio
d
[
"percentage"
]
=
"{:.2f}%"
.
format
(
ratio
*
100
)
...
...
@@ -186,7 +199,13 @@ def print_params_stats(params, bar_length_max=20):
"param stats:
\n
"
+
tabulate
.
tabulate
(
dict2table
(
params
,
header
=
header
))
)
return
total_param_size
return
total_param_dims
,
total_param_size
def
print_summary
(
**
kwargs
):
data
=
[[
"item"
,
"value"
]]
data
.
extend
(
list
(
kwargs
.
items
()))
logger
.
info
(
"summary
\n
"
+
tabulate
.
tabulate
(
data
))
def
module_stats
(
...
...
@@ -206,14 +225,6 @@ def module_stats(
:param log_flops: whether print and record op flops.
"""
def
get_byteswidth
(
tensor
):
if
dtype
.
is_quantize
(
tensor
.
dtype
):
return
1
# elif dtype.is_bfloat16(tensor.dtype):
# return 2
else
:
return
4
def
module_stats_hook
(
module
,
input
,
output
,
name
=
""
):
class_name
=
str
(
module
.
__class__
).
split
(
"."
)[
-
1
].
split
(
"'"
)[
0
]
...
...
@@ -237,39 +248,15 @@ def module_stats(
if
hasattr
(
module
,
"weight"
)
and
module
.
weight
is
not
None
:
w
=
module
.
weight
value
=
w
.
numpy
()
param_dim
=
np
.
prod
(
w
.
shape
)
param_bytes
=
get_byteswidth
(
w
)
params
.
append
(
dict
(
name
=
name
+
"-w"
,
shape
=
w
.
shape
,
param_dim
=
param_dim
,
bits
=
param_bytes
*
8
,
size
=
param_dim
*
param_bytes
,
size_cum
=
0
,
mean
=
"{:.2g}"
.
format
(
value
.
mean
()),
std
=
"{:.2g}"
.
format
(
value
.
std
()),
)
)
param_stats
=
get_param_stats
(
w
.
numpy
())
param_stats
[
"name"
]
=
name
+
"-w"
params
.
append
(
param_stats
)
if
hasattr
(
module
,
"bias"
)
and
module
.
bias
is
not
None
:
b
=
module
.
bias
value
=
b
.
numpy
()
param_dim
=
np
.
prod
(
b
.
shape
)
param_bytes
=
get_byteswidth
(
b
)
params
.
append
(
dict
(
name
=
name
+
"-b"
,
shape
=
b
.
shape
,
param_dim
=
param_dim
,
bits
=
param_bytes
*
8
,
size
=
param_dim
*
param_bytes
,
size_cum
=
0
,
mean
=
"{:.2g}"
.
format
(
value
.
mean
()),
std
=
"{:.2g}"
.
format
(
value
.
std
()),
)
)
param_stats
=
get_param_stats
(
b
.
numpy
())
param_stats
[
"name"
]
=
name
+
"-b"
params
.
append
(
param_stats
)
# multiple inputs to the network
if
not
isinstance
(
input_size
[
0
],
tuple
):
...
...
@@ -293,8 +280,17 @@ def module_stats(
total_flops
,
total_params
=
0
,
0
if
log_params
:
total_param
s
=
print_params_stats
(
params
,
bar_length_max
)
total_param
_dims
,
total_param_size
=
print_params_stats
(
params
,
bar_length_max
)
if
log_flops
:
total_flops
=
print_flops_stats
(
flops
,
bar_length_max
)
extra_info
=
{
"#params"
:
len
(
params
),
"total_param_dims"
:
sizeof_fmt
(
total_param_dims
),
"total_param_size"
:
sizeof_fmt
(
total_param_size
),
"total_flops"
:
sizeof_fmt
(
total_flops
,
suffix
=
"OPs"
),
"flops/param_size"
:
"{:3.3f}"
.
format
(
total_flops
/
total_param_size
),
}
print_summary
(
**
extra_info
)
return
total_params
,
total_flops
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录