From 87d6ff228df9d99acf224c5aeca89d2f804a9db5 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 6 May 2021 15:37:26 +0800 Subject: [PATCH] fix(mge): fix dumping backward graph GitOrigin-RevId: 430f110053911dbb7719badb6463a8280376ed42 --- .../megengine/core/tensor/megbrain_graph.py | 2 +- imperative/python/megengine/jit/tracing.py | 5 +++- .../python/test/unit/jit/test_tracing.py | 28 +++++++++++++++++++ 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index eaa3d6b2a..1bff175b2 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -489,7 +489,7 @@ def apply_backward_varnode(op: BackwardGraph, *args: VarNode): graph._make_const_for_backward, args, ) - return _unwrap(outputs) + return outputs set_cpp_apply_backward_varnode(apply_backward_varnode) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 203efe0aa..c4b70d6cd 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -830,7 +830,10 @@ class trace: name=info.name, ) ivars.append(h2v[h]) - ovars = G.apply_normal_varnode(op, *ivars) + if isinstance(op, BackwardGraph): + ovars = G.apply_backward_varnode(op, *ivars) + else: + ovars = G.apply_normal_varnode(op, *ivars) AutoNaming.record_opnode(ovars[0].op) diff --git a/imperative/python/test/unit/jit/test_tracing.py b/imperative/python/test/unit/jit/test_tracing.py index 44ac4f044..480587a92 100644 --- a/imperative/python/test/unit/jit/test_tracing.py +++ b/imperative/python/test/unit/jit/test_tracing.py @@ -247,6 +247,34 @@ def test_dump_volatile(): ) +def test_dump_backward_graph(): + x0 = tensor(np.random.randn(3, 4)) + x1 = tensor(np.random.randn(3, 4)) + + gm = GradManager().attach(x0) + + @trace(symbolic=True, capture_as_const=True) + def f(x0, x1): + with gm: + y = x0 * x1 + gm.backward(y, F.ones_like(y)) + dx0 = x0.grad + return y, dx0 + + y, dx0 = f(x0, x1) + np.testing.assert_equal(dx0.numpy(), x1) + + file = io.BytesIO() + f.dump(file, optimize_for_inference=False) + file.seek(0) + + infer_cg = cgtools.GraphInference(file) + results = list((infer_cg.run(x0, x1)).values()) + + np.testing.assert_equal(results[0], y) + np.testing.assert_equal(results[1], dx0) + + @pytest.mark.parametrize("trace_mode", [False, True]) def test_trace_profiler(trace_mode): @trace(symbolic=trace_mode, profiling=True) -- GitLab