diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h index b68262f09f485..ee401cca8f552 100644 --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -707,6 +707,19 @@ class IntegerRelation { /// this for uniformity with `applyDomain`. void applyRange(const IntegerRelation &rel); + /// Let the relation `this` be R1, and the relation `rel` be R2. Requires + /// R1 and R2 to have the same domain. + /// + /// Let R3 be the rangeProduct of R1 and R2. Then x R3 (y, z) iff + /// (x R1 y and x R2 z). + /// + /// Example: + /// + /// R1: (i, j) -> k : f(i, j, k) = 0 + /// R2: (i, j) -> l : g(i, j, l) = 0 + /// R1.rangeProduct(R2): (i, j) -> (k, l) : f(i, j, k) = 0 and g(i, j, l) = 0 + IntegerRelation rangeProduct(const IntegerRelation &rel); + /// Given a relation `other: (A -> B)`, this operation merges the symbol and /// local variables and then takes the composition of `other` on `this: (B -> /// C)`. The resulting relation represents tuples of the form: `A -> C`. diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp index 17e48e0d069b7..5c4d4d13580a0 100644 --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -2481,6 +2481,44 @@ void IntegerRelation::applyDomain(const IntegerRelation &rel) { void IntegerRelation::applyRange(const IntegerRelation &rel) { compose(rel); } +IntegerRelation IntegerRelation::rangeProduct(const IntegerRelation &rel) { + /// R1: (i, j) -> k : f(i, j, k) = 0 + /// R2: (i, j) -> l : g(i, j, l) = 0 + /// R1.rangeProduct(R2): (i, j) -> (k, l) : f(i, j, k) = 0 and g(i, j, l) = 0 + assert(getNumDomainVars() == rel.getNumDomainVars() && + "Range product is only defined for relations with equal domains"); + + // explicit copy of `this` + IntegerRelation result = *this; + unsigned relRangeVarStart = rel.getVarKindOffset(VarKind::Range); + unsigned numThisRangeVars = getNumRangeVars(); + unsigned numNewSymbolVars = result.getNumSymbolVars() - getNumSymbolVars(); + + result.appendVar(VarKind::Range, rel.getNumRangeVars()); + + // Copy each equality from `rel` and update the copy to account for range + // variables from `this`. The `rel` equality is a list of coefficients of the + // variables from `rel`, and so the range variables need to be shifted right + // by the number of `this` range variables and symbols. + for (unsigned i = 0; i < rel.getNumEqualities(); ++i) { + SmallVector copy = + SmallVector(rel.getEquality(i)); + copy.insert(copy.begin() + relRangeVarStart, + numThisRangeVars + numNewSymbolVars, DynamicAPInt(0)); + result.addEquality(copy); + } + + for (unsigned i = 0; i < rel.getNumInequalities(); ++i) { + SmallVector copy = + SmallVector(rel.getInequality(i)); + copy.insert(copy.begin() + relRangeVarStart, + numThisRangeVars + numNewSymbolVars, DynamicAPInt(0)); + result.addInequality(copy); + } + + return result; +} + void IntegerRelation::printSpace(raw_ostream &os) const { space.print(os); os << getNumConstraints() << " constraints\n"; diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp index 7df500bc9568a..dd0b09f7f05d2 100644 --- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp @@ -608,3 +608,97 @@ TEST(IntegerRelationTest, convertVarKindToLocal) { EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3])); EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4])); } + +TEST(IntegerRelationTest, rangeProduct) { + IntegerRelation r1 = parseRelationFromSet( + "(i, j, k) : (2*i + 3*k == 0, i >= 0, j >= 0, k >= 0)", 2); + IntegerRelation r2 = parseRelationFromSet( + "(i, j, l) : (4*i + 6*j + 9*l == 0, i >= 0, j >= 0, l >= 0)", 2); + + IntegerRelation rangeProd = r1.rangeProduct(r2); + IntegerRelation expected = + parseRelationFromSet("(i, j, k, l) : (2*i + 3*k == 0, 4*i + 6*j + 9*l == " + "0, i >= 0, j >= 0, k >= 0, l >= 0)", + 2); + + EXPECT_TRUE(expected.isEqual(rangeProd)); +} + +TEST(IntegerRelationTest, rangeProductMultdimRange) { + IntegerRelation r1 = + parseRelationFromSet("(i, k) : (2*i + 3*k == 0, i >= 0, k >= 0)", 1); + IntegerRelation r2 = parseRelationFromSet( + "(i, l, m) : (4*i + 6*m + 9*l == 0, i >= 0, l >= 0, m >= 0)", 1); + + IntegerRelation rangeProd = r1.rangeProduct(r2); + IntegerRelation expected = + parseRelationFromSet("(i, k, l, m) : (2*i + 3*k == 0, 4*i + 6*m + 9*l == " + "0, i >= 0, k >= 0, l >= 0, m >= 0)", + 1); + + EXPECT_TRUE(expected.isEqual(rangeProd)); +} + +TEST(IntegerRelationTest, rangeProductMultdimRangeSwapped) { + IntegerRelation r1 = parseRelationFromSet( + "(i, l, m) : (4*i + 6*m + 9*l == 0, i >= 0, l >= 0, m >= 0)", 1); + IntegerRelation r2 = + parseRelationFromSet("(i, k) : (2*i + 3*k == 0, i >= 0, k >= 0)", 1); + + IntegerRelation rangeProd = r1.rangeProduct(r2); + IntegerRelation expected = + parseRelationFromSet("(i, l, m, k) : (2*i + 3*k == 0, 4*i + 6*m + 9*l == " + "0, i >= 0, k >= 0, l >= 0, m >= 0)", + 1); + + EXPECT_TRUE(expected.isEqual(rangeProd)); +} + +TEST(IntegerRelationTest, rangeProductEmptyDomain) { + IntegerRelation r1 = + parseRelationFromSet("(i, j) : (4*i + 9*j == 0, i >= 0, j >= 0)", 0); + IntegerRelation r2 = + parseRelationFromSet("(k, l) : (2*k + 3*l == 0, k >= 0, l >= 0)", 0); + IntegerRelation rangeProd = r1.rangeProduct(r2); + IntegerRelation expected = + parseRelationFromSet("(i, j, k, l) : (2*k + 3*l == 0, 4*i + 9*j == " + "0, i >= 0, j >= 0, k >= 0, l >= 0)", + 0); + EXPECT_TRUE(expected.isEqual(rangeProd)); +} + +TEST(IntegerRelationTest, rangeProductEmptyRange) { + IntegerRelation r1 = + parseRelationFromSet("(i, j) : (4*i + 9*j == 0, i >= 0, j >= 0)", 2); + IntegerRelation r2 = + parseRelationFromSet("(i, j) : (2*i + 3*j == 0, i >= 0, j >= 0)", 2); + IntegerRelation rangeProd = r1.rangeProduct(r2); + IntegerRelation expected = + parseRelationFromSet("(i, j) : (2*i + 3*j == 0, 4*i + 9*j == " + "0, i >= 0, j >= 0)", + 2); + EXPECT_TRUE(expected.isEqual(rangeProd)); +} + +TEST(IntegerRelationTest, rangeProductEmptyDomainAndRange) { + IntegerRelation r1 = parseRelationFromSet("() : ()", 0); + IntegerRelation r2 = parseRelationFromSet("() : ()", 0); + IntegerRelation rangeProd = r1.rangeProduct(r2); + IntegerRelation expected = parseRelationFromSet("() : ()", 0); + EXPECT_TRUE(expected.isEqual(rangeProd)); +} + +TEST(IntegerRelationTest, rangeProductSymbols) { + IntegerRelation r1 = parseRelationFromSet( + "(i, j)[s] : (2*i + 3*j + s == 0, i >= 0, j >= 0)", 1); + IntegerRelation r2 = parseRelationFromSet( + "(i, l)[s] : (3*i + 4*l + s == 0, i >= 0, l >= 0)", 1); + + IntegerRelation rangeProd = r1.rangeProduct(r2); + IntegerRelation expected = parseRelationFromSet( + "(i, j, l)[s] : (2*i + 3*j + s == 0, 3*i + 4*l + s == " + "0, i >= 0, j >= 0, l >= 0)", + 1); + + EXPECT_TRUE(expected.isEqual(rangeProd)); +}