Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -462,11 +462,50 @@ TVM_DLL Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions);
TVM_DLL Pass SimplifyExpr();

/*!
* \brief Run any registered RelayToTIR passes registered on the functions in a module.
* \brief Run any custom passes registered under "RelayToTIR" attributes on TargetKinds.
*
* This pass looks for inline, let-bound or global functions which have a "Compiler" attribute.
* If the attribute value corresponds to a TargetKind with a "RelayToTIR" attribute, then the
* 'custom' pass bound to that attribute is run (at most once) on the IRModule as a whole.
*
* If, in addition, the \p config has a Target with a matching TargetKind, that Target is set
* as the 'current' target before the custom pass is executed. In this way it is possible
* for custom passes to pick up target options which may guide how they transform the IRModule.
* (Those targets are referred to as 'extern codegen targets' elsewhere).
*
* A typical custom pass will:
* - Find calls to "Compiler" attributes functions with matching compiler name.
* - Lower those function to TIR PrimFuncs.
* - Bind those functions into the IRModule under the the functions' "global_symbol" attribute.
* - Replace all calls to those functions with 'call_lowered' to the matching global.
* Care should be taken to handle multiple calls to the same function.
* See src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc for an example custom pass.
*
* It is also possible (despite the pass and attribute names!) for the custom pass to proceed
* directly to a runtime::Module, which can be attached to the output IRModules "external_mods"
* attribute (taking care not to clobber any existing modules). In this case the flow is as above,
* except:
* - The runtime::Module must contain a binding for each compiled function under their
* "global_symbol" (ie runtime::Module::ImplementsFunction should return true).
* - A Relay Function must be bound (or re-bound) into the result IRModule, again with the same
* "global_symbol", but with only the "Extern" attribute set to Integer(1). The function body
* should be the original function body. In this way we always have a TVM definition matching
* every global function name.
*
* There are many existing runtime::Modules, ranging from source to object to dynamic libaries to
* entirely custom implementations. Some of those may require additional compilation using
* 'export_library' on the final build artifact.
*
* The OutlineCompilerFunctionsWithExistingGlobalSymbols and MarkCompilerFunctionsAsExtern utility
* passes can be used by custom passes to take care of some of the boilerplate.
*
* TODO(mbs): Rename PreLoweringTargetHooks?
*
* \param config All available targets.
*
* \return The pass.
*/
TVM_DLL Pass RelayToTIRTargetHook();
TVM_DLL Pass RelayToTIRTargetHook(CompilationConfig config);

/*!
* \brief A pass for manifesting explicit memory allocations and rewriting
Expand Down
10 changes: 10 additions & 0 deletions include/tvm/target/target_kind.h
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,16 @@ namespace attr {
* See also \p Target::IsExternalCodegenFor
*/
constexpr const char* kIsExternalCodegen = "is_external_codegen";

/*!
* \brief A \p TargetKind attribute of type \p FTVMRelayToTIR. If set, then the target kind name
* also corresponds to an external codegen 'compiler' name, and the bound value is a \p Pass
* to apply before the TVM lowering.
*
* See also \p Target::IsExternalCodegenFor
*/
constexpr const char* kRelayToTIR = "RelayToTIR";

} // namespace attr

/*!
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1079,7 +1079,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
// lowering process directly.
tec::UpdateFunctionMetadata(func, this->function_metadata_, workspace_byte_alignment);
},
config_->host_virtual_device)(mod);
config_)(mod);

auto lowered_main = lowered_mod->Lookup("main");
auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/contrib/cmsisnn/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ tvm::transform::Pass RelayToTIR();
runtime::Module TIRToRuntime(IRModule mod, Target target);

TVM_REGISTER_TARGET_KIND("cmsis-nn", kDLCPU)
.set_attr<FTVMRelayToTIR>("RelayToTIR", RelayToTIR())
.set_attr<FTVMRelayToTIR>(tvm::attr::kRelayToTIR, RelayToTIR())
.set_attr<FTVMTIRToRuntime>("TIRToRuntime", TIRToRuntime);

} // namespace cmsisnn
Expand Down
12 changes: 12 additions & 0 deletions src/relay/backend/contrib/codegen_c/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,14 @@ class CSourceCodegen : public CSourceModuleCodegenBase {
Array<String> variables = std::get<0>(res);
String func_name = std::get<1>(res);

Optional<Target> opt_target = Target::Current();
if (opt_target.defined() && opt_target.value()->kind->name == "ccompiler") {
Optional<String> header = opt_target.value()->GetAttr<String>("header");
if (header.defined() && !header.value().empty()) {
code_stream_ << header.value().c_str() << "\n";
}
}

// Create headers
code_stream_ << "#include <stdio.h>\n";
code_stream_ << "#include <stdlib.h>\n";
Expand Down Expand Up @@ -293,6 +301,10 @@ runtime::Module CCompiler(const ObjectRef& ref) {

TVM_REGISTER_GLOBAL("relay.ext.ccompiler").set_body_typed(CCompiler);

TVM_REGISTER_TARGET_KIND("ccompiler", kDLCPU)
.set_attr<Bool>(tvm::attr::kIsExternalCodegen, Bool(true))
.add_attr_option<String>("header", String("")); // value is prepended to every output CModule

} // namespace contrib
} // namespace relay
} // namespace tvm
2 changes: 1 addition & 1 deletion src/relay/backend/contrib/ethosu/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {

TVM_REGISTER_TARGET_KIND("ethos-u", kDLCPU)
.set_attr<Bool>("use_device_api", Bool(true))
.set_attr<FTVMRelayToTIR>("RelayToTIR", RelayToTIR())
.set_attr<FTVMRelayToTIR>(tvm::attr::kRelayToTIR, RelayToTIR())
.set_attr<FTVMTIRToRuntime>("TIRToRuntime", TIRToRuntime);

} // namespace ethosu
Expand Down
200 changes: 144 additions & 56 deletions src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,37 @@
#include <tvm/tir/op.h>

#include "../../../op/call/call.h"
#include "tvm/tir/function.h"

namespace tvm {
namespace relay {
namespace contrib {
namespace example_target_hooks {

namespace {

/*!
* \brief An example mutator for a "RelayToTIR" custom pass. Replaces every call to a Relay
* Function with "external_symbol" attribute of "replace_add_with_subtract" with a call to a
* TIR PrimFunc implementing subtraction.
*
* Illustrates six aspects a custom 'lowering' style pass may need to account for:
* - Lowerable functions can appear inline as call ops, bound to let-bound variables, or as
* global functions.
* - Let-bound lowerable functions should be inlined on-the-fly since after processing the
* let-binding is no longer required.
* - There may be multiple calls to the same lowerable function. All calls need to be
* rewritten, even though the function itself need be rewritten only once.
* - GlobalVars must be shared between all calls and the new definition itself.
* - Calls to lowered functions must use the "call_lowered" calling convention.
* - The Target::Current() may hold an instance of the TargetKind from which the custom Pass
* was extracted.
*
* Though not illustrated here, it is also valid for a "RelayToTIR" custom pass to add
* runtime::Modules to the output IRModule's "external_mods" attribute. In this case the
* IRModule must be left with an 'extern' Function definition with the matching "external_symbol"
* name.
*/
class ConvertAddToSubtract : public MixedModeMutator {
public:
explicit ConvertAddToSubtract(IRModule ir_module, Target host_target)
Expand All @@ -56,51 +81,102 @@ class ConvertAddToSubtract : public MixedModeMutator {
return tir::BufferLoad(buffer, {index});
}

void ReplaceAddWithSubtractPrimFunc(const GlobalVar& new_global_var, const Function& func) {
tir::Buffer x_buffer = tir::decl_buffer({8}, DataType::Float(32), "x");
tir::Buffer y_buffer = tir::decl_buffer({8}, DataType::Float(32), "y");
tir::Buffer out_buffer = tir::decl_buffer({8}, DataType::Float(32));
GlobalVar ReplaceAddWithSubtractPrimFunc(const Function& func) {
auto func_name = func->GetAttr<String>(::tvm::attr::kGlobalSymbol);
ICHECK(func_name.defined());

tir::Var x_var("x", DataType::Handle());
tir::Var y_var("y", DataType::Handle());
tir::Var out_var("out", DataType::Handle());
// --------------------------------------------------------------------------------------------
// Cases:
// - Inline function:
// - First encounter: create global var, rewrite to PrimFunc, add binding, replace call.
// - Thereafter (via object sharing): discover global var already in module, replace call
// - Global function:
// - Assume func_name == global_var->name_hint
// - First encounter: create global var, rewrite to PrimFunc, update binding, replace call
// - Thereafter (via global var): discover global var already in module, replace call
// --------------------------------------------------------------------------------------------

Map<String, ObjectRef> dict_attrs;
dict_attrs.Set("global_symbol", new_global_var->name_hint);
dict_attrs.Set("tir.noalias", Bool(true));
// If necessary, introduce a new global var to map the function to and copy the source type
// over for InferType.
GlobalVar global_var;
bool need_rewriting;
if (ir_module_->ContainGlobalVar(func_name.value())) {
global_var = ir_module_->GetGlobalVar(func_name.value());
// Only rewrite to a PrimFunc if the global definition is still a Relay function.
need_rewriting = ir_module_->Lookup(global_var)->IsInstance<FunctionNode>();
} else {
global_var = GlobalVar(func_name.value());
global_var->checked_type_ = func->checked_type();
need_rewriting = true;
}

te::Var index("index", DataType::Int(32));
tir::Sub indexed_sub = tir::Sub(LoadIndex(x_buffer, index), LoadIndex(y_buffer, index));
tir::Stmt math_body = tir::BufferStore(out_buffer, indexed_sub, {index});
tir::Stmt math_loop = tir::For(index, 0, 8, tir::ForKind::kSerial, math_body);
// For illustration only, check if the current target matches the example_target_hook kind,
// and if so extract the example attribute value.
int64_t example_attribute_value = 0;
Optional<Target> opt_current_target = Target::Current();
if (opt_current_target.defined() &&
opt_current_target.value()->kind->name == "example_target_hook") {
example_attribute_value =
opt_current_target.value()->GetAttr<Integer>("example_attribute").value()->value;
}

Map<tir::Var, tir::Buffer> buffer_map = {
{x_var, x_buffer},
{y_var, y_buffer},
{out_var, out_buffer},
};
if (need_rewriting) {
// The called function is still in Relay form. Convert to TIR.
tir::Buffer x_buffer = tir::decl_buffer({8}, DataType::Float(32), "x");
tir::Buffer y_buffer = tir::decl_buffer({8}, DataType::Float(32), "y");
tir::Buffer out_buffer = tir::decl_buffer({8}, DataType::Float(32));

tir::PrimFunc replacement_func = tir::PrimFunc({x_var, y_var, out_var}, math_loop, VoidType(),
buffer_map, {}, DictAttrs(dict_attrs));
tir::Var x_var("x", DataType::Handle());
tir::Var y_var("y", DataType::Handle());
tir::Var out_var("out", DataType::Handle());

// Switch to TIRToRuntime hook for testing
Bool tir_to_runtime = func->GetAttr<Bool>("tir_to_runtime").value_or(Bool(false));
if (tir_to_runtime) {
replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, custom_target_);
} else {
replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, host_target_);
Map<String, ObjectRef> dict_attrs;
dict_attrs.Set("global_symbol", global_var->name_hint);
dict_attrs.Set("tir.noalias", Bool(true));

te::Var index("index", DataType::Int(32));
tir::Sub indexed_sub = tir::Sub(LoadIndex(x_buffer, index), LoadIndex(y_buffer, index));
if (example_attribute_value > 0) {
// For illustration only, fold the example attribute into the result.
indexed_sub = tir::Sub(indexed_sub, FloatImm(DataType::Float(32),
static_cast<double>(example_attribute_value)));
}

tir::Stmt math_body = tir::BufferStore(out_buffer, indexed_sub, {index});
tir::Stmt math_loop = tir::For(index, 0, 8, tir::ForKind::kSerial, math_body);

Map<tir::Var, tir::Buffer> buffer_map = {
{x_var, x_buffer},
{y_var, y_buffer},
{out_var, out_buffer},
};

tir::PrimFunc replacement_func = tir::PrimFunc({x_var, y_var, out_var}, math_loop, VoidType(),
buffer_map, {}, DictAttrs(dict_attrs));

// Switch to TIRToRuntime hook for testing
Bool tir_to_runtime = func->GetAttr<Bool>("tir_to_runtime").value_or(Bool(false));
if (tir_to_runtime) {
replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, custom_target_);
} else {
replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, host_target_);
}

ir_module_->Update(global_var, replacement_func); // Will Add if global_var is new.
}

ir_module_->Add(new_global_var, replacement_func);
return global_var;
}

using MixedModeMutator::VisitExpr_;

Expr VisitExpr_(const LetNode* op) final {
auto pre_visit = [this](const LetNode* op) {
Expr var = this->VisitExpr(op->var);
Expr value = this->VisitExpr(op->value);

// Outlineable function no longer needs let binding
if (this->CanLowerExpr(value)) {
if (AsLowerableFunction(value)) {
// Inline on-the-fly if the let-bound value is lowerable.
this->memo_[var] = value;
}
};
Expand All @@ -110,8 +186,8 @@ class ConvertAddToSubtract : public MixedModeMutator {
Expr body = this->VisitExpr(op->body);
auto expr = GetRef<Expr>(op);

// Drop the let binding
if (this->CanLowerExpr(value)) {
if (AsLowerableFunction(value)) {
// The let binding is no longer needed since inlined on-the-fly above.
this->memo_[expr] = this->VisitExpr(op->body);
} else {
Var var = Downcast<Var>(this->VisitExpr(op->var));
Expand All @@ -126,39 +202,49 @@ class ConvertAddToSubtract : public MixedModeMutator {
return memo_[GetRef<Expr>(op)];
}

bool CanLowerExpr(const Expr& expr) {
const auto* func = expr.as<FunctionNode>();
if (func == nullptr) {
return false;
}
auto func_name = func->GetAttr<String>(::tvm::attr::kGlobalSymbol);
if (!func_name.defined()) {
return false;
const FunctionNode* AsLowerableFunction(const Expr& expr) {
if (const auto* function_node = expr.as<FunctionNode>()) {
auto func_name = function_node->GetAttr<String>(::tvm::attr::kGlobalSymbol);
if (!func_name.defined()) {
return nullptr;
}
if (func_name != "replace_add_with_subtract") {
return nullptr;
}
return function_node;
} else if (const auto* global_var_node = expr.as<GlobalVarNode>()) {
return AsLowerableFunction(ir_module_->Lookup(GetRef<GlobalVar>(global_var_node)));
} else {
return nullptr;
}
if (func_name != "replace_add_with_subtract") {
return false;
}

const GlobalVarNode* AsAlreadyLoweredFunction(const Expr& expr) {
if (const auto* global_var_node = expr.as<GlobalVarNode>()) {
if (ir_module_->Lookup(GetRef<GlobalVar>(global_var_node)).as<tir::PrimFuncNode>()) {
return global_var_node;
}
}
return true;
return nullptr;
}

Expr Rewrite_(const CallNode* pre, const Expr& post) override {
if (const CallNode* call = post.as<CallNode>()) {
if (CanLowerExpr(call->op)) {
auto* func = call->op.as<FunctionNode>();
auto func_name = func->GetAttr<String>(::tvm::attr::kGlobalSymbol);

// Introduce a new global var to map the function to and copy the source type
// over for InferType
GlobalVar new_global_var(func_name.value());
new_global_var->checked_type_ = func->checked_type();
ReplaceAddWithSubtractPrimFunc(new_global_var, GetRef<Function>(func));

if (const auto* call = post.as<CallNode>()) {
GlobalVar new_op;
if (const auto* function_node = AsLowerableFunction(call->op)) {
// Add or replace the function with a PrimFunc.
new_op = ReplaceAddWithSubtractPrimFunc(GetRef<Function>(function_node));
} else if (const auto* global_var_node = AsAlreadyLoweredFunction(call->op)) {
// The function has already been rewritten, so we just need to update the call.
new_op = GetRef<GlobalVar>(global_var_node);
}
if (new_op.defined()) {
// Since we are replacing the Relay function with a call to a TIR function, we must use
// the call_lowered op.
CallLoweredAttrs attrs;
attrs.metadata.Set("relay_attrs", call->attrs);
ICHECK(call->type_args.empty()) << "lowered functions cannot be polymorphic";
return CallLowered(std::move(new_global_var), call->args, std::move(attrs), call->span);
return CallLowered(std::move(new_op), call->args, std::move(attrs), call->span);
}
}

Expand All @@ -171,10 +257,12 @@ class ConvertAddToSubtract : public MixedModeMutator {
Target custom_target_;
};

} // namespace

transform::Pass RelayToTIR() {
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
[=](IRModule ir_module, transform::PassContext pass_context) {
auto relay_to_tir = ConvertAddToSubtract(ir_module, Target("c"));
ConvertAddToSubtract relay_to_tir(std::move(ir_module), Target("c"));
return relay_to_tir.Mutate();
};
return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIR", {});
Expand Down
Loading