提交 da167cbc 编写于 作者: M Megvii Engine Team

fix(mge/utils): fix module stats calculate flops bug for group conv and remove model status change

GitOrigin-RevId: 647dc6eb66831a805a3f52e1d58beb6c72b3df01
上级 6bb9a255
......@@ -5,6 +5,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import contextlib
from functools import partial
import numpy as np
......@@ -87,30 +88,20 @@ def disable_receptive_field():
@register_flops(
m.Conv1d, m.Conv2d, m.Conv3d,
m.Conv1d, m.Conv2d, m.Conv3d, m.ConvTranspose2d, m.LocalConv2d, m.DeformableConv2d
)
def flops_convNd(module: m.Conv2d, inputs, outputs):
bias = 1 if module.bias is not None else 0
group = module.groups
ic = inputs[0].shape[1]
oc = outputs[0].shape[1]
goc = oc // group
gic = ic // group
N = outputs[0].shape[0]
HW = np.prod(outputs[0].shape[2:])
# N x Cout x H x W x (Cin x Kw x Kh + bias)
return N * HW * goc * (gic * np.prod(module.kernel_size) + bias)
@register_flops(m.ConvTranspose2d)
def flops_deconvNd(module: m.ConvTranspose2d, inputs, outputs):
return np.prod(inputs[0].shape) * outputs[0].shape[1] * np.prod(module.kernel_size)
return np.prod(outputs[0].shape) * (
module.in_channels // module.groups * np.prod(module.kernel_size) + bias
)
@register_flops(m.Linear)
def flops_linear(module: m.Linear, inputs, outputs):
bias = 1 if module.bias is not None else 0
return np.prod(outputs[0].shape) * module.in_features
bias = module.out_features if module.bias is not None else 0
return np.prod(outputs[0].shape) * module.in_features + bias
@register_flops(m.BatchMatMulActivation)
......@@ -340,6 +331,31 @@ def module_stats(
param_stats["name"] = name + "-b"
params.append(param_stats)
@contextlib.contextmanager
def adjust_stats(module, training=False):
"""Adjust module to training/eval mode temporarily.
Args:
module (M.Module): used module.
training (bool): training mode. True for train mode, False fro eval mode.
"""
def recursive_backup_stats(module, mode):
for m in module.modules():
# save prev status to _prev_training
m._prev_training = m.training
m.train(mode, recursive=False)
def recursive_recover_stats(module):
for m in module.modules():
# recover prev status and delete attribute
m.training = m._prev_training
delattr(m, "_prev_training")
recursive_backup_stats(module, mode=training)
yield module
recursive_recover_stats(module)
# multiple inputs to the network
if not isinstance(input_size[0], tuple):
input_size = [input_size]
......@@ -355,8 +371,9 @@ def module_stats(
)
inputs = [zeros(in_size, dtype=np.float32) for in_size in input_size]
model.eval()
model(*inputs)
with adjust_stats(model, training=False) as model:
model(*inputs)
for h in hooks:
h.remove()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册