Path: blob/main/contrib/llvm-project/llvm/lib/Transforms/Vectorize/EVLIndVarSimplify.cpp
213799 views
//===---- EVLIndVarSimplify.cpp - Optimize vectorized loops w/ EVL IV------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// This pass optimizes a vectorized loop with canonical IV to using EVL-based9// IV if it was tail-folded by predicated EVL.10//11//===----------------------------------------------------------------------===//1213#include "llvm/Transforms/Vectorize/EVLIndVarSimplify.h"14#include "llvm/ADT/Statistic.h"15#include "llvm/Analysis/IVDescriptors.h"16#include "llvm/Analysis/LoopInfo.h"17#include "llvm/Analysis/LoopPass.h"18#include "llvm/Analysis/OptimizationRemarkEmitter.h"19#include "llvm/Analysis/ScalarEvolution.h"20#include "llvm/Analysis/ScalarEvolutionExpressions.h"21#include "llvm/Analysis/ValueTracking.h"22#include "llvm/IR/IRBuilder.h"23#include "llvm/IR/PatternMatch.h"24#include "llvm/Support/CommandLine.h"25#include "llvm/Support/Debug.h"26#include "llvm/Support/MathExtras.h"27#include "llvm/Support/raw_ostream.h"28#include "llvm/Transforms/Scalar/LoopPassManager.h"29#include "llvm/Transforms/Utils/Local.h"3031#define DEBUG_TYPE "evl-iv-simplify"3233using namespace llvm;3435STATISTIC(NumEliminatedCanonicalIV, "Number of canonical IVs we eliminated");3637static cl::opt<bool> EnableEVLIndVarSimplify(38"enable-evl-indvar-simplify",39cl::desc("Enable EVL-based induction variable simplify Pass"), cl::Hidden,40cl::init(true));4142namespace {43struct EVLIndVarSimplifyImpl {44ScalarEvolution &SE;45OptimizationRemarkEmitter *ORE = nullptr;4647EVLIndVarSimplifyImpl(LoopStandardAnalysisResults &LAR,48OptimizationRemarkEmitter *ORE)49: SE(LAR.SE), ORE(ORE) {}5051/// Returns true if modify the loop.52bool run(Loop &L);53};54} // anonymous namespace5556/// Returns the constant part of vectorization factor from the induction57/// variable's step value SCEV expression.58static uint32_t getVFFromIndVar(const SCEV *Step, const Function &F) {59if (!Step)60return 0U;6162// Looking for loops with IV step value in the form of `(<constant VF> x63// vscale)`.64if (const auto *Mul = dyn_cast<SCEVMulExpr>(Step)) {65if (Mul->getNumOperands() == 2) {66const SCEV *LHS = Mul->getOperand(0);67const SCEV *RHS = Mul->getOperand(1);68if (const auto *Const = dyn_cast<SCEVConstant>(LHS);69Const && isa<SCEVVScale>(RHS)) {70uint64_t V = Const->getAPInt().getLimitedValue();71if (llvm::isUInt<32>(V))72return V;73}74}75}7677// If not, see if the vscale_range of the parent function is a fixed value,78// which makes the step value to be replaced by a constant.79if (F.hasFnAttribute(Attribute::VScaleRange))80if (const auto *ConstStep = dyn_cast<SCEVConstant>(Step)) {81APInt V = ConstStep->getAPInt().abs();82ConstantRange CR = llvm::getVScaleRange(&F, 64);83if (const APInt *Fixed = CR.getSingleElement()) {84V = V.zextOrTrunc(Fixed->getBitWidth());85uint64_t VF = V.udiv(*Fixed).getLimitedValue();86if (VF && llvm::isUInt<32>(VF) &&87// Make sure step is divisible by vscale.88V.urem(*Fixed).isZero())89return VF;90}91}9293return 0U;94}9596bool EVLIndVarSimplifyImpl::run(Loop &L) {97if (!EnableEVLIndVarSimplify)98return false;99100if (!getBooleanLoopAttribute(&L, "llvm.loop.isvectorized"))101return false;102const MDOperand *EVLMD =103findStringMetadataForLoop(&L, "llvm.loop.isvectorized.tailfoldingstyle")104.value_or(nullptr);105if (!EVLMD || !EVLMD->equalsStr("evl"))106return false;107108BasicBlock *LatchBlock = L.getLoopLatch();109ICmpInst *OrigLatchCmp = L.getLatchCmpInst();110if (!LatchBlock || !OrigLatchCmp)111return false;112113InductionDescriptor IVD;114PHINode *IndVar = L.getInductionVariable(SE);115if (!IndVar || !L.getInductionDescriptor(SE, IVD)) {116const char *Reason = (IndVar ? "induction descriptor is not available"117: "cannot recognize induction variable");118LLVM_DEBUG(dbgs() << "Cannot retrieve IV from loop " << L.getName()119<< " because" << Reason << "\n");120if (ORE) {121ORE->emit([&]() {122return OptimizationRemarkMissed(DEBUG_TYPE, "UnrecognizedIndVar",123L.getStartLoc(), L.getHeader())124<< "Cannot retrieve IV because " << ore::NV("Reason", Reason);125});126}127return false;128}129130BasicBlock *InitBlock, *BackEdgeBlock;131if (!L.getIncomingAndBackEdge(InitBlock, BackEdgeBlock)) {132LLVM_DEBUG(dbgs() << "Expect unique incoming and backedge in "133<< L.getName() << "\n");134if (ORE) {135ORE->emit([&]() {136return OptimizationRemarkMissed(DEBUG_TYPE, "UnrecognizedLoopStructure",137L.getStartLoc(), L.getHeader())138<< "Does not have a unique incoming and backedge";139});140}141return false;142}143144// Retrieve the loop bounds.145std::optional<Loop::LoopBounds> Bounds = L.getBounds(SE);146if (!Bounds) {147LLVM_DEBUG(dbgs() << "Could not obtain the bounds for loop " << L.getName()148<< "\n");149if (ORE) {150ORE->emit([&]() {151return OptimizationRemarkMissed(DEBUG_TYPE, "UnrecognizedLoopStructure",152L.getStartLoc(), L.getHeader())153<< "Could not obtain the loop bounds";154});155}156return false;157}158Value *CanonicalIVInit = &Bounds->getInitialIVValue();159Value *CanonicalIVFinal = &Bounds->getFinalIVValue();160161const SCEV *StepV = IVD.getStep();162uint32_t VF = getVFFromIndVar(StepV, *L.getHeader()->getParent());163if (!VF) {164LLVM_DEBUG(dbgs() << "Could not infer VF from IndVar step '" << *StepV165<< "'\n");166if (ORE) {167ORE->emit([&]() {168return OptimizationRemarkMissed(DEBUG_TYPE, "UnrecognizedIndVar",169L.getStartLoc(), L.getHeader())170<< "Could not infer VF from IndVar step "171<< ore::NV("Step", StepV);172});173}174return false;175}176LLVM_DEBUG(dbgs() << "Using VF=" << VF << " for loop " << L.getName()177<< "\n");178179// Try to find the EVL-based induction variable.180using namespace PatternMatch;181BasicBlock *BB = IndVar->getParent();182183Value *EVLIndVar = nullptr;184Value *RemTC = nullptr;185Value *TC = nullptr;186auto IntrinsicMatch = m_Intrinsic<Intrinsic::experimental_get_vector_length>(187m_Value(RemTC), m_SpecificInt(VF),188/*Scalable=*/m_SpecificInt(1));189for (PHINode &PN : BB->phis()) {190if (&PN == IndVar)191continue;192193// Check 1: it has to contain both incoming (init) & backedge blocks194// from IndVar.195if (PN.getBasicBlockIndex(InitBlock) < 0 ||196PN.getBasicBlockIndex(BackEdgeBlock) < 0)197continue;198// Check 2: EVL index is always increasing, thus its inital value has to be199// equal to either the initial IV value (when the canonical IV is also200// increasing) or the last IV value (when canonical IV is decreasing).201Value *Init = PN.getIncomingValueForBlock(InitBlock);202using Direction = Loop::LoopBounds::Direction;203switch (Bounds->getDirection()) {204case Direction::Increasing:205if (Init != CanonicalIVInit)206continue;207break;208case Direction::Decreasing:209if (Init != CanonicalIVFinal)210continue;211break;212case Direction::Unknown:213// To be more permissive and see if either the initial or final IV value214// matches PN's init value.215if (Init != CanonicalIVInit && Init != CanonicalIVFinal)216continue;217break;218}219Value *RecValue = PN.getIncomingValueForBlock(BackEdgeBlock);220assert(RecValue && "expect recurrent IndVar value");221222LLVM_DEBUG(dbgs() << "Found candidate PN of EVL-based IndVar: " << PN223<< "\n");224225// Check 3: Pattern match to find the EVL-based index and total trip count226// (TC).227if (match(RecValue,228m_c_Add(m_ZExtOrSelf(IntrinsicMatch), m_Specific(&PN))) &&229match(RemTC, m_Sub(m_Value(TC), m_Specific(&PN)))) {230EVLIndVar = RecValue;231break;232}233}234235if (!EVLIndVar || !TC)236return false;237238LLVM_DEBUG(dbgs() << "Using " << *EVLIndVar << " for EVL-based IndVar\n");239if (ORE) {240ORE->emit([&]() {241DebugLoc DL;242BasicBlock *Region = nullptr;243if (auto *I = dyn_cast<Instruction>(EVLIndVar)) {244DL = I->getDebugLoc();245Region = I->getParent();246} else {247DL = L.getStartLoc();248Region = L.getHeader();249}250return OptimizationRemark(DEBUG_TYPE, "UseEVLIndVar", DL, Region)251<< "Using " << ore::NV("EVLIndVar", EVLIndVar)252<< " for EVL-based IndVar";253});254}255256// Create an EVL-based comparison and replace the branch to use it as257// predicate.258259// Loop::getLatchCmpInst check at the beginning of this function has ensured260// that latch block ends in a conditional branch.261auto *LatchBranch = cast<BranchInst>(LatchBlock->getTerminator());262assert(LatchBranch->isConditional() &&263"expect the loop latch to be ended with a conditional branch");264ICmpInst::Predicate Pred;265if (LatchBranch->getSuccessor(0) == L.getHeader())266Pred = ICmpInst::ICMP_NE;267else268Pred = ICmpInst::ICMP_EQ;269270IRBuilder<> Builder(OrigLatchCmp);271auto *NewLatchCmp = Builder.CreateICmp(Pred, EVLIndVar, TC);272OrigLatchCmp->replaceAllUsesWith(NewLatchCmp);273274// llvm::RecursivelyDeleteDeadPHINode only deletes cycles whose values are275// not used outside the cycles. However, in this case the now-RAUW-ed276// OrigLatchCmp will be considered a use outside the cycle while in reality277// it's practically dead. Thus we need to remove it before calling278// RecursivelyDeleteDeadPHINode.279(void)RecursivelyDeleteTriviallyDeadInstructions(OrigLatchCmp);280if (llvm::RecursivelyDeleteDeadPHINode(IndVar))281LLVM_DEBUG(dbgs() << "Removed original IndVar\n");282283++NumEliminatedCanonicalIV;284285return true;286}287288PreservedAnalyses EVLIndVarSimplifyPass::run(Loop &L, LoopAnalysisManager &LAM,289LoopStandardAnalysisResults &AR,290LPMUpdater &U) {291Function &F = *L.getHeader()->getParent();292auto &FAMProxy = LAM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR);293OptimizationRemarkEmitter *ORE =294FAMProxy.getCachedResult<OptimizationRemarkEmitterAnalysis>(F);295296if (EVLIndVarSimplifyImpl(AR, ORE).run(L))297return PreservedAnalyses::allInSet<CFGAnalyses>();298return PreservedAnalyses::all();299}300301302