diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index bde4ba993f69e..1a319bf404838 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1143,6 +1143,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, ISD::SIGN_EXTEND_INREG, ISD::CONCAT_VECTORS, ISD::EXTRACT_SUBVECTOR, ISD::INSERT_SUBVECTOR, ISD::STORE, ISD::BUILD_VECTOR}); + setTargetDAGCombine(ISD::SMIN); setTargetDAGCombine(ISD::TRUNCATE); setTargetDAGCombine(ISD::LOAD); @@ -2392,6 +2393,15 @@ static bool isIntImmediate(const SDNode *N, uint64_t &Imm) { return false; } +bool isVectorizedBinOp(unsigned Opcode) { + switch (Opcode) { + case AArch64ISD::SQDMULH: + return true; + default: + return false; + } +} + // isOpcWithIntImmediate - This method tests to see if the node is a specific // opcode and that it has a immediate integer right operand. // If so Imm will receive the value. @@ -20126,8 +20136,9 @@ static SDValue performConcatVectorsCombine(SDNode *N, // size, combine into an binop of two contacts of the source vectors. eg: // concat(uhadd(a,b), uhadd(c, d)) -> uhadd(concat(a, c), concat(b, d)) if (N->getNumOperands() == 2 && N0Opc == N1Opc && VT.is128BitVector() && - DAG.getTargetLoweringInfo().isBinOp(N0Opc) && N0->hasOneUse() && - N1->hasOneUse()) { + (DAG.getTargetLoweringInfo().isBinOp(N0Opc) || + isVectorizedBinOp(N0Opc)) && + N0->hasOneUse() && N1->hasOneUse()) { SDValue N00 = N0->getOperand(0); SDValue N01 = N0->getOperand(1); SDValue N10 = N1->getOperand(0); @@ -20986,6 +20997,98 @@ static SDValue performBuildVectorCombine(SDNode *N, return SDValue(); } +// A special combine for the sqdmulh family of instructions. +// smin( sra ( mul( sext v0, sext v1 ) ), SHIFT_AMOUNT ), +// SATURATING_VAL ) can be reduced to sqdmulh(...) +static SDValue trySQDMULHCombine(SDNode *N, SelectionDAG &DAG) { + + if (N->getOpcode() != ISD::SMIN) + return SDValue(); + + EVT DestVT = N->getValueType(0); + + if (!DestVT.isVector() || DestVT.getScalarSizeInBits() > 64 || + DestVT.isScalableVector()) + return SDValue(); + + ConstantSDNode *Clamp = isConstOrConstSplat(N->getOperand(1)); + + if (!Clamp) + return SDValue(); + + MVT ScalarType; + unsigned ShiftAmt = 0; + switch (Clamp->getSExtValue()) { + case (1ULL << 15) - 1: + ScalarType = MVT::i16; + ShiftAmt = 16; + break; + case (1ULL << 31) - 1: + ScalarType = MVT::i32; + ShiftAmt = 32; + break; + default: + return SDValue(); + } + + SDValue Sra = N->getOperand(0); + if (Sra.getOpcode() != ISD::SRA || !Sra.hasOneUse()) + return SDValue(); + + ConstantSDNode *RightShiftVec = isConstOrConstSplat(Sra.getOperand(1)); + if (!RightShiftVec) + return SDValue(); + unsigned SExtValue = RightShiftVec->getSExtValue(); + + if (SExtValue != (ShiftAmt - 1)) + return SDValue(); + + SDValue Mul = Sra.getOperand(0); + if (Mul.getOpcode() != ISD::MUL) + return SDValue(); + + SDValue SExt0 = Mul.getOperand(0); + SDValue SExt1 = Mul.getOperand(1); + + if (SExt0.getOpcode() != ISD::SIGN_EXTEND || + SExt1.getOpcode() != ISD::SIGN_EXTEND) + return SDValue(); + + EVT SExt0Type = SExt0.getOperand(0).getValueType(); + EVT SExt1Type = SExt1.getOperand(0).getValueType(); + + if (SExt0Type != SExt1Type || SExt0Type.getScalarType() != ScalarType || + SExt0Type.getFixedSizeInBits() > 128 || !SExt0Type.isPow2VectorType() || + SExt0Type.getVectorNumElements() == 1) + return SDValue(); + + SDLoc DL(N); + SDValue V0 = SExt0.getOperand(0); + SDValue V1 = SExt1.getOperand(0); + + // Ensure input vectors are extended to legal types + if (SExt0Type.getFixedSizeInBits() < 64) { + unsigned VecNumElements = SExt0Type.getVectorNumElements(); + EVT ExtVecVT = MVT::getVectorVT(MVT::getIntegerVT(64 / VecNumElements), + VecNumElements); + V0 = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVecVT, V0); + V1 = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVecVT, V1); + } + + SDValue SQDMULH = + DAG.getNode(AArch64ISD::SQDMULH, DL, V0.getValueType(), V0, V1); + + return DAG.getNode(ISD::SIGN_EXTEND, DL, DestVT, SQDMULH); +} + +static SDValue performSMINCombine(SDNode *N, SelectionDAG &DAG) { + if (SDValue V = trySQDMULHCombine(N, DAG)) { + return V; + } + + return SDValue(); +} + static SDValue performTruncateCombine(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI) { SDLoc DL(N); @@ -26737,6 +26840,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, return performAddSubCombine(N, DCI); case ISD::BUILD_VECTOR: return performBuildVectorCombine(N, DCI, DAG); + case ISD::SMIN: + return performSMINCombine(N, DAG); case ISD::TRUNCATE: return performTruncateCombine(N, DAG, DCI); case AArch64ISD::ANDS: diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index ddc685fae5e9a..ce91b72fa24e5 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -1022,6 +1022,7 @@ def AArch64smull : SDNode<"AArch64ISD::SMULL", SDT_AArch64mull, [SDNPCommutative]>; def AArch64umull : SDNode<"AArch64ISD::UMULL", SDT_AArch64mull, [SDNPCommutative]>; +def AArch64sqdmulh : SDNode<"AArch64ISD::SQDMULH", SDT_AArch64mull>; // Reciprocal estimates and steps. def AArch64frecpe : SDNode<"AArch64ISD::FRECPE", SDTFPUnaryOp>; @@ -9439,6 +9440,15 @@ def : Pat<(v4i32 (mulhu V128:$Rn, V128:$Rm)), (EXTRACT_SUBREG V128:$Rm, dsub)), (UMULLv4i32_v2i64 V128:$Rn, V128:$Rm))>; +def : Pat<(v4i16 (AArch64sqdmulh (v4i16 V64:$Rn), (v4i16 V64:$Rm))), + (SQDMULHv4i16 V64:$Rn, V64:$Rm)>; +def : Pat<(v2i32 (AArch64sqdmulh (v2i32 V64:$Rn), (v2i32 V64:$Rm))), + (SQDMULHv2i32 V64:$Rn, V64:$Rm)>; +def : Pat<(v8i16 (AArch64sqdmulh (v8i16 V128:$Rn), (v8i16 V128:$Rm))), + (SQDMULHv8i16 V128:$Rn, V128:$Rm)>; +def : Pat<(v4i32 (AArch64sqdmulh (v4i32 V128:$Rn), (v4i32 V128:$Rm))), + (SQDMULHv4i32 V128:$Rn, V128:$Rm)>; + // Conversions within AdvSIMD types in the same register size are free. // But because we need a consistent lane ordering, in big endian many // conversions require one or more REV instructions. diff --git a/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll b/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll new file mode 100644 index 0000000000000..b647daf72ca35 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll @@ -0,0 +1,223 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc -mtriple=aarch64-none-elf < %s | FileCheck %s + + +define <2 x i16> @saturating_2xi16(<2 x i16> %a, <2 x i16> %b) { +; CHECK-LABEL: saturating_2xi16: +; CHECK: // %bb.0: +; CHECK-NEXT: shl v0.2s, v0.2s, #16 +; CHECK-NEXT: shl v1.2s, v1.2s, #16 +; CHECK-NEXT: sshr v0.2s, v0.2s, #16 +; CHECK-NEXT: sshr v1.2s, v1.2s, #16 +; CHECK-NEXT: sqdmulh v0.2s, v1.2s, v0.2s +; CHECK-NEXT: ret + %as = sext <2 x i16> %a to <2 x i32> + %bs = sext <2 x i16> %b to <2 x i32> + %m = mul <2 x i32> %bs, %as + %sh = ashr <2 x i32> %m, splat (i32 15) + %ma = tail call <2 x i32> @llvm.smin.v4i32(<2 x i32> %sh, <2 x i32> splat (i32 32767)) + %t = trunc <2 x i32> %ma to <2 x i16> + ret <2 x i16> %t +} + +define <4 x i16> @saturating_4xi16(<4 x i16> %a, <4 x i16> %b) { +; CHECK-LABEL: saturating_4xi16: +; CHECK: // %bb.0: +; CHECK-NEXT: sqdmulh v0.4h, v1.4h, v0.4h +; CHECK-NEXT: ret + %as = sext <4 x i16> %a to <4 x i32> + %bs = sext <4 x i16> %b to <4 x i32> + %m = mul <4 x i32> %bs, %as + %sh = ashr <4 x i32> %m, splat (i32 15) + %ma = tail call <4 x i32> @llvm.smin.v4i32(<4 x i32> %sh, <4 x i32> splat (i32 32767)) + %t = trunc <4 x i32> %ma to <4 x i16> + ret <4 x i16> %t +} + +define <8 x i16> @saturating_8xi16(<8 x i16> %a, <8 x i16> %b) { +; CHECK-LABEL: saturating_8xi16: +; CHECK: // %bb.0: +; CHECK-NEXT: sqdmulh v0.8h, v1.8h, v0.8h +; CHECK-NEXT: ret + %as = sext <8 x i16> %a to <8 x i32> + %bs = sext <8 x i16> %b to <8 x i32> + %m = mul <8 x i32> %bs, %as + %sh = ashr <8 x i32> %m, splat (i32 15) + %ma = tail call <8 x i32> @llvm.smin.v8i32(<8 x i32> %sh, <8 x i32> splat (i32 32767)) + %t = trunc <8 x i32> %ma to <8 x i16> + ret <8 x i16> %t +} + +define <2 x i32> @saturating_2xi32(<2 x i32> %a, <2 x i32> %b) { +; CHECK-LABEL: saturating_2xi32: +; CHECK: // %bb.0: +; CHECK-NEXT: sqdmulh v0.2s, v1.2s, v0.2s +; CHECK-NEXT: ret + %as = sext <2 x i32> %a to <2 x i64> + %bs = sext <2 x i32> %b to <2 x i64> + %m = mul <2 x i64> %bs, %as + %sh = ashr <2 x i64> %m, splat (i64 31) + %ma = tail call <2 x i64> @llvm.smin.v8i64(<2 x i64> %sh, <2 x i64> splat (i64 2147483647)) + %t = trunc <2 x i64> %ma to <2 x i32> + ret <2 x i32> %t +} + +define <4 x i32> @saturating_4xi32(<4 x i32> %a, <4 x i32> %b) { +; CHECK-LABEL: saturating_4xi32: +; CHECK: // %bb.0: +; CHECK-NEXT: sqdmulh v0.4s, v1.4s, v0.4s +; CHECK-NEXT: ret + %as = sext <4 x i32> %a to <4 x i64> + %bs = sext <4 x i32> %b to <4 x i64> + %m = mul <4 x i64> %bs, %as + %sh = ashr <4 x i64> %m, splat (i64 31) + %ma = tail call <4 x i64> @llvm.smin.v4i64(<4 x i64> %sh, <4 x i64> splat (i64 2147483647)) + %t = trunc <4 x i64> %ma to <4 x i32> + ret <4 x i32> %t +} + +define <8 x i32> @saturating_8xi32(<8 x i32> %a, <8 x i32> %b) { +; CHECK-LABEL: saturating_8xi32: +; CHECK: // %bb.0: +; CHECK-NEXT: sqdmulh v1.4s, v3.4s, v1.4s +; CHECK-NEXT: sqdmulh v0.4s, v2.4s, v0.4s +; CHECK-NEXT: ret + %as = sext <8 x i32> %a to <8 x i64> + %bs = sext <8 x i32> %b to <8 x i64> + %m = mul <8 x i64> %bs, %as + %sh = ashr <8 x i64> %m, splat (i64 31) + %ma = tail call <8 x i64> @llvm.smin.v8i64(<8 x i64> %sh, <8 x i64> splat (i64 2147483647)) + %t = trunc <8 x i64> %ma to <8 x i32> + ret <8 x i32> %t +} + +define <2 x i64> @saturating_2xi32_2xi64(<2 x i32> %a, <2 x i32> %b) { +; CHECK-LABEL: saturating_2xi32_2xi64: +; CHECK: // %bb.0: +; CHECK-NEXT: sqdmulh v0.2s, v1.2s, v0.2s +; CHECK-NEXT: sshll v0.2d, v0.2s, #0 +; CHECK-NEXT: ret + %as = sext <2 x i32> %a to <2 x i64> + %bs = sext <2 x i32> %b to <2 x i64> + %m = mul <2 x i64> %bs, %as + %sh = ashr <2 x i64> %m, splat (i64 31) + %ma = tail call <2 x i64> @llvm.smin.v8i64(<2 x i64> %sh, <2 x i64> splat (i64 2147483647)) + ret <2 x i64> %ma +} + +define <6 x i16> @saturating_6xi16(<6 x i16> %a, <6 x i16> %b) { +; CHECK-LABEL: saturating_6xi16: +; CHECK: // %bb.0: +; CHECK-NEXT: smull2 v3.4s, v1.8h, v0.8h +; CHECK-NEXT: movi v2.4s, #127, msl #8 +; CHECK-NEXT: sqdmulh v0.4h, v1.4h, v0.4h +; CHECK-NEXT: sshr v3.4s, v3.4s, #15 +; CHECK-NEXT: smin v2.4s, v3.4s, v2.4s +; CHECK-NEXT: xtn2 v0.8h, v2.4s +; CHECK-NEXT: ret + %as = sext <6 x i16> %a to <6 x i32> + %bs = sext <6 x i16> %b to <6 x i32> + %m = mul <6 x i32> %bs, %as + %sh = ashr <6 x i32> %m, splat (i32 15) + %ma = tail call <6 x i32> @llvm.smin.v6i32(<6 x i32> %sh, <6 x i32> splat (i32 32767)) + %t = trunc <6 x i32> %ma to <6 x i16> + ret <6 x i16> %t +} + +define <4 x i16> @unsupported_saturation_value_v4i16(<4 x i16> %a, <4 x i16> %b) { +; CHECK-LABEL: unsupported_saturation_value_v4i16: +; CHECK: // %bb.0: +; CHECK-NEXT: smull v0.4s, v1.4h, v0.4h +; CHECK-NEXT: movi v1.4s, #42 +; CHECK-NEXT: sshr v0.4s, v0.4s, #15 +; CHECK-NEXT: smin v0.4s, v0.4s, v1.4s +; CHECK-NEXT: xtn v0.4h, v0.4s +; CHECK-NEXT: ret + %as = sext <4 x i16> %a to <4 x i32> + %bs = sext <4 x i16> %b to <4 x i32> + %m = mul <4 x i32> %bs, %as + %sh = ashr <4 x i32> %m, splat (i32 15) + %ma = tail call <4 x i32> @llvm.smin.v4i32(<4 x i32> %sh, <4 x i32> splat (i32 42)) + %t = trunc <4 x i32> %ma to <4 x i16> + ret <4 x i16> %t +} + +define <4 x i16> @unsupported_shift_value_v4i16(<4 x i16> %a, <4 x i16> %b) { +; CHECK-LABEL: unsupported_shift_value_v4i16: +; CHECK: // %bb.0: +; CHECK-NEXT: smull v0.4s, v1.4h, v0.4h +; CHECK-NEXT: movi v1.4s, #127, msl #8 +; CHECK-NEXT: sshr v0.4s, v0.4s, #3 +; CHECK-NEXT: smin v0.4s, v0.4s, v1.4s +; CHECK-NEXT: xtn v0.4h, v0.4s +; CHECK-NEXT: ret + %as = sext <4 x i16> %a to <4 x i32> + %bs = sext <4 x i16> %b to <4 x i32> + %m = mul <4 x i32> %bs, %as + %sh = ashr <4 x i32> %m, splat (i32 3) + %ma = tail call <4 x i32> @llvm.smin.v4i32(<4 x i32> %sh, <4 x i32> splat (i32 32767)) + %t = trunc <4 x i32> %ma to <4 x i16> + ret <4 x i16> %t +} + +define <2 x i16> @extend_to_illegal_type(<2 x i16> %a, <2 x i16> %b) { +; CHECK-LABEL: extend_to_illegal_type: +; CHECK: // %bb.0: +; CHECK-NEXT: shl v0.2s, v0.2s, #16 +; CHECK-NEXT: shl v1.2s, v1.2s, #16 +; CHECK-NEXT: sshr v0.2s, v0.2s, #16 +; CHECK-NEXT: sshr v1.2s, v1.2s, #16 +; CHECK-NEXT: sqdmulh v0.2s, v1.2s, v0.2s +; CHECK-NEXT: ret + %as = sext <2 x i16> %a to <2 x i48> + %bs = sext <2 x i16> %b to <2 x i48> + %m = mul <2 x i48> %bs, %as + %sh = ashr <2 x i48> %m, splat (i48 15) + %ma = tail call <2 x i48> @llvm.smin.v4i32(<2 x i48> %sh, <2 x i48> splat (i48 32767)) + %t = trunc <2 x i48> %ma to <2 x i16> + ret <2 x i16> %t +} + +define <2 x i11> @illegal_source(<2 x i11> %a, <2 x i11> %b) { +; CHECK-LABEL: illegal_source: +; CHECK: // %bb.0: +; CHECK-NEXT: shl v0.2s, v0.2s, #21 +; CHECK-NEXT: shl v1.2s, v1.2s, #21 +; CHECK-NEXT: sshr v0.2s, v0.2s, #21 +; CHECK-NEXT: sshr v1.2s, v1.2s, #21 +; CHECK-NEXT: mul v0.2s, v1.2s, v0.2s +; CHECK-NEXT: movi v1.2s, #127, msl #8 +; CHECK-NEXT: sshr v0.2s, v0.2s, #15 +; CHECK-NEXT: smin v0.2s, v0.2s, v1.2s +; CHECK-NEXT: ret + %as = sext <2 x i11> %a to <2 x i32> + %bs = sext <2 x i11> %b to <2 x i32> + %m = mul <2 x i32> %bs, %as + %sh = ashr <2 x i32> %m, splat (i32 15) + %ma = tail call <2 x i32> @llvm.smin.v2i32(<2 x i32> %sh, <2 x i32> splat (i32 32767)) + %t = trunc <2 x i32> %ma to <2 x i11> + ret <2 x i11> %t +} +define <1 x i16> @saturating_1xi16(<1 x i16> %a, <1 x i16> %b) { +; CHECK-LABEL: saturating_1xi16: +; CHECK: // %bb.0: +; CHECK-NEXT: zip1 v0.4h, v0.4h, v0.4h +; CHECK-NEXT: zip1 v1.4h, v1.4h, v0.4h +; CHECK-NEXT: shl v0.2s, v0.2s, #16 +; CHECK-NEXT: sshr v0.2s, v0.2s, #16 +; CHECK-NEXT: shl v1.2s, v1.2s, #16 +; CHECK-NEXT: sshr v1.2s, v1.2s, #16 +; CHECK-NEXT: mul v0.2s, v1.2s, v0.2s +; CHECK-NEXT: movi v1.2s, #127, msl #8 +; CHECK-NEXT: sshr v0.2s, v0.2s, #15 +; CHECK-NEXT: smin v0.2s, v0.2s, v1.2s +; CHECK-NEXT: uzp1 v0.4h, v0.4h, v0.4h +; CHECK-NEXT: ret + %as = sext <1 x i16> %a to <1 x i32> + %bs = sext <1 x i16> %b to <1 x i32> + %m = mul <1 x i32> %bs, %as + %sh = ashr <1 x i32> %m, splat (i32 15) + %ma = tail call <1 x i32> @llvm.smin.v1i32(<1 x i32> %sh, <1 x i32> splat (i32 32767)) + %t = trunc <1 x i32> %ma to <1 x i16> + ret <1 x i16> %t +}