Path: blob/main/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOCtxProfFlattening.cpp
213799 views
//===- PGOCtxProfFlattening.cpp - Contextual Instr. Flattening ------------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// Flattens the contextual profile and lowers it to MD_prof.9// This should happen after all IPO (which is assumed to have maintained the10// contextual profile) happened. Flattening consists of summing the values at11// the same index of the counters belonging to all the contexts of a function.12// The lowering consists of materializing the counter values to function13// entrypoint counts and branch probabilities.14//15// This pass also removes contextual instrumentation, which has been kept around16// to facilitate its functionality.17//18//===----------------------------------------------------------------------===//1920#include "llvm/Transforms/Instrumentation/PGOCtxProfFlattening.h"21#include "llvm/ADT/STLExtras.h"22#include "llvm/ADT/ScopeExit.h"23#include "llvm/Analysis/CFG.h"24#include "llvm/Analysis/CtxProfAnalysis.h"25#include "llvm/Analysis/ProfileSummaryInfo.h"26#include "llvm/IR/Analysis.h"27#include "llvm/IR/CFG.h"28#include "llvm/IR/Dominators.h"29#include "llvm/IR/Instructions.h"30#include "llvm/IR/IntrinsicInst.h"31#include "llvm/IR/Module.h"32#include "llvm/IR/PassManager.h"33#include "llvm/IR/ProfileSummary.h"34#include "llvm/ProfileData/ProfileCommon.h"35#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"36#include "llvm/Transforms/Scalar/DCE.h"37#include "llvm/Transforms/Utils/BasicBlockUtils.h"3839using namespace llvm;4041#define DEBUG_TYPE "ctx_prof_flatten"4243namespace {4445/// Assign branch weights and function entry count. Also update the PSI46/// builder.47void assignProfileData(Function &F, ArrayRef<uint64_t> RawCounters) {48assert(!RawCounters.empty());49ProfileAnnotator PA(F, RawCounters);5051F.setEntryCount(RawCounters[0]);52SmallVector<uint64_t, 2> ProfileHolder;5354for (auto &BB : F) {55for (auto &I : BB)56if (auto *SI = dyn_cast<SelectInst>(&I)) {57uint64_t TrueCount, FalseCount = 0;58if (!PA.getSelectInstrProfile(*SI, TrueCount, FalseCount))59continue;60setProfMetadata(F.getParent(), SI, {TrueCount, FalseCount},61std::max(TrueCount, FalseCount));62}63if (succ_size(&BB) < 2)64continue;65uint64_t MaxCount = 0;66if (!PA.getOutgoingBranchWeights(BB, ProfileHolder, MaxCount))67continue;68assert(MaxCount > 0);69setProfMetadata(F.getParent(), BB.getTerminator(), ProfileHolder, MaxCount);70}71}7273[[maybe_unused]] bool areAllBBsReachable(const Function &F,74FunctionAnalysisManager &FAM) {75auto &DT = FAM.getResult<DominatorTreeAnalysis>(const_cast<Function &>(F));76return llvm::all_of(77F, [&](const BasicBlock &BB) { return DT.isReachableFromEntry(&BB); });78}7980void clearColdFunctionProfile(Function &F) {81for (auto &BB : F)82BB.getTerminator()->setMetadata(LLVMContext::MD_prof, nullptr);83F.setEntryCount(0U);84}8586void removeInstrumentation(Function &F) {87for (auto &BB : F)88for (auto &I : llvm::make_early_inc_range(BB))89if (isa<InstrProfCntrInstBase>(I))90I.eraseFromParent();91}9293void annotateIndirectCall(94Module &M, CallBase &CB,95const DenseMap<uint32_t, FlatIndirectTargets> &FlatProf,96const InstrProfCallsite &Ins) {97auto Idx = Ins.getIndex()->getZExtValue();98auto FIt = FlatProf.find(Idx);99if (FIt == FlatProf.end())100return;101const auto &Targets = FIt->second;102SmallVector<InstrProfValueData, 2> Data;103uint64_t Sum = 0;104for (auto &[Guid, Count] : Targets) {105Data.push_back({/*.Value=*/Guid, /*.Count=*/Count});106Sum += Count;107}108109llvm::sort(Data,110[](const InstrProfValueData &A, const InstrProfValueData &B) {111return A.Count > B.Count;112});113llvm::annotateValueSite(M, CB, Data, Sum,114InstrProfValueKind::IPVK_IndirectCallTarget,115Data.size());116LLVM_DEBUG(dbgs() << "[ctxprof] flat indirect call prof: " << CB117<< CB.getMetadata(LLVMContext::MD_prof) << "\n");118}119120// We normally return a "Changed" bool, but the calling pass' run assumes121// something will change - some profile will be added - so this won't add much122// by returning false when applicable.123void annotateIndirectCalls(Module &M, const CtxProfAnalysis::Result &CtxProf) {124const auto FlatIndCalls = CtxProf.flattenVirtCalls();125for (auto &F : M) {126if (F.isDeclaration())127continue;128auto FlatProfIter = FlatIndCalls.find(AssignGUIDPass::getGUID(F));129if (FlatProfIter == FlatIndCalls.end())130continue;131const auto &FlatProf = FlatProfIter->second;132for (auto &BB : F) {133for (auto &I : BB) {134auto *CB = dyn_cast<CallBase>(&I);135if (!CB || !CB->isIndirectCall())136continue;137if (auto *Ins = CtxProfAnalysis::getCallsiteInstrumentation(*CB))138annotateIndirectCall(M, *CB, FlatProf, *Ins);139}140}141}142}143144} // namespace145146PreservedAnalyses PGOCtxProfFlatteningPass::run(Module &M,147ModuleAnalysisManager &MAM) {148// Ensure in all cases the instrumentation is removed: if this module had no149// roots, the contextual profile would evaluate to false, but there would150// still be instrumentation.151// Note: in such cases we leave as-is any other profile info (if present -152// e.g. synthetic weights, etc) because it wouldn't interfere with the153// contextual - based one (which would be in other modules)154auto OnExit = llvm::make_scope_exit([&]() {155if (IsPreThinlink)156return;157for (auto &F : M)158removeInstrumentation(F);159});160auto &CtxProf = MAM.getResult<CtxProfAnalysis>(M);161// post-thinlink, we only reprocess for the module(s) containing the162// contextual tree. For everything else, OnExit will just clean the163// instrumentation.164if (!IsPreThinlink && !CtxProf.isInSpecializedModule())165return PreservedAnalyses::none();166167if (IsPreThinlink)168annotateIndirectCalls(M, CtxProf);169const auto FlattenedProfile = CtxProf.flatten();170171for (auto &F : M) {172if (F.isDeclaration())173continue;174175assert(areAllBBsReachable(176F, MAM.getResult<FunctionAnalysisManagerModuleProxy>(M)177.getManager()) &&178"Function has unreacheable basic blocks. The expectation was that "179"DCE was run before.");180181auto It = FlattenedProfile.find(AssignGUIDPass::getGUID(F));182// If this function didn't appear in the contextual profile, it's cold.183if (It == FlattenedProfile.end())184clearColdFunctionProfile(F);185else186assignProfileData(F, It->second);187}188InstrProfSummaryBuilder PB(ProfileSummaryBuilder::DefaultCutoffs);189// use here the flat profiles just so the importer doesn't complain about190// how different the PSIs are between the module with the roots and the191// various modules it imports.192for (auto &C : FlattenedProfile) {193PB.addEntryCount(C.second[0]);194for (auto V : llvm::drop_begin(C.second))195PB.addInternalCount(V);196}197198M.setProfileSummary(PB.getSummary()->getMD(M.getContext()),199ProfileSummary::Kind::PSK_Instr);200PreservedAnalyses PA;201PA.abandon<ProfileSummaryAnalysis>();202MAM.invalidate(M, PA);203auto &PSI = MAM.getResult<ProfileSummaryAnalysis>(M);204PSI.refresh(PB.getSummary());205return PreservedAnalyses::none();206}207208209