Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Transforms/Vectorize/EVLIndVarSimplify.cpp
213799 views
1
//===---- EVLIndVarSimplify.cpp - Optimize vectorized loops w/ EVL IV------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This pass optimizes a vectorized loop with canonical IV to using EVL-based
10
// IV if it was tail-folded by predicated EVL.
11
//
12
//===----------------------------------------------------------------------===//
13
14
#include "llvm/Transforms/Vectorize/EVLIndVarSimplify.h"
15
#include "llvm/ADT/Statistic.h"
16
#include "llvm/Analysis/IVDescriptors.h"
17
#include "llvm/Analysis/LoopInfo.h"
18
#include "llvm/Analysis/LoopPass.h"
19
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
20
#include "llvm/Analysis/ScalarEvolution.h"
21
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
22
#include "llvm/Analysis/ValueTracking.h"
23
#include "llvm/IR/IRBuilder.h"
24
#include "llvm/IR/PatternMatch.h"
25
#include "llvm/Support/CommandLine.h"
26
#include "llvm/Support/Debug.h"
27
#include "llvm/Support/MathExtras.h"
28
#include "llvm/Support/raw_ostream.h"
29
#include "llvm/Transforms/Scalar/LoopPassManager.h"
30
#include "llvm/Transforms/Utils/Local.h"
31
32
#define DEBUG_TYPE "evl-iv-simplify"
33
34
using namespace llvm;
35
36
STATISTIC(NumEliminatedCanonicalIV, "Number of canonical IVs we eliminated");
37
38
static cl::opt<bool> EnableEVLIndVarSimplify(
39
"enable-evl-indvar-simplify",
40
cl::desc("Enable EVL-based induction variable simplify Pass"), cl::Hidden,
41
cl::init(true));
42
43
namespace {
44
struct EVLIndVarSimplifyImpl {
45
ScalarEvolution &SE;
46
OptimizationRemarkEmitter *ORE = nullptr;
47
48
EVLIndVarSimplifyImpl(LoopStandardAnalysisResults &LAR,
49
OptimizationRemarkEmitter *ORE)
50
: SE(LAR.SE), ORE(ORE) {}
51
52
/// Returns true if modify the loop.
53
bool run(Loop &L);
54
};
55
} // anonymous namespace
56
57
/// Returns the constant part of vectorization factor from the induction
58
/// variable's step value SCEV expression.
59
static uint32_t getVFFromIndVar(const SCEV *Step, const Function &F) {
60
if (!Step)
61
return 0U;
62
63
// Looking for loops with IV step value in the form of `(<constant VF> x
64
// vscale)`.
65
if (const auto *Mul = dyn_cast<SCEVMulExpr>(Step)) {
66
if (Mul->getNumOperands() == 2) {
67
const SCEV *LHS = Mul->getOperand(0);
68
const SCEV *RHS = Mul->getOperand(1);
69
if (const auto *Const = dyn_cast<SCEVConstant>(LHS);
70
Const && isa<SCEVVScale>(RHS)) {
71
uint64_t V = Const->getAPInt().getLimitedValue();
72
if (llvm::isUInt<32>(V))
73
return V;
74
}
75
}
76
}
77
78
// If not, see if the vscale_range of the parent function is a fixed value,
79
// which makes the step value to be replaced by a constant.
80
if (F.hasFnAttribute(Attribute::VScaleRange))
81
if (const auto *ConstStep = dyn_cast<SCEVConstant>(Step)) {
82
APInt V = ConstStep->getAPInt().abs();
83
ConstantRange CR = llvm::getVScaleRange(&F, 64);
84
if (const APInt *Fixed = CR.getSingleElement()) {
85
V = V.zextOrTrunc(Fixed->getBitWidth());
86
uint64_t VF = V.udiv(*Fixed).getLimitedValue();
87
if (VF && llvm::isUInt<32>(VF) &&
88
// Make sure step is divisible by vscale.
89
V.urem(*Fixed).isZero())
90
return VF;
91
}
92
}
93
94
return 0U;
95
}
96
97
bool EVLIndVarSimplifyImpl::run(Loop &L) {
98
if (!EnableEVLIndVarSimplify)
99
return false;
100
101
if (!getBooleanLoopAttribute(&L, "llvm.loop.isvectorized"))
102
return false;
103
const MDOperand *EVLMD =
104
findStringMetadataForLoop(&L, "llvm.loop.isvectorized.tailfoldingstyle")
105
.value_or(nullptr);
106
if (!EVLMD || !EVLMD->equalsStr("evl"))
107
return false;
108
109
BasicBlock *LatchBlock = L.getLoopLatch();
110
ICmpInst *OrigLatchCmp = L.getLatchCmpInst();
111
if (!LatchBlock || !OrigLatchCmp)
112
return false;
113
114
InductionDescriptor IVD;
115
PHINode *IndVar = L.getInductionVariable(SE);
116
if (!IndVar || !L.getInductionDescriptor(SE, IVD)) {
117
const char *Reason = (IndVar ? "induction descriptor is not available"
118
: "cannot recognize induction variable");
119
LLVM_DEBUG(dbgs() << "Cannot retrieve IV from loop " << L.getName()
120
<< " because" << Reason << "\n");
121
if (ORE) {
122
ORE->emit([&]() {
123
return OptimizationRemarkMissed(DEBUG_TYPE, "UnrecognizedIndVar",
124
L.getStartLoc(), L.getHeader())
125
<< "Cannot retrieve IV because " << ore::NV("Reason", Reason);
126
});
127
}
128
return false;
129
}
130
131
BasicBlock *InitBlock, *BackEdgeBlock;
132
if (!L.getIncomingAndBackEdge(InitBlock, BackEdgeBlock)) {
133
LLVM_DEBUG(dbgs() << "Expect unique incoming and backedge in "
134
<< L.getName() << "\n");
135
if (ORE) {
136
ORE->emit([&]() {
137
return OptimizationRemarkMissed(DEBUG_TYPE, "UnrecognizedLoopStructure",
138
L.getStartLoc(), L.getHeader())
139
<< "Does not have a unique incoming and backedge";
140
});
141
}
142
return false;
143
}
144
145
// Retrieve the loop bounds.
146
std::optional<Loop::LoopBounds> Bounds = L.getBounds(SE);
147
if (!Bounds) {
148
LLVM_DEBUG(dbgs() << "Could not obtain the bounds for loop " << L.getName()
149
<< "\n");
150
if (ORE) {
151
ORE->emit([&]() {
152
return OptimizationRemarkMissed(DEBUG_TYPE, "UnrecognizedLoopStructure",
153
L.getStartLoc(), L.getHeader())
154
<< "Could not obtain the loop bounds";
155
});
156
}
157
return false;
158
}
159
Value *CanonicalIVInit = &Bounds->getInitialIVValue();
160
Value *CanonicalIVFinal = &Bounds->getFinalIVValue();
161
162
const SCEV *StepV = IVD.getStep();
163
uint32_t VF = getVFFromIndVar(StepV, *L.getHeader()->getParent());
164
if (!VF) {
165
LLVM_DEBUG(dbgs() << "Could not infer VF from IndVar step '" << *StepV
166
<< "'\n");
167
if (ORE) {
168
ORE->emit([&]() {
169
return OptimizationRemarkMissed(DEBUG_TYPE, "UnrecognizedIndVar",
170
L.getStartLoc(), L.getHeader())
171
<< "Could not infer VF from IndVar step "
172
<< ore::NV("Step", StepV);
173
});
174
}
175
return false;
176
}
177
LLVM_DEBUG(dbgs() << "Using VF=" << VF << " for loop " << L.getName()
178
<< "\n");
179
180
// Try to find the EVL-based induction variable.
181
using namespace PatternMatch;
182
BasicBlock *BB = IndVar->getParent();
183
184
Value *EVLIndVar = nullptr;
185
Value *RemTC = nullptr;
186
Value *TC = nullptr;
187
auto IntrinsicMatch = m_Intrinsic<Intrinsic::experimental_get_vector_length>(
188
m_Value(RemTC), m_SpecificInt(VF),
189
/*Scalable=*/m_SpecificInt(1));
190
for (PHINode &PN : BB->phis()) {
191
if (&PN == IndVar)
192
continue;
193
194
// Check 1: it has to contain both incoming (init) & backedge blocks
195
// from IndVar.
196
if (PN.getBasicBlockIndex(InitBlock) < 0 ||
197
PN.getBasicBlockIndex(BackEdgeBlock) < 0)
198
continue;
199
// Check 2: EVL index is always increasing, thus its inital value has to be
200
// equal to either the initial IV value (when the canonical IV is also
201
// increasing) or the last IV value (when canonical IV is decreasing).
202
Value *Init = PN.getIncomingValueForBlock(InitBlock);
203
using Direction = Loop::LoopBounds::Direction;
204
switch (Bounds->getDirection()) {
205
case Direction::Increasing:
206
if (Init != CanonicalIVInit)
207
continue;
208
break;
209
case Direction::Decreasing:
210
if (Init != CanonicalIVFinal)
211
continue;
212
break;
213
case Direction::Unknown:
214
// To be more permissive and see if either the initial or final IV value
215
// matches PN's init value.
216
if (Init != CanonicalIVInit && Init != CanonicalIVFinal)
217
continue;
218
break;
219
}
220
Value *RecValue = PN.getIncomingValueForBlock(BackEdgeBlock);
221
assert(RecValue && "expect recurrent IndVar value");
222
223
LLVM_DEBUG(dbgs() << "Found candidate PN of EVL-based IndVar: " << PN
224
<< "\n");
225
226
// Check 3: Pattern match to find the EVL-based index and total trip count
227
// (TC).
228
if (match(RecValue,
229
m_c_Add(m_ZExtOrSelf(IntrinsicMatch), m_Specific(&PN))) &&
230
match(RemTC, m_Sub(m_Value(TC), m_Specific(&PN)))) {
231
EVLIndVar = RecValue;
232
break;
233
}
234
}
235
236
if (!EVLIndVar || !TC)
237
return false;
238
239
LLVM_DEBUG(dbgs() << "Using " << *EVLIndVar << " for EVL-based IndVar\n");
240
if (ORE) {
241
ORE->emit([&]() {
242
DebugLoc DL;
243
BasicBlock *Region = nullptr;
244
if (auto *I = dyn_cast<Instruction>(EVLIndVar)) {
245
DL = I->getDebugLoc();
246
Region = I->getParent();
247
} else {
248
DL = L.getStartLoc();
249
Region = L.getHeader();
250
}
251
return OptimizationRemark(DEBUG_TYPE, "UseEVLIndVar", DL, Region)
252
<< "Using " << ore::NV("EVLIndVar", EVLIndVar)
253
<< " for EVL-based IndVar";
254
});
255
}
256
257
// Create an EVL-based comparison and replace the branch to use it as
258
// predicate.
259
260
// Loop::getLatchCmpInst check at the beginning of this function has ensured
261
// that latch block ends in a conditional branch.
262
auto *LatchBranch = cast<BranchInst>(LatchBlock->getTerminator());
263
assert(LatchBranch->isConditional() &&
264
"expect the loop latch to be ended with a conditional branch");
265
ICmpInst::Predicate Pred;
266
if (LatchBranch->getSuccessor(0) == L.getHeader())
267
Pred = ICmpInst::ICMP_NE;
268
else
269
Pred = ICmpInst::ICMP_EQ;
270
271
IRBuilder<> Builder(OrigLatchCmp);
272
auto *NewLatchCmp = Builder.CreateICmp(Pred, EVLIndVar, TC);
273
OrigLatchCmp->replaceAllUsesWith(NewLatchCmp);
274
275
// llvm::RecursivelyDeleteDeadPHINode only deletes cycles whose values are
276
// not used outside the cycles. However, in this case the now-RAUW-ed
277
// OrigLatchCmp will be considered a use outside the cycle while in reality
278
// it's practically dead. Thus we need to remove it before calling
279
// RecursivelyDeleteDeadPHINode.
280
(void)RecursivelyDeleteTriviallyDeadInstructions(OrigLatchCmp);
281
if (llvm::RecursivelyDeleteDeadPHINode(IndVar))
282
LLVM_DEBUG(dbgs() << "Removed original IndVar\n");
283
284
++NumEliminatedCanonicalIV;
285
286
return true;
287
}
288
289
PreservedAnalyses EVLIndVarSimplifyPass::run(Loop &L, LoopAnalysisManager &LAM,
290
LoopStandardAnalysisResults &AR,
291
LPMUpdater &U) {
292
Function &F = *L.getHeader()->getParent();
293
auto &FAMProxy = LAM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR);
294
OptimizationRemarkEmitter *ORE =
295
FAMProxy.getCachedResult<OptimizationRemarkEmitterAnalysis>(F);
296
297
if (EVLIndVarSimplifyImpl(AR, ORE).run(L))
298
return PreservedAnalyses::allInSet<CFGAnalyses>();
299
return PreservedAnalyses::all();
300
}
301
302