Skip to content

[mlir][SPIR-V] Fix math.powf lowering for non-integer exponents#197727

Merged
aobolensk merged 6 commits into
llvm:mainfrom
aobolensk:mlir-spirv-fix-math-powf-non-inf
May 17, 2026
Merged

[mlir][SPIR-V] Fix math.powf lowering for non-integer exponents#197727
aobolensk merged 6 commits into
llvm:mainfrom
aobolensk:mlir-spirv-fix-math-powf-non-inf

Conversation

@aobolensk

Copy link
Copy Markdown
Contributor

The ConvertFToS usage only works when y is an integer. Use it only for integer constants, for others: lower as GL.Exp(y * GL.Log(x))

The ConvertFToS usage only works when y is an integer. Use it only for integer constants, for others: lower as GL.Exp(y * GL.Log(x)).
@llvmorg-github-actions

llvmorg-github-actions Bot commented May 14, 2026

Copy link
Copy Markdown

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Arseniy Obolenskiy (aobolensk)

Changes

The ConvertFToS usage only works when y is an integer. Use it only for integer constants, for others: lower as GL.Exp(y * GL.Log(x))


Full diff: https://github.com/llvm/llvm-project/pull/197727.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp (+29-49)
  • (modified) mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir (+48-17)
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 01285c6c0ec09..c973b2b927f9c 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/STLExtras.h"
@@ -360,62 +361,41 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
     if (!dstType)
       return failure();
 
-    // Get the scalar float type.
-    FloatType scalarFloatType;
-    if (auto scalarType = dyn_cast<FloatType>(powfOp.getType())) {
-      scalarFloatType = scalarType;
-    } else if (auto vectorType = dyn_cast<VectorType>(powfOp.getType())) {
-      scalarFloatType = cast<FloatType>(vectorType.getElementType());
-    } else {
-      return failure();
-    }
-
-    // Get int type of the same shape as the float type.
-    Type scalarIntType = rewriter.getIntegerType(32);
-    Type intType = scalarIntType;
+    Location loc = powfOp.getLoc();
     auto operandType = adaptor.getRhs().getType();
-    if (auto vectorType = dyn_cast<VectorType>(operandType)) {
-      auto shape = vectorType.getShape();
-      intType = VectorType::get(shape, scalarIntType);
+
+    // ConvertFToS-based parity needs an integer-valued exponent. Otherwise
+    // fall back to exp(y*log(x)), which yields NaN for x<0 (matches C).
+    auto isIntegerValuedConstant = [](Value v) -> bool {
+      Attribute attr;
+      if (!matchPattern(v, m_Constant(&attr)))
+        return false;
+      if (auto fAttr = dyn_cast<FloatAttr>(attr))
+        return fAttr.getValue().isInteger();
+      if (auto dense = dyn_cast<DenseFPElementsAttr>(attr))
+        return llvm::all_of(dense.getValues<APFloat>(),
+                            [](const APFloat &v) { return v.isInteger(); });
+      return false;
+    };
+
+    if (!isIntegerValuedConstant(adaptor.getRhs())) {
+      Value log = spirv::GLLogOp::create(rewriter, loc, adaptor.getLhs());
+      Value mul = spirv::FMulOp::create(rewriter, loc, adaptor.getRhs(), log);
+      rewriter.replaceOpWithNewOp<spirv::GLExpOp>(powfOp, mul);
+      return success();
     }
 
-    // Per GL Pow extended instruction spec:
-    // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0."
-    Location loc = powfOp.getLoc();
+    // GL.Pow is undefined for x < 0; take abs and conditionally negate the
+    // result when the exponent is odd.
+    Type intType = rewriter.getIntegerType(32);
+    if (auto vectorType = dyn_cast<VectorType>(operandType))
+      intType = VectorType::get(vectorType.getShape(), intType);
+
     Value zero = spirv::ConstantOp::getZero(operandType, loc, rewriter);
     Value lessThan =
         spirv::FOrdLessThanOp::create(rewriter, loc, adaptor.getLhs(), zero);
+    Value abs = spirv::GLFAbsOp::create(rewriter, loc, adaptor.getLhs());
 
-    // Per C/C++ spec:
-    // > pow(base, exponent) returns NaN (and raises FE_INVALID) if base is
-    // > finite and negative and exponent is finite and non-integer.
-    // Calculate the reminder from the exponent and check whether it is zero.
-    Value floatOne = spirv::ConstantOp::getOne(operandType, loc, rewriter);
-    Value expRem =
-        spirv::FRemOp::create(rewriter, loc, adaptor.getRhs(), floatOne);
-    Value expRemNonZero =
-        spirv::FOrdNotEqualOp::create(rewriter, loc, expRem, zero);
-    Value cmpNegativeWithFractionalExp =
-        spirv::LogicalAndOp::create(rewriter, loc, expRemNonZero, lessThan);
-    // Create NaN result and replace base value if conditions are met.
-    const auto &floatSemantics = scalarFloatType.getFloatSemantics();
-    const auto nan = APFloat::getNaN(floatSemantics);
-    Attribute nanAttr = rewriter.getFloatAttr(scalarFloatType, nan);
-    if (auto vectorType = dyn_cast<VectorType>(operandType))
-      nanAttr = DenseElementsAttr::get(vectorType, nan);
-
-    Value nanValue =
-        spirv::ConstantOp::create(rewriter, loc, operandType, nanAttr);
-    Value lhs =
-        spirv::SelectOp::create(rewriter, loc, cmpNegativeWithFractionalExp,
-                                nanValue, adaptor.getLhs());
-    Value abs = spirv::GLFAbsOp::create(rewriter, loc, lhs);
-
-    // TODO: The following just forcefully casts y into an integer value in
-    // order to properly propagate the sign, assuming integer y cases. It
-    // doesn't cover other cases and should be fixed.
-
-    // Cast exponent to integer and calculate exponent % 2 != 0.
     Value intRhs =
         spirv::ConvertFToSOp::create(rewriter, loc, intType, adaptor.getRhs());
     Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter);
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
index 8eb533eeff2a9..e3fce6fa40dfd 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
@@ -183,45 +183,76 @@ func.func @ctlz_vector2(%val: vector<2xi32>) -> vector<2xi32> {
   return %0 : vector<2xi32>
 }
 
+// Dynamic exponent: exp(y * log(x)); yields NaN for x<0.
 // CHECK-LABEL: @powf_scalar
 //  CHECK-SAME: (%[[LHS:.+]]: f32, %[[RHS:.+]]: f32)
 func.func @powf_scalar(%lhs: f32, %rhs: f32) -> f32 {
+  // CHECK: %[[LOG:.+]] = spirv.GL.Log %[[LHS]] : f32
+  // CHECK: %[[MUL:.+]] = spirv.FMul %[[RHS]], %[[LOG]] : f32
+  // CHECK: %[[EXP:.+]] = spirv.GL.Exp %[[MUL]] : f32
+  %0 = math.powf %lhs, %rhs : f32
+  // CHECK: return %[[EXP]]
+  return %0: f32
+}
+
+// CHECK-LABEL: @powf_vector
+func.func @powf_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) -> vector<4xf32> {
+  // CHECK: spirv.GL.Log %{{.*}} : vector<4xf32>
+  // CHECK: spirv.FMul %{{.*}} : vector<4xf32>
+  // CHECK: spirv.GL.Exp %{{.*}} : vector<4xf32>
+  %0 = math.powf %lhs, %rhs : vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// Constant integer exponent: parity-based path preserves sign (pow(-2,3)=-8).
+// CHECK-LABEL: @powf_const_int_exp
+//  CHECK-SAME: (%[[LHS:.+]]: f32)
+func.func @powf_const_int_exp(%lhs: f32) -> f32 {
+  // CHECK: %[[RHS:.+]] = arith.constant 3.000000e+00 : f32
   // CHECK: %[[F0:.+]] = spirv.Constant 0.000000e+00 : f32
   // CHECK: %[[LT:.+]] = spirv.FOrdLessThan %[[LHS]], %[[F0]] : f32
-  // CHECK: %[[F1:.+]] = spirv.Constant 1.000000e+00 : f32
-  // CHECK: %[[REM:.+]] = spirv.FRem %[[RHS]], %[[F1]] : f32
-  // CHECK: %[[IS_FRACTION:.+]] = spirv.FOrdNotEqual %[[REM]], %[[F0]] : f32
-  // CHECK: %[[AND:.+]] = spirv.LogicalAnd %[[IS_FRACTION]], %[[LT]] : i1
-  // CHECK: %[[NAN:.+]] = spirv.Constant 0x7FC00000 : f32
-  // CHECK: %[[NEW_LHS:.+]] = spirv.Select %[[AND]], %[[NAN]], %[[LHS]] : i1, f32
-  // CHECK: %[[ABS:.+]] = spirv.GL.FAbs %[[NEW_LHS]] : f32
-  // CHECK: %[[IRHS:.+]] = spirv.ConvertFToS
+  // CHECK: %[[ABS:.+]] = spirv.GL.FAbs %[[LHS]] : f32
+  // CHECK: %[[IRHS:.+]] = spirv.ConvertFToS %[[RHS]] : f32 to i32
   // CHECK: %[[CST1:.+]] = spirv.Constant 1 : i32
-  // CHECK: %[[REM:.+]] = spirv.BitwiseAnd %[[IRHS]]
+  // CHECK: %[[REM:.+]] = spirv.BitwiseAnd %[[IRHS]], %[[CST1]] : i32
   // CHECK: %[[ODD:.+]] = spirv.IEqual %[[REM]], %[[CST1]] : i32
   // CHECK: %[[POW:.+]] = spirv.GL.Pow %[[ABS]], %[[RHS]] : f32
   // CHECK: %[[NEG:.+]] = spirv.FNegate %[[POW]] : f32
   // CHECK: %[[SNEG:.+]] = spirv.LogicalAnd %[[LT]], %[[ODD]] : i1
   // CHECK: %[[SEL:.+]] = spirv.Select %[[SNEG]], %[[NEG]], %[[POW]] : i1, f32
-  %0 = math.powf %lhs, %rhs : f32
+  %c = arith.constant 3.0 : f32
+  %0 = math.powf %lhs, %c : f32
   // CHECK: return %[[SEL]]
   return %0: f32
 }
 
-// CHECK-LABEL: @powf_vector
-func.func @powf_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) -> vector<4xf32> {
+// Constant non-integer exponent: falls into the dynamic exp(y*log(x)) path.
+// CHECK-LABEL: @powf_const_frac_exp
+//  CHECK-SAME: (%[[LHS:.+]]: f32)
+func.func @powf_const_frac_exp(%lhs: f32) -> f32 {
+  // CHECK: %[[RHS:.+]] = arith.constant 2.500000e+00 : f32
+  // CHECK: %[[LOG:.+]] = spirv.GL.Log %[[LHS]] : f32
+  // CHECK: %[[MUL:.+]] = spirv.FMul %[[RHS]], %[[LOG]] : f32
+  // CHECK: %[[EXP:.+]] = spirv.GL.Exp %[[MUL]] : f32
+  %c = arith.constant 2.5 : f32
+  %0 = math.powf %lhs, %c : f32
+  // CHECK: return %[[EXP]]
+  return %0: f32
+}
+
+// Splat constant integer-valued vector exponent: parity-based path.
+// CHECK-LABEL: @powf_const_int_exp_vector
+func.func @powf_const_int_exp_vector(%lhs: vector<4xf32>) -> vector<4xf32> {
   // CHECK: spirv.FOrdLessThan
-  // CHECK: spirv.FRem
-  // CHECK: spirv.FOrdNotEqual
-  // CHECK: spirv.LogicalAnd
-  // CHECK: spirv.Select
   // CHECK: spirv.GL.FAbs
+  // CHECK: spirv.ConvertFToS %{{.*}} : vector<4xf32> to vector<4xi32>
   // CHECK: spirv.BitwiseAnd %{{.*}} : vector<4xi32>
   // CHECK: spirv.IEqual %{{.*}} : vector<4xi32>
   // CHECK: spirv.GL.Pow %{{.*}}: vector<4xf32>
   // CHECK: spirv.FNegate
   // CHECK: spirv.Select
-  %0 = math.powf %lhs, %rhs : vector<4xf32>
+  %c = arith.constant dense<3.0> : vector<4xf32>
+  %0 = math.powf %lhs, %c : vector<4xf32>
   return %0: vector<4xf32>
 }
 

Comment on lines +373 to +377
if (auto fAttr = dyn_cast<FloatAttr>(attr))
return fAttr.getValue().isInteger();
if (auto dense = dyn_cast<DenseFPElementsAttr>(attr))
return llvm::all_of(dense.getValues<APFloat>(),
[](const APFloat &v) { return v.isInteger(); });

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

turn this into a type switch and handle SplatElementsAttr separately, so that we don't repeatedly check the same values

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment thread mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp Outdated
// doesn't cover other cases and should be fixed.

// Cast exponent to integer and calculate exponent % 2 != 0.
Value intRhs =

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rhs is guaranteed to be a constant here, because we would hit if (!isIntegerValuedConstant(adaptor.getRhs())) { otherwise, right? Based on that can we simplify the code here?

  1. Not bother with ConvertFToSOp and just materialize an integer constant of a right type directly?
  2. Remove the whole shouldNegate logic and just emit negate based on the sign of the constant?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, why do we disallow, non-constant integers in here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot know that this number is potentially the integer, unless it is constant. There current implementation that is residing on the main branch is good for integers, but incorrect for numbers with fractional part. If you have any suggestions on how to know that the number is integer without requiring it to be constant, I can consider that as well

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. I missed the fact the value type is float. In that case is my suggestion around simplify the lowering sensible?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, they are applicable. Done

@github-actions

github-actions Bot commented May 15, 2026

Copy link
Copy Markdown

✅ With the latest revision this PR passed the C/C++ code formatter.

@aobolensk aobolensk force-pushed the mlir-spirv-fix-math-powf-non-inf branch from dee55c4 to 7e71a06 Compare May 15, 2026 14:13
} else {
return failure();
Location loc = powfOp.getLoc();
auto operandType = adaptor.getRhs().getType();

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto operandType = adaptor.getRhs().getType();
Type operandType = adaptor.getRhs().getType();

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

spirv::LogicalAndOp::create(rewriter, loc, lessThan, isOdd);

Value shouldNegate;
if (llvm::all_of(oddMask, [](bool b) { return b; })) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llvm::all_equal

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

SmallVector<bool> oddMask;
bool isIntegerValued = false;
if (matchPattern(adaptor.getRhs(), m_Constant(&rhsAttr))) {
isIntegerValued = TypeSwitch<Attribute, bool>(rhsAttr)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make the formatting less ugly somehow? If absolutely necessary you can wrap it in // clang-format off and // clang-format on.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe make it back into lambda with oddMask captured?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe make it back into lambda with oddMask captured?

That helps, thanks

};

SmallVector<bool> oddMask;
auto isIntegerValuedConstant = [&](Value v) -> bool {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about it, I think we could simplify this a lot:

  1. Remove the lambda completely
  2. Have a TypeSwitch only populate the oddMask
  3. Instead of checking!isIntegerValuedConstant(adaptor.getRhs()) just check oddMask.empty().

What do you think?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. Applied

@kuhar kuhar left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but also give @IgWod a chance to review before landing

@IgWod IgWod left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@aobolensk aobolensk merged commit 3a2876d into llvm:main May 17, 2026
10 checks passed
pedroMVicente pushed a commit to pedroMVicente/llvm-project that referenced this pull request May 19, 2026
…#197727)

The ConvertFToS usage only works when y is an integer. Use it only for
integer constants, for others: lower as GL.Exp(y * GL.Log(x))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants