Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopConstrainer.cpp
35271 views
1
#include "llvm/Transforms/Utils/LoopConstrainer.h"
2
#include "llvm/Analysis/LoopInfo.h"
3
#include "llvm/Analysis/ScalarEvolution.h"
4
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
5
#include "llvm/IR/Dominators.h"
6
#include "llvm/Transforms/Utils/Cloning.h"
7
#include "llvm/Transforms/Utils/LoopSimplify.h"
8
#include "llvm/Transforms/Utils/LoopUtils.h"
9
#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
10
11
using namespace llvm;
12
13
static const char *ClonedLoopTag = "loop_constrainer.loop.clone";
14
15
#define DEBUG_TYPE "loop-constrainer"
16
17
/// Given a loop with an deccreasing induction variable, is it possible to
18
/// safely calculate the bounds of a new loop using the given Predicate.
19
static bool isSafeDecreasingBound(const SCEV *Start, const SCEV *BoundSCEV,
20
const SCEV *Step, ICmpInst::Predicate Pred,
21
unsigned LatchBrExitIdx, Loop *L,
22
ScalarEvolution &SE) {
23
if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&
24
Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)
25
return false;
26
27
if (!SE.isAvailableAtLoopEntry(BoundSCEV, L))
28
return false;
29
30
assert(SE.isKnownNegative(Step) && "expecting negative step");
31
32
LLVM_DEBUG(dbgs() << "isSafeDecreasingBound with:\n");
33
LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n");
34
LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n");
35
LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n");
36
LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n");
37
LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n");
38
39
bool IsSigned = ICmpInst::isSigned(Pred);
40
// The predicate that we need to check that the induction variable lies
41
// within bounds.
42
ICmpInst::Predicate BoundPred =
43
IsSigned ? CmpInst::ICMP_SGT : CmpInst::ICMP_UGT;
44
45
auto StartLG = SE.applyLoopGuards(Start, L);
46
auto BoundLG = SE.applyLoopGuards(BoundSCEV, L);
47
48
if (LatchBrExitIdx == 1)
49
return SE.isLoopEntryGuardedByCond(L, BoundPred, StartLG, BoundLG);
50
51
assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be either 0 or 1");
52
53
const SCEV *StepPlusOne = SE.getAddExpr(Step, SE.getOne(Step->getType()));
54
unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth();
55
APInt Min = IsSigned ? APInt::getSignedMinValue(BitWidth)
56
: APInt::getMinValue(BitWidth);
57
const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Min), StepPlusOne);
58
59
const SCEV *MinusOne =
60
SE.getMinusSCEV(BoundLG, SE.getOne(BoundLG->getType()));
61
62
return SE.isLoopEntryGuardedByCond(L, BoundPred, StartLG, MinusOne) &&
63
SE.isLoopEntryGuardedByCond(L, BoundPred, BoundLG, Limit);
64
}
65
66
/// Given a loop with an increasing induction variable, is it possible to
67
/// safely calculate the bounds of a new loop using the given Predicate.
68
static bool isSafeIncreasingBound(const SCEV *Start, const SCEV *BoundSCEV,
69
const SCEV *Step, ICmpInst::Predicate Pred,
70
unsigned LatchBrExitIdx, Loop *L,
71
ScalarEvolution &SE) {
72
if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&
73
Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)
74
return false;
75
76
if (!SE.isAvailableAtLoopEntry(BoundSCEV, L))
77
return false;
78
79
LLVM_DEBUG(dbgs() << "isSafeIncreasingBound with:\n");
80
LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n");
81
LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n");
82
LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n");
83
LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n");
84
LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n");
85
86
bool IsSigned = ICmpInst::isSigned(Pred);
87
// The predicate that we need to check that the induction variable lies
88
// within bounds.
89
ICmpInst::Predicate BoundPred =
90
IsSigned ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT;
91
92
auto StartLG = SE.applyLoopGuards(Start, L);
93
auto BoundLG = SE.applyLoopGuards(BoundSCEV, L);
94
95
if (LatchBrExitIdx == 1)
96
return SE.isLoopEntryGuardedByCond(L, BoundPred, StartLG, BoundLG);
97
98
assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be 0 or 1");
99
100
const SCEV *StepMinusOne = SE.getMinusSCEV(Step, SE.getOne(Step->getType()));
101
unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth();
102
APInt Max = IsSigned ? APInt::getSignedMaxValue(BitWidth)
103
: APInt::getMaxValue(BitWidth);
104
const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Max), StepMinusOne);
105
106
return (SE.isLoopEntryGuardedByCond(L, BoundPred, StartLG,
107
SE.getAddExpr(BoundLG, Step)) &&
108
SE.isLoopEntryGuardedByCond(L, BoundPred, BoundLG, Limit));
109
}
110
111
/// Returns estimate for max latch taken count of the loop of the narrowest
112
/// available type. If the latch block has such estimate, it is returned.
113
/// Otherwise, we use max exit count of whole loop (that is potentially of wider
114
/// type than latch check itself), which is still better than no estimate.
115
static const SCEV *getNarrowestLatchMaxTakenCountEstimate(ScalarEvolution &SE,
116
const Loop &L) {
117
const SCEV *FromBlock =
118
SE.getExitCount(&L, L.getLoopLatch(), ScalarEvolution::SymbolicMaximum);
119
if (isa<SCEVCouldNotCompute>(FromBlock))
120
return SE.getSymbolicMaxBackedgeTakenCount(&L);
121
return FromBlock;
122
}
123
124
std::optional<LoopStructure>
125
LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L,
126
bool AllowUnsignedLatchCond,
127
const char *&FailureReason) {
128
if (!L.isLoopSimplifyForm()) {
129
FailureReason = "loop not in LoopSimplify form";
130
return std::nullopt;
131
}
132
133
BasicBlock *Latch = L.getLoopLatch();
134
assert(Latch && "Simplified loops only have one latch!");
135
136
if (Latch->getTerminator()->getMetadata(ClonedLoopTag)) {
137
FailureReason = "loop has already been cloned";
138
return std::nullopt;
139
}
140
141
if (!L.isLoopExiting(Latch)) {
142
FailureReason = "no loop latch";
143
return std::nullopt;
144
}
145
146
BasicBlock *Header = L.getHeader();
147
BasicBlock *Preheader = L.getLoopPreheader();
148
if (!Preheader) {
149
FailureReason = "no preheader";
150
return std::nullopt;
151
}
152
153
BranchInst *LatchBr = dyn_cast<BranchInst>(Latch->getTerminator());
154
if (!LatchBr || LatchBr->isUnconditional()) {
155
FailureReason = "latch terminator not conditional branch";
156
return std::nullopt;
157
}
158
159
unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0;
160
161
ICmpInst *ICI = dyn_cast<ICmpInst>(LatchBr->getCondition());
162
if (!ICI || !isa<IntegerType>(ICI->getOperand(0)->getType())) {
163
FailureReason = "latch terminator branch not conditional on integral icmp";
164
return std::nullopt;
165
}
166
167
const SCEV *MaxBETakenCount = getNarrowestLatchMaxTakenCountEstimate(SE, L);
168
if (isa<SCEVCouldNotCompute>(MaxBETakenCount)) {
169
FailureReason = "could not compute latch count";
170
return std::nullopt;
171
}
172
assert(SE.getLoopDisposition(MaxBETakenCount, &L) ==
173
ScalarEvolution::LoopInvariant &&
174
"loop variant exit count doesn't make sense!");
175
176
ICmpInst::Predicate Pred = ICI->getPredicate();
177
Value *LeftValue = ICI->getOperand(0);
178
const SCEV *LeftSCEV = SE.getSCEV(LeftValue);
179
IntegerType *IndVarTy = cast<IntegerType>(LeftValue->getType());
180
181
Value *RightValue = ICI->getOperand(1);
182
const SCEV *RightSCEV = SE.getSCEV(RightValue);
183
184
// We canonicalize `ICI` such that `LeftSCEV` is an add recurrence.
185
if (!isa<SCEVAddRecExpr>(LeftSCEV)) {
186
if (isa<SCEVAddRecExpr>(RightSCEV)) {
187
std::swap(LeftSCEV, RightSCEV);
188
std::swap(LeftValue, RightValue);
189
Pred = ICmpInst::getSwappedPredicate(Pred);
190
} else {
191
FailureReason = "no add recurrences in the icmp";
192
return std::nullopt;
193
}
194
}
195
196
auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) {
197
if (AR->getNoWrapFlags(SCEV::FlagNSW))
198
return true;
199
200
IntegerType *Ty = cast<IntegerType>(AR->getType());
201
IntegerType *WideTy =
202
IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2);
203
204
const SCEVAddRecExpr *ExtendAfterOp =
205
dyn_cast<SCEVAddRecExpr>(SE.getSignExtendExpr(AR, WideTy));
206
if (ExtendAfterOp) {
207
const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy);
208
const SCEV *ExtendedStep =
209
SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy);
210
211
bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart &&
212
ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep;
213
214
if (NoSignedWrap)
215
return true;
216
}
217
218
// We may have proved this when computing the sign extension above.
219
return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap;
220
};
221
222
// `ICI` is interpreted as taking the backedge if the *next* value of the
223
// induction variable satisfies some constraint.
224
225
const SCEVAddRecExpr *IndVarBase = cast<SCEVAddRecExpr>(LeftSCEV);
226
if (IndVarBase->getLoop() != &L) {
227
FailureReason = "LHS in cmp is not an AddRec for this loop";
228
return std::nullopt;
229
}
230
if (!IndVarBase->isAffine()) {
231
FailureReason = "LHS in icmp not induction variable";
232
return std::nullopt;
233
}
234
const SCEV *StepRec = IndVarBase->getStepRecurrence(SE);
235
if (!isa<SCEVConstant>(StepRec)) {
236
FailureReason = "LHS in icmp not induction variable";
237
return std::nullopt;
238
}
239
ConstantInt *StepCI = cast<SCEVConstant>(StepRec)->getValue();
240
241
if (ICI->isEquality() && !HasNoSignedWrap(IndVarBase)) {
242
FailureReason = "LHS in icmp needs nsw for equality predicates";
243
return std::nullopt;
244
}
245
246
assert(!StepCI->isZero() && "Zero step?");
247
bool IsIncreasing = !StepCI->isNegative();
248
bool IsSignedPredicate;
249
const SCEV *StartNext = IndVarBase->getStart();
250
const SCEV *Addend = SE.getNegativeSCEV(IndVarBase->getStepRecurrence(SE));
251
const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend);
252
const SCEV *Step = SE.getSCEV(StepCI);
253
254
const SCEV *FixedRightSCEV = nullptr;
255
256
// If RightValue resides within loop (but still being loop invariant),
257
// regenerate it as preheader.
258
if (auto *I = dyn_cast<Instruction>(RightValue))
259
if (L.contains(I->getParent()))
260
FixedRightSCEV = RightSCEV;
261
262
if (IsIncreasing) {
263
bool DecreasedRightValueByOne = false;
264
if (StepCI->isOne()) {
265
// Try to turn eq/ne predicates to those we can work with.
266
if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)
267
// while (++i != len) { while (++i < len) {
268
// ... ---> ...
269
// } }
270
// If both parts are known non-negative, it is profitable to use
271
// unsigned comparison in increasing loop. This allows us to make the
272
// comparison check against "RightSCEV + 1" more optimistic.
273
if (isKnownNonNegativeInLoop(IndVarStart, &L, SE) &&
274
isKnownNonNegativeInLoop(RightSCEV, &L, SE))
275
Pred = ICmpInst::ICMP_ULT;
276
else
277
Pred = ICmpInst::ICMP_SLT;
278
else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {
279
// while (true) { while (true) {
280
// if (++i == len) ---> if (++i > len - 1)
281
// break; break;
282
// ... ...
283
// } }
284
if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) &&
285
cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ false)) {
286
Pred = ICmpInst::ICMP_UGT;
287
RightSCEV =
288
SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
289
DecreasedRightValueByOne = true;
290
} else if (cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ true)) {
291
Pred = ICmpInst::ICMP_SGT;
292
RightSCEV =
293
SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
294
DecreasedRightValueByOne = true;
295
}
296
}
297
}
298
299
bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);
300
bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);
301
bool FoundExpectedPred =
302
(LTPred && LatchBrExitIdx == 1) || (GTPred && LatchBrExitIdx == 0);
303
304
if (!FoundExpectedPred) {
305
FailureReason = "expected icmp slt semantically, found something else";
306
return std::nullopt;
307
}
308
309
IsSignedPredicate = ICmpInst::isSigned(Pred);
310
if (!IsSignedPredicate && !AllowUnsignedLatchCond) {
311
FailureReason = "unsigned latch conditions are explicitly prohibited";
312
return std::nullopt;
313
}
314
315
if (!isSafeIncreasingBound(IndVarStart, RightSCEV, Step, Pred,
316
LatchBrExitIdx, &L, SE)) {
317
FailureReason = "Unsafe loop bounds";
318
return std::nullopt;
319
}
320
if (LatchBrExitIdx == 0) {
321
// We need to increase the right value unless we have already decreased
322
// it virtually when we replaced EQ with SGT.
323
if (!DecreasedRightValueByOne)
324
FixedRightSCEV =
325
SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
326
} else {
327
assert(!DecreasedRightValueByOne &&
328
"Right value can be decreased only for LatchBrExitIdx == 0!");
329
}
330
} else {
331
bool IncreasedRightValueByOne = false;
332
if (StepCI->isMinusOne()) {
333
// Try to turn eq/ne predicates to those we can work with.
334
if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)
335
// while (--i != len) { while (--i > len) {
336
// ... ---> ...
337
// } }
338
// We intentionally don't turn the predicate into UGT even if we know
339
// that both operands are non-negative, because it will only pessimize
340
// our check against "RightSCEV - 1".
341
Pred = ICmpInst::ICMP_SGT;
342
else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {
343
// while (true) { while (true) {
344
// if (--i == len) ---> if (--i < len + 1)
345
// break; break;
346
// ... ...
347
// } }
348
if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) &&
349
cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ false)) {
350
Pred = ICmpInst::ICMP_ULT;
351
RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
352
IncreasedRightValueByOne = true;
353
} else if (cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ true)) {
354
Pred = ICmpInst::ICMP_SLT;
355
RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
356
IncreasedRightValueByOne = true;
357
}
358
}
359
}
360
361
bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);
362
bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);
363
364
bool FoundExpectedPred =
365
(GTPred && LatchBrExitIdx == 1) || (LTPred && LatchBrExitIdx == 0);
366
367
if (!FoundExpectedPred) {
368
FailureReason = "expected icmp sgt semantically, found something else";
369
return std::nullopt;
370
}
371
372
IsSignedPredicate =
373
Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT;
374
375
if (!IsSignedPredicate && !AllowUnsignedLatchCond) {
376
FailureReason = "unsigned latch conditions are explicitly prohibited";
377
return std::nullopt;
378
}
379
380
if (!isSafeDecreasingBound(IndVarStart, RightSCEV, Step, Pred,
381
LatchBrExitIdx, &L, SE)) {
382
FailureReason = "Unsafe bounds";
383
return std::nullopt;
384
}
385
386
if (LatchBrExitIdx == 0) {
387
// We need to decrease the right value unless we have already increased
388
// it virtually when we replaced EQ with SLT.
389
if (!IncreasedRightValueByOne)
390
FixedRightSCEV =
391
SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
392
} else {
393
assert(!IncreasedRightValueByOne &&
394
"Right value can be increased only for LatchBrExitIdx == 0!");
395
}
396
}
397
BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx);
398
399
assert(!L.contains(LatchExit) && "expected an exit block!");
400
const DataLayout &DL = Preheader->getDataLayout();
401
SCEVExpander Expander(SE, DL, "loop-constrainer");
402
Instruction *Ins = Preheader->getTerminator();
403
404
if (FixedRightSCEV)
405
RightValue =
406
Expander.expandCodeFor(FixedRightSCEV, FixedRightSCEV->getType(), Ins);
407
408
Value *IndVarStartV = Expander.expandCodeFor(IndVarStart, IndVarTy, Ins);
409
IndVarStartV->setName("indvar.start");
410
411
LoopStructure Result;
412
413
Result.Tag = "main";
414
Result.Header = Header;
415
Result.Latch = Latch;
416
Result.LatchBr = LatchBr;
417
Result.LatchExit = LatchExit;
418
Result.LatchBrExitIdx = LatchBrExitIdx;
419
Result.IndVarStart = IndVarStartV;
420
Result.IndVarStep = StepCI;
421
Result.IndVarBase = LeftValue;
422
Result.IndVarIncreasing = IsIncreasing;
423
Result.LoopExitAt = RightValue;
424
Result.IsSignedPredicate = IsSignedPredicate;
425
Result.ExitCountTy = cast<IntegerType>(MaxBETakenCount->getType());
426
427
FailureReason = nullptr;
428
429
return Result;
430
}
431
432
// Add metadata to the loop L to disable loop optimizations. Callers need to
433
// confirm that optimizing loop L is not beneficial.
434
static void DisableAllLoopOptsOnLoop(Loop &L) {
435
// We do not care about any existing loopID related metadata for L, since we
436
// are setting all loop metadata to false.
437
LLVMContext &Context = L.getHeader()->getContext();
438
// Reserve first location for self reference to the LoopID metadata node.
439
MDNode *Dummy = MDNode::get(Context, {});
440
MDNode *DisableUnroll = MDNode::get(
441
Context, {MDString::get(Context, "llvm.loop.unroll.disable")});
442
Metadata *FalseVal =
443
ConstantAsMetadata::get(ConstantInt::get(Type::getInt1Ty(Context), 0));
444
MDNode *DisableVectorize = MDNode::get(
445
Context,
446
{MDString::get(Context, "llvm.loop.vectorize.enable"), FalseVal});
447
MDNode *DisableLICMVersioning = MDNode::get(
448
Context, {MDString::get(Context, "llvm.loop.licm_versioning.disable")});
449
MDNode *DisableDistribution = MDNode::get(
450
Context,
451
{MDString::get(Context, "llvm.loop.distribute.enable"), FalseVal});
452
MDNode *NewLoopID =
453
MDNode::get(Context, {Dummy, DisableUnroll, DisableVectorize,
454
DisableLICMVersioning, DisableDistribution});
455
// Set operand 0 to refer to the loop id itself.
456
NewLoopID->replaceOperandWith(0, NewLoopID);
457
L.setLoopID(NewLoopID);
458
}
459
460
LoopConstrainer::LoopConstrainer(Loop &L, LoopInfo &LI,
461
function_ref<void(Loop *, bool)> LPMAddNewLoop,
462
const LoopStructure &LS, ScalarEvolution &SE,
463
DominatorTree &DT, Type *T, SubRanges SR)
464
: F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), SE(SE),
465
DT(DT), LI(LI), LPMAddNewLoop(LPMAddNewLoop), OriginalLoop(L), RangeTy(T),
466
MainLoopStructure(LS), SR(SR) {}
467
468
void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result,
469
const char *Tag) const {
470
for (BasicBlock *BB : OriginalLoop.getBlocks()) {
471
BasicBlock *Clone = CloneBasicBlock(BB, Result.Map, Twine(".") + Tag, &F);
472
Result.Blocks.push_back(Clone);
473
Result.Map[BB] = Clone;
474
}
475
476
auto GetClonedValue = [&Result](Value *V) {
477
assert(V && "null values not in domain!");
478
auto It = Result.Map.find(V);
479
if (It == Result.Map.end())
480
return V;
481
return static_cast<Value *>(It->second);
482
};
483
484
auto *ClonedLatch =
485
cast<BasicBlock>(GetClonedValue(OriginalLoop.getLoopLatch()));
486
ClonedLatch->getTerminator()->setMetadata(ClonedLoopTag,
487
MDNode::get(Ctx, {}));
488
489
Result.Structure = MainLoopStructure.map(GetClonedValue);
490
Result.Structure.Tag = Tag;
491
492
for (unsigned i = 0, e = Result.Blocks.size(); i != e; ++i) {
493
BasicBlock *ClonedBB = Result.Blocks[i];
494
BasicBlock *OriginalBB = OriginalLoop.getBlocks()[i];
495
496
assert(Result.Map[OriginalBB] == ClonedBB && "invariant!");
497
498
for (Instruction &I : *ClonedBB)
499
RemapInstruction(&I, Result.Map,
500
RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
501
502
// Exit blocks will now have one more predecessor and their PHI nodes need
503
// to be edited to reflect that. No phi nodes need to be introduced because
504
// the loop is in LCSSA.
505
506
for (auto *SBB : successors(OriginalBB)) {
507
if (OriginalLoop.contains(SBB))
508
continue; // not an exit block
509
510
for (PHINode &PN : SBB->phis()) {
511
Value *OldIncoming = PN.getIncomingValueForBlock(OriginalBB);
512
PN.addIncoming(GetClonedValue(OldIncoming), ClonedBB);
513
SE.forgetValue(&PN);
514
}
515
}
516
}
517
}
518
519
LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd(
520
const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt,
521
BasicBlock *ContinuationBlock) const {
522
// We start with a loop with a single latch:
523
//
524
// +--------------------+
525
// | |
526
// | preheader |
527
// | |
528
// +--------+-----------+
529
// | ----------------\
530
// | / |
531
// +--------v----v------+ |
532
// | | |
533
// | header | |
534
// | | |
535
// +--------------------+ |
536
// |
537
// ..... |
538
// |
539
// +--------------------+ |
540
// | | |
541
// | latch >----------/
542
// | |
543
// +-------v------------+
544
// |
545
// |
546
// | +--------------------+
547
// | | |
548
// +---> original exit |
549
// | |
550
// +--------------------+
551
//
552
// We change the control flow to look like
553
//
554
//
555
// +--------------------+
556
// | |
557
// | preheader >-------------------------+
558
// | | |
559
// +--------v-----------+ |
560
// | /-------------+ |
561
// | / | |
562
// +--------v--v--------+ | |
563
// | | | |
564
// | header | | +--------+ |
565
// | | | | | |
566
// +--------------------+ | | +-----v-----v-----------+
567
// | | | |
568
// | | | .pseudo.exit |
569
// | | | |
570
// | | +-----------v-----------+
571
// | | |
572
// ..... | | |
573
// | | +--------v-------------+
574
// +--------------------+ | | | |
575
// | | | | | ContinuationBlock |
576
// | latch >------+ | | |
577
// | | | +----------------------+
578
// +---------v----------+ |
579
// | |
580
// | |
581
// | +---------------^-----+
582
// | | |
583
// +-----> .exit.selector |
584
// | |
585
// +----------v----------+
586
// |
587
// +--------------------+ |
588
// | | |
589
// | original exit <----+
590
// | |
591
// +--------------------+
592
593
RewrittenRangeInfo RRI;
594
595
BasicBlock *BBInsertLocation = LS.Latch->getNextNode();
596
RRI.ExitSelector = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".exit.selector",
597
&F, BBInsertLocation);
598
RRI.PseudoExit = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".pseudo.exit", &F,
599
BBInsertLocation);
600
601
BranchInst *PreheaderJump = cast<BranchInst>(Preheader->getTerminator());
602
bool Increasing = LS.IndVarIncreasing;
603
bool IsSignedPredicate = LS.IsSignedPredicate;
604
605
IRBuilder<> B(PreheaderJump);
606
auto NoopOrExt = [&](Value *V) {
607
if (V->getType() == RangeTy)
608
return V;
609
return IsSignedPredicate ? B.CreateSExt(V, RangeTy, "wide." + V->getName())
610
: B.CreateZExt(V, RangeTy, "wide." + V->getName());
611
};
612
613
// EnterLoopCond - is it okay to start executing this `LS'?
614
Value *EnterLoopCond = nullptr;
615
auto Pred =
616
Increasing
617
? (IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT)
618
: (IsSignedPredicate ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT);
619
Value *IndVarStart = NoopOrExt(LS.IndVarStart);
620
EnterLoopCond = B.CreateICmp(Pred, IndVarStart, ExitSubloopAt);
621
622
B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit);
623
PreheaderJump->eraseFromParent();
624
625
LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector);
626
B.SetInsertPoint(LS.LatchBr);
627
Value *IndVarBase = NoopOrExt(LS.IndVarBase);
628
Value *TakeBackedgeLoopCond = B.CreateICmp(Pred, IndVarBase, ExitSubloopAt);
629
630
Value *CondForBranch = LS.LatchBrExitIdx == 1
631
? TakeBackedgeLoopCond
632
: B.CreateNot(TakeBackedgeLoopCond);
633
634
LS.LatchBr->setCondition(CondForBranch);
635
636
B.SetInsertPoint(RRI.ExitSelector);
637
638
// IterationsLeft - are there any more iterations left, given the original
639
// upper bound on the induction variable? If not, we branch to the "real"
640
// exit.
641
Value *LoopExitAt = NoopOrExt(LS.LoopExitAt);
642
Value *IterationsLeft = B.CreateICmp(Pred, IndVarBase, LoopExitAt);
643
B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit);
644
645
BranchInst *BranchToContinuation =
646
BranchInst::Create(ContinuationBlock, RRI.PseudoExit);
647
648
// We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of
649
// each of the PHI nodes in the loop header. This feeds into the initial
650
// value of the same PHI nodes if/when we continue execution.
651
for (PHINode &PN : LS.Header->phis()) {
652
PHINode *NewPHI = PHINode::Create(PN.getType(), 2, PN.getName() + ".copy",
653
BranchToContinuation->getIterator());
654
655
NewPHI->addIncoming(PN.getIncomingValueForBlock(Preheader), Preheader);
656
NewPHI->addIncoming(PN.getIncomingValueForBlock(LS.Latch),
657
RRI.ExitSelector);
658
RRI.PHIValuesAtPseudoExit.push_back(NewPHI);
659
}
660
661
RRI.IndVarEnd = PHINode::Create(IndVarBase->getType(), 2, "indvar.end",
662
BranchToContinuation->getIterator());
663
RRI.IndVarEnd->addIncoming(IndVarStart, Preheader);
664
RRI.IndVarEnd->addIncoming(IndVarBase, RRI.ExitSelector);
665
666
// The latch exit now has a branch from `RRI.ExitSelector' instead of
667
// `LS.Latch'. The PHI nodes need to be updated to reflect that.
668
LS.LatchExit->replacePhiUsesWith(LS.Latch, RRI.ExitSelector);
669
670
return RRI;
671
}
672
673
void LoopConstrainer::rewriteIncomingValuesForPHIs(
674
LoopStructure &LS, BasicBlock *ContinuationBlock,
675
const LoopConstrainer::RewrittenRangeInfo &RRI) const {
676
unsigned PHIIndex = 0;
677
for (PHINode &PN : LS.Header->phis())
678
PN.setIncomingValueForBlock(ContinuationBlock,
679
RRI.PHIValuesAtPseudoExit[PHIIndex++]);
680
681
LS.IndVarStart = RRI.IndVarEnd;
682
}
683
684
BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS,
685
BasicBlock *OldPreheader,
686
const char *Tag) const {
687
BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header);
688
BranchInst::Create(LS.Header, Preheader);
689
690
LS.Header->replacePhiUsesWith(OldPreheader, Preheader);
691
692
return Preheader;
693
}
694
695
void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) {
696
Loop *ParentLoop = OriginalLoop.getParentLoop();
697
if (!ParentLoop)
698
return;
699
700
for (BasicBlock *BB : BBs)
701
ParentLoop->addBasicBlockToLoop(BB, LI);
702
}
703
704
Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent,
705
ValueToValueMapTy &VM,
706
bool IsSubloop) {
707
Loop &New = *LI.AllocateLoop();
708
if (Parent)
709
Parent->addChildLoop(&New);
710
else
711
LI.addTopLevelLoop(&New);
712
LPMAddNewLoop(&New, IsSubloop);
713
714
// Add all of the blocks in Original to the new loop.
715
for (auto *BB : Original->blocks())
716
if (LI.getLoopFor(BB) == Original)
717
New.addBasicBlockToLoop(cast<BasicBlock>(VM[BB]), LI);
718
719
// Add all of the subloops to the new loop.
720
for (Loop *SubLoop : *Original)
721
createClonedLoopStructure(SubLoop, &New, VM, /* IsSubloop */ true);
722
723
return &New;
724
}
725
726
bool LoopConstrainer::run() {
727
BasicBlock *Preheader = OriginalLoop.getLoopPreheader();
728
assert(Preheader != nullptr && "precondition!");
729
730
OriginalPreheader = Preheader;
731
MainLoopPreheader = Preheader;
732
bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate;
733
bool Increasing = MainLoopStructure.IndVarIncreasing;
734
IntegerType *IVTy = cast<IntegerType>(RangeTy);
735
736
SCEVExpander Expander(SE, F.getDataLayout(), "loop-constrainer");
737
Instruction *InsertPt = OriginalPreheader->getTerminator();
738
739
// It would have been better to make `PreLoop' and `PostLoop'
740
// `std::optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy
741
// constructor.
742
ClonedLoop PreLoop, PostLoop;
743
bool NeedsPreLoop =
744
Increasing ? SR.LowLimit.has_value() : SR.HighLimit.has_value();
745
bool NeedsPostLoop =
746
Increasing ? SR.HighLimit.has_value() : SR.LowLimit.has_value();
747
748
Value *ExitPreLoopAt = nullptr;
749
Value *ExitMainLoopAt = nullptr;
750
const SCEVConstant *MinusOneS =
751
cast<SCEVConstant>(SE.getConstant(IVTy, -1, true /* isSigned */));
752
753
if (NeedsPreLoop) {
754
const SCEV *ExitPreLoopAtSCEV = nullptr;
755
756
if (Increasing)
757
ExitPreLoopAtSCEV = *SR.LowLimit;
758
else if (cannotBeMinInLoop(*SR.HighLimit, &OriginalLoop, SE,
759
IsSignedPredicate))
760
ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS);
761
else {
762
LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing "
763
<< "preloop exit limit. HighLimit = "
764
<< *(*SR.HighLimit) << "\n");
765
return false;
766
}
767
768
if (!Expander.isSafeToExpandAt(ExitPreLoopAtSCEV, InsertPt)) {
769
LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the"
770
<< " preloop exit limit " << *ExitPreLoopAtSCEV
771
<< " at block " << InsertPt->getParent()->getName()
772
<< "\n");
773
return false;
774
}
775
776
ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt);
777
ExitPreLoopAt->setName("exit.preloop.at");
778
}
779
780
if (NeedsPostLoop) {
781
const SCEV *ExitMainLoopAtSCEV = nullptr;
782
783
if (Increasing)
784
ExitMainLoopAtSCEV = *SR.HighLimit;
785
else if (cannotBeMinInLoop(*SR.LowLimit, &OriginalLoop, SE,
786
IsSignedPredicate))
787
ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS);
788
else {
789
LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing "
790
<< "mainloop exit limit. LowLimit = "
791
<< *(*SR.LowLimit) << "\n");
792
return false;
793
}
794
795
if (!Expander.isSafeToExpandAt(ExitMainLoopAtSCEV, InsertPt)) {
796
LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the"
797
<< " main loop exit limit " << *ExitMainLoopAtSCEV
798
<< " at block " << InsertPt->getParent()->getName()
799
<< "\n");
800
return false;
801
}
802
803
ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt);
804
ExitMainLoopAt->setName("exit.mainloop.at");
805
}
806
807
// We clone these ahead of time so that we don't have to deal with changing
808
// and temporarily invalid IR as we transform the loops.
809
if (NeedsPreLoop)
810
cloneLoop(PreLoop, "preloop");
811
if (NeedsPostLoop)
812
cloneLoop(PostLoop, "postloop");
813
814
RewrittenRangeInfo PreLoopRRI;
815
816
if (NeedsPreLoop) {
817
Preheader->getTerminator()->replaceUsesOfWith(MainLoopStructure.Header,
818
PreLoop.Structure.Header);
819
820
MainLoopPreheader =
821
createPreheader(MainLoopStructure, Preheader, "mainloop");
822
PreLoopRRI = changeIterationSpaceEnd(PreLoop.Structure, Preheader,
823
ExitPreLoopAt, MainLoopPreheader);
824
rewriteIncomingValuesForPHIs(MainLoopStructure, MainLoopPreheader,
825
PreLoopRRI);
826
}
827
828
BasicBlock *PostLoopPreheader = nullptr;
829
RewrittenRangeInfo PostLoopRRI;
830
831
if (NeedsPostLoop) {
832
PostLoopPreheader =
833
createPreheader(PostLoop.Structure, Preheader, "postloop");
834
PostLoopRRI = changeIterationSpaceEnd(MainLoopStructure, MainLoopPreheader,
835
ExitMainLoopAt, PostLoopPreheader);
836
rewriteIncomingValuesForPHIs(PostLoop.Structure, PostLoopPreheader,
837
PostLoopRRI);
838
}
839
840
BasicBlock *NewMainLoopPreheader =
841
MainLoopPreheader != Preheader ? MainLoopPreheader : nullptr;
842
BasicBlock *NewBlocks[] = {PostLoopPreheader, PreLoopRRI.PseudoExit,
843
PreLoopRRI.ExitSelector, PostLoopRRI.PseudoExit,
844
PostLoopRRI.ExitSelector, NewMainLoopPreheader};
845
846
// Some of the above may be nullptr, filter them out before passing to
847
// addToParentLoopIfNeeded.
848
auto NewBlocksEnd =
849
std::remove(std::begin(NewBlocks), std::end(NewBlocks), nullptr);
850
851
addToParentLoopIfNeeded(ArrayRef(std::begin(NewBlocks), NewBlocksEnd));
852
853
DT.recalculate(F);
854
855
// We need to first add all the pre and post loop blocks into the loop
856
// structures (as part of createClonedLoopStructure), and then update the
857
// LCSSA form and LoopSimplifyForm. This is necessary for correctly updating
858
// LI when LoopSimplifyForm is generated.
859
Loop *PreL = nullptr, *PostL = nullptr;
860
if (!PreLoop.Blocks.empty()) {
861
PreL = createClonedLoopStructure(&OriginalLoop,
862
OriginalLoop.getParentLoop(), PreLoop.Map,
863
/* IsSubLoop */ false);
864
}
865
866
if (!PostLoop.Blocks.empty()) {
867
PostL =
868
createClonedLoopStructure(&OriginalLoop, OriginalLoop.getParentLoop(),
869
PostLoop.Map, /* IsSubLoop */ false);
870
}
871
872
// This function canonicalizes the loop into Loop-Simplify and LCSSA forms.
873
auto CanonicalizeLoop = [&](Loop *L, bool IsOriginalLoop) {
874
formLCSSARecursively(*L, DT, &LI, &SE);
875
simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr, true);
876
// Pre/post loops are slow paths, we do not need to perform any loop
877
// optimizations on them.
878
if (!IsOriginalLoop)
879
DisableAllLoopOptsOnLoop(*L);
880
};
881
if (PreL)
882
CanonicalizeLoop(PreL, false);
883
if (PostL)
884
CanonicalizeLoop(PostL, false);
885
CanonicalizeLoop(&OriginalLoop, true);
886
887
/// At this point:
888
/// - We've broken a "main loop" out of the loop in a way that the "main loop"
889
/// runs with the induction variable in a subset of [Begin, End).
890
/// - There is no overflow when computing "main loop" exit limit.
891
/// - Max latch taken count of the loop is limited.
892
/// It guarantees that induction variable will not overflow iterating in the
893
/// "main loop".
894
if (isa<OverflowingBinaryOperator>(MainLoopStructure.IndVarBase))
895
if (IsSignedPredicate)
896
cast<BinaryOperator>(MainLoopStructure.IndVarBase)
897
->setHasNoSignedWrap(true);
898
/// TODO: support unsigned predicate.
899
/// To add NUW flag we need to prove that both operands of BO are
900
/// non-negative. E.g:
901
/// ...
902
/// %iv.next = add nsw i32 %iv, -1
903
/// %cmp = icmp ult i32 %iv.next, %n
904
/// br i1 %cmp, label %loopexit, label %loop
905
///
906
/// -1 is MAX_UINT in terms of unsigned int. Adding anything but zero will
907
/// overflow, therefore NUW flag is not legal here.
908
909
return true;
910
}
911
912