Path: blob/main/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp
35266 views
//===------- LoopBoundSplit.cpp - Split Loop Bound --------------*- C++ -*-===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//78#include "llvm/Transforms/Scalar/LoopBoundSplit.h"9#include "llvm/ADT/Sequence.h"10#include "llvm/Analysis/LoopAnalysisManager.h"11#include "llvm/Analysis/LoopInfo.h"12#include "llvm/Analysis/ScalarEvolution.h"13#include "llvm/Analysis/ScalarEvolutionExpressions.h"14#include "llvm/IR/PatternMatch.h"15#include "llvm/Transforms/Scalar/LoopPassManager.h"16#include "llvm/Transforms/Utils/BasicBlockUtils.h"17#include "llvm/Transforms/Utils/Cloning.h"18#include "llvm/Transforms/Utils/LoopSimplify.h"19#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"2021#define DEBUG_TYPE "loop-bound-split"2223namespace llvm {2425using namespace PatternMatch;2627namespace {28struct ConditionInfo {29/// Branch instruction with this condition30BranchInst *BI = nullptr;31/// ICmp instruction with this condition32ICmpInst *ICmp = nullptr;33/// Preciate info34ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;35/// AddRec llvm value36Value *AddRecValue = nullptr;37/// Non PHI AddRec llvm value38Value *NonPHIAddRecValue;39/// Bound llvm value40Value *BoundValue = nullptr;41/// AddRec SCEV42const SCEVAddRecExpr *AddRecSCEV = nullptr;43/// Bound SCEV44const SCEV *BoundSCEV = nullptr;4546ConditionInfo() = default;47};48} // namespace4950static void analyzeICmp(ScalarEvolution &SE, ICmpInst *ICmp,51ConditionInfo &Cond, const Loop &L) {52Cond.ICmp = ICmp;53if (match(ICmp, m_ICmp(Cond.Pred, m_Value(Cond.AddRecValue),54m_Value(Cond.BoundValue)))) {55const SCEV *AddRecSCEV = SE.getSCEV(Cond.AddRecValue);56const SCEV *BoundSCEV = SE.getSCEV(Cond.BoundValue);57const SCEVAddRecExpr *LHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(AddRecSCEV);58const SCEVAddRecExpr *RHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(BoundSCEV);59// Locate AddRec in LHSSCEV and Bound in RHSSCEV.60if (!LHSAddRecSCEV && RHSAddRecSCEV) {61std::swap(Cond.AddRecValue, Cond.BoundValue);62std::swap(AddRecSCEV, BoundSCEV);63Cond.Pred = ICmpInst::getSwappedPredicate(Cond.Pred);64}6566Cond.AddRecSCEV = dyn_cast<SCEVAddRecExpr>(AddRecSCEV);67Cond.BoundSCEV = BoundSCEV;68Cond.NonPHIAddRecValue = Cond.AddRecValue;6970// If the Cond.AddRecValue is PHI node, update Cond.NonPHIAddRecValue with71// value from backedge.72if (Cond.AddRecSCEV && isa<PHINode>(Cond.AddRecValue)) {73PHINode *PN = cast<PHINode>(Cond.AddRecValue);74Cond.NonPHIAddRecValue = PN->getIncomingValueForBlock(L.getLoopLatch());75}76}77}7879static bool calculateUpperBound(const Loop &L, ScalarEvolution &SE,80ConditionInfo &Cond, bool IsExitCond) {81if (IsExitCond) {82const SCEV *ExitCount = SE.getExitCount(&L, Cond.ICmp->getParent());83if (isa<SCEVCouldNotCompute>(ExitCount))84return false;8586Cond.BoundSCEV = ExitCount;87return true;88}8990// For non-exit condtion, if pred is LT, keep existing bound.91if (Cond.Pred == ICmpInst::ICMP_SLT || Cond.Pred == ICmpInst::ICMP_ULT)92return true;9394// For non-exit condition, if pre is LE, try to convert it to LT.95// Range Range96// AddRec <= Bound --> AddRec < Bound + 197if (Cond.Pred != ICmpInst::ICMP_ULE && Cond.Pred != ICmpInst::ICMP_SLE)98return false;99100if (IntegerType *BoundSCEVIntType =101dyn_cast<IntegerType>(Cond.BoundSCEV->getType())) {102unsigned BitWidth = BoundSCEVIntType->getBitWidth();103APInt Max = ICmpInst::isSigned(Cond.Pred)104? APInt::getSignedMaxValue(BitWidth)105: APInt::getMaxValue(BitWidth);106const SCEV *MaxSCEV = SE.getConstant(Max);107// Check Bound < INT_MAX108ICmpInst::Predicate Pred =109ICmpInst::isSigned(Cond.Pred) ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;110if (SE.isKnownPredicate(Pred, Cond.BoundSCEV, MaxSCEV)) {111const SCEV *BoundPlusOneSCEV =112SE.getAddExpr(Cond.BoundSCEV, SE.getOne(BoundSCEVIntType));113Cond.BoundSCEV = BoundPlusOneSCEV;114Cond.Pred = Pred;115return true;116}117}118119// ToDo: Support ICMP_NE/EQ.120121return false;122}123124static bool hasProcessableCondition(const Loop &L, ScalarEvolution &SE,125ICmpInst *ICmp, ConditionInfo &Cond,126bool IsExitCond) {127analyzeICmp(SE, ICmp, Cond, L);128129// The BoundSCEV should be evaluated at loop entry.130if (!SE.isAvailableAtLoopEntry(Cond.BoundSCEV, &L))131return false;132133// Allowed AddRec as induction variable.134if (!Cond.AddRecSCEV)135return false;136137if (!Cond.AddRecSCEV->isAffine())138return false;139140const SCEV *StepRecSCEV = Cond.AddRecSCEV->getStepRecurrence(SE);141// Allowed constant step.142if (!isa<SCEVConstant>(StepRecSCEV))143return false;144145ConstantInt *StepCI = cast<SCEVConstant>(StepRecSCEV)->getValue();146// Allowed positive step for now.147// TODO: Support negative step.148if (StepCI->isNegative() || StepCI->isZero())149return false;150151// Calculate upper bound.152if (!calculateUpperBound(L, SE, Cond, IsExitCond))153return false;154155return true;156}157158static bool isProcessableCondBI(const ScalarEvolution &SE,159const BranchInst *BI) {160BasicBlock *TrueSucc = nullptr;161BasicBlock *FalseSucc = nullptr;162ICmpInst::Predicate Pred;163Value *LHS, *RHS;164if (!match(BI, m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)),165m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc))))166return false;167168if (!SE.isSCEVable(LHS->getType()))169return false;170assert(SE.isSCEVable(RHS->getType()) && "Expected RHS's type is SCEVable");171172if (TrueSucc == FalseSucc)173return false;174175return true;176}177178static bool canSplitLoopBound(const Loop &L, const DominatorTree &DT,179ScalarEvolution &SE, ConditionInfo &Cond) {180// Skip function with optsize.181if (L.getHeader()->getParent()->hasOptSize())182return false;183184// Split only innermost loop.185if (!L.isInnermost())186return false;187188// Check loop is in simplified form.189if (!L.isLoopSimplifyForm())190return false;191192// Check loop is in LCSSA form.193if (!L.isLCSSAForm(DT))194return false;195196// Skip loop that cannot be cloned.197if (!L.isSafeToClone())198return false;199200BasicBlock *ExitingBB = L.getExitingBlock();201// Assumed only one exiting block.202if (!ExitingBB)203return false;204205BranchInst *ExitingBI = dyn_cast<BranchInst>(ExitingBB->getTerminator());206if (!ExitingBI)207return false;208209// Allowed only conditional branch with ICmp.210if (!isProcessableCondBI(SE, ExitingBI))211return false;212213// Check the condition is processable.214ICmpInst *ICmp = cast<ICmpInst>(ExitingBI->getCondition());215if (!hasProcessableCondition(L, SE, ICmp, Cond, /*IsExitCond*/ true))216return false;217218Cond.BI = ExitingBI;219return true;220}221222static bool isProfitableToTransform(const Loop &L, const BranchInst *BI) {223// If the conditional branch splits a loop into two halves, we could224// generally say it is profitable.225//226// ToDo: Add more profitable cases here.227228// Check this branch causes diamond CFG.229BasicBlock *Succ0 = BI->getSuccessor(0);230BasicBlock *Succ1 = BI->getSuccessor(1);231232BasicBlock *Succ0Succ = Succ0->getSingleSuccessor();233BasicBlock *Succ1Succ = Succ1->getSingleSuccessor();234if (!Succ0Succ || !Succ1Succ || Succ0Succ != Succ1Succ)235return false;236237// ToDo: Calculate each successor's instruction cost.238239return true;240}241242static BranchInst *findSplitCandidate(const Loop &L, ScalarEvolution &SE,243ConditionInfo &ExitingCond,244ConditionInfo &SplitCandidateCond) {245for (auto *BB : L.blocks()) {246// Skip condition of backedge.247if (L.getLoopLatch() == BB)248continue;249250auto *BI = dyn_cast<BranchInst>(BB->getTerminator());251if (!BI)252continue;253254// Check conditional branch with ICmp.255if (!isProcessableCondBI(SE, BI))256continue;257258// Skip loop invariant condition.259if (L.isLoopInvariant(BI->getCondition()))260continue;261262// Check the condition is processable.263ICmpInst *ICmp = cast<ICmpInst>(BI->getCondition());264if (!hasProcessableCondition(L, SE, ICmp, SplitCandidateCond,265/*IsExitCond*/ false))266continue;267268if (ExitingCond.BoundSCEV->getType() !=269SplitCandidateCond.BoundSCEV->getType())270continue;271272// After transformation, we assume the split condition of the pre-loop is273// always true. In order to guarantee it, we need to check the start value274// of the split cond AddRec satisfies the split condition.275if (!SE.isLoopEntryGuardedByCond(&L, SplitCandidateCond.Pred,276SplitCandidateCond.AddRecSCEV->getStart(),277SplitCandidateCond.BoundSCEV))278continue;279280SplitCandidateCond.BI = BI;281return BI;282}283284return nullptr;285}286287static bool splitLoopBound(Loop &L, DominatorTree &DT, LoopInfo &LI,288ScalarEvolution &SE, LPMUpdater &U) {289ConditionInfo SplitCandidateCond;290ConditionInfo ExitingCond;291292// Check we can split this loop's bound.293if (!canSplitLoopBound(L, DT, SE, ExitingCond))294return false;295296if (!findSplitCandidate(L, SE, ExitingCond, SplitCandidateCond))297return false;298299if (!isProfitableToTransform(L, SplitCandidateCond.BI))300return false;301302// Now, we have a split candidate. Let's build a form as below.303// +--------------------+304// | preheader |305// | set up newbound |306// +--------------------+307// | /----------------\308// +--------v----v------+ |309// | header |---\ |310// | with true condition| | |311// +--------------------+ | |312// | | |313// +--------v-----------+ | |314// | if.then.BB | | |315// +--------------------+ | |316// | | |317// +--------v-----------<---/ |318// | latch >----------/319// | with newbound |320// +--------------------+321// |322// +--------v-----------+323// | preheader2 |--------------\324// | if (AddRec i != | |325// | org bound) | |326// +--------------------+ |327// | /----------------\ |328// +--------v----v------+ | |329// | header2 |---\ | |330// | conditional branch | | | |331// |with false condition| | | |332// +--------------------+ | | |333// | | | |334// +--------v-----------+ | | |335// | if.then.BB2 | | | |336// +--------------------+ | | |337// | | | |338// +--------v-----------<---/ | |339// | latch2 >----------/ |340// | with org bound | |341// +--------v-----------+ |342// | |343// | +---------------+ |344// +--> exit <-------/345// +---------------+346347// Let's create post loop.348SmallVector<BasicBlock *, 8> PostLoopBlocks;349Loop *PostLoop;350ValueToValueMapTy VMap;351BasicBlock *PreHeader = L.getLoopPreheader();352BasicBlock *SplitLoopPH = SplitEdge(PreHeader, L.getHeader(), &DT, &LI);353PostLoop = cloneLoopWithPreheader(L.getExitBlock(), SplitLoopPH, &L, VMap,354".split", &LI, &DT, PostLoopBlocks);355remapInstructionsInBlocks(PostLoopBlocks, VMap);356357BasicBlock *PostLoopPreHeader = PostLoop->getLoopPreheader();358IRBuilder<> Builder(&PostLoopPreHeader->front());359360// Update phi nodes in header of post-loop.361bool isExitingLatch =362(L.getExitingBlock() == L.getLoopLatch()) ? true : false;363Value *ExitingCondLCSSAPhi = nullptr;364for (PHINode &PN : L.getHeader()->phis()) {365// Create LCSSA phi node in preheader of post-loop.366PHINode *LCSSAPhi =367Builder.CreatePHI(PN.getType(), 1, PN.getName() + ".lcssa");368LCSSAPhi->setDebugLoc(PN.getDebugLoc());369// If the exiting block is loop latch, the phi does not have the update at370// last iteration. In this case, update lcssa phi with value from backedge.371LCSSAPhi->addIncoming(372isExitingLatch ? PN.getIncomingValueForBlock(L.getLoopLatch()) : &PN,373L.getExitingBlock());374375// Update the start value of phi node in post-loop with the LCSSA phi node.376PHINode *PostLoopPN = cast<PHINode>(VMap[&PN]);377PostLoopPN->setIncomingValueForBlock(PostLoopPreHeader, LCSSAPhi);378379// Find PHI with exiting condition from pre-loop. The PHI should be380// SCEVAddRecExpr and have same incoming value from backedge with381// ExitingCond.382if (!SE.isSCEVable(PN.getType()))383continue;384385const SCEVAddRecExpr *PhiSCEV = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN));386if (PhiSCEV && ExitingCond.NonPHIAddRecValue ==387PN.getIncomingValueForBlock(L.getLoopLatch()))388ExitingCondLCSSAPhi = LCSSAPhi;389}390391// Add conditional branch to check we can skip post-loop in its preheader.392Instruction *OrigBI = PostLoopPreHeader->getTerminator();393ICmpInst::Predicate Pred = ICmpInst::ICMP_NE;394Value *Cond =395Builder.CreateICmp(Pred, ExitingCondLCSSAPhi, ExitingCond.BoundValue);396Builder.CreateCondBr(Cond, PostLoop->getHeader(), PostLoop->getExitBlock());397OrigBI->eraseFromParent();398399// Create new loop bound and add it into preheader of pre-loop.400const SCEV *NewBoundSCEV = ExitingCond.BoundSCEV;401const SCEV *SplitBoundSCEV = SplitCandidateCond.BoundSCEV;402NewBoundSCEV = ICmpInst::isSigned(ExitingCond.Pred)403? SE.getSMinExpr(NewBoundSCEV, SplitBoundSCEV)404: SE.getUMinExpr(NewBoundSCEV, SplitBoundSCEV);405406SCEVExpander Expander(407SE, L.getHeader()->getDataLayout(), "split");408Instruction *InsertPt = SplitLoopPH->getTerminator();409Value *NewBoundValue =410Expander.expandCodeFor(NewBoundSCEV, NewBoundSCEV->getType(), InsertPt);411NewBoundValue->setName("new.bound");412413// Replace exiting bound value of pre-loop NewBound.414ExitingCond.ICmp->setOperand(1, NewBoundValue);415416// Replace SplitCandidateCond.BI's condition of pre-loop by True.417LLVMContext &Context = PreHeader->getContext();418SplitCandidateCond.BI->setCondition(ConstantInt::getTrue(Context));419420// Replace cloned SplitCandidateCond.BI's condition in post-loop by False.421BranchInst *ClonedSplitCandidateBI =422cast<BranchInst>(VMap[SplitCandidateCond.BI]);423ClonedSplitCandidateBI->setCondition(ConstantInt::getFalse(Context));424425// Replace exit branch target of pre-loop by post-loop's preheader.426if (L.getExitBlock() == ExitingCond.BI->getSuccessor(0))427ExitingCond.BI->setSuccessor(0, PostLoopPreHeader);428else429ExitingCond.BI->setSuccessor(1, PostLoopPreHeader);430431// Update phi node in exit block of post-loop.432Builder.SetInsertPoint(PostLoopPreHeader, PostLoopPreHeader->begin());433for (PHINode &PN : PostLoop->getExitBlock()->phis()) {434for (auto i : seq<int>(0, PN.getNumOperands())) {435// Check incoming block is pre-loop's exiting block.436if (PN.getIncomingBlock(i) == L.getExitingBlock()) {437Value *IncomingValue = PN.getIncomingValue(i);438439// Create LCSSA phi node for incoming value.440PHINode *LCSSAPhi =441Builder.CreatePHI(PN.getType(), 1, PN.getName() + ".lcssa");442LCSSAPhi->setDebugLoc(PN.getDebugLoc());443LCSSAPhi->addIncoming(IncomingValue, PN.getIncomingBlock(i));444445// Replace pre-loop's exiting block by post-loop's preheader.446PN.setIncomingBlock(i, PostLoopPreHeader);447// Replace incoming value by LCSSAPhi.448PN.setIncomingValue(i, LCSSAPhi);449// Add a new incoming value with post-loop's exiting block.450PN.addIncoming(VMap[IncomingValue], PostLoop->getExitingBlock());451}452}453}454455// Update dominator tree.456DT.changeImmediateDominator(PostLoopPreHeader, L.getExitingBlock());457DT.changeImmediateDominator(PostLoop->getExitBlock(), PostLoopPreHeader);458459// Invalidate cached SE information.460SE.forgetLoop(&L);461462// Canonicalize loops.463simplifyLoop(&L, &DT, &LI, &SE, nullptr, nullptr, true);464simplifyLoop(PostLoop, &DT, &LI, &SE, nullptr, nullptr, true);465466// Add new post-loop to loop pass manager.467U.addSiblingLoops(PostLoop);468469return true;470}471472PreservedAnalyses LoopBoundSplitPass::run(Loop &L, LoopAnalysisManager &AM,473LoopStandardAnalysisResults &AR,474LPMUpdater &U) {475Function &F = *L.getHeader()->getParent();476(void)F;477478LLVM_DEBUG(dbgs() << "Spliting bound of loop in " << F.getName() << ": " << L479<< "\n");480481if (!splitLoopBound(L, AR.DT, AR.LI, AR.SE, U))482return PreservedAnalyses::all();483484assert(AR.DT.verify(DominatorTree::VerificationLevel::Fast));485AR.LI.verify(AR.DT);486487return getLoopPassPreservedAnalyses();488}489490} // end namespace llvm491492493