From 9f4bffbd00590bcf09655c114b97f17d21b7353e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 12 Oct 2020 17:50:32 +0800 Subject: [PATCH] fix(mge/tensor): fix valid_broadcast GitOrigin-RevId: 562b7664e23cd336d942568203df03958b67a4b7 --- .../python/megengine/core/tensor/indexing.py | 2 +- .../megengine/core/tensor/tensor_wrapper.py | 6 +++--- imperative/python/test/unit/test_tracing.py | 15 +++++++++++++++ 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/imperative/python/megengine/core/tensor/indexing.py b/imperative/python/megengine/core/tensor/indexing.py index e4bd8377d..40a6f1aba 100644 --- a/imperative/python/megengine/core/tensor/indexing.py +++ b/imperative/python/megengine/core/tensor/indexing.py @@ -173,7 +173,7 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): item.append(True) v = get_index(v) assert np.issubdtype(v.dtype, np.integer) or np.issubdtype( - v.dtype, np.bool + v.dtype, np.bool_ ), "var type in the subscript must be int or bool" tensors.append(v) diff --git a/imperative/python/megengine/core/tensor/tensor_wrapper.py b/imperative/python/megengine/core/tensor/tensor_wrapper.py index 47c54441d..6c8a32771 100644 --- a/imperative/python/megengine/core/tensor/tensor_wrapper.py +++ b/imperative/python/megengine/core/tensor/tensor_wrapper.py @@ -65,10 +65,10 @@ def _broadcast(inp, shape): ) ) - if isinstance(src, (Tensor, TensorWrapperBase)): + if isinstance(src, (TensorBase, TensorWrapperBase)): src = src.numpy() - if isinstance(tar, (Tensor, TensorWrapperBase)): + if isinstance(tar, (TensorBase, TensorWrapperBase)): tar = tar.numpy() if len(src) > len(tar): @@ -78,8 +78,8 @@ def _broadcast(inp, shape): if src[-i - 1] != 1 and src[-i - 1] != tar[-i - 1]: failed() - valid_broadcast(inp.shape, shape) shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device) + valid_broadcast(inp.shape, shape) (result,) = apply(builtin.Broadcast(), inp, shape) return result diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index 805d01216..bca796a32 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -379,3 +379,18 @@ def test_trace_nms(): f(*make_inputs(10)) f(*make_inputs(20)) f(*make_inputs(30)) + + +def test_trace_valid_broadcast(): + set_tensor_shape(True) + x1 = tensor(np.random.randn(1, 1)) + x2 = tensor(np.random.randn(1, 2)) + shape = (tensor([2]), tensor([2])) + + @trace(symbolic=False) + def f(x, shape): + y = F.broadcast_to(x, shape) + return y + + f(x1, shape) + f(x2, shape) -- GitLab