Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
de0742be
MegEngine
项目概览
MegEngine 天元
/
MegEngine
9 个月 前同步成功
通知
392
Star
4702
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
de0742be
编写于
1月 03, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(mge): reopen passed assertions
GitOrigin-RevId: e0276e73e31ddba1e35a1561abe3f178eedd509a
上级
a90c937d
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
38 addition
and
31 deletion
+38
-31
imperative/python/megengine/core/tensor/megbrain_graph.py
imperative/python/megengine/core/tensor/megbrain_graph.py
+3
-0
imperative/python/megengine/distributed/functional.py
imperative/python/megengine/distributed/functional.py
+2
-0
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+25
-29
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+8
-1
imperative/python/test/unit/autodiff/test_grad_manger.py
imperative/python/test/unit/autodiff/test_grad_manger.py
+0
-1
未找到文件。
imperative/python/megengine/core/tensor/megbrain_graph.py
浏览文件 @
de0742be
...
...
@@ -450,6 +450,9 @@ def _unwrap(x):
def
apply_normal_varnode
(
op
:
OpDef
,
*
args
:
VarNode
):
# for PyOp like RemoteSend/Recv
if
getattr
(
op
,
"op"
,
None
):
op
=
op
.
op
outputs
=
_imperative_rt
.
invoke_op
(
op
,
_unwrap
(
args
))
return
_wrap
(
outputs
)
...
...
imperative/python/megengine/distributed/functional.py
浏览文件 @
de0742be
...
...
@@ -292,6 +292,8 @@ def remote_recv(
op
=
RemoteRecv
()
op
.
key
=
key
op
.
cn
=
device
if
isinstance
(
shape
,
Tensor
):
shape
=
shape
.
numpy
()
op
.
shape
=
shape
op
.
dtype
=
dtype
op
.
addr
,
op
.
port
=
get_mm_server_addr
()
...
...
imperative/python/megengine/jit/tracing.py
浏览文件 @
de0742be
...
...
@@ -234,20 +234,21 @@ class trace:
)
info
.
data_setter
.
set_value
(
x
.
_dev_tensor
())
else
:
pass
# if x.__class__ is not CompiledTensorProxy:
# if x not in self._tensor_remaps:
# raise TraceMismatchError(
# "unexpected capture: trying to use an external tensor as "
# "input, but that input was an internal tensor last time"
# )
# else:
# x = self._tensor_remaps[x]
# if x._CompiledTensorProxy__handle != h:
# raise TraceMismatchError(
# "mis-wiring: input edge to an data flow "
# "graph node is different from last time"
# )
if
x
.
mixin_handle
==
-
1
:
if
x
.
_handle
not
in
self
.
_tensor_remaps
:
raise
TraceMismatchError
(
"unexpected capture: trying to use an external tensor as "
"input, but that input was an internal tensor last time"
)
else
:
x
.
mixin_handle
=
self
.
_tensor_remaps
[
x
.
_handle
].
_CompiledTensorProxy__handle
if
x
.
mixin_handle
!=
h
:
raise
TraceMismatchError
(
"mis-wiring: input edge to an data flow "
"graph node is different from last time"
)
self
.
_pc
+=
1
outputs
=
[]
...
...
@@ -268,14 +269,11 @@ class trace:
op_
,
ihandles
,
ohandles
=
record
assert
isinstance
(
op_
,
str
)
and
op_
==
"Const"
# TODO : assert on const value
# eq = value == self._tinfo[ohandles[0]].bound_data.numpy()
# if not isinstance(eq, bool):
# eq = all(eq)
# if not eq:
# raise TraceMismatchError(
# "const tensor violated: got a different tensor this time"
# )
eq
=
np
.
all
(
np
.
atleast_1d
(
value
)
==
self
.
_tinfo
[
ohandles
[
0
]].
bound_data
.
numpy
())
if
not
eq
:
raise
TraceMismatchError
(
"const tensor violated: got a different tensor this time"
)
self
.
_pc
+=
1
(
h
,)
=
ohandles
...
...
@@ -750,7 +748,6 @@ class trace:
dtype
=
info
.
dtype
,
device
=
dumped_device
,
shape
=
info
.
shape
or
(
1
,),
name
=
k
)
set_tracing
()
for
op
,
ihandles
,
ohandles
in
self
.
_seq
:
if
isinstance
(
op
,
str
)
and
op
==
"Const"
:
assert
len
(
ihandles
)
==
0
...
...
@@ -776,7 +773,6 @@ class trace:
ovars
=
G
.
apply_normal_varnode
(
op
,
*
ivars
)
assert
len
(
ovars
)
==
len
(
ohandles
)
h2v
.
update
(
zip
(
ohandles
,
ovars
))
unset_tracing
()
dest_vars
=
[]
for
i
,
h
in
enumerate
(
self
.
_output_bindings
):
...
...
@@ -843,7 +839,7 @@ class trace:
if
x
.
device
!=
info
.
device
:
raise
TypeError
(
"args[%d].device different from last time"
%
i
)
info
.
data_setter
.
set_value
(
x
.
_dev_tensor
())
self
.
_tensor_remaps
[
x
]
=
CompiledTensorProxy
(
h
)
self
.
_tensor_remaps
[
x
.
_handle
]
=
CompiledTensorProxy
(
h
)
kwargs_tensors
=
{}
for
k
,
x
in
kwargs
.
items
():
...
...
@@ -870,7 +866,7 @@ class trace:
if
x
.
device
!=
info
.
device
:
raise
TypeError
(
"kwargs[%s].device different from last time"
%
k
)
info
.
data_setter
.
set_value
(
x
.
_dev_tensor
())
self
.
_tensor_remaps
[
x
]
=
CompiledTensorProxy
(
h
)
self
.
_tensor_remaps
[
x
.
_handle
]
=
CompiledTensorProxy
(
h
)
def
_process_outputs
(
self
,
outputs
):
output_names
=
None
...
...
@@ -1000,8 +996,8 @@ class CompiledTensorProxy:
def
__del__
(
self
):
if
self
.
__tensor
.
shape_read
and
self
.
__shape
is
not
None
:
self
.
__info
.
shape_reader
.
drop_value
()
#
if self.__tensor.value_read and self.__value is not None:
#
self.__info.value_reader.drop_value()
if
self
.
__tensor
.
value_read
and
self
.
__value
is
not
None
:
self
.
__info
.
value_reader
.
drop_value
()
if
self
.
__tensor
.
data_read
and
self
.
__data
is
not
None
:
self
.
__info
.
data_reader
.
drop_value
()
...
...
@@ -1047,7 +1043,7 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor):
outputs
=
[
RawTensor
(
o
)
for
o
in
ovars
]
if
require_links
:
active_trace
.
_lazy_eval_links
=
(
outputs
[
0
].
_varnode
,)
active_trace
.
_lazy_eval_links
=
(
G
.
VarNode
(
outputs
[
0
].
_varnode
)
,)
active_trace
.
_lazy_eval_tensors
.
update
([
TensorWeakRef
(
o
)
for
o
in
outputs
])
return
outputs
...
...
imperative/python/src/tensor.cpp
浏览文件 @
de0742be
...
...
@@ -760,7 +760,14 @@ void init_tensor(py::module m) {
m
.
attr
(
"skip_tracing"
)
=
&
skip_tracing
;
py
::
class_
<
SharedHandle
>
(
m
,
"SharedHandle"
)
.
def
(
py
::
init
<
const
SharedHandle
&>
());
.
def
(
py
::
init
<
const
SharedHandle
&>
())
.
def
(
"__eq__"
,
[](
SharedHandle
&
thish
,
SharedHandle
&
thath
)
{
return
(
thish
.
get
()
==
thath
.
get
());
})
.
def
(
"__hash__"
,
[](
SharedHandle
&
sh
)
{
return
reinterpret_cast
<
int64_t
>
(
sh
.
get
());
})
;
m
.
def
(
"set_tracing"
,
&
set_tracing
);
m
.
def
(
"unset_tracing"
,
&
unset_tracing
);
...
...
imperative/python/test/unit/autodiff/test_grad_manger.py
浏览文件 @
de0742be
...
...
@@ -141,7 +141,6 @@ def test_regression_1762():
)
@
pytest
.
mark
.
skipif
(
get_device_count_by_fork
(
"gpu"
)
<
2
,
reason
=
"need more gpu device"
)
@
pytest
.
mark
.
isolated_distributed
@
pytest
.
mark
.
skip
(
reason
=
"FIXME: remote_send/recv"
)
def
test_remote_grad
():
@
dist
.
launcher
def
worker
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录