diff --git a/Makefile b/Makefile index a3de1df1a820..4697746cdb27 100644 --- a/Makefile +++ b/Makefile @@ -231,7 +231,7 @@ LLVM_STATIC_LIBFILES = \ linker \ ipo \ passes \ - mcjit \ + orcjit \ $(X86_LLVM_CONFIG_LIB) \ $(ARM_LLVM_CONFIG_LIB) \ $(OPENCL_LLVM_CONFIG_LIB) \ diff --git a/dependencies/llvm/CMakeLists.txt b/dependencies/llvm/CMakeLists.txt index ef357ecd5baf..41be363b5744 100644 --- a/dependencies/llvm/CMakeLists.txt +++ b/dependencies/llvm/CMakeLists.txt @@ -53,7 +53,7 @@ endif () # Create options for including or excluding LLVM backends. ## -set(active_components mcjit bitwriter linker passes) +set(active_components orcjit bitwriter linker passes) set(known_components AArch64 AMDGPU ARM Hexagon Mips NVPTX PowerPC RISCV WebAssembly X86) foreach (comp IN LISTS known_components) diff --git a/src/JITModule.cpp b/src/JITModule.cpp index e595613ffb5e..0ebdfde3126d 100644 --- a/src/JITModule.cpp +++ b/src/JITModule.cpp @@ -127,15 +127,16 @@ class JITModuleContents { JITModuleContents() = default; ~JITModuleContents() { - if (execution_engine != nullptr) { - execution_engine->runStaticConstructorsDestructors(true); - delete execution_engine; + if (JIT != nullptr) { + auto err = dtorRunner->run(); + internal_assert(!err) << llvm::toString(std::move(err)) << "\n"; } } std::map exports; - llvm::LLVMContext context; - ExecutionEngine *execution_engine = nullptr; + std::unique_ptr context = std::make_unique(); + std::unique_ptr JIT = nullptr; + std::unique_ptr dtorRunner = nullptr; std::vector dependencies; JITModule::Symbol entrypoint; JITModule::Symbol argv_entrypoint; @@ -156,11 +157,17 @@ void destroy(const JITModuleContents *f) { namespace { // Retrieve a function pointer from an llvm module, possibly by compiling it. -JITModule::Symbol compile_and_get_function(ExecutionEngine &ee, const string &name) { +JITModule::Symbol compile_and_get_function(llvm::orc::LLJIT &JIT, const string &name) { debug(2) << "JIT Compiling " << name << "\n"; - llvm::Function *fn = ee.FindFunctionNamed(name); - internal_assert(fn->getName() == name); - void *f = (void *)ee.getFunctionAddress(name); + + auto addr = JIT.lookup(name); + internal_assert(addr) << llvm::toString(addr.takeError()) << "\n"; + +#if LLVM_VERSION >= 150 + void *f = (void *)addr->getValue(); +#else + void *f = (void *)addr->getAddress(); +#endif if (!f) { internal_error << "Compiling " << name << " returned nullptr\n"; } @@ -233,7 +240,7 @@ JITModule::JITModule() { JITModule::JITModule(const Module &m, const LoweredFunc &fn, const std::vector &dependencies) { jit_module = new JITModuleContents(); - std::unique_ptr llvm_module(compile_module_to_llvm_module(m, jit_module->context)); + std::unique_ptr llvm_module(compile_module_to_llvm_module(m, *jit_module->context)); std::vector deps_with_runtime = dependencies; std::vector shared_runtime = JITSharedRuntime::get(llvm_module.get(), m.target()); deps_with_runtime.insert(deps_with_runtime.end(), shared_runtime.begin(), shared_runtime.end()); @@ -262,43 +269,89 @@ void JITModule::compile_module(std::unique_ptr m, const string &fu DataLayout initial_module_data_layout = m->getDataLayout(); string module_name = m->getModuleIdentifier(); - llvm::EngineBuilder engine_builder((std::move(m))); - engine_builder.setTargetOptions(options); - engine_builder.setErrorStr(&error_string); - engine_builder.setEngineKind(llvm::EngineKind::JIT); - HalideJITMemoryManager *memory_manager = new HalideJITMemoryManager(dependencies); - engine_builder.setMCJITMemoryManager(std::unique_ptr(memory_manager)); + // Build TargetMachine + llvm::orc::JITTargetMachineBuilder tm_builder(llvm::Triple(m->getTargetTriple())); + tm_builder.setOptions(options); + tm_builder.setCodeGenOptLevel(CodeGenOpt::Aggressive); + if (target.arch == Target::Arch::RISCV) { + tm_builder.setCodeModel(llvm::CodeModel::Medium); + } - engine_builder.setOptLevel(CodeGenOpt::Aggressive); + auto tm = tm_builder.createTargetMachine(); + internal_assert(tm) << llvm::toString(tm.takeError()) << "\n"; - TargetMachine *tm = engine_builder.selectTarget(); - internal_assert(tm) << error_string << "\n"; - DataLayout target_data_layout(tm->createDataLayout()); + DataLayout target_data_layout(tm.get()->createDataLayout()); if (initial_module_data_layout != target_data_layout) { internal_error << "Warning: data layout mismatch between module (" << initial_module_data_layout.getStringRepresentation() << ") and what the execution engine expects (" << target_data_layout.getStringRepresentation() << ")\n"; } - ExecutionEngine *ee = engine_builder.create(tm); - if (!ee) { - std::cerr << error_string << "\n"; - } - internal_assert(ee) << "Couldn't create execution engine\n"; - - // Do any target-specific initialization - std::vector listeners; - - if (target.arch == Target::X86) { - listeners.push_back(llvm::JITEventListener::createIntelJITEventListener()); - } - // TODO: If this ever works in LLVM, this would allow profiling of JIT code with symbols with oprofile. - // listeners.push_back(llvm::createOProfileJITEventListener()); - - for (auto &listener : listeners) { - ee->RegisterJITEventListener(listener); + // Create LLJIT + const auto compilerBuilder = [&](const llvm::orc::JITTargetMachineBuilder & /*jtmb*/) + -> llvm::Expected> { + return std::make_unique(std::move(*tm)); + }; + + llvm::orc::LLJITBuilderState::ObjectLinkingLayerCreator linkerBuilder; + if ((target.arch == Target::Arch::X86 && target.bits == 32) || + (target.arch == Target::Arch::ARM && target.bits == 32)) { + // Fallback to RTDyld-based linking to workaround errors: + // i386: "JIT session error: Unsupported i386 relocation:4" (R_386_PLT32) + // ARM 32bit: Unsupported target machine architecture in ELF object shared runtime-jitted-objectbuffer + linkerBuilder = [&](llvm::orc::ExecutionSession &session, const llvm::Triple &) { + return std::make_unique(session, [&]() { + return std::make_unique(dependencies); + }); + }; + } else { + linkerBuilder = [](llvm::orc::ExecutionSession &session, const llvm::Triple &) { + return std::make_unique(session); + }; + } + + auto JIT = llvm::cantFail(llvm::orc::LLJITBuilder() + .setDataLayout(target_data_layout) + .setCompileFunctionCreator(compilerBuilder) + .setObjectLinkingLayerCreator(linkerBuilder) + .create()); + + auto ctors = llvm::orc::getConstructors(*m); + llvm::orc::CtorDtorRunner ctorRunner(JIT->getMainJITDylib()); + ctorRunner.add(ctors); + + auto dtors = llvm::orc::getDestructors(*m); + auto dtorRunner = std::make_unique(JIT->getMainJITDylib()); + dtorRunner->add(dtors); + + // Resolve system symbols (like pthread, dl and others) + auto gen = llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(target_data_layout.getGlobalPrefix()); + internal_assert(gen) << llvm::toString(gen.takeError()) << "\n"; + JIT->getMainJITDylib().addGenerator(std::move(gen.get())); + + llvm::orc::ThreadSafeModule tsm(std::move(m), std::move(jit_module->context)); + auto err = JIT->addIRModule(std::move(tsm)); + internal_assert(!err) << llvm::toString(std::move(err)) << "\n"; + + // Resolve symbol dependencies + llvm::orc::SymbolMap newSymbols; + auto symbolStringPool = JIT->getExecutionSession().getExecutorProcessControl().getSymbolStringPool(); + for (const auto &module : dependencies) { + for (auto const &iter : module.exports()) { + orc::SymbolStringPtr name = symbolStringPool->intern(iter.first); + orc::SymbolStringPtr _name = symbolStringPool->intern("_" + iter.first); + auto symbol = llvm::JITEvaluatedSymbol::fromPointer(iter.second.address); + if (!newSymbols.count(name)) { + newSymbols.insert({name, symbol}); + } + if (!newSymbols.count(_name)) { + newSymbols.insert({_name, symbol}); + } + } } + err = JIT->getMainJITDylib().define(orc::absoluteSymbols(std::move(newSymbols))); + internal_assert(!err) << llvm::toString(std::move(err)) << "\n"; // Retrieve function pointers from the compiled module (which also // triggers compilation) @@ -310,31 +363,23 @@ void JITModule::compile_module(std::unique_ptr m, const string &fu Symbol entrypoint; Symbol argv_entrypoint; if (!function_name.empty()) { - entrypoint = compile_and_get_function(*ee, function_name); + entrypoint = compile_and_get_function(*JIT, function_name); exports[function_name] = entrypoint; - argv_entrypoint = compile_and_get_function(*ee, function_name + "_argv"); + argv_entrypoint = compile_and_get_function(*JIT, function_name + "_argv"); exports[function_name + "_argv"] = argv_entrypoint; } for (const auto &requested_export : requested_exports) { - exports[requested_export] = compile_and_get_function(*ee, requested_export); - } - - debug(2) << "Finalizing object\n"; - ee->finalizeObject(); - // Do any target-specific post-compilation module meddling - for (auto &listener : listeners) { - ee->UnregisterJITEventListener(listener); - delete listener; + exports[requested_export] = compile_and_get_function(*JIT, requested_export); } - listeners.clear(); - // TODO: I don't think this is necessary, we shouldn't have any static constructors - ee->runStaticConstructorsDestructors(false); + err = ctorRunner.run(); + internal_assert(!err) << llvm::toString(std::move(err)) << "\n"; // Stash the various objects that need to stay alive behind a reference-counted pointer. jit_module->exports = exports; - jit_module->execution_engine = ee; + jit_module->JIT = std::move(JIT); + jit_module->dtorRunner = std::move(dtorRunner); jit_module->dependencies = dependencies; jit_module->entrypoint = entrypoint; jit_module->argv_entrypoint = argv_entrypoint; @@ -362,7 +407,7 @@ JITModule JITModule::make_trampolines_module(const Target &target_arg, } std::unique_ptr llvm_module = CodeGen_LLVM::compile_trampolines( - target, result.jit_module->context, suffix, extern_signatures); + target, *result.jit_module->context, suffix, extern_signatures); result.compile_module(std::move(llvm_module), /*function_name*/ "", target, deps, requested_exports); @@ -461,7 +506,7 @@ void JITModule::reuse_device_allocations(bool b) const { } bool JITModule::compiled() const { - return jit_module->execution_engine != nullptr; + return jit_module->JIT != nullptr; } namespace { @@ -760,7 +805,7 @@ JITModule &make_module(llvm::Module *for_module, Target target, // This function is protected by a mutex so this is thread safe. auto module = get_initial_module_for_target(one_gpu, - &runtime.jit_module->context, + runtime.jit_module->context.get(), true, runtime_kind != MainShared); if (for_module) { @@ -860,13 +905,21 @@ JITModule &make_module(llvm::Module *for_module, Target target, } } - uint64_t arg_addr = - runtime.jit_module->execution_engine->getGlobalValueAddress("halide_jit_module_argument"); - + uint64_t arg_addr = llvm::cantFail(runtime.jit_module->JIT->lookup("halide_jit_module_argument")) +#if LLVM_VERSION >= 150 + .getValue(); +#else + .getAddress(); +#endif internal_assert(arg_addr != 0); *((void **)arg_addr) = runtime.jit_module.get(); - uint64_t fun_addr = runtime.jit_module->execution_engine->getGlobalValueAddress("halide_jit_module_adjust_ref_count"); + uint64_t fun_addr = llvm::cantFail(runtime.jit_module->JIT->lookup("halide_jit_module_adjust_ref_count")) +#if LLVM_VERSION >= 150 + .getValue(); +#else + .getAddress(); +#endif internal_assert(fun_addr != 0); *(void (**)(void *arg, int32_t count))fun_addr = &adjust_module_ref_count; } diff --git a/src/LLVM_Headers.h b/src/LLVM_Headers.h index b3986a4d5ddb..da0f301e11fb 100644 --- a/src/LLVM_Headers.h +++ b/src/LLVM_Headers.h @@ -35,9 +35,10 @@ #include #include #include -#include #include -#include +#include +#include +#include #include #include #include diff --git a/test/correctness/c_function.cpp b/test/correctness/c_function.cpp index cadab3386f7f..846c977b6613 100644 --- a/test/correctness/c_function.cpp +++ b/test/correctness/c_function.cpp @@ -71,7 +71,7 @@ int main(int argc, char **argv) { } if (call_counter2 != 32 * 32) { - printf("C function my_func2 was called %d times instead of %d\n", call_counter, 32 * 32); + printf("C function my_func2 was called %d times instead of %d\n", call_counter2, 32 * 32); return -1; }