提交 cdbb4a20 编写于 作者: M Megvii Engine Team

fix(mge/tensor): fix tensor's serialization behavior

GitOrigin-RevId: 4d74a4b46e6367ce3b17fa3688949d5b707779e8
上级 9da26407
......@@ -55,7 +55,7 @@ def _get_callable_map_location(map_location):
if map_location is None:
def callable_map_location(state):
return str(get_default_device())
return state
elif isinstance(map_location, str):
......
......@@ -28,6 +28,13 @@ logger = get_logger(__name__)
class Tensor(_Tensor, ArrayMethodMixin):
r"""
A tensor object represents a multidimensional, homogeneous array of fixed-size items.
:param data: The value of returned Tensor.
:param dtype: The dtype of returned Tensor. Uses data's dtype if not specified.
:param device: The desired device of returned Tensor. Uses :func:`get_default_device` if not specified.
:param is_const: Whether make it a ``ImutableTensor`` in tracing mode.
:param no_cache: Whether cache it for memory sharing.
:param name: Used to improve convenience in graph operation on dumped model.
"""
grad = None
......@@ -35,8 +42,16 @@ class Tensor(_Tensor, ArrayMethodMixin):
_qparams = None
def __new__(
cls, data, dtype=None, device=None, is_const=False, no_cache=False, name=None
cls,
data: Union["Tensor", np.ndarray, list, "scalar"] = None,
dtype: np.dtype = None,
device: str = None,
is_const: bool = False,
no_cache: bool = False,
name: str = None,
):
if data is None:
data = []
if device is None:
cn = get_default_device()
elif isinstance(device, str):
......@@ -59,13 +74,24 @@ class Tensor(_Tensor, ArrayMethodMixin):
obj = _Tensor.__new__(cls, data, dtype, cn, is_const, no_cache, name)
return obj
def __init__(
self,
data: Union["Tensor", np.ndarray, list, "scalar"],
dtype: np.dtype = None,
device: str = None,
is_const: bool = False,
no_cache: bool = False,
name: str = None,
):
pass
@property
def shape(self) -> Union[tuple, "Tensor"]:
r"""
Returns a :class:`tuple` or a :class:`~.Tensor` represents tensor dimensions.
.. note::
The shape of a tensor was usually represented by a :class:`tuple`.
But if a tensor was treated as symbolic placeholder with tracing,
it's shape could also be a :class:`~.Tensor`. See :class:`~.trace` for more details.
......@@ -100,6 +126,9 @@ class Tensor(_Tensor, ArrayMethodMixin):
@property
def qparams(self):
r"""
Returns a :class:`~.QParams` object containing quantization params of a :class:`~.Tensor`.
"""
from .quantization.utils import create_qparams # pylint: disable=all
if self._qparams is None:
......@@ -185,18 +214,20 @@ class Tensor(_Tensor, ArrayMethodMixin):
def __getstate__(self):
r""" __getstate__ will be called for pickle serialization or deep copy
"""
state = {
"numpy": self.numpy(),
"dtype": self.dtype,
"device": self.device.logical_name,
}
state = {}
if self._qparams is not None:
state["qparams"] = self._qparams
return state
def __setstate__(self, state):
from .quantization.utils import create_qparams # pylint: disable=all
# for compatibility with old version not using fastcore
if "data" in state:
data = state.pop("data")
device = state.pop("device")
dtype = state.pop("dtype")
self._reset(Tensor(data, dtype=dtype, device=device))
# quantize related state for deepcopy
if "qdict" in state:
qparams = state.pop("qdict")
logger.warning(
......@@ -206,7 +237,6 @@ class Tensor(_Tensor, ArrayMethodMixin):
qparams = state.pop("qparams")
else:
qparams = None
self._reset(Tensor(state.pop("numpy"), state.pop("dtype"), state.pop("device")))
self._qparams = qparams
......
......@@ -6,6 +6,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import os
import pickle
from tempfile import TemporaryFile
......@@ -18,25 +19,27 @@ from megengine import Parameter, Tensor
def test_tensor_serialization():
with TemporaryFile() as f:
data = np.random.randint(low=0, high=7, size=[233])
a = Tensor(data, device="xpux", dtype=np.int32)
pickle.dump(a, f)
a = Tensor(data, device="cpu0", dtype=np.int32)
mge.save(a, f)
f.seek(0)
b = pickle.load(f)
np.testing.assert_equal(a.numpy(), b.numpy())
b = mge.load(f)
np.testing.assert_equal(a.numpy(), data)
assert b.device.logical_name == "cpu0:0"
assert b.dtype == np.int32
with TemporaryFile() as f:
a = Parameter(np.random.random(size=(233, 2)).astype(np.float32))
pickle.dump(a, f)
mge.save(a, f)
f.seek(0)
b = pickle.load(f)
b = mge.load(f)
assert isinstance(b, Parameter)
np.testing.assert_equal(a.numpy(), b.numpy())
with TemporaryFile() as f:
a = Tensor(np.random.random(size=(2, 233)).astype(np.float32))
pickle.dump(a, f)
mge.save(a, f)
f.seek(0)
b = pickle.load(f)
b = mge.load(f)
assert type(b) is Tensor
np.testing.assert_equal(a.numpy(), b.numpy())
......@@ -66,8 +69,20 @@ def test_tensor_serialization():
with TemporaryFile() as f:
a = Tensor(0)
a.qparams.scale = Tensor(1.0)
pickle.dump(a, f)
mge.save(a, f)
f.seek(0)
b = pickle.load(f)
b = mge.load(f)
assert isinstance(b.qparams.scale, Tensor)
np.testing.assert_equal(b.qparams.scale.numpy(), 1.0)
def test_compatibility():
def test_old_tensor(model_name):
path = os.path.join(os.path.dirname(__file__), model_name)
old_tensor = mge.load(path)
assert np.all(old_tensor.numpy() == [1, 2, 3])
assert old_tensor.device.logical_name == "cpu0:0"
assert old_tensor.dtype == np.int8
test_old_tensor("tensor_v1_1.mge")
test_old_tensor("tensor_v1_2.mge")
......@@ -98,6 +98,20 @@ def test_as_type():
np.testing.assert_equal(get_zero_point(b.dtype), 128)
def test_serialization():
x = Tensor([1, 2, 3], dtype=np.float32)
newargs = x.__getnewargs__()
states = x.__getstate__()
assert np.all(newargs[0] == x.numpy())
assert newargs[1] == x.dtype
assert newargs[2] == x.device.logical_name
assert not states
x.qparams
states = x.__getstate__()
assert len(states.keys()) == 1
assert states["qparams"] == x.qparams
def test_qparams():
x = Tensor(1)
assert x.qparams.scale is None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册