Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,32 @@ inline T Substitute(T input, const std::unordered_map<const VarNode*, PrimExpr>&
return Substitute(std::move(input), vmap);
}

/*!
* \brief Substitute the var specified by vmap and legalize data types after substitution.
* \param stmt The source statement to be substituted
* \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr.
*
* Unlike `Substitute`, this allows the substitution to change the data type of the expression.
*
* \sa Substitute
* \return The result.
*/
TVM_DLL Stmt SubstituteWithDataTypeLegalization(Stmt stmt,
std::function<Optional<PrimExpr>(const Var&)> vmap);

/*!
* \brief Substitute the var specified by vmap and legalize data types after substitution.
* \param expr The source statement to be substituted
* \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr.
*
* Unlike `Substitute`, this allows the substitution to change the data type of the expression.
*
* \sa Substitute
* \return The result.
*/
TVM_DLL PrimExpr SubstituteWithDataTypeLegalization(
PrimExpr expr, std::function<Optional<PrimExpr>(const Var&)> vmap);

/*!
* \brief Recursively visit the IR in pre DFS order node, apply fvisit.
* If fvisit returns false, it won't visit the children of the node.
Expand Down
25 changes: 21 additions & 4 deletions src/tir/ir/index_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,11 @@ Array<PrimExpr> IndexMapNode::MapIndices(const Array<PrimExpr>& indices,
analyzer = &local_analyzer;
}

Array<PrimExpr> output = final_indices.Map(
[&](PrimExpr index) { return analyzer->Simplify(Substitute(std::move(index), vmap)); });

Array<PrimExpr> output = final_indices.Map([&](PrimExpr index) {
PrimExpr result = SubstituteWithDataTypeLegalization(
std::move(index), [&](const Var& var) { return vmap.Get(var); });
return analyzer->Simplify(result);
});
return output;
}

Expand Down Expand Up @@ -218,6 +220,21 @@ Array<Range> IndexMapNode::MapRanges(const Array<Range>& ranges, arith::Analyzer
analyzer->Simplify(int_set.max() - int_set.min() + 1)));
}
}
auto output_dtype = [&]() {
int max_bits = 0;
for (const auto& range : ranges) {
max_bits = std::max(max_bits, range->extent.dtype().bits());
}
return DataType::Int(max_bits);
}();
output.MutateByApply([&](const Range& range) {
if (range->min.dtype() != output_dtype || range->extent.dtype() != output_dtype) {
return Range::FromMinExtent(cast(output_dtype, range->min),
cast(output_dtype, range->extent));
} else {
return range;
}
});
return output;
}

Expand All @@ -227,7 +244,7 @@ Array<PrimExpr> IndexMapNode::MapShape(const Array<PrimExpr>& shape,

Array<Range> ranges;
for (auto& dim : shape) {
ranges.push_back(Range(0, dim));
ranges.push_back(Range(make_zero(dim.dtype()), dim));
}
Array<Range> mapped = MapRanges(std::move(ranges), analyzer);

Expand Down
89 changes: 89 additions & 0 deletions src/tir/ir/stmt_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,95 @@ void PreOrderVisit(const ObjectRef& stmt_or_expr,
}
}

class IRSubstituteWithDataTypeLegalization : public DataTypeLegalizer {
public:
explicit IRSubstituteWithDataTypeLegalization(std::function<Optional<PrimExpr>(const Var&)> vmap)
: vmap_(vmap) {}

PrimExpr VisitExpr_(const VarNode* op) final {
Var var = GetRef<Var>(op);
auto ret = vmap_(var);
if (ret.defined()) {
return ret.value();
}
return std::move(var);
}

PrimExpr VisitExpr_(const BufferLoadNode* op) final {
auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
return VisitBufferAccess(std::move(node));
}

Stmt VisitStmt_(const BufferStoreNode* op) final {
auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
return VisitBufferAccess(std::move(node));
}

template <typename Node>
Node VisitBufferAccess(Node node) {
Buffer new_buf = GetRemappedBuffer(node->buffer);

if (!new_buf.same_as(node->buffer)) {
auto writer = node.CopyOnWrite();
writer->buffer = new_buf;
}

return node;
}

Buffer GetRemappedBuffer(Buffer buf) {
auto key = buf.get();
auto it = buf_remap_.find(key);
if (it != buf_remap_.end()) {
return it->second;
}

auto new_buffer_var = vmap_(buf->data);
if (new_buffer_var.defined() && !new_buffer_var.value().same_as(buf->data)) {
auto writer = buf.CopyOnWrite();
writer->data = Downcast<Var>(new_buffer_var);
}

buf_remap_[key] = buf;
return buf;
}

Stmt VisitStmt_(const AttrStmtNode* op) final {
Stmt ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<AttrStmtNode>();
// remap var node in attr
if (const auto* var_node = op->node.as<VarNode>()) {
if (auto mapped_var = vmap_(GetRef<Var>(var_node))) {
return AttrStmt(mapped_var, op->attr_key, op->value, op->body);
}
}
return ret;
}

private:
// Caller provided function that defines the variables to be remapped.
std::function<Optional<PrimExpr>(const Var&)> vmap_;

/* \brief Generated map to track buffers being remapped.
*
* If a `Var BufferNode::data` is remapped, then all buffers
* containing that data pointer should also be remapped. This map
* is used to track buffer modifications, and ensure all instances
* of a buffer are replaced by the same modified buffer object.
*/
std::unordered_map<const BufferNode*, Buffer> buf_remap_;
};

Stmt SubstituteWithDataTypeLegalization(Stmt stmt,
std::function<Optional<PrimExpr>(const Var&)> vmap) {
return IRSubstituteWithDataTypeLegalization(vmap)(std::move(stmt));
}

PrimExpr SubstituteWithDataTypeLegalization(PrimExpr expr,
std::function<Optional<PrimExpr>(const Var&)> vmap) {
return IRSubstituteWithDataTypeLegalization(vmap)(std::move(expr));
}

TVM_REGISTER_GLOBAL("tir.IRTransform").set_body_typed(IRTransform);

TVM_REGISTER_GLOBAL("tir.PostOrderVisit").set_body_typed([](ObjectRef node, PackedFunc f) {
Expand Down
10 changes: 10 additions & 0 deletions tests/python/unittest/test_index_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import tvm.testing
from tvm.ir import assert_structural_equal
from tvm.tir import IndexMap, IntImm, floordiv, floormod
from tvm.runtime import const


def assert_equal_index_map(map1: IndexMap, map2: IndexMap) -> None:
Expand All @@ -41,6 +42,9 @@ def test_index_mapping():
assert_structural_equal(index_map.map_indices([3]), [0, 3])
assert_structural_equal(index_map.map_indices([4]), [1, 0])
assert_structural_equal(index_map.map_indices([42]), [10, 2])
assert_structural_equal(
index_map.map_indices([const(42, "int64")]), [const(10, "int64"), const(2, "int64")]
)


def test_shape_mapping():
Expand All @@ -50,6 +54,12 @@ def test_shape_mapping():
assert_structural_equal(index_map.map_shape([16]), [4, 4])

assert_structural_equal(index_map.map_shape([14]), [4, 4])
assert_structural_equal(
index_map.map_shape([const(16, "int64")]), [const(4, "int64"), const(4, "int64")]
)
assert_structural_equal(
index_map.map_shape([const(14, "int64")]), [const(4, "int64"), const(4, "int64")]
)


def test_inverse():
Expand Down
35 changes: 35 additions & 0 deletions tests/python/unittest/test_tir_schedule_transform_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,41 @@ def test_transform_block_layout_fail_mixed_iter_type(use_block_name):
)


def test_transform_block_layout_int64_extent(use_block_name):
@T.prim_func
def elementwise_int64_extent(
A: T.Buffer[(T.int64(128), T.int64(128)), "float32"],
B: T.Buffer[(T.int64(128), T.int64(128)), "float32"],
) -> None:
for i, j in T.grid(T.int64(128), T.int64(128)):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0

@T.prim_func
def elementwise_int64_extent_transformed(
A: T.Buffer[(T.int64(128), T.int64(128)), "float32"],
B: T.Buffer[(T.int64(128), T.int64(128)), "float32"],
) -> None:
for i in range(T.int64(16384)):
with T.block("B"):
vi = T.axis.remap("S", [i])
B[vi // T.int64(128), vi % T.int64(128)] = (
A[vi // T.int64(128), vi % T.int64(128)] * 2.0
)

sch = tir.Schedule(elementwise_int64_extent, debug_mask="all")
block = "B" if use_block_name else sch.get_block("B")
sch.transform_block_layout(block, lambda i, j: (i * 128 + j,))
print(
tvm.ir.base.get_first_structural_mismatch(
elementwise_int64_extent_transformed, sch.mod["main"]
)
)
tvm.ir.assert_structural_equal(elementwise_int64_extent_transformed, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=elementwise_int64_extent)


class BasePaddingCompare(tvm.testing.CompareBeforeAfter):
pad_value = tvm.testing.parameter(None)

Expand Down