From 84c2a5c27aa5645bdee9b3ed67cdf1c5d02872f4 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 26 Mar 2021 16:21:18 +0800 Subject: [PATCH] feat(mge/tools): add summary print for module_stats and network_visualize GitOrigin-RevId: 7d85aa0ea2cc349369bb295d76db38a8748314ad --- .../python/megengine/core/tensor/dtype.py | 7 ++ .../megengine/tools/network_visualize.py | 37 +++---- .../python/megengine/utils/module_stats.py | 98 +++++++++---------- 3 files changed, 74 insertions(+), 68 deletions(-) diff --git a/imperative/python/megengine/core/tensor/dtype.py b/imperative/python/megengine/core/tensor/dtype.py index f68fde121..408b9ded1 100644 --- a/imperative/python/megengine/core/tensor/dtype.py +++ b/imperative/python/megengine/core/tensor/dtype.py @@ -5,6 +5,7 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import re from collections import namedtuple from typing import Union @@ -22,6 +23,12 @@ from .._imperative_rt.common import ( ) +def get_dtype_bit(dtype_name: str): + numbers = re.findall(r"\d+", dtype_name) + assert len(numbers) == 1, "Unsupport dtype name with more than one number." + return int(numbers[0]) + + # normal dtype related def is_lowbit(dtype): return (dtype is intb1) or (dtype is intb2) or (dtype is intb4) diff --git a/imperative/python/megengine/tools/network_visualize.py b/imperative/python/megengine/tools/network_visualize.py index d155755d4..c0868ed39 100755 --- a/imperative/python/megengine/tools/network_visualize.py +++ b/imperative/python/megengine/tools/network_visualize.py @@ -14,8 +14,10 @@ import numpy as np from megengine.core.tensor.dtype import is_quantize from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level from megengine.utils.module_stats import ( + get_param_stats, print_flops_stats, print_params_stats, + print_summary, sizeof_fmt, ) from megengine.utils.network import Network @@ -69,6 +71,7 @@ def visualize( def process_name(name): return name.replace(".", "/").encode(encoding="utf-8") + summary = [["item", "value"]] node_list = [] flops_list = [] params_list = [] @@ -117,26 +120,15 @@ def visualize( ) ) if node.type == "ImmutableTensor": - param_dim = np.prod(node_oup.shape) - # TODO: consider other quantize dtypes - param_bytes = 1 if is_quantize(node_oup.dtype) else 4 + param_stats = get_param_stats(node.numpy()) # add tensor size attr if log_path: attr["size"] = AttrValue( - s=sizeof_fmt(param_dim * param_bytes).encode(encoding="utf-8") + s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8") ) - params_list.append( - dict( - name=node.name, - shape=node_oup.shape, - param_dim=param_dim, - bits=param_bytes * 8, - size=param_dim * param_bytes, - size_cum=0, - mean="{:.2g}".format(node.numpy().mean()), - std="{:.2g}".format(node.numpy().std()), - ) - ) + param_stats["name"] = node.name + params_list.append(param_stats) + # FIXME(MGE-2165): nodes outside network module may lead to unknown display bug if not len(node.name.split(".")) > 2 and not node in graph.input_vars: continue @@ -152,7 +144,9 @@ def visualize( total_flops, total_params = None, None if log_params: - total_params = print_params_stats(params_list, bar_length_max) + total_param_dims, total_param_size = print_params_stats( + params_list, bar_length_max + ) if log_flops: total_flops = print_flops_stats(flops_list, bar_length_max) @@ -167,6 +161,15 @@ def visualize( writer._get_file_writer().add_graph((graph_def, stepstats)) # summary + extra_info = { + "#ops": len(graph.all_oprs), + "#params": len(params_list), + "total_param_dims": sizeof_fmt(total_param_dims), + "total_param_size": sizeof_fmt(total_param_size), + "total_flops": sizeof_fmt(total_flops, suffix="OPs"), + "flops/param_size": "{:3.3f}".format(total_flops / total_param_size), + } + print_summary(**extra_info) # FIXME: remove this after resolving "span dist too large" warning _imperative_rt_logger.set_log_level(old_level) diff --git a/imperative/python/megengine/utils/module_stats.py b/imperative/python/megengine/utils/module_stats.py index af2a1cb5e..d063b983f 100644 --- a/imperative/python/megengine/utils/module_stats.py +++ b/imperative/python/megengine/utils/module_stats.py @@ -11,10 +11,10 @@ import numpy as np import tabulate import megengine as mge -import megengine.core.tensor.dtype as dtype import megengine.module as m import megengine.module.qat as qatm import megengine.module.quantized as qm +from megengine.core.tensor.dtype import get_dtype_bit from megengine.functional.tensor import zeros try: @@ -115,13 +115,13 @@ def print_flops_stats(flops, bar_length_max=20): total_flops_num += int(d["flops_num"]) d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs") - for i in flops: - f = i["flops_num"] - i["flops"] = sizeof_fmt(f, suffix="OPs") - r = i["ratio"] = f / total_flops_num - i["percentage"] = "{:.2f}%".format(r * 100) + for d in flops: + f = d["flops_num"] + d["flops"] = sizeof_fmt(f, suffix="OPs") + r = d["ratio"] = f / total_flops_num + d["percentage"] = "{:.2f}%".format(r * 100) bar_length = int(f / max_flops_num * bar_length_max) - i["bar"] = "#" * bar_length + d["bar"] = "#" * bar_length header = [ "name", @@ -136,7 +136,7 @@ def print_flops_stats(flops, bar_length_max=20): total_flops_str = sizeof_fmt(total_flops_num, suffix="OPs") total_var_size = sum( - sum(s[1] if len(s) > 1 else 0 for s in i["output_shapes"]) for i in flops + sum(s[1] if len(s) > 1 else 0 for s in d["output_shapes"]) for d in flops ) flops.append( dict(name="total", flops=total_flops_str, output_shapes=total_var_size) @@ -147,16 +147,29 @@ def print_flops_stats(flops, bar_length_max=20): return total_flops_num +def get_param_stats(param: np.ndarray): + nbits = get_dtype_bit(param.dtype.name) + shape = param.shape + param_dim = np.prod(param.shape) + param_size = param_dim * nbits // 8 + return { + "shape": shape, + "mean": param.mean(), + "std": param.std(), + "param_dim": param_dim, + "nbits": nbits, + "size": param_size, + } + + def print_params_stats(params, bar_length_max=20): total_param_dims, total_param_size = 0, 0 for d in params: total_param_dims += int(d["param_dim"]) total_param_size += int(d["size"]) + ratio = d["size"] / total_param_size d["size"] = sizeof_fmt(d["size"]) d["size_cum"] = sizeof_fmt(total_param_size) - - for d in params: - ratio = d["param_dim"] / total_param_dims d["ratio"] = ratio d["percentage"] = "{:.2f}%".format(ratio * 100) @@ -186,7 +199,13 @@ def print_params_stats(params, bar_length_max=20): "param stats: \n" + tabulate.tabulate(dict2table(params, header=header)) ) - return total_param_size + return total_param_dims, total_param_size + + +def print_summary(**kwargs): + data = [["item", "value"]] + data.extend(list(kwargs.items())) + logger.info("summary\n" + tabulate.tabulate(data)) def module_stats( @@ -206,14 +225,6 @@ def module_stats( :param log_flops: whether print and record op flops. """ - def get_byteswidth(tensor): - if dtype.is_quantize(tensor.dtype): - return 1 - # elif dtype.is_bfloat16(tensor.dtype): - # return 2 - else: - return 4 - def module_stats_hook(module, input, output, name=""): class_name = str(module.__class__).split(".")[-1].split("'")[0] @@ -237,39 +248,15 @@ def module_stats( if hasattr(module, "weight") and module.weight is not None: w = module.weight - value = w.numpy() - param_dim = np.prod(w.shape) - param_bytes = get_byteswidth(w) - params.append( - dict( - name=name + "-w", - shape=w.shape, - param_dim=param_dim, - bits=param_bytes * 8, - size=param_dim * param_bytes, - size_cum=0, - mean="{:.2g}".format(value.mean()), - std="{:.2g}".format(value.std()), - ) - ) + param_stats = get_param_stats(w.numpy()) + param_stats["name"] = name + "-w" + params.append(param_stats) if hasattr(module, "bias") and module.bias is not None: b = module.bias - value = b.numpy() - param_dim = np.prod(b.shape) - param_bytes = get_byteswidth(b) - params.append( - dict( - name=name + "-b", - shape=b.shape, - param_dim=param_dim, - bits=param_bytes * 8, - size=param_dim * param_bytes, - size_cum=0, - mean="{:.2g}".format(value.mean()), - std="{:.2g}".format(value.std()), - ) - ) + param_stats = get_param_stats(b.numpy()) + param_stats["name"] = name + "-b" + params.append(param_stats) # multiple inputs to the network if not isinstance(input_size[0], tuple): @@ -293,8 +280,17 @@ def module_stats( total_flops, total_params = 0, 0 if log_params: - total_params = print_params_stats(params, bar_length_max) + total_param_dims, total_param_size = print_params_stats(params, bar_length_max) if log_flops: total_flops = print_flops_stats(flops, bar_length_max) + extra_info = { + "#params": len(params), + "total_param_dims": sizeof_fmt(total_param_dims), + "total_param_size": sizeof_fmt(total_param_size), + "total_flops": sizeof_fmt(total_flops, suffix="OPs"), + "flops/param_size": "{:3.3f}".format(total_flops / total_param_size), + } + print_summary(**extra_info) + return total_params, total_flops -- GitLab