Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Transforms/Scalar/JumpTableToSwitch.cpp
35266 views
1
//===- JumpTableToSwitch.cpp ----------------------------------------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
9
#include "llvm/Transforms/Scalar/JumpTableToSwitch.h"
10
#include "llvm/ADT/SmallVector.h"
11
#include "llvm/Analysis/ConstantFolding.h"
12
#include "llvm/Analysis/DomTreeUpdater.h"
13
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
14
#include "llvm/Analysis/PostDominators.h"
15
#include "llvm/IR/IRBuilder.h"
16
#include "llvm/Support/CommandLine.h"
17
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
18
19
using namespace llvm;
20
21
static cl::opt<unsigned>
22
JumpTableSizeThreshold("jump-table-to-switch-size-threshold", cl::Hidden,
23
cl::desc("Only split jump tables with size less or "
24
"equal than JumpTableSizeThreshold."),
25
cl::init(10));
26
27
// TODO: Consider adding a cost model for profitability analysis of this
28
// transformation. Currently we replace a jump table with a switch if all the
29
// functions in the jump table are smaller than the provided threshold.
30
static cl::opt<unsigned> FunctionSizeThreshold(
31
"jump-table-to-switch-function-size-threshold", cl::Hidden,
32
cl::desc("Only split jump tables containing functions whose sizes are less "
33
"or equal than this threshold."),
34
cl::init(50));
35
36
#define DEBUG_TYPE "jump-table-to-switch"
37
38
namespace {
39
struct JumpTableTy {
40
Value *Index;
41
SmallVector<Function *, 10> Funcs;
42
};
43
} // anonymous namespace
44
45
static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP,
46
PointerType *PtrTy) {
47
Constant *Ptr = dyn_cast<Constant>(GEP->getPointerOperand());
48
if (!Ptr)
49
return std::nullopt;
50
51
GlobalVariable *GV = dyn_cast<GlobalVariable>(Ptr);
52
if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer())
53
return std::nullopt;
54
55
Function &F = *GEP->getParent()->getParent();
56
const DataLayout &DL = F.getDataLayout();
57
const unsigned BitWidth =
58
DL.getIndexSizeInBits(GEP->getPointerAddressSpace());
59
MapVector<Value *, APInt> VariableOffsets;
60
APInt ConstantOffset(BitWidth, 0);
61
if (!GEP->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset))
62
return std::nullopt;
63
if (VariableOffsets.size() != 1)
64
return std::nullopt;
65
// TODO: consider supporting more general patterns
66
if (!ConstantOffset.isZero())
67
return std::nullopt;
68
APInt StrideBytes = VariableOffsets.front().second;
69
const uint64_t JumpTableSizeBytes = DL.getTypeAllocSize(GV->getValueType());
70
if (JumpTableSizeBytes % StrideBytes.getZExtValue() != 0)
71
return std::nullopt;
72
const uint64_t N = JumpTableSizeBytes / StrideBytes.getZExtValue();
73
if (N > JumpTableSizeThreshold)
74
return std::nullopt;
75
76
JumpTableTy JumpTable;
77
JumpTable.Index = VariableOffsets.front().first;
78
JumpTable.Funcs.reserve(N);
79
for (uint64_t Index = 0; Index < N; ++Index) {
80
// ConstantOffset is zero.
81
APInt Offset = Index * StrideBytes;
82
Constant *C =
83
ConstantFoldLoadFromConst(GV->getInitializer(), PtrTy, Offset, DL);
84
auto *Func = dyn_cast_or_null<Function>(C);
85
if (!Func || Func->isDeclaration() ||
86
Func->getInstructionCount() > FunctionSizeThreshold)
87
return std::nullopt;
88
JumpTable.Funcs.push_back(Func);
89
}
90
return JumpTable;
91
}
92
93
static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
94
DomTreeUpdater &DTU,
95
OptimizationRemarkEmitter &ORE) {
96
const bool IsVoid = CB->getType() == Type::getVoidTy(CB->getContext());
97
98
SmallVector<DominatorTree::UpdateType, 8> DTUpdates;
99
BasicBlock *BB = CB->getParent();
100
BasicBlock *Tail = SplitBlock(BB, CB, &DTU, nullptr, nullptr,
101
BB->getName() + Twine(".tail"));
102
DTUpdates.push_back({DominatorTree::Delete, BB, Tail});
103
BB->getTerminator()->eraseFromParent();
104
105
Function &F = *BB->getParent();
106
BasicBlock *BBUnreachable = BasicBlock::Create(
107
F.getContext(), "default.switch.case.unreachable", &F, Tail);
108
IRBuilder<> BuilderUnreachable(BBUnreachable);
109
BuilderUnreachable.CreateUnreachable();
110
111
IRBuilder<> Builder(BB);
112
SwitchInst *Switch = Builder.CreateSwitch(JT.Index, BBUnreachable);
113
DTUpdates.push_back({DominatorTree::Insert, BB, BBUnreachable});
114
115
IRBuilder<> BuilderTail(CB);
116
PHINode *PHI =
117
IsVoid ? nullptr : BuilderTail.CreatePHI(CB->getType(), JT.Funcs.size());
118
119
for (auto [Index, Func] : llvm::enumerate(JT.Funcs)) {
120
BasicBlock *B = BasicBlock::Create(Func->getContext(),
121
"call." + Twine(Index), &F, Tail);
122
DTUpdates.push_back({DominatorTree::Insert, BB, B});
123
DTUpdates.push_back({DominatorTree::Insert, B, Tail});
124
125
CallBase *Call = cast<CallBase>(CB->clone());
126
Call->setCalledFunction(Func);
127
Call->insertInto(B, B->end());
128
Switch->addCase(
129
cast<ConstantInt>(ConstantInt::get(JT.Index->getType(), Index)), B);
130
BranchInst::Create(Tail, B);
131
if (PHI)
132
PHI->addIncoming(Call, B);
133
}
134
DTU.applyUpdates(DTUpdates);
135
ORE.emit([&]() {
136
return OptimizationRemark(DEBUG_TYPE, "ReplacedJumpTableWithSwitch", CB)
137
<< "expanded indirect call into switch";
138
});
139
if (PHI)
140
CB->replaceAllUsesWith(PHI);
141
CB->eraseFromParent();
142
return Tail;
143
}
144
145
PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
146
FunctionAnalysisManager &AM) {
147
OptimizationRemarkEmitter &ORE =
148
AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
149
DominatorTree *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
150
PostDominatorTree *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F);
151
DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy);
152
bool Changed = false;
153
for (BasicBlock &BB : make_early_inc_range(F)) {
154
BasicBlock *CurrentBB = &BB;
155
while (CurrentBB) {
156
BasicBlock *SplittedOutTail = nullptr;
157
for (Instruction &I : make_early_inc_range(*CurrentBB)) {
158
auto *Call = dyn_cast<CallInst>(&I);
159
if (!Call || Call->getCalledFunction() || Call->isMustTailCall())
160
continue;
161
auto *L = dyn_cast<LoadInst>(Call->getCalledOperand());
162
// Skip atomic or volatile loads.
163
if (!L || !L->isSimple())
164
continue;
165
auto *GEP = dyn_cast<GetElementPtrInst>(L->getPointerOperand());
166
if (!GEP)
167
continue;
168
auto *PtrTy = dyn_cast<PointerType>(L->getType());
169
assert(PtrTy && "call operand must be a pointer");
170
std::optional<JumpTableTy> JumpTable = parseJumpTable(GEP, PtrTy);
171
if (!JumpTable)
172
continue;
173
SplittedOutTail = expandToSwitch(Call, *JumpTable, DTU, ORE);
174
Changed = true;
175
break;
176
}
177
CurrentBB = SplittedOutTail ? SplittedOutTail : nullptr;
178
}
179
}
180
181
if (!Changed)
182
return PreservedAnalyses::all();
183
184
PreservedAnalyses PA;
185
if (DT)
186
PA.preserve<DominatorTreeAnalysis>();
187
if (PDT)
188
PA.preserve<PostDominatorTreeAnalysis>();
189
return PA;
190
}
191
192