提交 43fb2c4a 编写于 作者: M Megvii Engine Team

feat(opr): let roll support empty IO

GitOrigin-RevId: b9a59b623a8b16ca0a6af1340cfc226b73128321
上级 b2827cb1
...@@ -1352,10 +1352,11 @@ def roll( ...@@ -1352,10 +1352,11 @@ def roll(
if shift_ == 0: if shift_ == 0:
continue continue
size = shp[axis_normalized_] size = shp[axis_normalized_]
if shift_ > 0: shift_normalized_ = 0 if size == 0 else shift_ % size
a, b = split(out, [size - shift_,], axis=axis_normalized_) if shift_normalized_ > 0:
a, b = split(out, [size - shift_normalized_,], axis=axis_normalized_)
else: else:
a, b = split(out, [-shift_,], axis=axis_normalized_) a, b = split(out, [-shift_normalized_,], axis=axis_normalized_)
out = concat((b, a), axis=axis_normalized_) out = concat((b, a), axis=axis_normalized_)
if shp_bak is not None: if shp_bak is not None:
out = out.reshape(shp_bak) out = out.reshape(shp_bak)
......
...@@ -806,6 +806,8 @@ def test_tile(shape, reps, is_varnode): ...@@ -806,6 +806,8 @@ def test_tile(shape, reps, is_varnode):
[ [
((2, 3), 0, None), ((2, 3), 0, None),
((2, 3), 1, 0), ((2, 3), 1, 0),
((2, 3), 100, 0),
((2, 3), -100, 0),
((2, 3, 4, 5), (-1, 1), (0, 1)), ((2, 3, 4, 5), (-1, 1), (0, 1)),
((2, 3, 4, 5), (-2, 1, 2), (1, 2, 3)), ((2, 3, 4, 5), (-2, 1, 2), (1, 2, 3)),
], ],
...@@ -829,3 +831,24 @@ def test_roll(shape, shifts, axis, is_varnode): ...@@ -829,3 +831,24 @@ def test_roll(shape, shifts, axis, is_varnode):
opr_test( opr_test(
cases, func, ref_fn=lambda inp: np.roll(inp, shifts, axis), network=network cases, func, ref_fn=lambda inp: np.roll(inp, shifts, axis), network=network
) )
@pytest.mark.parametrize(
"shape, shifts, axis", [((10, 0), 5, 1), ((10, 0), -10, 1),],
)
@pytest.mark.parametrize("is_symbolic", [None, True, False])
def test_roll_empty_tensor(shape, shifts, axis, is_symbolic):
inp = tensor(np.random.randn(*shape).astype("float32"))
def func(inp):
return F.roll(inp, shifts, axis)
if is_symbolic is not None:
func = trace(symbolic=is_symbolic)(func)
out_ref = np.roll(inp.numpy(), shifts, axis)
for _ in range(3):
out = F.roll(inp, shifts, axis)
np.testing.assert_equal(out.numpy(), out_ref)
if is_symbolic is None:
break
...@@ -1339,8 +1339,10 @@ void Concat::scn_do_execute() { ...@@ -1339,8 +1339,10 @@ void Concat::scn_do_execute() {
if (real_axis < 0) if (real_axis < 0)
real_axis += in.shape().ndim; real_axis += in.shape().ndim;
end = begin + in.shape().shape[real_axis]; end = begin + in.shape().shape[real_axis];
out.sub(Slice(begin, end).apply(out.layout(), real_axis)). if (!in.layout().is_empty()) {
copy_from_fixlayout(in); out.sub(Slice(begin, end).apply(out.layout(), real_axis)).
copy_from_fixlayout(in);
}
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册