提交 1040b778 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(mge/functional): fix F.topk(kth_only=True)

GitOrigin-RevId: ddecd1d14b62b43b2934fa2671562b04495ff5b4
上级 551cc701
......@@ -673,7 +673,7 @@ def topk(
:param descending: if True, return the largest elements instead. Default: False
:param kth_only: if True, only the k-th element will be returned. Default: False
:param no_sort: if True, the returned elements can be unordered. Default: False
:return: tuple of two tensors `(topk_tensor, indices_of_int32)`.
:return: tuple of two tensors ``(topk_tensor, indices_of_int32)``
Examples:
......@@ -695,7 +695,7 @@ def topk(
"""
if descending:
inp = -inp
k = -k
if kth_only:
mode = "kth_only"
......@@ -709,21 +709,25 @@ def topk(
(k,) = Const(k, dtype="int32", device=inp.device)()
if len(inp.shape) == 1:
inp = inp.reshape(1, -1)
res = apply(op, inp, k)
if kth_only:
tns = res[0]
(tns,) = apply(op, expand_dims(inp, 0), k)
# FIXME:
# could use a dedicated kernel
# gradient may be routed to other indices if k-th value is not unique
ind = argmax((tns == inp).astype("int8"))
tns = squeeze(tns, 0)
else:
tns, ind = res[0][0], res[1][0]
tns, ind = apply(op, expand_dims(inp, 0), k)
tns = squeeze(tns, 0)
ind = squeeze(ind, 0)
else:
res = apply(op, inp, k)
if kth_only:
tns = res
(tns,) = apply(op, inp, k)
# FIXME: same as above
ind = argmax((expand_dims(tns, 1) == inp).astype("int8"), 1)
else:
tns, ind = res[0], res[1]
tns, ind = apply(op, inp, k)
if descending:
tns = -tns
return tns, ind
......
......@@ -168,3 +168,39 @@ def test_has_inf():
data[0][0][0][0] = float("inf")
rst = F.math._has_inf(tensor(data))
np.testing.assert_equal(rst.numpy(), [1])
@pytest.mark.parametrize("descending", [True, False])
@pytest.mark.parametrize("sorted", [True, False])
@pytest.mark.parametrize("inp1d", [True, False])
@pytest.mark.parametrize("kth_only", [True, False])
def test_topk(descending, sorted, inp1d, kth_only):
k = 3
if inp1d:
data = np.random.permutation(7)
else:
data = np.random.permutation(5 * 7).reshape(5, 7)
data = data.astype(np.int32)
def np_sort(x):
if descending:
return np.sort(x)[..., ::-1]
return np.sort(x)
res = F.topk(
tensor(data), k, descending=descending, no_sort=(not sorted), kth_only=kth_only
)
values, indices = res
values = values.numpy()
indices = indices.numpy()
if kth_only:
np.testing.assert_equal(
values, np.take_along_axis(data, indices[..., None], -1).squeeze(-1)
)
np.testing.assert_equal(values, np_sort(data)[..., k - 1])
else:
np.testing.assert_equal(values, np.take_along_axis(data, indices, -1))
if not sorted:
values = np_sort(values)
np.testing.assert_equal(values, np_sort(data)[..., :k])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册