Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
35267 views
1
//===-- SPIRVMergeRegionExitTargets.cpp ----------------------*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// Merge the multiple exit targets of a convergence region into a single block.
10
// Each exit target will be assigned a constant value, and a phi node + switch
11
// will allow the new exit target to re-route to the correct basic block.
12
//
13
//===----------------------------------------------------------------------===//
14
15
#include "Analysis/SPIRVConvergenceRegionAnalysis.h"
16
#include "SPIRV.h"
17
#include "SPIRVSubtarget.h"
18
#include "SPIRVTargetMachine.h"
19
#include "SPIRVUtils.h"
20
#include "llvm/ADT/DenseMap.h"
21
#include "llvm/ADT/SmallPtrSet.h"
22
#include "llvm/Analysis/LoopInfo.h"
23
#include "llvm/CodeGen/IntrinsicLowering.h"
24
#include "llvm/IR/CFG.h"
25
#include "llvm/IR/Dominators.h"
26
#include "llvm/IR/IRBuilder.h"
27
#include "llvm/IR/IntrinsicInst.h"
28
#include "llvm/IR/Intrinsics.h"
29
#include "llvm/IR/IntrinsicsSPIRV.h"
30
#include "llvm/InitializePasses.h"
31
#include "llvm/Transforms/Utils/Cloning.h"
32
#include "llvm/Transforms/Utils/LoopSimplify.h"
33
#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
34
35
using namespace llvm;
36
37
namespace llvm {
38
void initializeSPIRVMergeRegionExitTargetsPass(PassRegistry &);
39
40
class SPIRVMergeRegionExitTargets : public FunctionPass {
41
public:
42
static char ID;
43
44
SPIRVMergeRegionExitTargets() : FunctionPass(ID) {
45
initializeSPIRVMergeRegionExitTargetsPass(*PassRegistry::getPassRegistry());
46
};
47
48
// Gather all the successors of |BB|.
49
// This function asserts if the terminator neither a branch, switch or return.
50
std::unordered_set<BasicBlock *> gatherSuccessors(BasicBlock *BB) {
51
std::unordered_set<BasicBlock *> output;
52
auto *T = BB->getTerminator();
53
54
if (auto *BI = dyn_cast<BranchInst>(T)) {
55
output.insert(BI->getSuccessor(0));
56
if (BI->isConditional())
57
output.insert(BI->getSuccessor(1));
58
return output;
59
}
60
61
if (auto *SI = dyn_cast<SwitchInst>(T)) {
62
output.insert(SI->getDefaultDest());
63
for (auto &Case : SI->cases())
64
output.insert(Case.getCaseSuccessor());
65
return output;
66
}
67
68
assert(isa<ReturnInst>(T) && "Unhandled terminator type.");
69
return output;
70
}
71
72
/// Create a value in BB set to the value associated with the branch the block
73
/// terminator will take.
74
llvm::Value *createExitVariable(
75
BasicBlock *BB,
76
const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) {
77
auto *T = BB->getTerminator();
78
if (isa<ReturnInst>(T))
79
return nullptr;
80
81
IRBuilder<> Builder(BB);
82
Builder.SetInsertPoint(T);
83
84
if (auto *BI = dyn_cast<BranchInst>(T)) {
85
86
BasicBlock *LHSTarget = BI->getSuccessor(0);
87
BasicBlock *RHSTarget =
88
BI->isConditional() ? BI->getSuccessor(1) : nullptr;
89
90
Value *LHS = TargetToValue.count(LHSTarget) != 0
91
? TargetToValue.at(LHSTarget)
92
: nullptr;
93
Value *RHS = TargetToValue.count(RHSTarget) != 0
94
? TargetToValue.at(RHSTarget)
95
: nullptr;
96
97
if (LHS == nullptr || RHS == nullptr)
98
return LHS == nullptr ? RHS : LHS;
99
return Builder.CreateSelect(BI->getCondition(), LHS, RHS);
100
}
101
102
// TODO: add support for switch cases.
103
llvm_unreachable("Unhandled terminator type.");
104
}
105
106
/// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|.
107
void replaceBranchTargets(BasicBlock *BB,
108
const SmallPtrSet<BasicBlock *, 4> &ToReplace,
109
BasicBlock *NewTarget) {
110
auto *T = BB->getTerminator();
111
if (isa<ReturnInst>(T))
112
return;
113
114
if (auto *BI = dyn_cast<BranchInst>(T)) {
115
for (size_t i = 0; i < BI->getNumSuccessors(); i++) {
116
if (ToReplace.count(BI->getSuccessor(i)) != 0)
117
BI->setSuccessor(i, NewTarget);
118
}
119
return;
120
}
121
122
if (auto *SI = dyn_cast<SwitchInst>(T)) {
123
for (size_t i = 0; i < SI->getNumSuccessors(); i++) {
124
if (ToReplace.count(SI->getSuccessor(i)) != 0)
125
SI->setSuccessor(i, NewTarget);
126
}
127
return;
128
}
129
130
assert(false && "Unhandled terminator type.");
131
}
132
133
// Run the pass on the given convergence region, ignoring the sub-regions.
134
// Returns true if the CFG changed, false otherwise.
135
bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,
136
const SPIRV::ConvergenceRegion *CR) {
137
// Gather all the exit targets for this region.
138
SmallPtrSet<BasicBlock *, 4> ExitTargets;
139
for (BasicBlock *Exit : CR->Exits) {
140
for (BasicBlock *Target : gatherSuccessors(Exit)) {
141
if (CR->Blocks.count(Target) == 0)
142
ExitTargets.insert(Target);
143
}
144
}
145
146
// If we have zero or one exit target, nothing do to.
147
if (ExitTargets.size() <= 1)
148
return false;
149
150
// Create the new single exit target.
151
auto F = CR->Entry->getParent();
152
auto NewExitTarget = BasicBlock::Create(F->getContext(), "new.exit", F);
153
IRBuilder<> Builder(NewExitTarget);
154
155
// CodeGen output needs to be stable. Using the set as-is would order
156
// the targets differently depending on the allocation pattern.
157
// Sorting per basic-block ordering in the function.
158
std::vector<BasicBlock *> SortedExitTargets;
159
std::vector<BasicBlock *> SortedExits;
160
for (BasicBlock &BB : *F) {
161
if (ExitTargets.count(&BB) != 0)
162
SortedExitTargets.push_back(&BB);
163
if (CR->Exits.count(&BB) != 0)
164
SortedExits.push_back(&BB);
165
}
166
167
// Creating one constant per distinct exit target. This will be route to the
168
// correct target.
169
DenseMap<BasicBlock *, ConstantInt *> TargetToValue;
170
for (BasicBlock *Target : SortedExitTargets)
171
TargetToValue.insert(
172
std::make_pair(Target, Builder.getInt32(TargetToValue.size())));
173
174
// Creating one variable per exit node, set to the constant matching the
175
// targeted external block.
176
std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable;
177
for (auto Exit : SortedExits) {
178
llvm::Value *Value = createExitVariable(Exit, TargetToValue);
179
ExitToVariable.emplace_back(std::make_pair(Exit, Value));
180
}
181
182
// Gather the correct value depending on the exit we came from.
183
llvm::PHINode *node =
184
Builder.CreatePHI(Builder.getInt32Ty(), ExitToVariable.size());
185
for (auto [BB, Value] : ExitToVariable) {
186
node->addIncoming(Value, BB);
187
}
188
189
// Creating the switch to jump to the correct exit target.
190
llvm::SwitchInst *Sw = Builder.CreateSwitch(node, SortedExitTargets[0],
191
SortedExitTargets.size() - 1);
192
for (size_t i = 1; i < SortedExitTargets.size(); i++) {
193
BasicBlock *BB = SortedExitTargets[i];
194
Sw->addCase(TargetToValue[BB], BB);
195
}
196
197
// Fix exit branches to redirect to the new exit.
198
for (auto Exit : CR->Exits)
199
replaceBranchTargets(Exit, ExitTargets, NewExitTarget);
200
201
return true;
202
}
203
204
/// Run the pass on the given convergence region and sub-regions (DFS).
205
/// Returns true if a region/sub-region was modified, false otherwise.
206
/// This returns as soon as one region/sub-region has been modified.
207
bool runOnConvergenceRegion(LoopInfo &LI,
208
const SPIRV::ConvergenceRegion *CR) {
209
for (auto *Child : CR->Children)
210
if (runOnConvergenceRegion(LI, Child))
211
return true;
212
213
return runOnConvergenceRegionNoRecurse(LI, CR);
214
}
215
216
#if !NDEBUG
217
/// Validates each edge exiting the region has the same destination basic
218
/// block.
219
void validateRegionExits(const SPIRV::ConvergenceRegion *CR) {
220
for (auto *Child : CR->Children)
221
validateRegionExits(Child);
222
223
std::unordered_set<BasicBlock *> ExitTargets;
224
for (auto *Exit : CR->Exits) {
225
auto Set = gatherSuccessors(Exit);
226
for (auto *BB : Set) {
227
if (CR->Blocks.count(BB) == 0)
228
ExitTargets.insert(BB);
229
}
230
}
231
232
assert(ExitTargets.size() <= 1);
233
}
234
#endif
235
236
virtual bool runOnFunction(Function &F) override {
237
LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
238
const auto *TopLevelRegion =
239
getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
240
.getRegionInfo()
241
.getTopLevelRegion();
242
243
// FIXME: very inefficient method: each time a region is modified, we bubble
244
// back up, and recompute the whole convergence region tree. Once the
245
// algorithm is completed and test coverage good enough, rewrite this pass
246
// to be efficient instead of simple.
247
bool modified = false;
248
while (runOnConvergenceRegion(LI, TopLevelRegion)) {
249
TopLevelRegion = getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
250
.getRegionInfo()
251
.getTopLevelRegion();
252
modified = true;
253
}
254
255
#if !defined(NDEBUG) || defined(EXPENSIVE_CHECKS)
256
validateRegionExits(TopLevelRegion);
257
#endif
258
return modified;
259
}
260
261
void getAnalysisUsage(AnalysisUsage &AU) const override {
262
AU.addRequired<DominatorTreeWrapperPass>();
263
AU.addRequired<LoopInfoWrapperPass>();
264
AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
265
FunctionPass::getAnalysisUsage(AU);
266
}
267
};
268
} // namespace llvm
269
270
char SPIRVMergeRegionExitTargets::ID = 0;
271
272
INITIALIZE_PASS_BEGIN(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
273
"SPIRV split region exit blocks", false, false)
274
INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
275
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
276
INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
277
INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass)
278
279
INITIALIZE_PASS_END(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
280
"SPIRV split region exit blocks", false, false)
281
282
FunctionPass *llvm::createSPIRVMergeRegionExitTargetsPass() {
283
return new SPIRVMergeRegionExitTargets();
284
}
285
286