Path: blob/main/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp
35266 views
//===-------- LoopDataPrefetch.cpp - Loop Data Prefetching Pass -----------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// This file implements a Loop Data Prefetching Pass.9//10//===----------------------------------------------------------------------===//1112#include "llvm/Transforms/Scalar/LoopDataPrefetch.h"13#include "llvm/InitializePasses.h"1415#include "llvm/ADT/DepthFirstIterator.h"16#include "llvm/ADT/Statistic.h"17#include "llvm/Analysis/AssumptionCache.h"18#include "llvm/Analysis/CodeMetrics.h"19#include "llvm/Analysis/LoopInfo.h"20#include "llvm/Analysis/OptimizationRemarkEmitter.h"21#include "llvm/Analysis/ScalarEvolution.h"22#include "llvm/Analysis/ScalarEvolutionExpressions.h"23#include "llvm/Analysis/TargetTransformInfo.h"24#include "llvm/IR/Dominators.h"25#include "llvm/IR/Function.h"26#include "llvm/IR/Module.h"27#include "llvm/Support/CommandLine.h"28#include "llvm/Support/Debug.h"29#include "llvm/Transforms/Scalar.h"30#include "llvm/Transforms/Utils.h"31#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"3233#define DEBUG_TYPE "loop-data-prefetch"3435using namespace llvm;3637// By default, we limit this to creating 16 PHIs (which is a little over half38// of the allocatable register set).39static cl::opt<bool>40PrefetchWrites("loop-prefetch-writes", cl::Hidden, cl::init(false),41cl::desc("Prefetch write addresses"));4243static cl::opt<unsigned>44PrefetchDistance("prefetch-distance",45cl::desc("Number of instructions to prefetch ahead"),46cl::Hidden);4748static cl::opt<unsigned>49MinPrefetchStride("min-prefetch-stride",50cl::desc("Min stride to add prefetches"), cl::Hidden);5152static cl::opt<unsigned> MaxPrefetchIterationsAhead(53"max-prefetch-iters-ahead",54cl::desc("Max number of iterations to prefetch ahead"), cl::Hidden);5556STATISTIC(NumPrefetches, "Number of prefetches inserted");5758namespace {5960/// Loop prefetch implementation class.61class LoopDataPrefetch {62public:63LoopDataPrefetch(AssumptionCache *AC, DominatorTree *DT, LoopInfo *LI,64ScalarEvolution *SE, const TargetTransformInfo *TTI,65OptimizationRemarkEmitter *ORE)66: AC(AC), DT(DT), LI(LI), SE(SE), TTI(TTI), ORE(ORE) {}6768bool run();6970private:71bool runOnLoop(Loop *L);7273/// Check if the stride of the accesses is large enough to74/// warrant a prefetch.75bool isStrideLargeEnough(const SCEVAddRecExpr *AR, unsigned TargetMinStride);7677unsigned getMinPrefetchStride(unsigned NumMemAccesses,78unsigned NumStridedMemAccesses,79unsigned NumPrefetches,80bool HasCall) {81if (MinPrefetchStride.getNumOccurrences() > 0)82return MinPrefetchStride;83return TTI->getMinPrefetchStride(NumMemAccesses, NumStridedMemAccesses,84NumPrefetches, HasCall);85}8687unsigned getPrefetchDistance() {88if (PrefetchDistance.getNumOccurrences() > 0)89return PrefetchDistance;90return TTI->getPrefetchDistance();91}9293unsigned getMaxPrefetchIterationsAhead() {94if (MaxPrefetchIterationsAhead.getNumOccurrences() > 0)95return MaxPrefetchIterationsAhead;96return TTI->getMaxPrefetchIterationsAhead();97}9899bool doPrefetchWrites() {100if (PrefetchWrites.getNumOccurrences() > 0)101return PrefetchWrites;102return TTI->enableWritePrefetching();103}104105AssumptionCache *AC;106DominatorTree *DT;107LoopInfo *LI;108ScalarEvolution *SE;109const TargetTransformInfo *TTI;110OptimizationRemarkEmitter *ORE;111};112113/// Legacy class for inserting loop data prefetches.114class LoopDataPrefetchLegacyPass : public FunctionPass {115public:116static char ID; // Pass ID, replacement for typeid117LoopDataPrefetchLegacyPass() : FunctionPass(ID) {118initializeLoopDataPrefetchLegacyPassPass(*PassRegistry::getPassRegistry());119}120121void getAnalysisUsage(AnalysisUsage &AU) const override {122AU.addRequired<AssumptionCacheTracker>();123AU.addRequired<DominatorTreeWrapperPass>();124AU.addPreserved<DominatorTreeWrapperPass>();125AU.addRequired<LoopInfoWrapperPass>();126AU.addPreserved<LoopInfoWrapperPass>();127AU.addRequiredID(LoopSimplifyID);128AU.addPreservedID(LoopSimplifyID);129AU.addRequired<OptimizationRemarkEmitterWrapperPass>();130AU.addRequired<ScalarEvolutionWrapperPass>();131AU.addPreserved<ScalarEvolutionWrapperPass>();132AU.addRequired<TargetTransformInfoWrapperPass>();133}134135bool runOnFunction(Function &F) override;136};137}138139char LoopDataPrefetchLegacyPass::ID = 0;140INITIALIZE_PASS_BEGIN(LoopDataPrefetchLegacyPass, "loop-data-prefetch",141"Loop Data Prefetch", false, false)142INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)143INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)144INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)145INITIALIZE_PASS_DEPENDENCY(LoopSimplify)146INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)147INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)148INITIALIZE_PASS_END(LoopDataPrefetchLegacyPass, "loop-data-prefetch",149"Loop Data Prefetch", false, false)150151FunctionPass *llvm::createLoopDataPrefetchPass() {152return new LoopDataPrefetchLegacyPass();153}154155bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr *AR,156unsigned TargetMinStride) {157// No need to check if any stride goes.158if (TargetMinStride <= 1)159return true;160161const auto *ConstStride = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*SE));162// If MinStride is set, don't prefetch unless we can ensure that stride is163// larger.164if (!ConstStride)165return false;166167unsigned AbsStride = std::abs(ConstStride->getAPInt().getSExtValue());168return TargetMinStride <= AbsStride;169}170171PreservedAnalyses LoopDataPrefetchPass::run(Function &F,172FunctionAnalysisManager &AM) {173DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F);174LoopInfo *LI = &AM.getResult<LoopAnalysis>(F);175ScalarEvolution *SE = &AM.getResult<ScalarEvolutionAnalysis>(F);176AssumptionCache *AC = &AM.getResult<AssumptionAnalysis>(F);177OptimizationRemarkEmitter *ORE =178&AM.getResult<OptimizationRemarkEmitterAnalysis>(F);179const TargetTransformInfo *TTI = &AM.getResult<TargetIRAnalysis>(F);180181LoopDataPrefetch LDP(AC, DT, LI, SE, TTI, ORE);182bool Changed = LDP.run();183184if (Changed) {185PreservedAnalyses PA;186PA.preserve<DominatorTreeAnalysis>();187PA.preserve<LoopAnalysis>();188return PA;189}190191return PreservedAnalyses::all();192}193194bool LoopDataPrefetchLegacyPass::runOnFunction(Function &F) {195if (skipFunction(F))196return false;197198DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();199LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();200ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();201AssumptionCache *AC =202&getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);203OptimizationRemarkEmitter *ORE =204&getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();205const TargetTransformInfo *TTI =206&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);207208LoopDataPrefetch LDP(AC, DT, LI, SE, TTI, ORE);209return LDP.run();210}211212bool LoopDataPrefetch::run() {213// If PrefetchDistance is not set, don't run the pass. This gives an214// opportunity for targets to run this pass for selected subtargets only215// (whose TTI sets PrefetchDistance and CacheLineSize).216if (getPrefetchDistance() == 0 || TTI->getCacheLineSize() == 0) {217LLVM_DEBUG(dbgs() << "Please set both PrefetchDistance and CacheLineSize "218"for loop data prefetch.\n");219return false;220}221222bool MadeChange = false;223224for (Loop *I : *LI)225for (Loop *L : depth_first(I))226MadeChange |= runOnLoop(L);227228return MadeChange;229}230231/// A record for a potential prefetch made during the initial scan of the232/// loop. This is used to let a single prefetch target multiple memory accesses.233struct Prefetch {234/// The address formula for this prefetch as returned by ScalarEvolution.235const SCEVAddRecExpr *LSCEVAddRec;236/// The point of insertion for the prefetch instruction.237Instruction *InsertPt = nullptr;238/// True if targeting a write memory access.239bool Writes = false;240/// The (first seen) prefetched instruction.241Instruction *MemI = nullptr;242243/// Constructor to create a new Prefetch for \p I.244Prefetch(const SCEVAddRecExpr *L, Instruction *I) : LSCEVAddRec(L) {245addInstruction(I);246};247248/// Add the instruction \param I to this prefetch. If it's not the first249/// one, 'InsertPt' and 'Writes' will be updated as required.250/// \param PtrDiff the known constant address difference to the first added251/// instruction.252void addInstruction(Instruction *I, DominatorTree *DT = nullptr,253int64_t PtrDiff = 0) {254if (!InsertPt) {255MemI = I;256InsertPt = I;257Writes = isa<StoreInst>(I);258} else {259BasicBlock *PrefBB = InsertPt->getParent();260BasicBlock *InsBB = I->getParent();261if (PrefBB != InsBB) {262BasicBlock *DomBB = DT->findNearestCommonDominator(PrefBB, InsBB);263if (DomBB != PrefBB)264InsertPt = DomBB->getTerminator();265}266267if (isa<StoreInst>(I) && PtrDiff == 0)268Writes = true;269}270}271};272273bool LoopDataPrefetch::runOnLoop(Loop *L) {274bool MadeChange = false;275276// Only prefetch in the inner-most loop277if (!L->isInnermost())278return MadeChange;279280SmallPtrSet<const Value *, 32> EphValues;281CodeMetrics::collectEphemeralValues(L, AC, EphValues);282283// Calculate the number of iterations ahead to prefetch284CodeMetrics Metrics;285bool HasCall = false;286for (const auto BB : L->blocks()) {287// If the loop already has prefetches, then assume that the user knows288// what they are doing and don't add any more.289for (auto &I : *BB) {290if (isa<CallInst>(&I) || isa<InvokeInst>(&I)) {291if (const Function *F = cast<CallBase>(I).getCalledFunction()) {292if (F->getIntrinsicID() == Intrinsic::prefetch)293return MadeChange;294if (TTI->isLoweredToCall(F))295HasCall = true;296} else { // indirect call.297HasCall = true;298}299}300}301Metrics.analyzeBasicBlock(BB, *TTI, EphValues);302}303304if (!Metrics.NumInsts.isValid())305return MadeChange;306307unsigned LoopSize = *Metrics.NumInsts.getValue();308if (!LoopSize)309LoopSize = 1;310311unsigned ItersAhead = getPrefetchDistance() / LoopSize;312if (!ItersAhead)313ItersAhead = 1;314315if (ItersAhead > getMaxPrefetchIterationsAhead())316return MadeChange;317318unsigned ConstantMaxTripCount = SE->getSmallConstantMaxTripCount(L);319if (ConstantMaxTripCount && ConstantMaxTripCount < ItersAhead + 1)320return MadeChange;321322unsigned NumMemAccesses = 0;323unsigned NumStridedMemAccesses = 0;324SmallVector<Prefetch, 16> Prefetches;325for (const auto BB : L->blocks())326for (auto &I : *BB) {327Value *PtrValue;328Instruction *MemI;329330if (LoadInst *LMemI = dyn_cast<LoadInst>(&I)) {331MemI = LMemI;332PtrValue = LMemI->getPointerOperand();333} else if (StoreInst *SMemI = dyn_cast<StoreInst>(&I)) {334if (!doPrefetchWrites()) continue;335MemI = SMemI;336PtrValue = SMemI->getPointerOperand();337} else continue;338339unsigned PtrAddrSpace = PtrValue->getType()->getPointerAddressSpace();340if (!TTI->shouldPrefetchAddressSpace(PtrAddrSpace))341continue;342NumMemAccesses++;343if (L->isLoopInvariant(PtrValue))344continue;345346const SCEV *LSCEV = SE->getSCEV(PtrValue);347const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);348if (!LSCEVAddRec)349continue;350NumStridedMemAccesses++;351352// We don't want to double prefetch individual cache lines. If this353// access is known to be within one cache line of some other one that354// has already been prefetched, then don't prefetch this one as well.355bool DupPref = false;356for (auto &Pref : Prefetches) {357const SCEV *PtrDiff = SE->getMinusSCEV(LSCEVAddRec, Pref.LSCEVAddRec);358if (const SCEVConstant *ConstPtrDiff =359dyn_cast<SCEVConstant>(PtrDiff)) {360int64_t PD = std::abs(ConstPtrDiff->getValue()->getSExtValue());361if (PD < (int64_t) TTI->getCacheLineSize()) {362Pref.addInstruction(MemI, DT, PD);363DupPref = true;364break;365}366}367}368if (!DupPref)369Prefetches.push_back(Prefetch(LSCEVAddRec, MemI));370}371372unsigned TargetMinStride =373getMinPrefetchStride(NumMemAccesses, NumStridedMemAccesses,374Prefetches.size(), HasCall);375376LLVM_DEBUG(dbgs() << "Prefetching " << ItersAhead377<< " iterations ahead (loop size: " << LoopSize << ") in "378<< L->getHeader()->getParent()->getName() << ": " << *L);379LLVM_DEBUG(dbgs() << "Loop has: "380<< NumMemAccesses << " memory accesses, "381<< NumStridedMemAccesses << " strided memory accesses, "382<< Prefetches.size() << " potential prefetch(es), "383<< "a minimum stride of " << TargetMinStride << ", "384<< (HasCall ? "calls" : "no calls") << ".\n");385386for (auto &P : Prefetches) {387// Check if the stride of the accesses is large enough to warrant a388// prefetch.389if (!isStrideLargeEnough(P.LSCEVAddRec, TargetMinStride))390continue;391392BasicBlock *BB = P.InsertPt->getParent();393SCEVExpander SCEVE(*SE, BB->getDataLayout(), "prefaddr");394const SCEV *NextLSCEV = SE->getAddExpr(P.LSCEVAddRec, SE->getMulExpr(395SE->getConstant(P.LSCEVAddRec->getType(), ItersAhead),396P.LSCEVAddRec->getStepRecurrence(*SE)));397if (!SCEVE.isSafeToExpand(NextLSCEV))398continue;399400unsigned PtrAddrSpace = NextLSCEV->getType()->getPointerAddressSpace();401Type *I8Ptr = PointerType::get(BB->getContext(), PtrAddrSpace);402Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, P.InsertPt);403404IRBuilder<> Builder(P.InsertPt);405Module *M = BB->getParent()->getParent();406Type *I32 = Type::getInt32Ty(BB->getContext());407Function *PrefetchFunc = Intrinsic::getDeclaration(408M, Intrinsic::prefetch, PrefPtrValue->getType());409Builder.CreateCall(410PrefetchFunc,411{PrefPtrValue,412ConstantInt::get(I32, P.Writes),413ConstantInt::get(I32, 3), ConstantInt::get(I32, 1)});414++NumPrefetches;415LLVM_DEBUG(dbgs() << " Access: "416<< *P.MemI->getOperand(isa<LoadInst>(P.MemI) ? 0 : 1)417<< ", SCEV: " << *P.LSCEVAddRec << "\n");418ORE->emit([&]() {419return OptimizationRemark(DEBUG_TYPE, "Prefetched", P.MemI)420<< "prefetched memory access";421});422423MadeChange = true;424}425426return MadeChange;427}428429430