From eeeddbbcd12fffe9709a34212e90815d35346437 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 29 Mar 2021 17:59:53 +0800 Subject: [PATCH] refactor(imperative): refactor tablegen code generator GitOrigin-RevId: b81b085762c47da9a901ec3e25778f3c75f21395 --- imperative/tablegen/CMakeLists.txt | 3 +- imperative/tablegen/autogen.cpp | 745 +----------------- imperative/tablegen/emitter.h | 40 + imperative/tablegen/helper.h | 36 + imperative/tablegen/targets/cpp_class.cpp | 309 ++++++++ imperative/tablegen/targets/cpp_class.h | 21 + imperative/tablegen/targets/pybind11.cpp | 142 ++++ imperative/tablegen/targets/pybind11.h | 19 + .../tablegen/targets/python_c_extension.cpp | 313 ++++++++ .../tablegen/targets/python_c_extension.h | 19 + 10 files changed, 916 insertions(+), 731 deletions(-) create mode 100644 imperative/tablegen/emitter.h create mode 100644 imperative/tablegen/targets/cpp_class.cpp create mode 100644 imperative/tablegen/targets/cpp_class.h create mode 100644 imperative/tablegen/targets/pybind11.cpp create mode 100644 imperative/tablegen/targets/pybind11.h create mode 100644 imperative/tablegen/targets/python_c_extension.cpp create mode 100644 imperative/tablegen/targets/python_c_extension.h diff --git a/imperative/tablegen/CMakeLists.txt b/imperative/tablegen/CMakeLists.txt index 31d3c5e87..1a5466ef9 100644 --- a/imperative/tablegen/CMakeLists.txt +++ b/imperative/tablegen/CMakeLists.txt @@ -1,6 +1,7 @@ # mgb tablegen executable set(TABLE_TARGET mgb-mlir-autogen) -add_executable(${TABLE_TARGET} autogen.cpp) +file(GLOB_RECURSE SRCS ${CMAKE_CURRENT_SOURCE_DIR}/*.h ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) +add_executable(${TABLE_TARGET} ${SRCS}) target_include_directories(${TABLE_TARGET} PRIVATE ${MLIR_LLVM_INCLUDE_DIR}) target_link_libraries(${TABLE_TARGET} PRIVATE LLVMTableGen MLIRTableGen LLVMSupport) set(MGB_TABLEGEN_EXE ${TABLE_TARGET}) diff --git a/imperative/tablegen/autogen.cpp b/imperative/tablegen/autogen.cpp index 44b3dabfb..83a861111 100644 --- a/imperative/tablegen/autogen.cpp +++ b/imperative/tablegen/autogen.cpp @@ -1,8 +1,17 @@ -#include -#include -#include - -#include "./helper.h" +/** + * \file imperative/tablegen/autogen.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./targets/cpp_class.h" +#include "./targets/pybind11.h" +#include "./targets/python_c_extension.h" using llvm::raw_ostream; using llvm::RecordKeeper; @@ -27,731 +36,7 @@ llvm::cl::opt action( clEnumValN(CPython, "gen-python-c-extension", "Generate python c extensions"))); -using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase; -using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin; -using MgbHashableAttr = mlir::tblgen::MgbHashableAttrMixin; -using MgbAliasAttr = mlir::tblgen::MgbAliasAttrMixin; -using MgbOp = mlir::tblgen::MgbOpBase; -using MgbHashableOp = mlir::tblgen::MgbHashableOpMixin; - -llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) { - // Note: we have already registered the corresponding attr wrappers - // for following basic ctypes so we needn't handle them here - /* auto&& attr_type_name = attr.getAttrDefName(); - if (attr_type_name == "UI32Attr") { - return "uint32_t"; - } - if (attr_type_name == "UI64Attr") { - return "uint64_t"; - } - if (attr_type_name == "I32Attr") { - return "int32_t"; - } - if (attr_type_name == "F32Attr") { - return "float"; - } - if (attr_type_name == "F64Attr") { - return "double"; - } - if (attr_type_name == "StrAttr") { - return "std::string"; - } - if (attr_type_name == "BoolAttr") { - return "bool"; - }*/ - - auto&& attr = llvm::cast(attr_); - if (auto e = llvm::dyn_cast(&attr)) { - return e->getEnumName(); - } - return attr.getUnderlyingType(); -} - -static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) { - os << formatv( - "class {0} : public OpDefImplBase<{0}> {{\n" - " MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n" - "public:\n", - op.getCppClassName() - ); - // handle enum alias - for (auto &&i : op.getMgbAttributes()) { - if (auto attr = llvm::dyn_cast(&i.attr)) { - os << formatv( - " using {0} = {1};\n", - attr->getEnumName(), attr->getUnderlyingType() - ); - } - } - for (auto &&i : op.getMgbAttributes()) { - auto defaultValue = i.attr.getDefaultValue().str(); - if (!defaultValue.empty()) { - defaultValue = formatv(" = {0}", defaultValue); - } - os << formatv( - " {0} {1}{2};\n", - attr_to_ctype(i.attr), i.name, defaultValue - ); - } - - auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) { - os << formatv( - " {0}({1}){2}{3}\n", - op.getCppClassName(), paramList, memInitList, body - ); - }; - - gen_ctor("", "", " = default;"); - - if (!op.getMgbAttributes().empty()) { - std::vector paramList, initList; - for (auto &&i : op.getMgbAttributes()) { - paramList.push_back(formatv( - "{0} {1}_", attr_to_ctype(i.attr), i.name - )); - initList.push_back(formatv( - "{0}({0}_)", i.name - )); - } - paramList.push_back("std::string scope_ = {}"); - gen_ctor(llvm::join(paramList, ", "), - ": " + llvm::join(initList, ", "), - " { set_scope(scope_); }"); - } - - auto packedParams = op.getPackedParams(); - if (!packedParams.empty()) { - std::vector paramList, initList; - for (auto &&p : packedParams) { - auto&& paramFields = p.getFields(); - auto&& paramType = p.getFullName(); - auto&& paramName = formatv("packed_param_{0}", paramList.size()); - paramList.push_back( - paramFields.empty() ? paramType.str() - : formatv("{0} {1}", paramType, paramName) - ); - for (auto&& i : paramFields) { - initList.push_back(formatv( - "{0}({1}.{0})", i.name, paramName - )); - } - } - for (auto&& i : op.getExtraArguments()) { - paramList.push_back(formatv( - "{0} {1}_", attr_to_ctype(i.attr), i.name - )); - initList.push_back(formatv( - "{0}({0}_)", i.name - )); - } - gen_ctor(llvm::join(paramList, ", "), - initList.empty() ? "" : ": " + llvm::join(initList, ", "), - " {}"); - } - - if (!packedParams.empty()) { - for (auto&& p : packedParams) { - auto accessor = p.getAccessor(); - if (!accessor.empty()) { - os << formatv( - " {0} {1}() const {{\n", - p.getFullName(), accessor - ); - std::vector fields; - for (auto&& i : p.getFields()) { - fields.push_back(i.name); - } - os << formatv( - " return {{{0}};\n", - llvm::join(fields, ", ") - ); - os << " }\n"; - } - } - } - - if (auto decl = op.getExtraOpdefDecl()) { - os << decl.getValue(); - } - - os << formatv( - "};\n\n" - ); -} - -static void gen_to_string_trait_for_enum(raw_ostream &os, MgbOp& op) { - for (auto &&i : op.getMgbAttributes()) { - if (auto attr = llvm::dyn_cast(&i.attr)) { - if (attr->supportToString()) { - std::vector case_body; - std::string ename = formatv("{0}::{1}", - op.getCppClassName(), attr->getEnumName()); - llvm::for_each(attr->getEnumMembers(), [&](auto&& v){ - case_body.push_back(formatv( - "case {0}::{1}: return \"{1}\";", ename, v)); - }); - os << formatv(R"( -template <> -struct ToStringTrait<{0}> { - std::string operator()({0} e) const { - switch (e) { - {1} - default: - return "{0}::Unknown"; - } - } -}; -)", ename, llvm::join(case_body, "\n")); - } - } - } -} - -static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { - auto&& className = op.getCppClassName(); - os << formatv( - "MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className - ); - auto formatMethImpl = [&](auto&& meth) { - return formatv( - "{0}_{1}_impl", className, meth - ); - }; - std::vector methods; - if (auto hashable = llvm::dyn_cast(&op)) { - os << "namespace {\n"; - - // generate hash() - mlir::tblgen::FmtContext ctx; - os << formatv( - "size_t {0}(const OpDef& def_) {{\n", - formatMethImpl("hash") - ); - os << formatv( - " auto&& op_ = def_.cast_final_safe<{0}>();\n" - " static_cast(op_);\n", - className - ); - ctx.withSelf("op_"); - os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx); - os << "}\n"; - - // generate is_same_st() - os << formatv( - "bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n", - formatMethImpl("is_same_st") - ); - os << formatv( - " auto &&a_ = lhs_.cast_final_safe<{0}>(),\n" - " &&b_ = rhs_.cast_final_safe<{0}>();\n" - " static_cast(a_);\n" - " static_cast(b_);\n", - className - ); - os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_"); - os << "}\n"; - - // generate props() - os << formatv( - "std::vector> {0}(const OpDef& def_) {{\n", - formatMethImpl("props") - ); - os << formatv( - " auto&& op_ = def_.cast_final_safe<{0}>();\n" - " static_cast(op_);\n", - className - ); - ctx.withSelf("op_"); - os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx); - os << "}\n"; - - // generate make_name() - os << formatv( - "std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name") - ); - os << formatv( - " auto&& op_ = def_.cast_final_safe<{0}>();\n" - " static_cast(op_);\n", - className - ); - ctx.withSelf("op_"); - os << mlir::tblgen::tgfmt(op.getNameFunctionTemplate(), &ctx); - os << "}\n"; - - os << "} // anonymous namespace\n"; - - methods.push_back("hash"); - methods.push_back("is_same_st"); - methods.push_back("props"); - methods.push_back("make_name"); - } - if (!methods.empty()) { - os << formatv( - "OP_TRAIT_REG({0}, {0})", op.getCppClassName() - ); - for (auto&& i : methods) { - os << formatv( - "\n .{0}({1})", i, formatMethImpl(i) - ); - } - os << ";\n\n"; - } -} - -struct EnumContext { - std::unordered_map> enumAlias; -}; - -static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { - auto className = op.getCppClassName(); - os << formatv( - "py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n", - className - ); - for (auto&& i : op.getMgbAttributes()) { - if (auto attr = llvm::dyn_cast(&i.attr)) { - unsigned int enumID; - if (auto alias = llvm::dyn_cast(attr)) { - auto&& aliasBase = alias->getAliasBase(); - enumID = - llvm::cast(aliasBase) - .getBaseRecord()->getID(); - } else { - enumID = attr->getBaseRecord()->getID(); - } - auto&& enumAlias = ctx.enumAlias; - auto&& iter = enumAlias.find(enumID); - if (iter == enumAlias.end()) { - os << formatv( - "py::enum_<{0}::{1}>({0}Inst, \"{1}\")", - className, attr->getEnumName() - ); - std::vector body; - for (auto&& i: attr->getEnumMembers()) { - os << formatv( - "\n .value(\"{2}\", {0}::{1}::{2})", - className, attr->getEnumName(), i - ); - body.push_back(formatv( - "if (str == \"{2}\") return {0}::{1}::{2};", - className, attr->getEnumName(), i - )); - } - if (attr->getEnumCombinedFlag()) { - //! define operator | - os << formatv( - "\n .def(\"__or__\", []({0}::{1} s0, {0}::{1} s1) {{ " - "\n return static_cast<{0}::{1}>(uint32_t(s0) | uint32_t(s1));" - "\n })", - className, attr->getEnumName()); - //! define operator & - os << formatv( - "\n .def(\"__and__\", []({0}::{1} s0, {0}::{1} s1) {{" - "\n return static_cast<{0}::{1}>(uint32_t(s0) & uint32_t(s1));" - "\n })", - className, attr->getEnumName()); - } - os << formatv( - "\n .def(py::init([](const std::string& in) {" - "\n auto&& str = normalize_enum(in);" - "\n {0}" - "\n throw py::cast_error(\"invalid enum value \" + in);" - "\n }));\n", - llvm::join(body, "\n ") - ); - os << formatv( - "py::implicitly_convertible();\n\n", - className, attr->getEnumName() - ); - enumAlias.emplace(enumID, - std::make_pair(className, attr->getEnumName())); - } else { - os << formatv( - "{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n", - className, attr->getEnumName(), - iter->second.first, iter->second.second - ); - } - } - } - // generate op class binding - os << formatv("{0}Inst", className); - bool hasDefaultCtor = op.getMgbAttributes().empty(); - if (!hasDefaultCtor) { - os << "\n .def(py::init<"; - std::vector targs; - for (auto &&i : op.getMgbAttributes()) { - targs.push_back(i.attr.getReturnType()); - } - os << llvm::join(targs, ", "); - os << ", std::string>()"; - for (auto &&i : op.getMgbAttributes()) { - os << formatv(", py::arg(\"{0}\")", i.name); - auto defaultValue = i.attr.getDefaultValue(); - if (!defaultValue.empty()) { - os << formatv(" = {0}", defaultValue); - } else { - hasDefaultCtor = true; - } - } - os << ", py::arg(\"scope\") = {})"; - } - if (hasDefaultCtor) { - os << "\n .def(py::init<>())"; - } - for (auto &&i : op.getMgbAttributes()) { - os << formatv( - "\n .def_readwrite(\"{0}\", &{1}::{0})", - i.name, className - ); - } - os << ";\n\n"; -} - -static std::string gen_op_def_python_c_extension_enum( - raw_ostream& os, EnumContext& ctx, MgbEnumAttr* attr, - llvm::StringRef className) { - std::string body; - unsigned int enumID; - if (auto alias = llvm::dyn_cast(attr)) { - auto&& aliasBase = alias->getAliasBase(); - enumID = llvm::cast(aliasBase).getBaseRecord()->getID(); - } else { - enumID = attr->getBaseRecord()->getID(); - } - auto&& enumAlias = ctx.enumAlias; - auto&& iter = enumAlias.find(enumID); - auto enumName = attr->getEnumName(); - body += "{\n"; - body += formatv("auto& e_type = EnumWrapper<{0}::{1}>::type;", className, - enumName); - if (iter == enumAlias.end()) { - os << formatv( - "template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n", - className, enumName); - os << formatv( - "template<> const char* EnumWrapper<{0}::{1}>::name = " - "\"{0}.{1}\";\n", - className, enumName); - std::vector pairStr; - for (auto&& i : attr->getEnumMembers()) { - pairStr.push_back( - formatv("{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", - className, enumName, i)); - } - os << formatv(R"( -template<> std::unordered_map -EnumWrapper<{0}::{1}>::str2type = {{ - {2} -}; -)", - className, enumName, llvm::join(pairStr, ", ")); - pairStr.clear(); - for (auto&& i : attr->getEnumMembers()) { - pairStr.push_back( - formatv("{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", - className, enumName, i)); - } - os << formatv(R"( -template<> std::unordered_map<{0}::{1}, std::string> -EnumWrapper<{0}::{1}>::type2str = {{ - {2} -}; -)", - className, enumName, llvm::join(pairStr, ", ")); - body += formatv(R"( - e_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; - e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}"; - e_type.tp_basicsize = sizeof(EnumWrapper<{0}::{1}>); - e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; - e_type.tp_doc = "{0}.{1}"; - e_type.tp_base = &PyBaseObject_Type; - e_type.tp_repr = EnumWrapper<{0}::{1}>::py_repr; - e_type.tp_richcompare = EnumWrapper<{0}::{1}>::tp_richcompare; - mgb_assert(PyType_Ready(&e_type) >= 0); -)", - className, enumName); - for (auto&& i : attr->getEnumMembers()) { - body += formatv(R"({{ - PyObject* inst = e_type.tp_alloc(&e_type, 0); - reinterpret_cast*>(inst)->value = {0}::{1}::{2}; - mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0); -})", - className, enumName, i); - } - enumAlias.emplace(enumID, std::make_pair(className, enumName)); - } - body += formatv(R"( - PyType_Modified(&e_type); - mgb_assert(PyDict_SetItemString( - py_type.tp_dict, "{0}", reinterpret_cast(&e_type)) >= 0); -)", - enumName); - body += "}\n"; - return body; -} - -static std::string gen_op_def_python_c_extension_bit_combined_enum( - raw_ostream& os, EnumContext& ctx, MgbEnumAttr* attr, - llvm::StringRef className) { - std::string body; - unsigned int enumID; - if (auto alias = llvm::dyn_cast(attr)) { - auto&& aliasBase = alias->getAliasBase(); - enumID = llvm::cast(aliasBase).getBaseRecord()->getID(); - } else { - enumID = attr->getBaseRecord()->getID(); - } - auto&& enumAlias = ctx.enumAlias; - auto&& iter = enumAlias.find(enumID); - auto enumName = attr->getEnumName(); - body += "{\n"; - body += formatv("auto& e_type = BitCombinedEnumWrapper<{0}::{1}>::type;", - className, enumName); - if (iter == enumAlias.end()) { - os << formatv( - "template<> PyTypeObject " - "BitCombinedEnumWrapper<{0}::{1}>::type={{};\n", - className, enumName); - os << formatv( - "template<> PyNumberMethods " - "BitCombinedEnumWrapper<{0}::{1}>::number_methods={{};\n", - className, enumName); - os << formatv( - "template<> const char* BitCombinedEnumWrapper<{0}::{1}>::name " - "= \"{0}.{1}\";\n", - className, enumName); - os << formatv( - "template<> struct EnumTrait<{0}::{1}> {{ static constexpr " - "bool is_bit_combined = true;};\n", - className, enumName); - std::vector pairStr; - for (auto&& i : attr->getEnumMembers()) { - pairStr.push_back( - formatv("{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", - className, enumName, i)); - } - os << formatv(R"( -template<> std::unordered_map -BitCombinedEnumWrapper<{0}::{1}>::str2type = {{ - {2} -}; -)", - className, enumName, llvm::join(pairStr, ", ")); - pairStr.clear(); - for (auto&& i : attr->getEnumMembers()) { - pairStr.push_back( - formatv("{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", - className, enumName, i)); - } - os << formatv(R"( -template<> std::unordered_map<{0}::{1}, std::string> -BitCombinedEnumWrapper<{0}::{1}>::type2str = {{ - {2} -}; -)", - className, enumName, llvm::join(pairStr, ", ")); - body += formatv(R"( - e_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; - e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}"; - e_type.tp_basicsize = sizeof(BitCombinedEnumWrapper<{0}::{1}>); - e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; - e_type.tp_doc = "{0}.{1}"; - e_type.tp_base = &PyBaseObject_Type; - e_type.tp_new = BitCombinedEnumWrapper<{0}::{1}>::py_new_combined_enum; - e_type.tp_init = BitCombinedEnumWrapper<{0}::{1}>::py_init; - e_type.tp_repr = BitCombinedEnumWrapper<{0}::{1}>::py_repr; - e_type.tp_richcompare = BitCombinedEnumWrapper<{0}::{1}>::tp_richcompare; - auto& number_method = BitCombinedEnumWrapper<{0}::{1}>::number_methods; - number_method.nb_or = BitCombinedEnumWrapper<{0}::{1}>::py_or; - number_method.nb_and = BitCombinedEnumWrapper<{0}::{1}>::py_and; - e_type.tp_as_number = &number_method; - mgb_assert(PyType_Ready(&e_type) >= 0); -)", - className, enumName); - for (auto&& i : attr->getEnumMembers()) { - body += formatv(R"({{ - PyObject* inst = e_type.tp_alloc(&e_type, 0); - reinterpret_cast*>(inst)->value = {0}::{1}::{2}; - mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0); -})", - className, enumName, i); - } - enumAlias.emplace(enumID, std::make_pair(className, enumName)); - } - body += formatv(R"( - PyType_Modified(&e_type); - mgb_assert(PyDict_SetItemString( - py_type.tp_dict, "{0}", reinterpret_cast(&e_type)) >= 0); -)", - enumName); - body += "}\n"; - return body; -} - -static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { - auto className = op.getCppClassName(); - std::string body; - - // generate PyType for enum class member - for (auto&& i : op.getMgbAttributes()) { - if (auto attr = llvm::dyn_cast(&i.attr)) { - if (attr->getEnumCombinedFlag()) { - body += gen_op_def_python_c_extension_bit_combined_enum( - os, ctx, attr, className); - } else { - body += gen_op_def_python_c_extension_enum(os, ctx, attr, - className); - } - } - } - - // generate getsetters - std::vector getsetters; - for (auto &&i : op.getMgbAttributes()) { - getsetters.push_back(formatv( - "{{const_cast(\"{1}\"), py_get_generic({0}, {1}), py_set_generic({0}, {1}), const_cast(\"{1}\"), NULL},", - className, i.name)); - } - - // generate tp_init - std::string initBody; - if (!op.getMgbAttributes().empty()) { - initBody += "static const char* kwlist[] = {"; - - std::vector attr_name_list; - llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { - attr_name_list.push_back(attr.name); - }); - attr_name_list.push_back("scope"); - - llvm::for_each(attr_name_list, [&](auto&& attr) { - initBody += formatv("\"{0}\", ", attr); - }); - initBody += "NULL};\n"; - initBody += " PyObject "; - std::vector attr_init; - llvm::for_each(attr_name_list, [&](auto&& attr) { - attr_init.push_back(formatv("*{0} = NULL", attr)); - }); - initBody += llvm::join(attr_init, ", ") + ";\n"; - initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|"; - // an extra slot created for name - initBody += std::string(attr_name_list.size(), 'O'); - initBody += "\", const_cast(kwlist)"; - llvm::for_each(attr_name_list, [&](auto&& attr) { - initBody += formatv(", &{0}", attr); - }); - initBody += "))\n"; - initBody += " return -1;\n"; - - llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { - initBody += formatv(R"( - if ({1}) {{ - try {{ - reinterpret_cast(self)->inst().{1} = - pyobj_convert_generic::from({1}); - } CATCH_ALL(-1) - } -)", className, attr.name); - }); - - initBody += formatv(R"( - if (scope) {{ - try {{ - reinterpret_cast(self)->op - ->set_scope(pyobj_convert_generic::from(scope)); - } CATCH_ALL(-1) - } -)", className); - - } - initBody += "\n return 0;"; - - os << formatv(R"( -PyOpDefBegin({0}) // {{ - static PyGetSetDef py_getsetters[]; - static int py_init(PyObject *self, PyObject *args, PyObject *kwds); -// }; -PyOpDefEnd({0}) -PyGetSetDef PyOp({0})::py_getsetters[] = {{ - {1} - {{NULL} /* Sentinel */ -}; -int PyOp({0})::py_init(PyObject *self, PyObject *args, PyObject *kwds) {{ - {2} -} - -void _init_py_{0}(py::module m) {{ - using py_op = PyOp({0}); - auto& py_type = PyOpType({0}); - py_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; - py_type.tp_name = "megengine.core._imperative_rt.ops.{0}"; - py_type.tp_basicsize = sizeof(PyOp({0})); - py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; - py_type.tp_doc = "{0}"; - py_type.tp_base = &PyOpType(OpDef); - py_type.tp_dealloc = py_dealloc_generic; - py_type.tp_new = py_new_generic; - py_type.tp_init = py_op::py_init; - py_type.tp_getset = py_op::py_getsetters; - mgb_assert(PyType_Ready(&py_type) >= 0); - {3} - PyType_Modified(&py_type); - m.add_object("{0}", reinterpret_cast(&py_type)); - mgb_assert(PyOp(OpDef)::ctype2pytype.emplace({0}::typeinfo(), &py_type).second); -} -)", - op.getCppClassName(), llvm::join(getsetters, "\n "), initBody, body); -} - -static void for_each_operator(raw_ostream &os, RecordKeeper &keeper, - std::function callback) { - auto op_base_class = keeper.getClass("Op"); - ASSERT(op_base_class, "could not find base class Op"); - for (auto&& i: keeper.getDefs()) { - auto&& r = i.second; - if (r->isSubClassOf(op_base_class)) { - auto op = mlir::tblgen::Operator(r.get()); - if (op.getDialectName().str() == "mgb") { - std::cerr << "\033[34;15m" << "Generating " << r->getName().str() << "\033[0m" << std::endl; - callback(os, llvm::cast(op)); - } - } - } -} - -static bool gen_op_def_c_header(raw_ostream &os, RecordKeeper &keeper) { - for_each_operator(os, keeper, gen_op_def_c_header_single); - for_each_operator(os, keeper, gen_to_string_trait_for_enum); - return false; -} - -static bool gen_op_def_c_body(raw_ostream &os, RecordKeeper &keeper) { - for_each_operator(os, keeper, gen_op_def_c_body_single); - return false; -} - -static bool gen_op_def_pybind11(raw_ostream &os, RecordKeeper &keeper) { - EnumContext ctx; - using namespace std::placeholders; - for_each_operator(os, keeper, - std::bind(gen_op_def_pybind11_single, _1, _2, std::ref(ctx))); - return false; -} - -static bool gen_op_def_python_c_extension(raw_ostream &os, RecordKeeper &keeper) { - EnumContext ctx; - using namespace std::placeholders; - for_each_operator(os, keeper, - std::bind(gen_op_def_python_c_extension_single, _1, _2, std::ref(ctx))); - os << "#define INIT_ALL_OP(m)"; - for_each_operator(os, keeper, [&](raw_ostream& os, MgbOp& op) { - os << formatv(" \\\n _init_py_{0}(m);", op.getCppClassName()); - }); - os << "\n"; - return false; -} +using namespace mlir::tblgen; int main(int argc, char **argv) { llvm::InitLLVM y(argc, argv); diff --git a/imperative/tablegen/emitter.h b/imperative/tablegen/emitter.h new file mode 100644 index 000000000..256da9701 --- /dev/null +++ b/imperative/tablegen/emitter.h @@ -0,0 +1,40 @@ +/** + * \file imperative/tablegen/emitter.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir::tblgen { + +struct Environment { + std::unordered_map> enumAlias; +}; + +struct EmitterBase { + EmitterBase(raw_ostream& os_): os(os_) {} + EmitterBase(raw_ostream& os_, Environment& env): os(os_), env_p(&env) {} +protected: + void newline() { os << "\n"; } + Environment& env() { + if (env_p) { + return *env_p; + } + throw std::runtime_error("access global environment via non-environment emitter"); + } + raw_ostream& os; + Environment* env_p = nullptr; +}; + +} // namespace mlir::tblgen \ No newline at end of file diff --git a/imperative/tablegen/helper.h b/imperative/tablegen/helper.h index c0fa56fb0..c881618e5 100644 --- a/imperative/tablegen/helper.h +++ b/imperative/tablegen/helper.h @@ -1,3 +1,16 @@ +/** + * \file imperative/tablegen/helper.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include #include #include @@ -278,5 +291,28 @@ public: } }; +using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase; +using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin; +using MgbHashableAttr = mlir::tblgen::MgbHashableAttrMixin; +using MgbAliasAttr = mlir::tblgen::MgbAliasAttrMixin; +using MgbOp = mlir::tblgen::MgbOpBase; +using MgbHashableOp = mlir::tblgen::MgbHashableOpMixin; + +static inline void foreach_operator(llvm::RecordKeeper &keeper, + std::function callback) { + auto op_base_class = keeper.getClass("Op"); + ASSERT(op_base_class, "could not find base class Op"); + for (auto&& i: keeper.getDefs()) { + auto&& r = i.second; + if (r->isSubClassOf(op_base_class)) { + auto op = mlir::tblgen::Operator(r.get()); + if (op.getDialectName().str() == "mgb") { + std::cerr << "\033[34;15m" << "Generating " << r->getName().str() << "\033[0m" << std::endl; + callback(llvm::cast(op)); + } + } + } +} + } // namespace tblgen } // namespace mlir diff --git a/imperative/tablegen/targets/cpp_class.cpp b/imperative/tablegen/targets/cpp_class.cpp new file mode 100644 index 000000000..e7285f14d --- /dev/null +++ b/imperative/tablegen/targets/cpp_class.cpp @@ -0,0 +1,309 @@ +/** + * \file imperative/tablegen/targets/cpp_class.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./cpp_class.h" +#include "../emitter.h" + +namespace mlir::tblgen { +namespace { +llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) { + // Note: we have already registered the corresponding attr wrappers + // for following basic ctypes so we needn't handle them here + /* auto&& attr_type_name = attr.getAttrDefName(); + if (attr_type_name == "UI32Attr") { + return "uint32_t"; + } + if (attr_type_name == "UI64Attr") { + return "uint64_t"; + } + if (attr_type_name == "I32Attr") { + return "int32_t"; + } + if (attr_type_name == "F32Attr") { + return "float"; + } + if (attr_type_name == "F64Attr") { + return "double"; + } + if (attr_type_name == "StrAttr") { + return "std::string"; + } + if (attr_type_name == "BoolAttr") { + return "bool"; + }*/ + + auto&& attr = llvm::cast(attr_); + if (auto e = llvm::dyn_cast(&attr)) { + return e->getEnumName(); + } + return attr.getUnderlyingType(); +} + +class OpDefEmitter final: public EmitterBase { +public: + OpDefEmitter(MgbOp& op_, raw_ostream& os_): + EmitterBase(os_), op(op_) {} + void emit_header(); + void emit_tpl_spl(); + void emit_body(); +private: + MgbOp& op; +}; + +void OpDefEmitter::emit_header() { + os << formatv( + "class {0} : public OpDefImplBase<{0}> {{\n" + " MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n" + "public:\n", + op.getCppClassName() + ); + // handle enum alias + for (auto &&i : op.getMgbAttributes()) { + if (auto attr = llvm::dyn_cast(&i.attr)) { + os << formatv( + " using {0} = {1};\n", + attr->getEnumName(), attr->getUnderlyingType() + ); + } + } + for (auto &&i : op.getMgbAttributes()) { + auto defaultValue = i.attr.getDefaultValue().str(); + if (!defaultValue.empty()) { + defaultValue = formatv(" = {0}", defaultValue); + } + os << formatv( + " {0} {1}{2};\n", + attr_to_ctype(i.attr), i.name, defaultValue + ); + } + + auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) { + os << formatv( + " {0}({1}){2}{3}\n", + op.getCppClassName(), paramList, memInitList, body + ); + }; + + gen_ctor("", "", " = default;"); + + if (!op.getMgbAttributes().empty()) { + std::vector paramList, initList; + for (auto &&i : op.getMgbAttributes()) { + paramList.push_back(formatv( + "{0} {1}_", attr_to_ctype(i.attr), i.name + )); + initList.push_back(formatv( + "{0}({0}_)", i.name + )); + } + paramList.push_back("std::string scope_ = {}"); + gen_ctor(llvm::join(paramList, ", "), + ": " + llvm::join(initList, ", "), + " { set_scope(scope_); }"); + } + + auto packedParams = op.getPackedParams(); + if (!packedParams.empty()) { + std::vector paramList, initList; + for (auto &&p : packedParams) { + auto&& paramFields = p.getFields(); + auto&& paramType = p.getFullName(); + auto&& paramName = formatv("packed_param_{0}", paramList.size()); + paramList.push_back( + paramFields.empty() ? paramType.str() + : formatv("{0} {1}", paramType, paramName) + ); + for (auto&& i : paramFields) { + initList.push_back(formatv( + "{0}({1}.{0})", i.name, paramName + )); + } + } + for (auto&& i : op.getExtraArguments()) { + paramList.push_back(formatv( + "{0} {1}_", attr_to_ctype(i.attr), i.name + )); + initList.push_back(formatv( + "{0}({0}_)", i.name + )); + } + gen_ctor(llvm::join(paramList, ", "), + initList.empty() ? "" : ": " + llvm::join(initList, ", "), + " {}"); + } + + if (!packedParams.empty()) { + for (auto&& p : packedParams) { + auto accessor = p.getAccessor(); + if (!accessor.empty()) { + os << formatv( + " {0} {1}() const {{\n", + p.getFullName(), accessor + ); + std::vector fields; + for (auto&& i : p.getFields()) { + fields.push_back(i.name); + } + os << formatv( + " return {{{0}};\n", + llvm::join(fields, ", ") + ); + os << " }\n"; + } + } + } + + if (auto decl = op.getExtraOpdefDecl()) { + os << decl.getValue(); + } + + os << formatv( + "};\n\n" + ); +} + +void OpDefEmitter::emit_tpl_spl() { + for (auto &&i : op.getMgbAttributes()) { + if (auto attr = llvm::dyn_cast(&i.attr)) { + if (attr->supportToString()) { + std::vector case_body; + std::string ename = formatv("{0}::{1}", + op.getCppClassName(), attr->getEnumName()); + llvm::for_each(attr->getEnumMembers(), [&](auto&& v){ + case_body.push_back(formatv( + "case {0}::{1}: return \"{1}\";", ename, v)); + }); + os << formatv(R"( +template <> +struct ToStringTrait<{0}> { + std::string operator()({0} e) const { + switch (e) { + {1} + default: + return "{0}::Unknown"; + } + } +}; +)", ename, llvm::join(case_body, "\n")); + } + } + } +} + +void OpDefEmitter::emit_body() { + auto&& className = op.getCppClassName(); + os << formatv( + "MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className + ); + auto formatMethImpl = [&](auto&& meth) { + return formatv( + "{0}_{1}_impl", className, meth + ); + }; + std::vector methods; + if (auto hashable = llvm::dyn_cast(&op)) { + os << "namespace {\n"; + + // generate hash() + mlir::tblgen::FmtContext ctx; + os << formatv( + "size_t {0}(const OpDef& def_) {{\n", + formatMethImpl("hash") + ); + os << formatv( + " auto&& op_ = def_.cast_final_safe<{0}>();\n" + " static_cast(op_);\n", + className + ); + ctx.withSelf("op_"); + os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx); + os << "}\n"; + + // generate is_same_st() + os << formatv( + "bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n", + formatMethImpl("is_same_st") + ); + os << formatv( + " auto &&a_ = lhs_.cast_final_safe<{0}>(),\n" + " &&b_ = rhs_.cast_final_safe<{0}>();\n" + " static_cast(a_);\n" + " static_cast(b_);\n", + className + ); + os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_"); + os << "}\n"; + + // generate props() + os << formatv( + "std::vector> {0}(const OpDef& def_) {{\n", + formatMethImpl("props") + ); + os << formatv( + " auto&& op_ = def_.cast_final_safe<{0}>();\n" + " static_cast(op_);\n", + className + ); + ctx.withSelf("op_"); + os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx); + os << "}\n"; + + // generate make_name() + os << formatv( + "std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name") + ); + os << formatv( + " auto&& op_ = def_.cast_final_safe<{0}>();\n" + " static_cast(op_);\n", + className + ); + ctx.withSelf("op_"); + os << mlir::tblgen::tgfmt(op.getNameFunctionTemplate(), &ctx); + os << "}\n"; + + os << "} // anonymous namespace\n"; + + methods.push_back("hash"); + methods.push_back("is_same_st"); + methods.push_back("props"); + methods.push_back("make_name"); + } + if (!methods.empty()) { + os << formatv( + "OP_TRAIT_REG({0}, {0})", op.getCppClassName() + ); + for (auto&& i : methods) { + os << formatv( + "\n .{0}({1})", i, formatMethImpl(i) + ); + } + os << ";\n\n"; + } +} +} // namespace + +bool gen_op_def_c_header(raw_ostream &os, llvm::RecordKeeper &keeper) { + foreach_operator(keeper, [&](MgbOp& op) { + OpDefEmitter emitter(op, os); + emitter.emit_header(); + emitter.emit_tpl_spl(); + }); + return false; +} + +bool gen_op_def_c_body(raw_ostream &os, llvm::RecordKeeper &keeper) { + foreach_operator(keeper, [&](MgbOp& op) { + OpDefEmitter emitter(op, os); + emitter.emit_body(); + }); + return false; +} +} // namespace mlir::tblgen diff --git a/imperative/tablegen/targets/cpp_class.h b/imperative/tablegen/targets/cpp_class.h new file mode 100644 index 000000000..1654b1b9f --- /dev/null +++ b/imperative/tablegen/targets/cpp_class.h @@ -0,0 +1,21 @@ +/** + * \file imperative/tablegen/targets/cpp_class.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "../helper.h" + +namespace mlir::tblgen { + +bool gen_op_def_c_header(raw_ostream &os, llvm::RecordKeeper &keeper); + +bool gen_op_def_c_body(raw_ostream &os, llvm::RecordKeeper &keeper); + +} // namespace mlir::tblgen diff --git a/imperative/tablegen/targets/pybind11.cpp b/imperative/tablegen/targets/pybind11.cpp new file mode 100644 index 000000000..714b4f960 --- /dev/null +++ b/imperative/tablegen/targets/pybind11.cpp @@ -0,0 +1,142 @@ +/** + * \file imperative/tablegen/targets/pybind11.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./pybind11.h" +#include "../emitter.h" + +namespace mlir::tblgen { +namespace { +class OpDefEmitter final: public EmitterBase { +public: + OpDefEmitter(MgbOp& op_, raw_ostream& os_, Environment& env_): + EmitterBase(os_, env_), op(op_) {} + + void emit(); +private: + MgbOp& op; +}; + +void OpDefEmitter::emit() { + auto className = op.getCppClassName(); + os << formatv( + "py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n", + className + ); + for (auto&& i : op.getMgbAttributes()) { + if (auto attr = llvm::dyn_cast(&i.attr)) { + unsigned int enumID; + if (auto alias = llvm::dyn_cast(attr)) { + auto&& aliasBase = alias->getAliasBase(); + enumID = + llvm::cast(aliasBase) + .getBaseRecord()->getID(); + } else { + enumID = attr->getBaseRecord()->getID(); + } + auto&& enumAlias = env().enumAlias; + auto&& iter = enumAlias.find(enumID); + if (iter == enumAlias.end()) { + os << formatv( + "py::enum_<{0}::{1}>({0}Inst, \"{1}\")", + className, attr->getEnumName() + ); + std::vector body; + for (auto&& i: attr->getEnumMembers()) { + os << formatv( + "\n .value(\"{2}\", {0}::{1}::{2})", + className, attr->getEnumName(), i + ); + body.push_back(formatv( + "if (str == \"{2}\") return {0}::{1}::{2};", + className, attr->getEnumName(), i + )); + } + if (attr->getEnumCombinedFlag()) { + //! define operator | + os << formatv( + "\n .def(\"__or__\", []({0}::{1} s0, {0}::{1} s1) {{ " + "\n return static_cast<{0}::{1}>(uint32_t(s0) | uint32_t(s1));" + "\n })", + className, attr->getEnumName()); + //! define operator & + os << formatv( + "\n .def(\"__and__\", []({0}::{1} s0, {0}::{1} s1) {{" + "\n return static_cast<{0}::{1}>(uint32_t(s0) & uint32_t(s1));" + "\n })", + className, attr->getEnumName()); + } + os << formatv( + "\n .def(py::init([](const std::string& in) {" + "\n auto&& str = normalize_enum(in);" + "\n {0}" + "\n throw py::cast_error(\"invalid enum value \" + in);" + "\n }));\n", + llvm::join(body, "\n ") + ); + os << formatv( + "py::implicitly_convertible();\n\n", + className, attr->getEnumName() + ); + enumAlias.emplace(enumID, + std::make_pair(className, attr->getEnumName())); + } else { + os << formatv( + "{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n", + className, attr->getEnumName(), + iter->second.first, iter->second.second + ); + } + } + } + // generate op class binding + os << formatv("{0}Inst", className); + bool hasDefaultCtor = op.getMgbAttributes().empty(); + if (!hasDefaultCtor) { + os << "\n .def(py::init<"; + std::vector targs; + for (auto &&i : op.getMgbAttributes()) { + targs.push_back(i.attr.getReturnType()); + } + os << llvm::join(targs, ", "); + os << ", std::string>()"; + for (auto &&i : op.getMgbAttributes()) { + os << formatv(", py::arg(\"{0}\")", i.name); + auto defaultValue = i.attr.getDefaultValue(); + if (!defaultValue.empty()) { + os << formatv(" = {0}", defaultValue); + } else { + hasDefaultCtor = true; + } + } + os << ", py::arg(\"scope\") = {})"; + } + if (hasDefaultCtor) { + os << "\n .def(py::init<>())"; + } + for (auto &&i : op.getMgbAttributes()) { + os << formatv( + "\n .def_readwrite(\"{0}\", &{1}::{0})", + i.name, className + ); + } + os << ";\n\n"; +} +} // namespace + +bool gen_op_def_pybind11(raw_ostream &os, llvm::RecordKeeper &keeper) { + Environment env; + using namespace std::placeholders; + foreach_operator(keeper, [&](MgbOp& op) { + OpDefEmitter(op, os, env).emit(); + }); + return false; +} +} // namespace mlir::tblgen diff --git a/imperative/tablegen/targets/pybind11.h b/imperative/tablegen/targets/pybind11.h new file mode 100644 index 000000000..d27a84b24 --- /dev/null +++ b/imperative/tablegen/targets/pybind11.h @@ -0,0 +1,19 @@ +/** + * \file imperative/tablegen/targets/pybind11.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "../helper.h" + +namespace mlir::tblgen { + +bool gen_op_def_pybind11(raw_ostream &os, llvm::RecordKeeper &keeper); + +} // namespace mlir::tblgen diff --git a/imperative/tablegen/targets/python_c_extension.cpp b/imperative/tablegen/targets/python_c_extension.cpp new file mode 100644 index 000000000..1de71a858 --- /dev/null +++ b/imperative/tablegen/targets/python_c_extension.cpp @@ -0,0 +1,313 @@ +/** + * \file imperative/tablegen/targets/python_c_extension.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "python_c_extension.h" +#include "../emitter.h" + +namespace mlir::tblgen { +namespace { +struct Initproc { + std::string func; + Initproc(std::string&& s): func(std::move(s)) {} + std::string operator()(std::string argument) { + return formatv("{0}({1})", func, argument); + } +}; + +class OpDefEmitter: public EmitterBase { +public: + OpDefEmitter(MgbOp& op_, raw_ostream& os_, Environment& env_): + EmitterBase(os_, env_), op(op_) { + ctx.withSelf(op.getCppClassName()); + } + + Initproc emit(); +private: + void emit_class(); + void emit_py_init(); + void emit_py_getsetters(); + Initproc emit_initproc(); + + MgbOp& op; + std::vector subclasses; + mlir::tblgen::FmtContext ctx; +}; + +class EnumAttrEmitter: public EmitterBase { +public: + EnumAttrEmitter(llvm::StringRef parent, MgbEnumAttr* attr_, raw_ostream& os_, Environment& env_): + EmitterBase(os_, env_), attr(attr_) { + unsigned int enumID; + if (auto alias = llvm::dyn_cast(attr)) { + auto&& aliasBase = alias->getAliasBase(); + enumID = llvm::cast(aliasBase).getBaseRecord()->getID(); + } else { + enumID = attr->getBaseRecord()->getID(); + } + ctx.addSubst("enumTpl", attr->getEnumCombinedFlag() ? "BitCombinedEnumWrapper" : "EnumWrapper"); + ctx.addSubst("opClass", parent); + ctx.addSubst("enumClass", attr->getEnumName()); + firstOccur = env().enumAlias.emplace(enumID, std::make_pair(parent, attr->getEnumName())).second; + } + + Initproc emit(); +protected: + void emit_tpl_spl(); + Initproc emit_initproc(); + + MgbEnumAttr* attr; + bool firstOccur; + mlir::tblgen::FmtContext ctx; +}; + +Initproc EnumAttrEmitter::emit() { + emit_tpl_spl(); + return emit_initproc(); +} + +void EnumAttrEmitter::emit_tpl_spl() { + if (!firstOccur) return; + + os << tgfmt( + "template<> PyTypeObject $enumTpl<$opClass::$enumClass>::type={};\n", + &ctx); + + os << tgfmt( + "template<> const char* $enumTpl<$opClass::$enumClass>::name = " + "\"$opClass.$enumClass\";\n", + &ctx); + + if (attr->getEnumCombinedFlag()) { + os << tgfmt( + "template<> PyNumberMethods " + "$enumTpl<$opClass::$enumClass>::number_methods={};\n", + &ctx); + os << tgfmt( + "template<> struct EnumTrait<$opClass::$enumClass> { static constexpr " + "bool is_bit_combined = true;};\n", + &ctx); + } + + auto str2type = [&](auto&& i) -> std::string { + return tgfmt("{normalize_enum(\"$0\"), $opClass::$enumClass::$0}", &ctx, i); + }; + os << tgfmt(R"( +template<> std::unordered_map +$enumTpl<$opClass::$enumClass>::str2type = {$0}; +)", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), str2type), ", ")); + + auto type2str = [&](auto&& i) -> std::string { + return tgfmt("{$opClass::$enumClass::$0, normalize_enum(\"$0\")}", &ctx, i); + }; + os << tgfmt(R"( +template<> std::unordered_map<$opClass::$enumClass, std::string> +$enumTpl<$opClass::$enumClass>::type2str = {$0}; +)", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), type2str), ", ")); +} + +Initproc EnumAttrEmitter::emit_initproc() { + std::string initproc = formatv("_init_py_{0}_{1}", + ctx.getSubstFor("opClass"), ctx.getSubstFor("enumClass")); + + os << tgfmt(R"( +void $0(PyTypeObject& py_type) { + auto& e_type = $enumTpl<$opClass::$enumClass>::type; +)", &ctx, initproc); + + if (firstOccur) { + os << tgfmt(R"( + e_type = {PyVarObject_HEAD_INIT(NULL, 0)}; + e_type.tp_name = "megengine.core._imperative_rt.ops.$opClass.$enumClass"; + e_type.tp_basicsize = sizeof($enumTpl<$opClass::$enumClass>); + e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; + e_type.tp_doc = "$opClass.$enumClass"; + e_type.tp_base = &PyBaseObject_Type; + e_type.tp_repr = $enumTpl<$opClass::$enumClass>::py_repr; + e_type.tp_richcompare = $enumTpl<$opClass::$enumClass>::tp_richcompare; +)", &ctx); + if (attr->getEnumCombinedFlag()) { + // only bit combined enum could new instance because bitwise operation, + // others should always use singleton + os << tgfmt(R"( + e_type.tp_new = $enumTpl<$opClass::$enumClass>::py_new_combined_enum; + e_type.tp_init = $enumTpl<$opClass::$enumClass>::py_init; + auto& number_method = $enumTpl<$opClass::$enumClass>::number_methods; + number_method.nb_or = $enumTpl<$opClass::$enumClass>::py_or; + number_method.nb_and = $enumTpl<$opClass::$enumClass>::py_and; + e_type.tp_as_number = &number_method; +)", &ctx); + } + + os << " mgb_assert(PyType_Ready(&e_type) >= 0);\n"; + + + for (auto&& i : attr->getEnumMembers()) { + os << tgfmt(R"({ + PyObject* inst = e_type.tp_alloc(&e_type, 0); + reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0; + mgb_assert(PyDict_SetItemString(e_type.tp_dict, "$0", inst) >= 0); + PyType_Modified(&e_type); +})", &ctx, i); + } + } + + os << tgfmt(R"( + mgb_assert(PyDict_SetItemString( + py_type.tp_dict, "$enumClass", reinterpret_cast(&e_type)) >= 0); +)", &ctx); + os << "}\n"; + return initproc; +} + +Initproc OpDefEmitter::emit() { + for (auto&& i : op.getMgbAttributes()) { + if (auto attr = llvm::dyn_cast(&i.attr)) { + subclasses.push_back(EnumAttrEmitter(op.getCppClassName(), attr, os, env()).emit()); + } + } + + emit_class(); + emit_py_init(); + emit_py_getsetters(); + return emit_initproc(); +} + +void OpDefEmitter::emit_class() { + os << tgfmt(R"( +PyOpDefBegin($_self) // { + static PyGetSetDef py_getsetters[]; + static int py_init(PyObject *self, PyObject *args, PyObject *kwds); +// }; +PyOpDefEnd($_self) +)", &ctx); +} + +void OpDefEmitter::emit_py_init() { + std::string initBody; + if (!op.getMgbAttributes().empty()) { + initBody += "static const char* kwlist[] = {"; + + std::vector attr_name_list; + llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { + attr_name_list.push_back(attr.name); + }); + attr_name_list.push_back("scope"); + + llvm::for_each(attr_name_list, [&](auto&& attr) { + initBody += formatv("\"{0}\", ", attr); + }); + initBody += "NULL};\n"; + initBody += " PyObject "; + auto initializer = [&](auto&& attr) -> std::string { + return formatv("*{0} = NULL", attr); + }; + initBody += llvm::join(llvm::map_range(attr_name_list, initializer), ", ") + ";\n"; + initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|"; + // an extra slot created for name + initBody += std::string(attr_name_list.size(), 'O'); + initBody += "\", const_cast(kwlist)"; + llvm::for_each(attr_name_list, [&](auto&& attr) { + initBody += formatv(", &{0}", attr); + }); + initBody += "))\n"; + initBody += " return -1;\n"; + + llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { + initBody += tgfmt(R"( + if ($0) { + try { + reinterpret_cast(self)->inst().$0 = + pyobj_convert_generic::from($0); + } CATCH_ALL(-1) + } +)", &ctx, attr.name); + }); + + initBody += tgfmt(R"( + if (scope) { + try { + reinterpret_cast(self)->op + ->set_scope(pyobj_convert_generic::from(scope)); + } CATCH_ALL(-1) + } +)", &ctx); + + } + initBody += "\n return 0;"; + + + os << tgfmt(R"( +int PyOp($_self)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { + $0 +} +)", &ctx, initBody); +} + +void OpDefEmitter::emit_py_getsetters() { + auto f = [&](auto&& attr) -> std::string { + return tgfmt( + "{const_cast(\"$0\"), py_get_generic($_self, $0), py_set_generic($_self, $0), const_cast(\"$0\"), NULL},", + &ctx, attr.name); + }; + os << tgfmt(R"( +PyGetSetDef PyOp($_self)::py_getsetters[] = { + $0 + {NULL} /* Sentinel */ +}; +)", &ctx, llvm::join(llvm::map_range(op.getMgbAttributes(), f), "\n ")); +} + +Initproc OpDefEmitter::emit_initproc() { + std::string initproc = formatv("_init_py_{0}", op.getCppClassName()); + std::string subclass_init_call; + for (auto&& i : subclasses) { + subclass_init_call += formatv(" {0};\n", i("py_type")); + } + os << tgfmt(R"( +void $0(py::module m) { + using py_op = PyOp($_self); + auto& py_type = PyOpType($_self); + py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; + py_type.tp_name = "megengine.core._imperative_rt.ops.$_self"; + py_type.tp_basicsize = sizeof(PyOp($_self)); + py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; + py_type.tp_doc = "$_self"; + py_type.tp_base = &PyOpType(OpDef); + py_type.tp_dealloc = py_dealloc_generic; + py_type.tp_new = py_new_generic; + py_type.tp_init = py_op::py_init; + py_type.tp_getset = py_op::py_getsetters; + mgb_assert(PyType_Ready(&py_type) >= 0); + $1 + PyType_Modified(&py_type); + m.add_object("$_self", reinterpret_cast(&py_type)); + mgb_assert(PyOp(OpDef)::ctype2pytype.emplace($_self::typeinfo(), &py_type).second); +} +)", &ctx, initproc, subclass_init_call); + return initproc; +} +} // namespace + +bool gen_op_def_python_c_extension(raw_ostream &os, llvm::RecordKeeper &keeper) { + Environment env; + using namespace std::placeholders; + std::vector initprocs; + foreach_operator(keeper, [&](MgbOp& op) { + initprocs.emplace_back(OpDefEmitter(op, os, env).emit()); + }); + os << "#define INIT_ALL_OP(m)"; + for(auto&& init : initprocs) { + os << formatv(" \\\n {0};", init("m")); + } + os << "\n"; + return false; +} +} // namespace mlir::tblgen \ No newline at end of file diff --git a/imperative/tablegen/targets/python_c_extension.h b/imperative/tablegen/targets/python_c_extension.h new file mode 100644 index 000000000..65692080f --- /dev/null +++ b/imperative/tablegen/targets/python_c_extension.h @@ -0,0 +1,19 @@ +/** + * \file imperative/tablegen/targets/python_c_extension.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "../helper.h" + +namespace mlir::tblgen { + +bool gen_op_def_python_c_extension(raw_ostream &os, llvm::RecordKeeper &keeper); + +} // namespace mlir::tblgen -- GitLab