提交 84c2a5c2 编写于 作者: M Megvii Engine Team

feat(mge/tools): add summary print for module_stats and network_visualize

GitOrigin-RevId: 7d85aa0ea2cc349369bb295d76db38a8748314ad
上级 edea528b
......@@ -5,6 +5,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import re
from collections import namedtuple
from typing import Union
......@@ -22,6 +23,12 @@ from .._imperative_rt.common import (
)
def get_dtype_bit(dtype_name: str):
numbers = re.findall(r"\d+", dtype_name)
assert len(numbers) == 1, "Unsupport dtype name with more than one number."
return int(numbers[0])
# normal dtype related
def is_lowbit(dtype):
return (dtype is intb1) or (dtype is intb2) or (dtype is intb4)
......
......@@ -14,8 +14,10 @@ import numpy as np
from megengine.core.tensor.dtype import is_quantize
from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level
from megengine.utils.module_stats import (
get_param_stats,
print_flops_stats,
print_params_stats,
print_summary,
sizeof_fmt,
)
from megengine.utils.network import Network
......@@ -69,6 +71,7 @@ def visualize(
def process_name(name):
return name.replace(".", "/").encode(encoding="utf-8")
summary = [["item", "value"]]
node_list = []
flops_list = []
params_list = []
......@@ -117,26 +120,15 @@ def visualize(
)
)
if node.type == "ImmutableTensor":
param_dim = np.prod(node_oup.shape)
# TODO: consider other quantize dtypes
param_bytes = 1 if is_quantize(node_oup.dtype) else 4
param_stats = get_param_stats(node.numpy())
# add tensor size attr
if log_path:
attr["size"] = AttrValue(
s=sizeof_fmt(param_dim * param_bytes).encode(encoding="utf-8")
s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8")
)
params_list.append(
dict(
name=node.name,
shape=node_oup.shape,
param_dim=param_dim,
bits=param_bytes * 8,
size=param_dim * param_bytes,
size_cum=0,
mean="{:.2g}".format(node.numpy().mean()),
std="{:.2g}".format(node.numpy().std()),
)
)
param_stats["name"] = node.name
params_list.append(param_stats)
# FIXME(MGE-2165): nodes outside network module may lead to unknown display bug
if not len(node.name.split(".")) > 2 and not node in graph.input_vars:
continue
......@@ -152,7 +144,9 @@ def visualize(
total_flops, total_params = None, None
if log_params:
total_params = print_params_stats(params_list, bar_length_max)
total_param_dims, total_param_size = print_params_stats(
params_list, bar_length_max
)
if log_flops:
total_flops = print_flops_stats(flops_list, bar_length_max)
......@@ -167,6 +161,15 @@ def visualize(
writer._get_file_writer().add_graph((graph_def, stepstats))
# summary
extra_info = {
"#ops": len(graph.all_oprs),
"#params": len(params_list),
"total_param_dims": sizeof_fmt(total_param_dims),
"total_param_size": sizeof_fmt(total_param_size),
"total_flops": sizeof_fmt(total_flops, suffix="OPs"),
"flops/param_size": "{:3.3f}".format(total_flops / total_param_size),
}
print_summary(**extra_info)
# FIXME: remove this after resolving "span dist too large" warning
_imperative_rt_logger.set_log_level(old_level)
......
......@@ -11,10 +11,10 @@ import numpy as np
import tabulate
import megengine as mge
import megengine.core.tensor.dtype as dtype
import megengine.module as m
import megengine.module.qat as qatm
import megengine.module.quantized as qm
from megengine.core.tensor.dtype import get_dtype_bit
from megengine.functional.tensor import zeros
try:
......@@ -115,13 +115,13 @@ def print_flops_stats(flops, bar_length_max=20):
total_flops_num += int(d["flops_num"])
d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs")
for i in flops:
f = i["flops_num"]
i["flops"] = sizeof_fmt(f, suffix="OPs")
r = i["ratio"] = f / total_flops_num
i["percentage"] = "{:.2f}%".format(r * 100)
for d in flops:
f = d["flops_num"]
d["flops"] = sizeof_fmt(f, suffix="OPs")
r = d["ratio"] = f / total_flops_num
d["percentage"] = "{:.2f}%".format(r * 100)
bar_length = int(f / max_flops_num * bar_length_max)
i["bar"] = "#" * bar_length
d["bar"] = "#" * bar_length
header = [
"name",
......@@ -136,7 +136,7 @@ def print_flops_stats(flops, bar_length_max=20):
total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs")
total_var_size = sum(
sum(s[1] if len(s) > 1 else 0 for s in i["output_shapes"]) for i in flops
sum(s[1] if len(s) > 1 else 0 for s in d["output_shapes"]) for d in flops
)
flops.append(
dict(name="total", flops=total_flops_str, output_shapes=total_var_size)
......@@ -147,16 +147,29 @@ def print_flops_stats(flops, bar_length_max=20):
return total_flops_num
def get_param_stats(param: np.ndarray):
nbits = get_dtype_bit(param.dtype.name)
shape = param.shape
param_dim = np.prod(param.shape)
param_size = param_dim * nbits // 8
return {
"shape": shape,
"mean": param.mean(),
"std": param.std(),
"param_dim": param_dim,
"nbits": nbits,
"size": param_size,
}
def print_params_stats(params, bar_length_max=20):
total_param_dims, total_param_size = 0, 0
for d in params:
total_param_dims += int(d["param_dim"])
total_param_size += int(d["size"])
ratio = d["size"] / total_param_size
d["size"] = sizeof_fmt(d["size"])
d["size_cum"] = sizeof_fmt(total_param_size)
for d in params:
ratio = d["param_dim"] / total_param_dims
d["ratio"] = ratio
d["percentage"] = "{:.2f}%".format(ratio * 100)
......@@ -186,7 +199,13 @@ def print_params_stats(params, bar_length_max=20):
"param stats: \n" + tabulate.tabulate(dict2table(params, header=header))
)
return total_param_size
return total_param_dims, total_param_size
def print_summary(**kwargs):
data = [["item", "value"]]
data.extend(list(kwargs.items()))
logger.info("summary\n" + tabulate.tabulate(data))
def module_stats(
......@@ -206,14 +225,6 @@ def module_stats(
:param log_flops: whether print and record op flops.
"""
def get_byteswidth(tensor):
if dtype.is_quantize(tensor.dtype):
return 1
# elif dtype.is_bfloat16(tensor.dtype):
# return 2
else:
return 4
def module_stats_hook(module, input, output, name=""):
class_name = str(module.__class__).split(".")[-1].split("'")[0]
......@@ -237,39 +248,15 @@ def module_stats(
if hasattr(module, "weight") and module.weight is not None:
w = module.weight
value = w.numpy()
param_dim = np.prod(w.shape)
param_bytes = get_byteswidth(w)
params.append(
dict(
name=name + "-w",
shape=w.shape,
param_dim=param_dim,
bits=param_bytes * 8,
size=param_dim * param_bytes,
size_cum=0,
mean="{:.2g}".format(value.mean()),
std="{:.2g}".format(value.std()),
)
)
param_stats = get_param_stats(w.numpy())
param_stats["name"] = name + "-w"
params.append(param_stats)
if hasattr(module, "bias") and module.bias is not None:
b = module.bias
value = b.numpy()
param_dim = np.prod(b.shape)
param_bytes = get_byteswidth(b)
params.append(
dict(
name=name + "-b",
shape=b.shape,
param_dim=param_dim,
bits=param_bytes * 8,
size=param_dim * param_bytes,
size_cum=0,
mean="{:.2g}".format(value.mean()),
std="{:.2g}".format(value.std()),
)
)
param_stats = get_param_stats(b.numpy())
param_stats["name"] = name + "-b"
params.append(param_stats)
# multiple inputs to the network
if not isinstance(input_size[0], tuple):
......@@ -293,8 +280,17 @@ def module_stats(
total_flops, total_params = 0, 0
if log_params:
total_params = print_params_stats(params, bar_length_max)
total_param_dims, total_param_size = print_params_stats(params, bar_length_max)
if log_flops:
total_flops = print_flops_stats(flops, bar_length_max)
extra_info = {
"#params": len(params),
"total_param_dims": sizeof_fmt(total_param_dims),
"total_param_size": sizeof_fmt(total_param_size),
"total_flops": sizeof_fmt(total_flops, suffix="OPs"),
"flops/param_size": "{:3.3f}".format(total_flops / total_param_size),
}
print_summary(**extra_info)
return total_params, total_flops
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册