提交 87d6ff22 编写于 作者: M Megvii Engine Team

fix(mge): fix dumping backward graph

GitOrigin-RevId: 430f110053911dbb7719badb6463a8280376ed42
上级 f31752d5
......@@ -489,7 +489,7 @@ def apply_backward_varnode(op: BackwardGraph, *args: VarNode):
graph._make_const_for_backward,
args,
)
return _unwrap(outputs)
return outputs
set_cpp_apply_backward_varnode(apply_backward_varnode)
......
......@@ -830,7 +830,10 @@ class trace:
name=info.name,
)
ivars.append(h2v[h])
ovars = G.apply_normal_varnode(op, *ivars)
if isinstance(op, BackwardGraph):
ovars = G.apply_backward_varnode(op, *ivars)
else:
ovars = G.apply_normal_varnode(op, *ivars)
AutoNaming.record_opnode(ovars[0].op)
......
......@@ -247,6 +247,34 @@ def test_dump_volatile():
)
def test_dump_backward_graph():
x0 = tensor(np.random.randn(3, 4))
x1 = tensor(np.random.randn(3, 4))
gm = GradManager().attach(x0)
@trace(symbolic=True, capture_as_const=True)
def f(x0, x1):
with gm:
y = x0 * x1
gm.backward(y, F.ones_like(y))
dx0 = x0.grad
return y, dx0
y, dx0 = f(x0, x1)
np.testing.assert_equal(dx0.numpy(), x1)
file = io.BytesIO()
f.dump(file, optimize_for_inference=False)
file.seek(0)
infer_cg = cgtools.GraphInference(file)
results = list((infer_cg.run(x0, x1)).values())
np.testing.assert_equal(results[0], y)
np.testing.assert_equal(results[1], dx0)
@pytest.mark.parametrize("trace_mode", [False, True])
def test_trace_profiler(trace_mode):
@trace(symbolic=trace_mode, profiling=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册