Path: blob/main/contrib/llvm-project/llvm/lib/CodeGen/ExpandVectorPredication.cpp
35232 views
//===----- CodeGen/ExpandVectorPredication.cpp - Expand VP intrinsics -----===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// This pass implements IR expansion for vector predication intrinsics, allowing9// targets to enable vector predication until just before codegen.10//11//===----------------------------------------------------------------------===//1213#include "llvm/CodeGen/ExpandVectorPredication.h"14#include "llvm/ADT/Statistic.h"15#include "llvm/Analysis/TargetTransformInfo.h"16#include "llvm/Analysis/ValueTracking.h"17#include "llvm/Analysis/VectorUtils.h"18#include "llvm/CodeGen/Passes.h"19#include "llvm/IR/Constants.h"20#include "llvm/IR/Function.h"21#include "llvm/IR/IRBuilder.h"22#include "llvm/IR/InstIterator.h"23#include "llvm/IR/Instructions.h"24#include "llvm/IR/IntrinsicInst.h"25#include "llvm/IR/Intrinsics.h"26#include "llvm/InitializePasses.h"27#include "llvm/Pass.h"28#include "llvm/Support/CommandLine.h"29#include "llvm/Support/Compiler.h"30#include "llvm/Support/Debug.h"31#include <optional>3233using namespace llvm;3435using VPLegalization = TargetTransformInfo::VPLegalization;36using VPTransform = TargetTransformInfo::VPLegalization::VPTransform;3738// Keep this in sync with TargetTransformInfo::VPLegalization.39#define VPINTERNAL_VPLEGAL_CASES \40VPINTERNAL_CASE(Legal) \41VPINTERNAL_CASE(Discard) \42VPINTERNAL_CASE(Convert)4344#define VPINTERNAL_CASE(X) "|" #X4546// Override options.47static cl::opt<std::string> EVLTransformOverride(48"expandvp-override-evl-transform", cl::init(""), cl::Hidden,49cl::desc("Options: <empty>" VPINTERNAL_VPLEGAL_CASES50". If non-empty, ignore "51"TargetTransformInfo and "52"always use this transformation for the %evl parameter (Used in "53"testing)."));5455static cl::opt<std::string> MaskTransformOverride(56"expandvp-override-mask-transform", cl::init(""), cl::Hidden,57cl::desc("Options: <empty>" VPINTERNAL_VPLEGAL_CASES58". If non-empty, Ignore "59"TargetTransformInfo and "60"always use this transformation for the %mask parameter (Used in "61"testing)."));6263#undef VPINTERNAL_CASE64#define VPINTERNAL_CASE(X) .Case(#X, VPLegalization::X)6566static VPTransform parseOverrideOption(const std::string &TextOpt) {67return StringSwitch<VPTransform>(TextOpt) VPINTERNAL_VPLEGAL_CASES;68}6970#undef VPINTERNAL_VPLEGAL_CASES7172// Whether any override options are set.73static bool anyExpandVPOverridesSet() {74return !EVLTransformOverride.empty() || !MaskTransformOverride.empty();75}7677#define DEBUG_TYPE "expandvp"7879STATISTIC(NumFoldedVL, "Number of folded vector length params");80STATISTIC(NumLoweredVPOps, "Number of folded vector predication operations");8182///// Helpers {8384/// \returns Whether the vector mask \p MaskVal has all lane bits set.85static bool isAllTrueMask(Value *MaskVal) {86if (Value *SplattedVal = getSplatValue(MaskVal))87if (auto *ConstValue = dyn_cast<Constant>(SplattedVal))88return ConstValue->isAllOnesValue();8990return false;91}9293/// \returns A non-excepting divisor constant for this type.94static Constant *getSafeDivisor(Type *DivTy) {95assert(DivTy->isIntOrIntVectorTy() && "Unsupported divisor type");96return ConstantInt::get(DivTy, 1u, false);97}9899/// Transfer operation properties from \p OldVPI to \p NewVal.100static void transferDecorations(Value &NewVal, VPIntrinsic &VPI) {101auto *NewInst = dyn_cast<Instruction>(&NewVal);102if (!NewInst || !isa<FPMathOperator>(NewVal))103return;104105auto *OldFMOp = dyn_cast<FPMathOperator>(&VPI);106if (!OldFMOp)107return;108109NewInst->setFastMathFlags(OldFMOp->getFastMathFlags());110}111112/// Transfer all properties from \p OldOp to \p NewOp and replace all uses.113/// OldVP gets erased.114static void replaceOperation(Value &NewOp, VPIntrinsic &OldOp) {115transferDecorations(NewOp, OldOp);116OldOp.replaceAllUsesWith(&NewOp);117OldOp.eraseFromParent();118}119120static bool maySpeculateLanes(VPIntrinsic &VPI) {121// The result of VP reductions depends on the mask and evl.122if (isa<VPReductionIntrinsic>(VPI))123return false;124// Fallback to whether the intrinsic is speculatable.125if (auto IntrID = VPI.getFunctionalIntrinsicID())126return Intrinsic::getAttributes(VPI.getContext(), *IntrID)127.hasFnAttr(Attribute::AttrKind::Speculatable);128if (auto Opc = VPI.getFunctionalOpcode())129return isSafeToSpeculativelyExecuteWithOpcode(*Opc, &VPI);130return false;131}132133//// } Helpers134135namespace {136137// Expansion pass state at function scope.138struct CachingVPExpander {139Function &F;140const TargetTransformInfo &TTI;141142/// \returns A (fixed length) vector with ascending integer indices143/// (<0, 1, ..., NumElems-1>).144/// \p Builder145/// Used for instruction creation.146/// \p LaneTy147/// Integer element type of the result vector.148/// \p NumElems149/// Number of vector elements.150Value *createStepVector(IRBuilder<> &Builder, Type *LaneTy,151unsigned NumElems);152153/// \returns A bitmask that is true where the lane position is less-than \p154/// EVLParam155///156/// \p Builder157/// Used for instruction creation.158/// \p VLParam159/// The explicit vector length parameter to test against the lane160/// positions.161/// \p ElemCount162/// Static (potentially scalable) number of vector elements.163Value *convertEVLToMask(IRBuilder<> &Builder, Value *EVLParam,164ElementCount ElemCount);165166Value *foldEVLIntoMask(VPIntrinsic &VPI);167168/// "Remove" the %evl parameter of \p PI by setting it to the static vector169/// length of the operation.170void discardEVLParameter(VPIntrinsic &PI);171172/// Lower this VP binary operator to a unpredicated binary operator.173Value *expandPredicationInBinaryOperator(IRBuilder<> &Builder,174VPIntrinsic &PI);175176/// Lower this VP int call to a unpredicated int call.177Value *expandPredicationToIntCall(IRBuilder<> &Builder, VPIntrinsic &PI,178unsigned UnpredicatedIntrinsicID);179180/// Lower this VP fp call to a unpredicated fp call.181Value *expandPredicationToFPCall(IRBuilder<> &Builder, VPIntrinsic &PI,182unsigned UnpredicatedIntrinsicID);183184/// Lower this VP reduction to a call to an unpredicated reduction intrinsic.185Value *expandPredicationInReduction(IRBuilder<> &Builder,186VPReductionIntrinsic &PI);187188/// Lower this VP cast operation to a non-VP intrinsic.189Value *expandPredicationToCastIntrinsic(IRBuilder<> &Builder,190VPIntrinsic &VPI);191192/// Lower this VP memory operation to a non-VP intrinsic.193Value *expandPredicationInMemoryIntrinsic(IRBuilder<> &Builder,194VPIntrinsic &VPI);195196/// Lower this VP comparison to a call to an unpredicated comparison.197Value *expandPredicationInComparison(IRBuilder<> &Builder,198VPCmpIntrinsic &PI);199200/// Query TTI and expand the vector predication in \p P accordingly.201Value *expandPredication(VPIntrinsic &PI);202203/// Determine how and whether the VPIntrinsic \p VPI shall be expanded. This204/// overrides TTI with the cl::opts listed at the top of this file.205VPLegalization getVPLegalizationStrategy(const VPIntrinsic &VPI) const;206bool UsingTTIOverrides;207208public:209CachingVPExpander(Function &F, const TargetTransformInfo &TTI)210: F(F), TTI(TTI), UsingTTIOverrides(anyExpandVPOverridesSet()) {}211212bool expandVectorPredication();213};214215//// CachingVPExpander {216217Value *CachingVPExpander::createStepVector(IRBuilder<> &Builder, Type *LaneTy,218unsigned NumElems) {219// TODO add caching220SmallVector<Constant *, 16> ConstElems;221222for (unsigned Idx = 0; Idx < NumElems; ++Idx)223ConstElems.push_back(ConstantInt::get(LaneTy, Idx, false));224225return ConstantVector::get(ConstElems);226}227228Value *CachingVPExpander::convertEVLToMask(IRBuilder<> &Builder,229Value *EVLParam,230ElementCount ElemCount) {231// TODO add caching232// Scalable vector %evl conversion.233if (ElemCount.isScalable()) {234auto *M = Builder.GetInsertBlock()->getModule();235Type *BoolVecTy = VectorType::get(Builder.getInt1Ty(), ElemCount);236Function *ActiveMaskFunc = Intrinsic::getDeclaration(237M, Intrinsic::get_active_lane_mask, {BoolVecTy, EVLParam->getType()});238// `get_active_lane_mask` performs an implicit less-than comparison.239Value *ConstZero = Builder.getInt32(0);240return Builder.CreateCall(ActiveMaskFunc, {ConstZero, EVLParam});241}242243// Fixed vector %evl conversion.244Type *LaneTy = EVLParam->getType();245unsigned NumElems = ElemCount.getFixedValue();246Value *VLSplat = Builder.CreateVectorSplat(NumElems, EVLParam);247Value *IdxVec = createStepVector(Builder, LaneTy, NumElems);248return Builder.CreateICmp(CmpInst::ICMP_ULT, IdxVec, VLSplat);249}250251Value *252CachingVPExpander::expandPredicationInBinaryOperator(IRBuilder<> &Builder,253VPIntrinsic &VPI) {254assert((maySpeculateLanes(VPI) || VPI.canIgnoreVectorLengthParam()) &&255"Implicitly dropping %evl in non-speculatable operator!");256257auto OC = static_cast<Instruction::BinaryOps>(*VPI.getFunctionalOpcode());258assert(Instruction::isBinaryOp(OC));259260Value *Op0 = VPI.getOperand(0);261Value *Op1 = VPI.getOperand(1);262Value *Mask = VPI.getMaskParam();263264// Blend in safe operands.265if (Mask && !isAllTrueMask(Mask)) {266switch (OC) {267default:268// Can safely ignore the predicate.269break;270271// Division operators need a safe divisor on masked-off lanes (1).272case Instruction::UDiv:273case Instruction::SDiv:274case Instruction::URem:275case Instruction::SRem:276// 2nd operand must not be zero.277Value *SafeDivisor = getSafeDivisor(VPI.getType());278Op1 = Builder.CreateSelect(Mask, Op1, SafeDivisor);279}280}281282Value *NewBinOp = Builder.CreateBinOp(OC, Op0, Op1, VPI.getName());283284replaceOperation(*NewBinOp, VPI);285return NewBinOp;286}287288Value *CachingVPExpander::expandPredicationToIntCall(289IRBuilder<> &Builder, VPIntrinsic &VPI, unsigned UnpredicatedIntrinsicID) {290switch (UnpredicatedIntrinsicID) {291case Intrinsic::abs:292case Intrinsic::smax:293case Intrinsic::smin:294case Intrinsic::umax:295case Intrinsic::umin: {296Value *Op0 = VPI.getOperand(0);297Value *Op1 = VPI.getOperand(1);298Function *Fn = Intrinsic::getDeclaration(299VPI.getModule(), UnpredicatedIntrinsicID, {VPI.getType()});300Value *NewOp = Builder.CreateCall(Fn, {Op0, Op1}, VPI.getName());301replaceOperation(*NewOp, VPI);302return NewOp;303}304case Intrinsic::bswap:305case Intrinsic::bitreverse: {306Value *Op = VPI.getOperand(0);307Function *Fn = Intrinsic::getDeclaration(308VPI.getModule(), UnpredicatedIntrinsicID, {VPI.getType()});309Value *NewOp = Builder.CreateCall(Fn, {Op}, VPI.getName());310replaceOperation(*NewOp, VPI);311return NewOp;312}313}314return nullptr;315}316317Value *CachingVPExpander::expandPredicationToFPCall(318IRBuilder<> &Builder, VPIntrinsic &VPI, unsigned UnpredicatedIntrinsicID) {319assert((maySpeculateLanes(VPI) || VPI.canIgnoreVectorLengthParam()) &&320"Implicitly dropping %evl in non-speculatable operator!");321322switch (UnpredicatedIntrinsicID) {323case Intrinsic::fabs:324case Intrinsic::sqrt: {325Value *Op0 = VPI.getOperand(0);326Function *Fn = Intrinsic::getDeclaration(327VPI.getModule(), UnpredicatedIntrinsicID, {VPI.getType()});328Value *NewOp = Builder.CreateCall(Fn, {Op0}, VPI.getName());329replaceOperation(*NewOp, VPI);330return NewOp;331}332case Intrinsic::maxnum:333case Intrinsic::minnum: {334Value *Op0 = VPI.getOperand(0);335Value *Op1 = VPI.getOperand(1);336Function *Fn = Intrinsic::getDeclaration(337VPI.getModule(), UnpredicatedIntrinsicID, {VPI.getType()});338Value *NewOp = Builder.CreateCall(Fn, {Op0, Op1}, VPI.getName());339replaceOperation(*NewOp, VPI);340return NewOp;341}342case Intrinsic::fma:343case Intrinsic::fmuladd:344case Intrinsic::experimental_constrained_fma:345case Intrinsic::experimental_constrained_fmuladd: {346Value *Op0 = VPI.getOperand(0);347Value *Op1 = VPI.getOperand(1);348Value *Op2 = VPI.getOperand(2);349Function *Fn = Intrinsic::getDeclaration(350VPI.getModule(), UnpredicatedIntrinsicID, {VPI.getType()});351Value *NewOp;352if (Intrinsic::isConstrainedFPIntrinsic(UnpredicatedIntrinsicID))353NewOp =354Builder.CreateConstrainedFPCall(Fn, {Op0, Op1, Op2}, VPI.getName());355else356NewOp = Builder.CreateCall(Fn, {Op0, Op1, Op2}, VPI.getName());357replaceOperation(*NewOp, VPI);358return NewOp;359}360}361362return nullptr;363}364365static Value *getNeutralReductionElement(const VPReductionIntrinsic &VPI,366Type *EltTy) {367bool Negative = false;368unsigned EltBits = EltTy->getScalarSizeInBits();369Intrinsic::ID VID = VPI.getIntrinsicID();370switch (VID) {371default:372llvm_unreachable("Expecting a VP reduction intrinsic");373case Intrinsic::vp_reduce_add:374case Intrinsic::vp_reduce_or:375case Intrinsic::vp_reduce_xor:376case Intrinsic::vp_reduce_umax:377return Constant::getNullValue(EltTy);378case Intrinsic::vp_reduce_mul:379return ConstantInt::get(EltTy, 1, /*IsSigned*/ false);380case Intrinsic::vp_reduce_and:381case Intrinsic::vp_reduce_umin:382return ConstantInt::getAllOnesValue(EltTy);383case Intrinsic::vp_reduce_smin:384return ConstantInt::get(EltTy->getContext(),385APInt::getSignedMaxValue(EltBits));386case Intrinsic::vp_reduce_smax:387return ConstantInt::get(EltTy->getContext(),388APInt::getSignedMinValue(EltBits));389case Intrinsic::vp_reduce_fmax:390case Intrinsic::vp_reduce_fmaximum:391Negative = true;392[[fallthrough]];393case Intrinsic::vp_reduce_fmin:394case Intrinsic::vp_reduce_fminimum: {395bool PropagatesNaN = VID == Intrinsic::vp_reduce_fminimum ||396VID == Intrinsic::vp_reduce_fmaximum;397FastMathFlags Flags = VPI.getFastMathFlags();398const fltSemantics &Semantics = EltTy->getFltSemantics();399return (!Flags.noNaNs() && !PropagatesNaN)400? ConstantFP::getQNaN(EltTy, Negative)401: !Flags.noInfs()402? ConstantFP::getInfinity(EltTy, Negative)403: ConstantFP::get(EltTy,404APFloat::getLargest(Semantics, Negative));405}406case Intrinsic::vp_reduce_fadd:407return ConstantFP::getNegativeZero(EltTy);408case Intrinsic::vp_reduce_fmul:409return ConstantFP::get(EltTy, 1.0);410}411}412413Value *414CachingVPExpander::expandPredicationInReduction(IRBuilder<> &Builder,415VPReductionIntrinsic &VPI) {416assert((maySpeculateLanes(VPI) || VPI.canIgnoreVectorLengthParam()) &&417"Implicitly dropping %evl in non-speculatable operator!");418419Value *Mask = VPI.getMaskParam();420Value *RedOp = VPI.getOperand(VPI.getVectorParamPos());421422// Insert neutral element in masked-out positions423if (Mask && !isAllTrueMask(Mask)) {424auto *NeutralElt = getNeutralReductionElement(VPI, VPI.getType());425auto *NeutralVector = Builder.CreateVectorSplat(426cast<VectorType>(RedOp->getType())->getElementCount(), NeutralElt);427RedOp = Builder.CreateSelect(Mask, RedOp, NeutralVector);428}429430Value *Reduction;431Value *Start = VPI.getOperand(VPI.getStartParamPos());432433switch (VPI.getIntrinsicID()) {434default:435llvm_unreachable("Impossible reduction kind");436case Intrinsic::vp_reduce_add:437Reduction = Builder.CreateAddReduce(RedOp);438Reduction = Builder.CreateAdd(Reduction, Start);439break;440case Intrinsic::vp_reduce_mul:441Reduction = Builder.CreateMulReduce(RedOp);442Reduction = Builder.CreateMul(Reduction, Start);443break;444case Intrinsic::vp_reduce_and:445Reduction = Builder.CreateAndReduce(RedOp);446Reduction = Builder.CreateAnd(Reduction, Start);447break;448case Intrinsic::vp_reduce_or:449Reduction = Builder.CreateOrReduce(RedOp);450Reduction = Builder.CreateOr(Reduction, Start);451break;452case Intrinsic::vp_reduce_xor:453Reduction = Builder.CreateXorReduce(RedOp);454Reduction = Builder.CreateXor(Reduction, Start);455break;456case Intrinsic::vp_reduce_smax:457Reduction = Builder.CreateIntMaxReduce(RedOp, /*IsSigned*/ true);458Reduction =459Builder.CreateBinaryIntrinsic(Intrinsic::smax, Reduction, Start);460break;461case Intrinsic::vp_reduce_smin:462Reduction = Builder.CreateIntMinReduce(RedOp, /*IsSigned*/ true);463Reduction =464Builder.CreateBinaryIntrinsic(Intrinsic::smin, Reduction, Start);465break;466case Intrinsic::vp_reduce_umax:467Reduction = Builder.CreateIntMaxReduce(RedOp, /*IsSigned*/ false);468Reduction =469Builder.CreateBinaryIntrinsic(Intrinsic::umax, Reduction, Start);470break;471case Intrinsic::vp_reduce_umin:472Reduction = Builder.CreateIntMinReduce(RedOp, /*IsSigned*/ false);473Reduction =474Builder.CreateBinaryIntrinsic(Intrinsic::umin, Reduction, Start);475break;476case Intrinsic::vp_reduce_fmax:477Reduction = Builder.CreateFPMaxReduce(RedOp);478transferDecorations(*Reduction, VPI);479Reduction =480Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, Reduction, Start);481break;482case Intrinsic::vp_reduce_fmin:483Reduction = Builder.CreateFPMinReduce(RedOp);484transferDecorations(*Reduction, VPI);485Reduction =486Builder.CreateBinaryIntrinsic(Intrinsic::minnum, Reduction, Start);487break;488case Intrinsic::vp_reduce_fmaximum:489Reduction = Builder.CreateFPMaximumReduce(RedOp);490transferDecorations(*Reduction, VPI);491Reduction =492Builder.CreateBinaryIntrinsic(Intrinsic::maximum, Reduction, Start);493break;494case Intrinsic::vp_reduce_fminimum:495Reduction = Builder.CreateFPMinimumReduce(RedOp);496transferDecorations(*Reduction, VPI);497Reduction =498Builder.CreateBinaryIntrinsic(Intrinsic::minimum, Reduction, Start);499break;500case Intrinsic::vp_reduce_fadd:501Reduction = Builder.CreateFAddReduce(Start, RedOp);502break;503case Intrinsic::vp_reduce_fmul:504Reduction = Builder.CreateFMulReduce(Start, RedOp);505break;506}507508replaceOperation(*Reduction, VPI);509return Reduction;510}511512Value *CachingVPExpander::expandPredicationToCastIntrinsic(IRBuilder<> &Builder,513VPIntrinsic &VPI) {514Value *CastOp = nullptr;515switch (VPI.getIntrinsicID()) {516default:517llvm_unreachable("Not a VP cast intrinsic");518case Intrinsic::vp_sext:519CastOp =520Builder.CreateSExt(VPI.getOperand(0), VPI.getType(), VPI.getName());521break;522case Intrinsic::vp_zext:523CastOp =524Builder.CreateZExt(VPI.getOperand(0), VPI.getType(), VPI.getName());525break;526case Intrinsic::vp_trunc:527CastOp =528Builder.CreateTrunc(VPI.getOperand(0), VPI.getType(), VPI.getName());529break;530case Intrinsic::vp_inttoptr:531CastOp =532Builder.CreateIntToPtr(VPI.getOperand(0), VPI.getType(), VPI.getName());533break;534case Intrinsic::vp_ptrtoint:535CastOp =536Builder.CreatePtrToInt(VPI.getOperand(0), VPI.getType(), VPI.getName());537break;538case Intrinsic::vp_fptosi:539CastOp =540Builder.CreateFPToSI(VPI.getOperand(0), VPI.getType(), VPI.getName());541break;542543case Intrinsic::vp_fptoui:544CastOp =545Builder.CreateFPToUI(VPI.getOperand(0), VPI.getType(), VPI.getName());546break;547case Intrinsic::vp_sitofp:548CastOp =549Builder.CreateSIToFP(VPI.getOperand(0), VPI.getType(), VPI.getName());550break;551case Intrinsic::vp_uitofp:552CastOp =553Builder.CreateUIToFP(VPI.getOperand(0), VPI.getType(), VPI.getName());554break;555case Intrinsic::vp_fptrunc:556CastOp =557Builder.CreateFPTrunc(VPI.getOperand(0), VPI.getType(), VPI.getName());558break;559case Intrinsic::vp_fpext:560CastOp =561Builder.CreateFPExt(VPI.getOperand(0), VPI.getType(), VPI.getName());562break;563}564replaceOperation(*CastOp, VPI);565return CastOp;566}567568Value *569CachingVPExpander::expandPredicationInMemoryIntrinsic(IRBuilder<> &Builder,570VPIntrinsic &VPI) {571assert(VPI.canIgnoreVectorLengthParam());572573const auto &DL = F.getDataLayout();574575Value *MaskParam = VPI.getMaskParam();576Value *PtrParam = VPI.getMemoryPointerParam();577Value *DataParam = VPI.getMemoryDataParam();578bool IsUnmasked = isAllTrueMask(MaskParam);579580MaybeAlign AlignOpt = VPI.getPointerAlignment();581582Value *NewMemoryInst = nullptr;583switch (VPI.getIntrinsicID()) {584default:585llvm_unreachable("Not a VP memory intrinsic");586case Intrinsic::vp_store:587if (IsUnmasked) {588StoreInst *NewStore =589Builder.CreateStore(DataParam, PtrParam, /*IsVolatile*/ false);590if (AlignOpt.has_value())591NewStore->setAlignment(*AlignOpt);592NewMemoryInst = NewStore;593} else594NewMemoryInst = Builder.CreateMaskedStore(595DataParam, PtrParam, AlignOpt.valueOrOne(), MaskParam);596597break;598case Intrinsic::vp_load:599if (IsUnmasked) {600LoadInst *NewLoad =601Builder.CreateLoad(VPI.getType(), PtrParam, /*IsVolatile*/ false);602if (AlignOpt.has_value())603NewLoad->setAlignment(*AlignOpt);604NewMemoryInst = NewLoad;605} else606NewMemoryInst = Builder.CreateMaskedLoad(607VPI.getType(), PtrParam, AlignOpt.valueOrOne(), MaskParam);608609break;610case Intrinsic::vp_scatter: {611auto *ElementType =612cast<VectorType>(DataParam->getType())->getElementType();613NewMemoryInst = Builder.CreateMaskedScatter(614DataParam, PtrParam,615AlignOpt.value_or(DL.getPrefTypeAlign(ElementType)), MaskParam);616break;617}618case Intrinsic::vp_gather: {619auto *ElementType = cast<VectorType>(VPI.getType())->getElementType();620NewMemoryInst = Builder.CreateMaskedGather(621VPI.getType(), PtrParam,622AlignOpt.value_or(DL.getPrefTypeAlign(ElementType)), MaskParam, nullptr,623VPI.getName());624break;625}626}627628assert(NewMemoryInst);629replaceOperation(*NewMemoryInst, VPI);630return NewMemoryInst;631}632633Value *CachingVPExpander::expandPredicationInComparison(IRBuilder<> &Builder,634VPCmpIntrinsic &VPI) {635assert((maySpeculateLanes(VPI) || VPI.canIgnoreVectorLengthParam()) &&636"Implicitly dropping %evl in non-speculatable operator!");637638assert(*VPI.getFunctionalOpcode() == Instruction::ICmp ||639*VPI.getFunctionalOpcode() == Instruction::FCmp);640641Value *Op0 = VPI.getOperand(0);642Value *Op1 = VPI.getOperand(1);643auto Pred = VPI.getPredicate();644645auto *NewCmp = Builder.CreateCmp(Pred, Op0, Op1);646647replaceOperation(*NewCmp, VPI);648return NewCmp;649}650651void CachingVPExpander::discardEVLParameter(VPIntrinsic &VPI) {652LLVM_DEBUG(dbgs() << "Discard EVL parameter in " << VPI << "\n");653654if (VPI.canIgnoreVectorLengthParam())655return;656657Value *EVLParam = VPI.getVectorLengthParam();658if (!EVLParam)659return;660661ElementCount StaticElemCount = VPI.getStaticVectorLength();662Value *MaxEVL = nullptr;663Type *Int32Ty = Type::getInt32Ty(VPI.getContext());664if (StaticElemCount.isScalable()) {665// TODO add caching666auto *M = VPI.getModule();667Function *VScaleFunc =668Intrinsic::getDeclaration(M, Intrinsic::vscale, Int32Ty);669IRBuilder<> Builder(VPI.getParent(), VPI.getIterator());670Value *FactorConst = Builder.getInt32(StaticElemCount.getKnownMinValue());671Value *VScale = Builder.CreateCall(VScaleFunc, {}, "vscale");672MaxEVL = Builder.CreateMul(VScale, FactorConst, "scalable_size",673/*NUW*/ true, /*NSW*/ false);674} else {675MaxEVL = ConstantInt::get(Int32Ty, StaticElemCount.getFixedValue(), false);676}677VPI.setVectorLengthParam(MaxEVL);678}679680Value *CachingVPExpander::foldEVLIntoMask(VPIntrinsic &VPI) {681LLVM_DEBUG(dbgs() << "Folding vlen for " << VPI << '\n');682683IRBuilder<> Builder(&VPI);684685// Ineffective %evl parameter and so nothing to do here.686if (VPI.canIgnoreVectorLengthParam())687return &VPI;688689// Only VP intrinsics can have an %evl parameter.690Value *OldMaskParam = VPI.getMaskParam();691Value *OldEVLParam = VPI.getVectorLengthParam();692assert(OldMaskParam && "no mask param to fold the vl param into");693assert(OldEVLParam && "no EVL param to fold away");694695LLVM_DEBUG(dbgs() << "OLD evl: " << *OldEVLParam << '\n');696LLVM_DEBUG(dbgs() << "OLD mask: " << *OldMaskParam << '\n');697698// Convert the %evl predication into vector mask predication.699ElementCount ElemCount = VPI.getStaticVectorLength();700Value *VLMask = convertEVLToMask(Builder, OldEVLParam, ElemCount);701Value *NewMaskParam = Builder.CreateAnd(VLMask, OldMaskParam);702VPI.setMaskParam(NewMaskParam);703704// Drop the %evl parameter.705discardEVLParameter(VPI);706assert(VPI.canIgnoreVectorLengthParam() &&707"transformation did not render the evl param ineffective!");708709// Reassess the modified instruction.710return &VPI;711}712713Value *CachingVPExpander::expandPredication(VPIntrinsic &VPI) {714LLVM_DEBUG(dbgs() << "Lowering to unpredicated op: " << VPI << '\n');715716IRBuilder<> Builder(&VPI);717718// Try lowering to a LLVM instruction first.719auto OC = VPI.getFunctionalOpcode();720721if (OC && Instruction::isBinaryOp(*OC))722return expandPredicationInBinaryOperator(Builder, VPI);723724if (auto *VPRI = dyn_cast<VPReductionIntrinsic>(&VPI))725return expandPredicationInReduction(Builder, *VPRI);726727if (auto *VPCmp = dyn_cast<VPCmpIntrinsic>(&VPI))728return expandPredicationInComparison(Builder, *VPCmp);729730if (VPCastIntrinsic::isVPCast(VPI.getIntrinsicID())) {731return expandPredicationToCastIntrinsic(Builder, VPI);732}733734switch (VPI.getIntrinsicID()) {735default:736break;737case Intrinsic::vp_fneg: {738Value *NewNegOp = Builder.CreateFNeg(VPI.getOperand(0), VPI.getName());739replaceOperation(*NewNegOp, VPI);740return NewNegOp;741}742case Intrinsic::vp_abs:743case Intrinsic::vp_smax:744case Intrinsic::vp_smin:745case Intrinsic::vp_umax:746case Intrinsic::vp_umin:747case Intrinsic::vp_bswap:748case Intrinsic::vp_bitreverse:749return expandPredicationToIntCall(Builder, VPI,750VPI.getFunctionalIntrinsicID().value());751case Intrinsic::vp_fabs:752case Intrinsic::vp_sqrt:753case Intrinsic::vp_maxnum:754case Intrinsic::vp_minnum:755case Intrinsic::vp_maximum:756case Intrinsic::vp_minimum:757case Intrinsic::vp_fma:758case Intrinsic::vp_fmuladd:759return expandPredicationToFPCall(Builder, VPI,760VPI.getFunctionalIntrinsicID().value());761case Intrinsic::vp_load:762case Intrinsic::vp_store:763case Intrinsic::vp_gather:764case Intrinsic::vp_scatter:765return expandPredicationInMemoryIntrinsic(Builder, VPI);766}767768if (auto CID = VPI.getConstrainedIntrinsicID())769if (Value *Call = expandPredicationToFPCall(Builder, VPI, *CID))770return Call;771772return &VPI;773}774775//// } CachingVPExpander776777struct TransformJob {778VPIntrinsic *PI;779TargetTransformInfo::VPLegalization Strategy;780TransformJob(VPIntrinsic *PI, TargetTransformInfo::VPLegalization InitStrat)781: PI(PI), Strategy(InitStrat) {}782783bool isDone() const { return Strategy.shouldDoNothing(); }784};785786void sanitizeStrategy(VPIntrinsic &VPI, VPLegalization &LegalizeStrat) {787// Operations with speculatable lanes do not strictly need predication.788if (maySpeculateLanes(VPI)) {789// Converting a speculatable VP intrinsic means dropping %mask and %evl.790// No need to expand %evl into the %mask only to ignore that code.791if (LegalizeStrat.OpStrategy == VPLegalization::Convert)792LegalizeStrat.EVLParamStrategy = VPLegalization::Discard;793return;794}795796// We have to preserve the predicating effect of %evl for this797// non-speculatable VP intrinsic.798// 1) Never discard %evl.799// 2) If this VP intrinsic will be expanded to non-VP code, make sure that800// %evl gets folded into %mask.801if ((LegalizeStrat.EVLParamStrategy == VPLegalization::Discard) ||802(LegalizeStrat.OpStrategy == VPLegalization::Convert)) {803LegalizeStrat.EVLParamStrategy = VPLegalization::Convert;804}805}806807VPLegalization808CachingVPExpander::getVPLegalizationStrategy(const VPIntrinsic &VPI) const {809auto VPStrat = TTI.getVPLegalizationStrategy(VPI);810if (LLVM_LIKELY(!UsingTTIOverrides)) {811// No overrides - we are in production.812return VPStrat;813}814815// Overrides set - we are in testing, the following does not need to be816// efficient.817VPStrat.EVLParamStrategy = parseOverrideOption(EVLTransformOverride);818VPStrat.OpStrategy = parseOverrideOption(MaskTransformOverride);819return VPStrat;820}821822/// Expand llvm.vp.* intrinsics as requested by \p TTI.823bool CachingVPExpander::expandVectorPredication() {824SmallVector<TransformJob, 16> Worklist;825826// Collect all VPIntrinsics that need expansion and determine their expansion827// strategy.828for (auto &I : instructions(F)) {829auto *VPI = dyn_cast<VPIntrinsic>(&I);830if (!VPI)831continue;832auto VPStrat = getVPLegalizationStrategy(*VPI);833sanitizeStrategy(*VPI, VPStrat);834if (!VPStrat.shouldDoNothing())835Worklist.emplace_back(VPI, VPStrat);836}837if (Worklist.empty())838return false;839840// Transform all VPIntrinsics on the worklist.841LLVM_DEBUG(dbgs() << "\n:::: Transforming " << Worklist.size()842<< " instructions ::::\n");843for (TransformJob Job : Worklist) {844// Transform the EVL parameter.845switch (Job.Strategy.EVLParamStrategy) {846case VPLegalization::Legal:847break;848case VPLegalization::Discard:849discardEVLParameter(*Job.PI);850break;851case VPLegalization::Convert:852if (foldEVLIntoMask(*Job.PI))853++NumFoldedVL;854break;855}856Job.Strategy.EVLParamStrategy = VPLegalization::Legal;857858// Replace with a non-predicated operation.859switch (Job.Strategy.OpStrategy) {860case VPLegalization::Legal:861break;862case VPLegalization::Discard:863llvm_unreachable("Invalid strategy for operators.");864case VPLegalization::Convert:865expandPredication(*Job.PI);866++NumLoweredVPOps;867break;868}869Job.Strategy.OpStrategy = VPLegalization::Legal;870871assert(Job.isDone() && "incomplete transformation");872}873874return true;875}876class ExpandVectorPredication : public FunctionPass {877public:878static char ID;879ExpandVectorPredication() : FunctionPass(ID) {880initializeExpandVectorPredicationPass(*PassRegistry::getPassRegistry());881}882883bool runOnFunction(Function &F) override {884const auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);885CachingVPExpander VPExpander(F, *TTI);886return VPExpander.expandVectorPredication();887}888889void getAnalysisUsage(AnalysisUsage &AU) const override {890AU.addRequired<TargetTransformInfoWrapperPass>();891AU.setPreservesCFG();892}893};894} // namespace895896char ExpandVectorPredication::ID;897INITIALIZE_PASS_BEGIN(ExpandVectorPredication, "expandvp",898"Expand vector predication intrinsics", false, false)899INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)900INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)901INITIALIZE_PASS_END(ExpandVectorPredication, "expandvp",902"Expand vector predication intrinsics", false, false)903904FunctionPass *llvm::createExpandVectorPredicationPass() {905return new ExpandVectorPredication();906}907908PreservedAnalyses909ExpandVectorPredicationPass::run(Function &F, FunctionAnalysisManager &AM) {910const auto &TTI = AM.getResult<TargetIRAnalysis>(F);911CachingVPExpander VPExpander(F, TTI);912if (!VPExpander.expandVectorPredication())913return PreservedAnalyses::all();914PreservedAnalyses PA;915PA.preserveSet<CFGAnalyses>();916return PA;917}918919920