提交 e9982d61 编写于 作者: M Megvii Engine Team

chore(mge/functional): add compatible code for functional api

GitOrigin-RevId: 3b2f829cc5c01e017d2ef76823b75a3058611152
上级 8928c77c
......@@ -12,6 +12,7 @@ from .elemwise import *
from .math import *
from .nn import *
from .tensor import *
from .utils import *
from . import distributed # isort:skip
......
......@@ -19,6 +19,7 @@ from ..core.tensor.utils import astype
from ..device import get_default_device
from ..jit.tracing import is_tracing
from ..tensor import Tensor
from ..utils.deprecation import deprecated_func
__all__ = [
"abs",
......@@ -567,3 +568,10 @@ def clip(x: Tensor, lower=None, upper=None) -> Tensor:
return maximum(x, lower)
else:
return minimum(x, upper)
sigmoid = deprecated_func("1.3", "megengine.functional.nn", "sigmoid", True)
hsigmoid = deprecated_func("1.3", "megengine.functional.nn", "hsigmoid", True)
relu = deprecated_func("1.3", "megengine.functional.nn", "relu", True)
relu6 = deprecated_func("1.3", "megengine.functional.nn", "relu6", True)
hswish = deprecated_func("1.3", "megengine.functional.nn", "hswish", True)
......@@ -22,10 +22,11 @@ from ..device import get_default_device
from ..distributed import WORLD, is_distributed
from ..random import uniform
from ..tensor import Tensor
from ..utils.deprecation import deprecated_func
from ..utils.tuple_function import _pair, _pair_nonzero, _triple, _triple_nonzero
from .debug_param import get_execution_strategy
from .distributed import all_reduce_sum
from .elemwise import exp, floor, log, log1p, maximum, minimum
from .elemwise import _elwise, exp, floor, log, log1p, maximum, minimum
from .math import argsort, matmul, max, prod, sum
from .tensor import (
broadcast_to,
......@@ -70,6 +71,10 @@ __all__ = [
"relu",
"relu6",
"hswish",
"resize",
"remap",
"warp_affine",
"warp_perspective",
]
......@@ -1434,43 +1439,6 @@ def nvof(src: Tensor, precision: int = 1) -> Tensor:
return apply(op, src)[0]
def _elwise(*args, mode):
tensor_args = list(filter(lambda x: isinstance(x, (Tensor, VarNode)), args))
if len(tensor_args) == 0:
dtype = utils.dtype_promotion(args)
first_arg = Tensor(args[0], dtype=dtype, device=get_default_device())
args = utils.convert_inputs(first_arg, *args[1:])
else:
args = utils.convert_inputs(*args)
if mode in (
Elemwise.Mode.TRUE_DIV,
Elemwise.Mode.EXP,
Elemwise.Mode.POW,
Elemwise.Mode.LOG,
Elemwise.Mode.EXPM1,
Elemwise.Mode.LOG1P,
Elemwise.Mode.TANH,
Elemwise.Mode.ACOS,
Elemwise.Mode.ASIN,
Elemwise.Mode.ATAN2,
Elemwise.Mode.CEIL,
Elemwise.Mode.COS,
Elemwise.Mode.FLOOR,
Elemwise.Mode.H_SWISH,
Elemwise.Mode.ROUND,
Elemwise.Mode.SIGMOID,
Elemwise.Mode.SIN,
):
if mode in (
Elemwise.Mode.CEIL,
Elemwise.Mode.FLOOR,
Elemwise.Mode.ROUND,
) and np.issubdtype(args[0].dtype, np.integer):
return args[0]
args = tuple(map(lambda x: astype(x, "float32"), args))
return _elwise_apply(args, mode)
def hswish(x):
"""
Element-wise `x * relu6(x + 3) / 6`.
......@@ -1518,5 +1486,16 @@ def relu6(x):
return minimum(maximum(x, 0), 6)
interpolate = deprecated_func("1.3", "megengine.functional.vision", "interpolate", True)
roi_pooling = deprecated_func("1.3", "megengine.functional.vision", "roi_pooling", True)
roi_align = deprecated_func("1.3", "megengine.functional.vision", "roi_align", True)
nms = deprecated_func("1.3", "megengine.functional.vision", "nms", True)
resize = deprecated_func("1.3", "megengine.functional.vision", "resize", True)
remap = deprecated_func("1.3", "megengine.functional.vision", "remap", True)
warp_affine = deprecated_func("1.3", "megengine.functional.vision", "warp_affine", True)
warp_perspective = deprecated_func(
"1.3", "megengine.functional.vision", "warp_perspective", True
)
from .loss import * # isort:skip
from .quantized import conv_bias_activation # isort:skip
......@@ -10,8 +10,11 @@ from ..core._imperative_rt.core2 import apply
from ..core._imperative_rt.core2 import sync as _sync
from ..core.ops.builtin import AssertEqual
from ..tensor import Tensor
from ..utils.deprecation import deprecated_func
from .elemwise import abs, maximum, minimum
__all__ = ["topk_accuracy"]
def _assert_equal(
expect: Tensor, actual: Tensor, *, maxerr: float = 0.0001, verbose: bool = False
......@@ -55,3 +58,9 @@ def _assert_equal(
result = apply(AssertEqual(maxerr=maxerr, verbose=verbose), expect, actual, err)[0]
_sync() # sync interpreter to get exception
return result
topk_accuracy = deprecated_func(
"1.3", "megengine.functional.metric", "topk_accuracy", True
)
copy = deprecated_func("1.3", "megengine.functional.tensor", "copy", True)
......@@ -5,4 +5,36 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import importlib
import warnings
from deprecated.sphinx import deprecated
def deprecated_func(version, origin, name, tbd):
"""
:param version: version to deprecate this function
:param origin: origin module path
:param name: function name
:param tbd: to be discussed, if true, ignore warnings
"""
should_warning = not tbd
def wrapper(*args, **kwargs):
nonlocal should_warning
module = importlib.import_module(origin)
func = module.__getattribute__(name)
if should_warning:
with warnings.catch_warnings():
warnings.simplefilter(action="always")
warnings.warn(
"Call to deprecated function {}. (use {}.{} instead) -- Deprecated since version {}.".format(
name, origin, name, version
),
category=DeprecationWarning,
stacklevel=2,
)
should_warning = False
return func(*args, **kwargs)
return wrapper
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册