提交 6bb9a255 编写于 作者: M Megvii Engine Team

fix(mge/tools): fix module stats' receptive field bug for Module

GitOrigin-RevId: b4713638304205c94927d6802e858633343e9d27
上级 acf28603
......@@ -15,10 +15,11 @@ import numpy as np
from megengine.core.tensor.dtype import is_quantize
from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level
from megengine.utils.module_stats import (
get_flops_stats,
enable_receptive_field,
get_op_stats,
get_param_stats,
print_flops_stats,
print_params_stats,
print_op_stats,
print_param_stats,
print_summary,
sizeof_fmt,
)
......@@ -68,6 +69,8 @@ def visualize(
# FIXME: remove this after resolving "span dist too large" warning
old_level = set_mgb_log_level(logging.ERROR)
enable_receptive_field()
graph = Network.load(model_path)
def process_name(name):
......@@ -110,7 +113,7 @@ def visualize(
"params": AttrValue(s=str(node.params).encode(encoding="utf-8")),
"dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")),
}
flops_stats = get_flops_stats(node, node.inputs, node.outputs)
flops_stats = get_op_stats(node, node.inputs, node.outputs)
if flops_stats is not None:
# add op flops attr
if log_path and hasattr(flops_stats, "flops_num"):
......@@ -148,13 +151,13 @@ def visualize(
total_flops, total_param_dims, total_param_size = 0, 0, 0
if log_params:
total_param_dims, total_param_size = print_params_stats(
total_param_dims, total_param_size = print_param_stats(
params_list, bar_length_max
)
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims)
extra_info["total_param_size"] = sizeof_fmt(total_param_size)
if log_flops:
total_flops = print_flops_stats(flops_list, bar_length_max)
total_flops = print_op_stats(flops_list, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if log_params and log_flops:
extra_info["flops/param_size"] = "{:3.3f}".format(
......
......@@ -31,6 +31,8 @@ _calc_receptive_field_dict = {}
def _receptive_field_fallback(module, inputs, outputs):
if not _receptive_field_enabled:
return
assert not hasattr(module, "_rf")
assert not hasattr(module, "_stride")
if len(inputs) == 0:
......@@ -54,6 +56,8 @@ _iter_list = [
),
]
_receptive_field_enabled = False
def _register_dict(*modules, dict=None):
def callback(impl):
......@@ -72,6 +76,16 @@ def register_receptive_field(*modules):
return _register_dict(*modules, dict=_calc_receptive_field_dict)
def enable_receptive_field():
global _receptive_field_enabled
_receptive_field_enabled = True
def disable_receptive_field():
global _receptive_field_enabled
_receptive_field_enabled = False
@register_flops(
m.Conv1d, m.Conv2d, m.Conv3d,
)
......@@ -144,16 +158,16 @@ def preprocess_receptive_field(module, inputs, outputs):
# TODO: support other dimensions
pre_rf = (
max(getattr(i.owner, "_rf", (1, 1))[0] for i in inputs),
max(i.owner._rf[1] for i in inputs),
max(getattr(i.owner, "_rf", (1, 1))[1] for i in inputs),
)
pre_stride = (
max(getattr(i.owner, "_stride", (1, 1))[0] for i in inputs),
max(i.owner._stride[1] for i in inputs),
max(getattr(i.owner, "_stride", (1, 1))[1] for i in inputs),
)
return pre_rf, pre_stride
def get_flops_stats(module, inputs, outputs):
def get_op_stats(module, inputs, outputs):
rst = {
"input_shapes": [i.shape for i in inputs],
"output_shapes": [o.shape for o in outputs],
......@@ -184,7 +198,7 @@ def get_flops_stats(module, inputs, outputs):
return
def print_flops_stats(flops, bar_length_max=20):
def print_op_stats(flops, bar_length_max=20):
max_flops_num = max([i["flops_num"] for i in flops] + [0])
total_flops_num = 0
for d in flops:
......@@ -203,13 +217,14 @@ def print_flops_stats(flops, bar_length_max=20):
"class_name",
"input_shapes",
"output_shapes",
"receptive_field",
"stride",
"flops",
"flops_cum",
"percentage",
"bar",
]
if _receptive_field_enabled:
header.insert(4, "receptive_field")
header.insert(5, "stride")
total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs")
total_var_size = sum(
......@@ -240,7 +255,7 @@ def get_param_stats(param: np.ndarray):
}
def print_params_stats(params, bar_length_max=20):
def print_param_stats(params, bar_length_max=20):
max_size = max([d["size"] for d in params] + [0])
total_param_dims, total_param_size = 0, 0
for d in params:
......@@ -302,11 +317,12 @@ def module_stats(
:param log_params: whether print and record params size.
:param log_flops: whether print and record op flops.
"""
disable_receptive_field()
def module_stats_hook(module, inputs, outputs, name=""):
class_name = str(module.__class__).split(".")[-1].split("'")[0]
flops_stats = get_flops_stats(module, inputs, outputs)
flops_stats = get_op_stats(module, inputs, outputs)
if flops_stats is not None:
flops_stats["name"] = name
flops_stats["class_name"] = class_name
......@@ -349,11 +365,11 @@ def module_stats(
}
total_flops, total_param_dims, total_param_size = 0, 0, 0
if log_params:
total_param_dims, total_param_size = print_params_stats(params, bar_length_max)
total_param_dims, total_param_size = print_param_stats(params, bar_length_max)
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims)
extra_info["total_param_size"] = sizeof_fmt(total_param_size)
if log_flops:
total_flops = print_flops_stats(flops, bar_length_max)
total_flops = print_op_stats(flops, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if log_params and log_flops:
extra_info["flops/param_size"] = "{:3.3f}".format(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册