提交 9279104b 编写于 作者: M Megvii Engine Team

feat(mge): add opdef serialization and apply_module_trace

GitOrigin-RevId: 5b45bded1de8e1fb36447d4469423ef68ff627e8
上级 aa204040
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Dict
from ...core._imperative_rt import OpDef
from ...core.ops import builtin
from ...version import __version__
OPDEF_PARAM_LOADER = {}
def get_opdef_state(obj: OpDef) -> Dict:
state = obj.__getstate__()
state["type"] = type(obj)
state["version"] = __version__
return state
def load_opdef_from_state(state: Dict) -> OpDef:
assert "type" in state and issubclass(state["type"], OpDef)
assert "version" in state
opdef_type = state.pop("type")
if opdef_type in OPDEF_PARAM_LOADER:
loader = OPDEF_PARAM_LOADER[opdef_type]
state = loader(state)
state.pop("version")
opdef_obj = opdef_type()
opdef_obj.__setstate__(state)
return opdef_obj
/**
* \file imperative/python/src/module_trace.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "./module_trace.h"
#include "./helper.h" // include op pybind11 caster
namespace py = pybind11;
namespace mgb::imperative::python {
apply_result_t apply_module_trace(ApplyContext& ctx) {
apply_result_t outputs;
auto args = py::tuple(ctx.nargs + 1);
args[0] = py::cast(ctx.op);
for (size_t i = 0; i < ctx.nargs; i++) {
args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this());
}
auto pyout = PyObject_Call(cpp_apply_module_trace, args.ptr(), nullptr);
if (!pyout) throw py::error_already_set();
auto ret = py::reinterpret_steal<py::object>(pyout);
// assumption: python function always returns PyList
auto tup = py::reinterpret_borrow<py::list>(ret);
for (auto i = 0; i < tup.size(); i++) {
auto tw = TensorWrapper::try_cast(tup[i].ptr());
outputs.emplace_back(tw->m_tensor);
}
return outputs;
}
} // namespace mgb::imperative::python
/**
* \file imperative/python/src/module_trace.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "./tensor.h"
namespace mgb::imperative::python {
apply_result_t apply_module_trace(ApplyContext& ctx);
} // namespace mgb::imperative::python
......@@ -88,6 +88,19 @@ PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) {
return obj;
}
template<typename T, typename SNIFAE=void>
struct serialization {
static T load(py::object obj) {
return py::cast<T>(obj);
}
template<typename U,
typename = std::enable_if_t<std::is_same_v<T, std::decay_t<U>>>>
static py::object dump(U&& t) {
return py::cast(std::forward<U>(t));
}
};
template<typename T>
void py_dealloc_generic(PyObject* obj) {
reinterpret_cast<T*>(obj)->op.reset();
......@@ -127,6 +140,13 @@ struct PyOpDef {
static PyGetSetDef py_getsetters[];
static Py_hash_t tp_hash(PyObject *obj);
static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op);
static PyObject* py_repr(PyObject* self) {
return py::cast(
reinterpret_cast<PyOpDef*>(self)->op->make_name())
.release()
.ptr();
}
};
PyTypeObject PyOpType(OpDef);
std::unordered_map<mgb::Typeinfo*, PyTypeObject*> PyOp(OpDef)::ctype2pytype;
......@@ -191,6 +211,13 @@ struct EnumWrapper {
std::string(name) + "." + reinterpret_cast<EnumWrapper*>(self)->to_string())
.release().ptr();
}
static PyObject* py_dump(PyObject* self) {
return py::cast(reinterpret_cast<EnumWrapper*>(self)->to_string())
.release()
.ptr();
}
static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op) {
if (op == Py_EQ || op == Py_NE) {
T lhs, rhs;
......@@ -279,6 +306,19 @@ struct BitCombinedEnumWrapper {
reinterpret_cast<BitCombinedEnumWrapper*>(self)->to_string())
.release().ptr();
}
static PyObject* py_dump(PyObject* self) {
std::vector<std::string> result;
auto value = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value;
uint32_t value_int = static_cast<uint32_t>(value);
for (uint32_t i = 0; i < 32; i++) {
if (value_int >> i & 1) {
result.push_back(members[i]);
}
}
return py::tuple(py::cast(result)).release().ptr();
}
static PyObject* py_or(PyObject* self, PyObject* other) {
if(!(self->ob_type == other->ob_type)){
return PyErr_Format(
......@@ -326,6 +366,24 @@ struct BitCombinedEnumWrapper {
return false;
}
}
if (py::isinstance<py::tuple>(src)) {
auto params = py::cast<std::vector<std::string>>(src);
bool first = true;
for (auto s : params){
auto&& iter = mem2value.find(normalize_enum(s));
if (iter != mem2value.end()) {
if (first) {
value = iter->second;
first = false;
} else {
value |= iter->second;
}
} else {
return false;
}
}
return true;
}
if (py::isinstance<py::int_>(obj)) {
auto v = py::cast<std::underlying_type_t<T>>(src);
if(v > EnumTrait<T>::max) {
......@@ -351,6 +409,25 @@ struct BitCombinedEnumWrapper {
}
};
template<typename T>
struct serialization<T,
std::enable_if_t<std::is_enum_v<std::decay_t<T>>>> {
static T load(py::object obj) {
auto caster = pybind11::detail::type_caster<T>();
if (caster.load(obj, true)) {
return caster;
} else {
PyErr_SetString(PyExc_RuntimeError,
"load faild \n");
return caster;
}
}
static py::object dump(T t) {
return py::cast(t).attr("dump")();
}
};
void _init_py_op_def(py::module m) {
using py_op = PyOp(OpDef);
auto& py_type = PyOpType(OpDef);
......@@ -363,6 +440,7 @@ void _init_py_op_def(py::module m) {
py_type.tp_hash = PyOp(OpDef)::tp_hash;
py_type.tp_richcompare = PyOp(OpDef)::tp_richcompare;
py_type.tp_getset = py_op::py_getsetters;
py_type.tp_repr = py_op::py_repr;
mgb_assert(PyType_Ready(&py_type) >= 0);
m.add_object("OpDef", reinterpret_cast<PyObject*>(&py_type));
}
......
......@@ -451,18 +451,11 @@ public:
template<typename... Args>
static PyObject* cnew(Args&&... args) {
auto* pytype = type().operator->();
auto* self = pytype->tp_alloc(pytype, 0);
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
if constexpr (has_vectorcall && tp_vectorcall::valid) {
reinterpret_cast<wrap_t*>(self)->vectorcall_slot = &tp_vectorcall::template impl<>;
}
new(inst) T(std::forward<Args>(args)...);
return self;
return cnew_with_type(pytype, std::forward<Args>(args)...);
}
template<typename... Args>
static PyObject* cnew_with_type(PyTypeObject* pytype, Args&&... args) {
auto* self = pytype->tp_alloc(pytype, 0);
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
if constexpr (has_vectorcall && tp_vectorcall::valid) {
......
......@@ -20,6 +20,7 @@
#include "./tensor.h"
#include "./grad.h"
#include "./trace.h"
#include "./module_trace.h"
#include "./common.h"
#include "./numpy_dtypes.h"
#include "./graph_rt.h"
......@@ -41,6 +42,7 @@ interpreter::Interpreter::Channel* interpreter_for_py;
PyObject *cpp_apply_with_tracing, *cpp_apply_const_with_tracing;
PyObject *cpp_apply_backward_varnode;
PyObject *cpp_apply_module_trace;
std::shared_ptr<Tensor> make_const(imperative::TensorPtr value) {
if (!(ApplyContext::global_enable & Tensor::Flags::TRACE)) {
......@@ -70,6 +72,7 @@ std::shared_ptr<Tensor> make_const(imperative::TensorPtr value) {
REGISTE_APPLY_FUNC(cpp_apply_with_tracing)
REGISTE_APPLY_FUNC(cpp_apply_const_with_tracing)
REGISTE_APPLY_FUNC(cpp_apply_backward_varnode)
REGISTE_APPLY_FUNC(cpp_apply_module_trace)
#undef REGISTE_APPLY_FUNC
......@@ -79,6 +82,14 @@ Tensor::flags_t ApplyContext::global_enable = 0;
void set_tracing() { ApplyContext::global_enable |= Tensor::Flags::TRACE; }
void unset_tracing() { ApplyContext::global_enable &= ~Tensor::Flags::TRACE; }
void set_module_tracing() { ApplyContext::global_enable |= Tensor::Flags::MODULE_TRACE; }
void unset_module_tracing() { ApplyContext::global_enable &= ~Tensor::Flags::MODULE_TRACE; }
bool is_tracing_module() {
return ApplyContext::global_enable & Tensor::Flags::MODULE_TRACE;
}
bool skip_tracing = false;
apply_result_t apply(ApplyContext& ctx) {
......@@ -117,6 +128,11 @@ apply_result_t apply(ApplyContext& ctx) {
return ret;
}
if (flags & Tensor::Flags::MODULE_TRACE) {
return apply_module_trace(ctx);
}
if (flags & Tensor::Flags::TRACE) {
return apply_trace(ctx);
} else {
......@@ -310,6 +326,21 @@ REGISTE_TENSORWRAPPER_FUNC(bool, recording)
#undef REGISTE_TENSORWRAPPER_FUNC
PyObject* TensorWrapper::module_trace_info() {
if (!m_tensor->m_module_trace_info.ptr()) {
PyErr_SetString(PyExc_AttributeError,
"Has no attribute named \'_NodeMixin__node\', please "
"set it first");
return nullptr;
}
return m_tensor->m_module_trace_info.inc_ref().ptr();
}
void TensorWrapper::set_module_trace_info(PyObject* obj) {
m_tensor->m_module_trace_info = py::reinterpret_borrow<py::object>(obj);
}
#define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \
PyObject* TensorWrapper::member() { \
......@@ -495,7 +526,9 @@ void TensorWrapper::reset(PyObject* tensor) {
}
std::string user_custom_name = m_tensor->user_custom_name;
std::string automatic_name = m_tensor->automatic_name;
auto module_trace_info = m_tensor->m_module_trace_info;
m_tensor = t->m_tensor;
m_tensor->m_module_trace_info = module_trace_info;
m_tensor->user_custom_name = user_custom_name;
m_tensor->automatic_name = automatic_name;
}
......@@ -856,6 +889,7 @@ void init_tensor(py::module m) {
.def_getset<&TensorWrapper::trace_mixin_info, &TensorWrapper::set_trace_mixin_info>("_trace_mixin_info")
.def_getset<&TensorWrapper::user_custom_name, &TensorWrapper::set_user_custom_name>("c_name")
.def_getset<&TensorWrapper::automatic_name, &TensorWrapper::set_automatic_name>("_name")
.def_getset<&TensorWrapper::module_trace_info, &TensorWrapper::set_module_trace_info>("_NodeMixin__node")
.finalize();
if (!tensor_type) throw py::error_already_set();
py::setattr(m, "Tensor", tensor_type);
......@@ -998,7 +1032,7 @@ void init_tensor(py::module m) {
m.def("set_cpp_apply_with_tracing", &set_cpp_apply_with_tracing);
m.def("set_cpp_apply_const_with_tracing", &set_cpp_apply_const_with_tracing);
m.def("set_cpp_apply_backward_varnode", &set_cpp_apply_backward_varnode);
m.def("set_cpp_apply_module_trace", &set_cpp_apply_module_trace);
m.attr("skip_tracing") = &skip_tracing;
py::class_<SharedHandle>(m, "SharedHandle")
......@@ -1016,6 +1050,9 @@ void init_tensor(py::module m) {
m.def("set_allow_higher_order_directive", [](bool value){
GradKey::allow_higher_order_directive = value;
});
m.def("set_module_tracing", &set_module_tracing);
m.def("unset_module_tracing", &unset_module_tracing);
m.def("is_tracing_module", &is_tracing_module);
}
#undef MGE_PY_INTERFACE
......
......@@ -96,6 +96,7 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
static constexpr flags_t SCALAR = 1;
static constexpr flags_t GRAD = 1 << 1;
static constexpr flags_t TRACE = 1 << 2;
static constexpr flags_t MODULE_TRACE = 1 << 3;
};
flags_t m_flags = 0;
......@@ -106,6 +107,7 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
std::string user_custom_name;
std::string automatic_name;
cg::VarNode* m_var;
pybind11::object m_module_trace_info;
using Handle = interpreter::Interpreter::Handle;
......@@ -158,10 +160,10 @@ struct TensorWrapper {
using wrap_t = pyext17::wrap<TensorWrapper>;
friend wrap_t;
inline static TensorWrapper* cast(PyObject* op) {return reinterpret_cast<wrap_t*>(op)->inst();}
inline static TensorWrapper* try_cast(PyObject* op) {
if (!wrap_t::type().isinstance(op)) return nullptr;
return cast(op);
inline static TensorWrapper* cast(PyObject* obj) {return reinterpret_cast<wrap_t*>(obj)->inst();}
inline static TensorWrapper* try_cast(PyObject* obj) {
if (!wrap_t::type().isinstance(obj)) return nullptr;
return cast(obj);
}
inline ObjectPtr<TensorWrapper, pybind11::handle> self() {return wrap_t::pycast(this);}
......@@ -206,6 +208,8 @@ struct TensorWrapper {
void set_compiled_info(PyObject *);
PyObject* trace_mixin_info();
void set_trace_mixin_info(PyObject *);
PyObject* module_trace_info();
void set_module_trace_info(PyObject *);
PyObject* user_custom_name();
void set_user_custom_name(PyObject *);
PyObject* automatic_name();
......@@ -331,6 +335,7 @@ void init_tensor(pybind11::module);
extern PyObject *cpp_apply_with_tracing;
extern PyObject *cpp_apply_backward_varnode;
extern PyObject *cpp_apply_module_trace;
} // namespace mgb::imperative::python
......
......@@ -14,6 +14,11 @@ import numpy as np
import megengine as mge
from megengine import Parameter, Tensor
from megengine.core.ops import builtin
from megengine.experimental.traced_module.serialization import (
get_opdef_state,
load_opdef_from_state,
)
def test_tensor_serialization():
......@@ -86,3 +91,25 @@ def test_compatibility():
test_old_tensor("tensor_v1_1.mge")
test_old_tensor("tensor_v1_2.mge")
def test_opdef_serialization():
with TemporaryFile() as f:
x = builtin.Elemwise(mode="Add")
pickle.dump(get_opdef_state(x), f)
f.seek(0)
load_x = load_opdef_from_state(pickle.load(f))
assert x == load_x
with TemporaryFile() as f:
x = builtin.Convolution(stride_h=9, compute_mode="float32")
x.strategy = (
builtin.Convolution.Strategy.PROFILE
| builtin.Convolution.Strategy.HEURISTIC
| builtin.Convolution.Strategy.REPRODUCIBLE
)
pickle.dump(get_opdef_state(x), f)
f.seek(0)
load_x = load_opdef_from_state(pickle.load(f))
assert x.strategy == load_x.strategy
assert x == load_x
......@@ -34,6 +34,7 @@ private:
void emit_class();
void emit_py_init();
void emit_py_getsetters();
void emit_py_methods();
Initproc emit_initproc();
MgbOp& op;
......@@ -133,9 +134,16 @@ void $0(PyTypeObject& py_type) {
if (firstOccur) {
os << tgfmt(R"(
static PyMethodDef tp_methods[] = {
{const_cast<char*>("dump"), (PyCFunction)$enumTpl<$opClass::$enumClass>::py_dump, METH_NOARGS, NULL},
{NULL} /* Sentinel */
};
)", &ctx);
os << tgfmt(R"(
static PyType_Slot slots[] = {
{Py_tp_repr, (void*)$enumTpl<$opClass::$enumClass>::py_repr},
{Py_tp_richcompare, (void*)$enumTpl<$opClass::$enumClass>::tp_richcompare},
{Py_tp_methods, tp_methods},
)", &ctx);
if (attr->getEnumCombinedFlag()) {
// only bit combined enum could new instance because bitwise operation,
......@@ -212,17 +220,62 @@ Initproc OpDefEmitter::emit() {
emit_class();
emit_py_init();
emit_py_getsetters();
emit_py_methods();
return emit_initproc();
}
void OpDefEmitter::emit_class() {
auto&& className = op.getCppClassName();
std::string method_defs;
std::vector<std::string> body;
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
body.push_back(formatv(R"(
{{"{0}", serialization<decltype(opdef.{0})>::dump(opdef.{0})})"
, attr.name));
});
method_defs += formatv(R"(
static PyObject* getstate(PyObject* self, PyObject*) {{
auto& opdef = reinterpret_cast<PyOp({0})*>(self)->inst();
static_cast<void>(opdef);
std::unordered_map<std::string, py::object> state {{
{1}
};
return py::cast(state).release().ptr();
})", className, llvm::join(body, ","));
body.clear();
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
body.push_back(formatv(R"(
{{
auto&& iter = state.find("{0}");
if (iter != state.end()) {
opdef.{0} = serialization<decltype(opdef.{0})>::load(iter->second);
}
})", attr.name));
});
method_defs += formatv(R"(
static PyObject* setstate(PyObject* self, PyObject* args) {{
PyObject* dict = PyTuple_GetItem(args, 0);
if (!dict) return NULL;
auto state = py::cast<std::unordered_map<std::string, py::object>>(dict);
auto& opdef = reinterpret_cast<PyOp({0})*>(self)->inst();
static_cast<void>(opdef);
{1}
Py_RETURN_NONE;
})", className, llvm::join(body, "\n"));
os << tgfmt(R"(
PyOpDefBegin($_self) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
$0
static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
// };
PyOpDefEnd($_self)
)", &ctx);
)", &ctx, method_defs);
}
void OpDefEmitter::emit_py_init() {
......@@ -302,6 +355,33 @@ PyGetSetDef PyOp($_self)::py_getsetters[] = {
)", &ctx, llvm::join(llvm::map_range(op.getMgbAttributes(), f), "\n "));
}
void OpDefEmitter::emit_py_methods(){
// generate methods
std::string method_defs;
std::vector<std::string> method_items;
{
auto&& className = op.getCppClassName();
// generate getstate
method_items.push_back(formatv(
"{{const_cast<char*>(\"__getstate__\"), PyOp({0})::getstate, METH_NOARGS, \"{0} getstate\"},",
className));
// generate setstate
method_items.push_back(formatv(
"{{const_cast<char*>(\"__setstate__\"), PyOp({0})::setstate, METH_VARARGS, \"{0} setstate\"},",
className));
}
os << tgfmt(R"(
PyMethodDef PyOp($_self)::tp_methods[] = {
$0
{NULL} /* Sentinel */
};
)", &ctx, llvm::join(method_items, "\n "));
}
Initproc OpDefEmitter::emit_initproc() {
std::string initproc = formatv("_init_py_{0}", op.getCppClassName());
std::string subclass_init_call;
......@@ -321,6 +401,7 @@ void $0(py::module m) {
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_new_generic<py_op>;
py_type.tp_init = py_op::py_init;
py_type.tp_methods = py_op::tp_methods;
py_type.tp_getset = py_op::py_getsetters;
mgb_assert(PyType_Ready(&py_type) >= 0);
$1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册