提交 5431929e 编写于 作者: M Megvii Engine Team

feat(functional): let advance indexing support empty tensor and add more tests

GitOrigin-RevId: 49e1492934813caf4e491a901610b95439bac236
上级 703b783c
......@@ -176,6 +176,8 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
def is_bool_list(x):
if not isinstance(x, list):
return False
if len(x) == 0:
return False
for i in x:
if not isinstance(i, bool):
return False
......@@ -246,17 +248,6 @@ def getitem(tensor, index):
if len(try_result) == 2:
return try_result[0]
tensor, tensors, items, use_subtensor, ret_scalar = unpack_getitem(tensor, index)
for v in tensors:
if v.shape is None:
break
if isinstance(v.shape, v.__class__):
break
if len(v.shape) > 0 and v.shape[0] == 0:
(empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)(
tensor
)
return empty_tensor
if use_subtensor:
op = builtin.Subtensor(items=items)
else:
......
......@@ -610,6 +610,25 @@ def test_subtensor_on_empty_tensor(symbolic):
run_test(lambda x: x[100:200, 300:400, 500:600])
@pytest.mark.parametrize("symbolic", [True, False, None])
def test_indexingMultiAxisVec_on_empty_tensor(symbolic):
np_x = np.array([], dtype=np.float32).reshape(10, 10, 0)
mge_x = megengine.tensor(np_x)
def run_test(fn):
out_ref = fn(np_x)
if symbolic is not None:
fn = jit.trace(symbolic=symbolic)(fn)
for i in range(3):
out = fn(mge_x)
np.testing.assert_equal(out.numpy(), out_ref)
run_test(lambda x: x[[1, 2, 3]])
run_test(lambda x: x[[1, 2, 3], [4, 5, 6]])
run_test(lambda x: x[[]])
run_test(lambda x: x[[], [], []])
@pytest.mark.parametrize("symbolic", [True, False, None])
def test_setsubtensor_on_empty_tensor(symbolic):
def run_test(inp_shp, fn):
......@@ -655,3 +674,39 @@ def test_setsubtensor_on_empty_tensor(symbolic):
run_test((10, 10, 10), test4)
run_test((10, 10, 10), test5)
run_test((10, 10, 10), test6)
@pytest.mark.parametrize("symbolic", [True, False, None])
def test_indexingSetMultiAxisVec_on_empty_tensor(symbolic):
def run_test(inp_shp, fn):
np_x = np.random.randn(*inp_shp).astype(np.float32)
mge_x = megengine.tensor(np_x)
out_ref = fn(np_x)
if symbolic is not None:
fn = jit.trace(symbolic=symbolic)(fn)
for i in range(3):
out = fn(mge_x)
np.testing.assert_equal(out.numpy(), out_ref)
def test1(x):
x[[1, 2, 3]] = x[[1, 2, 3]]
return x
def test2(x):
x[[1, 2, 3], [1, 2, 3]] = x[[1, 2, 3], [1, 2, 3]]
return x
def test3(x):
x[[]] = x[[]]
return x
def test4(x):
x[[], [], []] = x[[], [], []]
return x
run_test((10, 10, 0), test1)
run_test((10, 10, 0), test2)
run_test((10, 10, 0), test3)
run_test((10, 10, 0), test4)
run_test((10, 10, 10), test3)
run_test((10, 10, 10), test4)
......@@ -860,8 +860,8 @@ def test_condtake():
np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
# @pytest.mark.parametrize("is_symbolic", [None, False, True])
def test_condtake(is_symbolic=None):
@pytest.mark.parametrize("is_symbolic", [None, False, True])
def test_condtake(is_symbolic):
shapes = [
(3, 3, 3),
(0,),
......
......@@ -292,8 +292,6 @@ cg::OperatorNodeBase::NodeProp*
IndexingMultiAxisVecBase<Opr>::do_make_node_prop() const {
auto prop = Super::do_make_node_prop();
using DT = NodeProp::DepType;
// TODO: should also allow input shape is empty if any
// indexer's shape is empty
prop->add_dep_type_existing_var(input(0), DT::VALUE_ALLOW_EMPTY);
for (auto i: m_input2idxonly_axis_indexer) {
if (i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册