提交 e509e7a0 编写于 作者: M Megvii Engine Team

fix(mge/network_visualize): fix warning 'span dist too large'

GitOrigin-RevId: 8db691b8ac9cac81bb4761a5b420c54bcca167e9
上级 946ab374
...@@ -74,8 +74,6 @@ def visualize( ...@@ -74,8 +74,6 @@ def visualize(
exc_info=True, exc_info=True,
) )
return return
# FIXME: remove this after resolving "span dist too large" warning
old_level = set_mgb_log_level(logging.ERROR)
enable_receptive_field() enable_receptive_field()
...@@ -136,13 +134,13 @@ def visualize( ...@@ -136,13 +134,13 @@ def visualize(
flops_stats["class_name"] = node.type flops_stats["class_name"] = node.type
flops_list.append(flops_stats) flops_list.append(flops_stats)
acts = get_activation_stats(node_oup.numpy()) acts = get_activation_stats(node_oup)
acts["name"] = node.name acts["name"] = node.name
acts["class_name"] = node.type acts["class_name"] = node.type
activations_list.append(acts) activations_list.append(acts)
if node.type == "ImmutableTensor": if node.type == "ImmutableTensor":
param_stats = get_param_stats(node.numpy()) param_stats = get_param_stats(node_oup)
# add tensor size attr # add tensor size attr
if log_path: if log_path:
attr["size"] = AttrValue( attr["size"] = AttrValue(
...@@ -209,9 +207,6 @@ def visualize( ...@@ -209,9 +207,6 @@ def visualize(
print_summary(**extra_info) print_summary(**extra_info)
# FIXME: remove this after resolving "span dist too large" warning
_imperative_rt_logger.set_log_level(old_level)
return ( return (
total_stats( total_stats(
param_size=total_param_size, flops=total_flops, act_size=total_act_size, param_size=total_param_size, flops=total_flops, act_size=total_act_size,
......
...@@ -15,6 +15,8 @@ import megengine as mge ...@@ -15,6 +15,8 @@ import megengine as mge
import megengine.module as m import megengine.module as m
import megengine.module.qat as qatm import megengine.module.qat as qatm
import megengine.module.quantized as qm import megengine.module.quantized as qm
from megengine import Tensor
from megengine import functional as F
from megengine.core.tensor.dtype import get_dtype_bit from megengine.core.tensor.dtype import get_dtype_bit
from megengine.functional.tensor import zeros from megengine.functional.tensor import zeros
...@@ -152,6 +154,16 @@ hook_modules = ( ...@@ -152,6 +154,16 @@ hook_modules = (
) )
def _mean(inp):
inp = mge.tensor(inp)
return F.mean(inp).numpy()
def _std(inp):
inp = mge.tensor(inp)
return F.std(inp).numpy()
def dict2table(list_of_dict, header): def dict2table(list_of_dict, header):
table_data = [header] table_data = [header]
for d in list_of_dict: for d in list_of_dict:
...@@ -266,16 +278,16 @@ def print_op_stats(flops): ...@@ -266,16 +278,16 @@ def print_op_stats(flops):
logger.info("flops stats: \n" + tabulate.tabulate(dict2table(flops, header=header))) logger.info("flops stats: \n" + tabulate.tabulate(dict2table(flops, header=header)))
def get_param_stats(param: np.ndarray): def get_param_stats(param: Tensor):
nbits = get_dtype_bit(param.dtype.name) nbits = get_dtype_bit(np.dtype(param.dtype).name)
shape = param.shape shape = param.shape
param_dim = np.prod(param.shape) param_dim = np.prod(param.shape)
param_size = param_dim * nbits // 8 param_size = param_dim * nbits // 8
return { return {
"dtype": param.dtype, "dtype": np.dtype(param.dtype),
"shape": shape, "shape": shape,
"mean": "{:.3g}".format(param.mean()), "mean": "{:.3g}".format(_mean(param)),
"std": "{:.3g}".format(param.std()), "std": "{:.3g}".format(_std(param)),
"param_dim": param_dim, "param_dim": param_dim,
"nbits": nbits, "nbits": nbits,
"size": param_size, "size": param_size,
...@@ -323,9 +335,9 @@ def print_param_stats(params): ...@@ -323,9 +335,9 @@ def print_param_stats(params):
) )
def get_activation_stats(output: np.ndarray): def get_activation_stats(output: Tensor):
out_shape = output.shape out_shape = output.shape
activations_dtype = output.dtype activations_dtype = np.dtype(output.dtype)
nbits = get_dtype_bit(activations_dtype.name) nbits = get_dtype_bit(activations_dtype.name)
act_dim = np.prod(out_shape) act_dim = np.prod(out_shape)
act_size = act_dim * nbits // 8 act_size = act_dim * nbits // 8
...@@ -333,8 +345,8 @@ def get_activation_stats(output: np.ndarray): ...@@ -333,8 +345,8 @@ def get_activation_stats(output: np.ndarray):
"dtype": activations_dtype, "dtype": activations_dtype,
"shape": out_shape, "shape": out_shape,
"act_dim": act_dim, "act_dim": act_dim,
"mean": "{:.3g}".format(output.mean()), "mean": "{:.3g}".format(_mean(output)),
"std": "{:.3g}".format(output.std()), "std": "{:.3g}".format(_std(output)),
"nbits": nbits, "nbits": nbits,
"size": act_size, "size": act_size,
} }
...@@ -418,20 +430,20 @@ def module_stats( ...@@ -418,20 +430,20 @@ def module_stats(
if hasattr(module, "weight") and module.weight is not None: if hasattr(module, "weight") and module.weight is not None:
w = module.weight w = module.weight
param_stats = get_param_stats(w.numpy()) param_stats = get_param_stats(w)
param_stats["name"] = name + "-w" param_stats["name"] = name + "-w"
params.append(param_stats) params.append(param_stats)
if hasattr(module, "bias") and module.bias is not None: if hasattr(module, "bias") and module.bias is not None:
b = module.bias b = module.bias
param_stats = get_param_stats(b.numpy()) param_stats = get_param_stats(b)
param_stats["name"] = name + "-b" param_stats["name"] = name + "-b"
params.append(param_stats) params.append(param_stats)
if not isinstance(outputs, tuple) or not isinstance(outputs, list): if not isinstance(outputs, tuple) or not isinstance(outputs, list):
output = outputs.numpy() output = outputs
else: else:
output = outputs[0].numpy() output = outputs[0]
activation_stats = get_activation_stats(output) activation_stats = get_activation_stats(output)
activation_stats["name"] = name activation_stats["name"] = name
activation_stats["class_name"] = class_name activation_stats["class_name"] = class_name
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册