提交 0c37a588 编写于 作者: M Megvii Engine Team

fix(mge/functional): fix F.ones when input is a tensor of scalar type

GitOrigin-RevId: 6d01d6b58d0445b42cc3f3e5f137ebd590af31a4
上级 b0944dc7
......@@ -108,7 +108,7 @@ def full(shape, value, dtype="float32", device=None):
if device is None:
device = get_default_device()
(x,) = Const(value, dtype=dtype, device=device)()
if len(shape) == 0: # scalar
if shape is (): # scalar.shape
return x
return broadcast_to(x, shape)
......
......@@ -739,3 +739,10 @@ def test_cvt_color():
x = tensor(inp)
y = F.img_proc.cvt_color(x, mode="RGB2GRAY")
np.testing.assert_allclose(y.numpy(), out, atol=1e-5)
@pytest.mark.parametrize("val", [2, [2,], [2, 3]])
def test_ones(val):
shp = tensor(val)
np_shp = np.array(val)
np.testing.assert_equal(F.ones(shp), np.ones(np_shp))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册