提交 536506c3 编写于 作者: M Megvii Engine Team

feat(functional): let interpolate support more modes

GitOrigin-RevId: 9693a1ac638658ca1ee0b5d0eff507d74fcc996d
上级 d811dc54
......@@ -17,7 +17,7 @@ from ..core.tensor.utils import astensor1d
from ..tensor import Tensor
from .elemwise import floor
from .math import argsort
from .tensor import broadcast_to, concat, expand_dims, reshape
from .tensor import broadcast_to, concat, expand_dims, reshape, transpose
def cvt_color(inp: Tensor, mode: str = ""):
......@@ -474,7 +474,7 @@ def interpolate(
:param size: size of the output tensor. Default: None
:param scale_factor: scaling factor of the output tensor. Default: None
:param mode: interpolation methods, acceptable values are:
"bilinear", "linear". Default: "bilinear"
"bilinear", "linear", "bicubic" and "nearest". Default: "bilinear"
:param align_corners: This only has an effect when `mode`
is "bilinear" or "linear". Geometrically, we consider the pixels of the input
and output as squares rather than points. If set to ``True``, the input
......@@ -511,8 +511,8 @@ def interpolate(
"""
mode = mode.lower()
if mode not in ["bilinear", "linear"]:
raise ValueError("interpolate only support linear or bilinear mode")
if mode not in ["bilinear", "linear", "bicubic", "nearest"]:
raise ValueError("unsupported interpolate mode: {}".format(mode))
if mode not in ["bilinear", "linear"]:
if align_corners is not None:
raise ValueError(
......@@ -625,9 +625,21 @@ def interpolate(
weight = broadcast_to(weight, (inp.shape[0], 3, 3))
weight = weight.astype("float32")
ret = warp_perspective(inp, weight, dsize, interp_mode="linear")
if mode == "linear":
ret = reshape(ret, ret.shape[0:3])
if mode in ["linear", "bilinear"]:
ret = warp_perspective(inp, weight, dsize, interp_mode="linear")
if mode == "linear":
ret = reshape(ret, ret.shape[0:3])
else:
# only NHWC format support "cubic" and "nearest" mode
inp = transpose(inp, (0, 2, 3, 1))
ret = warp_perspective(
inp,
weight,
dsize,
format="NHWC",
interp_mode="cubic" if mode == "bicubic" else mode,
)
ret = transpose(ret, (0, 3, 1, 2))
return ret
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册