Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp
35294 views
1
//===- ConvergenceRegionAnalysis.h -----------------------------*- C++ -*--===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// The analysis determines the convergence region for each basic block of
10
// the module, and provides a tree-like structure describing the region
11
// hierarchy.
12
//
13
//===----------------------------------------------------------------------===//
14
15
#include "SPIRVConvergenceRegionAnalysis.h"
16
#include "llvm/Analysis/LoopInfo.h"
17
#include "llvm/IR/Dominators.h"
18
#include "llvm/IR/IntrinsicInst.h"
19
#include "llvm/InitializePasses.h"
20
#include "llvm/Transforms/Utils/LoopSimplify.h"
21
#include <optional>
22
#include <queue>
23
24
#define DEBUG_TYPE "spirv-convergence-region-analysis"
25
26
using namespace llvm;
27
28
namespace llvm {
29
void initializeSPIRVConvergenceRegionAnalysisWrapperPassPass(PassRegistry &);
30
} // namespace llvm
31
32
INITIALIZE_PASS_BEGIN(SPIRVConvergenceRegionAnalysisWrapperPass,
33
"convergence-region",
34
"SPIRV convergence regions analysis", true, true)
35
INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
36
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
37
INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
38
INITIALIZE_PASS_END(SPIRVConvergenceRegionAnalysisWrapperPass,
39
"convergence-region", "SPIRV convergence regions analysis",
40
true, true)
41
42
namespace llvm {
43
namespace SPIRV {
44
namespace {
45
46
template <typename BasicBlockType, typename IntrinsicInstType>
47
std::optional<IntrinsicInstType *>
48
getConvergenceTokenInternal(BasicBlockType *BB) {
49
static_assert(std::is_const_v<IntrinsicInstType> ==
50
std::is_const_v<BasicBlockType>,
51
"Constness must match between input and output.");
52
static_assert(std::is_same_v<BasicBlock, std::remove_const_t<BasicBlockType>>,
53
"Input must be a basic block.");
54
static_assert(
55
std::is_same_v<IntrinsicInst, std::remove_const_t<IntrinsicInstType>>,
56
"Output type must be an intrinsic instruction.");
57
58
for (auto &I : *BB) {
59
if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
60
switch (II->getIntrinsicID()) {
61
case Intrinsic::experimental_convergence_entry:
62
case Intrinsic::experimental_convergence_loop:
63
return II;
64
case Intrinsic::experimental_convergence_anchor: {
65
auto Bundle = II->getOperandBundle(LLVMContext::OB_convergencectrl);
66
assert(Bundle->Inputs.size() == 1 &&
67
Bundle->Inputs[0]->getType()->isTokenTy());
68
auto TII = dyn_cast<IntrinsicInst>(Bundle->Inputs[0].get());
69
assert(TII != nullptr);
70
return TII;
71
}
72
}
73
}
74
75
if (auto *CI = dyn_cast<CallInst>(&I)) {
76
auto OB = CI->getOperandBundle(LLVMContext::OB_convergencectrl);
77
if (!OB.has_value())
78
continue;
79
return dyn_cast<IntrinsicInst>(OB.value().Inputs[0]);
80
}
81
}
82
83
return std::nullopt;
84
}
85
86
// Given a ConvergenceRegion tree with |Start| as its root, finds the smallest
87
// region |Entry| belongs to. If |Entry| does not belong to the region defined
88
// by |Start|, this function returns |nullptr|.
89
ConvergenceRegion *findParentRegion(ConvergenceRegion *Start,
90
BasicBlock *Entry) {
91
ConvergenceRegion *Candidate = nullptr;
92
ConvergenceRegion *NextCandidate = Start;
93
94
while (Candidate != NextCandidate && NextCandidate != nullptr) {
95
Candidate = NextCandidate;
96
NextCandidate = nullptr;
97
98
// End of the search, we can return.
99
if (Candidate->Children.size() == 0)
100
return Candidate;
101
102
for (auto *Child : Candidate->Children) {
103
if (Child->Blocks.count(Entry) != 0) {
104
NextCandidate = Child;
105
break;
106
}
107
}
108
}
109
110
return Candidate;
111
}
112
113
} // anonymous namespace
114
115
std::optional<IntrinsicInst *> getConvergenceToken(BasicBlock *BB) {
116
return getConvergenceTokenInternal<BasicBlock, IntrinsicInst>(BB);
117
}
118
119
std::optional<const IntrinsicInst *> getConvergenceToken(const BasicBlock *BB) {
120
return getConvergenceTokenInternal<const BasicBlock, const IntrinsicInst>(BB);
121
}
122
123
ConvergenceRegion::ConvergenceRegion(DominatorTree &DT, LoopInfo &LI,
124
Function &F)
125
: DT(DT), LI(LI), Parent(nullptr) {
126
Entry = &F.getEntryBlock();
127
ConvergenceToken = getConvergenceToken(Entry);
128
for (auto &B : F) {
129
Blocks.insert(&B);
130
if (isa<ReturnInst>(B.getTerminator()))
131
Exits.insert(&B);
132
}
133
}
134
135
ConvergenceRegion::ConvergenceRegion(
136
DominatorTree &DT, LoopInfo &LI,
137
std::optional<IntrinsicInst *> ConvergenceToken, BasicBlock *Entry,
138
SmallPtrSet<BasicBlock *, 8> &&Blocks, SmallPtrSet<BasicBlock *, 2> &&Exits)
139
: DT(DT), LI(LI), ConvergenceToken(ConvergenceToken), Entry(Entry),
140
Exits(std::move(Exits)), Blocks(std::move(Blocks)) {
141
for ([[maybe_unused]] auto *BB : this->Exits)
142
assert(this->Blocks.count(BB) != 0);
143
assert(this->Blocks.count(this->Entry) != 0);
144
}
145
146
void ConvergenceRegion::releaseMemory() {
147
// Parent memory is owned by the parent.
148
Parent = nullptr;
149
for (auto *Child : Children) {
150
Child->releaseMemory();
151
delete Child;
152
}
153
Children.resize(0);
154
}
155
156
void ConvergenceRegion::dump(const unsigned IndentSize) const {
157
const std::string Indent(IndentSize, '\t');
158
dbgs() << Indent << this << ": {\n";
159
dbgs() << Indent << " Parent: " << Parent << "\n";
160
161
if (ConvergenceToken.value_or(nullptr)) {
162
dbgs() << Indent
163
<< " ConvergenceToken: " << ConvergenceToken.value()->getName()
164
<< "\n";
165
}
166
167
if (Entry->getName() != "")
168
dbgs() << Indent << " Entry: " << Entry->getName() << "\n";
169
else
170
dbgs() << Indent << " Entry: " << Entry << "\n";
171
172
dbgs() << Indent << " Exits: { ";
173
for (const auto &Exit : Exits) {
174
if (Exit->getName() != "")
175
dbgs() << Exit->getName() << ", ";
176
else
177
dbgs() << Exit << ", ";
178
}
179
dbgs() << " }\n";
180
181
dbgs() << Indent << " Blocks: { ";
182
for (const auto &Block : Blocks) {
183
if (Block->getName() != "")
184
dbgs() << Block->getName() << ", ";
185
else
186
dbgs() << Block << ", ";
187
}
188
dbgs() << " }\n";
189
190
dbgs() << Indent << " Children: {\n";
191
for (const auto Child : Children)
192
Child->dump(IndentSize + 2);
193
dbgs() << Indent << " }\n";
194
195
dbgs() << Indent << "}\n";
196
}
197
198
class ConvergenceRegionAnalyzer {
199
200
public:
201
ConvergenceRegionAnalyzer(Function &F, DominatorTree &DT, LoopInfo &LI)
202
: DT(DT), LI(LI), F(F) {}
203
204
private:
205
bool isBackEdge(const BasicBlock *From, const BasicBlock *To) const {
206
assert(From != To && "From == To. This is awkward.");
207
208
// We only handle loop in the simplified form. This means:
209
// - a single back-edge, a single latch.
210
// - meaning the back-edge target can only be the loop header.
211
// - meaning the From can only be the loop latch.
212
if (!LI.isLoopHeader(To))
213
return false;
214
215
auto *L = LI.getLoopFor(To);
216
if (L->contains(From) && L->isLoopLatch(From))
217
return true;
218
219
return false;
220
}
221
222
std::unordered_set<BasicBlock *>
223
findPathsToMatch(LoopInfo &LI, BasicBlock *From,
224
std::function<bool(const BasicBlock *)> isMatch) const {
225
std::unordered_set<BasicBlock *> Output;
226
227
if (isMatch(From))
228
Output.insert(From);
229
230
auto *Terminator = From->getTerminator();
231
for (unsigned i = 0; i < Terminator->getNumSuccessors(); ++i) {
232
auto *To = Terminator->getSuccessor(i);
233
if (isBackEdge(From, To))
234
continue;
235
236
auto ChildSet = findPathsToMatch(LI, To, isMatch);
237
if (ChildSet.size() == 0)
238
continue;
239
240
Output.insert(ChildSet.begin(), ChildSet.end());
241
Output.insert(From);
242
if (LI.isLoopHeader(From)) {
243
auto *L = LI.getLoopFor(From);
244
for (auto *BB : L->getBlocks()) {
245
Output.insert(BB);
246
}
247
}
248
}
249
250
return Output;
251
}
252
253
SmallPtrSet<BasicBlock *, 2>
254
findExitNodes(const SmallPtrSetImpl<BasicBlock *> &RegionBlocks) {
255
SmallPtrSet<BasicBlock *, 2> Exits;
256
257
for (auto *B : RegionBlocks) {
258
auto *Terminator = B->getTerminator();
259
for (unsigned i = 0; i < Terminator->getNumSuccessors(); ++i) {
260
auto *Child = Terminator->getSuccessor(i);
261
if (RegionBlocks.count(Child) == 0)
262
Exits.insert(B);
263
}
264
}
265
266
return Exits;
267
}
268
269
public:
270
ConvergenceRegionInfo analyze() {
271
ConvergenceRegion *TopLevelRegion = new ConvergenceRegion(DT, LI, F);
272
std::queue<Loop *> ToProcess;
273
for (auto *L : LI.getLoopsInPreorder())
274
ToProcess.push(L);
275
276
while (ToProcess.size() != 0) {
277
auto *L = ToProcess.front();
278
ToProcess.pop();
279
assert(L->isLoopSimplifyForm());
280
281
auto CT = getConvergenceToken(L->getHeader());
282
SmallPtrSet<BasicBlock *, 8> RegionBlocks(L->block_begin(),
283
L->block_end());
284
SmallVector<BasicBlock *> LoopExits;
285
L->getExitingBlocks(LoopExits);
286
if (CT.has_value()) {
287
for (auto *Exit : LoopExits) {
288
auto N = findPathsToMatch(LI, Exit, [&CT](const BasicBlock *block) {
289
auto Token = getConvergenceToken(block);
290
if (Token == std::nullopt)
291
return false;
292
return Token.value() == CT.value();
293
});
294
RegionBlocks.insert(N.begin(), N.end());
295
}
296
}
297
298
auto RegionExits = findExitNodes(RegionBlocks);
299
ConvergenceRegion *Region = new ConvergenceRegion(
300
DT, LI, CT, L->getHeader(), std::move(RegionBlocks),
301
std::move(RegionExits));
302
Region->Parent = findParentRegion(TopLevelRegion, Region->Entry);
303
assert(Region->Parent != nullptr && "This is impossible.");
304
Region->Parent->Children.push_back(Region);
305
}
306
307
return ConvergenceRegionInfo(TopLevelRegion);
308
}
309
310
private:
311
DominatorTree &DT;
312
LoopInfo &LI;
313
Function &F;
314
};
315
316
ConvergenceRegionInfo getConvergenceRegions(Function &F, DominatorTree &DT,
317
LoopInfo &LI) {
318
ConvergenceRegionAnalyzer Analyzer(F, DT, LI);
319
return Analyzer.analyze();
320
}
321
322
} // namespace SPIRV
323
324
char SPIRVConvergenceRegionAnalysisWrapperPass::ID = 0;
325
326
SPIRVConvergenceRegionAnalysisWrapperPass::
327
SPIRVConvergenceRegionAnalysisWrapperPass()
328
: FunctionPass(ID) {}
329
330
bool SPIRVConvergenceRegionAnalysisWrapperPass::runOnFunction(Function &F) {
331
DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
332
LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
333
334
CRI = SPIRV::getConvergenceRegions(F, DT, LI);
335
// Nothing was modified.
336
return false;
337
}
338
339
SPIRVConvergenceRegionAnalysis::Result
340
SPIRVConvergenceRegionAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
341
Result CRI;
342
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
343
auto &LI = AM.getResult<LoopAnalysis>(F);
344
CRI = SPIRV::getConvergenceRegions(F, DT, LI);
345
return CRI;
346
}
347
348
AnalysisKey SPIRVConvergenceRegionAnalysis::Key;
349
350
} // namespace llvm
351
352