提交 5c7d48cd 编写于 作者: M Megvii Engine Team

fix(mge/functional): fix tensor split

GitOrigin-RevId: 0a112ab0bdaa82202c50f7f7b9fe05248b22e415
上级 a240d558
......@@ -158,7 +158,7 @@ def div(x, y):
def floor_div(x, y):
"""Element-wise `floor(x / y)`."""
return _elwise(x, y, mode=Elemwise.Mode.FLOOR_DIVIDE)
return _elwise(x, y, mode=Elemwise.Mode.FLOOR_DIV)
def neg(x):
......
......@@ -28,7 +28,7 @@ from ..core.tensor.utils import (
)
from ..device import get_default_device
from ..tensor import Tensor
from .elemwise import ceil
from .elemwise import ceil, floor_div
__all__ = [
"arange",
......@@ -324,52 +324,73 @@ def split(inp, nsplits_or_sections, axis=0):
.. testcode::
import os
import numpy as np
from megengine import tensor
import megengine.functional as F
x = tensor(np.random.random((2,3,4,5)), dtype=np.float32)
out = F.split(x, 2, axis=3)
print(out[0].numpy().shape, out[1].numpy().shape)
x = tensor(np.random.random((10, 20)), dtype=np.float32)
y = F.split(x, 3)
z = F.split(x, [6, 17], axis=1)
if os.environ.get("MEGENGINE_USE_SYMBOLIC_SHAPE"):
print([tuple(i.shape.numpy().tolist()) for i in y])
print([tuple(i.shape.numpy().tolist()) for i in z])
else:
print([i.shape for i in y])
print([i.shape for i in z])
Outputs:
.. testoutput::
(2, 3, 4, 3) (2, 3, 4, 2)
[(4, 20), (3, 20), (3, 20)]
[(10, 6), (10, 11), (10, 3)]
"""
sub_tensors = []
sections = []
ndim = len(inp.shape)
if axis >= ndim:
raise ValueError("Invalid axis {}".format(axis))
def swapaxis(inp, src, dst):
if src == dst:
return inp
shape = [i for i in range(inp.ndim)]
shape[src] = dst
shape[dst] = src
return inp.transpose(shape)
inp = swapaxis(inp, 0, axis)
if isinstance(nsplits_or_sections, int):
incr_step = ceil(inp.shape[0] / nsplits_or_sections)
nsplits = nsplits_or_sections
while nsplits > 0:
nsplits -= 1
sections.append(incr_step.astype("int32"))
incr_step += nsplits_or_sections
else:
sections = nsplits_or_sections
st = 0
for se in sections:
sub_tensors.append(swapaxis(inp[st:se], axis, 0))
st = se
Ntotal = inp.shape[axis]
if st < inp.shape[0]:
sub_tensors.append(swapaxis(inp[st:], axis, 0))
try:
Nsections = len(nsplits_or_sections) + 1
is_array = True
except TypeError:
Nsections = int(nsplits_or_sections)
is_array = False
if is_array:
div_points = [0] + list(nsplits_or_sections) + [Ntotal]
for i in range(1, len(div_points)):
if div_points[i - 1] >= div_points[i]:
raise ValueError(
"Invalid nsplits_or_secions: {}".format(nsplits_or_sections)
)
else: # scalar
if Nsections <= 0:
raise ValueError("Number sections must be larger than 0")
if Nsections > Ntotal:
raise ValueError(
"The size {} at dim {} cannot be split into {} sections".format(
Ntotal, axis, Nsections
)
)
div_points = [0] + [
floor_div(Ntotal + Nsections - i - 1, Nsections) for i in range(Nsections)
]
for i in range(2, Nsections + 1):
div_points[i] = div_points[i - 1] + div_points[i]
sub_tensors = []
for i in range(Nsections):
l = div_points[i]
r = div_points[i + 1]
slices = tuple(
[slice(None)] * axis + [slice(l, r)] + [slice(None)] * (ndim - axis - 1)
)
sub_tensors.append(inp[slices])
return sub_tensors
......
......@@ -77,14 +77,34 @@ def test_stack():
def test_split():
data = np.random.random((2, 3, 4, 5)).astype(np.float32)
mge_out1 = F.split(tensor(data), 2, axis=3)
mge_out2 = F.split(tensor(data), [3, 5], axis=3)
inp = tensor(data)
mge_out0 = F.split(inp, 2, axis=3)
mge_out1 = F.split(inp, [3], axis=3)
np_out = np.split(data, [3, 5], axis=3)
np.testing.assert_equal(mge_out1[0].numpy(), mge_out2[0].numpy())
assert len(mge_out0) == 2
assert len(mge_out1) == 2
np.testing.assert_equal(mge_out0[0].numpy(), np_out[0])
np.testing.assert_equal(mge_out1[0].numpy(), np_out[0])
np.testing.assert_equal(mge_out0[1].numpy(), np_out[1])
np.testing.assert_equal(mge_out1[1].numpy(), np_out[1])
try:
F.split(inp, 4)
assert False
except ValueError as e:
pass
try:
F.split(inp, [3, 3, 5], axis=3)
assert False
except ValueError as e:
assert str(e) == "Invalid nsplits_or_secions: [3, 3, 5]"
def test_reshape():
x = np.arange(6, dtype="float32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册