提交 eeeddbbc 编写于 作者: M Megvii Engine Team

refactor(imperative): refactor tablegen code generator

GitOrigin-RevId: b81b085762c47da9a901ec3e25778f3c75f21395
上级 0ad85a41
# mgb tablegen executable
set(TABLE_TARGET mgb-mlir-autogen)
add_executable(${TABLE_TARGET} autogen.cpp)
file(GLOB_RECURSE SRCS ${CMAKE_CURRENT_SOURCE_DIR}/*.h ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp)
add_executable(${TABLE_TARGET} ${SRCS})
target_include_directories(${TABLE_TARGET} PRIVATE ${MLIR_LLVM_INCLUDE_DIR})
target_link_libraries(${TABLE_TARGET} PRIVATE LLVMTableGen MLIRTableGen LLVMSupport)
set(MGB_TABLEGEN_EXE ${TABLE_TARGET})
......
此差异已折叠。
/**
* \file imperative/tablegen/emitter.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <unordered_map>
#include <stdexcept>
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/raw_ostream.h"
namespace mlir::tblgen {
struct Environment {
std::unordered_map<unsigned int, std::pair<llvm::StringRef, llvm::StringRef>> enumAlias;
};
struct EmitterBase {
EmitterBase(raw_ostream& os_): os(os_) {}
EmitterBase(raw_ostream& os_, Environment& env): os(os_), env_p(&env) {}
protected:
void newline() { os << "\n"; }
Environment& env() {
if (env_p) {
return *env_p;
}
throw std::runtime_error("access global environment via non-environment emitter");
}
raw_ostream& os;
Environment* env_p = nullptr;
};
} // namespace mlir::tblgen
\ No newline at end of file
/**
* \file imperative/tablegen/helper.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <iostream>
#include <string>
#include <vector>
......@@ -278,5 +291,28 @@ public:
}
};
using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase;
using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin;
using MgbHashableAttr = mlir::tblgen::MgbHashableAttrMixin;
using MgbAliasAttr = mlir::tblgen::MgbAliasAttrMixin;
using MgbOp = mlir::tblgen::MgbOpBase;
using MgbHashableOp = mlir::tblgen::MgbHashableOpMixin;
static inline void foreach_operator(llvm::RecordKeeper &keeper,
std::function<void(MgbOp&)> callback) {
auto op_base_class = keeper.getClass("Op");
ASSERT(op_base_class, "could not find base class Op");
for (auto&& i: keeper.getDefs()) {
auto&& r = i.second;
if (r->isSubClassOf(op_base_class)) {
auto op = mlir::tblgen::Operator(r.get());
if (op.getDialectName().str() == "mgb") {
std::cerr << "\033[34;15m" << "Generating " << r->getName().str() << "\033[0m" << std::endl;
callback(llvm::cast<MgbOp>(op));
}
}
}
}
} // namespace tblgen
} // namespace mlir
/**
* \file imperative/tablegen/targets/cpp_class.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./cpp_class.h"
#include "../emitter.h"
namespace mlir::tblgen {
namespace {
llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) {
// Note: we have already registered the corresponding attr wrappers
// for following basic ctypes so we needn't handle them here
/* auto&& attr_type_name = attr.getAttrDefName();
if (attr_type_name == "UI32Attr") {
return "uint32_t";
}
if (attr_type_name == "UI64Attr") {
return "uint64_t";
}
if (attr_type_name == "I32Attr") {
return "int32_t";
}
if (attr_type_name == "F32Attr") {
return "float";
}
if (attr_type_name == "F64Attr") {
return "double";
}
if (attr_type_name == "StrAttr") {
return "std::string";
}
if (attr_type_name == "BoolAttr") {
return "bool";
}*/
auto&& attr = llvm::cast<MgbAttrWrapper>(attr_);
if (auto e = llvm::dyn_cast<MgbEnumAttr>(&attr)) {
return e->getEnumName();
}
return attr.getUnderlyingType();
}
class OpDefEmitter final: public EmitterBase {
public:
OpDefEmitter(MgbOp& op_, raw_ostream& os_):
EmitterBase(os_), op(op_) {}
void emit_header();
void emit_tpl_spl();
void emit_body();
private:
MgbOp& op;
};
void OpDefEmitter::emit_header() {
os << formatv(
"class {0} : public OpDefImplBase<{0}> {{\n"
" MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n"
"public:\n",
op.getCppClassName()
);
// handle enum alias
for (auto &&i : op.getMgbAttributes()) {
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
os << formatv(
" using {0} = {1};\n",
attr->getEnumName(), attr->getUnderlyingType()
);
}
}
for (auto &&i : op.getMgbAttributes()) {
auto defaultValue = i.attr.getDefaultValue().str();
if (!defaultValue.empty()) {
defaultValue = formatv(" = {0}", defaultValue);
}
os << formatv(
" {0} {1}{2};\n",
attr_to_ctype(i.attr), i.name, defaultValue
);
}
auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) {
os << formatv(
" {0}({1}){2}{3}\n",
op.getCppClassName(), paramList, memInitList, body
);
};
gen_ctor("", "", " = default;");
if (!op.getMgbAttributes().empty()) {
std::vector<std::string> paramList, initList;
for (auto &&i : op.getMgbAttributes()) {
paramList.push_back(formatv(
"{0} {1}_", attr_to_ctype(i.attr), i.name
));
initList.push_back(formatv(
"{0}({0}_)", i.name
));
}
paramList.push_back("std::string scope_ = {}");
gen_ctor(llvm::join(paramList, ", "),
": " + llvm::join(initList, ", "),
" { set_scope(scope_); }");
}
auto packedParams = op.getPackedParams();
if (!packedParams.empty()) {
std::vector<std::string> paramList, initList;
for (auto &&p : packedParams) {
auto&& paramFields = p.getFields();
auto&& paramType = p.getFullName();
auto&& paramName = formatv("packed_param_{0}", paramList.size());
paramList.push_back(
paramFields.empty() ? paramType.str()
: formatv("{0} {1}", paramType, paramName)
);
for (auto&& i : paramFields) {
initList.push_back(formatv(
"{0}({1}.{0})", i.name, paramName
));
}
}
for (auto&& i : op.getExtraArguments()) {
paramList.push_back(formatv(
"{0} {1}_", attr_to_ctype(i.attr), i.name
));
initList.push_back(formatv(
"{0}({0}_)", i.name
));
}
gen_ctor(llvm::join(paramList, ", "),
initList.empty() ? "" : ": " + llvm::join(initList, ", "),
" {}");
}
if (!packedParams.empty()) {
for (auto&& p : packedParams) {
auto accessor = p.getAccessor();
if (!accessor.empty()) {
os << formatv(
" {0} {1}() const {{\n",
p.getFullName(), accessor
);
std::vector<llvm::StringRef> fields;
for (auto&& i : p.getFields()) {
fields.push_back(i.name);
}
os << formatv(
" return {{{0}};\n",
llvm::join(fields, ", ")
);
os << " }\n";
}
}
}
if (auto decl = op.getExtraOpdefDecl()) {
os << decl.getValue();
}
os << formatv(
"};\n\n"
);
}
void OpDefEmitter::emit_tpl_spl() {
for (auto &&i : op.getMgbAttributes()) {
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
if (attr->supportToString()) {
std::vector<std::string> case_body;
std::string ename = formatv("{0}::{1}",
op.getCppClassName(), attr->getEnumName());
llvm::for_each(attr->getEnumMembers(), [&](auto&& v){
case_body.push_back(formatv(
"case {0}::{1}: return \"{1}\";", ename, v));
});
os << formatv(R"(
template <>
struct ToStringTrait<{0}> {
std::string operator()({0} e) const {
switch (e) {
{1}
default:
return "{0}::Unknown";
}
}
};
)", ename, llvm::join(case_body, "\n"));
}
}
}
}
void OpDefEmitter::emit_body() {
auto&& className = op.getCppClassName();
os << formatv(
"MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className
);
auto formatMethImpl = [&](auto&& meth) {
return formatv(
"{0}_{1}_impl", className, meth
);
};
std::vector<std::string> methods;
if (auto hashable = llvm::dyn_cast<MgbHashableOp>(&op)) {
os << "namespace {\n";
// generate hash()
mlir::tblgen::FmtContext ctx;
os << formatv(
"size_t {0}(const OpDef& def_) {{\n",
formatMethImpl("hash")
);
os << formatv(
" auto&& op_ = def_.cast_final_safe<{0}>();\n"
" static_cast<void>(op_);\n",
className
);
ctx.withSelf("op_");
os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx);
os << "}\n";
// generate is_same_st()
os << formatv(
"bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n",
formatMethImpl("is_same_st")
);
os << formatv(
" auto &&a_ = lhs_.cast_final_safe<{0}>(),\n"
" &&b_ = rhs_.cast_final_safe<{0}>();\n"
" static_cast<void>(a_);\n"
" static_cast<void>(b_);\n",
className
);
os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_");
os << "}\n";
// generate props()
os << formatv(
"std::vector<std::pair<const char*, std::string>> {0}(const OpDef& def_) {{\n",
formatMethImpl("props")
);
os << formatv(
" auto&& op_ = def_.cast_final_safe<{0}>();\n"
" static_cast<void>(op_);\n",
className
);
ctx.withSelf("op_");
os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx);
os << "}\n";
// generate make_name()
os << formatv(
"std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name")
);
os << formatv(
" auto&& op_ = def_.cast_final_safe<{0}>();\n"
" static_cast<void>(op_);\n",
className
);
ctx.withSelf("op_");
os << mlir::tblgen::tgfmt(op.getNameFunctionTemplate(), &ctx);
os << "}\n";
os << "} // anonymous namespace\n";
methods.push_back("hash");
methods.push_back("is_same_st");
methods.push_back("props");
methods.push_back("make_name");
}
if (!methods.empty()) {
os << formatv(
"OP_TRAIT_REG({0}, {0})", op.getCppClassName()
);
for (auto&& i : methods) {
os << formatv(
"\n .{0}({1})", i, formatMethImpl(i)
);
}
os << ";\n\n";
}
}
} // namespace
bool gen_op_def_c_header(raw_ostream &os, llvm::RecordKeeper &keeper) {
foreach_operator(keeper, [&](MgbOp& op) {
OpDefEmitter emitter(op, os);
emitter.emit_header();
emitter.emit_tpl_spl();
});
return false;
}
bool gen_op_def_c_body(raw_ostream &os, llvm::RecordKeeper &keeper) {
foreach_operator(keeper, [&](MgbOp& op) {
OpDefEmitter emitter(op, os);
emitter.emit_body();
});
return false;
}
} // namespace mlir::tblgen
/**
* \file imperative/tablegen/targets/cpp_class.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "../helper.h"
namespace mlir::tblgen {
bool gen_op_def_c_header(raw_ostream &os, llvm::RecordKeeper &keeper);
bool gen_op_def_c_body(raw_ostream &os, llvm::RecordKeeper &keeper);
} // namespace mlir::tblgen
/**
* \file imperative/tablegen/targets/pybind11.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./pybind11.h"
#include "../emitter.h"
namespace mlir::tblgen {
namespace {
class OpDefEmitter final: public EmitterBase {
public:
OpDefEmitter(MgbOp& op_, raw_ostream& os_, Environment& env_):
EmitterBase(os_, env_), op(op_) {}
void emit();
private:
MgbOp& op;
};
void OpDefEmitter::emit() {
auto className = op.getCppClassName();
os << formatv(
"py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n",
className
);
for (auto&& i : op.getMgbAttributes()) {
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
unsigned int enumID;
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
auto&& aliasBase = alias->getAliasBase();
enumID =
llvm::cast<MgbEnumAttr>(aliasBase)
.getBaseRecord()->getID();
} else {
enumID = attr->getBaseRecord()->getID();
}
auto&& enumAlias = env().enumAlias;
auto&& iter = enumAlias.find(enumID);
if (iter == enumAlias.end()) {
os << formatv(
"py::enum_<{0}::{1}>({0}Inst, \"{1}\")",
className, attr->getEnumName()
);
std::vector<std::string> body;
for (auto&& i: attr->getEnumMembers()) {
os << formatv(
"\n .value(\"{2}\", {0}::{1}::{2})",
className, attr->getEnumName(), i
);
body.push_back(formatv(
"if (str == \"{2}\") return {0}::{1}::{2};",
className, attr->getEnumName(), i
));
}
if (attr->getEnumCombinedFlag()) {
//! define operator |
os << formatv(
"\n .def(\"__or__\", []({0}::{1} s0, {0}::{1} s1) {{ "
"\n return static_cast<{0}::{1}>(uint32_t(s0) | uint32_t(s1));"
"\n })",
className, attr->getEnumName());
//! define operator &
os << formatv(
"\n .def(\"__and__\", []({0}::{1} s0, {0}::{1} s1) {{"
"\n return static_cast<{0}::{1}>(uint32_t(s0) & uint32_t(s1));"
"\n })",
className, attr->getEnumName());
}
os << formatv(
"\n .def(py::init([](const std::string& in) {"
"\n auto&& str = normalize_enum(in);"
"\n {0}"
"\n throw py::cast_error(\"invalid enum value \" + in);"
"\n }));\n",
llvm::join(body, "\n ")
);
os << formatv(
"py::implicitly_convertible<std::string, {0}::{1}>();\n\n",
className, attr->getEnumName()
);
enumAlias.emplace(enumID,
std::make_pair(className, attr->getEnumName()));
} else {
os << formatv(
"{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n",
className, attr->getEnumName(),
iter->second.first, iter->second.second
);
}
}
}
// generate op class binding
os << formatv("{0}Inst", className);
bool hasDefaultCtor = op.getMgbAttributes().empty();
if (!hasDefaultCtor) {
os << "\n .def(py::init<";
std::vector<llvm::StringRef> targs;
for (auto &&i : op.getMgbAttributes()) {
targs.push_back(i.attr.getReturnType());
}
os << llvm::join(targs, ", ");
os << ", std::string>()";
for (auto &&i : op.getMgbAttributes()) {
os << formatv(", py::arg(\"{0}\")", i.name);
auto defaultValue = i.attr.getDefaultValue();
if (!defaultValue.empty()) {
os << formatv(" = {0}", defaultValue);
} else {
hasDefaultCtor = true;
}
}
os << ", py::arg(\"scope\") = {})";
}
if (hasDefaultCtor) {
os << "\n .def(py::init<>())";
}
for (auto &&i : op.getMgbAttributes()) {
os << formatv(
"\n .def_readwrite(\"{0}\", &{1}::{0})",
i.name, className
);
}
os << ";\n\n";
}
} // namespace
bool gen_op_def_pybind11(raw_ostream &os, llvm::RecordKeeper &keeper) {
Environment env;
using namespace std::placeholders;
foreach_operator(keeper, [&](MgbOp& op) {
OpDefEmitter(op, os, env).emit();
});
return false;
}
} // namespace mlir::tblgen
/**
* \file imperative/tablegen/targets/pybind11.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "../helper.h"
namespace mlir::tblgen {
bool gen_op_def_pybind11(raw_ostream &os, llvm::RecordKeeper &keeper);
} // namespace mlir::tblgen
/**
* \file imperative/tablegen/targets/python_c_extension.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "python_c_extension.h"
#include "../emitter.h"
namespace mlir::tblgen {
namespace {
struct Initproc {
std::string func;
Initproc(std::string&& s): func(std::move(s)) {}
std::string operator()(std::string argument) {
return formatv("{0}({1})", func, argument);
}
};
class OpDefEmitter: public EmitterBase {
public:
OpDefEmitter(MgbOp& op_, raw_ostream& os_, Environment& env_):
EmitterBase(os_, env_), op(op_) {
ctx.withSelf(op.getCppClassName());
}
Initproc emit();
private:
void emit_class();
void emit_py_init();
void emit_py_getsetters();
Initproc emit_initproc();
MgbOp& op;
std::vector<Initproc> subclasses;
mlir::tblgen::FmtContext ctx;
};
class EnumAttrEmitter: public EmitterBase {
public:
EnumAttrEmitter(llvm::StringRef parent, MgbEnumAttr* attr_, raw_ostream& os_, Environment& env_):
EmitterBase(os_, env_), attr(attr_) {
unsigned int enumID;
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
auto&& aliasBase = alias->getAliasBase();
enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID();
} else {
enumID = attr->getBaseRecord()->getID();
}
ctx.addSubst("enumTpl", attr->getEnumCombinedFlag() ? "BitCombinedEnumWrapper" : "EnumWrapper");
ctx.addSubst("opClass", parent);
ctx.addSubst("enumClass", attr->getEnumName());
firstOccur = env().enumAlias.emplace(enumID, std::make_pair(parent, attr->getEnumName())).second;
}
Initproc emit();
protected:
void emit_tpl_spl();
Initproc emit_initproc();
MgbEnumAttr* attr;
bool firstOccur;
mlir::tblgen::FmtContext ctx;
};
Initproc EnumAttrEmitter::emit() {
emit_tpl_spl();
return emit_initproc();
}
void EnumAttrEmitter::emit_tpl_spl() {
if (!firstOccur) return;
os << tgfmt(
"template<> PyTypeObject $enumTpl<$opClass::$enumClass>::type={};\n",
&ctx);
os << tgfmt(
"template<> const char* $enumTpl<$opClass::$enumClass>::name = "
"\"$opClass.$enumClass\";\n",
&ctx);
if (attr->getEnumCombinedFlag()) {
os << tgfmt(
"template<> PyNumberMethods "
"$enumTpl<$opClass::$enumClass>::number_methods={};\n",
&ctx);
os << tgfmt(
"template<> struct EnumTrait<$opClass::$enumClass> { static constexpr "
"bool is_bit_combined = true;};\n",
&ctx);
}
auto str2type = [&](auto&& i) -> std::string {
return tgfmt("{normalize_enum(\"$0\"), $opClass::$enumClass::$0}", &ctx, i);
};
os << tgfmt(R"(
template<> std::unordered_map<std::string, $opClass::$enumClass>
$enumTpl<$opClass::$enumClass>::str2type = {$0};
)", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), str2type), ", "));
auto type2str = [&](auto&& i) -> std::string {
return tgfmt("{$opClass::$enumClass::$0, normalize_enum(\"$0\")}", &ctx, i);
};
os << tgfmt(R"(
template<> std::unordered_map<$opClass::$enumClass, std::string>
$enumTpl<$opClass::$enumClass>::type2str = {$0};
)", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), type2str), ", "));
}
Initproc EnumAttrEmitter::emit_initproc() {
std::string initproc = formatv("_init_py_{0}_{1}",
ctx.getSubstFor("opClass"), ctx.getSubstFor("enumClass"));
os << tgfmt(R"(
void $0(PyTypeObject& py_type) {
auto& e_type = $enumTpl<$opClass::$enumClass>::type;
)", &ctx, initproc);
if (firstOccur) {
os << tgfmt(R"(
e_type = {PyVarObject_HEAD_INIT(NULL, 0)};
e_type.tp_name = "megengine.core._imperative_rt.ops.$opClass.$enumClass";
e_type.tp_basicsize = sizeof($enumTpl<$opClass::$enumClass>);
e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
e_type.tp_doc = "$opClass.$enumClass";
e_type.tp_base = &PyBaseObject_Type;
e_type.tp_repr = $enumTpl<$opClass::$enumClass>::py_repr;
e_type.tp_richcompare = $enumTpl<$opClass::$enumClass>::tp_richcompare;
)", &ctx);
if (attr->getEnumCombinedFlag()) {
// only bit combined enum could new instance because bitwise operation,
// others should always use singleton
os << tgfmt(R"(
e_type.tp_new = $enumTpl<$opClass::$enumClass>::py_new_combined_enum;
e_type.tp_init = $enumTpl<$opClass::$enumClass>::py_init;
auto& number_method = $enumTpl<$opClass::$enumClass>::number_methods;
number_method.nb_or = $enumTpl<$opClass::$enumClass>::py_or;
number_method.nb_and = $enumTpl<$opClass::$enumClass>::py_and;
e_type.tp_as_number = &number_method;
)", &ctx);
}
os << " mgb_assert(PyType_Ready(&e_type) >= 0);\n";
for (auto&& i : attr->getEnumMembers()) {
os << tgfmt(R"({
PyObject* inst = e_type.tp_alloc(&e_type, 0);
reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0;
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "$0", inst) >= 0);
PyType_Modified(&e_type);
})", &ctx, i);
}
}
os << tgfmt(R"(
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "$enumClass", reinterpret_cast<PyObject*>(&e_type)) >= 0);
)", &ctx);
os << "}\n";
return initproc;
}
Initproc OpDefEmitter::emit() {
for (auto&& i : op.getMgbAttributes()) {
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
subclasses.push_back(EnumAttrEmitter(op.getCppClassName(), attr, os, env()).emit());
}
}
emit_class();
emit_py_init();
emit_py_getsetters();
return emit_initproc();
}
void OpDefEmitter::emit_class() {
os << tgfmt(R"(
PyOpDefBegin($_self) // {
static PyGetSetDef py_getsetters[];
static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
// };
PyOpDefEnd($_self)
)", &ctx);
}
void OpDefEmitter::emit_py_init() {
std::string initBody;
if (!op.getMgbAttributes().empty()) {
initBody += "static const char* kwlist[] = {";
std::vector<llvm::StringRef> attr_name_list;
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
attr_name_list.push_back(attr.name);
});
attr_name_list.push_back("scope");
llvm::for_each(attr_name_list, [&](auto&& attr) {
initBody += formatv("\"{0}\", ", attr);
});
initBody += "NULL};\n";
initBody += " PyObject ";
auto initializer = [&](auto&& attr) -> std::string {
return formatv("*{0} = NULL", attr);
};
initBody += llvm::join(llvm::map_range(attr_name_list, initializer), ", ") + ";\n";
initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|";
// an extra slot created for name
initBody += std::string(attr_name_list.size(), 'O');
initBody += "\", const_cast<char**>(kwlist)";
llvm::for_each(attr_name_list, [&](auto&& attr) {
initBody += formatv(", &{0}", attr);
});
initBody += "))\n";
initBody += " return -1;\n";
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
initBody += tgfmt(R"(
if ($0) {
try {
reinterpret_cast<PyOp($_self)*>(self)->inst().$0 =
pyobj_convert_generic<decltype($_self::$0)>::from($0);
} CATCH_ALL(-1)
}
)", &ctx, attr.name);
});
initBody += tgfmt(R"(
if (scope) {
try {
reinterpret_cast<PyOp(OpDef)*>(self)->op
->set_scope(pyobj_convert_generic<std::string>::from(scope));
} CATCH_ALL(-1)
}
)", &ctx);
}
initBody += "\n return 0;";
os << tgfmt(R"(
int PyOp($_self)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
$0
}
)", &ctx, initBody);
}
void OpDefEmitter::emit_py_getsetters() {
auto f = [&](auto&& attr) -> std::string {
return tgfmt(
"{const_cast<char*>(\"$0\"), py_get_generic($_self, $0), py_set_generic($_self, $0), const_cast<char*>(\"$0\"), NULL},",
&ctx, attr.name);
};
os << tgfmt(R"(
PyGetSetDef PyOp($_self)::py_getsetters[] = {
$0
{NULL} /* Sentinel */
};
)", &ctx, llvm::join(llvm::map_range(op.getMgbAttributes(), f), "\n "));
}
Initproc OpDefEmitter::emit_initproc() {
std::string initproc = formatv("_init_py_{0}", op.getCppClassName());
std::string subclass_init_call;
for (auto&& i : subclasses) {
subclass_init_call += formatv(" {0};\n", i("py_type"));
}
os << tgfmt(R"(
void $0(py::module m) {
using py_op = PyOp($_self);
auto& py_type = PyOpType($_self);
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.ops.$_self";
py_type.tp_basicsize = sizeof(PyOp($_self));
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "$_self";
py_type.tp_base = &PyOpType(OpDef);
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_new_generic<py_op>;
py_type.tp_init = py_op::py_init;
py_type.tp_getset = py_op::py_getsetters;
mgb_assert(PyType_Ready(&py_type) >= 0);
$1
PyType_Modified(&py_type);
m.add_object("$_self", reinterpret_cast<PyObject*>(&py_type));
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace($_self::typeinfo(), &py_type).second);
}
)", &ctx, initproc, subclass_init_call);
return initproc;
}
} // namespace
bool gen_op_def_python_c_extension(raw_ostream &os, llvm::RecordKeeper &keeper) {
Environment env;
using namespace std::placeholders;
std::vector<Initproc> initprocs;
foreach_operator(keeper, [&](MgbOp& op) {
initprocs.emplace_back(OpDefEmitter(op, os, env).emit());
});
os << "#define INIT_ALL_OP(m)";
for(auto&& init : initprocs) {
os << formatv(" \\\n {0};", init("m"));
}
os << "\n";
return false;
}
} // namespace mlir::tblgen
\ No newline at end of file
/**
* \file imperative/tablegen/targets/python_c_extension.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "../helper.h"
namespace mlir::tblgen {
bool gen_op_def_python_c_extension(raw_ostream &os, llvm::RecordKeeper &keeper);
} // namespace mlir::tblgen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册