提交 19466046 编写于 作者: M Megvii Engine Team

feat(mge/funcitonal): add cvt_color opr python interface

GitOrigin-RevId: 29e069fb2334a20545021f37a7133f7bfdbafd0b
上级 412d1f0c
......@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=redefined-builtin
from .elemwise import *
from .img_proc import *
from .math import *
from .nn import *
from .tensor import *
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin
from ..tensor import Tensor
__all__ = [
"cvt_color",
]
def cvt_color(inp: Tensor, mode: str = ""):
r"""
Convert images from one format to another
:param inp: input images.
:param mode: format mode.
:return: convert result.
Examples:
.. testcode::
import numpy as np
import megengine as mge
import megengine.functional as F
x = mge.tensor(np.array([[[[-0.58675045, 1.7526233, 0.10702174]]]]).astype(np.float32))
y = F.img_proc.cvt_color(x, mode="RGB2GRAY")
print(y.numpy())
Outputs:
.. testoutput::
[[[[0.86555195]]]]
"""
assert mode in builtin.CvtColor.Mode.__dict__, "unspport mode for cvt_color"
mode = getattr(builtin.CvtColor.Mode, mode)
assert isinstance(mode, builtin.CvtColor.Mode)
op = builtin.CvtColor(mode=mode)
(out,) = apply(op, inp)
return out
......@@ -704,3 +704,14 @@ def test_argmxx_on_inf():
assert all(run_argmax() >= 0)
assert all(run_argmin() >= 0)
def test_cvt_color():
def rgb2gray(rgb):
return np.dot(rgb[..., :3], [0.299, 0.587, 0.114])
inp = np.random.randn(3, 3, 3, 3).astype(np.float32)
out = np.expand_dims(rgb2gray(inp), 3).astype(np.float32)
x = tensor(inp)
y = F.img_proc.cvt_color(x, mode="RGB2GRAY")
np.testing.assert_allclose(y.numpy(), out, atol=1e-5)
/**
* \file imperative/src/impl/ops/img_proc.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/imgproc.h"
#include "../op_trait.h"
namespace mgb {
namespace imperative {
namespace {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const CvtColor&>(def);
mgb_assert(inputs.size() == 1);
return opr::CvtColor::make(inputs[0], op.param());
}
OP_TRAIT_REG(CvtColor, CvtColor)
.apply_on_var_node(apply_on_var_node)
.fallback();
}
}
}
\ No newline at end of file
......@@ -254,4 +254,6 @@ def TensorRTRuntime: MgbHashableOp<"TensorRTRuntime"> {
);
}
def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>;
#endif // MGB_OPS
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册