提交 4566cc21 编写于 作者: M Megvii Engine Team

fix(mge/utils): fix bug of VarNode inplace operations

GitOrigin-RevId: fa9eec7079671a117809c3da8ae7338e12f345f0
上级 cf892ec0
...@@ -6,10 +6,9 @@ ...@@ -6,10 +6,9 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import abc
import json import json
import sys import sys
from typing import Callable, Sequence from typing import Sequence
import numpy as np import numpy as np
...@@ -19,10 +18,7 @@ from ..core._trace_option import use_symbolic_shape ...@@ -19,10 +18,7 @@ from ..core._trace_option import use_symbolic_shape
from ..core._wrap import Device from ..core._wrap import Device
from ..core.ops import builtin from ..core.ops import builtin
from ..core.tensor.array_method import ArrayMethodMixin from ..core.tensor.array_method import ArrayMethodMixin
from ..core.tensor.indexing import getitem as _getitem from ..core.tensor.megbrain_graph import OutputNode
from ..core.tensor.indexing import setitem as _setitem
from ..core.tensor.megbrain_graph import InputNode, OutputNode
from ..tensor import Tensor
from .comp_graph_tools import replace_vars from .comp_graph_tools import replace_vars
from .module_stats import ( from .module_stats import (
preprocess_receptive_field, preprocess_receptive_field,
...@@ -110,18 +106,18 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): ...@@ -110,18 +106,18 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
self.graph.compile(o.outputs).execute() self.graph.compile(o.outputs).execute()
return o.get_value().numpy() return o.get_value().numpy()
def __getitem__(self, index): def _reset(self, other):
return _getitem(self, index) if not isinstance(other, VarNode):
assert self.graph, "VarNode _reset must have graph"
def __setitem__(self, index, value): node = ImmutableTensor(other, graph=self.graph)
if index is not Ellipsis: node.compile(self.graph)
value = _setitem(self, index, value) other = node.outputs[0]
if self.owner is not None: if self.owner is not None:
idx = self.owner.outputs.index(self) idx = self.owner.outputs.index(self)
self.owner.outputs[idx] = VarNode( self.owner.outputs[idx] = VarNode(
self.var, owner_opr=self.owner, name=self.var.name self.var, owner_opr=self.owner, name=self.var.name
) )
self.var = value.var self.var = other.var
self.owner = None self.owner = None
def set_owner_opr(self, owner_opr): def set_owner_opr(self, owner_opr):
......
...@@ -9,38 +9,81 @@ ...@@ -9,38 +9,81 @@
import copy import copy
import numpy as np import numpy as np
import pytest
from utils import make_tensor
from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8
from megengine.tensor import Tensor from megengine.tensor import Tensor
from megengine.utils.network import Network
def test_basic(): @pytest.mark.parametrize("is_varnode", [True, False])
def test_basic(is_varnode):
if is_varnode:
network = Network()
else:
network = None
x_np = np.random.rand(10).astype("float32") x_np = np.random.rand(10).astype("float32")
x = Tensor(x_np) x = make_tensor(x_np, network)
y = x * x y = x * x
y_np = y.numpy() y_np = y.numpy()
np.testing.assert_almost_equal(y_np, x_np * x_np) np.testing.assert_almost_equal(y_np, x_np * x_np)
def test_literal_arith(): @pytest.mark.parametrize("is_varnode", [True, False])
def test_literal_arith(is_varnode):
if is_varnode:
network = Network()
else:
network = None
x_np = np.random.rand(10).astype("float32") x_np = np.random.rand(10).astype("float32")
x = Tensor(x_np) x = make_tensor(x_np, network)
y = x * 2 y = x * 2
y_np = y.numpy() y_np = y.numpy()
np.testing.assert_almost_equal(y_np, x_np * 2) np.testing.assert_almost_equal(y_np, x_np * 2)
def test_matmul(): @pytest.mark.parametrize("is_varnode", [True, False])
A = Tensor(np.random.rand(5, 7).astype("float32")) def test_matmul(is_varnode):
B = Tensor(np.random.rand(7, 10).astype("float32")) if is_varnode:
network = Network()
else:
network = None
A = make_tensor(np.random.rand(5, 7).astype("float32"), network)
B = make_tensor(np.random.rand(7, 10).astype("float32"), network)
C = A @ B C = A @ B
np.testing.assert_almost_equal(C.numpy(), A.numpy() @ B.numpy(), decimal=6) np.testing.assert_almost_equal(C.numpy(), A.numpy() @ B.numpy(), decimal=6)
def test_reduce(): @pytest.mark.parametrize("is_varnode", [True, False])
def test_inplace_add(is_varnode):
if is_varnode:
network = Network()
else:
network = None
x_np = np.random.rand(10).astype("float32")
y_np = np.random.rand(10).astype("float32")
x = make_tensor(x_np, network)
y = make_tensor(y_np, network)
y += x
out_np = y.numpy()
np.testing.assert_almost_equal(out_np, x_np + y_np)
@pytest.mark.parametrize("is_varnode", [True, False])
def test_reduce(is_varnode):
if is_varnode:
network = Network()
else:
network = None
def test_x(x_np): def test_x(x_np):
for m in ["sum", "prod", "min", "max", "mean"]: for m in ["sum", "prod", "min", "max", "mean"]:
x = Tensor(x_np) x = make_tensor(x_np, network)
y = getattr(x, m)(axis=-1, keepdims=True) y = getattr(x, m)(axis=-1, keepdims=True)
np.testing.assert_almost_equal(y.numpy(), getattr(x_np, m)(-1), decimal=6) np.testing.assert_almost_equal(y.numpy(), getattr(x_np, m)(-1), decimal=6)
...@@ -50,16 +93,28 @@ def test_reduce(): ...@@ -50,16 +93,28 @@ def test_reduce():
test_x(np.array([True, False, True])) test_x(np.array([True, False, True]))
def test_set_value(): @pytest.mark.parametrize("is_varnode", [True, False])
def test_set_value(is_varnode):
if is_varnode:
network = Network()
else:
network = None
v0 = np.random.random((2, 3)).astype(np.float32) v0 = np.random.random((2, 3)).astype(np.float32)
param = Tensor(v0) param = make_tensor(v0, network)
v1 = np.random.random((2, 3)).astype(np.float32) v1 = np.random.random((2, 3)).astype(np.float32)
param[...] = v1 param[...] = v1
np.testing.assert_allclose(param.numpy(), v1, atol=5e-6) np.testing.assert_allclose(param.numpy(), v1, atol=5e-6)
def test_set_subtensor(): @pytest.mark.parametrize("is_varnode", [True, False])
x = Tensor([1, 2, 3]) def test_set_subtensor(is_varnode):
if is_varnode:
network = Network()
else:
network = None
x = make_tensor([1, 2, 3], network)
x[:] = [1, 1, 1] x[:] = [1, 1, 1]
np.testing.assert_almost_equal(x.numpy(), [1, 1, 1], decimal=6) np.testing.assert_almost_equal(x.numpy(), [1, 1, 1], decimal=6)
x[[0, 2]] = [3, 2] x[[0, 2]] = [3, 2]
...@@ -78,14 +133,27 @@ def test_computing_with_numpy_array(): ...@@ -78,14 +133,27 @@ def test_computing_with_numpy_array():
np.testing.assert_equal(np.equal(xx, xx).numpy(), np.equal(x, x)) np.testing.assert_equal(np.equal(xx, xx).numpy(), np.equal(x, x))
def test_transpose(): @pytest.mark.parametrize("is_varnode", [True, False])
def test_transpose(is_varnode):
if is_varnode:
network = Network()
else:
network = None
x = np.random.rand(2, 5).astype("float32") x = np.random.rand(2, 5).astype("float32")
xx = Tensor(x) xx = make_tensor(x, network)
np.testing.assert_almost_equal(xx.T.numpy(), x.T) np.testing.assert_almost_equal(xx.T.numpy(), x.T)
def test_as_type(): @pytest.mark.parametrize("is_varnode", [True, False])
x = Tensor([1, 2, 3], dtype=np.float32) def test_as_type(is_varnode):
if is_varnode:
network = Network()
else:
network = None
x_np = np.array([1, 2, 3], dtype=np.float32)
x = make_tensor(x_np, network)
y = x.astype(qint8(0.1)) y = x.astype(qint8(0.1))
np.testing.assert_almost_equal(get_scale(y.dtype), 0.1) np.testing.assert_almost_equal(get_scale(y.dtype), 0.1)
z = y.astype(qint8(0.2)) z = y.astype(qint8(0.2))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册