提交 5a1f9134 编写于 作者: M Megvii Engine Team

chore(imperative): remove unnecessary function template

GitOrigin-RevId: 8dd2f8c308061fd510a6a82f09c94f2214e6f4e4
上级 2de2222e
......@@ -33,6 +33,18 @@ auto normalize_enum(const std::string& in) {
}
} // anonymous namespace
#define CATCH_ALL(RETVAL) \
catch(py::error_already_set& e) { \
e.restore(); \
return RETVAL; \
} catch(py::builtin_exception& e) { \
e.set_error(); \
return RETVAL; \
} catch(std::exception& e) { \
PyErr_SetString(PyExc_RuntimeError, e.what()); \
return RETVAL; \
} \
namespace {
#define PyOp(name) Py##name
#define PyOpType(name) PyOp(name)::py_type
......@@ -99,14 +111,6 @@ PyObject* py_get_generic_impl(PyObject* obj, void* /* closure */) {
#define py_get_generic(name, attr) \
py_get_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>
template<typename T>
PyObject* py_get_scope_impl(PyObject* obj, void* /* closure */) {
// T: PyOpXXX inst(): return XXX in opdef.h.inl
auto& op = reinterpret_cast<T*>(obj)->inst();
return pyobj_convert_generic<std::string>::to(op.scope());
}
#define py_get_scope(class) py_get_scope_impl<PyOp(class)>
template<typename T, typename U, U T::Ty::*attr>
int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) {
if (value == NULL) {
......@@ -116,51 +120,46 @@ int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) {
auto& op = reinterpret_cast<T*>(obj)->inst();
try {
op.*attr = pyobj_convert_generic<U>::from(value);
return 0;
} catch(py::error_already_set& e) {
e.restore();
} catch(py::builtin_exception& e) {
e.set_error();
} catch(...) {
PyErr_SetString(PyExc_RuntimeError, "Unknown Error");
}
return -1;
} CATCH_ALL(-1)
return 0;
}
#define py_set_generic(name, attr) \
py_set_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>
template<typename T>
int py_set_scope_impl(PyObject* obj, PyObject* value, void* /* closure */) {
if (value == NULL) {
PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute");
return -1;
}
auto& op = reinterpret_cast<T*>(obj)->inst();
try {
op.set_scope(pyobj_convert_generic<std::string>::from(value));
return 0;
} catch(py::error_already_set& e) {
e.restore();
} catch(py::builtin_exception& e) {
e.set_error();
} catch(...) {
PyErr_SetString(PyExc_RuntimeError, "Unknown Error");
}
return -1;
}
#define py_set_scope(class) py_set_scope_impl<PyOp(class)>
struct PyOpDef {
PyObject_HEAD
std::shared_ptr<OpDef> op;
static PyTypeObject py_type;
static std::unordered_map<mgb::Typeinfo*, PyTypeObject*> ctype2pytype;
static PyGetSetDef py_getsetters[];
static Py_hash_t tp_hash(PyObject *obj);
static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op);
};
PyTypeObject PyOpType(OpDef);
std::unordered_map<mgb::Typeinfo*, PyTypeObject*> PyOp(OpDef)::ctype2pytype;
PyObject* py_get_scope(PyObject* obj, void* /* closure */) {
return pyobj_convert_generic<std::string>::to(
reinterpret_cast<PyOp(OpDef)*>(obj)->op->scope());
}
int py_set_scope(PyObject* obj, PyObject* value, void* /* closure */) {
if (value == NULL) {
PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute");
return -1;
}
try {
reinterpret_cast<PyOp(OpDef)*>(obj)->op
->set_scope(pyobj_convert_generic<std::string>::from(value));
} CATCH_ALL(-1)
return 0;
}
PyGetSetDef PyOp(OpDef)::py_getsetters[] = {
{const_cast<char*>("scope"), py_get_scope, py_set_scope, "scope", NULL},
{NULL}
};
Py_hash_t PyOp(OpDef)::tp_hash(PyObject *obj) {
return static_cast<Py_hash_t>(
reinterpret_cast<PyOp(OpDef)*>(obj)->op->hash());
......@@ -225,6 +224,7 @@ struct pyobj_convert_generic<T,
};
void _init_py_op_def(py::module m) {
using py_op = PyOp(OpDef);
auto& py_type = PyOpType(OpDef);
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.OpDef";
......@@ -234,6 +234,7 @@ void _init_py_op_def(py::module m) {
py_type.tp_base = &PyBaseObject_Type;
py_type.tp_hash = PyOp(OpDef)::tp_hash;
py_type.tp_richcompare = PyOp(OpDef)::tp_richcompare;
py_type.tp_getset = py_op::py_getsetters;
mgb_assert(PyType_Ready(&py_type) >= 0);
m.add_object("OpDef", reinterpret_cast<PyObject*>(&py_type));
}
......@@ -309,6 +310,8 @@ void _init_py_op_base(py::module m) {
// auto generated opdefs
#include "opdef.cpy.inl"
#undef CATCH_ALL
} // anonymous namespace
namespace PYBIND11_NAMESPACE {
......
......@@ -485,52 +485,44 @@ EnumWrapper<{0}::{1}>::type2str = {{
className, i.name));
}
getsetters.push_back(formatv(
"{{\"scope\", py_get_scope({0}), py_set_scope({0}), \"scope\", NULL},",
className));
// generate tp_init
std::string initBody;
if (!op.getMgbAttributes().empty()) {
initBody += "static const char* kwlist[] = {";
std::vector<llvm::StringRef> attr_name_list;
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
initBody += formatv("\"{0}\", ", attr.name);
attr_name_list.push_back(attr.name);
});
attr_name_list.push_back("scope");
llvm::for_each(attr_name_list, [&](auto&& attr) {
initBody += formatv("\"{0}\", ", attr);
});
initBody += "\"scope\", ";
initBody += "NULL};\n";
initBody += " PyObject ";
std::vector<std::string> attrs;
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
attrs.push_back(formatv("*{0} = NULL", attr.name));
std::vector<std::string> attr_init;
llvm::for_each(attr_name_list, [&](auto&& attr) {
attr_init.push_back(formatv("*{0} = NULL", attr));
});
initBody += llvm::join(attrs, ", ") + ";\n";
initBody += " PyObject *scope = NULL;\n";
initBody += llvm::join(attr_init, ", ") + ";\n";
initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|";
// an extra slot created for name
initBody += std::string(op.getMgbAttributes().size() + 1, 'O');
initBody += std::string(attr_name_list.size(), 'O');
initBody += "\", const_cast<char**>(kwlist)";
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
initBody += formatv(", &{0}", attr.name);
llvm::for_each(attr_name_list, [&](auto&& attr) {
initBody += formatv(", &{0}", attr);
});
initBody += ", &scope";
initBody += "))\n";
initBody += " return -1;\n";
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
initBody += formatv(R"(
if ({1}) {{
try {{
reinterpret_cast<PyOp({0})*>(self)->inst().{1} =
pyobj_convert_generic<decltype({0}::{1})>::from({1});
} catch(py::error_already_set& e) {{
e.restore();
return -1;
} catch(py::builtin_exception& e) {{
e.set_error();
return -1;
} catch(...) {{
PyErr_SetString(PyExc_RuntimeError, "Unknown Error");
return -1;
}
} CATCH_ALL(-1)
}
)", className, attr.name);
});
......@@ -538,18 +530,9 @@ EnumWrapper<{0}::{1}>::type2str = {{
initBody += formatv(R"(
if (scope) {{
try {{
reinterpret_cast<PyOp({0})*>(self)->inst().set_scope(
pyobj_convert_generic<std::string>::from(scope));
} catch(py::error_already_set& e) {{
e.restore();
return -1;
} catch(py::builtin_exception& e) {{
e.set_error();
return -1;
} catch(...) {{
PyErr_SetString(PyExc_RuntimeError, "Unknown Error");
return -1;
}
reinterpret_cast<PyOp(OpDef)*>(self)->op
->set_scope(pyobj_convert_generic<std::string>::from(scope));
} CATCH_ALL(-1)
}
)", className);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册