提交 6bb9a255 编写于 作者: M Megvii Engine Team

fix(mge/tools): fix module stats' receptive field bug for Module

GitOrigin-RevId: b4713638304205c94927d6802e858633343e9d27
上级 acf28603
...@@ -15,10 +15,11 @@ import numpy as np ...@@ -15,10 +15,11 @@ import numpy as np
from megengine.core.tensor.dtype import is_quantize from megengine.core.tensor.dtype import is_quantize
from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level
from megengine.utils.module_stats import ( from megengine.utils.module_stats import (
get_flops_stats, enable_receptive_field,
get_op_stats,
get_param_stats, get_param_stats,
print_flops_stats, print_op_stats,
print_params_stats, print_param_stats,
print_summary, print_summary,
sizeof_fmt, sizeof_fmt,
) )
...@@ -68,6 +69,8 @@ def visualize( ...@@ -68,6 +69,8 @@ def visualize(
# FIXME: remove this after resolving "span dist too large" warning # FIXME: remove this after resolving "span dist too large" warning
old_level = set_mgb_log_level(logging.ERROR) old_level = set_mgb_log_level(logging.ERROR)
enable_receptive_field()
graph = Network.load(model_path) graph = Network.load(model_path)
def process_name(name): def process_name(name):
...@@ -110,7 +113,7 @@ def visualize( ...@@ -110,7 +113,7 @@ def visualize(
"params": AttrValue(s=str(node.params).encode(encoding="utf-8")), "params": AttrValue(s=str(node.params).encode(encoding="utf-8")),
"dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")), "dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")),
} }
flops_stats = get_flops_stats(node, node.inputs, node.outputs) flops_stats = get_op_stats(node, node.inputs, node.outputs)
if flops_stats is not None: if flops_stats is not None:
# add op flops attr # add op flops attr
if log_path and hasattr(flops_stats, "flops_num"): if log_path and hasattr(flops_stats, "flops_num"):
...@@ -148,13 +151,13 @@ def visualize( ...@@ -148,13 +151,13 @@ def visualize(
total_flops, total_param_dims, total_param_size = 0, 0, 0 total_flops, total_param_dims, total_param_size = 0, 0, 0
if log_params: if log_params:
total_param_dims, total_param_size = print_params_stats( total_param_dims, total_param_size = print_param_stats(
params_list, bar_length_max params_list, bar_length_max
) )
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) extra_info["total_param_dims"] = sizeof_fmt(total_param_dims)
extra_info["total_param_size"] = sizeof_fmt(total_param_size) extra_info["total_param_size"] = sizeof_fmt(total_param_size)
if log_flops: if log_flops:
total_flops = print_flops_stats(flops_list, bar_length_max) total_flops = print_op_stats(flops_list, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if log_params and log_flops: if log_params and log_flops:
extra_info["flops/param_size"] = "{:3.3f}".format( extra_info["flops/param_size"] = "{:3.3f}".format(
......
...@@ -31,6 +31,8 @@ _calc_receptive_field_dict = {} ...@@ -31,6 +31,8 @@ _calc_receptive_field_dict = {}
def _receptive_field_fallback(module, inputs, outputs): def _receptive_field_fallback(module, inputs, outputs):
if not _receptive_field_enabled:
return
assert not hasattr(module, "_rf") assert not hasattr(module, "_rf")
assert not hasattr(module, "_stride") assert not hasattr(module, "_stride")
if len(inputs) == 0: if len(inputs) == 0:
...@@ -54,6 +56,8 @@ _iter_list = [ ...@@ -54,6 +56,8 @@ _iter_list = [
), ),
] ]
_receptive_field_enabled = False
def _register_dict(*modules, dict=None): def _register_dict(*modules, dict=None):
def callback(impl): def callback(impl):
...@@ -72,6 +76,16 @@ def register_receptive_field(*modules): ...@@ -72,6 +76,16 @@ def register_receptive_field(*modules):
return _register_dict(*modules, dict=_calc_receptive_field_dict) return _register_dict(*modules, dict=_calc_receptive_field_dict)
def enable_receptive_field():
global _receptive_field_enabled
_receptive_field_enabled = True
def disable_receptive_field():
global _receptive_field_enabled
_receptive_field_enabled = False
@register_flops( @register_flops(
m.Conv1d, m.Conv2d, m.Conv3d, m.Conv1d, m.Conv2d, m.Conv3d,
) )
...@@ -144,16 +158,16 @@ def preprocess_receptive_field(module, inputs, outputs): ...@@ -144,16 +158,16 @@ def preprocess_receptive_field(module, inputs, outputs):
# TODO: support other dimensions # TODO: support other dimensions
pre_rf = ( pre_rf = (
max(getattr(i.owner, "_rf", (1, 1))[0] for i in inputs), max(getattr(i.owner, "_rf", (1, 1))[0] for i in inputs),
max(i.owner._rf[1] for i in inputs), max(getattr(i.owner, "_rf", (1, 1))[1] for i in inputs),
) )
pre_stride = ( pre_stride = (
max(getattr(i.owner, "_stride", (1, 1))[0] for i in inputs), max(getattr(i.owner, "_stride", (1, 1))[0] for i in inputs),
max(i.owner._stride[1] for i in inputs), max(getattr(i.owner, "_stride", (1, 1))[1] for i in inputs),
) )
return pre_rf, pre_stride return pre_rf, pre_stride
def get_flops_stats(module, inputs, outputs): def get_op_stats(module, inputs, outputs):
rst = { rst = {
"input_shapes": [i.shape for i in inputs], "input_shapes": [i.shape for i in inputs],
"output_shapes": [o.shape for o in outputs], "output_shapes": [o.shape for o in outputs],
...@@ -184,7 +198,7 @@ def get_flops_stats(module, inputs, outputs): ...@@ -184,7 +198,7 @@ def get_flops_stats(module, inputs, outputs):
return return
def print_flops_stats(flops, bar_length_max=20): def print_op_stats(flops, bar_length_max=20):
max_flops_num = max([i["flops_num"] for i in flops] + [0]) max_flops_num = max([i["flops_num"] for i in flops] + [0])
total_flops_num = 0 total_flops_num = 0
for d in flops: for d in flops:
...@@ -203,13 +217,14 @@ def print_flops_stats(flops, bar_length_max=20): ...@@ -203,13 +217,14 @@ def print_flops_stats(flops, bar_length_max=20):
"class_name", "class_name",
"input_shapes", "input_shapes",
"output_shapes", "output_shapes",
"receptive_field",
"stride",
"flops", "flops",
"flops_cum", "flops_cum",
"percentage", "percentage",
"bar", "bar",
] ]
if _receptive_field_enabled:
header.insert(4, "receptive_field")
header.insert(5, "stride")
total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs")
total_var_size = sum( total_var_size = sum(
...@@ -240,7 +255,7 @@ def get_param_stats(param: np.ndarray): ...@@ -240,7 +255,7 @@ def get_param_stats(param: np.ndarray):
} }
def print_params_stats(params, bar_length_max=20): def print_param_stats(params, bar_length_max=20):
max_size = max([d["size"] for d in params] + [0]) max_size = max([d["size"] for d in params] + [0])
total_param_dims, total_param_size = 0, 0 total_param_dims, total_param_size = 0, 0
for d in params: for d in params:
...@@ -302,11 +317,12 @@ def module_stats( ...@@ -302,11 +317,12 @@ def module_stats(
:param log_params: whether print and record params size. :param log_params: whether print and record params size.
:param log_flops: whether print and record op flops. :param log_flops: whether print and record op flops.
""" """
disable_receptive_field()
def module_stats_hook(module, inputs, outputs, name=""): def module_stats_hook(module, inputs, outputs, name=""):
class_name = str(module.__class__).split(".")[-1].split("'")[0] class_name = str(module.__class__).split(".")[-1].split("'")[0]
flops_stats = get_flops_stats(module, inputs, outputs) flops_stats = get_op_stats(module, inputs, outputs)
if flops_stats is not None: if flops_stats is not None:
flops_stats["name"] = name flops_stats["name"] = name
flops_stats["class_name"] = class_name flops_stats["class_name"] = class_name
...@@ -349,11 +365,11 @@ def module_stats( ...@@ -349,11 +365,11 @@ def module_stats(
} }
total_flops, total_param_dims, total_param_size = 0, 0, 0 total_flops, total_param_dims, total_param_size = 0, 0, 0
if log_params: if log_params:
total_param_dims, total_param_size = print_params_stats(params, bar_length_max) total_param_dims, total_param_size = print_param_stats(params, bar_length_max)
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) extra_info["total_param_dims"] = sizeof_fmt(total_param_dims)
extra_info["total_param_size"] = sizeof_fmt(total_param_size) extra_info["total_param_size"] = sizeof_fmt(total_param_size)
if log_flops: if log_flops:
total_flops = print_flops_stats(flops, bar_length_max) total_flops = print_op_stats(flops, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if log_params and log_flops: if log_params and log_flops:
extra_info["flops/param_size"] = "{:3.3f}".format( extra_info["flops/param_size"] = "{:3.3f}".format(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册