Path: blob/main/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopConstrainer.cpp
35271 views
#include "llvm/Transforms/Utils/LoopConstrainer.h"1#include "llvm/Analysis/LoopInfo.h"2#include "llvm/Analysis/ScalarEvolution.h"3#include "llvm/Analysis/ScalarEvolutionExpressions.h"4#include "llvm/IR/Dominators.h"5#include "llvm/Transforms/Utils/Cloning.h"6#include "llvm/Transforms/Utils/LoopSimplify.h"7#include "llvm/Transforms/Utils/LoopUtils.h"8#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"910using namespace llvm;1112static const char *ClonedLoopTag = "loop_constrainer.loop.clone";1314#define DEBUG_TYPE "loop-constrainer"1516/// Given a loop with an deccreasing induction variable, is it possible to17/// safely calculate the bounds of a new loop using the given Predicate.18static bool isSafeDecreasingBound(const SCEV *Start, const SCEV *BoundSCEV,19const SCEV *Step, ICmpInst::Predicate Pred,20unsigned LatchBrExitIdx, Loop *L,21ScalarEvolution &SE) {22if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&23Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)24return false;2526if (!SE.isAvailableAtLoopEntry(BoundSCEV, L))27return false;2829assert(SE.isKnownNegative(Step) && "expecting negative step");3031LLVM_DEBUG(dbgs() << "isSafeDecreasingBound with:\n");32LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n");33LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n");34LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n");35LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n");36LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n");3738bool IsSigned = ICmpInst::isSigned(Pred);39// The predicate that we need to check that the induction variable lies40// within bounds.41ICmpInst::Predicate BoundPred =42IsSigned ? CmpInst::ICMP_SGT : CmpInst::ICMP_UGT;4344auto StartLG = SE.applyLoopGuards(Start, L);45auto BoundLG = SE.applyLoopGuards(BoundSCEV, L);4647if (LatchBrExitIdx == 1)48return SE.isLoopEntryGuardedByCond(L, BoundPred, StartLG, BoundLG);4950assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be either 0 or 1");5152const SCEV *StepPlusOne = SE.getAddExpr(Step, SE.getOne(Step->getType()));53unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth();54APInt Min = IsSigned ? APInt::getSignedMinValue(BitWidth)55: APInt::getMinValue(BitWidth);56const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Min), StepPlusOne);5758const SCEV *MinusOne =59SE.getMinusSCEV(BoundLG, SE.getOne(BoundLG->getType()));6061return SE.isLoopEntryGuardedByCond(L, BoundPred, StartLG, MinusOne) &&62SE.isLoopEntryGuardedByCond(L, BoundPred, BoundLG, Limit);63}6465/// Given a loop with an increasing induction variable, is it possible to66/// safely calculate the bounds of a new loop using the given Predicate.67static bool isSafeIncreasingBound(const SCEV *Start, const SCEV *BoundSCEV,68const SCEV *Step, ICmpInst::Predicate Pred,69unsigned LatchBrExitIdx, Loop *L,70ScalarEvolution &SE) {71if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&72Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)73return false;7475if (!SE.isAvailableAtLoopEntry(BoundSCEV, L))76return false;7778LLVM_DEBUG(dbgs() << "isSafeIncreasingBound with:\n");79LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n");80LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n");81LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n");82LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n");83LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n");8485bool IsSigned = ICmpInst::isSigned(Pred);86// The predicate that we need to check that the induction variable lies87// within bounds.88ICmpInst::Predicate BoundPred =89IsSigned ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT;9091auto StartLG = SE.applyLoopGuards(Start, L);92auto BoundLG = SE.applyLoopGuards(BoundSCEV, L);9394if (LatchBrExitIdx == 1)95return SE.isLoopEntryGuardedByCond(L, BoundPred, StartLG, BoundLG);9697assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be 0 or 1");9899const SCEV *StepMinusOne = SE.getMinusSCEV(Step, SE.getOne(Step->getType()));100unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth();101APInt Max = IsSigned ? APInt::getSignedMaxValue(BitWidth)102: APInt::getMaxValue(BitWidth);103const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Max), StepMinusOne);104105return (SE.isLoopEntryGuardedByCond(L, BoundPred, StartLG,106SE.getAddExpr(BoundLG, Step)) &&107SE.isLoopEntryGuardedByCond(L, BoundPred, BoundLG, Limit));108}109110/// Returns estimate for max latch taken count of the loop of the narrowest111/// available type. If the latch block has such estimate, it is returned.112/// Otherwise, we use max exit count of whole loop (that is potentially of wider113/// type than latch check itself), which is still better than no estimate.114static const SCEV *getNarrowestLatchMaxTakenCountEstimate(ScalarEvolution &SE,115const Loop &L) {116const SCEV *FromBlock =117SE.getExitCount(&L, L.getLoopLatch(), ScalarEvolution::SymbolicMaximum);118if (isa<SCEVCouldNotCompute>(FromBlock))119return SE.getSymbolicMaxBackedgeTakenCount(&L);120return FromBlock;121}122123std::optional<LoopStructure>124LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L,125bool AllowUnsignedLatchCond,126const char *&FailureReason) {127if (!L.isLoopSimplifyForm()) {128FailureReason = "loop not in LoopSimplify form";129return std::nullopt;130}131132BasicBlock *Latch = L.getLoopLatch();133assert(Latch && "Simplified loops only have one latch!");134135if (Latch->getTerminator()->getMetadata(ClonedLoopTag)) {136FailureReason = "loop has already been cloned";137return std::nullopt;138}139140if (!L.isLoopExiting(Latch)) {141FailureReason = "no loop latch";142return std::nullopt;143}144145BasicBlock *Header = L.getHeader();146BasicBlock *Preheader = L.getLoopPreheader();147if (!Preheader) {148FailureReason = "no preheader";149return std::nullopt;150}151152BranchInst *LatchBr = dyn_cast<BranchInst>(Latch->getTerminator());153if (!LatchBr || LatchBr->isUnconditional()) {154FailureReason = "latch terminator not conditional branch";155return std::nullopt;156}157158unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0;159160ICmpInst *ICI = dyn_cast<ICmpInst>(LatchBr->getCondition());161if (!ICI || !isa<IntegerType>(ICI->getOperand(0)->getType())) {162FailureReason = "latch terminator branch not conditional on integral icmp";163return std::nullopt;164}165166const SCEV *MaxBETakenCount = getNarrowestLatchMaxTakenCountEstimate(SE, L);167if (isa<SCEVCouldNotCompute>(MaxBETakenCount)) {168FailureReason = "could not compute latch count";169return std::nullopt;170}171assert(SE.getLoopDisposition(MaxBETakenCount, &L) ==172ScalarEvolution::LoopInvariant &&173"loop variant exit count doesn't make sense!");174175ICmpInst::Predicate Pred = ICI->getPredicate();176Value *LeftValue = ICI->getOperand(0);177const SCEV *LeftSCEV = SE.getSCEV(LeftValue);178IntegerType *IndVarTy = cast<IntegerType>(LeftValue->getType());179180Value *RightValue = ICI->getOperand(1);181const SCEV *RightSCEV = SE.getSCEV(RightValue);182183// We canonicalize `ICI` such that `LeftSCEV` is an add recurrence.184if (!isa<SCEVAddRecExpr>(LeftSCEV)) {185if (isa<SCEVAddRecExpr>(RightSCEV)) {186std::swap(LeftSCEV, RightSCEV);187std::swap(LeftValue, RightValue);188Pred = ICmpInst::getSwappedPredicate(Pred);189} else {190FailureReason = "no add recurrences in the icmp";191return std::nullopt;192}193}194195auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) {196if (AR->getNoWrapFlags(SCEV::FlagNSW))197return true;198199IntegerType *Ty = cast<IntegerType>(AR->getType());200IntegerType *WideTy =201IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2);202203const SCEVAddRecExpr *ExtendAfterOp =204dyn_cast<SCEVAddRecExpr>(SE.getSignExtendExpr(AR, WideTy));205if (ExtendAfterOp) {206const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy);207const SCEV *ExtendedStep =208SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy);209210bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart &&211ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep;212213if (NoSignedWrap)214return true;215}216217// We may have proved this when computing the sign extension above.218return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap;219};220221// `ICI` is interpreted as taking the backedge if the *next* value of the222// induction variable satisfies some constraint.223224const SCEVAddRecExpr *IndVarBase = cast<SCEVAddRecExpr>(LeftSCEV);225if (IndVarBase->getLoop() != &L) {226FailureReason = "LHS in cmp is not an AddRec for this loop";227return std::nullopt;228}229if (!IndVarBase->isAffine()) {230FailureReason = "LHS in icmp not induction variable";231return std::nullopt;232}233const SCEV *StepRec = IndVarBase->getStepRecurrence(SE);234if (!isa<SCEVConstant>(StepRec)) {235FailureReason = "LHS in icmp not induction variable";236return std::nullopt;237}238ConstantInt *StepCI = cast<SCEVConstant>(StepRec)->getValue();239240if (ICI->isEquality() && !HasNoSignedWrap(IndVarBase)) {241FailureReason = "LHS in icmp needs nsw for equality predicates";242return std::nullopt;243}244245assert(!StepCI->isZero() && "Zero step?");246bool IsIncreasing = !StepCI->isNegative();247bool IsSignedPredicate;248const SCEV *StartNext = IndVarBase->getStart();249const SCEV *Addend = SE.getNegativeSCEV(IndVarBase->getStepRecurrence(SE));250const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend);251const SCEV *Step = SE.getSCEV(StepCI);252253const SCEV *FixedRightSCEV = nullptr;254255// If RightValue resides within loop (but still being loop invariant),256// regenerate it as preheader.257if (auto *I = dyn_cast<Instruction>(RightValue))258if (L.contains(I->getParent()))259FixedRightSCEV = RightSCEV;260261if (IsIncreasing) {262bool DecreasedRightValueByOne = false;263if (StepCI->isOne()) {264// Try to turn eq/ne predicates to those we can work with.265if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)266// while (++i != len) { while (++i < len) {267// ... ---> ...268// } }269// If both parts are known non-negative, it is profitable to use270// unsigned comparison in increasing loop. This allows us to make the271// comparison check against "RightSCEV + 1" more optimistic.272if (isKnownNonNegativeInLoop(IndVarStart, &L, SE) &&273isKnownNonNegativeInLoop(RightSCEV, &L, SE))274Pred = ICmpInst::ICMP_ULT;275else276Pred = ICmpInst::ICMP_SLT;277else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {278// while (true) { while (true) {279// if (++i == len) ---> if (++i > len - 1)280// break; break;281// ... ...282// } }283if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) &&284cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ false)) {285Pred = ICmpInst::ICMP_UGT;286RightSCEV =287SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));288DecreasedRightValueByOne = true;289} else if (cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ true)) {290Pred = ICmpInst::ICMP_SGT;291RightSCEV =292SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));293DecreasedRightValueByOne = true;294}295}296}297298bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);299bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);300bool FoundExpectedPred =301(LTPred && LatchBrExitIdx == 1) || (GTPred && LatchBrExitIdx == 0);302303if (!FoundExpectedPred) {304FailureReason = "expected icmp slt semantically, found something else";305return std::nullopt;306}307308IsSignedPredicate = ICmpInst::isSigned(Pred);309if (!IsSignedPredicate && !AllowUnsignedLatchCond) {310FailureReason = "unsigned latch conditions are explicitly prohibited";311return std::nullopt;312}313314if (!isSafeIncreasingBound(IndVarStart, RightSCEV, Step, Pred,315LatchBrExitIdx, &L, SE)) {316FailureReason = "Unsafe loop bounds";317return std::nullopt;318}319if (LatchBrExitIdx == 0) {320// We need to increase the right value unless we have already decreased321// it virtually when we replaced EQ with SGT.322if (!DecreasedRightValueByOne)323FixedRightSCEV =324SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));325} else {326assert(!DecreasedRightValueByOne &&327"Right value can be decreased only for LatchBrExitIdx == 0!");328}329} else {330bool IncreasedRightValueByOne = false;331if (StepCI->isMinusOne()) {332// Try to turn eq/ne predicates to those we can work with.333if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)334// while (--i != len) { while (--i > len) {335// ... ---> ...336// } }337// We intentionally don't turn the predicate into UGT even if we know338// that both operands are non-negative, because it will only pessimize339// our check against "RightSCEV - 1".340Pred = ICmpInst::ICMP_SGT;341else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {342// while (true) { while (true) {343// if (--i == len) ---> if (--i < len + 1)344// break; break;345// ... ...346// } }347if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) &&348cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ false)) {349Pred = ICmpInst::ICMP_ULT;350RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));351IncreasedRightValueByOne = true;352} else if (cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ true)) {353Pred = ICmpInst::ICMP_SLT;354RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));355IncreasedRightValueByOne = true;356}357}358}359360bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);361bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);362363bool FoundExpectedPred =364(GTPred && LatchBrExitIdx == 1) || (LTPred && LatchBrExitIdx == 0);365366if (!FoundExpectedPred) {367FailureReason = "expected icmp sgt semantically, found something else";368return std::nullopt;369}370371IsSignedPredicate =372Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT;373374if (!IsSignedPredicate && !AllowUnsignedLatchCond) {375FailureReason = "unsigned latch conditions are explicitly prohibited";376return std::nullopt;377}378379if (!isSafeDecreasingBound(IndVarStart, RightSCEV, Step, Pred,380LatchBrExitIdx, &L, SE)) {381FailureReason = "Unsafe bounds";382return std::nullopt;383}384385if (LatchBrExitIdx == 0) {386// We need to decrease the right value unless we have already increased387// it virtually when we replaced EQ with SLT.388if (!IncreasedRightValueByOne)389FixedRightSCEV =390SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));391} else {392assert(!IncreasedRightValueByOne &&393"Right value can be increased only for LatchBrExitIdx == 0!");394}395}396BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx);397398assert(!L.contains(LatchExit) && "expected an exit block!");399const DataLayout &DL = Preheader->getDataLayout();400SCEVExpander Expander(SE, DL, "loop-constrainer");401Instruction *Ins = Preheader->getTerminator();402403if (FixedRightSCEV)404RightValue =405Expander.expandCodeFor(FixedRightSCEV, FixedRightSCEV->getType(), Ins);406407Value *IndVarStartV = Expander.expandCodeFor(IndVarStart, IndVarTy, Ins);408IndVarStartV->setName("indvar.start");409410LoopStructure Result;411412Result.Tag = "main";413Result.Header = Header;414Result.Latch = Latch;415Result.LatchBr = LatchBr;416Result.LatchExit = LatchExit;417Result.LatchBrExitIdx = LatchBrExitIdx;418Result.IndVarStart = IndVarStartV;419Result.IndVarStep = StepCI;420Result.IndVarBase = LeftValue;421Result.IndVarIncreasing = IsIncreasing;422Result.LoopExitAt = RightValue;423Result.IsSignedPredicate = IsSignedPredicate;424Result.ExitCountTy = cast<IntegerType>(MaxBETakenCount->getType());425426FailureReason = nullptr;427428return Result;429}430431// Add metadata to the loop L to disable loop optimizations. Callers need to432// confirm that optimizing loop L is not beneficial.433static void DisableAllLoopOptsOnLoop(Loop &L) {434// We do not care about any existing loopID related metadata for L, since we435// are setting all loop metadata to false.436LLVMContext &Context = L.getHeader()->getContext();437// Reserve first location for self reference to the LoopID metadata node.438MDNode *Dummy = MDNode::get(Context, {});439MDNode *DisableUnroll = MDNode::get(440Context, {MDString::get(Context, "llvm.loop.unroll.disable")});441Metadata *FalseVal =442ConstantAsMetadata::get(ConstantInt::get(Type::getInt1Ty(Context), 0));443MDNode *DisableVectorize = MDNode::get(444Context,445{MDString::get(Context, "llvm.loop.vectorize.enable"), FalseVal});446MDNode *DisableLICMVersioning = MDNode::get(447Context, {MDString::get(Context, "llvm.loop.licm_versioning.disable")});448MDNode *DisableDistribution = MDNode::get(449Context,450{MDString::get(Context, "llvm.loop.distribute.enable"), FalseVal});451MDNode *NewLoopID =452MDNode::get(Context, {Dummy, DisableUnroll, DisableVectorize,453DisableLICMVersioning, DisableDistribution});454// Set operand 0 to refer to the loop id itself.455NewLoopID->replaceOperandWith(0, NewLoopID);456L.setLoopID(NewLoopID);457}458459LoopConstrainer::LoopConstrainer(Loop &L, LoopInfo &LI,460function_ref<void(Loop *, bool)> LPMAddNewLoop,461const LoopStructure &LS, ScalarEvolution &SE,462DominatorTree &DT, Type *T, SubRanges SR)463: F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), SE(SE),464DT(DT), LI(LI), LPMAddNewLoop(LPMAddNewLoop), OriginalLoop(L), RangeTy(T),465MainLoopStructure(LS), SR(SR) {}466467void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result,468const char *Tag) const {469for (BasicBlock *BB : OriginalLoop.getBlocks()) {470BasicBlock *Clone = CloneBasicBlock(BB, Result.Map, Twine(".") + Tag, &F);471Result.Blocks.push_back(Clone);472Result.Map[BB] = Clone;473}474475auto GetClonedValue = [&Result](Value *V) {476assert(V && "null values not in domain!");477auto It = Result.Map.find(V);478if (It == Result.Map.end())479return V;480return static_cast<Value *>(It->second);481};482483auto *ClonedLatch =484cast<BasicBlock>(GetClonedValue(OriginalLoop.getLoopLatch()));485ClonedLatch->getTerminator()->setMetadata(ClonedLoopTag,486MDNode::get(Ctx, {}));487488Result.Structure = MainLoopStructure.map(GetClonedValue);489Result.Structure.Tag = Tag;490491for (unsigned i = 0, e = Result.Blocks.size(); i != e; ++i) {492BasicBlock *ClonedBB = Result.Blocks[i];493BasicBlock *OriginalBB = OriginalLoop.getBlocks()[i];494495assert(Result.Map[OriginalBB] == ClonedBB && "invariant!");496497for (Instruction &I : *ClonedBB)498RemapInstruction(&I, Result.Map,499RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);500501// Exit blocks will now have one more predecessor and their PHI nodes need502// to be edited to reflect that. No phi nodes need to be introduced because503// the loop is in LCSSA.504505for (auto *SBB : successors(OriginalBB)) {506if (OriginalLoop.contains(SBB))507continue; // not an exit block508509for (PHINode &PN : SBB->phis()) {510Value *OldIncoming = PN.getIncomingValueForBlock(OriginalBB);511PN.addIncoming(GetClonedValue(OldIncoming), ClonedBB);512SE.forgetValue(&PN);513}514}515}516}517518LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd(519const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt,520BasicBlock *ContinuationBlock) const {521// We start with a loop with a single latch:522//523// +--------------------+524// | |525// | preheader |526// | |527// +--------+-----------+528// | ----------------\529// | / |530// +--------v----v------+ |531// | | |532// | header | |533// | | |534// +--------------------+ |535// |536// ..... |537// |538// +--------------------+ |539// | | |540// | latch >----------/541// | |542// +-------v------------+543// |544// |545// | +--------------------+546// | | |547// +---> original exit |548// | |549// +--------------------+550//551// We change the control flow to look like552//553//554// +--------------------+555// | |556// | preheader >-------------------------+557// | | |558// +--------v-----------+ |559// | /-------------+ |560// | / | |561// +--------v--v--------+ | |562// | | | |563// | header | | +--------+ |564// | | | | | |565// +--------------------+ | | +-----v-----v-----------+566// | | | |567// | | | .pseudo.exit |568// | | | |569// | | +-----------v-----------+570// | | |571// ..... | | |572// | | +--------v-------------+573// +--------------------+ | | | |574// | | | | | ContinuationBlock |575// | latch >------+ | | |576// | | | +----------------------+577// +---------v----------+ |578// | |579// | |580// | +---------------^-----+581// | | |582// +-----> .exit.selector |583// | |584// +----------v----------+585// |586// +--------------------+ |587// | | |588// | original exit <----+589// | |590// +--------------------+591592RewrittenRangeInfo RRI;593594BasicBlock *BBInsertLocation = LS.Latch->getNextNode();595RRI.ExitSelector = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".exit.selector",596&F, BBInsertLocation);597RRI.PseudoExit = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".pseudo.exit", &F,598BBInsertLocation);599600BranchInst *PreheaderJump = cast<BranchInst>(Preheader->getTerminator());601bool Increasing = LS.IndVarIncreasing;602bool IsSignedPredicate = LS.IsSignedPredicate;603604IRBuilder<> B(PreheaderJump);605auto NoopOrExt = [&](Value *V) {606if (V->getType() == RangeTy)607return V;608return IsSignedPredicate ? B.CreateSExt(V, RangeTy, "wide." + V->getName())609: B.CreateZExt(V, RangeTy, "wide." + V->getName());610};611612// EnterLoopCond - is it okay to start executing this `LS'?613Value *EnterLoopCond = nullptr;614auto Pred =615Increasing616? (IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT)617: (IsSignedPredicate ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT);618Value *IndVarStart = NoopOrExt(LS.IndVarStart);619EnterLoopCond = B.CreateICmp(Pred, IndVarStart, ExitSubloopAt);620621B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit);622PreheaderJump->eraseFromParent();623624LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector);625B.SetInsertPoint(LS.LatchBr);626Value *IndVarBase = NoopOrExt(LS.IndVarBase);627Value *TakeBackedgeLoopCond = B.CreateICmp(Pred, IndVarBase, ExitSubloopAt);628629Value *CondForBranch = LS.LatchBrExitIdx == 1630? TakeBackedgeLoopCond631: B.CreateNot(TakeBackedgeLoopCond);632633LS.LatchBr->setCondition(CondForBranch);634635B.SetInsertPoint(RRI.ExitSelector);636637// IterationsLeft - are there any more iterations left, given the original638// upper bound on the induction variable? If not, we branch to the "real"639// exit.640Value *LoopExitAt = NoopOrExt(LS.LoopExitAt);641Value *IterationsLeft = B.CreateICmp(Pred, IndVarBase, LoopExitAt);642B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit);643644BranchInst *BranchToContinuation =645BranchInst::Create(ContinuationBlock, RRI.PseudoExit);646647// We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of648// each of the PHI nodes in the loop header. This feeds into the initial649// value of the same PHI nodes if/when we continue execution.650for (PHINode &PN : LS.Header->phis()) {651PHINode *NewPHI = PHINode::Create(PN.getType(), 2, PN.getName() + ".copy",652BranchToContinuation->getIterator());653654NewPHI->addIncoming(PN.getIncomingValueForBlock(Preheader), Preheader);655NewPHI->addIncoming(PN.getIncomingValueForBlock(LS.Latch),656RRI.ExitSelector);657RRI.PHIValuesAtPseudoExit.push_back(NewPHI);658}659660RRI.IndVarEnd = PHINode::Create(IndVarBase->getType(), 2, "indvar.end",661BranchToContinuation->getIterator());662RRI.IndVarEnd->addIncoming(IndVarStart, Preheader);663RRI.IndVarEnd->addIncoming(IndVarBase, RRI.ExitSelector);664665// The latch exit now has a branch from `RRI.ExitSelector' instead of666// `LS.Latch'. The PHI nodes need to be updated to reflect that.667LS.LatchExit->replacePhiUsesWith(LS.Latch, RRI.ExitSelector);668669return RRI;670}671672void LoopConstrainer::rewriteIncomingValuesForPHIs(673LoopStructure &LS, BasicBlock *ContinuationBlock,674const LoopConstrainer::RewrittenRangeInfo &RRI) const {675unsigned PHIIndex = 0;676for (PHINode &PN : LS.Header->phis())677PN.setIncomingValueForBlock(ContinuationBlock,678RRI.PHIValuesAtPseudoExit[PHIIndex++]);679680LS.IndVarStart = RRI.IndVarEnd;681}682683BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS,684BasicBlock *OldPreheader,685const char *Tag) const {686BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header);687BranchInst::Create(LS.Header, Preheader);688689LS.Header->replacePhiUsesWith(OldPreheader, Preheader);690691return Preheader;692}693694void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) {695Loop *ParentLoop = OriginalLoop.getParentLoop();696if (!ParentLoop)697return;698699for (BasicBlock *BB : BBs)700ParentLoop->addBasicBlockToLoop(BB, LI);701}702703Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent,704ValueToValueMapTy &VM,705bool IsSubloop) {706Loop &New = *LI.AllocateLoop();707if (Parent)708Parent->addChildLoop(&New);709else710LI.addTopLevelLoop(&New);711LPMAddNewLoop(&New, IsSubloop);712713// Add all of the blocks in Original to the new loop.714for (auto *BB : Original->blocks())715if (LI.getLoopFor(BB) == Original)716New.addBasicBlockToLoop(cast<BasicBlock>(VM[BB]), LI);717718// Add all of the subloops to the new loop.719for (Loop *SubLoop : *Original)720createClonedLoopStructure(SubLoop, &New, VM, /* IsSubloop */ true);721722return &New;723}724725bool LoopConstrainer::run() {726BasicBlock *Preheader = OriginalLoop.getLoopPreheader();727assert(Preheader != nullptr && "precondition!");728729OriginalPreheader = Preheader;730MainLoopPreheader = Preheader;731bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate;732bool Increasing = MainLoopStructure.IndVarIncreasing;733IntegerType *IVTy = cast<IntegerType>(RangeTy);734735SCEVExpander Expander(SE, F.getDataLayout(), "loop-constrainer");736Instruction *InsertPt = OriginalPreheader->getTerminator();737738// It would have been better to make `PreLoop' and `PostLoop'739// `std::optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy740// constructor.741ClonedLoop PreLoop, PostLoop;742bool NeedsPreLoop =743Increasing ? SR.LowLimit.has_value() : SR.HighLimit.has_value();744bool NeedsPostLoop =745Increasing ? SR.HighLimit.has_value() : SR.LowLimit.has_value();746747Value *ExitPreLoopAt = nullptr;748Value *ExitMainLoopAt = nullptr;749const SCEVConstant *MinusOneS =750cast<SCEVConstant>(SE.getConstant(IVTy, -1, true /* isSigned */));751752if (NeedsPreLoop) {753const SCEV *ExitPreLoopAtSCEV = nullptr;754755if (Increasing)756ExitPreLoopAtSCEV = *SR.LowLimit;757else if (cannotBeMinInLoop(*SR.HighLimit, &OriginalLoop, SE,758IsSignedPredicate))759ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS);760else {761LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing "762<< "preloop exit limit. HighLimit = "763<< *(*SR.HighLimit) << "\n");764return false;765}766767if (!Expander.isSafeToExpandAt(ExitPreLoopAtSCEV, InsertPt)) {768LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the"769<< " preloop exit limit " << *ExitPreLoopAtSCEV770<< " at block " << InsertPt->getParent()->getName()771<< "\n");772return false;773}774775ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt);776ExitPreLoopAt->setName("exit.preloop.at");777}778779if (NeedsPostLoop) {780const SCEV *ExitMainLoopAtSCEV = nullptr;781782if (Increasing)783ExitMainLoopAtSCEV = *SR.HighLimit;784else if (cannotBeMinInLoop(*SR.LowLimit, &OriginalLoop, SE,785IsSignedPredicate))786ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS);787else {788LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing "789<< "mainloop exit limit. LowLimit = "790<< *(*SR.LowLimit) << "\n");791return false;792}793794if (!Expander.isSafeToExpandAt(ExitMainLoopAtSCEV, InsertPt)) {795LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the"796<< " main loop exit limit " << *ExitMainLoopAtSCEV797<< " at block " << InsertPt->getParent()->getName()798<< "\n");799return false;800}801802ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt);803ExitMainLoopAt->setName("exit.mainloop.at");804}805806// We clone these ahead of time so that we don't have to deal with changing807// and temporarily invalid IR as we transform the loops.808if (NeedsPreLoop)809cloneLoop(PreLoop, "preloop");810if (NeedsPostLoop)811cloneLoop(PostLoop, "postloop");812813RewrittenRangeInfo PreLoopRRI;814815if (NeedsPreLoop) {816Preheader->getTerminator()->replaceUsesOfWith(MainLoopStructure.Header,817PreLoop.Structure.Header);818819MainLoopPreheader =820createPreheader(MainLoopStructure, Preheader, "mainloop");821PreLoopRRI = changeIterationSpaceEnd(PreLoop.Structure, Preheader,822ExitPreLoopAt, MainLoopPreheader);823rewriteIncomingValuesForPHIs(MainLoopStructure, MainLoopPreheader,824PreLoopRRI);825}826827BasicBlock *PostLoopPreheader = nullptr;828RewrittenRangeInfo PostLoopRRI;829830if (NeedsPostLoop) {831PostLoopPreheader =832createPreheader(PostLoop.Structure, Preheader, "postloop");833PostLoopRRI = changeIterationSpaceEnd(MainLoopStructure, MainLoopPreheader,834ExitMainLoopAt, PostLoopPreheader);835rewriteIncomingValuesForPHIs(PostLoop.Structure, PostLoopPreheader,836PostLoopRRI);837}838839BasicBlock *NewMainLoopPreheader =840MainLoopPreheader != Preheader ? MainLoopPreheader : nullptr;841BasicBlock *NewBlocks[] = {PostLoopPreheader, PreLoopRRI.PseudoExit,842PreLoopRRI.ExitSelector, PostLoopRRI.PseudoExit,843PostLoopRRI.ExitSelector, NewMainLoopPreheader};844845// Some of the above may be nullptr, filter them out before passing to846// addToParentLoopIfNeeded.847auto NewBlocksEnd =848std::remove(std::begin(NewBlocks), std::end(NewBlocks), nullptr);849850addToParentLoopIfNeeded(ArrayRef(std::begin(NewBlocks), NewBlocksEnd));851852DT.recalculate(F);853854// We need to first add all the pre and post loop blocks into the loop855// structures (as part of createClonedLoopStructure), and then update the856// LCSSA form and LoopSimplifyForm. This is necessary for correctly updating857// LI when LoopSimplifyForm is generated.858Loop *PreL = nullptr, *PostL = nullptr;859if (!PreLoop.Blocks.empty()) {860PreL = createClonedLoopStructure(&OriginalLoop,861OriginalLoop.getParentLoop(), PreLoop.Map,862/* IsSubLoop */ false);863}864865if (!PostLoop.Blocks.empty()) {866PostL =867createClonedLoopStructure(&OriginalLoop, OriginalLoop.getParentLoop(),868PostLoop.Map, /* IsSubLoop */ false);869}870871// This function canonicalizes the loop into Loop-Simplify and LCSSA forms.872auto CanonicalizeLoop = [&](Loop *L, bool IsOriginalLoop) {873formLCSSARecursively(*L, DT, &LI, &SE);874simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr, true);875// Pre/post loops are slow paths, we do not need to perform any loop876// optimizations on them.877if (!IsOriginalLoop)878DisableAllLoopOptsOnLoop(*L);879};880if (PreL)881CanonicalizeLoop(PreL, false);882if (PostL)883CanonicalizeLoop(PostL, false);884CanonicalizeLoop(&OriginalLoop, true);885886/// At this point:887/// - We've broken a "main loop" out of the loop in a way that the "main loop"888/// runs with the induction variable in a subset of [Begin, End).889/// - There is no overflow when computing "main loop" exit limit.890/// - Max latch taken count of the loop is limited.891/// It guarantees that induction variable will not overflow iterating in the892/// "main loop".893if (isa<OverflowingBinaryOperator>(MainLoopStructure.IndVarBase))894if (IsSignedPredicate)895cast<BinaryOperator>(MainLoopStructure.IndVarBase)896->setHasNoSignedWrap(true);897/// TODO: support unsigned predicate.898/// To add NUW flag we need to prove that both operands of BO are899/// non-negative. E.g:900/// ...901/// %iv.next = add nsw i32 %iv, -1902/// %cmp = icmp ult i32 %iv.next, %n903/// br i1 %cmp, label %loopexit, label %loop904///905/// -1 is MAX_UINT in terms of unsigned int. Adding anything but zero will906/// overflow, therefore NUW flag is not legal here.907908return true;909}910911912