Skip to content

Commit 427b644

Browse files
Revert "[LLVM][ISel][AArch64 Remove AArch64ISD::FCM##z nodes. (#135817)"
This reverts commit 15d8b3c.
1 parent ebceb73 commit 427b644

11 files changed

+167
-93
lines changed

llvm/lib/CodeGen/GlobalISel/Utils.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -1385,8 +1385,7 @@ bool llvm::isBuildVectorConstantSplat(const Register Reg,
13851385
const MachineRegisterInfo &MRI,
13861386
int64_t SplatValue, bool AllowUndef) {
13871387
if (auto SplatValAndReg = getAnyConstantSplat(Reg, MRI, AllowUndef))
1388-
return SplatValAndReg->Value.getSExtValue() == SplatValue;
1389-
1388+
return mi_match(SplatValAndReg->VReg, MRI, m_SpecificICst(SplatValue));
13901389
return false;
13911390
}
13921391

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

+36-1
Original file line numberDiff line numberDiff line change
@@ -2602,6 +2602,11 @@ unsigned AArch64TargetLowering::ComputeNumSignBitsForTargetNode(
26022602
case AArch64ISD::FCMEQ:
26032603
case AArch64ISD::FCMGE:
26042604
case AArch64ISD::FCMGT:
2605+
case AArch64ISD::FCMEQz:
2606+
case AArch64ISD::FCMGEz:
2607+
case AArch64ISD::FCMGTz:
2608+
case AArch64ISD::FCMLEz:
2609+
case AArch64ISD::FCMLTz:
26052610
// Compares return either 0 or all-ones
26062611
return VTBits;
26072612
case AArch64ISD::VASHR: {
@@ -2818,6 +2823,11 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
28182823
MAKE_CASE(AArch64ISD::FCMEQ)
28192824
MAKE_CASE(AArch64ISD::FCMGE)
28202825
MAKE_CASE(AArch64ISD::FCMGT)
2826+
MAKE_CASE(AArch64ISD::FCMEQz)
2827+
MAKE_CASE(AArch64ISD::FCMGEz)
2828+
MAKE_CASE(AArch64ISD::FCMGTz)
2829+
MAKE_CASE(AArch64ISD::FCMLEz)
2830+
MAKE_CASE(AArch64ISD::FCMLTz)
28212831
MAKE_CASE(AArch64ISD::SADDV)
28222832
MAKE_CASE(AArch64ISD::UADDV)
28232833
MAKE_CASE(AArch64ISD::UADDLV)
@@ -15830,33 +15840,58 @@ static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS,
1583015840
assert(VT.getSizeInBits() == SrcVT.getSizeInBits() &&
1583115841
"function only supposed to emit natural comparisons");
1583215842

15843+
APInt SplatValue;
15844+
APInt SplatUndef;
15845+
unsigned SplatBitSize = 0;
15846+
bool HasAnyUndefs;
15847+
15848+
BuildVectorSDNode *BVN = dyn_cast<BuildVectorSDNode>(RHS.getNode());
15849+
bool IsCnst = BVN && BVN->isConstantSplat(SplatValue, SplatUndef,
15850+
SplatBitSize, HasAnyUndefs);
15851+
15852+
bool IsZero = IsCnst && SplatValue == 0;
15853+
1583315854
if (SrcVT.getVectorElementType().isFloatingPoint()) {
1583415855
switch (CC) {
1583515856
default:
1583615857
return SDValue();
1583715858
case AArch64CC::NE: {
15838-
SDValue Fcmeq = DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
15859+
SDValue Fcmeq;
15860+
if (IsZero)
15861+
Fcmeq = DAG.getNode(AArch64ISD::FCMEQz, dl, VT, LHS);
15862+
else
15863+
Fcmeq = DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
1583915864
return DAG.getNOT(dl, Fcmeq, VT);
1584015865
}
1584115866
case AArch64CC::EQ:
15867+
if (IsZero)
15868+
return DAG.getNode(AArch64ISD::FCMEQz, dl, VT, LHS);
1584215869
return DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
1584315870
case AArch64CC::GE:
15871+
if (IsZero)
15872+
return DAG.getNode(AArch64ISD::FCMGEz, dl, VT, LHS);
1584415873
return DAG.getNode(AArch64ISD::FCMGE, dl, VT, LHS, RHS);
1584515874
case AArch64CC::GT:
15875+
if (IsZero)
15876+
return DAG.getNode(AArch64ISD::FCMGTz, dl, VT, LHS);
1584615877
return DAG.getNode(AArch64ISD::FCMGT, dl, VT, LHS, RHS);
1584715878
case AArch64CC::LE:
1584815879
if (!NoNans)
1584915880
return SDValue();
1585015881
// If we ignore NaNs then we can use to the LS implementation.
1585115882
[[fallthrough]];
1585215883
case AArch64CC::LS:
15884+
if (IsZero)
15885+
return DAG.getNode(AArch64ISD::FCMLEz, dl, VT, LHS);
1585315886
return DAG.getNode(AArch64ISD::FCMGE, dl, VT, RHS, LHS);
1585415887
case AArch64CC::LT:
1585515888
if (!NoNans)
1585615889
return SDValue();
1585715890
// If we ignore NaNs then we can use to the MI implementation.
1585815891
[[fallthrough]];
1585915892
case AArch64CC::MI:
15893+
if (IsZero)
15894+
return DAG.getNode(AArch64ISD::FCMLTz, dl, VT, LHS);
1586015895
return DAG.getNode(AArch64ISD::FCMGT, dl, VT, RHS, LHS);
1586115896
}
1586215897
}

llvm/lib/Target/AArch64/AArch64ISelLowering.h

+7
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,13 @@ enum NodeType : unsigned {
245245
FCMGE,
246246
FCMGT,
247247

248+
// Vector zero comparisons
249+
FCMEQz,
250+
FCMGEz,
251+
FCMGTz,
252+
FCMLEz,
253+
FCMLTz,
254+
248255
// Round wide FP to narrow FP with inexact results to odd.
249256
FCVTXN,
250257

llvm/lib/Target/AArch64/AArch64InstrFormats.td

+1-1
Original file line numberDiff line numberDiff line change
@@ -7136,7 +7136,7 @@ multiclass SIMDCmpTwoVector<bit U, bits<5> opc, string asm,
71367136

71377137
// FP Comparisons support only S and D element sizes (and H for v8.2a).
71387138
multiclass SIMDFPCmpTwoVector<bit U, bit S, bits<5> opc,
7139-
string asm, SDPatternOperator OpNode> {
7139+
string asm, SDNode OpNode> {
71407140

71417141
let mayRaiseFPException = 1, Uses = [FPCR] in {
71427142
let Predicates = [HasNEON, HasFullFP16] in {

llvm/lib/Target/AArch64/AArch64InstrGISel.td

+36
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,36 @@ def G_FCMGT : AArch64GenericInstruction {
179179
let hasSideEffects = 0;
180180
}
181181

182+
def G_FCMEQZ : AArch64GenericInstruction {
183+
let OutOperandList = (outs type0:$dst);
184+
let InOperandList = (ins type0:$src);
185+
let hasSideEffects = 0;
186+
}
187+
188+
def G_FCMGEZ : AArch64GenericInstruction {
189+
let OutOperandList = (outs type0:$dst);
190+
let InOperandList = (ins type0:$src);
191+
let hasSideEffects = 0;
192+
}
193+
194+
def G_FCMGTZ : AArch64GenericInstruction {
195+
let OutOperandList = (outs type0:$dst);
196+
let InOperandList = (ins type0:$src);
197+
let hasSideEffects = 0;
198+
}
199+
200+
def G_FCMLEZ : AArch64GenericInstruction {
201+
let OutOperandList = (outs type0:$dst);
202+
let InOperandList = (ins type0:$src);
203+
let hasSideEffects = 0;
204+
}
205+
206+
def G_FCMLTZ : AArch64GenericInstruction {
207+
let OutOperandList = (outs type0:$dst);
208+
let InOperandList = (ins type0:$src);
209+
let hasSideEffects = 0;
210+
}
211+
182212
def G_AARCH64_PREFETCH : AArch64GenericInstruction {
183213
let OutOperandList = (outs);
184214
let InOperandList = (ins type0:$imm, ptype0:$src1);
@@ -265,6 +295,12 @@ def : GINodeEquiv<G_FCMEQ, AArch64fcmeq>;
265295
def : GINodeEquiv<G_FCMGE, AArch64fcmge>;
266296
def : GINodeEquiv<G_FCMGT, AArch64fcmgt>;
267297

298+
def : GINodeEquiv<G_FCMEQZ, AArch64fcmeqz>;
299+
def : GINodeEquiv<G_FCMGEZ, AArch64fcmgez>;
300+
def : GINodeEquiv<G_FCMGTZ, AArch64fcmgtz>;
301+
def : GINodeEquiv<G_FCMLEZ, AArch64fcmlez>;
302+
def : GINodeEquiv<G_FCMLTZ, AArch64fcmltz>;
303+
268304
def : GINodeEquiv<G_BSP, AArch64bsp>;
269305

270306
def : GINodeEquiv<G_UMULL, AArch64umull>;

llvm/lib/Target/AArch64/AArch64InstrInfo.td

+5-14
Original file line numberDiff line numberDiff line change
@@ -882,20 +882,11 @@ def AArch64cmltz : PatFrag<(ops node:$lhs),
882882
def AArch64cmtst : PatFrag<(ops node:$LHS, node:$RHS),
883883
(vnot (AArch64cmeqz (and node:$LHS, node:$RHS)))>;
884884

885-
def AArch64fcmeqz : PatFrag<(ops node:$lhs),
886-
(AArch64fcmeq node:$lhs, immAllZerosV)>;
887-
888-
def AArch64fcmgez : PatFrag<(ops node:$lhs),
889-
(AArch64fcmge node:$lhs, immAllZerosV)>;
890-
891-
def AArch64fcmgtz : PatFrag<(ops node:$lhs),
892-
(AArch64fcmgt node:$lhs, immAllZerosV)>;
893-
894-
def AArch64fcmlez : PatFrag<(ops node:$lhs),
895-
(AArch64fcmge immAllZerosV, node:$lhs)>;
896-
897-
def AArch64fcmltz : PatFrag<(ops node:$lhs),
898-
(AArch64fcmgt immAllZerosV, node:$lhs)>;
885+
def AArch64fcmeqz: SDNode<"AArch64ISD::FCMEQz", SDT_AArch64fcmpz>;
886+
def AArch64fcmgez: SDNode<"AArch64ISD::FCMGEz", SDT_AArch64fcmpz>;
887+
def AArch64fcmgtz: SDNode<"AArch64ISD::FCMGTz", SDT_AArch64fcmpz>;
888+
def AArch64fcmlez: SDNode<"AArch64ISD::FCMLEz", SDT_AArch64fcmpz>;
889+
def AArch64fcmltz: SDNode<"AArch64ISD::FCMLTz", SDT_AArch64fcmpz>;
899890

900891
def AArch64fcvtxn_n: SDNode<"AArch64ISD::FCVTXN", SDTFPRoundOp>;
901892
def AArch64fcvtxnsdr: PatFrags<(ops node:$Rn),

llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp

+48-22
Original file line numberDiff line numberDiff line change
@@ -808,14 +808,16 @@ void applyScalarizeVectorUnmerge(MachineInstr &MI, MachineRegisterInfo &MRI,
808808

809809
bool matchBuildVectorToDup(MachineInstr &MI, MachineRegisterInfo &MRI) {
810810
assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
811-
811+
auto Splat = getAArch64VectorSplat(MI, MRI);
812+
if (!Splat)
813+
return false;
814+
if (Splat->isReg())
815+
return true;
812816
// Later, during selection, we'll try to match imported patterns using
813817
// immAllOnesV and immAllZerosV. These require G_BUILD_VECTOR. Don't lower
814818
// G_BUILD_VECTORs which could match those patterns.
815-
if (isBuildVectorAllZeros(MI, MRI) || isBuildVectorAllOnes(MI, MRI))
816-
return false;
817-
818-
return getAArch64VectorSplat(MI, MRI).has_value();
819+
int64_t Cst = Splat->getCst();
820+
return (Cst != 0 && Cst != -1);
819821
}
820822

821823
void applyBuildVectorToDup(MachineInstr &MI, MachineRegisterInfo &MRI,
@@ -931,40 +933,58 @@ void applySwapICmpOperands(MachineInstr &MI, GISelChangeObserver &Observer) {
931933

932934
/// \returns a function which builds a vector floating point compare instruction
933935
/// for a condition code \p CC.
936+
/// \param [in] IsZero - True if the comparison is against 0.
934937
/// \param [in] NoNans - True if the target has NoNansFPMath.
935938
std::function<Register(MachineIRBuilder &)>
936-
getVectorFCMP(AArch64CC::CondCode CC, Register LHS, Register RHS, bool NoNans,
937-
MachineRegisterInfo &MRI) {
939+
getVectorFCMP(AArch64CC::CondCode CC, Register LHS, Register RHS, bool IsZero,
940+
bool NoNans, MachineRegisterInfo &MRI) {
938941
LLT DstTy = MRI.getType(LHS);
939942
assert(DstTy.isVector() && "Expected vector types only?");
940943
assert(DstTy == MRI.getType(RHS) && "Src and Dst types must match!");
941944
switch (CC) {
942945
default:
943946
llvm_unreachable("Unexpected condition code!");
944947
case AArch64CC::NE:
945-
return [LHS, RHS, DstTy](MachineIRBuilder &MIB) {
946-
auto FCmp = MIB.buildInstr(AArch64::G_FCMEQ, {DstTy}, {LHS, RHS});
948+
return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
949+
auto FCmp = IsZero
950+
? MIB.buildInstr(AArch64::G_FCMEQZ, {DstTy}, {LHS})
951+
: MIB.buildInstr(AArch64::G_FCMEQ, {DstTy}, {LHS, RHS});
947952
return MIB.buildNot(DstTy, FCmp).getReg(0);
948953
};
949954
case AArch64CC::EQ:
950-
return [LHS, RHS, DstTy](MachineIRBuilder &MIB) {
951-
return MIB.buildInstr(AArch64::G_FCMEQ, {DstTy}, {LHS, RHS}).getReg(0);
955+
return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
956+
return IsZero
957+
? MIB.buildInstr(AArch64::G_FCMEQZ, {DstTy}, {LHS}).getReg(0)
958+
: MIB.buildInstr(AArch64::G_FCMEQ, {DstTy}, {LHS, RHS})
959+
.getReg(0);
952960
};
953961
case AArch64CC::GE:
954-
return [LHS, RHS, DstTy](MachineIRBuilder &MIB) {
955-
return MIB.buildInstr(AArch64::G_FCMGE, {DstTy}, {LHS, RHS}).getReg(0);
962+
return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
963+
return IsZero
964+
? MIB.buildInstr(AArch64::G_FCMGEZ, {DstTy}, {LHS}).getReg(0)
965+
: MIB.buildInstr(AArch64::G_FCMGE, {DstTy}, {LHS, RHS})
966+
.getReg(0);
956967
};
957968
case AArch64CC::GT:
958-
return [LHS, RHS, DstTy](MachineIRBuilder &MIB) {
959-
return MIB.buildInstr(AArch64::G_FCMGT, {DstTy}, {LHS, RHS}).getReg(0);
969+
return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
970+
return IsZero
971+
? MIB.buildInstr(AArch64::G_FCMGTZ, {DstTy}, {LHS}).getReg(0)
972+
: MIB.buildInstr(AArch64::G_FCMGT, {DstTy}, {LHS, RHS})
973+
.getReg(0);
960974
};
961975
case AArch64CC::LS:
962-
return [LHS, RHS, DstTy](MachineIRBuilder &MIB) {
963-
return MIB.buildInstr(AArch64::G_FCMGE, {DstTy}, {RHS, LHS}).getReg(0);
976+
return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
977+
return IsZero
978+
? MIB.buildInstr(AArch64::G_FCMLEZ, {DstTy}, {LHS}).getReg(0)
979+
: MIB.buildInstr(AArch64::G_FCMGE, {DstTy}, {RHS, LHS})
980+
.getReg(0);
964981
};
965982
case AArch64CC::MI:
966-
return [LHS, RHS, DstTy](MachineIRBuilder &MIB) {
967-
return MIB.buildInstr(AArch64::G_FCMGT, {DstTy}, {RHS, LHS}).getReg(0);
983+
return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
984+
return IsZero
985+
? MIB.buildInstr(AArch64::G_FCMLTZ, {DstTy}, {LHS}).getReg(0)
986+
: MIB.buildInstr(AArch64::G_FCMGT, {DstTy}, {RHS, LHS})
987+
.getReg(0);
968988
};
969989
}
970990
}
@@ -1004,17 +1024,23 @@ void applyLowerVectorFCMP(MachineInstr &MI, MachineRegisterInfo &MRI,
10041024

10051025
LLT DstTy = MRI.getType(Dst);
10061026

1027+
auto Splat = getAArch64VectorSplat(*MRI.getVRegDef(RHS), MRI);
1028+
1029+
// Compares against 0 have special target-specific pseudos.
1030+
bool IsZero = Splat && Splat->isCst() && Splat->getCst() == 0;
1031+
10071032
bool Invert = false;
10081033
AArch64CC::CondCode CC, CC2 = AArch64CC::AL;
10091034
if ((Pred == CmpInst::Predicate::FCMP_ORD ||
10101035
Pred == CmpInst::Predicate::FCMP_UNO) &&
1011-
isBuildVectorAllZeros(*MRI.getVRegDef(RHS), MRI)) {
1036+
IsZero) {
10121037
// The special case "fcmp ord %a, 0" is the canonical check that LHS isn't
10131038
// NaN, so equivalent to a == a and doesn't need the two comparisons an
10141039
// "ord" normally would.
10151040
// Similarly, "fcmp uno %a, 0" is the canonical check that LHS is NaN and is
10161041
// thus equivalent to a != a.
10171042
RHS = LHS;
1043+
IsZero = false;
10181044
CC = Pred == CmpInst::Predicate::FCMP_ORD ? AArch64CC::EQ : AArch64CC::NE;
10191045
} else
10201046
changeVectorFCMPPredToAArch64CC(Pred, CC, CC2, Invert);
@@ -1025,12 +1051,12 @@ void applyLowerVectorFCMP(MachineInstr &MI, MachineRegisterInfo &MRI,
10251051
const bool NoNans =
10261052
ST.getTargetLowering()->getTargetMachine().Options.NoNaNsFPMath;
10271053

1028-
auto Cmp = getVectorFCMP(CC, LHS, RHS, NoNans, MRI);
1054+
auto Cmp = getVectorFCMP(CC, LHS, RHS, IsZero, NoNans, MRI);
10291055
Register CmpRes;
10301056
if (CC2 == AArch64CC::AL)
10311057
CmpRes = Cmp(MIB);
10321058
else {
1033-
auto Cmp2 = getVectorFCMP(CC2, LHS, RHS, NoNans, MRI);
1059+
auto Cmp2 = getVectorFCMP(CC2, LHS, RHS, IsZero, NoNans, MRI);
10341060
auto Cmp2Dst = Cmp2(MIB);
10351061
auto Cmp1Dst = Cmp(MIB);
10361062
CmpRes = MIB.buildOr(DstTy, Cmp1Dst, Cmp2Dst).getReg(0);

0 commit comments

Comments
 (0)