Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
35269 views
1
//===- PGOCtxProfLowering.cpp - Contextual PGO Instr. Lowering ------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
10
#include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
11
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
12
#include "llvm/IR/Analysis.h"
13
#include "llvm/IR/DiagnosticInfo.h"
14
#include "llvm/IR/IRBuilder.h"
15
#include "llvm/IR/Instructions.h"
16
#include "llvm/IR/IntrinsicInst.h"
17
#include "llvm/IR/Module.h"
18
#include "llvm/IR/PassManager.h"
19
#include "llvm/Support/CommandLine.h"
20
#include <utility>
21
22
using namespace llvm;
23
24
#define DEBUG_TYPE "ctx-instr-lower"
25
26
static cl::list<std::string> ContextRoots(
27
"profile-context-root", cl::Hidden,
28
cl::desc(
29
"A function name, assumed to be global, which will be treated as the "
30
"root of an interesting graph, which will be profiled independently "
31
"from other similar graphs."));
32
33
bool PGOCtxProfLoweringPass::isContextualIRPGOEnabled() {
34
return !ContextRoots.empty();
35
}
36
37
// the names of symbols we expect in compiler-rt. Using a namespace for
38
// readability.
39
namespace CompilerRtAPINames {
40
static auto StartCtx = "__llvm_ctx_profile_start_context";
41
static auto ReleaseCtx = "__llvm_ctx_profile_release_context";
42
static auto GetCtx = "__llvm_ctx_profile_get_context";
43
static auto ExpectedCalleeTLS = "__llvm_ctx_profile_expected_callee";
44
static auto CallsiteTLS = "__llvm_ctx_profile_callsite";
45
} // namespace CompilerRtAPINames
46
47
namespace {
48
// The lowering logic and state.
49
class CtxInstrumentationLowerer final {
50
Module &M;
51
ModuleAnalysisManager &MAM;
52
Type *ContextNodeTy = nullptr;
53
Type *ContextRootTy = nullptr;
54
55
DenseMap<const Function *, Constant *> ContextRootMap;
56
Function *StartCtx = nullptr;
57
Function *GetCtx = nullptr;
58
Function *ReleaseCtx = nullptr;
59
GlobalVariable *ExpectedCalleeTLS = nullptr;
60
GlobalVariable *CallsiteInfoTLS = nullptr;
61
62
public:
63
CtxInstrumentationLowerer(Module &M, ModuleAnalysisManager &MAM);
64
// return true if lowering happened (i.e. a change was made)
65
bool lowerFunction(Function &F);
66
};
67
68
// llvm.instrprof.increment[.step] captures the total number of counters as one
69
// of its parameters, and llvm.instrprof.callsite captures the total number of
70
// callsites. Those values are the same for instances of those intrinsics in
71
// this function. Find the first instance of each and return them.
72
std::pair<uint32_t, uint32_t> getNrCountersAndCallsites(const Function &F) {
73
uint32_t NrCounters = 0;
74
uint32_t NrCallsites = 0;
75
for (const auto &BB : F) {
76
for (const auto &I : BB) {
77
if (const auto *Incr = dyn_cast<InstrProfIncrementInst>(&I)) {
78
uint32_t V =
79
static_cast<uint32_t>(Incr->getNumCounters()->getZExtValue());
80
assert((!NrCounters || V == NrCounters) &&
81
"expected all llvm.instrprof.increment[.step] intrinsics to "
82
"have the same total nr of counters parameter");
83
NrCounters = V;
84
} else if (const auto *CSIntr = dyn_cast<InstrProfCallsite>(&I)) {
85
uint32_t V =
86
static_cast<uint32_t>(CSIntr->getNumCounters()->getZExtValue());
87
assert((!NrCallsites || V == NrCallsites) &&
88
"expected all llvm.instrprof.callsite intrinsics to have the "
89
"same total nr of callsites parameter");
90
NrCallsites = V;
91
}
92
#if NDEBUG
93
if (NrCounters && NrCallsites)
94
return std::make_pair(NrCounters, NrCallsites);
95
#endif
96
}
97
}
98
return {NrCounters, NrCallsites};
99
}
100
} // namespace
101
102
// set up tie-in with compiler-rt.
103
// NOTE!!!
104
// These have to match compiler-rt/lib/ctx_profile/CtxInstrProfiling.h
105
CtxInstrumentationLowerer::CtxInstrumentationLowerer(Module &M,
106
ModuleAnalysisManager &MAM)
107
: M(M), MAM(MAM) {
108
auto *PointerTy = PointerType::get(M.getContext(), 0);
109
auto *SanitizerMutexType = Type::getInt8Ty(M.getContext());
110
auto *I32Ty = Type::getInt32Ty(M.getContext());
111
auto *I64Ty = Type::getInt64Ty(M.getContext());
112
113
// The ContextRoot type
114
ContextRootTy =
115
StructType::get(M.getContext(), {
116
PointerTy, /*FirstNode*/
117
PointerTy, /*FirstMemBlock*/
118
PointerTy, /*CurrentMem*/
119
SanitizerMutexType, /*Taken*/
120
});
121
// The Context header.
122
ContextNodeTy = StructType::get(M.getContext(), {
123
I64Ty, /*Guid*/
124
PointerTy, /*Next*/
125
I32Ty, /*NrCounters*/
126
I32Ty, /*NrCallsites*/
127
});
128
129
// Define a global for each entrypoint. We'll reuse the entrypoint's name as
130
// prefix. We assume the entrypoint names to be unique.
131
for (const auto &Fname : ContextRoots) {
132
if (const auto *F = M.getFunction(Fname)) {
133
if (F->isDeclaration())
134
continue;
135
auto *G = M.getOrInsertGlobal(Fname + "_ctx_root", ContextRootTy);
136
cast<GlobalVariable>(G)->setInitializer(
137
Constant::getNullValue(ContextRootTy));
138
ContextRootMap.insert(std::make_pair(F, G));
139
for (const auto &BB : *F)
140
for (const auto &I : BB)
141
if (const auto *CB = dyn_cast<CallBase>(&I))
142
if (CB->isMustTailCall()) {
143
M.getContext().emitError(
144
"The function " + Fname +
145
" was indicated as a context root, but it features musttail "
146
"calls, which is not supported.");
147
}
148
}
149
}
150
151
// Declare the functions we will call.
152
StartCtx = cast<Function>(
153
M.getOrInsertFunction(
154
CompilerRtAPINames::StartCtx,
155
FunctionType::get(ContextNodeTy->getPointerTo(),
156
{ContextRootTy->getPointerTo(), /*ContextRoot*/
157
I64Ty, /*Guid*/ I32Ty,
158
/*NrCounters*/ I32Ty /*NrCallsites*/},
159
false))
160
.getCallee());
161
GetCtx = cast<Function>(
162
M.getOrInsertFunction(CompilerRtAPINames::GetCtx,
163
FunctionType::get(ContextNodeTy->getPointerTo(),
164
{PointerTy, /*Callee*/
165
I64Ty, /*Guid*/
166
I32Ty, /*NrCounters*/
167
I32Ty}, /*NrCallsites*/
168
false))
169
.getCallee());
170
ReleaseCtx = cast<Function>(
171
M.getOrInsertFunction(
172
CompilerRtAPINames::ReleaseCtx,
173
FunctionType::get(Type::getVoidTy(M.getContext()),
174
{
175
ContextRootTy->getPointerTo(), /*ContextRoot*/
176
},
177
false))
178
.getCallee());
179
180
// Declare the TLSes we will need to use.
181
CallsiteInfoTLS =
182
new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage,
183
nullptr, CompilerRtAPINames::CallsiteTLS);
184
CallsiteInfoTLS->setThreadLocal(true);
185
CallsiteInfoTLS->setVisibility(llvm::GlobalValue::HiddenVisibility);
186
ExpectedCalleeTLS =
187
new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage,
188
nullptr, CompilerRtAPINames::ExpectedCalleeTLS);
189
ExpectedCalleeTLS->setThreadLocal(true);
190
ExpectedCalleeTLS->setVisibility(llvm::GlobalValue::HiddenVisibility);
191
}
192
193
PreservedAnalyses PGOCtxProfLoweringPass::run(Module &M,
194
ModuleAnalysisManager &MAM) {
195
CtxInstrumentationLowerer Lowerer(M, MAM);
196
bool Changed = false;
197
for (auto &F : M)
198
Changed |= Lowerer.lowerFunction(F);
199
return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
200
}
201
202
bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
203
if (F.isDeclaration())
204
return false;
205
auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
206
auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
207
208
Value *Guid = nullptr;
209
auto [NrCounters, NrCallsites] = getNrCountersAndCallsites(F);
210
211
Value *Context = nullptr;
212
Value *RealContext = nullptr;
213
214
StructType *ThisContextType = nullptr;
215
Value *TheRootContext = nullptr;
216
Value *ExpectedCalleeTLSAddr = nullptr;
217
Value *CallsiteInfoTLSAddr = nullptr;
218
219
auto &Head = F.getEntryBlock();
220
for (auto &I : Head) {
221
// Find the increment intrinsic in the entry basic block.
222
if (auto *Mark = dyn_cast<InstrProfIncrementInst>(&I)) {
223
assert(Mark->getIndex()->isZero());
224
225
IRBuilder<> Builder(Mark);
226
// FIXME(mtrofin): use InstrProfSymtab::getCanonicalName
227
Guid = Builder.getInt64(F.getGUID());
228
// The type of the context of this function is now knowable since we have
229
// NrCallsites and NrCounters. We delcare it here because it's more
230
// convenient - we have the Builder.
231
ThisContextType = StructType::get(
232
F.getContext(),
233
{ContextNodeTy, ArrayType::get(Builder.getInt64Ty(), NrCounters),
234
ArrayType::get(Builder.getPtrTy(), NrCallsites)});
235
// Figure out which way we obtain the context object for this function -
236
// if it's an entrypoint, then we call StartCtx, otherwise GetCtx. In the
237
// former case, we also set TheRootContext since we need to release it
238
// at the end (plus it can be used to know if we have an entrypoint or a
239
// regular function)
240
auto Iter = ContextRootMap.find(&F);
241
if (Iter != ContextRootMap.end()) {
242
TheRootContext = Iter->second;
243
Context = Builder.CreateCall(StartCtx, {TheRootContext, Guid,
244
Builder.getInt32(NrCounters),
245
Builder.getInt32(NrCallsites)});
246
ORE.emit(
247
[&] { return OptimizationRemark(DEBUG_TYPE, "Entrypoint", &F); });
248
} else {
249
Context =
250
Builder.CreateCall(GetCtx, {&F, Guid, Builder.getInt32(NrCounters),
251
Builder.getInt32(NrCallsites)});
252
ORE.emit([&] {
253
return OptimizationRemark(DEBUG_TYPE, "RegularFunction", &F);
254
});
255
}
256
// The context could be scratch.
257
auto *CtxAsInt = Builder.CreatePtrToInt(Context, Builder.getInt64Ty());
258
if (NrCallsites > 0) {
259
// Figure out which index of the TLS 2-element buffers to use.
260
// Scratch context => we use index == 1. Real contexts => index == 0.
261
auto *Index = Builder.CreateAnd(CtxAsInt, Builder.getInt64(1));
262
// The GEPs corresponding to that index, in the respective TLS.
263
ExpectedCalleeTLSAddr = Builder.CreateGEP(
264
Builder.getInt8Ty()->getPointerTo(),
265
Builder.CreateThreadLocalAddress(ExpectedCalleeTLS), {Index});
266
CallsiteInfoTLSAddr = Builder.CreateGEP(
267
Builder.getInt32Ty(),
268
Builder.CreateThreadLocalAddress(CallsiteInfoTLS), {Index});
269
}
270
// Because the context pointer may have LSB set (to indicate scratch),
271
// clear it for the value we use as base address for the counter vector.
272
// This way, if later we want to have "real" (not clobbered) buffers
273
// acting as scratch, the lowering (at least this part of it that deals
274
// with counters) stays the same.
275
RealContext = Builder.CreateIntToPtr(
276
Builder.CreateAnd(CtxAsInt, Builder.getInt64(-2)),
277
ThisContextType->getPointerTo());
278
I.eraseFromParent();
279
break;
280
}
281
}
282
if (!Context) {
283
ORE.emit([&] {
284
return OptimizationRemarkMissed(DEBUG_TYPE, "Skip", &F)
285
<< "Function doesn't have instrumentation, skipping";
286
});
287
return false;
288
}
289
290
bool ContextWasReleased = false;
291
for (auto &BB : F) {
292
for (auto &I : llvm::make_early_inc_range(BB)) {
293
if (auto *Instr = dyn_cast<InstrProfCntrInstBase>(&I)) {
294
IRBuilder<> Builder(Instr);
295
switch (Instr->getIntrinsicID()) {
296
case llvm::Intrinsic::instrprof_increment:
297
case llvm::Intrinsic::instrprof_increment_step: {
298
// Increments (or increment-steps) are just a typical load - increment
299
// - store in the RealContext.
300
auto *AsStep = cast<InstrProfIncrementInst>(Instr);
301
auto *GEP = Builder.CreateGEP(
302
ThisContextType, RealContext,
303
{Builder.getInt32(0), Builder.getInt32(1), AsStep->getIndex()});
304
Builder.CreateStore(
305
Builder.CreateAdd(Builder.CreateLoad(Builder.getInt64Ty(), GEP),
306
AsStep->getStep()),
307
GEP);
308
} break;
309
case llvm::Intrinsic::instrprof_callsite:
310
// callsite lowering: write the called value in the expected callee
311
// TLS we treat the TLS as volatile because of signal handlers and to
312
// avoid these being moved away from the callsite they decorate.
313
auto *CSIntrinsic = dyn_cast<InstrProfCallsite>(Instr);
314
Builder.CreateStore(CSIntrinsic->getCallee(), ExpectedCalleeTLSAddr,
315
true);
316
// write the GEP of the slot in the sub-contexts portion of the
317
// context in TLS. Now, here, we use the actual Context value - as
318
// returned from compiler-rt - which may have the LSB set if the
319
// Context was scratch. Since the header of the context object and
320
// then the values are all 8-aligned (or, really, insofar as we care,
321
// they are even) - if the context is scratch (meaning, an odd value),
322
// so will the GEP. This is important because this is then visible to
323
// compiler-rt which will produce scratch contexts for callers that
324
// have a scratch context.
325
Builder.CreateStore(
326
Builder.CreateGEP(ThisContextType, Context,
327
{Builder.getInt32(0), Builder.getInt32(2),
328
CSIntrinsic->getIndex()}),
329
CallsiteInfoTLSAddr, true);
330
break;
331
}
332
I.eraseFromParent();
333
} else if (TheRootContext && isa<ReturnInst>(I)) {
334
// Remember to release the context if we are an entrypoint.
335
IRBuilder<> Builder(&I);
336
Builder.CreateCall(ReleaseCtx, {TheRootContext});
337
ContextWasReleased = true;
338
}
339
}
340
}
341
// FIXME: This would happen if the entrypoint tailcalls. A way to fix would be
342
// to disallow this, (so this then stays as an error), another is to detect
343
// that and then do a wrapper or disallow the tail call. This only affects
344
// instrumentation, when we want to detect the call graph.
345
if (TheRootContext && !ContextWasReleased)
346
F.getContext().emitError(
347
"[ctx_prof] An entrypoint was instrumented but it has no `ret` "
348
"instructions above which to release the context: " +
349
F.getName());
350
return true;
351
}
352
353