提交 d47cf332 编写于 作者: M Megvii Engine Team

build(third_party): update llvm-project and adapt to mlir interface changes

GitOrigin-RevId: dd45984ccaeef39094f9a209abe490d10e54a77e
上级 95f6b531
......@@ -49,6 +49,14 @@ function(external_tablegen_library)
install(TARGETS ${_NAME} EXPORT ${MGE_EXPORT_TARGETS})
endfunction()
set(LLVM_LIBS LLVMCore LLVMSupport LLVMX86CodeGen LLVMOrcJIT LLVMNVPTXCodeGen LLVMNVPTXDesc LLVMNVPTXInfo)
set(MLIR_CORE_LIBS MLIRAnalysis MLIRExecutionEngine MLIRIR MLIRParser MLIRPass MLIRSideEffectInterfaces MLIRTransforms)
set(MLIR_DIALECT_LIBS MLIRAsync MLIRAVX512 MLIRGPU MLIRLLVMAVX512 MLIRNVVMIR MLIROpenACC MLIRPDL MLIRPDLInterp MLIRQuant MLIRROCDLIR MLIRSDBM MLIRShape MLIRSPIRV MLIRStandardOpsTransforms)
set(MLIR_CONVERSION_LIBS MLIRAffineToStandard MLIRAVX512ToLLVM MLIRGPUToGPURuntimeTransforms MLIRGPUToNVVMTransforms MLIRSCFToStandard)
set(MLIR_TRANSLATION_LIBS MLIRTargetLLVMIR MLIRTargetNVVMIR)
set(MLIR_LIBS ${MLIR_CORE_LIBS} ${MLIR_DIALECT_LIBS} ${MLIR_CONVERSION_LIBS} ${MLIR_TRANSLATION_LIBS})
set(MLIR_LLVM_LIBS ${LLVM_LIBS} ${MLIR_LIBS})
if (MGE_USE_SYSTEM_LIB)
find_package(ZLIB)
find_package(MLIR REQUIRED CONFIG)
......@@ -77,9 +85,7 @@ if (MGE_USE_SYSTEM_LIB)
endif()
endfunction(find_mlir_llvm_lib)
set(MLIR_COMPONENTS MLIRAnalysis;MLIRExecutionEngine;MLIRIR;MLIRParser;MLIRPass;MLIRSideEffectInterfaces;MLIRTargetLLVMIR;MLIRTransforms;MLIRAffineToStandard;MLIRSCFToStandard;MLIRAVX512ToLLVM;MLIRAVX512;MLIRLLVMAVX512;MLIRSDBM;MLIRROCDLIR;MLIRGPU;MLIRQuant;MLIRSPIRV;MLIRNVVMIR;MLIRShape;MLIRGPUToNVVMTransforms;MLIRTargetNVVMIR;MLIRGPUToGPURuntimeTransforms;MLIRStandardOpsTransforms)
foreach(c ${MLIR_COMPONENTS})
foreach(c ${MLIR_LIBS})
find_mlir_llvm_lib(${c})
endforeach()
return()
......@@ -119,5 +125,3 @@ set(MLIR_LLVM_INCLUDE_DIR
${PROJECT_BINARY_DIR}/third_party/llvm-project/llvm/tools/mlir/include
)
set(MLIR_TABLEGEN_EXE mlir-tblgen)
set(MLIR_LLVM_LIBS LLVMCore;LLVMSupport;LLVMX86CodeGen;LLVMOrcJIT;LLVMNVPTXCodeGen;LLVMNVPTXDesc;LLVMNVPTXInfo;MLIRAnalysis;MLIRExecutionEngine;MLIRIR;MLIRParser;MLIRPass;MLIRSideEffectInterfaces;MLIRTargetLLVMIR;MLIRTransforms;MLIRAffineToStandard;MLIRSCFToStandard;MLIRAVX512ToLLVM;MLIRAVX512;MLIRLLVMAVX512;MLIRSDBM;MLIRROCDLIR;MLIRGPU;MLIRQuant;MLIRSPIRV;MLIRNVVMIR;MLIRGPUToNVVMTransforms;MLIRShape;MLIRTargetNVVMIR;MLIRGPUToGPURuntimeTransforms;MLIRStandardOpsTransforms)
......@@ -67,8 +67,8 @@ mlir::OwnedBlob compile_ptx_to_cubin(const std::string ptx, mlir::Location,
}
std::unique_ptr<llvm::Module> translate_module_to_nvvm_ir_and_link_device(
Operation* m) {
std::unique_ptr<llvm::Module> module = mlir::translateModuleToNVVMIR(m);
Operation* m, llvm::LLVMContext& llvmContext, llvm::StringRef name) {
std::unique_ptr<llvm::Module> module = mlir::translateModuleToNVVMIR(m, llvmContext);
auto get_device_path = []() -> std::string {
auto cuda_path = getenv("CUDA_BIN_PATH");
std::string device_dir;
......@@ -223,6 +223,7 @@ void MLIRCompiler::run_lowering_pass(mlir::OwningModuleRef& module,
std::unique_ptr<Executable> MLIRCompiler::do_compile(
const InternalGraph& graph, const JITExecutor::Args& args) {
mlir::MLIRContext ctx;
ctx.getOrLoadDialect<MgbDialect>();
ctx.printStackTraceOnDiagnostic(true);
ctx.printOpOnDiagnostic(true);
......
......@@ -24,7 +24,8 @@
using namespace mgb;
using namespace jit;
MgbDialect::MgbDialect(mlir::MLIRContext* ctx) : mlir::Dialect("mgb", ctx) {
MgbDialect::MgbDialect(mlir::MLIRContext* ctx)
: mlir::Dialect("mgb", ctx, mlir::TypeID::get<MgbDialect>()) {
addOperations<
#define GET_OP_LIST
#include "megbrain/jit/mlir/ir/ops.cpp.inc"
......
......@@ -209,6 +209,11 @@ struct ConstantScalarOpLowering
class MgbToAffineLoweringPass
: public PassWrapper<MgbToAffineLoweringPass, FunctionPass> {
public:
void getDependentDialects(mlir::DialectRegistry& registry) const override {
registry.insert<mlir::AffineDialect>();
registry.insert<mlir::StandardOpsDialect>();
}
void runOnFunction() override final {
ConversionTarget target(getContext());
target.addLegalDialect<AffineDialect, StandardOpsDialect>();
......
......@@ -259,6 +259,11 @@ private:
class MgbToGpuLoweringPass
: public PassWrapper<MgbToGpuLoweringPass, FunctionPass> {
public:
void getDependentDialects(mlir::DialectRegistry& registry) const override {
registry.insert<mlir::gpu::GPUDialect>();
registry.insert<mlir::StandardOpsDialect>();
}
void runOnFunction() override final {
auto func_op = getFunction();
Location loc = func_op.getLoc();
......
......@@ -21,6 +21,8 @@
#include <mlir/Conversion/SCFToStandard/SCFToStandard.h>
#include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h>
#include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/Dialect/SCF/SCF.h>
#include <mlir/Dialect/StandardOps/Transforms/Passes.h>
using namespace mgb;
......@@ -30,6 +32,12 @@ namespace {
class AffineToLLVMLoweringPass : public PassWrapper<AffineToLLVMLoweringPass,
OperationPass<ModuleOp>> {
public:
void getDependentDialects(mlir::DialectRegistry& registry) const override {
registry.insert<mlir::LLVM::LLVMDialect>();
registry.insert<mlir::scf::SCFDialect>();
}
void runOnOperation() final {
LLVMConversionTarget target(getContext());
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
......
......@@ -21,7 +21,7 @@ namespace jit {
inline bool is_elemwise_float(const mlir::Type& dt) {
if (auto cast = dt.dyn_cast_or_null<mlir::MemRefType>()) {
if (cast.getElementType().getKind() == mlir::StandardTypes::F32) {
if (cast.getElementType().isF32()) {
return true;
}
}
......
......@@ -82,13 +82,12 @@ megdnn::DType jit::mlir_type_to_dtype(mlir::Type type) {
if (auto cast = type.dyn_cast_or_null<mlir::MemRefType>()) {
element_type = cast.getElementType();
}
switch (element_type.getKind()) {
case mlir::StandardTypes::F32:
return megdnn::dtype::Float32{};
default:
mgb_throw(InternalError,
"Unsupport mlir type for MemRefType, got: %s\n",
mlir_type_to_string(type).c_str());
if (element_type.isF32()) {
return megdnn::dtype::Float32{};
} else {
mgb_throw(InternalError,
"Unsupport mlir type for MemRefType, got: %s\n",
mlir_type_to_string(type).c_str());
}
return {};
}
......
......@@ -34,13 +34,13 @@ public:
static llvm::StringRef getDialectNamespace() { return "mgb::jit"; }
};
} // namespace jit
} // namespace mgb
#define GET_OP_CLASSES
using namespace mlir;
#include "megbrain/jit/mlir/ir/ops.h.inc"
} // namespace jit
} // namespace mgb
#endif // MGB_JIT && MGB_JIT_MLIR
// vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册