From 717b88e68486e6601383ddf5c547fbb1c2078c3c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 12 May 2021 11:51:19 +0800 Subject: [PATCH] fix(mge/elemwise): fix problem that elemwise.mode is not comparable with string mode GitOrigin-RevId: 82e39be0a975cc72dfe1fe7c206be218c2ada131 --- imperative/python/src/ops.cpp | 18 ++++++++++++------ .../test/unit/functional/test_elemwise.py | 14 +++++--------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index 414edb85..c9d50775 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -191,10 +191,13 @@ struct EnumWrapper { .release().ptr(); } static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op) { - T lhs = reinterpret_cast(self)->value, - rhs = reinterpret_cast(other)->value; if (op == Py_EQ || op == Py_NE) { - RETURN_RICHCOMPARE(lhs, rhs, op); + T lhs, rhs; + if (load(other, rhs) && load(self, lhs)) { + RETURN_RICHCOMPARE(lhs, rhs, op); + } else { + RETURN_RICHCOMPARE(0, 1, op); + } } Py_RETURN_NOTIMPLEMENTED; } @@ -296,10 +299,13 @@ struct BitCombinedEnumWrapper { return cast(lhs & rhs); } static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) { - T lhs = reinterpret_cast(self)->value, - rhs = reinterpret_cast(other)->value; if (op == Py_EQ || op == Py_NE) { - RETURN_RICHCOMPARE(lhs, rhs, op); + T lhs, rhs; + if (load(other, rhs) && load(self, lhs)) { + RETURN_RICHCOMPARE(lhs, rhs, op); + } else { + RETURN_RICHCOMPARE(0, 1, op); + } } Py_RETURN_NOTIMPLEMENTED; } diff --git a/imperative/python/test/unit/functional/test_elemwise.py b/imperative/python/test/unit/functional/test_elemwise.py index 3c6d43a5..09f1a6b9 100644 --- a/imperative/python/test/unit/functional/test_elemwise.py +++ b/imperative/python/test/unit/functional/test_elemwise.py @@ -12,7 +12,7 @@ import megengine.functional as F import megengine.functional.elemwise as elemwise from megengine import tensor from megengine.core.tensor import dtype -from megengine.functional.elemwise import _elwise +from megengine.functional.elemwise import Elemwise, _elwise def test_abs(): @@ -25,14 +25,10 @@ def test_abs(): def test_elemwise_mode_string(): - np.testing.assert_allclose( - _elwise(tensor([-3.0, -4.0, -5.0]), mode="ABS").numpy(), - np.abs(np.array([-3.0, -4.0, -5.0], dtype=np.float32)), - ) - - np.testing.assert_allclose( - _elwise(-3.0, mode="ABS").numpy(), np.abs(np.float32(-3.0)) - ) + for key, mode in vars(Elemwise.Mode).items(): + if isinstance(mode, Elemwise.Mode): + assert key == mode + assert Elemwise(mode=key) == Elemwise(mode=mode) def test_multiply(): -- GitLab