Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
cdbb4a20
MegEngine
项目概览
MegEngine 天元
/
MegEngine
9 个月 前同步成功
通知
392
Star
4702
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
cdbb4a20
编写于
3月 12, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/tensor): fix tensor's serialization behavior
GitOrigin-RevId: 4d74a4b46e6367ce3b17fa3688949d5b707779e8
上级
9da26407
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
80 addition
and
21 deletion
+80
-21
imperative/python/megengine/serialization.py
imperative/python/megengine/serialization.py
+1
-1
imperative/python/megengine/tensor.py
imperative/python/megengine/tensor.py
+40
-10
imperative/python/test/unit/core/tensor_v1_1.mge
imperative/python/test/unit/core/tensor_v1_1.mge
+0
-0
imperative/python/test/unit/core/tensor_v1_2.mge
imperative/python/test/unit/core/tensor_v1_2.mge
+0
-0
imperative/python/test/unit/core/test_serialization.py
imperative/python/test/unit/core/test_serialization.py
+25
-10
imperative/python/test/unit/core/test_tensor_wrapper.py
imperative/python/test/unit/core/test_tensor_wrapper.py
+14
-0
未找到文件。
imperative/python/megengine/serialization.py
浏览文件 @
cdbb4a20
...
...
@@ -55,7 +55,7 @@ def _get_callable_map_location(map_location):
if
map_location
is
None
:
def
callable_map_location
(
state
):
return
st
r
(
get_default_device
())
return
st
ate
elif
isinstance
(
map_location
,
str
):
...
...
imperative/python/megengine/tensor.py
浏览文件 @
cdbb4a20
...
...
@@ -28,6 +28,13 @@ logger = get_logger(__name__)
class
Tensor
(
_Tensor
,
ArrayMethodMixin
):
r
"""
A tensor object represents a multidimensional, homogeneous array of fixed-size items.
:param data: The value of returned Tensor.
:param dtype: The dtype of returned Tensor. Uses data's dtype if not specified.
:param device: The desired device of returned Tensor. Uses :func:`get_default_device` if not specified.
:param is_const: Whether make it a ``ImutableTensor`` in tracing mode.
:param no_cache: Whether cache it for memory sharing.
:param name: Used to improve convenience in graph operation on dumped model.
"""
grad
=
None
...
...
@@ -35,8 +42,16 @@ class Tensor(_Tensor, ArrayMethodMixin):
_qparams
=
None
def
__new__
(
cls
,
data
,
dtype
=
None
,
device
=
None
,
is_const
=
False
,
no_cache
=
False
,
name
=
None
cls
,
data
:
Union
[
"Tensor"
,
np
.
ndarray
,
list
,
"scalar"
]
=
None
,
dtype
:
np
.
dtype
=
None
,
device
:
str
=
None
,
is_const
:
bool
=
False
,
no_cache
:
bool
=
False
,
name
:
str
=
None
,
):
if
data
is
None
:
data
=
[]
if
device
is
None
:
cn
=
get_default_device
()
elif
isinstance
(
device
,
str
):
...
...
@@ -59,13 +74,24 @@ class Tensor(_Tensor, ArrayMethodMixin):
obj
=
_Tensor
.
__new__
(
cls
,
data
,
dtype
,
cn
,
is_const
,
no_cache
,
name
)
return
obj
def
__init__
(
self
,
data
:
Union
[
"Tensor"
,
np
.
ndarray
,
list
,
"scalar"
],
dtype
:
np
.
dtype
=
None
,
device
:
str
=
None
,
is_const
:
bool
=
False
,
no_cache
:
bool
=
False
,
name
:
str
=
None
,
):
pass
@
property
def
shape
(
self
)
->
Union
[
tuple
,
"Tensor"
]:
r
"""
Returns a :class:`tuple` or a :class:`~.Tensor` represents tensor dimensions.
.. note::
The shape of a tensor was usually represented by a :class:`tuple`.
But if a tensor was treated as symbolic placeholder with tracing,
it's shape could also be a :class:`~.Tensor`. See :class:`~.trace` for more details.
...
...
@@ -100,6 +126,9 @@ class Tensor(_Tensor, ArrayMethodMixin):
@
property
def
qparams
(
self
):
r
"""
Returns a :class:`~.QParams` object containing quantization params of a :class:`~.Tensor`.
"""
from
.quantization.utils
import
create_qparams
# pylint: disable=all
if
self
.
_qparams
is
None
:
...
...
@@ -185,18 +214,20 @@ class Tensor(_Tensor, ArrayMethodMixin):
def
__getstate__
(
self
):
r
""" __getstate__ will be called for pickle serialization or deep copy
"""
state
=
{
"numpy"
:
self
.
numpy
(),
"dtype"
:
self
.
dtype
,
"device"
:
self
.
device
.
logical_name
,
}
state
=
{}
if
self
.
_qparams
is
not
None
:
state
[
"qparams"
]
=
self
.
_qparams
return
state
def
__setstate__
(
self
,
state
):
from
.quantization.utils
import
create_qparams
# pylint: disable=all
# for compatibility with old version not using fastcore
if
"data"
in
state
:
data
=
state
.
pop
(
"data"
)
device
=
state
.
pop
(
"device"
)
dtype
=
state
.
pop
(
"dtype"
)
self
.
_reset
(
Tensor
(
data
,
dtype
=
dtype
,
device
=
device
))
# quantize related state for deepcopy
if
"qdict"
in
state
:
qparams
=
state
.
pop
(
"qdict"
)
logger
.
warning
(
...
...
@@ -206,7 +237,6 @@ class Tensor(_Tensor, ArrayMethodMixin):
qparams
=
state
.
pop
(
"qparams"
)
else
:
qparams
=
None
self
.
_reset
(
Tensor
(
state
.
pop
(
"numpy"
),
state
.
pop
(
"dtype"
),
state
.
pop
(
"device"
)))
self
.
_qparams
=
qparams
...
...
imperative/python/test/unit/core/tensor_v1_1.mge
0 → 100644
浏览文件 @
cdbb4a20
文件已添加
imperative/python/test/unit/core/tensor_v1_2.mge
0 → 100644
浏览文件 @
cdbb4a20
文件已添加
imperative/python/test/unit/core/test_serialization.py
浏览文件 @
cdbb4a20
...
...
@@ -6,6 +6,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
os
import
pickle
from
tempfile
import
TemporaryFile
...
...
@@ -18,25 +19,27 @@ from megengine import Parameter, Tensor
def
test_tensor_serialization
():
with
TemporaryFile
()
as
f
:
data
=
np
.
random
.
randint
(
low
=
0
,
high
=
7
,
size
=
[
233
])
a
=
Tensor
(
data
,
device
=
"
xpux
"
,
dtype
=
np
.
int32
)
pickle
.
dump
(
a
,
f
)
a
=
Tensor
(
data
,
device
=
"
cpu0
"
,
dtype
=
np
.
int32
)
mge
.
save
(
a
,
f
)
f
.
seek
(
0
)
b
=
pickle
.
load
(
f
)
np
.
testing
.
assert_equal
(
a
.
numpy
(),
b
.
numpy
())
b
=
mge
.
load
(
f
)
np
.
testing
.
assert_equal
(
a
.
numpy
(),
data
)
assert
b
.
device
.
logical_name
==
"cpu0:0"
assert
b
.
dtype
==
np
.
int32
with
TemporaryFile
()
as
f
:
a
=
Parameter
(
np
.
random
.
random
(
size
=
(
233
,
2
)).
astype
(
np
.
float32
))
pickle
.
dump
(
a
,
f
)
mge
.
save
(
a
,
f
)
f
.
seek
(
0
)
b
=
pickl
e
.
load
(
f
)
b
=
mg
e
.
load
(
f
)
assert
isinstance
(
b
,
Parameter
)
np
.
testing
.
assert_equal
(
a
.
numpy
(),
b
.
numpy
())
with
TemporaryFile
()
as
f
:
a
=
Tensor
(
np
.
random
.
random
(
size
=
(
2
,
233
)).
astype
(
np
.
float32
))
pickle
.
dump
(
a
,
f
)
mge
.
save
(
a
,
f
)
f
.
seek
(
0
)
b
=
pickl
e
.
load
(
f
)
b
=
mg
e
.
load
(
f
)
assert
type
(
b
)
is
Tensor
np
.
testing
.
assert_equal
(
a
.
numpy
(),
b
.
numpy
())
...
...
@@ -66,8 +69,20 @@ def test_tensor_serialization():
with
TemporaryFile
()
as
f
:
a
=
Tensor
(
0
)
a
.
qparams
.
scale
=
Tensor
(
1.0
)
pickle
.
dump
(
a
,
f
)
mge
.
save
(
a
,
f
)
f
.
seek
(
0
)
b
=
pickl
e
.
load
(
f
)
b
=
mg
e
.
load
(
f
)
assert
isinstance
(
b
.
qparams
.
scale
,
Tensor
)
np
.
testing
.
assert_equal
(
b
.
qparams
.
scale
.
numpy
(),
1.0
)
def
test_compatibility
():
def
test_old_tensor
(
model_name
):
path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
model_name
)
old_tensor
=
mge
.
load
(
path
)
assert
np
.
all
(
old_tensor
.
numpy
()
==
[
1
,
2
,
3
])
assert
old_tensor
.
device
.
logical_name
==
"cpu0:0"
assert
old_tensor
.
dtype
==
np
.
int8
test_old_tensor
(
"tensor_v1_1.mge"
)
test_old_tensor
(
"tensor_v1_2.mge"
)
imperative/python/test/unit/core/test_tensor_wrapper.py
浏览文件 @
cdbb4a20
...
...
@@ -98,6 +98,20 @@ def test_as_type():
np
.
testing
.
assert_equal
(
get_zero_point
(
b
.
dtype
),
128
)
def
test_serialization
():
x
=
Tensor
([
1
,
2
,
3
],
dtype
=
np
.
float32
)
newargs
=
x
.
__getnewargs__
()
states
=
x
.
__getstate__
()
assert
np
.
all
(
newargs
[
0
]
==
x
.
numpy
())
assert
newargs
[
1
]
==
x
.
dtype
assert
newargs
[
2
]
==
x
.
device
.
logical_name
assert
not
states
x
.
qparams
states
=
x
.
__getstate__
()
assert
len
(
states
.
keys
())
==
1
assert
states
[
"qparams"
]
==
x
.
qparams
def
test_qparams
():
x
=
Tensor
(
1
)
assert
x
.
qparams
.
scale
is
None
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录