Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ LLVM_STATIC_LIBFILES = \
linker \
ipo \
passes \
mcjit \
orcjit \
$(X86_LLVM_CONFIG_LIB) \
$(ARM_LLVM_CONFIG_LIB) \
$(OPENCL_LLVM_CONFIG_LIB) \
Expand Down
2 changes: 1 addition & 1 deletion dependencies/llvm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ endif ()
# Create options for including or excluding LLVM backends.
##

set(active_components mcjit bitwriter linker passes)
set(active_components orcjit bitwriter linker passes)
set(known_components AArch64 AMDGPU ARM Hexagon Mips NVPTX PowerPC RISCV WebAssembly X86)

foreach (comp IN LISTS known_components)
Expand Down
171 changes: 112 additions & 59 deletions src/JITModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,16 @@ class JITModuleContents {
JITModuleContents() = default;

~JITModuleContents() {
if (execution_engine != nullptr) {
execution_engine->runStaticConstructorsDestructors(true);
delete execution_engine;
if (JIT != nullptr) {
auto err = dtorRunner->run();
internal_assert(!err) << llvm::toString(std::move(err)) << "\n";
}
}

std::map<std::string, JITModule::Symbol> exports;
llvm::LLVMContext context;
ExecutionEngine *execution_engine = nullptr;
std::unique_ptr<llvm::LLVMContext> context = std::make_unique<llvm::LLVMContext>();
std::unique_ptr<llvm::orc::LLJIT> JIT = nullptr;
std::unique_ptr<llvm::orc::CtorDtorRunner> dtorRunner = nullptr;
std::vector<JITModule> dependencies;
JITModule::Symbol entrypoint;
JITModule::Symbol argv_entrypoint;
Expand All @@ -156,11 +157,17 @@ void destroy<JITModuleContents>(const JITModuleContents *f) {
namespace {

// Retrieve a function pointer from an llvm module, possibly by compiling it.
JITModule::Symbol compile_and_get_function(ExecutionEngine &ee, const string &name) {
JITModule::Symbol compile_and_get_function(llvm::orc::LLJIT &JIT, const string &name) {
debug(2) << "JIT Compiling " << name << "\n";
llvm::Function *fn = ee.FindFunctionNamed(name);
internal_assert(fn->getName() == name);
void *f = (void *)ee.getFunctionAddress(name);

auto addr = JIT.lookup(name);
internal_assert(addr) << llvm::toString(addr.takeError()) << "\n";

#if LLVM_VERSION >= 150
void *f = (void *)addr->getValue();
#else
void *f = (void *)addr->getAddress();
#endif
if (!f) {
internal_error << "Compiling " << name << " returned nullptr\n";
}
Expand Down Expand Up @@ -233,7 +240,7 @@ JITModule::JITModule() {
JITModule::JITModule(const Module &m, const LoweredFunc &fn,
const std::vector<JITModule> &dependencies) {
jit_module = new JITModuleContents();
std::unique_ptr<llvm::Module> llvm_module(compile_module_to_llvm_module(m, jit_module->context));
std::unique_ptr<llvm::Module> llvm_module(compile_module_to_llvm_module(m, *jit_module->context));
std::vector<JITModule> deps_with_runtime = dependencies;
std::vector<JITModule> shared_runtime = JITSharedRuntime::get(llvm_module.get(), m.target());
deps_with_runtime.insert(deps_with_runtime.end(), shared_runtime.begin(), shared_runtime.end());
Expand Down Expand Up @@ -262,43 +269,89 @@ void JITModule::compile_module(std::unique_ptr<llvm::Module> m, const string &fu
DataLayout initial_module_data_layout = m->getDataLayout();
string module_name = m->getModuleIdentifier();

llvm::EngineBuilder engine_builder((std::move(m)));
engine_builder.setTargetOptions(options);
engine_builder.setErrorStr(&error_string);
engine_builder.setEngineKind(llvm::EngineKind::JIT);
HalideJITMemoryManager *memory_manager = new HalideJITMemoryManager(dependencies);
engine_builder.setMCJITMemoryManager(std::unique_ptr<RTDyldMemoryManager>(memory_manager));
// Build TargetMachine
llvm::orc::JITTargetMachineBuilder tm_builder(llvm::Triple(m->getTargetTriple()));
tm_builder.setOptions(options);
tm_builder.setCodeGenOptLevel(CodeGenOpt::Aggressive);
if (target.arch == Target::Arch::RISCV) {
tm_builder.setCodeModel(llvm::CodeModel::Medium);
}

engine_builder.setOptLevel(CodeGenOpt::Aggressive);
auto tm = tm_builder.createTargetMachine();
internal_assert(tm) << llvm::toString(tm.takeError()) << "\n";

TargetMachine *tm = engine_builder.selectTarget();
internal_assert(tm) << error_string << "\n";
DataLayout target_data_layout(tm->createDataLayout());
DataLayout target_data_layout(tm.get()->createDataLayout());
if (initial_module_data_layout != target_data_layout) {
internal_error << "Warning: data layout mismatch between module ("
<< initial_module_data_layout.getStringRepresentation()
<< ") and what the execution engine expects ("
<< target_data_layout.getStringRepresentation() << ")\n";
}
ExecutionEngine *ee = engine_builder.create(tm);

if (!ee) {
std::cerr << error_string << "\n";
}
internal_assert(ee) << "Couldn't create execution engine\n";

// Do any target-specific initialization
std::vector<llvm::JITEventListener *> listeners;

if (target.arch == Target::X86) {
listeners.push_back(llvm::JITEventListener::createIntelJITEventListener());
}
// TODO: If this ever works in LLVM, this would allow profiling of JIT code with symbols with oprofile.
// listeners.push_back(llvm::createOProfileJITEventListener());

for (auto &listener : listeners) {
ee->RegisterJITEventListener(listener);
Comment thread
dkurt marked this conversation as resolved.
// Create LLJIT
const auto compilerBuilder = [&](const llvm::orc::JITTargetMachineBuilder & /*jtmb*/)
-> llvm::Expected<std::unique_ptr<llvm::orc::IRCompileLayer::IRCompiler>> {
return std::make_unique<llvm::orc::TMOwningSimpleCompiler>(std::move(*tm));
};

llvm::orc::LLJITBuilderState::ObjectLinkingLayerCreator linkerBuilder;
if ((target.arch == Target::Arch::X86 && target.bits == 32) ||
(target.arch == Target::Arch::ARM && target.bits == 32)) {
// Fallback to RTDyld-based linking to workaround errors:
// i386: "JIT session error: Unsupported i386 relocation:4" (R_386_PLT32)
// ARM 32bit: Unsupported target machine architecture in ELF object shared runtime-jitted-objectbuffer
linkerBuilder = [&](llvm::orc::ExecutionSession &session, const llvm::Triple &) {
return std::make_unique<llvm::orc::RTDyldObjectLinkingLayer>(session, [&]() {
return std::make_unique<HalideJITMemoryManager>(dependencies);
});
};
} else {
linkerBuilder = [](llvm::orc::ExecutionSession &session, const llvm::Triple &) {
return std::make_unique<llvm::orc::ObjectLinkingLayer>(session);
};
}

auto JIT = llvm::cantFail(llvm::orc::LLJITBuilder()
.setDataLayout(target_data_layout)
.setCompileFunctionCreator(compilerBuilder)
.setObjectLinkingLayerCreator(linkerBuilder)
.create());

auto ctors = llvm::orc::getConstructors(*m);
llvm::orc::CtorDtorRunner ctorRunner(JIT->getMainJITDylib());
ctorRunner.add(ctors);

auto dtors = llvm::orc::getDestructors(*m);
auto dtorRunner = std::make_unique<llvm::orc::CtorDtorRunner>(JIT->getMainJITDylib());
dtorRunner->add(dtors);

// Resolve system symbols (like pthread, dl and others)
auto gen = llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(target_data_layout.getGlobalPrefix());
internal_assert(gen) << llvm::toString(gen.takeError()) << "\n";
JIT->getMainJITDylib().addGenerator(std::move(gen.get()));

llvm::orc::ThreadSafeModule tsm(std::move(m), std::move(jit_module->context));
auto err = JIT->addIRModule(std::move(tsm));
internal_assert(!err) << llvm::toString(std::move(err)) << "\n";

// Resolve symbol dependencies
llvm::orc::SymbolMap newSymbols;
auto symbolStringPool = JIT->getExecutionSession().getExecutorProcessControl().getSymbolStringPool();
for (const auto &module : dependencies) {
for (auto const &iter : module.exports()) {
orc::SymbolStringPtr name = symbolStringPool->intern(iter.first);
orc::SymbolStringPtr _name = symbolStringPool->intern("_" + iter.first);
auto symbol = llvm::JITEvaluatedSymbol::fromPointer(iter.second.address);
if (!newSymbols.count(name)) {
newSymbols.insert({name, symbol});
}
if (!newSymbols.count(_name)) {
newSymbols.insert({_name, symbol});
}
}
}
err = JIT->getMainJITDylib().define(orc::absoluteSymbols(std::move(newSymbols)));
internal_assert(!err) << llvm::toString(std::move(err)) << "\n";

// Retrieve function pointers from the compiled module (which also
// triggers compilation)
Expand All @@ -310,31 +363,23 @@ void JITModule::compile_module(std::unique_ptr<llvm::Module> m, const string &fu
Symbol entrypoint;
Symbol argv_entrypoint;
if (!function_name.empty()) {
entrypoint = compile_and_get_function(*ee, function_name);
entrypoint = compile_and_get_function(*JIT, function_name);
exports[function_name] = entrypoint;
argv_entrypoint = compile_and_get_function(*ee, function_name + "_argv");
argv_entrypoint = compile_and_get_function(*JIT, function_name + "_argv");
exports[function_name + "_argv"] = argv_entrypoint;
}

for (const auto &requested_export : requested_exports) {
exports[requested_export] = compile_and_get_function(*ee, requested_export);
}

debug(2) << "Finalizing object\n";
ee->finalizeObject();
// Do any target-specific post-compilation module meddling
for (auto &listener : listeners) {
ee->UnregisterJITEventListener(listener);
delete listener;
exports[requested_export] = compile_and_get_function(*JIT, requested_export);
}
listeners.clear();

// TODO: I don't think this is necessary, we shouldn't have any static constructors
ee->runStaticConstructorsDestructors(false);
err = ctorRunner.run();
internal_assert(!err) << llvm::toString(std::move(err)) << "\n";

// Stash the various objects that need to stay alive behind a reference-counted pointer.
jit_module->exports = exports;
jit_module->execution_engine = ee;
jit_module->JIT = std::move(JIT);
jit_module->dtorRunner = std::move(dtorRunner);
jit_module->dependencies = dependencies;
jit_module->entrypoint = entrypoint;
jit_module->argv_entrypoint = argv_entrypoint;
Expand Down Expand Up @@ -362,7 +407,7 @@ JITModule JITModule::make_trampolines_module(const Target &target_arg,
}

std::unique_ptr<llvm::Module> llvm_module = CodeGen_LLVM::compile_trampolines(
target, result.jit_module->context, suffix, extern_signatures);
target, *result.jit_module->context, suffix, extern_signatures);

result.compile_module(std::move(llvm_module), /*function_name*/ "", target, deps, requested_exports);

Expand Down Expand Up @@ -461,7 +506,7 @@ void JITModule::reuse_device_allocations(bool b) const {
}

bool JITModule::compiled() const {
return jit_module->execution_engine != nullptr;
return jit_module->JIT != nullptr;
}

namespace {
Expand Down Expand Up @@ -760,7 +805,7 @@ JITModule &make_module(llvm::Module *for_module, Target target,
// This function is protected by a mutex so this is thread safe.
auto module =
get_initial_module_for_target(one_gpu,
&runtime.jit_module->context,
runtime.jit_module->context.get(),
true,
runtime_kind != MainShared);
if (for_module) {
Expand Down Expand Up @@ -860,13 +905,21 @@ JITModule &make_module(llvm::Module *for_module, Target target,
}
}

uint64_t arg_addr =
runtime.jit_module->execution_engine->getGlobalValueAddress("halide_jit_module_argument");

uint64_t arg_addr = llvm::cantFail(runtime.jit_module->JIT->lookup("halide_jit_module_argument"))
#if LLVM_VERSION >= 150
.getValue();
#else
.getAddress();
#endif
internal_assert(arg_addr != 0);
*((void **)arg_addr) = runtime.jit_module.get();

uint64_t fun_addr = runtime.jit_module->execution_engine->getGlobalValueAddress("halide_jit_module_adjust_ref_count");
uint64_t fun_addr = llvm::cantFail(runtime.jit_module->JIT->lookup("halide_jit_module_adjust_ref_count"))
#if LLVM_VERSION >= 150
.getValue();
#else
.getAddress();
#endif
internal_assert(fun_addr != 0);
*(void (**)(void *arg, int32_t count))fun_addr = &adjust_module_ref_count;
}
Expand Down
5 changes: 3 additions & 2 deletions src/LLVM_Headers.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@
#include <llvm/Analysis/TargetTransformInfo.h>
#include <llvm/Bitcode/BitcodeReader.h>
#include <llvm/Bitcode/BitcodeWriter.h>
#include <llvm/ExecutionEngine/ExecutionEngine.h>
#include <llvm/ExecutionEngine/JITEventListener.h>
#include <llvm/ExecutionEngine/MCJIT.h>
#include <llvm/ExecutionEngine/Orc/LLJIT.h>
#include <llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h>
#include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h>
#include <llvm/ExecutionEngine/SectionMemoryManager.h>
#include <llvm/IR/Constant.h>
#include <llvm/IR/Constants.h>
Expand Down
2 changes: 1 addition & 1 deletion test/correctness/c_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ int main(int argc, char **argv) {
}

if (call_counter2 != 32 * 32) {
printf("C function my_func2 was called %d times instead of %d\n", call_counter, 32 * 32);
printf("C function my_func2 was called %d times instead of %d\n", call_counter2, 32 * 32);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eek, nice catch :-)

return -1;
}

Expand Down