Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -349,11 +349,12 @@ IterMapResult DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range
* \param input_iters Map from variable to iterator's range.
* \param input_pred The predicate constraints on the input iterators
* \param check_level The iter mapping checking level.
*
* \param simplify_trivial_iterators If true, iterators with unit extents are simplified
* \return The indices after rewrite
*/
Array<PrimExpr> IterMapSimplify(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& input_pred, IterMapLevel check_level);
const PrimExpr& input_pred, IterMapLevel check_level,
bool simplify_trivial_iterators = true);

/*!
* \brief Apply the inverse of the affine transformation to the outputs.
Expand Down
6 changes: 4 additions & 2 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1720,10 +1720,12 @@ PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr) {
TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed(NormalizeIterMapToExpr);

Array<PrimExpr> IterMapSimplify(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& input_pred, IterMapLevel check_level) {
const PrimExpr& input_pred, IterMapLevel check_level,
bool simplify_trivial_iterators) {
if (!IterRangeSanityCheck(input_iters)) return indices;
Analyzer analyzer;
auto res = DetectIterMap(indices, input_iters, input_pred, check_level, &analyzer);
auto res = DetectIterMap(indices, input_iters, input_pred, check_level, &analyzer,
/*simplify_trivial_iterators=*/simplify_trivial_iterators);
Array<IterSumExpr> rewrite = res->indices;

if (rewrite.empty()) {
Expand Down
5 changes: 3 additions & 2 deletions src/tir/schedule/primitive/loop_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator {
Array<PrimExpr> v = arith::IterMapSimplify(/*indices=*/op->iter_values,
/*input_iters=*/loop_var2extent_,
/*input_pred=*/op->predicate,
/*check_level=*/arith::IterMapLevel::Surjective);
/*check_level=*/arith::IterMapLevel::Surjective,
/*simplify_trivial_iterators=*/false);
if (v.same_as(op->iter_values)) {
return GetRef<Stmt>(op);
} else {
Expand Down Expand Up @@ -397,7 +398,7 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
for (int i = 0; i < n; i++) {
const PrimExpr& factor = factors[i];
Var var = loop->loop_var.copy_with_suffix("_" + std::to_string(i));
if (!is_one(factor)) substitute_value = substitute_value * factor + var;
substitute_value = substitute_value * factor + var;
analyzer.Bind(var, Range::FromMinExtent(0, factor));
new_loop_vars.emplace_back(std::move(var));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None:
with T.block("C"):
i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4)
j = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4)
k = T.axis.reduce(512, i2_1 * 32 + i2_2)
k = T.axis.reduce(512, i2_0 * 512 + i2_1 * 32 + i2_2)
T.reads([A_shared[i, k], B_shared[k, j]])
T.writes([C_local[i, j]])
with T.init():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring

from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply
from tvm.meta_schedule.testing.schedule_rule import auto_bind
from tvm.meta_schedule.testing.space_generation import check_trace
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_tir_schedule_reorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def cascade_pool_ops_tile_reordered(
)
for h_i, w, kh, kw in T.grid(4, 108, 3, 3):
with T.block("pool_1"):
ax0 = T.axis.spatial(1, 0)
ax0 = T.axis.spatial(1, n)
ax1 = T.axis.spatial(16, c)
ax2 = T.axis.spatial(108, h_o * 4 + h_i)
ax3, rv0, rv1 = T.axis.remap("SRR", [w, kh, kw])
Expand Down
12 changes: 6 additions & 6 deletions tests/python/unittest/test_tir_schedule_split_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def elementwise_split_case0(a: T.handle, b: T.handle) -> None:
B = T.match_buffer(b, [128, 128, 128])
for i1, i2, i3, j1, j2, k1, k2 in T.grid(2, 1, 64, 4, 32, 16, 8):
with T.block("B"):
vi = T.axis.S(128, i1 * 64 + i3)
vi = T.axis.S(128, (i1 + i2) * 64 + i3)
vj = T.axis.S(128, j1 * 32 + j2)
vk = T.axis.S(128, k1 * 8 + k2)
T.reads([A[vi, vj, vk]])
Expand All @@ -192,9 +192,9 @@ def elementwise_split_case1(a: T.handle, b: T.handle) -> None:
B = T.match_buffer(b, [128, 128, 128])
for i1, i2, i3, j1, j2, j3, k1, k2, k3 in T.grid(2, 1, 64, 2, 1, 64, 2, 1, 64):
with T.block("B"):
vi = T.axis.S(128, i1 * 64 + i3)
vj = T.axis.S(128, j1 * 64 + j3)
vk = T.axis.S(128, k1 * 64 + k3)
vi = T.axis.S(128, (i1 + i2) * 64 + i3)
vj = T.axis.S(128, (j1 + j2) * 64 + j3)
vk = T.axis.S(128, (k1 + k2) * 64 + k3)
T.reads([A[vi, vj, vk]])
T.writes([B[vi, vj, vk]])
B[vi, vj, vk] = A[vi, vj, vk] * 2.0
Expand All @@ -206,10 +206,10 @@ def elementwise_split_with_predicate(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, [128, 128, 128])
for i0, i1, i2, j0, j1, k0, k1 in T.grid(1000, 2, 3, 1, 129, 3, 43):
with T.block("B"):
T.where((i0 * 2 + i1) * 3 + i2 < 128 and j1 < 128 and k0 * 43 + k1 < 128)
vi = T.axis.S(128, i0 * 6 + i1 * 3 + i2)
vj = T.axis.S(128, j1)
vj = T.axis.S(128, j0 * 129 + j1)
vk = T.axis.S(128, k0 * 43 + k1)
T.where((i0 * 2 + i1) * 3 + i2 < 128 and j0 * 129 + j1 < 128 and k0 * 43 + k1 < 128)
T.reads([A[vi, vj, vk]])
T.writes([B[vi, vj, vk]])
B[vi, vj, vk] = A[vi, vj, vk] * 2.0
Expand Down
18 changes: 6 additions & 12 deletions tests/python/unittest/test_tir_schedule_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN

from tvm.tir import Schedule
from tvm.script import tir as T
from tvm.tir import Schedule
from tvm.tir.schedule.transform import tile_with_tensor_intrin
from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN


@tvm.script.ir_module
Expand Down Expand Up @@ -128,11 +127,10 @@ def main(
1, 16, 56, 56, 1, 1, 1, 4, 4, 1, 16, 4
):
with T.block("conv2d_NCHWc_int8"):
n = T.axis.spatial(1, 0)
oc_chunk, oh, ow, oc_block = T.axis.remap("SSSS", [i1, i2, i3, i4_1])
kh = T.axis.reduce(1, 0)
kw = T.axis.reduce(1, 0)
ic_outer, ic_f_inner, ic_s_inner = T.axis.remap("RRR", [i7, i8, i9_1])
n, oc_chunk, oh, ow = T.axis.remap("SSSS", [i0, i1, i2, i3])
oc_block = T.axis.spatial(16, i4_0 * 16 + i4_1)
kh, kw, ic_outer, ic_f_inner = T.axis.remap("RRRR", [i5, i6, i7, i8])
ic_s_inner = T.axis.reduce(4, i9_0 * 4 + i9_1)
T.reads(
placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner],
placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner],
Expand Down Expand Up @@ -165,14 +163,10 @@ def test_tile_with_tensor_intrin_dense_vnni():
def test_tile_with_tensor_intrin_conv2d_nchwc_vnni():
s = Schedule(Conv2dNCHWcVNNIModule)
block = s.get_block("conv2d_NCHWc_int8")

tiled_loop = tile_with_tensor_intrin(s, block, VNNI_DOT_16x4_INTRIN)

tiled_loops = s.get_loops(block)

assert len(tiled_loops) == 12
assert s.get(tiled_loop) == s.get(tiled_loops[-2])

tvm.ir.assert_structural_equal(s.mod, Conv2dNCHWcVNNIModuleTiled)


Expand Down