From 249382899bfa4244fe69fdf7cf0faf54a1799e03 Mon Sep 17 00:00:00 2001 From: Wei Luning Date: Wed, 9 Sep 2020 14:31:44 +0800 Subject: [PATCH] fix bug in cell pickle and copy --- mindspore/ccsrc/pybind_api/ir/cell_py.cc | 15 ++++++++++++++- tests/ut/python/nn/test_cell.py | 6 ++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/pybind_api/ir/cell_py.cc b/mindspore/ccsrc/pybind_api/ir/cell_py.cc index efd80209a..6ec01e902 100644 --- a/mindspore/ccsrc/pybind_api/ir/cell_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/cell_py.cc @@ -45,6 +45,19 @@ REGISTER_PYBIND_DEFINE(Cell, ([](const py::module *m) { .def("_del_attr", &Cell::DelAttr, "Delete Cell attr.") .def( "construct", []() { MS_LOG(EXCEPTION) << "we should define `construct` for all `cell`."; }, - "construct"); + "construct") + .def(py::pickle( + [](const Cell &cell) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(py::str(cell.name())); + }, + [](const py::tuple &tup) { // __setstate__ + if (tup.size() != 1) { + throw std::runtime_error("Invalid state!"); + } + /* Create a new C++ instance */ + Cell data(tup[0].cast()); + return data; + })); })); } // namespace mindspore diff --git a/tests/ut/python/nn/test_cell.py b/tests/ut/python/nn/test_cell.py index 30066ee85..0c4668403 100644 --- a/tests/ut/python/nn/test_cell.py +++ b/tests/ut/python/nn/test_cell.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """ test cell """ +import copy import numpy as np import pytest @@ -200,6 +201,11 @@ def test_exceptions(): m.construct() +def test_cell_copy(): + net = ConvNet() + copy.deepcopy(net) + + def test_del(): """ test_del """ ta = Tensor(np.ones([2, 3])) -- GitLab