提交 b10238ac 编写于 作者: M Megvii Engine Team

feat(mge/tools): add support of receptive_field stats for NetworkNode

GitOrigin-RevId: 11ef3354689d343883348d4129bc89db784e3fe0
上级 84c2a5c2
......@@ -7,6 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import argparse
import json
import logging
import numpy as np
......@@ -14,6 +15,7 @@ import numpy as np
from megengine.core.tensor.dtype import is_quantize
from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level
from megengine.utils.module_stats import (
get_flops_stats,
get_param_stats,
print_flops_stats,
print_params_stats,
......@@ -89,6 +91,7 @@ def visualize(
inp_list = [process_name(var.owner.name) for var in node.inputs]
if log_path:
# detail format see tensorboard/compat/proto/attr_value.proto
attr = {
"_output_shapes": AttrValue(
list=AttrValue.ListValue(
......@@ -101,24 +104,20 @@ def visualize(
]
)
),
"params": AttrValue(s=str(node.params).encode(encoding="utf-8")),
"dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")),
}
if hasattr(node, "calc_flops"):
flops_num = node.calc_flops()
flops_stats = get_flops_stats(node, node.inputs, node.outputs)
if flops_stats is not None:
# add op flops attr
if log_path:
if log_path and hasattr(flops_stats, "flops_num"):
attr["flops"] = AttrValue(
s=sizeof_fmt(flops_num).encode(encoding="utf-8")
)
flops_list.append(
dict(
name=node.name,
class_name=node.type,
input_shapes=[i.shape for i in node.inputs],
output_shapes=[o.shape for o in node.outputs],
flops_num=flops_num,
flops_cum=0,
s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8")
)
)
flops_stats["name"] = node.name
flops_stats["class_name"] = node.type
flops_list.append(flops_stats)
if node.type == "ImmutableTensor":
param_stats = get_param_stats(node.numpy())
# add tensor size attr
......@@ -132,6 +131,7 @@ def visualize(
# FIXME(MGE-2165): nodes outside network module may lead to unknown display bug
if not len(node.name.split(".")) > 2 and not node in graph.input_vars:
continue
if log_path:
node_list.append(
NodeDef(
......@@ -141,14 +141,26 @@ def visualize(
attr=attr,
)
)
# summary
extra_info = {
"#ops": len(graph.all_oprs),
"#params": len(params_list),
}
total_flops, total_params = None, None
total_flops, total_param_dims, total_param_size = 0, 0, 0
if log_params:
total_param_dims, total_param_size = print_params_stats(
params_list, bar_length_max
)
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims)
extra_info["total_param_size"] = sizeof_fmt(total_param_size)
if log_flops:
total_flops = print_flops_stats(flops_list, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if log_params and log_flops:
extra_info["flops/param_size"] = "{:3.3f}".format(
total_flops / total_param_size
)
if log_path:
graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))
......@@ -160,21 +172,12 @@ def visualize(
writer = SummaryWriter(log_path)
writer._get_file_writer().add_graph((graph_def, stepstats))
# summary
extra_info = {
"#ops": len(graph.all_oprs),
"#params": len(params_list),
"total_param_dims": sizeof_fmt(total_param_dims),
"total_param_size": sizeof_fmt(total_param_size),
"total_flops": sizeof_fmt(total_flops, suffix="OPs"),
"flops/param_size": "{:3.3f}".format(total_flops / total_param_size),
}
print_summary(**extra_info)
# FIXME: remove this after resolving "span dist too large" warning
_imperative_rt_logger.set_log_level(old_level)
return total_params, total_flops
return total_param_size, total_flops
def main():
......
......@@ -26,61 +26,95 @@ logger = mge.get_logger(__name__)
logger.setLevel("INFO")
CALC_FLOPS = {}
def _register_modules(*modules):
_calc_flops_dict = {}
_calc_receptive_field_dict = {}
def _receptive_field_fallback(module, inputs, outputs):
assert not hasattr(module, "_rf")
assert not hasattr(module, "_stride")
if len(inputs) == 0:
# TODO: support other dimension
module._rf = (1, 1)
module._stride = (1, 1)
return module._rf, module._stride
rf, stride = preprocess_receptive_field(module, inputs, outputs)
module._rf = rf
module._stride = stride
return rf, stride
# key tuple, impl_dict, fallback
_iter_list = [
("flops_num", _calc_flops_dict, None),
(
("receptive_field", "stride"),
_calc_receptive_field_dict,
_receptive_field_fallback,
),
]
def _register_dict(*modules, dict=None):
def callback(impl):
for module in modules:
CALC_FLOPS[module] = impl
dict[module] = impl
return impl
return callback
@_register_modules(
m.Conv2d,
m.ConvTranspose2d,
m.LocalConv2d,
qm.Conv2d,
qm.ConvRelu2d,
qm.ConvBn2d,
qm.ConvBnRelu2d,
qatm.Conv2d,
qatm.ConvRelu2d,
qatm.ConvBn2d,
qatm.ConvBnRelu2d,
def register_flops(*modules):
return _register_dict(*modules, dict=_calc_flops_dict)
def register_receptive_field(*modules):
return _register_dict(*modules, dict=_calc_receptive_field_dict)
@register_flops(
m.Conv1d, m.Conv2d, m.Conv3d,
)
def count_convNd(module, input, output):
def flops_convNd(module: m.Conv2d, inputs, outputs):
bias = 1 if module.bias is not None else 0
group = module.groups
ic = input[0].shape[1]
oc = output[0].shape[1]
ic = inputs[0].shape[1]
oc = outputs[0].shape[1]
goc = oc // group
gic = ic // group
N = output[0].shape[0]
HW = np.prod(output[0].shape[2:])
N = outputs[0].shape[0]
HW = np.prod(outputs[0].shape[2:])
# N x Cout x H x W x (Cin x Kw x Kh + bias)
return N * HW * goc * (gic * np.prod(module.kernel_size) + bias)
@_register_modules(m.ConvTranspose2d)
def count_deconvNd(module, input, output):
return np.prod(input[0].shape) * output[0].shape[1] * np.prod(module.kernel_size)
@register_flops(m.ConvTranspose2d)
def flops_deconvNd(module: m.ConvTranspose2d, inputs, outputs):
return np.prod(inputs[0].shape) * outputs[0].shape[1] * np.prod(module.kernel_size)
@register_flops(m.Linear)
def flops_linear(module: m.Linear, inputs, outputs):
bias = 1 if module.bias is not None else 0
return np.prod(outputs[0].shape) * module.in_features
@_register_modules(m.Linear, qatm.Linear, qm.Linear)
def count_linear(module, input, output):
return np.prod(output[0].shape) * module.in_features
@register_flops(m.BatchMatMulActivation)
def flops_batchmatmul(module: m.BatchMatMulActivation, inputs, outputs):
bias = 1 if module.bias is not None else 0
x = inputs[0]
w = module.weight
batch_size = x.shape[0]
n, p = x.shape[1:]
_, m = w.shape[1:]
return n * (p + bias) * m * batch_size
# does not need import qat and quantized module since they inherit from float module.
hook_modules = (
m.Conv2d,
m.ConvTranspose2d,
m.LocalConv2d,
m.BatchNorm2d,
m.conv._ConvNd,
m.Linear,
m.BatchMatMulActivation,
)
......@@ -106,28 +140,71 @@ def sizeof_fmt(num, suffix="B"):
return "{}{:.1f} {}{}".format(sign_str, num, "Yi", suffix)
def preprocess_receptive_field(module, inputs, outputs):
# TODO: support other dimensions
pre_rf = (
max(getattr(i.owner, "_rf", (1, 1))[0] for i in inputs),
max(i.owner._rf[1] for i in inputs),
)
pre_stride = (
max(getattr(i.owner, "_stride", (1, 1))[0] for i in inputs),
max(i.owner._stride[1] for i in inputs),
)
return pre_rf, pre_stride
def get_flops_stats(module, inputs, outputs):
rst = {
"input_shapes": [i.shape for i in inputs],
"output_shapes": [o.shape for o in outputs],
}
valid_flag = False
for key, _dict, fallback in _iter_list:
for _type in _dict:
if isinstance(module, _type):
value = _dict[_type](module, inputs, outputs)
valid_flag = True
break
else:
if fallback is not None:
value = fallback(module, inputs, outputs)
continue
if isinstance(key, tuple):
assert isinstance(value, tuple)
for k, v in zip(key, value):
rst[k] = v
else:
rst[key] = value
if valid_flag:
return rst
else:
return None
return
def print_flops_stats(flops, bar_length_max=20):
flops_list = [i["flops_num"] for i in flops]
max_flops_num = max(flops_list + [0])
# calc total flops and set flops_cum
max_flops_num = max([i["flops_num"] for i in flops] + [0])
total_flops_num = 0
for d in flops:
total_flops_num += int(d["flops_num"])
d["flops_cum"] = sizeof_fmt(total_flops_num, suffix="OPs")
for d in flops:
f = d["flops_num"]
d["flops"] = sizeof_fmt(f, suffix="OPs")
r = d["ratio"] = f / total_flops_num
d["percentage"] = "{:.2f}%".format(r * 100)
bar_length = int(f / max_flops_num * bar_length_max)
ratio = d["ratio"] = d["flops_num"] / total_flops_num
d["percentage"] = "{:.2f}%".format(ratio * 100)
bar_length = int(d["flops_num"] / max_flops_num * bar_length_max)
d["bar"] = "#" * bar_length
d["flops"] = sizeof_fmt(d["flops_num"], suffix="OPs")
header = [
"name",
"class_name",
"input_shapes",
"output_shapes",
"receptive_field",
"stride",
"flops",
"flops_cum",
"percentage",
......@@ -154,8 +231,8 @@ def get_param_stats(param: np.ndarray):
param_size = param_dim * nbits // 8
return {
"shape": shape,
"mean": param.mean(),
"std": param.std(),
"mean": "{:.3g}".format(param.mean()),
"std": "{:.3g}".format(param.std()),
"param_dim": param_dim,
"nbits": nbits,
"size": param_size,
......@@ -163,21 +240,20 @@ def get_param_stats(param: np.ndarray):
def print_params_stats(params, bar_length_max=20):
max_size = max([d["size"] for d in params] + [0])
total_param_dims, total_param_size = 0, 0
for d in params:
total_param_dims += int(d["param_dim"])
total_param_size += int(d["size"])
ratio = d["size"] / total_param_size
d["size"] = sizeof_fmt(d["size"])
d["size_cum"] = sizeof_fmt(total_param_size)
d["ratio"] = ratio
d["percentage"] = "{:.2f}%".format(ratio * 100)
# construct bar
max_ratio = max([d["ratio"] for d in params])
for d in params:
bar_length = int(d["ratio"] / max_ratio * bar_length_max)
ratio = d["size"] / total_param_size
d["ratio"] = ratio
d["percentage"] = "{:.2f}%".format(ratio * 100)
bar_length = int(d["size"] / max_size * bar_length_max)
d["size_bar"] = "#" * bar_length
d["size"] = sizeof_fmt(d["size"])
param_size = sizeof_fmt(total_param_size)
params.append(dict(name="total", param_dim=total_param_dims, size=param_size,))
......@@ -225,26 +301,14 @@ def module_stats(
:param log_flops: whether print and record op flops.
"""
def module_stats_hook(module, input, output, name=""):
def module_stats_hook(module, inputs, outputs, name=""):
class_name = str(module.__class__).split(".")[-1].split("'")[0]
flops_fun = CALC_FLOPS.get(type(module))
if callable(flops_fun):
flops_num = flops_fun(module, input, output)
if not isinstance(output, (list, tuple)):
output = [output]
flops.append(
dict(
name=name,
class_name=class_name,
input_shapes=[i.shape for i in input],
output_shapes=[o.shape for o in output],
flops_num=flops_num,
flops_cum=0,
)
)
flops_stats = get_flops_stats(module, inputs, outputs)
if flops_stats is not None:
flops_stats["name"] = name
flops_stats["class_name"] = class_name
flops.append(flops_stats)
if hasattr(module, "weight") and module.weight is not None:
w = module.weight
......@@ -278,19 +342,22 @@ def module_stats(
for h in hooks:
h.remove()
total_flops, total_params = 0, 0
extra_info = {
"#params": len(params),
}
total_flops, total_param_dims, total_param_size = 0, 0, 0
if log_params:
total_param_dims, total_param_size = print_params_stats(params, bar_length_max)
extra_info["total_param_dims"] = sizeof_fmt(total_param_dims)
extra_info["total_param_size"] = sizeof_fmt(total_param_size)
if log_flops:
total_flops = print_flops_stats(flops, bar_length_max)
extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
if log_params and log_flops:
extra_info["flops/param_size"] = "{:3.3f}".format(
total_flops / total_param_size
)
extra_info = {
"#params": len(params),
"total_param_dims": sizeof_fmt(total_param_dims),
"total_param_size": sizeof_fmt(total_param_size),
"total_flops": sizeof_fmt(total_flops, suffix="OPs"),
"flops/param_size": "{:3.3f}".format(total_flops / total_param_size),
}
print_summary(**extra_info)
return total_params, total_flops
return total_param_size, total_flops
......@@ -18,6 +18,11 @@ from ..core.ops import builtin
from ..core.tensor.megbrain_graph import InputNode
from ..tensor import Tensor
from .comp_graph_tools import replace_vars
from .module_stats import (
preprocess_receptive_field,
register_flops,
register_receptive_field,
)
class NetworkNode:
......@@ -225,8 +230,21 @@ class Elemwise(OpNode):
type = "Elemwise"
opdef = builtin.Elemwise
def calc_flops(self):
return np.prod(self.outputs[0].shape)
class ElemwiseMultiType(OpNode):
type = "ElemwiseMultiType"
opdef = builtin.ElemwiseMultiType
@classmethod
def load(cls, opr):
obj = super(ElemwiseMultiType, cls).load(opr)
obj.params["dtype"] = opr.outputs[0].dtype
return obj
@register_flops(Elemwise, ElemwiseMultiType)
def flops_elemwise(opnode: Elemwise, inputs, outputs):
return np.prod(outputs[0].shape)
class Reduce(OpNode):
......@@ -255,20 +273,24 @@ class MatrixMul(OpNode):
type = "MatrixMul"
opdef = builtin.MatrixMul
def calc_flops(self):
assert len(self.inputs[0].shape) == 2 and len(self.outputs[0].shape) == 2
mid_shape = self.inputs[0].shape[1]
return np.prod(self.outputs[0].shape) * mid_shape
@register_flops(MatrixMul)
def flops_matmul(opnode: MatrixMul, inputs, outputs):
assert len(inputs[0].shape) == 2 and len(outputs[0].shape) == 2
mid_shape = inputs[0].shape[1]
return np.prod(outputs[0].shape) * mid_shape
class BatchedMatrixMul(OpNode):
type = "BatchedMatmul"
opdef = builtin.BatchedMatrixMul
def calc_flops(self):
assert len(self.inputs[0].shape) == 3 and len(self.outputs[0].shape) == 3
mid_shape = self.inputs[0].shape[2]
return np.prod(self.outputs[0].shape) * mid_shape
@register_flops(BatchedMatrixMul)
def flops_batchmatmul(opnode: BatchedMatrixMul, inputs, outputs):
assert len(inputs[0].shape) == 3 and len(outputs[0].shape) == 3
mid_shape = inputs[0].shape[2]
return np.prod(outputs[0].shape) * mid_shape
class Dot(OpNode):
......@@ -285,18 +307,6 @@ class ConvolutionForward(OpNode):
type = "Convolution"
opdef = builtin.Convolution
def calc_flops(self):
param_W_shape = self.inputs[1].shape
kh = param_W_shape[-2]
kw = param_W_shape[-1]
if len(param_W_shape) == 5:
num_input = param_W_shape[2]
else:
num_input = param_W_shape[1]
NCHW = np.prod(self.outputs[0].shape)
# N x Cout x H x W x (Cin x Kw x Kh)
return NCHW * (num_input * kw * kh)
class ConvolutionBackwardData(OpNode):
type = "ConvTranspose"
......@@ -343,17 +353,41 @@ class ConvBiasForward(OpNode):
obj.params["dtype"] = opr.outputs[0].dtype
return obj
def calc_flops(self):
param_W_shape = self.inputs[1].shape
kh = param_W_shape[-2]
kw = param_W_shape[-1]
if len(param_W_shape) == 5:
num_input = param_W_shape[2]
else:
num_input = param_W_shape[1]
NCHW = np.prod(self.outputs[0].shape)
# N x Cout x H x W x (Cin x Kw x Kh + bias)
return NCHW * (num_input * kw * kh + 1)
@register_flops(
ConvolutionForward, ConvBiasForward,
)
def flops_conv(opnode: ConvolutionForward, inputs, outputs):
param_W_shape = inputs[1].shape
kh = param_W_shape[-2]
kw = param_W_shape[-1]
if len(param_W_shape) == 5:
num_input = param_W_shape[2]
else:
num_input = param_W_shape[1]
NCHW = np.prod(outputs[0].shape)
bias = 1 if isinstance(opnode, ConvBiasForward) else 0
# N x Cout x H x W x (Cin x Kw x Kh)
return NCHW * (num_input * kw * kh + bias)
@register_receptive_field(ConvolutionForward, ConvBiasForward)
def receptive_field(opnode: ConvolutionForward, inputs, outputs):
pre_rf, pre_stride = preprocess_receptive_field(opnode, inputs, outputs)
param_W_shape = inputs[1].shape
kh = param_W_shape[-2]
kw = param_W_shape[-1]
rf = (
kh * pre_stride[0] + pre_rf[0] - pre_stride[0],
kw * pre_stride[1] + pre_rf[1] - pre_stride[1],
)
stride = (
opnode.params["stride_h"] * pre_stride[0],
opnode.params["stride_w"] * pre_stride[1],
)
opnode._rf = rf
opnode._stride = stride
return rf, stride
class BatchConvBiasForward(OpNode):
......@@ -652,20 +686,6 @@ class AssertEqual(OpNode):
opdef = builtin.AssertEqual
class ElemwiseMultiType(OpNode):
type = "ElemwiseMultiType"
opdef = builtin.ElemwiseMultiType
@classmethod
def load(cls, opr):
obj = super(ElemwiseMultiType, cls).load(opr)
obj.params["dtype"] = opr.outputs[0].dtype
return obj
def calc_flops(self):
return np.prod(self.outputs[0].shape)
class CvtColorForward(OpNode):
type = "CvtColor"
opdef = builtin.CvtColor
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册