diff --git a/imperative/python/megengine/tools/network_visualize.py b/imperative/python/megengine/tools/network_visualize.py index 181a302727e46c3d7a4d4d71369cd69c93eea5eb..a77641b00a4524a83603e7d6b67c52f626852fa4 100755 --- a/imperative/python/megengine/tools/network_visualize.py +++ b/imperative/python/megengine/tools/network_visualize.py @@ -31,7 +31,7 @@ def visualize( ): r""" Load megengine dumped model and visualize graph structure with tensorboard log files. - Can also record and print model's statistics like :func:`~.net_stats` + Can also record and print model's statistics like :func:`~.module_stats` :param model_path: dir path for megengine dumped model. :param log_path: dir path for tensorboard graph log. diff --git a/imperative/python/megengine/utils/module_stats.py b/imperative/python/megengine/utils/module_stats.py index 46753cbe92bd8a32d824a21509533e663e6fe203..c091e3215554e7579f3db741e9605be86aa3a6a0 100644 --- a/imperative/python/megengine/utils/module_stats.py +++ b/imperative/python/megengine/utils/module_stats.py @@ -187,7 +187,7 @@ def print_params_stats(params, bar_length_max=20): return total_param_size -def net_stats( +def module_stats( model: m.Module, input_size: int, bar_length_max: int = 20, @@ -212,7 +212,7 @@ def net_stats( else: return 4 - def net_stats_hook(module, input, output, name=""): + def module_stats_hook(module, input, output, name=""): class_name = str(module.__class__).split(".")[-1].split("'")[0] flops_fun = CALC_FLOPS.get(type(module)) @@ -280,7 +280,7 @@ def net_stats( for (name, module) in model.named_modules(): if isinstance(module, hook_modules): hooks.append( - module.register_forward_hook(partial(net_stats_hook, name=name)) + module.register_forward_hook(partial(module_stats_hook, name=name)) ) inputs = [zeros(in_size, dtype=np.float32) for in_size in input_size]