Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopTermFold.cpp
213799 views
1
//===- LoopTermFold.cpp - Eliminate last use of IV in exit branch----------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//===----------------------------------------------------------------------===//
9
10
#include "llvm/Transforms/Scalar/LoopTermFold.h"
11
#include "llvm/ADT/Statistic.h"
12
#include "llvm/Analysis/LoopAnalysisManager.h"
13
#include "llvm/Analysis/LoopInfo.h"
14
#include "llvm/Analysis/LoopPass.h"
15
#include "llvm/Analysis/MemorySSA.h"
16
#include "llvm/Analysis/MemorySSAUpdater.h"
17
#include "llvm/Analysis/ScalarEvolution.h"
18
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
19
#include "llvm/Analysis/TargetLibraryInfo.h"
20
#include "llvm/Analysis/TargetTransformInfo.h"
21
#include "llvm/Analysis/ValueTracking.h"
22
#include "llvm/Config/llvm-config.h"
23
#include "llvm/IR/BasicBlock.h"
24
#include "llvm/IR/Dominators.h"
25
#include "llvm/IR/IRBuilder.h"
26
#include "llvm/IR/InstrTypes.h"
27
#include "llvm/IR/Instruction.h"
28
#include "llvm/IR/Instructions.h"
29
#include "llvm/IR/Type.h"
30
#include "llvm/IR/Value.h"
31
#include "llvm/InitializePasses.h"
32
#include "llvm/Pass.h"
33
#include "llvm/Support/Debug.h"
34
#include "llvm/Support/raw_ostream.h"
35
#include "llvm/Transforms/Scalar.h"
36
#include "llvm/Transforms/Utils.h"
37
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
38
#include "llvm/Transforms/Utils/Local.h"
39
#include "llvm/Transforms/Utils/LoopUtils.h"
40
#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
41
#include <cassert>
42
#include <optional>
43
44
using namespace llvm;
45
46
#define DEBUG_TYPE "loop-term-fold"
47
48
STATISTIC(NumTermFold,
49
"Number of terminating condition fold recognized and performed");
50
51
static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *, bool>>
52
canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
53
const LoopInfo &LI, const TargetTransformInfo &TTI) {
54
if (!L->isInnermost()) {
55
LLVM_DEBUG(dbgs() << "Cannot fold on non-innermost loop\n");
56
return std::nullopt;
57
}
58
// Only inspect on simple loop structure
59
if (!L->isLoopSimplifyForm()) {
60
LLVM_DEBUG(dbgs() << "Cannot fold on non-simple loop\n");
61
return std::nullopt;
62
}
63
64
if (!SE.hasLoopInvariantBackedgeTakenCount(L)) {
65
LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n");
66
return std::nullopt;
67
}
68
69
BasicBlock *LoopLatch = L->getLoopLatch();
70
BranchInst *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator());
71
if (!BI || BI->isUnconditional())
72
return std::nullopt;
73
auto *TermCond = dyn_cast<ICmpInst>(BI->getCondition());
74
if (!TermCond) {
75
LLVM_DEBUG(
76
dbgs() << "Cannot fold on branching condition that is not an ICmpInst");
77
return std::nullopt;
78
}
79
if (!TermCond->hasOneUse()) {
80
LLVM_DEBUG(
81
dbgs()
82
<< "Cannot replace terminating condition with more than one use\n");
83
return std::nullopt;
84
}
85
86
BinaryOperator *LHS = dyn_cast<BinaryOperator>(TermCond->getOperand(0));
87
Value *RHS = TermCond->getOperand(1);
88
if (!LHS || !L->isLoopInvariant(RHS))
89
// We could pattern match the inverse form of the icmp, but that is
90
// non-canonical, and this pass is running *very* late in the pipeline.
91
return std::nullopt;
92
93
// Find the IV used by the current exit condition.
94
PHINode *ToFold;
95
Value *ToFoldStart, *ToFoldStep;
96
if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep))
97
return std::nullopt;
98
99
// Ensure the simple recurrence is a part of the current loop.
100
if (ToFold->getParent() != L->getHeader())
101
return std::nullopt;
102
103
// If that IV isn't dead after we rewrite the exit condition in terms of
104
// another IV, there's no point in doing the transform.
105
if (!isAlmostDeadIV(ToFold, LoopLatch, TermCond))
106
return std::nullopt;
107
108
// Inserting instructions in the preheader has a runtime cost, scale
109
// the allowed cost with the loops trip count as best we can.
110
const unsigned ExpansionBudget = [&]() {
111
unsigned Budget = 2 * SCEVCheapExpansionBudget;
112
if (unsigned SmallTC = SE.getSmallConstantMaxTripCount(L))
113
return std::min(Budget, SmallTC);
114
if (std::optional<unsigned> SmallTC = getLoopEstimatedTripCount(L))
115
return std::min(Budget, *SmallTC);
116
// Unknown trip count, assume long running by default.
117
return Budget;
118
}();
119
120
const SCEV *BECount = SE.getBackedgeTakenCount(L);
121
const DataLayout &DL = L->getHeader()->getDataLayout();
122
SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
123
124
PHINode *ToHelpFold = nullptr;
125
const SCEV *TermValueS = nullptr;
126
bool MustDropPoison = false;
127
auto InsertPt = L->getLoopPreheader()->getTerminator();
128
for (PHINode &PN : L->getHeader()->phis()) {
129
if (ToFold == &PN)
130
continue;
131
132
if (!SE.isSCEVable(PN.getType())) {
133
LLVM_DEBUG(dbgs() << "IV of phi '" << PN
134
<< "' is not SCEV-able, not qualified for the "
135
"terminating condition folding.\n");
136
continue;
137
}
138
const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN));
139
// Only speculate on affine AddRec
140
if (!AddRec || !AddRec->isAffine()) {
141
LLVM_DEBUG(dbgs() << "SCEV of phi '" << PN
142
<< "' is not an affine add recursion, not qualified "
143
"for the terminating condition folding.\n");
144
continue;
145
}
146
147
// Check that we can compute the value of AddRec on the exiting iteration
148
// without soundness problems. evaluateAtIteration internally needs
149
// to multiply the stride of the iteration number - which may wrap around.
150
// The issue here is subtle because computing the result accounting for
151
// wrap is insufficient. In order to use the result in an exit test, we
152
// must also know that AddRec doesn't take the same value on any previous
153
// iteration. The simplest case to consider is a candidate IV which is
154
// narrower than the trip count (and thus original IV), but this can
155
// also happen due to non-unit strides on the candidate IVs.
156
if (!AddRec->hasNoSelfWrap() ||
157
!SE.isKnownNonZero(AddRec->getStepRecurrence(SE)))
158
continue;
159
160
const SCEVAddRecExpr *PostInc = AddRec->getPostIncExpr(SE);
161
const SCEV *TermValueSLocal = PostInc->evaluateAtIteration(BECount, SE);
162
if (!Expander.isSafeToExpand(TermValueSLocal)) {
163
LLVM_DEBUG(
164
dbgs() << "Is not safe to expand terminating value for phi node" << PN
165
<< "\n");
166
continue;
167
}
168
169
if (Expander.isHighCostExpansion(TermValueSLocal, L, ExpansionBudget, &TTI,
170
InsertPt)) {
171
LLVM_DEBUG(
172
dbgs() << "Is too expensive to expand terminating value for phi node"
173
<< PN << "\n");
174
continue;
175
}
176
177
// The candidate IV may have been otherwise dead and poison from the
178
// very first iteration. If we can't disprove that, we can't use the IV.
179
if (!mustExecuteUBIfPoisonOnPathTo(&PN, LoopLatch->getTerminator(), &DT)) {
180
LLVM_DEBUG(dbgs() << "Can not prove poison safety for IV " << PN << "\n");
181
continue;
182
}
183
184
// The candidate IV may become poison on the last iteration. If this
185
// value is not branched on, this is a well defined program. We're
186
// about to add a new use to this IV, and we have to ensure we don't
187
// insert UB which didn't previously exist.
188
bool MustDropPoisonLocal = false;
189
Instruction *PostIncV =
190
cast<Instruction>(PN.getIncomingValueForBlock(LoopLatch));
191
if (!mustExecuteUBIfPoisonOnPathTo(PostIncV, LoopLatch->getTerminator(),
192
&DT)) {
193
LLVM_DEBUG(dbgs() << "Can not prove poison safety to insert use" << PN
194
<< "\n");
195
196
// If this is a complex recurrance with multiple instructions computing
197
// the backedge value, we might need to strip poison flags from all of
198
// them.
199
if (PostIncV->getOperand(0) != &PN)
200
continue;
201
202
// In order to perform the transform, we need to drop the poison
203
// generating flags on this instruction (if any).
204
MustDropPoisonLocal = PostIncV->hasPoisonGeneratingFlags();
205
}
206
207
// We pick the last legal alternate IV. We could expore choosing an optimal
208
// alternate IV if we had a decent heuristic to do so.
209
ToHelpFold = &PN;
210
TermValueS = TermValueSLocal;
211
MustDropPoison = MustDropPoisonLocal;
212
}
213
214
LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs()
215
<< "Cannot find other AddRec IV to help folding\n";);
216
217
LLVM_DEBUG(if (ToFold && ToHelpFold) dbgs()
218
<< "\nFound loop that can fold terminating condition\n"
219
<< " BECount (SCEV): " << *SE.getBackedgeTakenCount(L) << "\n"
220
<< " TermCond: " << *TermCond << "\n"
221
<< " BrandInst: " << *BI << "\n"
222
<< " ToFold: " << *ToFold << "\n"
223
<< " ToHelpFold: " << *ToHelpFold << "\n");
224
225
if (!ToFold || !ToHelpFold)
226
return std::nullopt;
227
return std::make_tuple(ToFold, ToHelpFold, TermValueS, MustDropPoison);
228
}
229
230
static bool RunTermFold(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
231
LoopInfo &LI, const TargetTransformInfo &TTI,
232
TargetLibraryInfo &TLI, MemorySSA *MSSA) {
233
std::unique_ptr<MemorySSAUpdater> MSSAU;
234
if (MSSA)
235
MSSAU = std::make_unique<MemorySSAUpdater>(MSSA);
236
237
auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI, TTI);
238
if (!Opt)
239
return false;
240
241
auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt;
242
243
NumTermFold++;
244
245
BasicBlock *LoopPreheader = L->getLoopPreheader();
246
BasicBlock *LoopLatch = L->getLoopLatch();
247
248
(void)ToFold;
249
LLVM_DEBUG(dbgs() << "To fold phi-node:\n"
250
<< *ToFold << "\n"
251
<< "New term-cond phi-node:\n"
252
<< *ToHelpFold << "\n");
253
254
Value *StartValue = ToHelpFold->getIncomingValueForBlock(LoopPreheader);
255
(void)StartValue;
256
Value *LoopValue = ToHelpFold->getIncomingValueForBlock(LoopLatch);
257
258
// See comment in canFoldTermCondOfLoop on why this is sufficient.
259
if (MustDrop)
260
cast<Instruction>(LoopValue)->dropPoisonGeneratingFlags();
261
262
// SCEVExpander for both use in preheader and latch
263
const DataLayout &DL = L->getHeader()->getDataLayout();
264
SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
265
266
assert(Expander.isSafeToExpand(TermValueS) &&
267
"Terminating value was checked safe in canFoldTerminatingCondition");
268
269
// Create new terminating value at loop preheader
270
Value *TermValue = Expander.expandCodeFor(TermValueS, ToHelpFold->getType(),
271
LoopPreheader->getTerminator());
272
273
LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n"
274
<< *StartValue << "\n"
275
<< "Terminating value of new term-cond phi-node:\n"
276
<< *TermValue << "\n");
277
278
// Create new terminating condition at loop latch
279
BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator());
280
ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition());
281
IRBuilder<> LatchBuilder(LoopLatch->getTerminator());
282
Value *NewTermCond =
283
LatchBuilder.CreateICmp(CmpInst::ICMP_EQ, LoopValue, TermValue,
284
"lsr_fold_term_cond.replaced_term_cond");
285
// Swap successors to exit loop body if IV equals to new TermValue
286
if (BI->getSuccessor(0) == L->getHeader())
287
BI->swapSuccessors();
288
289
LLVM_DEBUG(dbgs() << "Old term-cond:\n"
290
<< *OldTermCond << "\n"
291
<< "New term-cond:\n"
292
<< *NewTermCond << "\n");
293
294
BI->setCondition(NewTermCond);
295
296
Expander.clear();
297
OldTermCond->eraseFromParent();
298
DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get());
299
return true;
300
}
301
302
namespace {
303
304
class LoopTermFold : public LoopPass {
305
public:
306
static char ID; // Pass ID, replacement for typeid
307
308
LoopTermFold();
309
310
private:
311
bool runOnLoop(Loop *L, LPPassManager &LPM) override;
312
void getAnalysisUsage(AnalysisUsage &AU) const override;
313
};
314
315
} // end anonymous namespace
316
317
LoopTermFold::LoopTermFold() : LoopPass(ID) {
318
initializeLoopTermFoldPass(*PassRegistry::getPassRegistry());
319
}
320
321
void LoopTermFold::getAnalysisUsage(AnalysisUsage &AU) const {
322
AU.addRequired<LoopInfoWrapperPass>();
323
AU.addPreserved<LoopInfoWrapperPass>();
324
AU.addPreservedID(LoopSimplifyID);
325
AU.addRequiredID(LoopSimplifyID);
326
AU.addRequired<DominatorTreeWrapperPass>();
327
AU.addPreserved<DominatorTreeWrapperPass>();
328
AU.addRequired<ScalarEvolutionWrapperPass>();
329
AU.addPreserved<ScalarEvolutionWrapperPass>();
330
AU.addRequired<TargetLibraryInfoWrapperPass>();
331
AU.addRequired<TargetTransformInfoWrapperPass>();
332
AU.addPreserved<MemorySSAWrapperPass>();
333
}
334
335
bool LoopTermFold::runOnLoop(Loop *L, LPPassManager & /*LPM*/) {
336
if (skipLoop(L))
337
return false;
338
339
auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
340
auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
341
auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
342
const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
343
*L->getHeader()->getParent());
344
auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
345
*L->getHeader()->getParent());
346
auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>();
347
MemorySSA *MSSA = nullptr;
348
if (MSSAAnalysis)
349
MSSA = &MSSAAnalysis->getMSSA();
350
return RunTermFold(L, SE, DT, LI, TTI, TLI, MSSA);
351
}
352
353
PreservedAnalyses LoopTermFoldPass::run(Loop &L, LoopAnalysisManager &AM,
354
LoopStandardAnalysisResults &AR,
355
LPMUpdater &) {
356
if (!RunTermFold(&L, AR.SE, AR.DT, AR.LI, AR.TTI, AR.TLI, AR.MSSA))
357
return PreservedAnalyses::all();
358
359
auto PA = getLoopPassPreservedAnalyses();
360
if (AR.MSSA)
361
PA.preserve<MemorySSAAnalysis>();
362
return PA;
363
}
364
365
char LoopTermFold::ID = 0;
366
367
INITIALIZE_PASS_BEGIN(LoopTermFold, "loop-term-fold", "Loop Terminator Folding",
368
false, false)
369
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
370
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
371
INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
372
INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
373
INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
374
INITIALIZE_PASS_END(LoopTermFold, "loop-term-fold", "Loop Terminator Folding",
375
false, false)
376
377
Pass *llvm::createLoopTermFoldPass() { return new LoopTermFold(); }
378
379