提交 4aaae995 编写于 作者: M Megvii Engine Team

feat(mge/functional): add python wrapper to resize opr

GitOrigin-RevId: b7cc6dd829531d750c6d61c9dc316d7999d82cfc
上级 d04b4bc0
......@@ -7,7 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=too-many-lines
from typing import Optional, Sequence, Tuple, Union
from typing import Iterable, Optional, Sequence, Tuple, Union
from ..core._imperative_rt import CompNode
from ..core._imperative_rt.core2 import apply
......@@ -58,6 +58,7 @@ __all__ = [
"one_hot",
"prelu",
"remap",
"resize",
"softmax",
"softplus",
"svd",
......@@ -878,6 +879,41 @@ def one_hot(inp: Tensor, num_classes: int) -> Tensor:
return result
def resize(
inp: Tensor, target_shape: Iterable[int], interp_mode: str = "LINEAR"
) -> Tensor:
r"""
Applies resize transformation to batched 2D images.
:param inp: `(N, C, H, W)` input tensor. Currently only support "NCHW" format.
:param target_shape: `(H, W)` target images shape.
:param interp_mode: interpolation methods. Defaule mode is "LINEAR", Currently only support "LINEAR".
Examples:
.. testcode::
import numpy as np
from megengine import tensor
import megengine.functional as F
x = tensor(np.random.randn(10, 3, 32, 32))
out = F.resize(x, (16, 16))
print(out.numpy().shape)
Outputs:
.. testoutput::
(10, 3, 16, 16)
"""
op = builtin.Resize(imode=interp_mode, format="NCHW")
shape = astensor1d(target_shape, inp, dtype="int32", device=inp.device)
(result,) = apply(op, inp, shape)
return result
def warp_perspective(
inp: Tensor,
M: Tensor,
......
......@@ -373,6 +373,17 @@ def test_Broadcast():
np.testing.assert_equal(np.ones((3, 3, 1), dtype=np.float32) * 10, x.grad.numpy())
def test_resize():
x_np = np.random.rand(3, 3, 32, 32).astype("float32")
x = mge.Tensor(x_np)
grad = Grad().wrt(x, callback=save_to(x))
y = F.resize(x, (16, 16))
grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones(x_np.shape, dtype=np.float32) / 4, x.grad.numpy())
def test_Reduce_sum():
x_np = np.random.rand(3, 3).astype("float32")
x = mge.Tensor(x_np)
......
......@@ -328,6 +328,31 @@ def test_one_hot():
onehot_high_dimension()
def test_resize():
# check shape
test_cases = [
[(1, 1, 10, 10), (5, 5)],
[(1, 3, 10, 10), (20, 20)],
[(10, 1, 10, 10), (1, 1)],
[(10, 10, 1, 1), (10, 10)],
]
for inp_shape, target_shape in test_cases:
x = tensor(np.random.randn(*inp_shape), dtype=np.float32)
out = F.resize(x, target_shape, interp_mode="LINEAR")
assert out.shape[0] == x.shape[0] and out.shape[1] == x.shape[1]
assert out.shape[2] == target_shape[0] and out.shape[3] == target_shape[1]
# check value
x = tensor(np.ones((3, 3, 10, 10)), dtype=np.float32)
out = F.resize(x, (15, 5), interp_mode="LINEAR")
np.testing.assert_equal(out.numpy(), np.ones((3, 3, 15, 5)).astype(np.float32))
np_x = np.arange(32)
x = tensor(np_x).astype(np.float32).reshape(1, 1, 32, 1)
out = F.resize(x, (1, 1), interp_mode="LINEAR")
np.testing.assert_equal(out.item(), np_x.mean())
def test_warp_perspective():
inp_shape = (1, 1, 4, 4)
x = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
......
/**
* \file imperative/src/impl/ops/resize.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/imgproc.h"
#include "../op_trait.h"
namespace mgb {
namespace imperative {
namespace {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const Resize&>(def);
mgb_assert(inputs.size() == 2);
return opr::Resize::make(inputs[0], inputs[1], op.param());
}
OP_TRAIT_REG(Resize, Resize)
.apply_on_var_node(apply_on_var_node)
.fallback();
} // anonymous namespace
} // namespace imperative
} // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -76,6 +76,8 @@ def WarpPerspective: MgbHashableOp<"WarpPerspective", [WarpPerspectiveParam]>;
def Remap: MgbHashableOp<"Remap", [RemapParam]>;
def Resize: MgbHashableOp<"Resize", [ResizeParam]>;
def IndexingOneHot: MgbHashableOp<"IndexingOneHot", [AxisParam]>;
def IndexingSetOneHot: MgbHashableOp<"IndexingSetOneHot", [AxisParam]>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册