Path: blob/main/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
35266 views
//===- LoopFlatten.cpp - Loop flattening pass------------------------------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// This pass flattens pairs nested loops into a single loop.9//10// The intention is to optimise loop nests like this, which together access an11// array linearly:12//13// for (int i = 0; i < N; ++i)14// for (int j = 0; j < M; ++j)15// f(A[i*M+j]);16//17// into one loop:18//19// for (int i = 0; i < (N*M); ++i)20// f(A[i]);21//22// It can also flatten loops where the induction variables are not used in the23// loop. This is only worth doing if the induction variables are only used in an24// expression like i*M+j. If they had any other uses, we would have to insert a25// div/mod to reconstruct the original values, so this wouldn't be profitable.26//27// We also need to prove that N*M will not overflow. The preferred solution is28// to widen the IV, which avoids overflow checks, so that is tried first. If29// the IV cannot be widened, then we try to determine that this new tripcount30// expression won't overflow.31//32// Q: Does LoopFlatten use SCEV?33// Short answer: Yes and no.34//35// Long answer:36// For this transformation to be valid, we require all uses of the induction37// variables to be linear expressions of the form i*M+j. The different Loop38// APIs are used to get some loop components like the induction variable,39// compare statement, etc. In addition, we do some pattern matching to find the40// linear expressions and other loop components like the loop increment. The41// latter are examples of expressions that do use the induction variable, but42// are safe to ignore when we check all uses to be of the form i*M+j. We keep43// track of all of this in bookkeeping struct FlattenInfo.44// We assume the loops to be canonical, i.e. starting at 0 and increment with45// 1. This makes RHS of the compare the loop tripcount (with the right46// predicate). We use SCEV to then sanity check that this tripcount matches47// with the tripcount as computed by SCEV.48//49//===----------------------------------------------------------------------===//5051#include "llvm/Transforms/Scalar/LoopFlatten.h"5253#include "llvm/ADT/Statistic.h"54#include "llvm/Analysis/AssumptionCache.h"55#include "llvm/Analysis/LoopInfo.h"56#include "llvm/Analysis/LoopNestAnalysis.h"57#include "llvm/Analysis/MemorySSAUpdater.h"58#include "llvm/Analysis/OptimizationRemarkEmitter.h"59#include "llvm/Analysis/ScalarEvolution.h"60#include "llvm/Analysis/TargetTransformInfo.h"61#include "llvm/Analysis/ValueTracking.h"62#include "llvm/IR/Dominators.h"63#include "llvm/IR/Function.h"64#include "llvm/IR/IRBuilder.h"65#include "llvm/IR/Module.h"66#include "llvm/IR/PatternMatch.h"67#include "llvm/Support/Debug.h"68#include "llvm/Support/raw_ostream.h"69#include "llvm/Transforms/Scalar/LoopPassManager.h"70#include "llvm/Transforms/Utils/Local.h"71#include "llvm/Transforms/Utils/LoopUtils.h"72#include "llvm/Transforms/Utils/LoopVersioning.h"73#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"74#include "llvm/Transforms/Utils/SimplifyIndVar.h"75#include <optional>7677using namespace llvm;78using namespace llvm::PatternMatch;7980#define DEBUG_TYPE "loop-flatten"8182STATISTIC(NumFlattened, "Number of loops flattened");8384static cl::opt<unsigned> RepeatedInstructionThreshold(85"loop-flatten-cost-threshold", cl::Hidden, cl::init(2),86cl::desc("Limit on the cost of instructions that can be repeated due to "87"loop flattening"));8889static cl::opt<bool>90AssumeNoOverflow("loop-flatten-assume-no-overflow", cl::Hidden,91cl::init(false),92cl::desc("Assume that the product of the two iteration "93"trip counts will never overflow"));9495static cl::opt<bool>96WidenIV("loop-flatten-widen-iv", cl::Hidden, cl::init(true),97cl::desc("Widen the loop induction variables, if possible, so "98"overflow checks won't reject flattening"));99100static cl::opt<bool>101VersionLoops("loop-flatten-version-loops", cl::Hidden, cl::init(true),102cl::desc("Version loops if flattened loop could overflow"));103104namespace {105// We require all uses of both induction variables to match this pattern:106//107// (OuterPHI * InnerTripCount) + InnerPHI108//109// I.e., it needs to be a linear expression of the induction variables and the110// inner loop trip count. We keep track of all different expressions on which111// checks will be performed in this bookkeeping struct.112//113struct FlattenInfo {114Loop *OuterLoop = nullptr; // The loop pair to be flattened.115Loop *InnerLoop = nullptr;116117PHINode *InnerInductionPHI = nullptr; // These PHINodes correspond to loop118PHINode *OuterInductionPHI = nullptr; // induction variables, which are119// expected to start at zero and120// increment by one on each loop.121122Value *InnerTripCount = nullptr; // The product of these two tripcounts123Value *OuterTripCount = nullptr; // will be the new flattened loop124// tripcount. Also used to recognise a125// linear expression that will be replaced.126127SmallPtrSet<Value *, 4> LinearIVUses; // Contains the linear expressions128// of the form i*M+j that will be129// replaced.130131BinaryOperator *InnerIncrement = nullptr; // Uses of induction variables in132BinaryOperator *OuterIncrement = nullptr; // loop control statements that133BranchInst *InnerBranch = nullptr; // are safe to ignore.134135BranchInst *OuterBranch = nullptr; // The instruction that needs to be136// updated with new tripcount.137138SmallPtrSet<PHINode *, 4> InnerPHIsToTransform;139140bool Widened = false; // Whether this holds the flatten info before or after141// widening.142143PHINode *NarrowInnerInductionPHI = nullptr; // Holds the old/narrow induction144PHINode *NarrowOuterInductionPHI = nullptr; // phis, i.e. the Phis before IV145// has been applied. Used to skip146// checks on phi nodes.147148Value *NewTripCount = nullptr; // The tripcount of the flattened loop.149150FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL){};151152bool isNarrowInductionPhi(PHINode *Phi) {153// This can't be the narrow phi if we haven't widened the IV first.154if (!Widened)155return false;156return NarrowInnerInductionPHI == Phi || NarrowOuterInductionPHI == Phi;157}158bool isInnerLoopIncrement(User *U) {159return InnerIncrement == U;160}161bool isOuterLoopIncrement(User *U) {162return OuterIncrement == U;163}164bool isInnerLoopTest(User *U) {165return InnerBranch->getCondition() == U;166}167168bool checkOuterInductionPhiUsers(SmallPtrSet<Value *, 4> &ValidOuterPHIUses) {169for (User *U : OuterInductionPHI->users()) {170if (isOuterLoopIncrement(U))171continue;172173auto IsValidOuterPHIUses = [&] (User *U) -> bool {174LLVM_DEBUG(dbgs() << "Found use of outer induction variable: "; U->dump());175if (!ValidOuterPHIUses.count(U)) {176LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");177return false;178}179LLVM_DEBUG(dbgs() << "Use is optimisable\n");180return true;181};182183if (auto *V = dyn_cast<TruncInst>(U)) {184for (auto *K : V->users()) {185if (!IsValidOuterPHIUses(K))186return false;187}188continue;189}190191if (!IsValidOuterPHIUses(U))192return false;193}194return true;195}196197bool matchLinearIVUser(User *U, Value *InnerTripCount,198SmallPtrSet<Value *, 4> &ValidOuterPHIUses) {199LLVM_DEBUG(dbgs() << "Checking linear i*M+j expression for: "; U->dump());200Value *MatchedMul = nullptr;201Value *MatchedItCount = nullptr;202203bool IsAdd = match(U, m_c_Add(m_Specific(InnerInductionPHI),204m_Value(MatchedMul))) &&205match(MatchedMul, m_c_Mul(m_Specific(OuterInductionPHI),206m_Value(MatchedItCount)));207208// Matches the same pattern as above, except it also looks for truncs209// on the phi, which can be the result of widening the induction variables.210bool IsAddTrunc =211match(U, m_c_Add(m_Trunc(m_Specific(InnerInductionPHI)),212m_Value(MatchedMul))) &&213match(MatchedMul, m_c_Mul(m_Trunc(m_Specific(OuterInductionPHI)),214m_Value(MatchedItCount)));215216// Matches the pattern ptr+i*M+j, with the two additions being done via GEP.217bool IsGEP = match(U, m_GEP(m_GEP(m_Value(), m_Value(MatchedMul)),218m_Specific(InnerInductionPHI))) &&219match(MatchedMul, m_c_Mul(m_Specific(OuterInductionPHI),220m_Value(MatchedItCount)));221222if (!MatchedItCount)223return false;224225LLVM_DEBUG(dbgs() << "Matched multiplication: "; MatchedMul->dump());226LLVM_DEBUG(dbgs() << "Matched iteration count: "; MatchedItCount->dump());227228// The mul should not have any other uses. Widening may leave trivially dead229// uses, which can be ignored.230if (count_if(MatchedMul->users(), [](User *U) {231return !isInstructionTriviallyDead(cast<Instruction>(U));232}) > 1) {233LLVM_DEBUG(dbgs() << "Multiply has more than one use\n");234return false;235}236237// Look through extends if the IV has been widened. Don't look through238// extends if we already looked through a trunc.239if (Widened && (IsAdd || IsGEP) &&240(isa<SExtInst>(MatchedItCount) || isa<ZExtInst>(MatchedItCount))) {241assert(MatchedItCount->getType() == InnerInductionPHI->getType() &&242"Unexpected type mismatch in types after widening");243MatchedItCount = isa<SExtInst>(MatchedItCount)244? dyn_cast<SExtInst>(MatchedItCount)->getOperand(0)245: dyn_cast<ZExtInst>(MatchedItCount)->getOperand(0);246}247248LLVM_DEBUG(dbgs() << "Looking for inner trip count: ";249InnerTripCount->dump());250251if ((IsAdd || IsAddTrunc || IsGEP) && MatchedItCount == InnerTripCount) {252LLVM_DEBUG(dbgs() << "Found. This sse is optimisable\n");253ValidOuterPHIUses.insert(MatchedMul);254LinearIVUses.insert(U);255return true;256}257258LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");259return false;260}261262bool checkInnerInductionPhiUsers(SmallPtrSet<Value *, 4> &ValidOuterPHIUses) {263Value *SExtInnerTripCount = InnerTripCount;264if (Widened &&265(isa<SExtInst>(InnerTripCount) || isa<ZExtInst>(InnerTripCount)))266SExtInnerTripCount = cast<Instruction>(InnerTripCount)->getOperand(0);267268for (User *U : InnerInductionPHI->users()) {269LLVM_DEBUG(dbgs() << "Checking User: "; U->dump());270if (isInnerLoopIncrement(U)) {271LLVM_DEBUG(dbgs() << "Use is inner loop increment, continuing\n");272continue;273}274275// After widening the IVs, a trunc instruction might have been introduced,276// so look through truncs.277if (isa<TruncInst>(U)) {278if (!U->hasOneUse())279return false;280U = *U->user_begin();281}282283// If the use is in the compare (which is also the condition of the inner284// branch) then the compare has been altered by another transformation e.g285// icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, where tripcount is286// a constant. Ignore this use as the compare gets removed later anyway.287if (isInnerLoopTest(U)) {288LLVM_DEBUG(dbgs() << "Use is the inner loop test, continuing\n");289continue;290}291292if (!matchLinearIVUser(U, SExtInnerTripCount, ValidOuterPHIUses)) {293LLVM_DEBUG(dbgs() << "Not a linear IV user\n");294return false;295}296LLVM_DEBUG(dbgs() << "Linear IV users found!\n");297}298return true;299}300};301} // namespace302303static bool304setLoopComponents(Value *&TC, Value *&TripCount, BinaryOperator *&Increment,305SmallPtrSetImpl<Instruction *> &IterationInstructions) {306TripCount = TC;307IterationInstructions.insert(Increment);308LLVM_DEBUG(dbgs() << "Found Increment: "; Increment->dump());309LLVM_DEBUG(dbgs() << "Found trip count: "; TripCount->dump());310LLVM_DEBUG(dbgs() << "Successfully found all loop components\n");311return true;312}313314// Given the RHS of the loop latch compare instruction, verify with SCEV315// that this is indeed the loop tripcount.316// TODO: This used to be a straightforward check but has grown to be quite317// complicated now. It is therefore worth revisiting what the additional318// benefits are of this (compared to relying on canonical loops and pattern319// matching).320static bool verifyTripCount(Value *RHS, Loop *L,321SmallPtrSetImpl<Instruction *> &IterationInstructions,322PHINode *&InductionPHI, Value *&TripCount, BinaryOperator *&Increment,323BranchInst *&BackBranch, ScalarEvolution *SE, bool IsWidened) {324const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L);325if (isa<SCEVCouldNotCompute>(BackedgeTakenCount)) {326LLVM_DEBUG(dbgs() << "Backedge-taken count is not predictable\n");327return false;328}329330// Evaluating in the trip count's type can not overflow here as the overflow331// checks are performed in checkOverflow, but are first tried to avoid by332// widening the IV.333const SCEV *SCEVTripCount =334SE->getTripCountFromExitCount(BackedgeTakenCount,335BackedgeTakenCount->getType(), L);336337const SCEV *SCEVRHS = SE->getSCEV(RHS);338if (SCEVRHS == SCEVTripCount)339return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);340ConstantInt *ConstantRHS = dyn_cast<ConstantInt>(RHS);341if (ConstantRHS) {342const SCEV *BackedgeTCExt = nullptr;343if (IsWidened) {344const SCEV *SCEVTripCountExt;345// Find the extended backedge taken count and extended trip count using346// SCEV. One of these should now match the RHS of the compare.347BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType());348SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt,349RHS->getType(), L);350if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) {351LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");352return false;353}354}355// If the RHS of the compare is equal to the backedge taken count we need356// to add one to get the trip count.357if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) {358Value *NewRHS = ConstantInt::get(ConstantRHS->getContext(),359ConstantRHS->getValue() + 1);360return setLoopComponents(NewRHS, TripCount, Increment,361IterationInstructions);362}363return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);364}365// If the RHS isn't a constant then check that the reason it doesn't match366// the SCEV trip count is because the RHS is a ZExt or SExt instruction367// (and take the trip count to be the RHS).368if (!IsWidened) {369LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");370return false;371}372auto *TripCountInst = dyn_cast<Instruction>(RHS);373if (!TripCountInst) {374LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");375return false;376}377if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) ||378SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) {379LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");380return false;381}382return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);383}384385// Finds the induction variable, increment and trip count for a simple loop that386// we can flatten.387static bool findLoopComponents(388Loop *L, SmallPtrSetImpl<Instruction *> &IterationInstructions,389PHINode *&InductionPHI, Value *&TripCount, BinaryOperator *&Increment,390BranchInst *&BackBranch, ScalarEvolution *SE, bool IsWidened) {391LLVM_DEBUG(dbgs() << "Finding components of loop: " << L->getName() << "\n");392393if (!L->isLoopSimplifyForm()) {394LLVM_DEBUG(dbgs() << "Loop is not in normal form\n");395return false;396}397398// Currently, to simplify the implementation, the Loop induction variable must399// start at zero and increment with a step size of one.400if (!L->isCanonical(*SE)) {401LLVM_DEBUG(dbgs() << "Loop is not canonical\n");402return false;403}404405// There must be exactly one exiting block, and it must be the same at the406// latch.407BasicBlock *Latch = L->getLoopLatch();408if (L->getExitingBlock() != Latch) {409LLVM_DEBUG(dbgs() << "Exiting and latch block are different\n");410return false;411}412413// Find the induction PHI. If there is no induction PHI, we can't do the414// transformation. TODO: could other variables trigger this? Do we have to415// search for the best one?416InductionPHI = L->getInductionVariable(*SE);417if (!InductionPHI) {418LLVM_DEBUG(dbgs() << "Could not find induction PHI\n");419return false;420}421LLVM_DEBUG(dbgs() << "Found induction PHI: "; InductionPHI->dump());422423bool ContinueOnTrue = L->contains(Latch->getTerminator()->getSuccessor(0));424auto IsValidPredicate = [&](ICmpInst::Predicate Pred) {425if (ContinueOnTrue)426return Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_ULT;427else428return Pred == CmpInst::ICMP_EQ;429};430431// Find Compare and make sure it is valid. getLatchCmpInst checks that the432// back branch of the latch is conditional.433ICmpInst *Compare = L->getLatchCmpInst();434if (!Compare || !IsValidPredicate(Compare->getUnsignedPredicate()) ||435Compare->hasNUsesOrMore(2)) {436LLVM_DEBUG(dbgs() << "Could not find valid comparison\n");437return false;438}439BackBranch = cast<BranchInst>(Latch->getTerminator());440IterationInstructions.insert(BackBranch);441LLVM_DEBUG(dbgs() << "Found back branch: "; BackBranch->dump());442IterationInstructions.insert(Compare);443LLVM_DEBUG(dbgs() << "Found comparison: "; Compare->dump());444445// Find increment and trip count.446// There are exactly 2 incoming values to the induction phi; one from the447// pre-header and one from the latch. The incoming latch value is the448// increment variable.449Increment =450cast<BinaryOperator>(InductionPHI->getIncomingValueForBlock(Latch));451if ((Compare->getOperand(0) != Increment || !Increment->hasNUses(2)) &&452!Increment->hasNUses(1)) {453LLVM_DEBUG(dbgs() << "Could not find valid increment\n");454return false;455}456// The trip count is the RHS of the compare. If this doesn't match the trip457// count computed by SCEV then this is because the trip count variable458// has been widened so the types don't match, or because it is a constant and459// another transformation has changed the compare (e.g. icmp ult %inc,460// tripcount -> icmp ult %j, tripcount-1), or both.461Value *RHS = Compare->getOperand(1);462463return verifyTripCount(RHS, L, IterationInstructions, InductionPHI, TripCount,464Increment, BackBranch, SE, IsWidened);465}466467static bool checkPHIs(FlattenInfo &FI, const TargetTransformInfo *TTI) {468// All PHIs in the inner and outer headers must either be:469// - The induction PHI, which we are going to rewrite as one induction in470// the new loop. This is already checked by findLoopComponents.471// - An outer header PHI with all incoming values from outside the loop.472// LoopSimplify guarantees we have a pre-header, so we don't need to473// worry about that here.474// - Pairs of PHIs in the inner and outer headers, which implement a475// loop-carried dependency that will still be valid in the new loop. To476// be valid, this variable must be modified only in the inner loop.477478// The set of PHI nodes in the outer loop header that we know will still be479// valid after the transformation. These will not need to be modified (with480// the exception of the induction variable), but we do need to check that481// there are no unsafe PHI nodes.482SmallPtrSet<PHINode *, 4> SafeOuterPHIs;483SafeOuterPHIs.insert(FI.OuterInductionPHI);484485// Check that all PHI nodes in the inner loop header match one of the valid486// patterns.487for (PHINode &InnerPHI : FI.InnerLoop->getHeader()->phis()) {488// The induction PHIs break these rules, and that's OK because we treat489// them specially when doing the transformation.490if (&InnerPHI == FI.InnerInductionPHI)491continue;492if (FI.isNarrowInductionPhi(&InnerPHI))493continue;494495// Each inner loop PHI node must have two incoming values/blocks - one496// from the pre-header, and one from the latch.497assert(InnerPHI.getNumIncomingValues() == 2);498Value *PreHeaderValue =499InnerPHI.getIncomingValueForBlock(FI.InnerLoop->getLoopPreheader());500Value *LatchValue =501InnerPHI.getIncomingValueForBlock(FI.InnerLoop->getLoopLatch());502503// The incoming value from the outer loop must be the PHI node in the504// outer loop header, with no modifications made in the top of the outer505// loop.506PHINode *OuterPHI = dyn_cast<PHINode>(PreHeaderValue);507if (!OuterPHI || OuterPHI->getParent() != FI.OuterLoop->getHeader()) {508LLVM_DEBUG(dbgs() << "value modified in top of outer loop\n");509return false;510}511512// The other incoming value must come from the inner loop, without any513// modifications in the tail end of the outer loop. We are in LCSSA form,514// so this will actually be a PHI in the inner loop's exit block, which515// only uses values from inside the inner loop.516PHINode *LCSSAPHI = dyn_cast<PHINode>(517OuterPHI->getIncomingValueForBlock(FI.OuterLoop->getLoopLatch()));518if (!LCSSAPHI) {519LLVM_DEBUG(dbgs() << "could not find LCSSA PHI\n");520return false;521}522523// The value used by the LCSSA PHI must be the same one that the inner524// loop's PHI uses.525if (LCSSAPHI->hasConstantValue() != LatchValue) {526LLVM_DEBUG(527dbgs() << "LCSSA PHI incoming value does not match latch value\n");528return false;529}530531LLVM_DEBUG(dbgs() << "PHI pair is safe:\n");532LLVM_DEBUG(dbgs() << " Inner: "; InnerPHI.dump());533LLVM_DEBUG(dbgs() << " Outer: "; OuterPHI->dump());534SafeOuterPHIs.insert(OuterPHI);535FI.InnerPHIsToTransform.insert(&InnerPHI);536}537538for (PHINode &OuterPHI : FI.OuterLoop->getHeader()->phis()) {539if (FI.isNarrowInductionPhi(&OuterPHI))540continue;541if (!SafeOuterPHIs.count(&OuterPHI)) {542LLVM_DEBUG(dbgs() << "found unsafe PHI in outer loop: "; OuterPHI.dump());543return false;544}545}546547LLVM_DEBUG(dbgs() << "checkPHIs: OK\n");548return true;549}550551static bool552checkOuterLoopInsts(FlattenInfo &FI,553SmallPtrSetImpl<Instruction *> &IterationInstructions,554const TargetTransformInfo *TTI) {555// Check for instructions in the outer but not inner loop. If any of these556// have side-effects then this transformation is not legal, and if there is557// a significant amount of code here which can't be optimised out that it's558// not profitable (as these instructions would get executed for each559// iteration of the inner loop).560InstructionCost RepeatedInstrCost = 0;561for (auto *B : FI.OuterLoop->getBlocks()) {562if (FI.InnerLoop->contains(B))563continue;564565for (auto &I : *B) {566if (!isa<PHINode>(&I) && !I.isTerminator() &&567!isSafeToSpeculativelyExecute(&I)) {568LLVM_DEBUG(dbgs() << "Cannot flatten because instruction may have "569"side effects: ";570I.dump());571return false;572}573// The execution count of the outer loop's iteration instructions574// (increment, compare and branch) will be increased, but the575// equivalent instructions will be removed from the inner loop, so576// they make a net difference of zero.577if (IterationInstructions.count(&I))578continue;579// The unconditional branch to the inner loop's header will turn into580// a fall-through, so adds no cost.581BranchInst *Br = dyn_cast<BranchInst>(&I);582if (Br && Br->isUnconditional() &&583Br->getSuccessor(0) == FI.InnerLoop->getHeader())584continue;585// Multiplies of the outer iteration variable and inner iteration586// count will be optimised out.587if (match(&I, m_c_Mul(m_Specific(FI.OuterInductionPHI),588m_Specific(FI.InnerTripCount))))589continue;590InstructionCost Cost =591TTI->getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency);592LLVM_DEBUG(dbgs() << "Cost " << Cost << ": "; I.dump());593RepeatedInstrCost += Cost;594}595}596597LLVM_DEBUG(dbgs() << "Cost of instructions that will be repeated: "598<< RepeatedInstrCost << "\n");599// Bail out if flattening the loops would cause instructions in the outer600// loop but not in the inner loop to be executed extra times.601if (RepeatedInstrCost > RepeatedInstructionThreshold) {602LLVM_DEBUG(dbgs() << "checkOuterLoopInsts: not profitable, bailing.\n");603return false;604}605606LLVM_DEBUG(dbgs() << "checkOuterLoopInsts: OK\n");607return true;608}609610611612// We require all uses of both induction variables to match this pattern:613//614// (OuterPHI * InnerTripCount) + InnerPHI615//616// Any uses of the induction variables not matching that pattern would617// require a div/mod to reconstruct in the flattened loop, so the618// transformation wouldn't be profitable.619static bool checkIVUsers(FlattenInfo &FI) {620// Check that all uses of the inner loop's induction variable match the621// expected pattern, recording the uses of the outer IV.622SmallPtrSet<Value *, 4> ValidOuterPHIUses;623if (!FI.checkInnerInductionPhiUsers(ValidOuterPHIUses))624return false;625626// Check that there are no uses of the outer IV other than the ones found627// as part of the pattern above.628if (!FI.checkOuterInductionPhiUsers(ValidOuterPHIUses))629return false;630631LLVM_DEBUG(dbgs() << "checkIVUsers: OK\n";632dbgs() << "Found " << FI.LinearIVUses.size()633<< " value(s) that can be replaced:\n";634for (Value *V : FI.LinearIVUses) {635dbgs() << " ";636V->dump();637});638return true;639}640641// Return an OverflowResult dependant on if overflow of the multiplication of642// InnerTripCount and OuterTripCount can be assumed not to happen.643static OverflowResult checkOverflow(FlattenInfo &FI, DominatorTree *DT,644AssumptionCache *AC) {645Function *F = FI.OuterLoop->getHeader()->getParent();646const DataLayout &DL = F->getDataLayout();647648// For debugging/testing.649if (AssumeNoOverflow)650return OverflowResult::NeverOverflows;651652// Check if the multiply could not overflow due to known ranges of the653// input values.654OverflowResult OR = computeOverflowForUnsignedMul(655FI.InnerTripCount, FI.OuterTripCount,656SimplifyQuery(DL, DT, AC,657FI.OuterLoop->getLoopPreheader()->getTerminator()));658if (OR != OverflowResult::MayOverflow)659return OR;660661auto CheckGEP = [&](GetElementPtrInst *GEP, Value *GEPOperand) {662for (Value *GEPUser : GEP->users()) {663auto *GEPUserInst = cast<Instruction>(GEPUser);664if (!isa<LoadInst>(GEPUserInst) &&665!(isa<StoreInst>(GEPUserInst) && GEP == GEPUserInst->getOperand(1)))666continue;667if (!isGuaranteedToExecuteForEveryIteration(GEPUserInst, FI.InnerLoop))668continue;669// The IV is used as the operand of a GEP which dominates the loop670// latch, and the IV is at least as wide as the address space of the671// GEP. In this case, the GEP would wrap around the address space672// before the IV increment wraps, which would be UB.673if (GEP->isInBounds() &&674GEPOperand->getType()->getIntegerBitWidth() >=675DL.getPointerTypeSizeInBits(GEP->getType())) {676LLVM_DEBUG(677dbgs() << "use of linear IV would be UB if overflow occurred: ";678GEP->dump());679return true;680}681}682return false;683};684685// Check if any IV user is, or is used by, a GEP that would cause UB if the686// multiply overflows.687for (Value *V : FI.LinearIVUses) {688if (auto *GEP = dyn_cast<GetElementPtrInst>(V))689if (GEP->getNumIndices() == 1 && CheckGEP(GEP, GEP->getOperand(1)))690return OverflowResult::NeverOverflows;691for (Value *U : V->users())692if (auto *GEP = dyn_cast<GetElementPtrInst>(U))693if (CheckGEP(GEP, V))694return OverflowResult::NeverOverflows;695}696697return OverflowResult::MayOverflow;698}699700static bool CanFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,701ScalarEvolution *SE, AssumptionCache *AC,702const TargetTransformInfo *TTI) {703SmallPtrSet<Instruction *, 8> IterationInstructions;704if (!findLoopComponents(FI.InnerLoop, IterationInstructions,705FI.InnerInductionPHI, FI.InnerTripCount,706FI.InnerIncrement, FI.InnerBranch, SE, FI.Widened))707return false;708if (!findLoopComponents(FI.OuterLoop, IterationInstructions,709FI.OuterInductionPHI, FI.OuterTripCount,710FI.OuterIncrement, FI.OuterBranch, SE, FI.Widened))711return false;712713// Both of the loop trip count values must be invariant in the outer loop714// (non-instructions are all inherently invariant).715if (!FI.OuterLoop->isLoopInvariant(FI.InnerTripCount)) {716LLVM_DEBUG(dbgs() << "inner loop trip count not invariant\n");717return false;718}719if (!FI.OuterLoop->isLoopInvariant(FI.OuterTripCount)) {720LLVM_DEBUG(dbgs() << "outer loop trip count not invariant\n");721return false;722}723724if (!checkPHIs(FI, TTI))725return false;726727// FIXME: it should be possible to handle different types correctly.728if (FI.InnerInductionPHI->getType() != FI.OuterInductionPHI->getType())729return false;730731if (!checkOuterLoopInsts(FI, IterationInstructions, TTI))732return false;733734// Find the values in the loop that can be replaced with the linearized735// induction variable, and check that there are no other uses of the inner736// or outer induction variable. If there were, we could still do this737// transformation, but we'd have to insert a div/mod to calculate the738// original IVs, so it wouldn't be profitable.739if (!checkIVUsers(FI))740return false;741742LLVM_DEBUG(dbgs() << "CanFlattenLoopPair: OK\n");743return true;744}745746static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,747ScalarEvolution *SE, AssumptionCache *AC,748const TargetTransformInfo *TTI, LPMUpdater *U,749MemorySSAUpdater *MSSAU) {750Function *F = FI.OuterLoop->getHeader()->getParent();751LLVM_DEBUG(dbgs() << "Checks all passed, doing the transformation\n");752{753using namespace ore;754OptimizationRemark Remark(DEBUG_TYPE, "Flattened", FI.InnerLoop->getStartLoc(),755FI.InnerLoop->getHeader());756OptimizationRemarkEmitter ORE(F);757Remark << "Flattened into outer loop";758ORE.emit(Remark);759}760761if (!FI.NewTripCount) {762FI.NewTripCount = BinaryOperator::CreateMul(763FI.InnerTripCount, FI.OuterTripCount, "flatten.tripcount",764FI.OuterLoop->getLoopPreheader()->getTerminator()->getIterator());765LLVM_DEBUG(dbgs() << "Created new trip count in preheader: ";766FI.NewTripCount->dump());767}768769// Fix up PHI nodes that take values from the inner loop back-edge, which770// we are about to remove.771FI.InnerInductionPHI->removeIncomingValue(FI.InnerLoop->getLoopLatch());772773// The old Phi will be optimised away later, but for now we can't leave774// leave it in an invalid state, so are updating them too.775for (PHINode *PHI : FI.InnerPHIsToTransform)776PHI->removeIncomingValue(FI.InnerLoop->getLoopLatch());777778// Modify the trip count of the outer loop to be the product of the two779// trip counts.780cast<User>(FI.OuterBranch->getCondition())->setOperand(1, FI.NewTripCount);781782// Replace the inner loop backedge with an unconditional branch to the exit.783BasicBlock *InnerExitBlock = FI.InnerLoop->getExitBlock();784BasicBlock *InnerExitingBlock = FI.InnerLoop->getExitingBlock();785Instruction *Term = InnerExitingBlock->getTerminator();786Instruction *BI = BranchInst::Create(InnerExitBlock, InnerExitingBlock);787BI->setDebugLoc(Term->getDebugLoc());788Term->eraseFromParent();789790// Update the DomTree and MemorySSA.791DT->deleteEdge(InnerExitingBlock, FI.InnerLoop->getHeader());792if (MSSAU)793MSSAU->removeEdge(InnerExitingBlock, FI.InnerLoop->getHeader());794795// Replace all uses of the polynomial calculated from the two induction796// variables with the one new one.797IRBuilder<> Builder(FI.OuterInductionPHI->getParent()->getTerminator());798for (Value *V : FI.LinearIVUses) {799Value *OuterValue = FI.OuterInductionPHI;800if (FI.Widened)801OuterValue = Builder.CreateTrunc(FI.OuterInductionPHI, V->getType(),802"flatten.trunciv");803804if (auto *GEP = dyn_cast<GetElementPtrInst>(V)) {805// Replace the GEP with one that uses OuterValue as the offset.806auto *InnerGEP = cast<GetElementPtrInst>(GEP->getOperand(0));807Value *Base = InnerGEP->getOperand(0);808// When the base of the GEP doesn't dominate the outer induction phi then809// we need to insert the new GEP where the old GEP was.810if (!DT->dominates(Base, &*Builder.GetInsertPoint()))811Builder.SetInsertPoint(cast<Instruction>(V));812OuterValue =813Builder.CreateGEP(GEP->getSourceElementType(), Base, OuterValue,814"flatten." + V->getName(),815GEP->isInBounds() && InnerGEP->isInBounds());816}817818LLVM_DEBUG(dbgs() << "Replacing: "; V->dump(); dbgs() << "with: ";819OuterValue->dump());820V->replaceAllUsesWith(OuterValue);821}822823// Tell LoopInfo, SCEV and the pass manager that the inner loop has been824// deleted, and invalidate any outer loop information.825SE->forgetLoop(FI.OuterLoop);826SE->forgetBlockAndLoopDispositions();827if (U)828U->markLoopAsDeleted(*FI.InnerLoop, FI.InnerLoop->getName());829LI->erase(FI.InnerLoop);830831// Increment statistic value.832NumFlattened++;833834return true;835}836837static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,838ScalarEvolution *SE, AssumptionCache *AC,839const TargetTransformInfo *TTI) {840if (!WidenIV) {841LLVM_DEBUG(dbgs() << "Widening the IVs is disabled\n");842return false;843}844845LLVM_DEBUG(dbgs() << "Try widening the IVs\n");846Module *M = FI.InnerLoop->getHeader()->getParent()->getParent();847auto &DL = M->getDataLayout();848auto *InnerType = FI.InnerInductionPHI->getType();849auto *OuterType = FI.OuterInductionPHI->getType();850unsigned MaxLegalSize = DL.getLargestLegalIntTypeSizeInBits();851auto *MaxLegalType = DL.getLargestLegalIntType(M->getContext());852853// If both induction types are less than the maximum legal integer width,854// promote both to the widest type available so we know calculating855// (OuterTripCount * InnerTripCount) as the new trip count is safe.856if (InnerType != OuterType ||857InnerType->getScalarSizeInBits() >= MaxLegalSize ||858MaxLegalType->getScalarSizeInBits() <859InnerType->getScalarSizeInBits() * 2) {860LLVM_DEBUG(dbgs() << "Can't widen the IV\n");861return false;862}863864SCEVExpander Rewriter(*SE, DL, "loopflatten");865SmallVector<WeakTrackingVH, 4> DeadInsts;866unsigned ElimExt = 0;867unsigned Widened = 0;868869auto CreateWideIV = [&](WideIVInfo WideIV, bool &Deleted) -> bool {870PHINode *WidePhi =871createWideIV(WideIV, LI, SE, Rewriter, DT, DeadInsts, ElimExt, Widened,872true /* HasGuards */, true /* UsePostIncrementRanges */);873if (!WidePhi)874return false;875LLVM_DEBUG(dbgs() << "Created wide phi: "; WidePhi->dump());876LLVM_DEBUG(dbgs() << "Deleting old phi: "; WideIV.NarrowIV->dump());877Deleted = RecursivelyDeleteDeadPHINode(WideIV.NarrowIV);878return true;879};880881bool Deleted;882if (!CreateWideIV({FI.InnerInductionPHI, MaxLegalType, false}, Deleted))883return false;884// Add the narrow phi to list, so that it will be adjusted later when the885// the transformation is performed.886if (!Deleted)887FI.InnerPHIsToTransform.insert(FI.InnerInductionPHI);888889if (!CreateWideIV({FI.OuterInductionPHI, MaxLegalType, false}, Deleted))890return false;891892assert(Widened && "Widened IV expected");893FI.Widened = true;894895// Save the old/narrow induction phis, which we need to ignore in CheckPHIs.896FI.NarrowInnerInductionPHI = FI.InnerInductionPHI;897FI.NarrowOuterInductionPHI = FI.OuterInductionPHI;898899// After widening, rediscover all the loop components.900return CanFlattenLoopPair(FI, DT, LI, SE, AC, TTI);901}902903static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,904ScalarEvolution *SE, AssumptionCache *AC,905const TargetTransformInfo *TTI, LPMUpdater *U,906MemorySSAUpdater *MSSAU,907const LoopAccessInfo &LAI) {908LLVM_DEBUG(909dbgs() << "Loop flattening running on outer loop "910<< FI.OuterLoop->getHeader()->getName() << " and inner loop "911<< FI.InnerLoop->getHeader()->getName() << " in "912<< FI.OuterLoop->getHeader()->getParent()->getName() << "\n");913914if (!CanFlattenLoopPair(FI, DT, LI, SE, AC, TTI))915return false;916917// Check if we can widen the induction variables to avoid overflow checks.918bool CanFlatten = CanWidenIV(FI, DT, LI, SE, AC, TTI);919920// It can happen that after widening of the IV, flattening may not be921// possible/happening, e.g. when it is deemed unprofitable. So bail here if922// that is the case.923// TODO: IV widening without performing the actual flattening transformation924// is not ideal. While this codegen change should not matter much, it is an925// unnecessary change which is better to avoid. It's unlikely this happens926// often, because if it's unprofitibale after widening, it should be927// unprofitabe before widening as checked in the first round of checks. But928// 'RepeatedInstructionThreshold' is set to only 2, which can probably be929// relaxed. Because this is making a code change (the IV widening, but not930// the flattening), we return true here.931if (FI.Widened && !CanFlatten)932return true;933934// If we have widened and can perform the transformation, do that here.935if (CanFlatten)936return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU);937938// Otherwise, if we haven't widened the IV, check if the new iteration939// variable might overflow. In this case, we need to version the loop, and940// select the original version at runtime if the iteration space is too941// large.942OverflowResult OR = checkOverflow(FI, DT, AC);943if (OR == OverflowResult::AlwaysOverflowsHigh ||944OR == OverflowResult::AlwaysOverflowsLow) {945LLVM_DEBUG(dbgs() << "Multiply would always overflow, so not profitable\n");946return false;947} else if (OR == OverflowResult::MayOverflow) {948Module *M = FI.OuterLoop->getHeader()->getParent()->getParent();949const DataLayout &DL = M->getDataLayout();950if (!VersionLoops) {951LLVM_DEBUG(dbgs() << "Multiply might overflow, not flattening\n");952return false;953} else if (!DL.isLegalInteger(954FI.OuterTripCount->getType()->getScalarSizeInBits())) {955// If the trip count type isn't legal then it won't be possible to check956// for overflow using only a single multiply instruction, so don't957// flatten.958LLVM_DEBUG(959dbgs() << "Can't check overflow efficiently, not flattening\n");960return false;961}962LLVM_DEBUG(dbgs() << "Multiply might overflow, versioning loop\n");963964// Version the loop. The overflow check isn't a runtime pointer check, so we965// pass an empty list of runtime pointer checks, causing LoopVersioning to966// emit 'false' as the branch condition, and add our own check afterwards.967BasicBlock *CheckBlock = FI.OuterLoop->getLoopPreheader();968ArrayRef<RuntimePointerCheck> Checks(nullptr, nullptr);969LoopVersioning LVer(LAI, Checks, FI.OuterLoop, LI, DT, SE);970LVer.versionLoop();971972// Check for overflow by calculating the new tripcount using973// umul_with_overflow and then checking if it overflowed.974BranchInst *Br = cast<BranchInst>(CheckBlock->getTerminator());975assert(Br->isConditional() &&976"Expected LoopVersioning to generate a conditional branch");977assert(match(Br->getCondition(), m_Zero()) &&978"Expected branch condition to be false");979IRBuilder<> Builder(Br);980Function *F = Intrinsic::getDeclaration(M, Intrinsic::umul_with_overflow,981FI.OuterTripCount->getType());982Value *Call = Builder.CreateCall(F, {FI.OuterTripCount, FI.InnerTripCount},983"flatten.mul");984FI.NewTripCount = Builder.CreateExtractValue(Call, 0, "flatten.tripcount");985Value *Overflow = Builder.CreateExtractValue(Call, 1, "flatten.overflow");986Br->setCondition(Overflow);987} else {988LLVM_DEBUG(dbgs() << "Multiply cannot overflow, modifying loop in-place\n");989}990991return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU);992}993994PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM,995LoopStandardAnalysisResults &AR,996LPMUpdater &U) {997998bool Changed = false;9991000std::optional<MemorySSAUpdater> MSSAU;1001if (AR.MSSA) {1002MSSAU = MemorySSAUpdater(AR.MSSA);1003if (VerifyMemorySSA)1004AR.MSSA->verifyMemorySSA();1005}10061007// The loop flattening pass requires loops to be1008// in simplified form, and also needs LCSSA. Running1009// this pass will simplify all loops that contain inner loops,1010// regardless of whether anything ends up being flattened.1011LoopAccessInfoManager LAIM(AR.SE, AR.AA, AR.DT, AR.LI, &AR.TTI, nullptr);1012for (Loop *InnerLoop : LN.getLoops()) {1013auto *OuterLoop = InnerLoop->getParentLoop();1014if (!OuterLoop)1015continue;1016FlattenInfo FI(OuterLoop, InnerLoop);1017Changed |=1018FlattenLoopPair(FI, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, &U,1019MSSAU ? &*MSSAU : nullptr, LAIM.getInfo(*OuterLoop));1020}10211022if (!Changed)1023return PreservedAnalyses::all();10241025if (AR.MSSA && VerifyMemorySSA)1026AR.MSSA->verifyMemorySSA();10271028auto PA = getLoopPassPreservedAnalyses();1029if (AR.MSSA)1030PA.preserve<MemorySSAAnalysis>();1031return PA;1032}103310341035