Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/tir/transforms/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ class LoopPartitioner : public StmtMutator {
}

Stmt VisitStmt_(const ForNode* op) final {
analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent), true);
auto fs = GetRef<Stmt>(op);
if (selector.candidates.count(fs)) {
Stmt s = TryPartition(fs, op->loop_var, op->min, op->min + op->extent - 1, op->body, false);
Expand Down Expand Up @@ -697,12 +698,13 @@ inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt b
const ForNode* for_node = static_cast<const ForNode*>(node);
ICHECK(for_node);
if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1)) &&
!no_unroll_loop_with_extent_one_) {
!no_unroll_loop_with_extent_one_ && for_node->annotations.empty()) {
// If the loop extent is 1, do not create the loop anymore
return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}});
} else {
ICHECK(for_node->kind != ForKind::kThreadBinding);
return For(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent, for_node->kind, body);
return For(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent, for_node->kind, body,
for_node->thread_binding, for_node->annotations);
}
}

Expand Down
128 changes: 110 additions & 18 deletions tests/python/unittest/test_tir_transform_loop_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,17 @@ def test_explicit_partition_hint():
assert tvm.ir.structural_equal(mod["main"], partitioned_concat)


def partition_from_scheduled_tir(prim_func, pass_cfg):
with tvm.transform.PassContext(config=pass_cfg):
mod = IRModule.from_expr(prim_func)
mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
mod = tvm.tir.transform.FlattenBuffer()(mod)
mod = tvm.tir.transform.LoopPartition()(mod)
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.RemoveNoOp()(mod)
return mod


@T.prim_func
def partitioned_concat_3(
placeholder: T.Buffer[(50176,), "int8"],
Expand Down Expand Up @@ -609,13 +620,9 @@ def concat_func_3(


def test_condition_mutually_exclusive():
mod = IRModule.from_expr(concat_func_3)
with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
mod = tvm.tir.transform.FlattenBuffer()(mod)
mod = tvm.tir.transform.LoopPartition()(mod)
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.RemoveNoOp()(mod)
mod = partition_from_scheduled_tir(
concat_func_3, {"tir.LoopPartition": {"partition_const_loop": True}}
)
assert tvm.ir.structural_equal(mod["main"], partitioned_concat_3)


Expand Down Expand Up @@ -650,23 +657,108 @@ def partitioned_main(A: T.Buffer[150528, "int8"], B: T.Buffer[25088, "int8"]) ->
if ax2 < 5 and ax3 < 3:
B[ax1 * 112 + ax2 * 16 + ax3] = A[ax3 * 50176 + ax1 * 224 + ax2 + 219]

mod = tvm.ir.module.IRModule.from_expr(main)
with tvm.transform.PassContext(
config={
mod = partition_from_scheduled_tir(
main,
{
"tir.LoopPartition": {
"partition_const_loop": True,
"unroll_loop_with_partition_hint_no_interval": True,
}
}
):
mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
mod = tvm.tir.transform.FlattenBuffer()(mod)
mod = tvm.tir.transform.LoopPartition()(mod)
mod = tvm.tir.transform.UnrollLoop()(mod)
mod = tvm.tir.transform.RemoveNoOp()(mod)
mod = tvm.tir.transform.Simplify()(mod)
},
)
mod = tvm.tir.transform.UnrollLoop()(mod)
mod = tvm.tir.transform.RemoveNoOp()(mod)
mod = tvm.tir.transform.Simplify()(mod)
assert tvm.ir.structural_equal(mod["main"], partitioned_main)


def test_loop_partition_keep_loop_annotations():
@T.prim_func
def before(A: T.Buffer[160, "int32"], B: T.Buffer[160, "int32"]) -> None:
for i in T.serial(
160,
annotations={"pragma_loop_partition_hint": True, "key": "value"},
):
if i < 10:
B[i] = A[i] + 1
elif 10 <= i and i < 150:
B[i] = A[i] + 2
else:
B[i] = A[i] + 3

@T.prim_func
def after(A: T.Buffer[160, "int32"], B: T.Buffer[160, "int32"]) -> None:
T.preflattened_buffer(A, [160], dtype="int32", data=A.data)
T.preflattened_buffer(B, [160], dtype="int32", data=B.data)
for i in T.serial(10, annotations={"key": "value"}):
B[i] = A[i] + 1
for i in T.serial(140, annotations={"key": "value"}):
B[i + 10] = A[i + 10] + 2
for i in T.serial(10, annotations={"key": "value"}):
B[i + 150] = A[i + 150] + 3

mod = partition_from_scheduled_tir(
before,
{
"tir.LoopPartition": {
"partition_const_loop": True,
}
},
)
assert tvm.ir.structural_equal(mod["main"], after)


def test_loop_partition_with_unit_loop_in_condition():
@T.prim_func
def before(
placeholder: T.Buffer[(50176,), "int8"],
placeholder_1: T.Buffer[(25088,), "int8"],
placeholder_2: T.Buffer[(25088,), "int8"],
T_concat: T.Buffer[(100352,), "int8"],
) -> None:
for k in range(1, annotations={"preserve_unit_loop": True}):
for i1 in T.serial(128, annotations={"pragma_loop_partition_hint": 1}):
for i2, i3 in T.grid(28, 28):
if 96 <= k * 128 + i1:
T_concat[k * i1 * 784 + i2 * 28 + i3] = placeholder_2[
i1 * 784 + i2 * 28 + i3 - 75264
]
if 64 <= k * 128 + i1 and k * 128 + i1 < 96:
T_concat[i1 * 784 + i2 * 28 + i3] = placeholder_1[
i1 * 784 + i2 * 28 + i3 - 50176
]
if k * 128 + i1 < 64:
T_concat[i1 * 784 + i2 * 28 + i3] = placeholder[i1 * 784 + i2 * 28 + i3]

@T.prim_func
def after(
placeholder: T.Buffer[50176, "int8"],
placeholder_1: T.Buffer[25088, "int8"],
placeholder_2: T.Buffer[25088, "int8"],
T_concat: T.Buffer[100352, "int8"],
) -> None:
T.preflattened_buffer(placeholder, [50176], dtype="int8", data=placeholder.data)
T.preflattened_buffer(placeholder_1, [25088], dtype="int8", data=placeholder_1.data)
T.preflattened_buffer(placeholder_2, [25088], dtype="int8", data=placeholder_2.data)
T.preflattened_buffer(T_concat, [100352], dtype="int8", data=T_concat.data)
for _ in T.serial(1, annotations={"preserve_unit_loop": True}):
for i1, i2, i3 in T.grid(64, 28, 28):
T_concat[i1 * 784 + i2 * 28 + i3] = placeholder[i1 * 784 + i2 * 28 + i3]
for i1, i2, i3 in T.grid(32, 28, 28):
T_concat[i1 * 784 + i2 * 28 + i3 + 50176] = placeholder_1[i1 * 784 + i2 * 28 + i3]
for i1, i2, i3 in T.grid(32, 28, 28):
T_concat[i2 * 28 + i3] = placeholder_2[i1 * 784 + i2 * 28 + i3]

mod = partition_from_scheduled_tir(
before,
{
"tir.LoopPartition": {
"partition_const_loop": True,
}
},
)
assert tvm.ir.structural_equal(mod["main"], after)


if __name__ == "__main__":
tvm.testing.main()