提交 8480302d 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(autograd): make higher order grad experimental

GitOrigin-RevId: 81e1eb0ebfd1f959ba5c8af9ce6e090f8104dd12
上级 72531f2b
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ..core._imperative_rt.core2 import (
set_allow_higher_order_directive as _set_allow_higher_order_directive,
)
__all__ = [
"enable_higher_order_directive",
"disable_higher_order_directive",
]
def enable_higher_order_directive():
_set_allow_higher_order_directive(True)
def disable_higher_order_directive():
_set_allow_higher_order_directive(False)
......@@ -271,7 +271,7 @@ struct GradFn : std::enable_shared_from_this<GradFn> {
pool.free(ptr);
}
std::shared_ptr<GradFn> make() {
static std::shared_ptr<GradFn> make() {
return std::shared_ptr<GradFn>(pool.alloc(), &deleter);
}
......@@ -316,14 +316,18 @@ public:
apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
// copy inputs first, or trace will make InputNodes for each usage
ApplyContext ctx_dup = ctx;
SmallVector<std::shared_ptr<Tensor>> inputs_copy;
SmallVector<Tensor*> inputs_copy_weak;
for (size_t i = 0; i < ctx.nargs; ++i) {
inputs_copy.push_back(python::apply(FastpathCopy::make(), ctx.args[i]->shared_from_this())[0]);
Tensor* input = ctx.args[i];
inputs_copy.push_back(python::apply(FastpathCopy::make(), input)[0]);
inputs_copy_weak.push_back(inputs_copy.back().get());
inputs_copy.back()->m_grad_info_dict = ctx.args[i]->m_grad_info_dict;
if (input->m_flags & Flags::GRAD) {
inputs_copy.back()->m_flags |= Flags::GRAD;
}
}
ApplyContext ctx_dup = ctx;
ctx_dup.args = inputs_copy_weak.data();
auto outputs = apply(ctx_dup);
......@@ -332,7 +336,6 @@ apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_gra
if (!backward_graph) {
return outputs;
}
ret_grad_fn.emplace<BackwardGraphWithClosure>(std::move(backward_graph), ctx_dup, outputs);
return outputs;
......@@ -389,6 +392,12 @@ apply_result_t apply_grad(ApplyContext& ctx) {
if (grad_keys.empty()) {
return apply(ctx);
} else if (grad_keys.size() > 1 && !GradKey::allow_higher_order_directive) {
PyErr_SetString(
PyExc_NotImplementedError,
"second order directive not enabled, please call "
"'megengine.experimental.enable_higher_order_directive'");
throw pyext17::py_err_set();
}
GradFnHelper grad_fn_holder;
......
......@@ -36,6 +36,7 @@ struct GradKey : std::enable_shared_from_this<GradKey>, NonCopyableObj {
bool is_blocked() const {
return priority < sm_min_priority;
}
inline static bool allow_higher_order_directive = false;
private:
static int sm_min_priority;
};
......
......@@ -990,6 +990,9 @@ void init_tensor(py::module m) {
m.def("set_tracing", &set_tracing);
m.def("unset_tracing", &unset_tracing);
m.def("set_allow_higher_order_directive", [](bool value){
GradKey::allow_higher_order_directive = value;
});
}
#undef MGE_PY_INTERFACE
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import os
import platform
import sys
......@@ -9,6 +16,10 @@ import megengine.module
from megengine import Parameter
from megengine.core._imperative_rt.core2 import sync
from megengine.distributed.helper import get_device_count_by_fork
from megengine.experimental.autograd import (
disable_higher_order_directive,
enable_higher_order_directive,
)
from megengine.jit import trace as _trace
from megengine.module import Linear, Module
......@@ -34,3 +45,13 @@ def skip_distributed(request):
platform.system()
)
)
@pytest.fixture(autouse=True)
def resolve_require_higher_order_directive(request):
marker = request.node.get_closest_marker("require_higher_order_directive")
if marker:
enable_higher_order_directive()
yield
if marker:
disable_higher_order_directive()
......@@ -281,6 +281,7 @@ def test_broadcast_grad(trace_mode):
worker()
@pytest.mark.require_higher_order_directive()
def test_2nd_grad_with_manager():
x_np = np.random.rand(10).astype("float32")
x = mge.tensor(x_np)
......@@ -299,6 +300,7 @@ def test_2nd_grad_with_manager():
)
@pytest.mark.require_higher_order_directive()
def test_grad_manager_group():
x_np = np.random.rand(10).astype("float32")
x = mge.tensor(x_np)
......@@ -315,6 +317,7 @@ def test_grad_manager_group():
x.grad = None
@pytest.mark.require_higher_order_directive()
def test_grad_manager_group_visibility():
x_np = np.random.rand(10).astype("float32")
x = mge.tensor(x_np)
......@@ -330,6 +333,7 @@ def test_grad_manager_group_visibility():
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5)
@pytest.mark.require_higher_order_directive()
def test_grad_manager_visibility_by_order():
x_np = np.random.rand(10).astype("float32")
x = mge.tensor(x_np)
......
......@@ -108,6 +108,7 @@ def test_grad_2():
np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6)
@pytest.mark.require_higher_order_directive()
def test_2nd_grad():
x_np = np.random.rand(10).astype("float32")
x = as_tensor(x_np)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册