提交 09af925f 编写于 作者: M Megvii Engine Team

fix(mge): fix cpp trace function release

GitOrigin-RevId: 924f945c211bc17596710410e616ab4b1e2e612e
上级 3975a54a
...@@ -72,7 +72,6 @@ if sys.platform == "win32": ...@@ -72,7 +72,6 @@ if sys.platform == "win32":
kernel32.SetErrorMode(old_error_mode) kernel32.SetErrorMode(old_error_mode)
from .core._imperative_rt.core2 import full_sync as _full_sync from .core._imperative_rt.core2 import full_sync as _full_sync
from .core._imperative_rt.core2 import release_trace_apply_func
from .core._imperative_rt.core2 import sync as _sync from .core._imperative_rt.core2 import sync as _sync
from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func
from .device import * from .device import *
...@@ -92,9 +91,7 @@ _persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer() ...@@ -92,9 +91,7 @@ _persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer()
_persistent_cache_impl_ins.reg() _persistent_cache_impl_ins.reg()
atexit.register(_full_sync) atexit.register(_full_sync)
atexit.register(release_trace_apply_func)
del release_trace_apply_func
del _set_fork_exec_path_for_timed_func del _set_fork_exec_path_for_timed_func
del _persistent_cache_impl_ins del _persistent_cache_impl_ins
......
...@@ -34,22 +34,15 @@ namespace mgb::imperative::python { ...@@ -34,22 +34,15 @@ namespace mgb::imperative::python {
std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py; std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py;
py::object cpp_apply_with_tracing, cpp_apply_const_with_tracing, PyObject *cpp_apply_with_tracing, *cpp_apply_const_with_tracing,
cpp_apply_compiled_mode, cpp_apply_const_compiled_mode; *cpp_apply_compiled_mode, *cpp_apply_const_compiled_mode;
py::object cpp_apply_backward_varnode; PyObject *cpp_apply_backward_varnode;
void release_trace_apply_func(){
cpp_apply_with_tracing.release();
cpp_apply_const_with_tracing.release();
cpp_apply_compiled_mode.release();
cpp_apply_const_compiled_mode.release();
cpp_apply_backward_varnode.release();
}
#define REGISTE_APPLY_FUNC(mode) \ #define REGISTE_APPLY_FUNC(mode) \
void set_##mode(py::object pyf) { \ void set_##mode(py::object pyf) { \
mode = pybind11::reinterpret_steal<py::object>(pyf); \ mode = pyf.ptr(); \
} }
REGISTE_APPLY_FUNC(cpp_apply_with_tracing) REGISTE_APPLY_FUNC(cpp_apply_with_tracing)
...@@ -242,14 +235,15 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { ...@@ -242,14 +235,15 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
// const op // const op
if (is_const && is_tracing) { if (is_const && is_tracing) {
py::object pyf; PyObject *pyf;
if (is_compiled) { if (is_compiled) {
pyf = cpp_apply_const_compiled_mode; pyf = cpp_apply_const_compiled_mode;
} else { } else {
pyf = cpp_apply_const_with_tracing; pyf = cpp_apply_const_with_tracing;
} }
auto ret = pyf(*tup); auto ret = py::reinterpret_steal<py::object>(
PyObject_Call(pyf, tup.ptr(), nullptr));
auto py_ret = py::reinterpret_borrow<py::list>(ret); auto py_ret = py::reinterpret_borrow<py::list>(ret);
if (auto* t = try_cast(py_ret[0].ptr())) { if (auto* t = try_cast(py_ret[0].ptr())) {
m_tensor = t->m_tensor; m_tensor = t->m_tensor;
...@@ -744,8 +738,6 @@ void init_tensor(py::module m) { ...@@ -744,8 +738,6 @@ void init_tensor(py::module m) {
}, },
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("release_trace_apply_func", &release_trace_apply_func);
py::handle grad_key_type = GradKeyWrapper::wrap_t::type() py::handle grad_key_type = GradKeyWrapper::wrap_t::type()
.def<&GradKeyWrapper::attach>("attach") .def<&GradKeyWrapper::attach>("attach")
.def<&GradKeyWrapper::is_attached_to>("is_attached_to") .def<&GradKeyWrapper::is_attached_to>("is_attached_to")
......
...@@ -253,8 +253,8 @@ auto apply(std::shared_ptr<OpDef> op, T&& tensors) ...@@ -253,8 +253,8 @@ auto apply(std::shared_ptr<OpDef> op, T&& tensors)
void init_tensor(pybind11::module); void init_tensor(pybind11::module);
extern pybind11::object cpp_apply_with_tracing, cpp_apply_compiled_mode; extern PyObject *cpp_apply_with_tracing, *cpp_apply_compiled_mode;
extern pybind11::object cpp_apply_backward_varnode; extern PyObject *cpp_apply_backward_varnode;
} // namespace mgb::imperative::python } // namespace mgb::imperative::python
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "./trace.h" #include "./trace.h"
...@@ -23,12 +24,13 @@ apply_result_t apply_trace(ApplyContext& ctx) { ...@@ -23,12 +24,13 @@ apply_result_t apply_trace(ApplyContext& ctx) {
if (ctx.backward) { if (ctx.backward) {
// reach here when symbolic=True or compiled=True // reach here when symbolic=True or compiled=True
// call megbrain_graph.py apply(BackwardGraph, *args) // call megbrain_graph.py apply(BackwardGraph, *args)
auto args = py::tuple(ctx.nargs); auto args = py::tuple(ctx.nargs + 1);
args[0] = py::cast(ctx.op);
for (size_t i = 0; i < ctx.nargs; i++) { for (size_t i = 0; i < ctx.nargs; i++) {
args[i] = py::cast(ctx.args[i]->m_var); args[i + 1] = py::cast(ctx.args[i]->m_var);
} }
py::object ret = cpp_apply_backward_varnode(py::cast(ctx.op), *args); py::object ret = py::reinterpret_steal<py::object>(
PyObject_Call(cpp_apply_backward_varnode, args.ptr(), nullptr));
if (!ret) { if (!ret) {
throw py::value_error("invalid py object call"); throw py::value_error("invalid py object call");
} }
...@@ -36,13 +38,13 @@ apply_result_t apply_trace(ApplyContext& ctx) { ...@@ -36,13 +38,13 @@ apply_result_t apply_trace(ApplyContext& ctx) {
// assumption: python function always returns PyList // assumption: python function always returns PyList
auto tup = py::reinterpret_borrow<py::list>(ret); auto tup = py::reinterpret_borrow<py::list>(ret);
for (auto i = 0; i < tup.size(); i++) { for (auto i = 0; i < tup.size(); i++) {
auto pitem = tup[i].cast<cg::VarNode *>(); auto pitem = tup[i].cast<cg::VarNode*>();
outputs.emplace_back(std::make_shared<Tensor>(pitem)); outputs.emplace_back(std::make_shared<Tensor>(pitem));
} }
return outputs; return outputs;
} }
py::object pyf; PyObject* pyf;
if (is_compiled) { if (is_compiled) {
// run apply in compiled mode, step 2, 3, etc // run apply in compiled mode, step 2, 3, etc
pyf = cpp_apply_compiled_mode; pyf = cpp_apply_compiled_mode;
...@@ -51,11 +53,15 @@ apply_result_t apply_trace(ApplyContext& ctx) { ...@@ -51,11 +53,15 @@ apply_result_t apply_trace(ApplyContext& ctx) {
pyf = cpp_apply_with_tracing; pyf = cpp_apply_with_tracing;
} }
auto args = py::tuple(ctx.nargs); auto args = py::tuple(ctx.nargs + 1);
args[0] = py::cast(ctx.op);
for (size_t i = 0; i < ctx.nargs; i++) { for (size_t i = 0; i < ctx.nargs; i++) {
args[i] = TensorWrapper::make(std::move(std::shared_ptr<Tensor>(ctx.args[i]))).release(); args[i + 1] = TensorWrapper::make(
std::move(std::shared_ptr<Tensor>(ctx.args[i])))
.release();
} }
auto ret = pyf(py::cast(ctx.op), *args); auto ret = py::reinterpret_steal<py::object>(
PyObject_Call(pyf, args.ptr(), nullptr));
// assumption: python function always returns PyList // assumption: python function always returns PyList
auto tup = py::reinterpret_borrow<py::list>(ret); auto tup = py::reinterpret_borrow<py::list>(ret);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册