提交 282dfc62 编写于 作者: M Megvii Engine Team

refactor(imperative): alloc enum type class on heap

GitOrigin-RevId: d2b2acea229df68151f04ce17c1e73621dd7fb60
上级 1e6ef377
......@@ -170,7 +170,7 @@ struct EnumTrait;
PyObject_HEAD \
T value; \
constexpr static const char *name = EnumTrait<T>::name; \
static PyTypeObject type; \
static PyTypeObject* type; \
static const char* members[]; \
static std::unordered_map<std::string, T> mem2value; \
static PyObject* pyobj_insts[];
......@@ -196,7 +196,7 @@ struct EnumWrapper {
}
static bool load(py::handle src, T& value) {
PyObject* obj = src.ptr();
if (PyObject_TypeCheck(obj, &type)) {
if (PyObject_TypeCheck(obj, type)) {
value = reinterpret_cast<EnumWrapper*>(obj)->value;
return true;
}
......@@ -224,7 +224,6 @@ struct EnumWrapper {
template<typename T>
struct BitCombinedEnumWrapper {
PyEnumHead
static PyNumberMethods number_methods;
std::string to_string() const {
uint32_t value_int = static_cast<uint32_t>(value);
if (value_int == 0) {
......@@ -302,7 +301,7 @@ struct BitCombinedEnumWrapper {
}
static bool load(py::handle src, T& value) {
PyObject* obj = src.ptr();
if (PyObject_TypeCheck(obj, &type)) {
if (PyObject_TypeCheck(obj, type)) {
value = reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value;
return true;
}
......@@ -330,8 +329,7 @@ struct BitCombinedEnumWrapper {
auto v = static_cast<std::underlying_type_t<T>>(value);
mgb_assert(v <= EnumTrait<T>::max);
if ((!v) || (v & (v - 1))) {
PyTypeObject* pytype = &type;
PyObject* obj = pytype->tp_alloc(pytype, 0);
PyObject* obj = type->tp_alloc(type, 0);
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = value;
return obj;
} else {
......
......@@ -69,3 +69,16 @@ def test_raw_tensor():
np.testing.assert_allclose(x * x, yy.numpy())
(yy,) = apply(Elemwise(Elemwise.Mode.MUL), xx, xx)
np.testing.assert_allclose(x * x, yy.numpy())
def test_opdef_path():
from megengine.core.ops.builtin import Elemwise
assert Elemwise.__module__ == "megengine.core._imperative_rt.ops"
assert Elemwise.__name__ == "Elemwise"
assert Elemwise.__qualname__ == "Elemwise"
Mode = Elemwise.Mode
assert Mode.__module__ == "megengine.core._imperative_rt.ops"
assert Mode.__name__ == "Mode"
assert Mode.__qualname__ == "Elemwise.Mode"
......@@ -97,7 +97,7 @@ void EnumAttrEmitter::emit_tpl_spl() {
if (!firstOccur) return;
os << tgfmt(
"template<> PyTypeObject $enumTpl<$opClass::$enumClass>::type = {};\n",
"template<> PyTypeObject* $enumTpl<$opClass::$enumClass>::type = nullptr;\n",
&ctx);
auto quote = [&](auto&& i) -> std::string {
......@@ -120,13 +120,6 @@ $enumTpl<$opClass::$enumClass>::mem2value = {$0};
"template<> PyObject* "
"$enumTpl<$opClass::$enumClass>::pyobj_insts[$0] = {nullptr};\n",
&ctx, attr->getEnumMembers().size());
if (attr->getEnumCombinedFlag()) {
os << tgfmt(
"template<> PyNumberMethods "
"$enumTpl<$opClass::$enumClass>::number_methods = {};\n",
&ctx);
}
}
Initproc EnumAttrEmitter::emit_initproc() {
......@@ -140,45 +133,70 @@ void $0(PyTypeObject& py_type) {
if (firstOccur) {
os << tgfmt(R"(
e_type = {PyVarObject_HEAD_INIT(NULL, 0)};
e_type.tp_name = "megengine.core._imperative_rt.ops.$opClass.$enumClass";
e_type.tp_basicsize = sizeof($enumTpl<$opClass::$enumClass>);
e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
e_type.tp_doc = "$opClass.$enumClass";
e_type.tp_base = &PyBaseObject_Type;
e_type.tp_repr = $enumTpl<$opClass::$enumClass>::py_repr;
e_type.tp_richcompare = $enumTpl<$opClass::$enumClass>::tp_richcompare;
static PyType_Slot slots[] = {
{Py_tp_repr, (void*)$enumTpl<$opClass::$enumClass>::py_repr},
{Py_tp_richcompare, (void*)$enumTpl<$opClass::$enumClass>::tp_richcompare},
)", &ctx);
if (attr->getEnumCombinedFlag()) {
// only bit combined enum could new instance because bitwise operation,
// others should always use singleton
os << tgfmt(R"(
e_type.tp_new = $enumTpl<$opClass::$enumClass>::py_new_combined_enum;
auto& number_method = $enumTpl<$opClass::$enumClass>::number_methods;
number_method.nb_or = $enumTpl<$opClass::$enumClass>::py_or;
number_method.nb_and = $enumTpl<$opClass::$enumClass>::py_and;
e_type.tp_as_number = &number_method;
{Py_tp_new, (void*)$enumTpl<$opClass::$enumClass>::py_new_combined_enum},
{Py_nb_or, (void*)$enumTpl<$opClass::$enumClass>::py_or},
{Py_nb_and, (void*)$enumTpl<$opClass::$enumClass>::py_and},
)", &ctx);
}
os << R"(
{0, NULL}
};)";
os << tgfmt(R"(
static PyType_Spec spec = {
// name
"megengine.core._imperative_rt.ops.$opClass.$enumClass",
// basicsize
sizeof($enumTpl<$opClass::$enumClass>),
// itemsize
0,
// flags
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE,
// slots
slots
};)", &ctx);
os << tgfmt(R"(
e_type = reinterpret_cast<PyTypeObject*>(PyType_FromSpec(&spec));
)", &ctx);
os << " mgb_assert(PyType_Ready(&e_type) >= 0);\n";
for (auto&& i : {
std::pair<std::string, std::string>{"__name__", tgfmt("$enumClass", &ctx)},
{"__module__", "megengine.core._imperative_rt.ops"},
{"__qualname__", tgfmt("$opClass.$enumClass", &ctx)}}) {
os << formatv(R"(
mgb_assert(
e_type->tp_setattro(
reinterpret_cast<PyObject*>(e_type),
py::cast("{0}").release().ptr(),
py::cast("{1}").release().ptr()) >= 0);
)", i.first, i.second);
}
auto&& members = attr->getEnumMembers();
for (size_t idx = 0; idx < members.size(); ++ idx) {
os << tgfmt(R"({
PyObject* inst = e_type.tp_alloc(&e_type, 0);
PyObject* inst = e_type->tp_alloc(e_type, 0);
reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0;
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "$0", inst) >= 0);
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "$0", inst) >= 0);
$enumTpl<$opClass::$enumClass>::pyobj_insts[$1] = inst;
})", &ctx, members[idx], idx);
}
os << " PyType_Modified(&e_type);\n";
}
os << tgfmt(R"(
Py_INCREF(e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "$enumClass", reinterpret_cast<PyObject*>(&e_type)) >= 0);
py_type.tp_dict, "$enumClass", reinterpret_cast<PyObject*>(e_type)) >= 0);
)", &ctx);
os << "}\n";
return initproc;
......
......@@ -11,6 +11,7 @@ endif()
# TODO: turn python binding into a static/object library
add_executable(imperative_test ${SOURCES} ${SRCS})
add_dependencies(imperative_test mgb_opdef)
target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR})
# Python binding
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册