From 6bb9a2556ff73f881510b9e0b09d14d9f18b2950 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 8 Apr 2021 15:54:54 +0800 Subject: [PATCH] fix(mge/tools): fix module stats' receptive field bug for Module GitOrigin-RevId: b4713638304205c94927d6802e858633343e9d27 --- .../megengine/tools/network_visualize.py | 15 ++++---- .../python/megengine/utils/module_stats.py | 36 +++++++++++++------ 2 files changed, 35 insertions(+), 16 deletions(-) diff --git a/imperative/python/megengine/tools/network_visualize.py b/imperative/python/megengine/tools/network_visualize.py index c3bc8b4a..fc67d753 100755 --- a/imperative/python/megengine/tools/network_visualize.py +++ b/imperative/python/megengine/tools/network_visualize.py @@ -15,10 +15,11 @@ import numpy as np from megengine.core.tensor.dtype import is_quantize from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level from megengine.utils.module_stats import ( - get_flops_stats, + enable_receptive_field, + get_op_stats, get_param_stats, - print_flops_stats, - print_params_stats, + print_op_stats, + print_param_stats, print_summary, sizeof_fmt, ) @@ -68,6 +69,8 @@ def visualize( # FIXME: remove this after resolving "span dist too large" warning old_level = set_mgb_log_level(logging.ERROR) + enable_receptive_field() + graph = Network.load(model_path) def process_name(name): @@ -110,7 +113,7 @@ def visualize( "params": AttrValue(s=str(node.params).encode(encoding="utf-8")), "dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")), } - flops_stats = get_flops_stats(node, node.inputs, node.outputs) + flops_stats = get_op_stats(node, node.inputs, node.outputs) if flops_stats is not None: # add op flops attr if log_path and hasattr(flops_stats, "flops_num"): @@ -148,13 +151,13 @@ def visualize( total_flops, total_param_dims, total_param_size = 0, 0, 0 if log_params: - total_param_dims, total_param_size = print_params_stats( + total_param_dims, total_param_size = print_param_stats( params_list, bar_length_max ) extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) extra_info["total_param_size"] = sizeof_fmt(total_param_size) if log_flops: - total_flops = print_flops_stats(flops_list, bar_length_max) + total_flops = print_op_stats(flops_list, bar_length_max) extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") if log_params and log_flops: extra_info["flops/param_size"] = "{:3.3f}".format( diff --git a/imperative/python/megengine/utils/module_stats.py b/imperative/python/megengine/utils/module_stats.py index fa7813ad..c7b2a37c 100644 --- a/imperative/python/megengine/utils/module_stats.py +++ b/imperative/python/megengine/utils/module_stats.py @@ -31,6 +31,8 @@ _calc_receptive_field_dict = {} def _receptive_field_fallback(module, inputs, outputs): + if not _receptive_field_enabled: + return assert not hasattr(module, "_rf") assert not hasattr(module, "_stride") if len(inputs) == 0: @@ -54,6 +56,8 @@ _iter_list = [ ), ] +_receptive_field_enabled = False + def _register_dict(*modules, dict=None): def callback(impl): @@ -72,6 +76,16 @@ def register_receptive_field(*modules): return _register_dict(*modules, dict=_calc_receptive_field_dict) +def enable_receptive_field(): + global _receptive_field_enabled + _receptive_field_enabled = True + + +def disable_receptive_field(): + global _receptive_field_enabled + _receptive_field_enabled = False + + @register_flops( m.Conv1d, m.Conv2d, m.Conv3d, ) @@ -144,16 +158,16 @@ def preprocess_receptive_field(module, inputs, outputs): # TODO: support other dimensions pre_rf = ( max(getattr(i.owner, "_rf", (1, 1))[0] for i in inputs), - max(i.owner._rf[1] for i in inputs), + max(getattr(i.owner, "_rf", (1, 1))[1] for i in inputs), ) pre_stride = ( max(getattr(i.owner, "_stride", (1, 1))[0] for i in inputs), - max(i.owner._stride[1] for i in inputs), + max(getattr(i.owner, "_stride", (1, 1))[1] for i in inputs), ) return pre_rf, pre_stride -def get_flops_stats(module, inputs, outputs): +def get_op_stats(module, inputs, outputs): rst = { "input_shapes": [i.shape for i in inputs], "output_shapes": [o.shape for o in outputs], @@ -184,7 +198,7 @@ def get_flops_stats(module, inputs, outputs): return -def print_flops_stats(flops, bar_length_max=20): +def print_op_stats(flops, bar_length_max=20): max_flops_num = max([i["flops_num"] for i in flops] + [0]) total_flops_num = 0 for d in flops: @@ -203,13 +217,14 @@ def print_flops_stats(flops, bar_length_max=20): "class_name", "input_shapes", "output_shapes", - "receptive_field", - "stride", "flops", "flops_cum", "percentage", "bar", ] + if _receptive_field_enabled: + header.insert(4, "receptive_field") + header.insert(5, "stride") total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") total_var_size = sum( @@ -240,7 +255,7 @@ def get_param_stats(param: np.ndarray): } -def print_params_stats(params, bar_length_max=20): +def print_param_stats(params, bar_length_max=20): max_size = max([d["size"] for d in params] + [0]) total_param_dims, total_param_size = 0, 0 for d in params: @@ -302,11 +317,12 @@ def module_stats( :param log_params: whether print and record params size. :param log_flops: whether print and record op flops. """ + disable_receptive_field() def module_stats_hook(module, inputs, outputs, name=""): class_name = str(module.__class__).split(".")[-1].split("'")[0] - flops_stats = get_flops_stats(module, inputs, outputs) + flops_stats = get_op_stats(module, inputs, outputs) if flops_stats is not None: flops_stats["name"] = name flops_stats["class_name"] = class_name @@ -349,11 +365,11 @@ def module_stats( } total_flops, total_param_dims, total_param_size = 0, 0, 0 if log_params: - total_param_dims, total_param_size = print_params_stats(params, bar_length_max) + total_param_dims, total_param_size = print_param_stats(params, bar_length_max) extra_info["total_param_dims"] = sizeof_fmt(total_param_dims) extra_info["total_param_size"] = sizeof_fmt(total_param_size) if log_flops: - total_flops = print_flops_stats(flops, bar_length_max) + total_flops = print_op_stats(flops, bar_length_max) extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs") if log_params and log_flops: extra_info["flops/param_size"] = "{:3.3f}".format( -- GitLab