提交 938152af 编写于 作者: M Megvii Engine Team

fix(mge/functional): convert input type to float32 for more elemwise op

GitOrigin-RevId: cf3bf8cb805a3229700dd2939393a3994bc59f35
上级 19466046
......@@ -27,9 +27,31 @@ from .utils import setscalar
_ElwMod = Elemwise.Mode
def _elwise(*args, mode):
def _elwise_apply(args, mode):
op = builtin.Elemwise(mode)
if mode in (_ElwMod.TRUE_DIV, _ElwMod.POW):
_isscalar = True
for i in args:
if isscalar(i) == False:
_isscalar = False
break
(result,) = apply(op, *args)
if _isscalar:
setscalar(result)
return result
def _elwise(*args, mode):
if mode in (
_ElwMod.TRUE_DIV,
_ElwMod.POW,
_ElwMod.CEIL,
_ElwMod.FLOOR,
_ElwMod.ROUND,
):
if mode in (_ElwMod.CEIL, _ElwMod.FLOOR, _ElwMod.ROUND) and np.issubdtype(
args[0].dtype, np.integer
):
return args[0]
args = tuple(
map(
lambda x: x.astype("float32")
......@@ -39,16 +61,7 @@ def _elwise(*args, mode):
)
)
args = utils.convert_inputs(*args)
(result,) = apply(op, *args)
_isscalar = True
for i in args:
if isscalar(i) == False:
_isscalar = False
break
if _isscalar:
setscalar(result)
return result
return _elwise_apply(args, mode)
def _matmul(inp1, inp2):
......
......@@ -9,10 +9,13 @@
# pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order
import functools
import numpy as np
from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin
from ..core.ops.builtin import Elemwise
from ..core.tensor import megbrain_graph, utils
from ..core.tensor.array_method import _elwise_apply
from ..core.tensor.utils import isscalar, setscalar
from ..device import get_default_device
from ..jit.tracing import is_tracing
......@@ -74,7 +77,6 @@ __all__ = [
def _elwise(*args, mode):
op = builtin.Elemwise(mode)
tensor_args = list(
filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args)
)
......@@ -84,17 +86,33 @@ def _elwise(*args, mode):
args = utils.convert_inputs(first_arg, *args[1:])
else:
args = utils.convert_inputs(*args)
if mode in ("true_div", "exp", "pow", "log", "expm1", "log1p"):
if mode in (
Elemwise.Mode.TRUE_DIV,
Elemwise.Mode.EXP,
Elemwise.Mode.POW,
Elemwise.Mode.LOG,
Elemwise.Mode.EXPM1,
Elemwise.Mode.LOG1P,
Elemwise.Mode.TANH,
Elemwise.Mode.ACOS,
Elemwise.Mode.ASIN,
Elemwise.Mode.ATAN2,
Elemwise.Mode.CEIL,
Elemwise.Mode.COS,
Elemwise.Mode.FLOOR,
Elemwise.Mode.H_SWISH,
Elemwise.Mode.ROUND,
Elemwise.Mode.SIGMOID,
Elemwise.Mode.SIN,
):
if mode in (
Elemwise.Mode.CEIL,
Elemwise.Mode.FLOOR,
Elemwise.Mode.ROUND,
) and np.issubdtype(args[0].dtype, np.integer):
return args[0]
args = tuple(map(lambda x: x.astype("float32"), args))
_isscalar = True
for i in args:
if isscalar(i) == False:
_isscalar = False
break
(result,) = apply(op, *args)
if _isscalar:
setscalar(result)
return result
return _elwise_apply(args, mode)
def _elemwise_multi_type(*args, mode, **kwargs):
......
......@@ -9,6 +9,7 @@
import numpy as np
import megengine.functional as F
import megengine.functional.elemwise as elemwise
from megengine import tensor
from megengine.core.tensor import dtype
from megengine.functional.elemwise import _elwise
......@@ -166,3 +167,20 @@ def test_qadd():
result_mge = result_mge.astype("float32").numpy()
result_expect = x.astype("float32").numpy() + y.astype("float32").numpy()
np.testing.assert_almost_equal(result_mge, result_expect, decimal=6)
def test_int32_input():
x = tensor(np.array([1, 2, 3, 4, 5]), dtype="int32")
for op_name in elemwise.__all__:
op = getattr(elemwise, op_name)
nargs = op.__code__.co_argcount
if op_name == "clip":
inp = (x, 0, 1)
elif op_name.endswith("_shift"):
inp = (x, 1)
elif op_name.startswith("logical_"):
continue
else:
inp = (x,) * nargs
y = op(*inp)
y.numpy()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册