Skip to content

[TIR] Prevent loop binding over-simplification#11578

Merged
spectrometerHBH merged 1 commit into
apache:mainfrom
junrushao:bugfix/2022-06-04/tir-split-over-simplification
Jun 5, 2022
Merged

[TIR] Prevent loop binding over-simplification#11578
spectrometerHBH merged 1 commit into
apache:mainfrom
junrushao:bugfix/2022-06-04/tir-split-over-simplification

Conversation

@junrushao

@junrushao junrushao commented Jun 5, 2022

Copy link
Copy Markdown
Member

@vinx13 @jinhongyii and I observe a recent regression on TVM mainline: over-simplification in Schedule.split leads to information loss that negatively impacts search space generation.

Impact. This affects common operators like softmax and even simpler reductions.

Example. Consider splitting a simple reduction loop:

@T.prim_func
def main(
    A: T.Buffer[2, "float32"],
    B: T.Buffer[2, "float32"],
    C: T.Buffer[(), "float32"],
) -> None:
    for i in T.serial(2):   # <= split `i` into `i_0` and `i_1`, where `i_0` is a trivial loop
        with T.block("C"):
            k = T.axis.reduce(2, i)
            with T.init():
                C[()] = T.float32(1)
            C[()] = T.min(C[()], A[k] / B[k])

Splitting loop i by factors [1, 2], we get:

@T.prim_func
def main(
    A: T.Buffer[2, "float32"],
    B: T.Buffer[2, "float32"],
    C: T.Buffer[(), "float32"],
) -> None:
    for i_0, i_1 in T.grid(1, 2):
        with T.block("C"):
            k = T.axis.reduce(2, i_1)   # <= i_0 is not part of the binding, so the system cannot tell if i_0 is a reduction loop
            with T.init():
                C[()] = T.float32(1)
            C[()] = T.min(C[()], A[k] / B[k])

In this case, loop i_0 will be considered as a spatial loop, even it’s the outcome of splitting a reduction loop. However, if we change the factors from [1, 2] to [2, 1], loop i_0 becomes a reduction loop. This means the loop iteration property depends on the loop extent.

Why is it problematic? MetaSchedule has an assumption: extremely seldomly, a loop extent would impact the iteration property of the loop itself, i.e. no matter the extent is 1 or 2 or anything, the fact that the loop is a reduction loop should rarely change.

As an example, Auto-Bind finds the outer k spatial loops, which are fused together and bound to thread axis. In the trace, the number (k) of the outer loops has to be a constant.

However, if Auto-Bind thinks there are k=3 outer loops to fuse during search space generation, where the last loop happens to be a reduction loop with extent 1, as shown below:

for spatial_loop_0 in range(...):
  for spatial_loop_1 in range(...):
    for reduction_loop in range(1):  # <= Auto-Bind mistakes this loop as spatial, because its extent is 1

During evolutionary search, the extent of reduction_loop will change and become larger than 1. In this case, the binding strategy will consistently fail because it considers fusing k=3 loops - which means the entire search strategy will fail with almost no valid candidates.

Thanks @MasterJH5574 for figuring out the root cause of the issue, and @jinhongyii for valuable pointers to the right fix!

@junrushao

Copy link
Copy Markdown
Member Author

CC: @wrongtest @vinx13 @spectrometerHBH @Hzfengsy

@junrushao junrushao changed the title [TIR] Prevent loop bining over-simplification [TIR] Prevent loop binding over-simplification Jun 5, 2022
@junrushao junrushao force-pushed the bugfix/2022-06-04/tir-split-over-simplification branch 2 times, most recently from 3567e1f to 554b701 Compare June 5, 2022 02:58
Comment thread tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py Outdated
@vinx13 @jinhongyii and I observe a recent regression on TVM mainline: over-simplification in
`Schedule.split` leads to information loss that negatively impacts search space generation.

**Impact.** This affects common operators like `softmax` and even simpler reductions.

**Example.** Consider splitting a simple reduction loop:

```python
@T.prim_func
def main(
    A: T.Buffer[2, "float32"],
    B: T.Buffer[2, "float32"],
    C: T.Buffer[(), "float32"],
) -> None:
    for i in T.serial(2):  # <= split `i` into `i_0` and `i_1`, where `i_0` is a trivial loop
        with T.block("C"):
            k = T.axis.reduce(2, i)
            with T.init():
                C[()] = T.float32(1)
            C[()] = T.min(C[()], A[k] / B[k])
```

Splitting loop `i`  by factors `[1, 2]`, we get:

```python
@T.prim_func
def main(
    A: T.Buffer[2, "float32"],
    B: T.Buffer[2, "float32"],
    C: T.Buffer[(), "float32"],
) -> None:
    for i_0, i_1 in T.grid(1, 2):
        with T.block("C"):
            k = T.axis.reduce(2, i_1)  # <= i_0 is not part of the binding,
                                       # so the system cannot tell if i_0 is a reduction loop
            with T.init():
                C[()] = T.float32(1)
            C[()] = T.min(C[()], A[k] / B[k])
```

In this case, loop `i_0` will be considered as a spatial loop, even it’s the outcome of splitting
a reduction loop. However, if we change the factors from `[1, 2]` to `[2, 1]`, loop `i_0` becomes
a reduction loop. This means the loop iteration property depends on the loop extent.

**Why is it problematic**? MetaSchedule has an assumption: extremely seldomly, a loop extent would
impact the iteration property of the loop itself, i.e. no matter the extent is 1 or 2 or anything,
the fact that the loop is a reduction loop should rarely change.

As an example, `Auto-Bind` finds the outer `k` spatial loops, which are fused together and bound to
thread axis. In the trace, the number (`k`) of the outer loops has to be a constant.

However, if Auto-Bind thinks there are `k=3` outer loops to fuse during search space generation,
where the last loop happens to be a reduction loop with extent 1, as shown below:

```python
for spatial_loop_0 in range(...):
  for spatial_loop_1 in range(...):
    for reduction_loop in range(1):  # <= Auto-Bind mistakes this loop as spatial, because extent==1
```

During evolutionary search, the extent of reduction_loop will change and become larger than 1.
In this case, the binding strategy will consistently fail because it considers fusing `k=3` loops
- which means the entire search strategy will fail with almost no valid candidates.

Thanks @MasterJH5574 for figuring out the root cause of the issue,
and @jinhongyii for valuable pointers to the right fix!
@junrushao junrushao force-pushed the bugfix/2022-06-04/tir-split-over-simplification branch from 554b701 to 644576c Compare June 5, 2022 04:46
@spectrometerHBH spectrometerHBH merged commit c732828 into apache:main Jun 5, 2022
junrushao added a commit to junrushao/tvm that referenced this pull request Jun 5, 2022
Follow-up of apache#11578, which enforces structural stability of TIR by
avoiding over-simplification in affine analysis. On the other hand, it
is possible that over-simplification could be desirable behavior.
Therefore, following the precedent of `preserve-unit-loops` in
`Compute-At`, this PR introduces `preserve-unit-iters` in block binding
for cases where users don't need structural stability (which is
admittedly rare).
junrushao added a commit to junrushao/tvm that referenced this pull request Jun 5, 2022
Follow-up of apache#11578, which enforces structural stability of TIR by
avoiding over-simplification in affine analysis. On the other hand, it
is possible that over-simplification could be desirable behavior.
Therefore, following the precedent of `preserve-unit-loops` in
`Compute-At`, this PR introduces `preserve-unit-iters` in block binding
for cases where users don't need structural stability (which is
admittedly rare).

This PR does not affect any existing functionalities.
junrushao added a commit to junrushao/tvm that referenced this pull request Jun 16, 2022
Follow-up of apache#11578, which enforces structural stability of TIR by
avoiding over-simplification in affine analysis. On the other hand, it
is possible that over-simplification could be desirable behavior.
Therefore, following the precedent of `preserve-unit-loops` in
`Compute-At`, this PR introduces `preserve-unit-iters` in block binding
for cases where users don't need structural stability (which is
admittedly rare).

This PR does not affect any existing functionalities.
junrushao added a commit to junrushao/tvm that referenced this pull request Jun 16, 2022
Follow-up of apache#11578, which enforces structural stability of TIR by
avoiding over-simplification in affine analysis. On the other hand, it
is possible that over-simplification could be desirable behavior.
Therefore, following the precedent of `preserve-unit-loops` in
`Compute-At`, this PR introduces `preserve-unit-iters` in block binding
for cases where users don't need structural stability (which is
admittedly rare).

This PR does not affect any existing functionalities.

Example:

```python
for i in T.serial(2):
    with T.block("C"):
        k = T.axis.reduce(2, i)

Split(i, [1, 2], preserve-unit-iters=True/False)

for i_0, i_1 in T.grid(1, 2):
    with T.block("C"):
        k = T.axis.reduce(2, i_0 * 2 + i_1)

for i_0, i_1 in T.grid(1, 2):
    with T.block("C"):
        k = T.axis.reduce(2, i_1)
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants