diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index f455695438ec..ac6bf94b1198 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -67,7 +67,18 @@ struct ModularSetAnalyzer::Entry { Entry() = default; Entry(int64_t coeff, int64_t base) { - ICHECK_GE(coeff, 0); + if (coeff < 0) { + // `analyzer->canonical_simplify()` can generate expressions with + // negative coefficients (e.g. simplifying `floormod(-i, 2)` + // into `floormod(i, -2) * -1`). When this happens, the + // ModularSet may enter a constraint based on this expression. + // + // Handling a negative coeff uses the same sign convention as + // canonical_simplify, requiring that + // `floormod(var, coeff) == -floormod(var, -coeff)`. + coeff *= -1; + base *= -1; + } this->coeff = coeff; if (coeff != 0) { base = base % coeff; diff --git a/tests/python/unittest/test_arith_canonical_simplify.py b/tests/python/unittest/test_arith_canonical_simplify.py index 81a163d0d431..9f187685991e 100644 --- a/tests/python/unittest/test_arith_canonical_simplify.py +++ b/tests/python/unittest/test_arith_canonical_simplify.py @@ -97,6 +97,8 @@ def test_split_index_simplify(): # cannot simplify mixed case, unless we canonicalize into one mode. ck.verify(tdiv(x, 6) * 2 + tmod(fld(x, 3), 2), tdiv(x, 6) * 2 + tmod(fld(x, 3), 2)) + ck.verify(tmod(-x, 2), tmod(x, -2) * -1) + def test_div_simplify(): ck = CanonicalChecker() @@ -129,6 +131,8 @@ def test_floormod_simplify(): ck.verify(flm(flm((x * 4) + y - 466036, 24528) - 24512, 16), flm((x * 4) + y + 12, 16)) ck.verify(flm(flm((x * 4), 16), 8), flm(x, 2) * 4) + ck.verify(flm(-x, 2), flm(x, -2) * -1) + def test_canonical_mixed(): ck = CanonicalChecker() diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py index 2eb9c3546ee5..46b6858ec773 100644 --- a/tests/python/unittest/test_tir_transform_simplify.py +++ b/tests/python/unittest/test_tir_transform_simplify.py @@ -816,5 +816,34 @@ def expected(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32, k: T.int32): A[0] = (i != 30) or (j == 0) +class TestConditionalFloorMod(BaseBeforeAfter): + """A regression test for negative floormod denominator + + Previously, simplifying this function could throw an error. First, the + `canonical_simplify` would rewrite `floormod(0-i,2)` to the equivalent + `floormod(i,-2)`. Then, the rewrite_simplifier would enter a + constrained context in which `floormod(i,-2)==1`. Passing this + expression to `ModularSet::EnterConstraint`, which previously did not + support a negative value for the second argument, threw an error. + + The analogous failure mode never occurred for `truncmod`, because + `truncmod(0-i,2)` would be canonicalized to `truncmod(i, -2) * -1`, and + the pattern matching in `ModularSet` didn't recognize the constant + factor. + + This failure mode was resolved by supporting negative arguments in + `ModularSet`, using the same sign convention as is used by + `canonical_simplify`. + """ + + def before(A: T.Buffer[1, "bool"], i: T.int32): + if T.floormod(0 - i, 2) == 0: + A[0] = T.floormod(i, 2) == 0 + + def expected(A: T.Buffer[1, "bool"], i: T.int32): + if T.floormod(i, -2) == 0: + A[0] = True + + if __name__ == "__main__": tvm.testing.main()