[mlir][SPIR-V] Fix math.powf lowering for non-integer exponents#197727
Conversation
The ConvertFToS usage only works when y is an integer. Use it only for integer constants, for others: lower as GL.Exp(y * GL.Log(x)).
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-spirv Author: Arseniy Obolenskiy (aobolensk) ChangesThe ConvertFToS usage only works when y is an integer. Use it only for integer constants, for others: lower as GL.Exp(y * GL.Log(x)) Full diff: https://github.com/llvm/llvm-project/pull/197727.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 01285c6c0ec09..c973b2b927f9c 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"
@@ -360,62 +361,41 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
if (!dstType)
return failure();
- // Get the scalar float type.
- FloatType scalarFloatType;
- if (auto scalarType = dyn_cast<FloatType>(powfOp.getType())) {
- scalarFloatType = scalarType;
- } else if (auto vectorType = dyn_cast<VectorType>(powfOp.getType())) {
- scalarFloatType = cast<FloatType>(vectorType.getElementType());
- } else {
- return failure();
- }
-
- // Get int type of the same shape as the float type.
- Type scalarIntType = rewriter.getIntegerType(32);
- Type intType = scalarIntType;
+ Location loc = powfOp.getLoc();
auto operandType = adaptor.getRhs().getType();
- if (auto vectorType = dyn_cast<VectorType>(operandType)) {
- auto shape = vectorType.getShape();
- intType = VectorType::get(shape, scalarIntType);
+
+ // ConvertFToS-based parity needs an integer-valued exponent. Otherwise
+ // fall back to exp(y*log(x)), which yields NaN for x<0 (matches C).
+ auto isIntegerValuedConstant = [](Value v) -> bool {
+ Attribute attr;
+ if (!matchPattern(v, m_Constant(&attr)))
+ return false;
+ if (auto fAttr = dyn_cast<FloatAttr>(attr))
+ return fAttr.getValue().isInteger();
+ if (auto dense = dyn_cast<DenseFPElementsAttr>(attr))
+ return llvm::all_of(dense.getValues<APFloat>(),
+ [](const APFloat &v) { return v.isInteger(); });
+ return false;
+ };
+
+ if (!isIntegerValuedConstant(adaptor.getRhs())) {
+ Value log = spirv::GLLogOp::create(rewriter, loc, adaptor.getLhs());
+ Value mul = spirv::FMulOp::create(rewriter, loc, adaptor.getRhs(), log);
+ rewriter.replaceOpWithNewOp<spirv::GLExpOp>(powfOp, mul);
+ return success();
}
- // Per GL Pow extended instruction spec:
- // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0."
- Location loc = powfOp.getLoc();
+ // GL.Pow is undefined for x < 0; take abs and conditionally negate the
+ // result when the exponent is odd.
+ Type intType = rewriter.getIntegerType(32);
+ if (auto vectorType = dyn_cast<VectorType>(operandType))
+ intType = VectorType::get(vectorType.getShape(), intType);
+
Value zero = spirv::ConstantOp::getZero(operandType, loc, rewriter);
Value lessThan =
spirv::FOrdLessThanOp::create(rewriter, loc, adaptor.getLhs(), zero);
+ Value abs = spirv::GLFAbsOp::create(rewriter, loc, adaptor.getLhs());
- // Per C/C++ spec:
- // > pow(base, exponent) returns NaN (and raises FE_INVALID) if base is
- // > finite and negative and exponent is finite and non-integer.
- // Calculate the reminder from the exponent and check whether it is zero.
- Value floatOne = spirv::ConstantOp::getOne(operandType, loc, rewriter);
- Value expRem =
- spirv::FRemOp::create(rewriter, loc, adaptor.getRhs(), floatOne);
- Value expRemNonZero =
- spirv::FOrdNotEqualOp::create(rewriter, loc, expRem, zero);
- Value cmpNegativeWithFractionalExp =
- spirv::LogicalAndOp::create(rewriter, loc, expRemNonZero, lessThan);
- // Create NaN result and replace base value if conditions are met.
- const auto &floatSemantics = scalarFloatType.getFloatSemantics();
- const auto nan = APFloat::getNaN(floatSemantics);
- Attribute nanAttr = rewriter.getFloatAttr(scalarFloatType, nan);
- if (auto vectorType = dyn_cast<VectorType>(operandType))
- nanAttr = DenseElementsAttr::get(vectorType, nan);
-
- Value nanValue =
- spirv::ConstantOp::create(rewriter, loc, operandType, nanAttr);
- Value lhs =
- spirv::SelectOp::create(rewriter, loc, cmpNegativeWithFractionalExp,
- nanValue, adaptor.getLhs());
- Value abs = spirv::GLFAbsOp::create(rewriter, loc, lhs);
-
- // TODO: The following just forcefully casts y into an integer value in
- // order to properly propagate the sign, assuming integer y cases. It
- // doesn't cover other cases and should be fixed.
-
- // Cast exponent to integer and calculate exponent % 2 != 0.
Value intRhs =
spirv::ConvertFToSOp::create(rewriter, loc, intType, adaptor.getRhs());
Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter);
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
index 8eb533eeff2a9..e3fce6fa40dfd 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
@@ -183,45 +183,76 @@ func.func @ctlz_vector2(%val: vector<2xi32>) -> vector<2xi32> {
return %0 : vector<2xi32>
}
+// Dynamic exponent: exp(y * log(x)); yields NaN for x<0.
// CHECK-LABEL: @powf_scalar
// CHECK-SAME: (%[[LHS:.+]]: f32, %[[RHS:.+]]: f32)
func.func @powf_scalar(%lhs: f32, %rhs: f32) -> f32 {
+ // CHECK: %[[LOG:.+]] = spirv.GL.Log %[[LHS]] : f32
+ // CHECK: %[[MUL:.+]] = spirv.FMul %[[RHS]], %[[LOG]] : f32
+ // CHECK: %[[EXP:.+]] = spirv.GL.Exp %[[MUL]] : f32
+ %0 = math.powf %lhs, %rhs : f32
+ // CHECK: return %[[EXP]]
+ return %0: f32
+}
+
+// CHECK-LABEL: @powf_vector
+func.func @powf_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) -> vector<4xf32> {
+ // CHECK: spirv.GL.Log %{{.*}} : vector<4xf32>
+ // CHECK: spirv.FMul %{{.*}} : vector<4xf32>
+ // CHECK: spirv.GL.Exp %{{.*}} : vector<4xf32>
+ %0 = math.powf %lhs, %rhs : vector<4xf32>
+ return %0: vector<4xf32>
+}
+
+// Constant integer exponent: parity-based path preserves sign (pow(-2,3)=-8).
+// CHECK-LABEL: @powf_const_int_exp
+// CHECK-SAME: (%[[LHS:.+]]: f32)
+func.func @powf_const_int_exp(%lhs: f32) -> f32 {
+ // CHECK: %[[RHS:.+]] = arith.constant 3.000000e+00 : f32
// CHECK: %[[F0:.+]] = spirv.Constant 0.000000e+00 : f32
// CHECK: %[[LT:.+]] = spirv.FOrdLessThan %[[LHS]], %[[F0]] : f32
- // CHECK: %[[F1:.+]] = spirv.Constant 1.000000e+00 : f32
- // CHECK: %[[REM:.+]] = spirv.FRem %[[RHS]], %[[F1]] : f32
- // CHECK: %[[IS_FRACTION:.+]] = spirv.FOrdNotEqual %[[REM]], %[[F0]] : f32
- // CHECK: %[[AND:.+]] = spirv.LogicalAnd %[[IS_FRACTION]], %[[LT]] : i1
- // CHECK: %[[NAN:.+]] = spirv.Constant 0x7FC00000 : f32
- // CHECK: %[[NEW_LHS:.+]] = spirv.Select %[[AND]], %[[NAN]], %[[LHS]] : i1, f32
- // CHECK: %[[ABS:.+]] = spirv.GL.FAbs %[[NEW_LHS]] : f32
- // CHECK: %[[IRHS:.+]] = spirv.ConvertFToS
+ // CHECK: %[[ABS:.+]] = spirv.GL.FAbs %[[LHS]] : f32
+ // CHECK: %[[IRHS:.+]] = spirv.ConvertFToS %[[RHS]] : f32 to i32
// CHECK: %[[CST1:.+]] = spirv.Constant 1 : i32
- // CHECK: %[[REM:.+]] = spirv.BitwiseAnd %[[IRHS]]
+ // CHECK: %[[REM:.+]] = spirv.BitwiseAnd %[[IRHS]], %[[CST1]] : i32
// CHECK: %[[ODD:.+]] = spirv.IEqual %[[REM]], %[[CST1]] : i32
// CHECK: %[[POW:.+]] = spirv.GL.Pow %[[ABS]], %[[RHS]] : f32
// CHECK: %[[NEG:.+]] = spirv.FNegate %[[POW]] : f32
// CHECK: %[[SNEG:.+]] = spirv.LogicalAnd %[[LT]], %[[ODD]] : i1
// CHECK: %[[SEL:.+]] = spirv.Select %[[SNEG]], %[[NEG]], %[[POW]] : i1, f32
- %0 = math.powf %lhs, %rhs : f32
+ %c = arith.constant 3.0 : f32
+ %0 = math.powf %lhs, %c : f32
// CHECK: return %[[SEL]]
return %0: f32
}
-// CHECK-LABEL: @powf_vector
-func.func @powf_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) -> vector<4xf32> {
+// Constant non-integer exponent: falls into the dynamic exp(y*log(x)) path.
+// CHECK-LABEL: @powf_const_frac_exp
+// CHECK-SAME: (%[[LHS:.+]]: f32)
+func.func @powf_const_frac_exp(%lhs: f32) -> f32 {
+ // CHECK: %[[RHS:.+]] = arith.constant 2.500000e+00 : f32
+ // CHECK: %[[LOG:.+]] = spirv.GL.Log %[[LHS]] : f32
+ // CHECK: %[[MUL:.+]] = spirv.FMul %[[RHS]], %[[LOG]] : f32
+ // CHECK: %[[EXP:.+]] = spirv.GL.Exp %[[MUL]] : f32
+ %c = arith.constant 2.5 : f32
+ %0 = math.powf %lhs, %c : f32
+ // CHECK: return %[[EXP]]
+ return %0: f32
+}
+
+// Splat constant integer-valued vector exponent: parity-based path.
+// CHECK-LABEL: @powf_const_int_exp_vector
+func.func @powf_const_int_exp_vector(%lhs: vector<4xf32>) -> vector<4xf32> {
// CHECK: spirv.FOrdLessThan
- // CHECK: spirv.FRem
- // CHECK: spirv.FOrdNotEqual
- // CHECK: spirv.LogicalAnd
- // CHECK: spirv.Select
// CHECK: spirv.GL.FAbs
+ // CHECK: spirv.ConvertFToS %{{.*}} : vector<4xf32> to vector<4xi32>
// CHECK: spirv.BitwiseAnd %{{.*}} : vector<4xi32>
// CHECK: spirv.IEqual %{{.*}} : vector<4xi32>
// CHECK: spirv.GL.Pow %{{.*}}: vector<4xf32>
// CHECK: spirv.FNegate
// CHECK: spirv.Select
- %0 = math.powf %lhs, %rhs : vector<4xf32>
+ %c = arith.constant dense<3.0> : vector<4xf32>
+ %0 = math.powf %lhs, %c : vector<4xf32>
return %0: vector<4xf32>
}
|
| if (auto fAttr = dyn_cast<FloatAttr>(attr)) | ||
| return fAttr.getValue().isInteger(); | ||
| if (auto dense = dyn_cast<DenseFPElementsAttr>(attr)) | ||
| return llvm::all_of(dense.getValues<APFloat>(), | ||
| [](const APFloat &v) { return v.isInteger(); }); |
There was a problem hiding this comment.
turn this into a type switch and handle SplatElementsAttr separately, so that we don't repeatedly check the same values
| // doesn't cover other cases and should be fixed. | ||
|
|
||
| // Cast exponent to integer and calculate exponent % 2 != 0. | ||
| Value intRhs = |
There was a problem hiding this comment.
The rhs is guaranteed to be a constant here, because we would hit if (!isIntegerValuedConstant(adaptor.getRhs())) { otherwise, right? Based on that can we simplify the code here?
- Not bother with
ConvertFToSOpand just materialize an integer constant of a right type directly? - Remove the whole
shouldNegatelogic and just emit negate based on the sign of the constant?
There was a problem hiding this comment.
Actually, why do we disallow, non-constant integers in here?
There was a problem hiding this comment.
We cannot know that this number is potentially the integer, unless it is constant. There current implementation that is residing on the main branch is good for integers, but incorrect for numbers with fractional part. If you have any suggestions on how to know that the number is integer without requiring it to be constant, I can consider that as well
There was a problem hiding this comment.
That makes sense. I missed the fact the value type is float. In that case is my suggestion around simplify the lowering sensible?
There was a problem hiding this comment.
Yes, they are applicable. Done
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
dee55c4 to
7e71a06
Compare
| } else { | ||
| return failure(); | ||
| Location loc = powfOp.getLoc(); | ||
| auto operandType = adaptor.getRhs().getType(); |
There was a problem hiding this comment.
| auto operandType = adaptor.getRhs().getType(); | |
| Type operandType = adaptor.getRhs().getType(); |
| spirv::LogicalAndOp::create(rewriter, loc, lessThan, isOdd); | ||
|
|
||
| Value shouldNegate; | ||
| if (llvm::all_of(oddMask, [](bool b) { return b; })) { |
| SmallVector<bool> oddMask; | ||
| bool isIntegerValued = false; | ||
| if (matchPattern(adaptor.getRhs(), m_Constant(&rhsAttr))) { | ||
| isIntegerValued = TypeSwitch<Attribute, bool>(rhsAttr) |
There was a problem hiding this comment.
Can we make the formatting less ugly somehow? If absolutely necessary you can wrap it in // clang-format off and // clang-format on.
There was a problem hiding this comment.
Maybe make it back into lambda with oddMask captured?
There was a problem hiding this comment.
Maybe make it back into lambda with
oddMaskcaptured?
That helps, thanks
| }; | ||
|
|
||
| SmallVector<bool> oddMask; | ||
| auto isIntegerValuedConstant = [&](Value v) -> bool { |
There was a problem hiding this comment.
Thinking about it, I think we could simplify this a lot:
- Remove the lambda completely
- Have a
TypeSwitchonly populate theoddMask - Instead of checking
!isIntegerValuedConstant(adaptor.getRhs())just checkoddMask.empty().
What do you think?
There was a problem hiding this comment.
That makes sense. Applied
…#197727) The ConvertFToS usage only works when y is an integer. Use it only for integer constants, for others: lower as GL.Exp(y * GL.Log(x))
The ConvertFToS usage only works when y is an integer. Use it only for integer constants, for others: lower as GL.Exp(y * GL.Log(x))