未验证 提交 075a02d2 编写于 作者: W Weilong Wu 提交者: GitHub

Fix _numel func logic and add test (#37810)

上级 a3b3ec68
......@@ -1963,10 +1963,6 @@ void BindImperative(py::module *m_ptr) {
.def("_numel",
[](std::shared_ptr<imperative::VarBase> &self) {
auto *t = self->MutableVar()->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(
t->IsInitialized(), true,
platform::errors::InvalidArgument(
"Tensor %s has not been initialized!", self->Name()));
return t->numel();
})
.def_property("name", &imperative::VarBase::Name,
......
......@@ -1279,7 +1279,7 @@ class TestVarBaseInitVarBaseFromTensorWithDevice(unittest.TestCase):
class TestVarBaseNumel(unittest.TestCase):
def test_numel(self):
def test_numel_normal(self):
paddle.disable_static()
np_x = np.random.random((3, 8, 8))
x = paddle.to_tensor(np_x, dtype="float64")
......@@ -1287,6 +1287,12 @@ class TestVarBaseNumel(unittest.TestCase):
x_expected_numel = np.product((3, 8, 8))
self.assertEqual(x_actual_numel, x_expected_numel)
def test_numel_without_holder(self):
paddle.disable_static()
x_without_holder = core.VarBase()
x_actual_numel = x_without_holder._numel()
self.assertEqual(x_actual_numel, 0)
class TestVarBaseCopyGradientFrom(unittest.TestCase):
def test_copy_gradient_from(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册