diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 60973577ac92..fdb0a0aa9d1a 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -409,6 +409,32 @@ inline T Substitute(T input, const std::unordered_map& return Substitute(std::move(input), vmap); } +/*! + * \brief Substitute the var specified by vmap and legalize data types after substitution. + * \param stmt The source statement to be substituted + * \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr. + * + * Unlike `Substitute`, this allows the substitution to change the data type of the expression. + * + * \sa Substitute + * \return The result. + */ +TVM_DLL Stmt SubstituteWithDataTypeLegalization(Stmt stmt, + std::function(const Var&)> vmap); + +/*! + * \brief Substitute the var specified by vmap and legalize data types after substitution. + * \param expr The source statement to be substituted + * \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr. + * + * Unlike `Substitute`, this allows the substitution to change the data type of the expression. + * + * \sa Substitute + * \return The result. + */ +TVM_DLL PrimExpr SubstituteWithDataTypeLegalization( + PrimExpr expr, std::function(const Var&)> vmap); + /*! * \brief Recursively visit the IR in pre DFS order node, apply fvisit. * If fvisit returns false, it won't visit the children of the node. diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index e1cc9dbdd093..03a2f29bd129 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -162,9 +162,11 @@ Array IndexMapNode::MapIndices(const Array& indices, analyzer = &local_analyzer; } - Array output = final_indices.Map( - [&](PrimExpr index) { return analyzer->Simplify(Substitute(std::move(index), vmap)); }); - + Array output = final_indices.Map([&](PrimExpr index) { + PrimExpr result = SubstituteWithDataTypeLegalization( + std::move(index), [&](const Var& var) { return vmap.Get(var); }); + return analyzer->Simplify(result); + }); return output; } @@ -218,6 +220,21 @@ Array IndexMapNode::MapRanges(const Array& ranges, arith::Analyzer analyzer->Simplify(int_set.max() - int_set.min() + 1))); } } + auto output_dtype = [&]() { + int max_bits = 0; + for (const auto& range : ranges) { + max_bits = std::max(max_bits, range->extent.dtype().bits()); + } + return DataType::Int(max_bits); + }(); + output.MutateByApply([&](const Range& range) { + if (range->min.dtype() != output_dtype || range->extent.dtype() != output_dtype) { + return Range::FromMinExtent(cast(output_dtype, range->min), + cast(output_dtype, range->extent)); + } else { + return range; + } + }); return output; } @@ -227,7 +244,7 @@ Array IndexMapNode::MapShape(const Array& shape, Array ranges; for (auto& dim : shape) { - ranges.push_back(Range(0, dim)); + ranges.push_back(Range(make_zero(dim.dtype()), dim)); } Array mapped = MapRanges(std::move(ranges), analyzer); diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index c2e2489cba92..6d0ee134c805 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -809,6 +809,95 @@ void PreOrderVisit(const ObjectRef& stmt_or_expr, } } +class IRSubstituteWithDataTypeLegalization : public DataTypeLegalizer { + public: + explicit IRSubstituteWithDataTypeLegalization(std::function(const Var&)> vmap) + : vmap_(vmap) {} + + PrimExpr VisitExpr_(const VarNode* op) final { + Var var = GetRef(op); + auto ret = vmap_(var); + if (ret.defined()) { + return ret.value(); + } + return std::move(var); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto node = Downcast(StmtExprMutator::VisitExpr_(op)); + return VisitBufferAccess(std::move(node)); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + return VisitBufferAccess(std::move(node)); + } + + template + Node VisitBufferAccess(Node node) { + Buffer new_buf = GetRemappedBuffer(node->buffer); + + if (!new_buf.same_as(node->buffer)) { + auto writer = node.CopyOnWrite(); + writer->buffer = new_buf; + } + + return node; + } + + Buffer GetRemappedBuffer(Buffer buf) { + auto key = buf.get(); + auto it = buf_remap_.find(key); + if (it != buf_remap_.end()) { + return it->second; + } + + auto new_buffer_var = vmap_(buf->data); + if (new_buffer_var.defined() && !new_buffer_var.value().same_as(buf->data)) { + auto writer = buf.CopyOnWrite(); + writer->data = Downcast(new_buffer_var); + } + + buf_remap_[key] = buf; + return buf; + } + + Stmt VisitStmt_(const AttrStmtNode* op) final { + Stmt ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + // remap var node in attr + if (const auto* var_node = op->node.as()) { + if (auto mapped_var = vmap_(GetRef(var_node))) { + return AttrStmt(mapped_var, op->attr_key, op->value, op->body); + } + } + return ret; + } + + private: + // Caller provided function that defines the variables to be remapped. + std::function(const Var&)> vmap_; + + /* \brief Generated map to track buffers being remapped. + * + * If a `Var BufferNode::data` is remapped, then all buffers + * containing that data pointer should also be remapped. This map + * is used to track buffer modifications, and ensure all instances + * of a buffer are replaced by the same modified buffer object. + */ + std::unordered_map buf_remap_; +}; + +Stmt SubstituteWithDataTypeLegalization(Stmt stmt, + std::function(const Var&)> vmap) { + return IRSubstituteWithDataTypeLegalization(vmap)(std::move(stmt)); +} + +PrimExpr SubstituteWithDataTypeLegalization(PrimExpr expr, + std::function(const Var&)> vmap) { + return IRSubstituteWithDataTypeLegalization(vmap)(std::move(expr)); +} + TVM_REGISTER_GLOBAL("tir.IRTransform").set_body_typed(IRTransform); TVM_REGISTER_GLOBAL("tir.PostOrderVisit").set_body_typed([](ObjectRef node, PackedFunc f) { diff --git a/tests/python/unittest/test_index_map.py b/tests/python/unittest/test_index_map.py index 6882c2b42634..ac128690c415 100644 --- a/tests/python/unittest/test_index_map.py +++ b/tests/python/unittest/test_index_map.py @@ -21,6 +21,7 @@ import tvm.testing from tvm.ir import assert_structural_equal from tvm.tir import IndexMap, IntImm, floordiv, floormod +from tvm.runtime import const def assert_equal_index_map(map1: IndexMap, map2: IndexMap) -> None: @@ -41,6 +42,9 @@ def test_index_mapping(): assert_structural_equal(index_map.map_indices([3]), [0, 3]) assert_structural_equal(index_map.map_indices([4]), [1, 0]) assert_structural_equal(index_map.map_indices([42]), [10, 2]) + assert_structural_equal( + index_map.map_indices([const(42, "int64")]), [const(10, "int64"), const(2, "int64")] + ) def test_shape_mapping(): @@ -50,6 +54,12 @@ def test_shape_mapping(): assert_structural_equal(index_map.map_shape([16]), [4, 4]) assert_structural_equal(index_map.map_shape([14]), [4, 4]) + assert_structural_equal( + index_map.map_shape([const(16, "int64")]), [const(4, "int64"), const(4, "int64")] + ) + assert_structural_equal( + index_map.map_shape([const(14, "int64")]), [const(4, "int64"), const(4, "int64")] + ) def test_inverse(): diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index 174e9eb25cc0..0bf75becb2c0 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -376,6 +376,41 @@ def test_transform_block_layout_fail_mixed_iter_type(use_block_name): ) +def test_transform_block_layout_int64_extent(use_block_name): + @T.prim_func + def elementwise_int64_extent( + A: T.Buffer[(T.int64(128), T.int64(128)), "float32"], + B: T.Buffer[(T.int64(128), T.int64(128)), "float32"], + ) -> None: + for i, j in T.grid(T.int64(128), T.int64(128)): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + + @T.prim_func + def elementwise_int64_extent_transformed( + A: T.Buffer[(T.int64(128), T.int64(128)), "float32"], + B: T.Buffer[(T.int64(128), T.int64(128)), "float32"], + ) -> None: + for i in range(T.int64(16384)): + with T.block("B"): + vi = T.axis.remap("S", [i]) + B[vi // T.int64(128), vi % T.int64(128)] = ( + A[vi // T.int64(128), vi % T.int64(128)] * 2.0 + ) + + sch = tir.Schedule(elementwise_int64_extent, debug_mask="all") + block = "B" if use_block_name else sch.get_block("B") + sch.transform_block_layout(block, lambda i, j: (i * 128 + j,)) + print( + tvm.ir.base.get_first_structural_mismatch( + elementwise_int64_extent_transformed, sch.mod["main"] + ) + ) + tvm.ir.assert_structural_equal(elementwise_int64_extent_transformed, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise_int64_extent) + + class BasePaddingCompare(tvm.testing.CompareBeforeAfter): pad_value = tvm.testing.parameter(None)