Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
35266 views
1
//===- AggressiveInstCombine.cpp ------------------------------------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This file implements the aggressive expression pattern combiner classes.
10
// Currently, it handles expression patterns for:
11
// * Truncate instruction
12
//
13
//===----------------------------------------------------------------------===//
14
15
#include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h"
16
#include "AggressiveInstCombineInternal.h"
17
#include "llvm/ADT/Statistic.h"
18
#include "llvm/Analysis/AliasAnalysis.h"
19
#include "llvm/Analysis/AssumptionCache.h"
20
#include "llvm/Analysis/BasicAliasAnalysis.h"
21
#include "llvm/Analysis/ConstantFolding.h"
22
#include "llvm/Analysis/DomTreeUpdater.h"
23
#include "llvm/Analysis/GlobalsModRef.h"
24
#include "llvm/Analysis/TargetLibraryInfo.h"
25
#include "llvm/Analysis/TargetTransformInfo.h"
26
#include "llvm/Analysis/ValueTracking.h"
27
#include "llvm/IR/DataLayout.h"
28
#include "llvm/IR/Dominators.h"
29
#include "llvm/IR/Function.h"
30
#include "llvm/IR/IRBuilder.h"
31
#include "llvm/IR/PatternMatch.h"
32
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
33
#include "llvm/Transforms/Utils/BuildLibCalls.h"
34
#include "llvm/Transforms/Utils/Local.h"
35
36
using namespace llvm;
37
using namespace PatternMatch;
38
39
#define DEBUG_TYPE "aggressive-instcombine"
40
41
STATISTIC(NumAnyOrAllBitsSet, "Number of any/all-bits-set patterns folded");
42
STATISTIC(NumGuardedRotates,
43
"Number of guarded rotates transformed into funnel shifts");
44
STATISTIC(NumGuardedFunnelShifts,
45
"Number of guarded funnel shifts transformed into funnel shifts");
46
STATISTIC(NumPopCountRecognized, "Number of popcount idioms recognized");
47
48
static cl::opt<unsigned> MaxInstrsToScan(
49
"aggressive-instcombine-max-scan-instrs", cl::init(64), cl::Hidden,
50
cl::desc("Max number of instructions to scan for aggressive instcombine."));
51
52
static cl::opt<unsigned> StrNCmpInlineThreshold(
53
"strncmp-inline-threshold", cl::init(3), cl::Hidden,
54
cl::desc("The maximum length of a constant string for a builtin string cmp "
55
"call eligible for inlining. The default value is 3."));
56
57
static cl::opt<unsigned>
58
MemChrInlineThreshold("memchr-inline-threshold", cl::init(3), cl::Hidden,
59
cl::desc("The maximum length of a constant string to "
60
"inline a memchr call."));
61
62
/// Match a pattern for a bitwise funnel/rotate operation that partially guards
63
/// against undefined behavior by branching around the funnel-shift/rotation
64
/// when the shift amount is 0.
65
static bool foldGuardedFunnelShift(Instruction &I, const DominatorTree &DT) {
66
if (I.getOpcode() != Instruction::PHI || I.getNumOperands() != 2)
67
return false;
68
69
// As with the one-use checks below, this is not strictly necessary, but we
70
// are being cautious to avoid potential perf regressions on targets that
71
// do not actually have a funnel/rotate instruction (where the funnel shift
72
// would be expanded back into math/shift/logic ops).
73
if (!isPowerOf2_32(I.getType()->getScalarSizeInBits()))
74
return false;
75
76
// Match V to funnel shift left/right and capture the source operands and
77
// shift amount.
78
auto matchFunnelShift = [](Value *V, Value *&ShVal0, Value *&ShVal1,
79
Value *&ShAmt) {
80
unsigned Width = V->getType()->getScalarSizeInBits();
81
82
// fshl(ShVal0, ShVal1, ShAmt)
83
// == (ShVal0 << ShAmt) | (ShVal1 >> (Width -ShAmt))
84
if (match(V, m_OneUse(m_c_Or(
85
m_Shl(m_Value(ShVal0), m_Value(ShAmt)),
86
m_LShr(m_Value(ShVal1),
87
m_Sub(m_SpecificInt(Width), m_Deferred(ShAmt))))))) {
88
return Intrinsic::fshl;
89
}
90
91
// fshr(ShVal0, ShVal1, ShAmt)
92
// == (ShVal0 >> ShAmt) | (ShVal1 << (Width - ShAmt))
93
if (match(V,
94
m_OneUse(m_c_Or(m_Shl(m_Value(ShVal0), m_Sub(m_SpecificInt(Width),
95
m_Value(ShAmt))),
96
m_LShr(m_Value(ShVal1), m_Deferred(ShAmt)))))) {
97
return Intrinsic::fshr;
98
}
99
100
return Intrinsic::not_intrinsic;
101
};
102
103
// One phi operand must be a funnel/rotate operation, and the other phi
104
// operand must be the source value of that funnel/rotate operation:
105
// phi [ rotate(RotSrc, ShAmt), FunnelBB ], [ RotSrc, GuardBB ]
106
// phi [ fshl(ShVal0, ShVal1, ShAmt), FunnelBB ], [ ShVal0, GuardBB ]
107
// phi [ fshr(ShVal0, ShVal1, ShAmt), FunnelBB ], [ ShVal1, GuardBB ]
108
PHINode &Phi = cast<PHINode>(I);
109
unsigned FunnelOp = 0, GuardOp = 1;
110
Value *P0 = Phi.getOperand(0), *P1 = Phi.getOperand(1);
111
Value *ShVal0, *ShVal1, *ShAmt;
112
Intrinsic::ID IID = matchFunnelShift(P0, ShVal0, ShVal1, ShAmt);
113
if (IID == Intrinsic::not_intrinsic ||
114
(IID == Intrinsic::fshl && ShVal0 != P1) ||
115
(IID == Intrinsic::fshr && ShVal1 != P1)) {
116
IID = matchFunnelShift(P1, ShVal0, ShVal1, ShAmt);
117
if (IID == Intrinsic::not_intrinsic ||
118
(IID == Intrinsic::fshl && ShVal0 != P0) ||
119
(IID == Intrinsic::fshr && ShVal1 != P0))
120
return false;
121
assert((IID == Intrinsic::fshl || IID == Intrinsic::fshr) &&
122
"Pattern must match funnel shift left or right");
123
std::swap(FunnelOp, GuardOp);
124
}
125
126
// The incoming block with our source operand must be the "guard" block.
127
// That must contain a cmp+branch to avoid the funnel/rotate when the shift
128
// amount is equal to 0. The other incoming block is the block with the
129
// funnel/rotate.
130
BasicBlock *GuardBB = Phi.getIncomingBlock(GuardOp);
131
BasicBlock *FunnelBB = Phi.getIncomingBlock(FunnelOp);
132
Instruction *TermI = GuardBB->getTerminator();
133
134
// Ensure that the shift values dominate each block.
135
if (!DT.dominates(ShVal0, TermI) || !DT.dominates(ShVal1, TermI))
136
return false;
137
138
ICmpInst::Predicate Pred;
139
BasicBlock *PhiBB = Phi.getParent();
140
if (!match(TermI, m_Br(m_ICmp(Pred, m_Specific(ShAmt), m_ZeroInt()),
141
m_SpecificBB(PhiBB), m_SpecificBB(FunnelBB))))
142
return false;
143
144
if (Pred != CmpInst::ICMP_EQ)
145
return false;
146
147
IRBuilder<> Builder(PhiBB, PhiBB->getFirstInsertionPt());
148
149
if (ShVal0 == ShVal1)
150
++NumGuardedRotates;
151
else
152
++NumGuardedFunnelShifts;
153
154
// If this is not a rotate then the select was blocking poison from the
155
// 'shift-by-zero' non-TVal, but a funnel shift won't - so freeze it.
156
bool IsFshl = IID == Intrinsic::fshl;
157
if (ShVal0 != ShVal1) {
158
if (IsFshl && !llvm::isGuaranteedNotToBePoison(ShVal1))
159
ShVal1 = Builder.CreateFreeze(ShVal1);
160
else if (!IsFshl && !llvm::isGuaranteedNotToBePoison(ShVal0))
161
ShVal0 = Builder.CreateFreeze(ShVal0);
162
}
163
164
// We matched a variation of this IR pattern:
165
// GuardBB:
166
// %cmp = icmp eq i32 %ShAmt, 0
167
// br i1 %cmp, label %PhiBB, label %FunnelBB
168
// FunnelBB:
169
// %sub = sub i32 32, %ShAmt
170
// %shr = lshr i32 %ShVal1, %sub
171
// %shl = shl i32 %ShVal0, %ShAmt
172
// %fsh = or i32 %shr, %shl
173
// br label %PhiBB
174
// PhiBB:
175
// %cond = phi i32 [ %fsh, %FunnelBB ], [ %ShVal0, %GuardBB ]
176
// -->
177
// llvm.fshl.i32(i32 %ShVal0, i32 %ShVal1, i32 %ShAmt)
178
Function *F = Intrinsic::getDeclaration(Phi.getModule(), IID, Phi.getType());
179
Phi.replaceAllUsesWith(Builder.CreateCall(F, {ShVal0, ShVal1, ShAmt}));
180
return true;
181
}
182
183
/// This is used by foldAnyOrAllBitsSet() to capture a source value (Root) and
184
/// the bit indexes (Mask) needed by a masked compare. If we're matching a chain
185
/// of 'and' ops, then we also need to capture the fact that we saw an
186
/// "and X, 1", so that's an extra return value for that case.
187
struct MaskOps {
188
Value *Root = nullptr;
189
APInt Mask;
190
bool MatchAndChain;
191
bool FoundAnd1 = false;
192
193
MaskOps(unsigned BitWidth, bool MatchAnds)
194
: Mask(APInt::getZero(BitWidth)), MatchAndChain(MatchAnds) {}
195
};
196
197
/// This is a recursive helper for foldAnyOrAllBitsSet() that walks through a
198
/// chain of 'and' or 'or' instructions looking for shift ops of a common source
199
/// value. Examples:
200
/// or (or (or X, (X >> 3)), (X >> 5)), (X >> 8)
201
/// returns { X, 0x129 }
202
/// and (and (X >> 1), 1), (X >> 4)
203
/// returns { X, 0x12 }
204
static bool matchAndOrChain(Value *V, MaskOps &MOps) {
205
Value *Op0, *Op1;
206
if (MOps.MatchAndChain) {
207
// Recurse through a chain of 'and' operands. This requires an extra check
208
// vs. the 'or' matcher: we must find an "and X, 1" instruction somewhere
209
// in the chain to know that all of the high bits are cleared.
210
if (match(V, m_And(m_Value(Op0), m_One()))) {
211
MOps.FoundAnd1 = true;
212
return matchAndOrChain(Op0, MOps);
213
}
214
if (match(V, m_And(m_Value(Op0), m_Value(Op1))))
215
return matchAndOrChain(Op0, MOps) && matchAndOrChain(Op1, MOps);
216
} else {
217
// Recurse through a chain of 'or' operands.
218
if (match(V, m_Or(m_Value(Op0), m_Value(Op1))))
219
return matchAndOrChain(Op0, MOps) && matchAndOrChain(Op1, MOps);
220
}
221
222
// We need a shift-right or a bare value representing a compare of bit 0 of
223
// the original source operand.
224
Value *Candidate;
225
const APInt *BitIndex = nullptr;
226
if (!match(V, m_LShr(m_Value(Candidate), m_APInt(BitIndex))))
227
Candidate = V;
228
229
// Initialize result source operand.
230
if (!MOps.Root)
231
MOps.Root = Candidate;
232
233
// The shift constant is out-of-range? This code hasn't been simplified.
234
if (BitIndex && BitIndex->uge(MOps.Mask.getBitWidth()))
235
return false;
236
237
// Fill in the mask bit derived from the shift constant.
238
MOps.Mask.setBit(BitIndex ? BitIndex->getZExtValue() : 0);
239
return MOps.Root == Candidate;
240
}
241
242
/// Match patterns that correspond to "any-bits-set" and "all-bits-set".
243
/// These will include a chain of 'or' or 'and'-shifted bits from a
244
/// common source value:
245
/// and (or (lshr X, C), ...), 1 --> (X & CMask) != 0
246
/// and (and (lshr X, C), ...), 1 --> (X & CMask) == CMask
247
/// Note: "any-bits-clear" and "all-bits-clear" are variations of these patterns
248
/// that differ only with a final 'not' of the result. We expect that final
249
/// 'not' to be folded with the compare that we create here (invert predicate).
250
static bool foldAnyOrAllBitsSet(Instruction &I) {
251
// The 'any-bits-set' ('or' chain) pattern is simpler to match because the
252
// final "and X, 1" instruction must be the final op in the sequence.
253
bool MatchAllBitsSet;
254
if (match(&I, m_c_And(m_OneUse(m_And(m_Value(), m_Value())), m_Value())))
255
MatchAllBitsSet = true;
256
else if (match(&I, m_And(m_OneUse(m_Or(m_Value(), m_Value())), m_One())))
257
MatchAllBitsSet = false;
258
else
259
return false;
260
261
MaskOps MOps(I.getType()->getScalarSizeInBits(), MatchAllBitsSet);
262
if (MatchAllBitsSet) {
263
if (!matchAndOrChain(cast<BinaryOperator>(&I), MOps) || !MOps.FoundAnd1)
264
return false;
265
} else {
266
if (!matchAndOrChain(cast<BinaryOperator>(&I)->getOperand(0), MOps))
267
return false;
268
}
269
270
// The pattern was found. Create a masked compare that replaces all of the
271
// shift and logic ops.
272
IRBuilder<> Builder(&I);
273
Constant *Mask = ConstantInt::get(I.getType(), MOps.Mask);
274
Value *And = Builder.CreateAnd(MOps.Root, Mask);
275
Value *Cmp = MatchAllBitsSet ? Builder.CreateICmpEQ(And, Mask)
276
: Builder.CreateIsNotNull(And);
277
Value *Zext = Builder.CreateZExt(Cmp, I.getType());
278
I.replaceAllUsesWith(Zext);
279
++NumAnyOrAllBitsSet;
280
return true;
281
}
282
283
// Try to recognize below function as popcount intrinsic.
284
// This is the "best" algorithm from
285
// http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
286
// Also used in TargetLowering::expandCTPOP().
287
//
288
// int popcount(unsigned int i) {
289
// i = i - ((i >> 1) & 0x55555555);
290
// i = (i & 0x33333333) + ((i >> 2) & 0x33333333);
291
// i = ((i + (i >> 4)) & 0x0F0F0F0F);
292
// return (i * 0x01010101) >> 24;
293
// }
294
static bool tryToRecognizePopCount(Instruction &I) {
295
if (I.getOpcode() != Instruction::LShr)
296
return false;
297
298
Type *Ty = I.getType();
299
if (!Ty->isIntOrIntVectorTy())
300
return false;
301
302
unsigned Len = Ty->getScalarSizeInBits();
303
// FIXME: fix Len == 8 and other irregular type lengths.
304
if (!(Len <= 128 && Len > 8 && Len % 8 == 0))
305
return false;
306
307
APInt Mask55 = APInt::getSplat(Len, APInt(8, 0x55));
308
APInt Mask33 = APInt::getSplat(Len, APInt(8, 0x33));
309
APInt Mask0F = APInt::getSplat(Len, APInt(8, 0x0F));
310
APInt Mask01 = APInt::getSplat(Len, APInt(8, 0x01));
311
APInt MaskShift = APInt(Len, Len - 8);
312
313
Value *Op0 = I.getOperand(0);
314
Value *Op1 = I.getOperand(1);
315
Value *MulOp0;
316
// Matching "(i * 0x01010101...) >> 24".
317
if ((match(Op0, m_Mul(m_Value(MulOp0), m_SpecificInt(Mask01)))) &&
318
match(Op1, m_SpecificInt(MaskShift))) {
319
Value *ShiftOp0;
320
// Matching "((i + (i >> 4)) & 0x0F0F0F0F...)".
321
if (match(MulOp0, m_And(m_c_Add(m_LShr(m_Value(ShiftOp0), m_SpecificInt(4)),
322
m_Deferred(ShiftOp0)),
323
m_SpecificInt(Mask0F)))) {
324
Value *AndOp0;
325
// Matching "(i & 0x33333333...) + ((i >> 2) & 0x33333333...)".
326
if (match(ShiftOp0,
327
m_c_Add(m_And(m_Value(AndOp0), m_SpecificInt(Mask33)),
328
m_And(m_LShr(m_Deferred(AndOp0), m_SpecificInt(2)),
329
m_SpecificInt(Mask33))))) {
330
Value *Root, *SubOp1;
331
// Matching "i - ((i >> 1) & 0x55555555...)".
332
if (match(AndOp0, m_Sub(m_Value(Root), m_Value(SubOp1))) &&
333
match(SubOp1, m_And(m_LShr(m_Specific(Root), m_SpecificInt(1)),
334
m_SpecificInt(Mask55)))) {
335
LLVM_DEBUG(dbgs() << "Recognized popcount intrinsic\n");
336
IRBuilder<> Builder(&I);
337
Function *Func = Intrinsic::getDeclaration(
338
I.getModule(), Intrinsic::ctpop, I.getType());
339
I.replaceAllUsesWith(Builder.CreateCall(Func, {Root}));
340
++NumPopCountRecognized;
341
return true;
342
}
343
}
344
}
345
}
346
347
return false;
348
}
349
350
/// Fold smin(smax(fptosi(x), C1), C2) to llvm.fptosi.sat(x), providing C1 and
351
/// C2 saturate the value of the fp conversion. The transform is not reversable
352
/// as the fptosi.sat is more defined than the input - all values produce a
353
/// valid value for the fptosi.sat, where as some produce poison for original
354
/// that were out of range of the integer conversion. The reversed pattern may
355
/// use fmax and fmin instead. As we cannot directly reverse the transform, and
356
/// it is not always profitable, we make it conditional on the cost being
357
/// reported as lower by TTI.
358
static bool tryToFPToSat(Instruction &I, TargetTransformInfo &TTI) {
359
// Look for min(max(fptosi, converting to fptosi_sat.
360
Value *In;
361
const APInt *MinC, *MaxC;
362
if (!match(&I, m_SMax(m_OneUse(m_SMin(m_OneUse(m_FPToSI(m_Value(In))),
363
m_APInt(MinC))),
364
m_APInt(MaxC))) &&
365
!match(&I, m_SMin(m_OneUse(m_SMax(m_OneUse(m_FPToSI(m_Value(In))),
366
m_APInt(MaxC))),
367
m_APInt(MinC))))
368
return false;
369
370
// Check that the constants clamp a saturate.
371
if (!(*MinC + 1).isPowerOf2() || -*MaxC != *MinC + 1)
372
return false;
373
374
Type *IntTy = I.getType();
375
Type *FpTy = In->getType();
376
Type *SatTy =
377
IntegerType::get(IntTy->getContext(), (*MinC + 1).exactLogBase2() + 1);
378
if (auto *VecTy = dyn_cast<VectorType>(IntTy))
379
SatTy = VectorType::get(SatTy, VecTy->getElementCount());
380
381
// Get the cost of the intrinsic, and check that against the cost of
382
// fptosi+smin+smax
383
InstructionCost SatCost = TTI.getIntrinsicInstrCost(
384
IntrinsicCostAttributes(Intrinsic::fptosi_sat, SatTy, {In}, {FpTy}),
385
TTI::TCK_RecipThroughput);
386
SatCost += TTI.getCastInstrCost(Instruction::SExt, IntTy, SatTy,
387
TTI::CastContextHint::None,
388
TTI::TCK_RecipThroughput);
389
390
InstructionCost MinMaxCost = TTI.getCastInstrCost(
391
Instruction::FPToSI, IntTy, FpTy, TTI::CastContextHint::None,
392
TTI::TCK_RecipThroughput);
393
MinMaxCost += TTI.getIntrinsicInstrCost(
394
IntrinsicCostAttributes(Intrinsic::smin, IntTy, {IntTy}),
395
TTI::TCK_RecipThroughput);
396
MinMaxCost += TTI.getIntrinsicInstrCost(
397
IntrinsicCostAttributes(Intrinsic::smax, IntTy, {IntTy}),
398
TTI::TCK_RecipThroughput);
399
400
if (SatCost >= MinMaxCost)
401
return false;
402
403
IRBuilder<> Builder(&I);
404
Function *Fn = Intrinsic::getDeclaration(I.getModule(), Intrinsic::fptosi_sat,
405
{SatTy, FpTy});
406
Value *Sat = Builder.CreateCall(Fn, In);
407
I.replaceAllUsesWith(Builder.CreateSExt(Sat, IntTy));
408
return true;
409
}
410
411
/// Try to replace a mathlib call to sqrt with the LLVM intrinsic. This avoids
412
/// pessimistic codegen that has to account for setting errno and can enable
413
/// vectorization.
414
static bool foldSqrt(CallInst *Call, LibFunc Func, TargetTransformInfo &TTI,
415
TargetLibraryInfo &TLI, AssumptionCache &AC,
416
DominatorTree &DT) {
417
418
Module *M = Call->getModule();
419
420
// If (1) this is a sqrt libcall, (2) we can assume that NAN is not created
421
// (because NNAN or the operand arg must not be less than -0.0) and (2) we
422
// would not end up lowering to a libcall anyway (which could change the value
423
// of errno), then:
424
// (1) errno won't be set.
425
// (2) it is safe to convert this to an intrinsic call.
426
Type *Ty = Call->getType();
427
Value *Arg = Call->getArgOperand(0);
428
if (TTI.haveFastSqrt(Ty) &&
429
(Call->hasNoNaNs() ||
430
cannotBeOrderedLessThanZero(
431
Arg, 0,
432
SimplifyQuery(Call->getDataLayout(), &TLI, &DT, &AC, Call)))) {
433
IRBuilder<> Builder(Call);
434
IRBuilderBase::FastMathFlagGuard Guard(Builder);
435
Builder.setFastMathFlags(Call->getFastMathFlags());
436
437
Function *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, Ty);
438
Value *NewSqrt = Builder.CreateCall(Sqrt, Arg, "sqrt");
439
Call->replaceAllUsesWith(NewSqrt);
440
441
// Explicitly erase the old call because a call with side effects is not
442
// trivially dead.
443
Call->eraseFromParent();
444
return true;
445
}
446
447
return false;
448
}
449
450
// Check if this array of constants represents a cttz table.
451
// Iterate over the elements from \p Table by trying to find/match all
452
// the numbers from 0 to \p InputBits that should represent cttz results.
453
static bool isCTTZTable(const ConstantDataArray &Table, uint64_t Mul,
454
uint64_t Shift, uint64_t InputBits) {
455
unsigned Length = Table.getNumElements();
456
if (Length < InputBits || Length > InputBits * 2)
457
return false;
458
459
APInt Mask = APInt::getBitsSetFrom(InputBits, Shift);
460
unsigned Matched = 0;
461
462
for (unsigned i = 0; i < Length; i++) {
463
uint64_t Element = Table.getElementAsInteger(i);
464
if (Element >= InputBits)
465
continue;
466
467
// Check if \p Element matches a concrete answer. It could fail for some
468
// elements that are never accessed, so we keep iterating over each element
469
// from the table. The number of matched elements should be equal to the
470
// number of potential right answers which is \p InputBits actually.
471
if ((((Mul << Element) & Mask.getZExtValue()) >> Shift) == i)
472
Matched++;
473
}
474
475
return Matched == InputBits;
476
}
477
478
// Try to recognize table-based ctz implementation.
479
// E.g., an example in C (for more cases please see the llvm/tests):
480
// int f(unsigned x) {
481
// static const char table[32] =
482
// {0, 1, 28, 2, 29, 14, 24, 3, 30,
483
// 22, 20, 15, 25, 17, 4, 8, 31, 27,
484
// 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9};
485
// return table[((unsigned)((x & -x) * 0x077CB531U)) >> 27];
486
// }
487
// this can be lowered to `cttz` instruction.
488
// There is also a special case when the element is 0.
489
//
490
// Here are some examples or LLVM IR for a 64-bit target:
491
//
492
// CASE 1:
493
// %sub = sub i32 0, %x
494
// %and = and i32 %sub, %x
495
// %mul = mul i32 %and, 125613361
496
// %shr = lshr i32 %mul, 27
497
// %idxprom = zext i32 %shr to i64
498
// %arrayidx = getelementptr inbounds [32 x i8], [32 x i8]* @ctz1.table, i64 0,
499
// i64 %idxprom
500
// %0 = load i8, i8* %arrayidx, align 1, !tbaa !8
501
//
502
// CASE 2:
503
// %sub = sub i32 0, %x
504
// %and = and i32 %sub, %x
505
// %mul = mul i32 %and, 72416175
506
// %shr = lshr i32 %mul, 26
507
// %idxprom = zext i32 %shr to i64
508
// %arrayidx = getelementptr inbounds [64 x i16], [64 x i16]* @ctz2.table,
509
// i64 0, i64 %idxprom
510
// %0 = load i16, i16* %arrayidx, align 2, !tbaa !8
511
//
512
// CASE 3:
513
// %sub = sub i32 0, %x
514
// %and = and i32 %sub, %x
515
// %mul = mul i32 %and, 81224991
516
// %shr = lshr i32 %mul, 27
517
// %idxprom = zext i32 %shr to i64
518
// %arrayidx = getelementptr inbounds [32 x i32], [32 x i32]* @ctz3.table,
519
// i64 0, i64 %idxprom
520
// %0 = load i32, i32* %arrayidx, align 4, !tbaa !8
521
//
522
// CASE 4:
523
// %sub = sub i64 0, %x
524
// %and = and i64 %sub, %x
525
// %mul = mul i64 %and, 283881067100198605
526
// %shr = lshr i64 %mul, 58
527
// %arrayidx = getelementptr inbounds [64 x i8], [64 x i8]* @table, i64 0,
528
// i64 %shr
529
// %0 = load i8, i8* %arrayidx, align 1, !tbaa !8
530
//
531
// All this can be lowered to @llvm.cttz.i32/64 intrinsic.
532
static bool tryToRecognizeTableBasedCttz(Instruction &I) {
533
LoadInst *LI = dyn_cast<LoadInst>(&I);
534
if (!LI)
535
return false;
536
537
Type *AccessType = LI->getType();
538
if (!AccessType->isIntegerTy())
539
return false;
540
541
GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getPointerOperand());
542
if (!GEP || !GEP->isInBounds() || GEP->getNumIndices() != 2)
543
return false;
544
545
if (!GEP->getSourceElementType()->isArrayTy())
546
return false;
547
548
uint64_t ArraySize = GEP->getSourceElementType()->getArrayNumElements();
549
if (ArraySize != 32 && ArraySize != 64)
550
return false;
551
552
GlobalVariable *GVTable = dyn_cast<GlobalVariable>(GEP->getPointerOperand());
553
if (!GVTable || !GVTable->hasInitializer() || !GVTable->isConstant())
554
return false;
555
556
ConstantDataArray *ConstData =
557
dyn_cast<ConstantDataArray>(GVTable->getInitializer());
558
if (!ConstData)
559
return false;
560
561
if (!match(GEP->idx_begin()->get(), m_ZeroInt()))
562
return false;
563
564
Value *Idx2 = std::next(GEP->idx_begin())->get();
565
Value *X1;
566
uint64_t MulConst, ShiftConst;
567
// FIXME: 64-bit targets have `i64` type for the GEP index, so this match will
568
// probably fail for other (e.g. 32-bit) targets.
569
if (!match(Idx2, m_ZExtOrSelf(
570
m_LShr(m_Mul(m_c_And(m_Neg(m_Value(X1)), m_Deferred(X1)),
571
m_ConstantInt(MulConst)),
572
m_ConstantInt(ShiftConst)))))
573
return false;
574
575
unsigned InputBits = X1->getType()->getScalarSizeInBits();
576
if (InputBits != 32 && InputBits != 64)
577
return false;
578
579
// Shift should extract top 5..7 bits.
580
if (InputBits - Log2_32(InputBits) != ShiftConst &&
581
InputBits - Log2_32(InputBits) - 1 != ShiftConst)
582
return false;
583
584
if (!isCTTZTable(*ConstData, MulConst, ShiftConst, InputBits))
585
return false;
586
587
auto ZeroTableElem = ConstData->getElementAsInteger(0);
588
bool DefinedForZero = ZeroTableElem == InputBits;
589
590
IRBuilder<> B(LI);
591
ConstantInt *BoolConst = B.getInt1(!DefinedForZero);
592
Type *XType = X1->getType();
593
auto Cttz = B.CreateIntrinsic(Intrinsic::cttz, {XType}, {X1, BoolConst});
594
Value *ZExtOrTrunc = nullptr;
595
596
if (DefinedForZero) {
597
ZExtOrTrunc = B.CreateZExtOrTrunc(Cttz, AccessType);
598
} else {
599
// If the value in elem 0 isn't the same as InputBits, we still want to
600
// produce the value from the table.
601
auto Cmp = B.CreateICmpEQ(X1, ConstantInt::get(XType, 0));
602
auto Select =
603
B.CreateSelect(Cmp, ConstantInt::get(XType, ZeroTableElem), Cttz);
604
605
// NOTE: If the table[0] is 0, but the cttz(0) is defined by the Target
606
// it should be handled as: `cttz(x) & (typeSize - 1)`.
607
608
ZExtOrTrunc = B.CreateZExtOrTrunc(Select, AccessType);
609
}
610
611
LI->replaceAllUsesWith(ZExtOrTrunc);
612
613
return true;
614
}
615
616
/// This is used by foldLoadsRecursive() to capture a Root Load node which is
617
/// of type or(load, load) and recursively build the wide load. Also capture the
618
/// shift amount, zero extend type and loadSize.
619
struct LoadOps {
620
LoadInst *Root = nullptr;
621
LoadInst *RootInsert = nullptr;
622
bool FoundRoot = false;
623
uint64_t LoadSize = 0;
624
const APInt *Shift = nullptr;
625
Type *ZextType;
626
AAMDNodes AATags;
627
};
628
629
// Identify and Merge consecutive loads recursively which is of the form
630
// (ZExt(L1) << shift1) | (ZExt(L2) << shift2) -> ZExt(L3) << shift1
631
// (ZExt(L1) << shift1) | ZExt(L2) -> ZExt(L3)
632
static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
633
AliasAnalysis &AA) {
634
const APInt *ShAmt2 = nullptr;
635
Value *X;
636
Instruction *L1, *L2;
637
638
// Go to the last node with loads.
639
if (match(V, m_OneUse(m_c_Or(
640
m_Value(X),
641
m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))),
642
m_APInt(ShAmt2)))))) ||
643
match(V, m_OneUse(m_Or(m_Value(X),
644
m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))))))) {
645
if (!foldLoadsRecursive(X, LOps, DL, AA) && LOps.FoundRoot)
646
// Avoid Partial chain merge.
647
return false;
648
} else
649
return false;
650
651
// Check if the pattern has loads
652
LoadInst *LI1 = LOps.Root;
653
const APInt *ShAmt1 = LOps.Shift;
654
if (LOps.FoundRoot == false &&
655
(match(X, m_OneUse(m_ZExt(m_Instruction(L1)))) ||
656
match(X, m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L1)))),
657
m_APInt(ShAmt1)))))) {
658
LI1 = dyn_cast<LoadInst>(L1);
659
}
660
LoadInst *LI2 = dyn_cast<LoadInst>(L2);
661
662
// Check if loads are same, atomic, volatile and having same address space.
663
if (LI1 == LI2 || !LI1 || !LI2 || !LI1->isSimple() || !LI2->isSimple() ||
664
LI1->getPointerAddressSpace() != LI2->getPointerAddressSpace())
665
return false;
666
667
// Check if Loads come from same BB.
668
if (LI1->getParent() != LI2->getParent())
669
return false;
670
671
// Find the data layout
672
bool IsBigEndian = DL.isBigEndian();
673
674
// Check if loads are consecutive and same size.
675
Value *Load1Ptr = LI1->getPointerOperand();
676
APInt Offset1(DL.getIndexTypeSizeInBits(Load1Ptr->getType()), 0);
677
Load1Ptr =
678
Load1Ptr->stripAndAccumulateConstantOffsets(DL, Offset1,
679
/* AllowNonInbounds */ true);
680
681
Value *Load2Ptr = LI2->getPointerOperand();
682
APInt Offset2(DL.getIndexTypeSizeInBits(Load2Ptr->getType()), 0);
683
Load2Ptr =
684
Load2Ptr->stripAndAccumulateConstantOffsets(DL, Offset2,
685
/* AllowNonInbounds */ true);
686
687
// Verify if both loads have same base pointers and load sizes are same.
688
uint64_t LoadSize1 = LI1->getType()->getPrimitiveSizeInBits();
689
uint64_t LoadSize2 = LI2->getType()->getPrimitiveSizeInBits();
690
if (Load1Ptr != Load2Ptr || LoadSize1 != LoadSize2)
691
return false;
692
693
// Support Loadsizes greater or equal to 8bits and only power of 2.
694
if (LoadSize1 < 8 || !isPowerOf2_64(LoadSize1))
695
return false;
696
697
// Alias Analysis to check for stores b/w the loads.
698
LoadInst *Start = LOps.FoundRoot ? LOps.RootInsert : LI1, *End = LI2;
699
MemoryLocation Loc;
700
if (!Start->comesBefore(End)) {
701
std::swap(Start, End);
702
Loc = MemoryLocation::get(End);
703
if (LOps.FoundRoot)
704
Loc = Loc.getWithNewSize(LOps.LoadSize);
705
} else
706
Loc = MemoryLocation::get(End);
707
unsigned NumScanned = 0;
708
for (Instruction &Inst :
709
make_range(Start->getIterator(), End->getIterator())) {
710
if (Inst.mayWriteToMemory() && isModSet(AA.getModRefInfo(&Inst, Loc)))
711
return false;
712
713
// Ignore debug info so that's not counted against MaxInstrsToScan.
714
// Otherwise debug info could affect codegen.
715
if (!isa<DbgInfoIntrinsic>(Inst) && ++NumScanned > MaxInstrsToScan)
716
return false;
717
}
718
719
// Make sure Load with lower Offset is at LI1
720
bool Reverse = false;
721
if (Offset2.slt(Offset1)) {
722
std::swap(LI1, LI2);
723
std::swap(ShAmt1, ShAmt2);
724
std::swap(Offset1, Offset2);
725
std::swap(Load1Ptr, Load2Ptr);
726
std::swap(LoadSize1, LoadSize2);
727
Reverse = true;
728
}
729
730
// Big endian swap the shifts
731
if (IsBigEndian)
732
std::swap(ShAmt1, ShAmt2);
733
734
// Find Shifts values.
735
uint64_t Shift1 = 0, Shift2 = 0;
736
if (ShAmt1)
737
Shift1 = ShAmt1->getZExtValue();
738
if (ShAmt2)
739
Shift2 = ShAmt2->getZExtValue();
740
741
// First load is always LI1. This is where we put the new load.
742
// Use the merged load size available from LI1 for forward loads.
743
if (LOps.FoundRoot) {
744
if (!Reverse)
745
LoadSize1 = LOps.LoadSize;
746
else
747
LoadSize2 = LOps.LoadSize;
748
}
749
750
// Verify if shift amount and load index aligns and verifies that loads
751
// are consecutive.
752
uint64_t ShiftDiff = IsBigEndian ? LoadSize2 : LoadSize1;
753
uint64_t PrevSize =
754
DL.getTypeStoreSize(IntegerType::get(LI1->getContext(), LoadSize1));
755
if ((Shift2 - Shift1) != ShiftDiff || (Offset2 - Offset1) != PrevSize)
756
return false;
757
758
// Update LOps
759
AAMDNodes AATags1 = LOps.AATags;
760
AAMDNodes AATags2 = LI2->getAAMetadata();
761
if (LOps.FoundRoot == false) {
762
LOps.FoundRoot = true;
763
AATags1 = LI1->getAAMetadata();
764
}
765
LOps.LoadSize = LoadSize1 + LoadSize2;
766
LOps.RootInsert = Start;
767
768
// Concatenate the AATags of the Merged Loads.
769
LOps.AATags = AATags1.concat(AATags2);
770
771
LOps.Root = LI1;
772
LOps.Shift = ShAmt1;
773
LOps.ZextType = X->getType();
774
return true;
775
}
776
777
// For a given BB instruction, evaluate all loads in the chain that form a
778
// pattern which suggests that the loads can be combined. The one and only use
779
// of the loads is to form a wider load.
780
static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL,
781
TargetTransformInfo &TTI, AliasAnalysis &AA,
782
const DominatorTree &DT) {
783
// Only consider load chains of scalar values.
784
if (isa<VectorType>(I.getType()))
785
return false;
786
787
LoadOps LOps;
788
if (!foldLoadsRecursive(&I, LOps, DL, AA) || !LOps.FoundRoot)
789
return false;
790
791
IRBuilder<> Builder(&I);
792
LoadInst *NewLoad = nullptr, *LI1 = LOps.Root;
793
794
IntegerType *WiderType = IntegerType::get(I.getContext(), LOps.LoadSize);
795
// TTI based checks if we want to proceed with wider load
796
bool Allowed = TTI.isTypeLegal(WiderType);
797
if (!Allowed)
798
return false;
799
800
unsigned AS = LI1->getPointerAddressSpace();
801
unsigned Fast = 0;
802
Allowed = TTI.allowsMisalignedMemoryAccesses(I.getContext(), LOps.LoadSize,
803
AS, LI1->getAlign(), &Fast);
804
if (!Allowed || !Fast)
805
return false;
806
807
// Get the Index and Ptr for the new GEP.
808
Value *Load1Ptr = LI1->getPointerOperand();
809
Builder.SetInsertPoint(LOps.RootInsert);
810
if (!DT.dominates(Load1Ptr, LOps.RootInsert)) {
811
APInt Offset1(DL.getIndexTypeSizeInBits(Load1Ptr->getType()), 0);
812
Load1Ptr = Load1Ptr->stripAndAccumulateConstantOffsets(
813
DL, Offset1, /* AllowNonInbounds */ true);
814
Load1Ptr = Builder.CreatePtrAdd(Load1Ptr, Builder.getInt(Offset1));
815
}
816
// Generate wider load.
817
NewLoad = Builder.CreateAlignedLoad(WiderType, Load1Ptr, LI1->getAlign(),
818
LI1->isVolatile(), "");
819
NewLoad->takeName(LI1);
820
// Set the New Load AATags Metadata.
821
if (LOps.AATags)
822
NewLoad->setAAMetadata(LOps.AATags);
823
824
Value *NewOp = NewLoad;
825
// Check if zero extend needed.
826
if (LOps.ZextType)
827
NewOp = Builder.CreateZExt(NewOp, LOps.ZextType);
828
829
// Check if shift needed. We need to shift with the amount of load1
830
// shift if not zero.
831
if (LOps.Shift)
832
NewOp = Builder.CreateShl(NewOp, ConstantInt::get(I.getContext(), *LOps.Shift));
833
I.replaceAllUsesWith(NewOp);
834
835
return true;
836
}
837
838
// Calculate GEP Stride and accumulated const ModOffset. Return Stride and
839
// ModOffset
840
static std::pair<APInt, APInt>
841
getStrideAndModOffsetOfGEP(Value *PtrOp, const DataLayout &DL) {
842
unsigned BW = DL.getIndexTypeSizeInBits(PtrOp->getType());
843
std::optional<APInt> Stride;
844
APInt ModOffset(BW, 0);
845
// Return a minimum gep stride, greatest common divisor of consective gep
846
// index scales(c.f. Bézout's identity).
847
while (auto *GEP = dyn_cast<GEPOperator>(PtrOp)) {
848
MapVector<Value *, APInt> VarOffsets;
849
if (!GEP->collectOffset(DL, BW, VarOffsets, ModOffset))
850
break;
851
852
for (auto [V, Scale] : VarOffsets) {
853
// Only keep a power of two factor for non-inbounds
854
if (!GEP->isInBounds())
855
Scale = APInt::getOneBitSet(Scale.getBitWidth(), Scale.countr_zero());
856
857
if (!Stride)
858
Stride = Scale;
859
else
860
Stride = APIntOps::GreatestCommonDivisor(*Stride, Scale);
861
}
862
863
PtrOp = GEP->getPointerOperand();
864
}
865
866
// Check whether pointer arrives back at Global Variable via at least one GEP.
867
// Even if it doesn't, we can check by alignment.
868
if (!isa<GlobalVariable>(PtrOp) || !Stride)
869
return {APInt(BW, 1), APInt(BW, 0)};
870
871
// In consideration of signed GEP indices, non-negligible offset become
872
// remainder of division by minimum GEP stride.
873
ModOffset = ModOffset.srem(*Stride);
874
if (ModOffset.isNegative())
875
ModOffset += *Stride;
876
877
return {*Stride, ModOffset};
878
}
879
880
/// If C is a constant patterned array and all valid loaded results for given
881
/// alignment are same to a constant, return that constant.
882
static bool foldPatternedLoads(Instruction &I, const DataLayout &DL) {
883
auto *LI = dyn_cast<LoadInst>(&I);
884
if (!LI || LI->isVolatile())
885
return false;
886
887
// We can only fold the load if it is from a constant global with definitive
888
// initializer. Skip expensive logic if this is not the case.
889
auto *PtrOp = LI->getPointerOperand();
890
auto *GV = dyn_cast<GlobalVariable>(getUnderlyingObject(PtrOp));
891
if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer())
892
return false;
893
894
// Bail for large initializers in excess of 4K to avoid too many scans.
895
Constant *C = GV->getInitializer();
896
uint64_t GVSize = DL.getTypeAllocSize(C->getType());
897
if (!GVSize || 4096 < GVSize)
898
return false;
899
900
Type *LoadTy = LI->getType();
901
unsigned BW = DL.getIndexTypeSizeInBits(PtrOp->getType());
902
auto [Stride, ConstOffset] = getStrideAndModOffsetOfGEP(PtrOp, DL);
903
904
// Any possible offset could be multiple of GEP stride. And any valid
905
// offset is multiple of load alignment, so checking only multiples of bigger
906
// one is sufficient to say results' equality.
907
if (auto LA = LI->getAlign();
908
LA <= GV->getAlign().valueOrOne() && Stride.getZExtValue() < LA.value()) {
909
ConstOffset = APInt(BW, 0);
910
Stride = APInt(BW, LA.value());
911
}
912
913
Constant *Ca = ConstantFoldLoadFromConst(C, LoadTy, ConstOffset, DL);
914
if (!Ca)
915
return false;
916
917
unsigned E = GVSize - DL.getTypeStoreSize(LoadTy);
918
for (; ConstOffset.getZExtValue() <= E; ConstOffset += Stride)
919
if (Ca != ConstantFoldLoadFromConst(C, LoadTy, ConstOffset, DL))
920
return false;
921
922
I.replaceAllUsesWith(Ca);
923
924
return true;
925
}
926
927
namespace {
928
class StrNCmpInliner {
929
public:
930
StrNCmpInliner(CallInst *CI, LibFunc Func, DomTreeUpdater *DTU,
931
const DataLayout &DL)
932
: CI(CI), Func(Func), DTU(DTU), DL(DL) {}
933
934
bool optimizeStrNCmp();
935
936
private:
937
void inlineCompare(Value *LHS, StringRef RHS, uint64_t N, bool Swapped);
938
939
CallInst *CI;
940
LibFunc Func;
941
DomTreeUpdater *DTU;
942
const DataLayout &DL;
943
};
944
945
} // namespace
946
947
/// First we normalize calls to strncmp/strcmp to the form of
948
/// compare(s1, s2, N), which means comparing first N bytes of s1 and s2
949
/// (without considering '\0').
950
///
951
/// Examples:
952
///
953
/// \code
954
/// strncmp(s, "a", 3) -> compare(s, "a", 2)
955
/// strncmp(s, "abc", 3) -> compare(s, "abc", 3)
956
/// strncmp(s, "a\0b", 3) -> compare(s, "a\0b", 2)
957
/// strcmp(s, "a") -> compare(s, "a", 2)
958
///
959
/// char s2[] = {'a'}
960
/// strncmp(s, s2, 3) -> compare(s, s2, 3)
961
///
962
/// char s2[] = {'a', 'b', 'c', 'd'}
963
/// strncmp(s, s2, 3) -> compare(s, s2, 3)
964
/// \endcode
965
///
966
/// We only handle cases where N and exactly one of s1 and s2 are constant.
967
/// Cases that s1 and s2 are both constant are already handled by the
968
/// instcombine pass.
969
///
970
/// We do not handle cases where N > StrNCmpInlineThreshold.
971
///
972
/// We also do not handles cases where N < 2, which are already
973
/// handled by the instcombine pass.
974
///
975
bool StrNCmpInliner::optimizeStrNCmp() {
976
if (StrNCmpInlineThreshold < 2)
977
return false;
978
979
if (!isOnlyUsedInZeroComparison(CI))
980
return false;
981
982
Value *Str1P = CI->getArgOperand(0);
983
Value *Str2P = CI->getArgOperand(1);
984
// Should be handled elsewhere.
985
if (Str1P == Str2P)
986
return false;
987
988
StringRef Str1, Str2;
989
bool HasStr1 = getConstantStringInfo(Str1P, Str1, /*TrimAtNul=*/false);
990
bool HasStr2 = getConstantStringInfo(Str2P, Str2, /*TrimAtNul=*/false);
991
if (HasStr1 == HasStr2)
992
return false;
993
994
// Note that '\0' and characters after it are not trimmed.
995
StringRef Str = HasStr1 ? Str1 : Str2;
996
Value *StrP = HasStr1 ? Str2P : Str1P;
997
998
size_t Idx = Str.find('\0');
999
uint64_t N = Idx == StringRef::npos ? UINT64_MAX : Idx + 1;
1000
if (Func == LibFunc_strncmp) {
1001
if (auto *ConstInt = dyn_cast<ConstantInt>(CI->getArgOperand(2)))
1002
N = std::min(N, ConstInt->getZExtValue());
1003
else
1004
return false;
1005
}
1006
// Now N means how many bytes we need to compare at most.
1007
if (N > Str.size() || N < 2 || N > StrNCmpInlineThreshold)
1008
return false;
1009
1010
// Cases where StrP has two or more dereferenceable bytes might be better
1011
// optimized elsewhere.
1012
bool CanBeNull = false, CanBeFreed = false;
1013
if (StrP->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed) > 1)
1014
return false;
1015
inlineCompare(StrP, Str, N, HasStr1);
1016
return true;
1017
}
1018
1019
/// Convert
1020
///
1021
/// \code
1022
/// ret = compare(s1, s2, N)
1023
/// \endcode
1024
///
1025
/// into
1026
///
1027
/// \code
1028
/// ret = (int)s1[0] - (int)s2[0]
1029
/// if (ret != 0)
1030
/// goto NE
1031
/// ...
1032
/// ret = (int)s1[N-2] - (int)s2[N-2]
1033
/// if (ret != 0)
1034
/// goto NE
1035
/// ret = (int)s1[N-1] - (int)s2[N-1]
1036
/// NE:
1037
/// \endcode
1038
///
1039
/// CFG before and after the transformation:
1040
///
1041
/// (before)
1042
/// BBCI
1043
///
1044
/// (after)
1045
/// BBCI -> BBSubs[0] (sub,icmp) --NE-> BBNE -> BBTail
1046
/// | ^
1047
/// E |
1048
/// | |
1049
/// BBSubs[1] (sub,icmp) --NE-----+
1050
/// ... |
1051
/// BBSubs[N-1] (sub) ---------+
1052
///
1053
void StrNCmpInliner::inlineCompare(Value *LHS, StringRef RHS, uint64_t N,
1054
bool Swapped) {
1055
auto &Ctx = CI->getContext();
1056
IRBuilder<> B(Ctx);
1057
1058
BasicBlock *BBCI = CI->getParent();
1059
BasicBlock *BBTail =
1060
SplitBlock(BBCI, CI, DTU, nullptr, nullptr, BBCI->getName() + ".tail");
1061
1062
SmallVector<BasicBlock *> BBSubs;
1063
for (uint64_t I = 0; I < N; ++I)
1064
BBSubs.push_back(
1065
BasicBlock::Create(Ctx, "sub_" + Twine(I), BBCI->getParent(), BBTail));
1066
BasicBlock *BBNE = BasicBlock::Create(Ctx, "ne", BBCI->getParent(), BBTail);
1067
1068
cast<BranchInst>(BBCI->getTerminator())->setSuccessor(0, BBSubs[0]);
1069
1070
B.SetInsertPoint(BBNE);
1071
PHINode *Phi = B.CreatePHI(CI->getType(), N);
1072
B.CreateBr(BBTail);
1073
1074
Value *Base = LHS;
1075
for (uint64_t i = 0; i < N; ++i) {
1076
B.SetInsertPoint(BBSubs[i]);
1077
Value *VL =
1078
B.CreateZExt(B.CreateLoad(B.getInt8Ty(),
1079
B.CreateInBoundsPtrAdd(Base, B.getInt64(i))),
1080
CI->getType());
1081
Value *VR =
1082
ConstantInt::get(CI->getType(), static_cast<unsigned char>(RHS[i]));
1083
Value *Sub = Swapped ? B.CreateSub(VR, VL) : B.CreateSub(VL, VR);
1084
if (i < N - 1)
1085
B.CreateCondBr(B.CreateICmpNE(Sub, ConstantInt::get(CI->getType(), 0)),
1086
BBNE, BBSubs[i + 1]);
1087
else
1088
B.CreateBr(BBNE);
1089
1090
Phi->addIncoming(Sub, BBSubs[i]);
1091
}
1092
1093
CI->replaceAllUsesWith(Phi);
1094
CI->eraseFromParent();
1095
1096
if (DTU) {
1097
SmallVector<DominatorTree::UpdateType, 8> Updates;
1098
Updates.push_back({DominatorTree::Insert, BBCI, BBSubs[0]});
1099
for (uint64_t i = 0; i < N; ++i) {
1100
if (i < N - 1)
1101
Updates.push_back({DominatorTree::Insert, BBSubs[i], BBSubs[i + 1]});
1102
Updates.push_back({DominatorTree::Insert, BBSubs[i], BBNE});
1103
}
1104
Updates.push_back({DominatorTree::Insert, BBNE, BBTail});
1105
Updates.push_back({DominatorTree::Delete, BBCI, BBTail});
1106
DTU->applyUpdates(Updates);
1107
}
1108
}
1109
1110
/// Convert memchr with a small constant string into a switch
1111
static bool foldMemChr(CallInst *Call, DomTreeUpdater *DTU,
1112
const DataLayout &DL) {
1113
if (isa<Constant>(Call->getArgOperand(1)))
1114
return false;
1115
1116
StringRef Str;
1117
Value *Base = Call->getArgOperand(0);
1118
if (!getConstantStringInfo(Base, Str, /*TrimAtNul=*/false))
1119
return false;
1120
1121
uint64_t N = Str.size();
1122
if (auto *ConstInt = dyn_cast<ConstantInt>(Call->getArgOperand(2))) {
1123
uint64_t Val = ConstInt->getZExtValue();
1124
// Ignore the case that n is larger than the size of string.
1125
if (Val > N)
1126
return false;
1127
N = Val;
1128
} else
1129
return false;
1130
1131
if (N > MemChrInlineThreshold)
1132
return false;
1133
1134
BasicBlock *BB = Call->getParent();
1135
BasicBlock *BBNext = SplitBlock(BB, Call, DTU);
1136
IRBuilder<> IRB(BB);
1137
IntegerType *ByteTy = IRB.getInt8Ty();
1138
BB->getTerminator()->eraseFromParent();
1139
SwitchInst *SI = IRB.CreateSwitch(
1140
IRB.CreateTrunc(Call->getArgOperand(1), ByteTy), BBNext, N);
1141
Type *IndexTy = DL.getIndexType(Call->getType());
1142
SmallVector<DominatorTree::UpdateType, 8> Updates;
1143
1144
BasicBlock *BBSuccess = BasicBlock::Create(
1145
Call->getContext(), "memchr.success", BB->getParent(), BBNext);
1146
IRB.SetInsertPoint(BBSuccess);
1147
PHINode *IndexPHI = IRB.CreatePHI(IndexTy, N, "memchr.idx");
1148
Value *FirstOccursLocation = IRB.CreateInBoundsPtrAdd(Base, IndexPHI);
1149
IRB.CreateBr(BBNext);
1150
if (DTU)
1151
Updates.push_back({DominatorTree::Insert, BBSuccess, BBNext});
1152
1153
SmallPtrSet<ConstantInt *, 4> Cases;
1154
for (uint64_t I = 0; I < N; ++I) {
1155
ConstantInt *CaseVal = ConstantInt::get(ByteTy, Str[I]);
1156
if (!Cases.insert(CaseVal).second)
1157
continue;
1158
1159
BasicBlock *BBCase = BasicBlock::Create(Call->getContext(), "memchr.case",
1160
BB->getParent(), BBSuccess);
1161
SI->addCase(CaseVal, BBCase);
1162
IRB.SetInsertPoint(BBCase);
1163
IndexPHI->addIncoming(ConstantInt::get(IndexTy, I), BBCase);
1164
IRB.CreateBr(BBSuccess);
1165
if (DTU) {
1166
Updates.push_back({DominatorTree::Insert, BB, BBCase});
1167
Updates.push_back({DominatorTree::Insert, BBCase, BBSuccess});
1168
}
1169
}
1170
1171
PHINode *PHI =
1172
PHINode::Create(Call->getType(), 2, Call->getName(), BBNext->begin());
1173
PHI->addIncoming(Constant::getNullValue(Call->getType()), BB);
1174
PHI->addIncoming(FirstOccursLocation, BBSuccess);
1175
1176
Call->replaceAllUsesWith(PHI);
1177
Call->eraseFromParent();
1178
1179
if (DTU)
1180
DTU->applyUpdates(Updates);
1181
1182
return true;
1183
}
1184
1185
static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI,
1186
TargetLibraryInfo &TLI, AssumptionCache &AC,
1187
DominatorTree &DT, const DataLayout &DL,
1188
bool &MadeCFGChange) {
1189
1190
auto *CI = dyn_cast<CallInst>(&I);
1191
if (!CI || CI->isNoBuiltin())
1192
return false;
1193
1194
Function *CalledFunc = CI->getCalledFunction();
1195
if (!CalledFunc)
1196
return false;
1197
1198
LibFunc LF;
1199
if (!TLI.getLibFunc(*CalledFunc, LF) ||
1200
!isLibFuncEmittable(CI->getModule(), &TLI, LF))
1201
return false;
1202
1203
DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Lazy);
1204
1205
switch (LF) {
1206
case LibFunc_sqrt:
1207
case LibFunc_sqrtf:
1208
case LibFunc_sqrtl:
1209
return foldSqrt(CI, LF, TTI, TLI, AC, DT);
1210
case LibFunc_strcmp:
1211
case LibFunc_strncmp:
1212
if (StrNCmpInliner(CI, LF, &DTU, DL).optimizeStrNCmp()) {
1213
MadeCFGChange = true;
1214
return true;
1215
}
1216
break;
1217
case LibFunc_memchr:
1218
if (foldMemChr(CI, &DTU, DL)) {
1219
MadeCFGChange = true;
1220
return true;
1221
}
1222
break;
1223
default:;
1224
}
1225
return false;
1226
}
1227
1228
/// This is the entry point for folds that could be implemented in regular
1229
/// InstCombine, but they are separated because they are not expected to
1230
/// occur frequently and/or have more than a constant-length pattern match.
1231
static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
1232
TargetTransformInfo &TTI,
1233
TargetLibraryInfo &TLI, AliasAnalysis &AA,
1234
AssumptionCache &AC, bool &MadeCFGChange) {
1235
bool MadeChange = false;
1236
for (BasicBlock &BB : F) {
1237
// Ignore unreachable basic blocks.
1238
if (!DT.isReachableFromEntry(&BB))
1239
continue;
1240
1241
const DataLayout &DL = F.getDataLayout();
1242
1243
// Walk the block backwards for efficiency. We're matching a chain of
1244
// use->defs, so we're more likely to succeed by starting from the bottom.
1245
// Also, we want to avoid matching partial patterns.
1246
// TODO: It would be more efficient if we removed dead instructions
1247
// iteratively in this loop rather than waiting until the end.
1248
for (Instruction &I : make_early_inc_range(llvm::reverse(BB))) {
1249
MadeChange |= foldAnyOrAllBitsSet(I);
1250
MadeChange |= foldGuardedFunnelShift(I, DT);
1251
MadeChange |= tryToRecognizePopCount(I);
1252
MadeChange |= tryToFPToSat(I, TTI);
1253
MadeChange |= tryToRecognizeTableBasedCttz(I);
1254
MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT);
1255
MadeChange |= foldPatternedLoads(I, DL);
1256
// NOTE: This function introduces erasing of the instruction `I`, so it
1257
// needs to be called at the end of this sequence, otherwise we may make
1258
// bugs.
1259
MadeChange |= foldLibCalls(I, TTI, TLI, AC, DT, DL, MadeCFGChange);
1260
}
1261
}
1262
1263
// We're done with transforms, so remove dead instructions.
1264
if (MadeChange)
1265
for (BasicBlock &BB : F)
1266
SimplifyInstructionsInBlock(&BB);
1267
1268
return MadeChange;
1269
}
1270
1271
/// This is the entry point for all transforms. Pass manager differences are
1272
/// handled in the callers of this function.
1273
static bool runImpl(Function &F, AssumptionCache &AC, TargetTransformInfo &TTI,
1274
TargetLibraryInfo &TLI, DominatorTree &DT,
1275
AliasAnalysis &AA, bool &MadeCFGChange) {
1276
bool MadeChange = false;
1277
const DataLayout &DL = F.getDataLayout();
1278
TruncInstCombine TIC(AC, TLI, DL, DT);
1279
MadeChange |= TIC.run(F);
1280
MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI, AA, AC, MadeCFGChange);
1281
return MadeChange;
1282
}
1283
1284
PreservedAnalyses AggressiveInstCombinePass::run(Function &F,
1285
FunctionAnalysisManager &AM) {
1286
auto &AC = AM.getResult<AssumptionAnalysis>(F);
1287
auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
1288
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
1289
auto &TTI = AM.getResult<TargetIRAnalysis>(F);
1290
auto &AA = AM.getResult<AAManager>(F);
1291
bool MadeCFGChange = false;
1292
if (!runImpl(F, AC, TTI, TLI, DT, AA, MadeCFGChange)) {
1293
// No changes, all analyses are preserved.
1294
return PreservedAnalyses::all();
1295
}
1296
// Mark all the analyses that instcombine updates as preserved.
1297
PreservedAnalyses PA;
1298
if (MadeCFGChange)
1299
PA.preserve<DominatorTreeAnalysis>();
1300
else
1301
PA.preserveSet<CFGAnalyses>();
1302
return PA;
1303
}
1304
1305