Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,12 @@ TVM_DLL Pass FoldConstant(bool fold_qnn = false);
/*!
* \brief Split function with huge number of arguments to smaller pieces.
*
* \param max_function_args Maximum number of function arguments. If it equals 0 then SplitArgs
* shouldn't split the function.
*
* \return The pass.
*/
TVM_DLL Pass SplitArgs(int max_function_args);
TVM_DLL Pass SplitArgs(uint64_t max_function_args);

/*!
* \brief Fuse operations into expr into separate functions.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b
}

/*!
* \brief Calcluate the output shape of strided_slice, the entry point for Relay type relation
* \brief Calculate the output shape of strided_slice, the entry point for Relay type relation
*
* \param ishape The input tensor shape
* \param begin The indices to begin with in the slicing
Expand Down
9 changes: 6 additions & 3 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from . import _make
from .dyn import _make as _dyn_make
from ..expr import Tuple, Expr, Constant
from ..expr import Tuple, Expr, Constant, Call
from . import op as reg


Expand Down Expand Up @@ -1141,12 +1141,15 @@ def concatenate(data, axis):
result: relay.Expr
The concatenated tensor.
"""
data = list(data)
if not isinstance(data, Call):
data = list(data)
if not data:
raise ValueError("relay.concatenate requires data to be non-empty.")
if not isinstance(data, Call):
data = Tuple(data)
if not isinstance(axis, int):
raise ValueError("For now, we only support integer axis")
return _make.concatenate(Tuple(data), axis)
return _make.concatenate(data, axis)


def einsum(data, equation):
Expand Down
9 changes: 8 additions & 1 deletion python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1376,10 +1376,17 @@ def ToMixedPrecision(mixed_precision_type="float16", missing_op_mode=1):
def SplitArgs(max_function_args):
"""Split function with huge number of arguments to smaller pieces.

Parameters
----------
max_function_args: int
Maximum number of function arguments. If it equals 0 then SplitArgs
shouldn't split the function.


Returns
-------
ret : tvm.transform.Pass
The registered pass for constant folding.
The registered pass.
"""
return _ffi_api.SplitArgs(max_function_args)

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/target/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def max_shared_memory_per_block(self):

@property
def max_function_args(self):
return int(self.attrs.get("max_function_args", -1))
return int(self.attrs.get("max_function_args", 0))

@property
def vtcm_capacity(self):
Expand Down
136 changes: 136 additions & 0 deletions src/relay/analysis/graph_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ void GraphPartitioner::MergeFromTo(Group* child, Group* parent) {
if (child == parent) return;
// update the number of nodes of the parent group
parent->num_nodes += child->num_nodes;
parent->args_num += child->args_num;
child->parent = parent;
// update anchor ref and pattern
if (child->anchor_ref != nullptr) {
Expand All @@ -180,6 +181,10 @@ void GraphPartitioner::MergeFromTo(Group* child, Group* parent) {

void GraphPartitioner::CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink,
Group* target) {
if (postpone_node_ != nullptr) {
postponed_fusing_map_.insert({postpone_node_, src});
return;
}
if (src == sink) return;
if (visited_.count(src)) return;
visited_.insert(src);
Expand Down Expand Up @@ -220,7 +225,113 @@ size_t GraphPartitioner::CountFusedNodesWithNewChild(IndexedForwardGraph::Node*
return target->FindRoot()->num_nodes + CountNodesUptoSink_(child, dom_parent);
}

size_t GraphPartitioner::CountAdditionalArgs_(const TensorTypeNode* ttype, bool with_strides) {
size_t any_dims = 0;
for (const auto& dim : ttype->shape) {
if (dim.as<AnyNode>()) {
any_dims++;
}
}
if (with_strides && any_dims > 0) any_dims += ttype->shape.size();
return any_dims;
}

size_t GraphPartitioner::CountArgs_(IndexedForwardGraph::Node* src,
const IndexedForwardGraph& graph, bool update_postpone) {
std::unordered_set<Group*> visited_groups;
Group* gnode = groups_[src->index];
ICHECK(gnode != nullptr);
auto sum = gnode->args_num;
visited_groups.insert(gnode->FindRoot());
auto calc_args_number = [this, src, &graph, &visited_groups,
update_postpone](const relay::Expr& arg) -> size_t {
if (arg.as<VarNode>()) return 0;
auto* node = graph.node_map.at(arg.get());
Group* prev_group = groups_[node->index]->FindRoot();
if (visited_groups.count(prev_group) == 0) {
visited_groups.insert(prev_group);
if (prev_group->args_num > 0) {
// Get the number of arguments from the group
return prev_group->args_num;
} else if (update_postpone) {
// Update pointer to the node which should be postponed for deferred fusing
postpone_node_ = src;
} else {
// Calculate the number of arguments for the node which wasn't processed before
return CountArgs_(node, graph, update_postpone);
}
}
return 0;
};
if (auto call_node = GetRef<ObjectRef>(src->ref).as<CallNode>()) {
for (auto& it : call_node->args) {
sum += calc_args_number(it);
}
} else if (auto tuple_node = GetRef<ObjectRef>(src->ref).as<TupleNode>()) {
for (auto& it : tuple_node->fields) {
sum += calc_args_number(it);
}
}
return sum;
}

size_t GraphPartitioner::CountArgsLimit_(const IndexedForwardGraph::Node* child) {
auto* outputs_list = child->outputs.head;
size_t output_args = 0;
while (outputs_list != nullptr) {
output_args++;
if (auto call_node = GetRef<ObjectRef>(outputs_list->value.node->ref).as<CallNode>()) {
if (const auto* ttype = call_node->checked_type().as<TensorTypeNode>()) {
output_args += CountAdditionalArgs_(ttype, false);
}
}
outputs_list = outputs_list->next;
}
return (max_function_args_ > output_args) ? max_function_args_ - output_args : 0;
}

size_t GraphPartitioner::CountFusedArgs(const IndexedForwardGraph& graph,
IndexedForwardGraph::Node* child) {
size_t args_num = 0;
auto* outputs_list = child->outputs.head;
while (outputs_list != nullptr) {
args_num = std::max(args_num, CountArgs_(outputs_list->value.node, graph));
outputs_list = outputs_list->next;
}
return args_num;
}

void GraphPartitioner::InitGroups(const IndexedForwardGraph& graph) {
auto args_counter = [this](const tvm::Object* obj) {
size_t args_num = 0;
if (auto call_node = GetRef<ObjectRef>(obj).as<CallNode>()) {
for (auto& it : call_node->args) {
if (it.as<VarNode>() || it.as<TupleGetItemNode>()) {
args_num++;
if (const auto* ttype = it.as<ExprNode>()->checked_type().as<TensorTypeNode>()) {
args_num += CountAdditionalArgs_(ttype);
}
}
}
} else if (auto tuple_node = GetRef<ObjectRef>(obj).as<TupleNode>()) {
for (auto& it : tuple_node->fields) {
if (it.as<VarNode>() || it.as<TupleGetItemNode>()) {
args_num++;
if (const auto* ttype = it.as<ExprNode>()->checked_type().as<TensorTypeNode>()) {
args_num += CountAdditionalArgs_(ttype);
}
}
}
} else if (GetRef<ObjectRef>(obj).as<VarNode>()) {
args_num++;
if (const auto* ttype =
GetRef<ObjectRef>(obj).as<ExprNode>()->checked_type().as<TensorTypeNode>()) {
args_num += CountAdditionalArgs_(ttype);
}
}
return args_num;
};

groups_.resize(graph.post_dfs_order.size());
for (size_t nid = 0; nid < groups_.size(); ++nid) {
const auto* graph_node = graph.post_dfs_order[nid];
Expand All @@ -231,6 +342,7 @@ void GraphPartitioner::InitGroups(const IndexedForwardGraph& graph) {
if (group_node->pattern == relay::kOutEWiseFusable) {
group_node->anchor_ref = graph_node->ref;
}
group_node->args_num = args_counter(graph_node->ref);
groups_[nid] = group_node;
}
}
Expand All @@ -244,6 +356,21 @@ void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, //
auto* dom_node = post_dom_tree.nodes[nid];
Group* group_node = groups_[nid];
ICHECK(group_node != nullptr);
postpone_node_ = nullptr;
// Check if the fusing of some inputs was postponed
if (postponed_fusing_map_.count(graph_node)) {
auto range = postponed_fusing_map_.equal_range(graph_node);
for (auto it = range.first; it != range.second; ++it) {
// If the number of arguments is less than the limit then the input can be fused
if (CountArgs_(graph_node, graph, false) <= CountArgsLimit_(graph_node)) {
auto* src = it->second;
auto* snode = post_dom_tree.nodes[src->index]->parent->gnode;
if (groups_[snode->index]->anchor_ref != nullptr) continue;
CommitFuse(src, snode);
}
}
postponed_fusing_map_.erase(graph_node);
}
// no actions for opaque nodes
if (group_node->pattern == kOpaque) continue;
// no actions needed if the current node have no dominator
Expand All @@ -254,6 +381,15 @@ void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, //
// refuse the fusion if too many ops are going to be fused together
if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_)
continue;
// Refuse the fusion if too many arguments are going to be in the fused function
if (max_function_args_ > 0) {
auto limit = CountArgsLimit_(graph_node);
if (limit > 0) {
if (CountFusedArgs(graph, graph_node) > limit) {
continue;
}
}
}

if (phase == 2) {
// Fuse injective ops into intermediate tuples, if any
Expand Down
46 changes: 43 additions & 3 deletions src/relay/analysis/graph_partitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class IndexedForwardGraph {
std::vector<Node*> post_dfs_order;

/*! \brief Dump the graph into string. */
void DebugDump() {
void DebugDump() const {
std::ostringstream os;
for (size_t i = 0; i < post_dfs_order.size(); ++i) {
Node* node = post_dfs_order[i];
Expand Down Expand Up @@ -162,8 +162,12 @@ class DominatorTree {
*/
class GraphPartitioner {
public:
explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth)
: arena_(arena), opt_level_(opt_level), max_fuse_depth_(max_fuse_depth) {}
explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth,
size_t max_function_args)
: arena_(arena),
opt_level_(opt_level),
max_fuse_depth_(max_fuse_depth),
max_function_args_(max_function_args) {}
/*!
* \brief Group as a union find data structure.
*/
Expand All @@ -183,6 +187,10 @@ class GraphPartitioner {
* \brief The number of nodes belonging to this group
*/
uint32_t num_nodes{1};
/*!
* \brief The number of function arguments belonging to this group
*/
size_t args_num{0};

/*! \brief Optional attributes to annotate the grouped function. */
runtime::Map<runtime::String, ObjectRef> attrs;
Expand All @@ -205,10 +213,21 @@ class GraphPartitioner {
int opt_level_;
/*! \brief The maximum number of operations in one fused function */
size_t max_fuse_depth_;
/*! \brief The maximum number of arguments in one fused function */
size_t max_function_args_;
/*! \brief The internal groups. */
std::vector<Group*> groups_;
/*! \brief internal field used for deduplication */
std::unordered_set<IndexedForwardGraph::Node*> visited_;
/*! \brief The map with nodes which were postponed for fusing. */
std::unordered_multimap<const IndexedForwardGraph::Node*, IndexedForwardGraph::Node*>
postponed_fusing_map_;
/*!
* \brief Fusing of this node should be postponed till all child nodes are evaluated.
* It is used to calculate the number of arguments which will be passed to this node in
* the generated function.
*/
const IndexedForwardGraph::Node* postpone_node_{nullptr};
// Internal implementation of CheckPath
template <typename F>
bool CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond);
Expand Down Expand Up @@ -247,6 +266,23 @@ class GraphPartitioner {
void CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink);

size_t CountNodesUptoSink_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink);
// Count the number of additional arguments. In the case of dynamic shape,
// generated function takes several additional arguments, such as the sizes of
// the dynamic dimensions and strides.
// This function calculates the number of such additional arguments.
size_t CountAdditionalArgs_(const TensorTypeNode* ttype, bool with_strides = true);
Comment thread
echuraev marked this conversation as resolved.
// Calculate the number of arguments for the node.
size_t CountArgs_(IndexedForwardGraph::Node* src, const IndexedForwardGraph& graph,
bool update_postpone = true);
// Count the actual limit of arguments for a generated function.
// max_function_args_ specifies the number of maximum function arguments. But
// usually, output tensors are also passed to the function as arguments.
// Additionally, in the case of dynamic shape, it is necessary to take into
// account the number of parameters which specifies the sizes of the dynamic
// dimensions.
// This function computes the maximum number of arguments by the following formula:
// limit = max_function_args_ - output_args_count
size_t CountArgsLimit_(const IndexedForwardGraph::Node* child);
Comment thread
echuraev marked this conversation as resolved.

// Count the number of nodes in a fused subgraph if child is additionally fused.
// dom_parent is already known to be a part of the subgraph.
Expand All @@ -256,6 +292,10 @@ class GraphPartitioner {
// is important for correct calculation.
size_t CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child,
IndexedForwardGraph::Node* dom_parent);
// Count the number of arguments in a fused subgraph. This function also takes into account the
// number of the child's output node argument. It helps to stop fusing before the node when the
// limit will be exceeded.
size_t CountFusedArgs(const IndexedForwardGraph& graph, IndexedForwardGraph::Node* child);

// Initialize the groups.
void InitGroups(const IndexedForwardGraph& graph);
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ class RelayBuildModule : public runtime::ModuleNode {
if (config_->optional_homogeneous_target.defined()) {
// This pass currently only supports the homogeneous case.
pass_seqs.push_back(transform::SplitArgs(
config_->optional_homogeneous_target->GetAttr<Integer>("max_function_args", -1)
config_->optional_homogeneous_target->GetAttr<Integer>("max_function_args", 0)
.value()
.IntValue()));
}
Expand Down
7 changes: 7 additions & 0 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,13 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
// Always plan devices so the remaining passes don't need to distinguish homogeneous vs
// heterogeneous execution.
pass_seqs.push_back(transform::PlanDevices(config_));
if (config_->optional_homogeneous_target.defined()) {
// This pass currently only supports the homogeneous case.
pass_seqs.push_back(transform::SplitArgs(
config_->optional_homogeneous_target->GetAttr<Integer>("max_function_args", 0)
.value()
.IntValue()));
}

pass_seqs.push_back(transform::FuseOps());

Expand Down
Loading