Path: blob/main/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
35269 views
//===- PGOCtxProfLowering.cpp - Contextual PGO Instr. Lowering ------------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//89#include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"10#include "llvm/Analysis/OptimizationRemarkEmitter.h"11#include "llvm/IR/Analysis.h"12#include "llvm/IR/DiagnosticInfo.h"13#include "llvm/IR/IRBuilder.h"14#include "llvm/IR/Instructions.h"15#include "llvm/IR/IntrinsicInst.h"16#include "llvm/IR/Module.h"17#include "llvm/IR/PassManager.h"18#include "llvm/Support/CommandLine.h"19#include <utility>2021using namespace llvm;2223#define DEBUG_TYPE "ctx-instr-lower"2425static cl::list<std::string> ContextRoots(26"profile-context-root", cl::Hidden,27cl::desc(28"A function name, assumed to be global, which will be treated as the "29"root of an interesting graph, which will be profiled independently "30"from other similar graphs."));3132bool PGOCtxProfLoweringPass::isContextualIRPGOEnabled() {33return !ContextRoots.empty();34}3536// the names of symbols we expect in compiler-rt. Using a namespace for37// readability.38namespace CompilerRtAPINames {39static auto StartCtx = "__llvm_ctx_profile_start_context";40static auto ReleaseCtx = "__llvm_ctx_profile_release_context";41static auto GetCtx = "__llvm_ctx_profile_get_context";42static auto ExpectedCalleeTLS = "__llvm_ctx_profile_expected_callee";43static auto CallsiteTLS = "__llvm_ctx_profile_callsite";44} // namespace CompilerRtAPINames4546namespace {47// The lowering logic and state.48class CtxInstrumentationLowerer final {49Module &M;50ModuleAnalysisManager &MAM;51Type *ContextNodeTy = nullptr;52Type *ContextRootTy = nullptr;5354DenseMap<const Function *, Constant *> ContextRootMap;55Function *StartCtx = nullptr;56Function *GetCtx = nullptr;57Function *ReleaseCtx = nullptr;58GlobalVariable *ExpectedCalleeTLS = nullptr;59GlobalVariable *CallsiteInfoTLS = nullptr;6061public:62CtxInstrumentationLowerer(Module &M, ModuleAnalysisManager &MAM);63// return true if lowering happened (i.e. a change was made)64bool lowerFunction(Function &F);65};6667// llvm.instrprof.increment[.step] captures the total number of counters as one68// of its parameters, and llvm.instrprof.callsite captures the total number of69// callsites. Those values are the same for instances of those intrinsics in70// this function. Find the first instance of each and return them.71std::pair<uint32_t, uint32_t> getNrCountersAndCallsites(const Function &F) {72uint32_t NrCounters = 0;73uint32_t NrCallsites = 0;74for (const auto &BB : F) {75for (const auto &I : BB) {76if (const auto *Incr = dyn_cast<InstrProfIncrementInst>(&I)) {77uint32_t V =78static_cast<uint32_t>(Incr->getNumCounters()->getZExtValue());79assert((!NrCounters || V == NrCounters) &&80"expected all llvm.instrprof.increment[.step] intrinsics to "81"have the same total nr of counters parameter");82NrCounters = V;83} else if (const auto *CSIntr = dyn_cast<InstrProfCallsite>(&I)) {84uint32_t V =85static_cast<uint32_t>(CSIntr->getNumCounters()->getZExtValue());86assert((!NrCallsites || V == NrCallsites) &&87"expected all llvm.instrprof.callsite intrinsics to have the "88"same total nr of callsites parameter");89NrCallsites = V;90}91#if NDEBUG92if (NrCounters && NrCallsites)93return std::make_pair(NrCounters, NrCallsites);94#endif95}96}97return {NrCounters, NrCallsites};98}99} // namespace100101// set up tie-in with compiler-rt.102// NOTE!!!103// These have to match compiler-rt/lib/ctx_profile/CtxInstrProfiling.h104CtxInstrumentationLowerer::CtxInstrumentationLowerer(Module &M,105ModuleAnalysisManager &MAM)106: M(M), MAM(MAM) {107auto *PointerTy = PointerType::get(M.getContext(), 0);108auto *SanitizerMutexType = Type::getInt8Ty(M.getContext());109auto *I32Ty = Type::getInt32Ty(M.getContext());110auto *I64Ty = Type::getInt64Ty(M.getContext());111112// The ContextRoot type113ContextRootTy =114StructType::get(M.getContext(), {115PointerTy, /*FirstNode*/116PointerTy, /*FirstMemBlock*/117PointerTy, /*CurrentMem*/118SanitizerMutexType, /*Taken*/119});120// The Context header.121ContextNodeTy = StructType::get(M.getContext(), {122I64Ty, /*Guid*/123PointerTy, /*Next*/124I32Ty, /*NrCounters*/125I32Ty, /*NrCallsites*/126});127128// Define a global for each entrypoint. We'll reuse the entrypoint's name as129// prefix. We assume the entrypoint names to be unique.130for (const auto &Fname : ContextRoots) {131if (const auto *F = M.getFunction(Fname)) {132if (F->isDeclaration())133continue;134auto *G = M.getOrInsertGlobal(Fname + "_ctx_root", ContextRootTy);135cast<GlobalVariable>(G)->setInitializer(136Constant::getNullValue(ContextRootTy));137ContextRootMap.insert(std::make_pair(F, G));138for (const auto &BB : *F)139for (const auto &I : BB)140if (const auto *CB = dyn_cast<CallBase>(&I))141if (CB->isMustTailCall()) {142M.getContext().emitError(143"The function " + Fname +144" was indicated as a context root, but it features musttail "145"calls, which is not supported.");146}147}148}149150// Declare the functions we will call.151StartCtx = cast<Function>(152M.getOrInsertFunction(153CompilerRtAPINames::StartCtx,154FunctionType::get(ContextNodeTy->getPointerTo(),155{ContextRootTy->getPointerTo(), /*ContextRoot*/156I64Ty, /*Guid*/ I32Ty,157/*NrCounters*/ I32Ty /*NrCallsites*/},158false))159.getCallee());160GetCtx = cast<Function>(161M.getOrInsertFunction(CompilerRtAPINames::GetCtx,162FunctionType::get(ContextNodeTy->getPointerTo(),163{PointerTy, /*Callee*/164I64Ty, /*Guid*/165I32Ty, /*NrCounters*/166I32Ty}, /*NrCallsites*/167false))168.getCallee());169ReleaseCtx = cast<Function>(170M.getOrInsertFunction(171CompilerRtAPINames::ReleaseCtx,172FunctionType::get(Type::getVoidTy(M.getContext()),173{174ContextRootTy->getPointerTo(), /*ContextRoot*/175},176false))177.getCallee());178179// Declare the TLSes we will need to use.180CallsiteInfoTLS =181new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage,182nullptr, CompilerRtAPINames::CallsiteTLS);183CallsiteInfoTLS->setThreadLocal(true);184CallsiteInfoTLS->setVisibility(llvm::GlobalValue::HiddenVisibility);185ExpectedCalleeTLS =186new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage,187nullptr, CompilerRtAPINames::ExpectedCalleeTLS);188ExpectedCalleeTLS->setThreadLocal(true);189ExpectedCalleeTLS->setVisibility(llvm::GlobalValue::HiddenVisibility);190}191192PreservedAnalyses PGOCtxProfLoweringPass::run(Module &M,193ModuleAnalysisManager &MAM) {194CtxInstrumentationLowerer Lowerer(M, MAM);195bool Changed = false;196for (auto &F : M)197Changed |= Lowerer.lowerFunction(F);198return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();199}200201bool CtxInstrumentationLowerer::lowerFunction(Function &F) {202if (F.isDeclaration())203return false;204auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();205auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);206207Value *Guid = nullptr;208auto [NrCounters, NrCallsites] = getNrCountersAndCallsites(F);209210Value *Context = nullptr;211Value *RealContext = nullptr;212213StructType *ThisContextType = nullptr;214Value *TheRootContext = nullptr;215Value *ExpectedCalleeTLSAddr = nullptr;216Value *CallsiteInfoTLSAddr = nullptr;217218auto &Head = F.getEntryBlock();219for (auto &I : Head) {220// Find the increment intrinsic in the entry basic block.221if (auto *Mark = dyn_cast<InstrProfIncrementInst>(&I)) {222assert(Mark->getIndex()->isZero());223224IRBuilder<> Builder(Mark);225// FIXME(mtrofin): use InstrProfSymtab::getCanonicalName226Guid = Builder.getInt64(F.getGUID());227// The type of the context of this function is now knowable since we have228// NrCallsites and NrCounters. We delcare it here because it's more229// convenient - we have the Builder.230ThisContextType = StructType::get(231F.getContext(),232{ContextNodeTy, ArrayType::get(Builder.getInt64Ty(), NrCounters),233ArrayType::get(Builder.getPtrTy(), NrCallsites)});234// Figure out which way we obtain the context object for this function -235// if it's an entrypoint, then we call StartCtx, otherwise GetCtx. In the236// former case, we also set TheRootContext since we need to release it237// at the end (plus it can be used to know if we have an entrypoint or a238// regular function)239auto Iter = ContextRootMap.find(&F);240if (Iter != ContextRootMap.end()) {241TheRootContext = Iter->second;242Context = Builder.CreateCall(StartCtx, {TheRootContext, Guid,243Builder.getInt32(NrCounters),244Builder.getInt32(NrCallsites)});245ORE.emit(246[&] { return OptimizationRemark(DEBUG_TYPE, "Entrypoint", &F); });247} else {248Context =249Builder.CreateCall(GetCtx, {&F, Guid, Builder.getInt32(NrCounters),250Builder.getInt32(NrCallsites)});251ORE.emit([&] {252return OptimizationRemark(DEBUG_TYPE, "RegularFunction", &F);253});254}255// The context could be scratch.256auto *CtxAsInt = Builder.CreatePtrToInt(Context, Builder.getInt64Ty());257if (NrCallsites > 0) {258// Figure out which index of the TLS 2-element buffers to use.259// Scratch context => we use index == 1. Real contexts => index == 0.260auto *Index = Builder.CreateAnd(CtxAsInt, Builder.getInt64(1));261// The GEPs corresponding to that index, in the respective TLS.262ExpectedCalleeTLSAddr = Builder.CreateGEP(263Builder.getInt8Ty()->getPointerTo(),264Builder.CreateThreadLocalAddress(ExpectedCalleeTLS), {Index});265CallsiteInfoTLSAddr = Builder.CreateGEP(266Builder.getInt32Ty(),267Builder.CreateThreadLocalAddress(CallsiteInfoTLS), {Index});268}269// Because the context pointer may have LSB set (to indicate scratch),270// clear it for the value we use as base address for the counter vector.271// This way, if later we want to have "real" (not clobbered) buffers272// acting as scratch, the lowering (at least this part of it that deals273// with counters) stays the same.274RealContext = Builder.CreateIntToPtr(275Builder.CreateAnd(CtxAsInt, Builder.getInt64(-2)),276ThisContextType->getPointerTo());277I.eraseFromParent();278break;279}280}281if (!Context) {282ORE.emit([&] {283return OptimizationRemarkMissed(DEBUG_TYPE, "Skip", &F)284<< "Function doesn't have instrumentation, skipping";285});286return false;287}288289bool ContextWasReleased = false;290for (auto &BB : F) {291for (auto &I : llvm::make_early_inc_range(BB)) {292if (auto *Instr = dyn_cast<InstrProfCntrInstBase>(&I)) {293IRBuilder<> Builder(Instr);294switch (Instr->getIntrinsicID()) {295case llvm::Intrinsic::instrprof_increment:296case llvm::Intrinsic::instrprof_increment_step: {297// Increments (or increment-steps) are just a typical load - increment298// - store in the RealContext.299auto *AsStep = cast<InstrProfIncrementInst>(Instr);300auto *GEP = Builder.CreateGEP(301ThisContextType, RealContext,302{Builder.getInt32(0), Builder.getInt32(1), AsStep->getIndex()});303Builder.CreateStore(304Builder.CreateAdd(Builder.CreateLoad(Builder.getInt64Ty(), GEP),305AsStep->getStep()),306GEP);307} break;308case llvm::Intrinsic::instrprof_callsite:309// callsite lowering: write the called value in the expected callee310// TLS we treat the TLS as volatile because of signal handlers and to311// avoid these being moved away from the callsite they decorate.312auto *CSIntrinsic = dyn_cast<InstrProfCallsite>(Instr);313Builder.CreateStore(CSIntrinsic->getCallee(), ExpectedCalleeTLSAddr,314true);315// write the GEP of the slot in the sub-contexts portion of the316// context in TLS. Now, here, we use the actual Context value - as317// returned from compiler-rt - which may have the LSB set if the318// Context was scratch. Since the header of the context object and319// then the values are all 8-aligned (or, really, insofar as we care,320// they are even) - if the context is scratch (meaning, an odd value),321// so will the GEP. This is important because this is then visible to322// compiler-rt which will produce scratch contexts for callers that323// have a scratch context.324Builder.CreateStore(325Builder.CreateGEP(ThisContextType, Context,326{Builder.getInt32(0), Builder.getInt32(2),327CSIntrinsic->getIndex()}),328CallsiteInfoTLSAddr, true);329break;330}331I.eraseFromParent();332} else if (TheRootContext && isa<ReturnInst>(I)) {333// Remember to release the context if we are an entrypoint.334IRBuilder<> Builder(&I);335Builder.CreateCall(ReleaseCtx, {TheRootContext});336ContextWasReleased = true;337}338}339}340// FIXME: This would happen if the entrypoint tailcalls. A way to fix would be341// to disallow this, (so this then stays as an error), another is to detect342// that and then do a wrapper or disallow the tail call. This only affects343// instrumentation, when we want to detect the call graph.344if (TheRootContext && !ContextWasReleased)345F.getContext().emitError(346"[ctx_prof] An entrypoint was instrumented but it has no `ret` "347"instructions above which to release the context: " +348F.getName());349return true;350}351352353