Path: blob/main/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopTermFold.cpp
213799 views
//===- LoopTermFold.cpp - Eliminate last use of IV in exit branch----------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//===----------------------------------------------------------------------===//89#include "llvm/Transforms/Scalar/LoopTermFold.h"10#include "llvm/ADT/Statistic.h"11#include "llvm/Analysis/LoopAnalysisManager.h"12#include "llvm/Analysis/LoopInfo.h"13#include "llvm/Analysis/LoopPass.h"14#include "llvm/Analysis/MemorySSA.h"15#include "llvm/Analysis/MemorySSAUpdater.h"16#include "llvm/Analysis/ScalarEvolution.h"17#include "llvm/Analysis/ScalarEvolutionExpressions.h"18#include "llvm/Analysis/TargetLibraryInfo.h"19#include "llvm/Analysis/TargetTransformInfo.h"20#include "llvm/Analysis/ValueTracking.h"21#include "llvm/Config/llvm-config.h"22#include "llvm/IR/BasicBlock.h"23#include "llvm/IR/Dominators.h"24#include "llvm/IR/IRBuilder.h"25#include "llvm/IR/InstrTypes.h"26#include "llvm/IR/Instruction.h"27#include "llvm/IR/Instructions.h"28#include "llvm/IR/Type.h"29#include "llvm/IR/Value.h"30#include "llvm/InitializePasses.h"31#include "llvm/Pass.h"32#include "llvm/Support/Debug.h"33#include "llvm/Support/raw_ostream.h"34#include "llvm/Transforms/Scalar.h"35#include "llvm/Transforms/Utils.h"36#include "llvm/Transforms/Utils/BasicBlockUtils.h"37#include "llvm/Transforms/Utils/Local.h"38#include "llvm/Transforms/Utils/LoopUtils.h"39#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"40#include <cassert>41#include <optional>4243using namespace llvm;4445#define DEBUG_TYPE "loop-term-fold"4647STATISTIC(NumTermFold,48"Number of terminating condition fold recognized and performed");4950static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *, bool>>51canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,52const LoopInfo &LI, const TargetTransformInfo &TTI) {53if (!L->isInnermost()) {54LLVM_DEBUG(dbgs() << "Cannot fold on non-innermost loop\n");55return std::nullopt;56}57// Only inspect on simple loop structure58if (!L->isLoopSimplifyForm()) {59LLVM_DEBUG(dbgs() << "Cannot fold on non-simple loop\n");60return std::nullopt;61}6263if (!SE.hasLoopInvariantBackedgeTakenCount(L)) {64LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n");65return std::nullopt;66}6768BasicBlock *LoopLatch = L->getLoopLatch();69BranchInst *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator());70if (!BI || BI->isUnconditional())71return std::nullopt;72auto *TermCond = dyn_cast<ICmpInst>(BI->getCondition());73if (!TermCond) {74LLVM_DEBUG(75dbgs() << "Cannot fold on branching condition that is not an ICmpInst");76return std::nullopt;77}78if (!TermCond->hasOneUse()) {79LLVM_DEBUG(80dbgs()81<< "Cannot replace terminating condition with more than one use\n");82return std::nullopt;83}8485BinaryOperator *LHS = dyn_cast<BinaryOperator>(TermCond->getOperand(0));86Value *RHS = TermCond->getOperand(1);87if (!LHS || !L->isLoopInvariant(RHS))88// We could pattern match the inverse form of the icmp, but that is89// non-canonical, and this pass is running *very* late in the pipeline.90return std::nullopt;9192// Find the IV used by the current exit condition.93PHINode *ToFold;94Value *ToFoldStart, *ToFoldStep;95if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep))96return std::nullopt;9798// Ensure the simple recurrence is a part of the current loop.99if (ToFold->getParent() != L->getHeader())100return std::nullopt;101102// If that IV isn't dead after we rewrite the exit condition in terms of103// another IV, there's no point in doing the transform.104if (!isAlmostDeadIV(ToFold, LoopLatch, TermCond))105return std::nullopt;106107// Inserting instructions in the preheader has a runtime cost, scale108// the allowed cost with the loops trip count as best we can.109const unsigned ExpansionBudget = [&]() {110unsigned Budget = 2 * SCEVCheapExpansionBudget;111if (unsigned SmallTC = SE.getSmallConstantMaxTripCount(L))112return std::min(Budget, SmallTC);113if (std::optional<unsigned> SmallTC = getLoopEstimatedTripCount(L))114return std::min(Budget, *SmallTC);115// Unknown trip count, assume long running by default.116return Budget;117}();118119const SCEV *BECount = SE.getBackedgeTakenCount(L);120const DataLayout &DL = L->getHeader()->getDataLayout();121SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");122123PHINode *ToHelpFold = nullptr;124const SCEV *TermValueS = nullptr;125bool MustDropPoison = false;126auto InsertPt = L->getLoopPreheader()->getTerminator();127for (PHINode &PN : L->getHeader()->phis()) {128if (ToFold == &PN)129continue;130131if (!SE.isSCEVable(PN.getType())) {132LLVM_DEBUG(dbgs() << "IV of phi '" << PN133<< "' is not SCEV-able, not qualified for the "134"terminating condition folding.\n");135continue;136}137const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN));138// Only speculate on affine AddRec139if (!AddRec || !AddRec->isAffine()) {140LLVM_DEBUG(dbgs() << "SCEV of phi '" << PN141<< "' is not an affine add recursion, not qualified "142"for the terminating condition folding.\n");143continue;144}145146// Check that we can compute the value of AddRec on the exiting iteration147// without soundness problems. evaluateAtIteration internally needs148// to multiply the stride of the iteration number - which may wrap around.149// The issue here is subtle because computing the result accounting for150// wrap is insufficient. In order to use the result in an exit test, we151// must also know that AddRec doesn't take the same value on any previous152// iteration. The simplest case to consider is a candidate IV which is153// narrower than the trip count (and thus original IV), but this can154// also happen due to non-unit strides on the candidate IVs.155if (!AddRec->hasNoSelfWrap() ||156!SE.isKnownNonZero(AddRec->getStepRecurrence(SE)))157continue;158159const SCEVAddRecExpr *PostInc = AddRec->getPostIncExpr(SE);160const SCEV *TermValueSLocal = PostInc->evaluateAtIteration(BECount, SE);161if (!Expander.isSafeToExpand(TermValueSLocal)) {162LLVM_DEBUG(163dbgs() << "Is not safe to expand terminating value for phi node" << PN164<< "\n");165continue;166}167168if (Expander.isHighCostExpansion(TermValueSLocal, L, ExpansionBudget, &TTI,169InsertPt)) {170LLVM_DEBUG(171dbgs() << "Is too expensive to expand terminating value for phi node"172<< PN << "\n");173continue;174}175176// The candidate IV may have been otherwise dead and poison from the177// very first iteration. If we can't disprove that, we can't use the IV.178if (!mustExecuteUBIfPoisonOnPathTo(&PN, LoopLatch->getTerminator(), &DT)) {179LLVM_DEBUG(dbgs() << "Can not prove poison safety for IV " << PN << "\n");180continue;181}182183// The candidate IV may become poison on the last iteration. If this184// value is not branched on, this is a well defined program. We're185// about to add a new use to this IV, and we have to ensure we don't186// insert UB which didn't previously exist.187bool MustDropPoisonLocal = false;188Instruction *PostIncV =189cast<Instruction>(PN.getIncomingValueForBlock(LoopLatch));190if (!mustExecuteUBIfPoisonOnPathTo(PostIncV, LoopLatch->getTerminator(),191&DT)) {192LLVM_DEBUG(dbgs() << "Can not prove poison safety to insert use" << PN193<< "\n");194195// If this is a complex recurrance with multiple instructions computing196// the backedge value, we might need to strip poison flags from all of197// them.198if (PostIncV->getOperand(0) != &PN)199continue;200201// In order to perform the transform, we need to drop the poison202// generating flags on this instruction (if any).203MustDropPoisonLocal = PostIncV->hasPoisonGeneratingFlags();204}205206// We pick the last legal alternate IV. We could expore choosing an optimal207// alternate IV if we had a decent heuristic to do so.208ToHelpFold = &PN;209TermValueS = TermValueSLocal;210MustDropPoison = MustDropPoisonLocal;211}212213LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs()214<< "Cannot find other AddRec IV to help folding\n";);215216LLVM_DEBUG(if (ToFold && ToHelpFold) dbgs()217<< "\nFound loop that can fold terminating condition\n"218<< " BECount (SCEV): " << *SE.getBackedgeTakenCount(L) << "\n"219<< " TermCond: " << *TermCond << "\n"220<< " BrandInst: " << *BI << "\n"221<< " ToFold: " << *ToFold << "\n"222<< " ToHelpFold: " << *ToHelpFold << "\n");223224if (!ToFold || !ToHelpFold)225return std::nullopt;226return std::make_tuple(ToFold, ToHelpFold, TermValueS, MustDropPoison);227}228229static bool RunTermFold(Loop *L, ScalarEvolution &SE, DominatorTree &DT,230LoopInfo &LI, const TargetTransformInfo &TTI,231TargetLibraryInfo &TLI, MemorySSA *MSSA) {232std::unique_ptr<MemorySSAUpdater> MSSAU;233if (MSSA)234MSSAU = std::make_unique<MemorySSAUpdater>(MSSA);235236auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI, TTI);237if (!Opt)238return false;239240auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt;241242NumTermFold++;243244BasicBlock *LoopPreheader = L->getLoopPreheader();245BasicBlock *LoopLatch = L->getLoopLatch();246247(void)ToFold;248LLVM_DEBUG(dbgs() << "To fold phi-node:\n"249<< *ToFold << "\n"250<< "New term-cond phi-node:\n"251<< *ToHelpFold << "\n");252253Value *StartValue = ToHelpFold->getIncomingValueForBlock(LoopPreheader);254(void)StartValue;255Value *LoopValue = ToHelpFold->getIncomingValueForBlock(LoopLatch);256257// See comment in canFoldTermCondOfLoop on why this is sufficient.258if (MustDrop)259cast<Instruction>(LoopValue)->dropPoisonGeneratingFlags();260261// SCEVExpander for both use in preheader and latch262const DataLayout &DL = L->getHeader()->getDataLayout();263SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");264265assert(Expander.isSafeToExpand(TermValueS) &&266"Terminating value was checked safe in canFoldTerminatingCondition");267268// Create new terminating value at loop preheader269Value *TermValue = Expander.expandCodeFor(TermValueS, ToHelpFold->getType(),270LoopPreheader->getTerminator());271272LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n"273<< *StartValue << "\n"274<< "Terminating value of new term-cond phi-node:\n"275<< *TermValue << "\n");276277// Create new terminating condition at loop latch278BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator());279ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition());280IRBuilder<> LatchBuilder(LoopLatch->getTerminator());281Value *NewTermCond =282LatchBuilder.CreateICmp(CmpInst::ICMP_EQ, LoopValue, TermValue,283"lsr_fold_term_cond.replaced_term_cond");284// Swap successors to exit loop body if IV equals to new TermValue285if (BI->getSuccessor(0) == L->getHeader())286BI->swapSuccessors();287288LLVM_DEBUG(dbgs() << "Old term-cond:\n"289<< *OldTermCond << "\n"290<< "New term-cond:\n"291<< *NewTermCond << "\n");292293BI->setCondition(NewTermCond);294295Expander.clear();296OldTermCond->eraseFromParent();297DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get());298return true;299}300301namespace {302303class LoopTermFold : public LoopPass {304public:305static char ID; // Pass ID, replacement for typeid306307LoopTermFold();308309private:310bool runOnLoop(Loop *L, LPPassManager &LPM) override;311void getAnalysisUsage(AnalysisUsage &AU) const override;312};313314} // end anonymous namespace315316LoopTermFold::LoopTermFold() : LoopPass(ID) {317initializeLoopTermFoldPass(*PassRegistry::getPassRegistry());318}319320void LoopTermFold::getAnalysisUsage(AnalysisUsage &AU) const {321AU.addRequired<LoopInfoWrapperPass>();322AU.addPreserved<LoopInfoWrapperPass>();323AU.addPreservedID(LoopSimplifyID);324AU.addRequiredID(LoopSimplifyID);325AU.addRequired<DominatorTreeWrapperPass>();326AU.addPreserved<DominatorTreeWrapperPass>();327AU.addRequired<ScalarEvolutionWrapperPass>();328AU.addPreserved<ScalarEvolutionWrapperPass>();329AU.addRequired<TargetLibraryInfoWrapperPass>();330AU.addRequired<TargetTransformInfoWrapperPass>();331AU.addPreserved<MemorySSAWrapperPass>();332}333334bool LoopTermFold::runOnLoop(Loop *L, LPPassManager & /*LPM*/) {335if (skipLoop(L))336return false;337338auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();339auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();340auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();341const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(342*L->getHeader()->getParent());343auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(344*L->getHeader()->getParent());345auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>();346MemorySSA *MSSA = nullptr;347if (MSSAAnalysis)348MSSA = &MSSAAnalysis->getMSSA();349return RunTermFold(L, SE, DT, LI, TTI, TLI, MSSA);350}351352PreservedAnalyses LoopTermFoldPass::run(Loop &L, LoopAnalysisManager &AM,353LoopStandardAnalysisResults &AR,354LPMUpdater &) {355if (!RunTermFold(&L, AR.SE, AR.DT, AR.LI, AR.TTI, AR.TLI, AR.MSSA))356return PreservedAnalyses::all();357358auto PA = getLoopPassPreservedAnalyses();359if (AR.MSSA)360PA.preserve<MemorySSAAnalysis>();361return PA;362}363364char LoopTermFold::ID = 0;365366INITIALIZE_PASS_BEGIN(LoopTermFold, "loop-term-fold", "Loop Terminator Folding",367false, false)368INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)369INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)370INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)371INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)372INITIALIZE_PASS_DEPENDENCY(LoopSimplify)373INITIALIZE_PASS_END(LoopTermFold, "loop-term-fold", "Loop Terminator Folding",374false, false)375376Pass *llvm::createLoopTermFoldPass() { return new LoopTermFold(); }377378379