未验证 提交 a8c182a4 编写于 作者: C Cheng Li 提交者: GitHub

fix interpolate flops compute (#3782)

上级 d81a6ad6
......@@ -682,15 +682,23 @@ def _instance_norm_flops_compute(
return input.numel() * (5 if has_affine else 4), 0
def _upsample_flops_compute(input, **kwargs):
def _upsample_flops_compute(*args, **kwargs):
input = args[0]
size = kwargs.get('size', None)
if size is None and len(args) > 1:
size = args[1]
if size is not None:
if isinstance(size, tuple) or isinstance(size, list):
return int(_prod(size)), 0
else:
return int(size), 0
scale_factor = kwargs.get('scale_factor', None)
if scale_factor is None and len(args) > 2:
scale_factor = args[2]
assert scale_factor is not None, "either size or scale_factor should be defined"
flops = input.numel()
if isinstance(scale_factor, tuple) and len(scale_factor) == len(input):
flops * int(_prod(scale_factor))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册