From 786c36ff42061cb102e063d21e804a1f33d3e613 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 16 Mar 2021 15:28:11 +0800 Subject: [PATCH] fix(mge/tools): rename `net_stats` in function and examples to match file name GitOrigin-RevId: 82a1377d6688f915d4f4a32354a3c3f8db712f9f --- imperative/python/megengine/tools/network_visualize.py | 2 +- imperative/python/megengine/utils/module_stats.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/imperative/python/megengine/tools/network_visualize.py b/imperative/python/megengine/tools/network_visualize.py index 181a3027..a77641b0 100755 --- a/imperative/python/megengine/tools/network_visualize.py +++ b/imperative/python/megengine/tools/network_visualize.py @@ -31,7 +31,7 @@ def visualize( ): r""" Load megengine dumped model and visualize graph structure with tensorboard log files. - Can also record and print model's statistics like :func:`~.net_stats` + Can also record and print model's statistics like :func:`~.module_stats` :param model_path: dir path for megengine dumped model. :param log_path: dir path for tensorboard graph log. diff --git a/imperative/python/megengine/utils/module_stats.py b/imperative/python/megengine/utils/module_stats.py index 46753cbe..c091e321 100644 --- a/imperative/python/megengine/utils/module_stats.py +++ b/imperative/python/megengine/utils/module_stats.py @@ -187,7 +187,7 @@ def print_params_stats(params, bar_length_max=20): return total_param_size -def net_stats( +def module_stats( model: m.Module, input_size: int, bar_length_max: int = 20, @@ -212,7 +212,7 @@ def net_stats( else: return 4 - def net_stats_hook(module, input, output, name=""): + def module_stats_hook(module, input, output, name=""): class_name = str(module.__class__).split(".")[-1].split("'")[0] flops_fun = CALC_FLOPS.get(type(module)) @@ -280,7 +280,7 @@ def net_stats( for (name, module) in model.named_modules(): if isinstance(module, hook_modules): hooks.append( - module.register_forward_hook(partial(net_stats_hook, name=name)) + module.register_forward_hook(partial(module_stats_hook, name=name)) ) inputs = [zeros(in_size, dtype=np.float32) for in_size in input_size] -- GitLab