Path: blob/main/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
35269 views
//===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// This pass custom lowers llvm.gather and llvm.scatter instructions to9// RISC-V intrinsics.10//11//===----------------------------------------------------------------------===//1213#include "RISCV.h"14#include "RISCVTargetMachine.h"15#include "llvm/Analysis/InstSimplifyFolder.h"16#include "llvm/Analysis/LoopInfo.h"17#include "llvm/Analysis/ValueTracking.h"18#include "llvm/Analysis/VectorUtils.h"19#include "llvm/CodeGen/TargetPassConfig.h"20#include "llvm/IR/GetElementPtrTypeIterator.h"21#include "llvm/IR/IRBuilder.h"22#include "llvm/IR/IntrinsicInst.h"23#include "llvm/IR/IntrinsicsRISCV.h"24#include "llvm/IR/PatternMatch.h"25#include "llvm/Transforms/Utils/Local.h"26#include <optional>2728using namespace llvm;29using namespace PatternMatch;3031#define DEBUG_TYPE "riscv-gather-scatter-lowering"3233namespace {3435class RISCVGatherScatterLowering : public FunctionPass {36const RISCVSubtarget *ST = nullptr;37const RISCVTargetLowering *TLI = nullptr;38LoopInfo *LI = nullptr;39const DataLayout *DL = nullptr;4041SmallVector<WeakTrackingVH> MaybeDeadPHIs;4243// Cache of the BasePtr and Stride determined from this GEP. When a GEP is44// used by multiple gathers/scatters, this allow us to reuse the scalar45// instructions we created for the first gather/scatter for the others.46DenseMap<GetElementPtrInst *, std::pair<Value *, Value *>> StridedAddrs;4748public:49static char ID; // Pass identification, replacement for typeid5051RISCVGatherScatterLowering() : FunctionPass(ID) {}5253bool runOnFunction(Function &F) override;5455void getAnalysisUsage(AnalysisUsage &AU) const override {56AU.setPreservesCFG();57AU.addRequired<TargetPassConfig>();58AU.addRequired<LoopInfoWrapperPass>();59}6061StringRef getPassName() const override {62return "RISC-V gather/scatter lowering";63}6465private:66bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr,67Value *AlignOp);6869std::pair<Value *, Value *> determineBaseAndStride(Instruction *Ptr,70IRBuilderBase &Builder);7172bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride,73PHINode *&BasePtr, BinaryOperator *&Inc,74IRBuilderBase &Builder);75};7677} // end anonymous namespace7879char RISCVGatherScatterLowering::ID = 0;8081INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE,82"RISC-V gather/scatter lowering pass", false, false)8384FunctionPass *llvm::createRISCVGatherScatterLoweringPass() {85return new RISCVGatherScatterLowering();86}8788// TODO: Should we consider the mask when looking for a stride?89static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) {90if (!isa<FixedVectorType>(StartC->getType()))91return std::make_pair(nullptr, nullptr);9293unsigned NumElts = cast<FixedVectorType>(StartC->getType())->getNumElements();9495// Check that the start value is a strided constant.96auto *StartVal =97dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement((unsigned)0));98if (!StartVal)99return std::make_pair(nullptr, nullptr);100APInt StrideVal(StartVal->getValue().getBitWidth(), 0);101ConstantInt *Prev = StartVal;102for (unsigned i = 1; i != NumElts; ++i) {103auto *C = dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement(i));104if (!C)105return std::make_pair(nullptr, nullptr);106107APInt LocalStride = C->getValue() - Prev->getValue();108if (i == 1)109StrideVal = LocalStride;110else if (StrideVal != LocalStride)111return std::make_pair(nullptr, nullptr);112113Prev = C;114}115116Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal);117118return std::make_pair(StartVal, Stride);119}120121static std::pair<Value *, Value *> matchStridedStart(Value *Start,122IRBuilderBase &Builder) {123// Base case, start is a strided constant.124auto *StartC = dyn_cast<Constant>(Start);125if (StartC)126return matchStridedConstant(StartC);127128// Base case, start is a stepvector129if (match(Start, m_Intrinsic<Intrinsic::experimental_stepvector>())) {130auto *Ty = Start->getType()->getScalarType();131return std::make_pair(ConstantInt::get(Ty, 0), ConstantInt::get(Ty, 1));132}133134// Not a constant, maybe it's a strided constant with a splat added or135// multipled.136auto *BO = dyn_cast<BinaryOperator>(Start);137if (!BO || (BO->getOpcode() != Instruction::Add &&138BO->getOpcode() != Instruction::Or &&139BO->getOpcode() != Instruction::Shl &&140BO->getOpcode() != Instruction::Mul))141return std::make_pair(nullptr, nullptr);142143if (BO->getOpcode() == Instruction::Or &&144!cast<PossiblyDisjointInst>(BO)->isDisjoint())145return std::make_pair(nullptr, nullptr);146147// Look for an operand that is splatted.148unsigned OtherIndex = 0;149Value *Splat = getSplatValue(BO->getOperand(1));150if (!Splat && Instruction::isCommutative(BO->getOpcode())) {151Splat = getSplatValue(BO->getOperand(0));152OtherIndex = 1;153}154if (!Splat)155return std::make_pair(nullptr, nullptr);156157Value *Stride;158std::tie(Start, Stride) = matchStridedStart(BO->getOperand(OtherIndex),159Builder);160if (!Start)161return std::make_pair(nullptr, nullptr);162163Builder.SetInsertPoint(BO);164Builder.SetCurrentDebugLocation(DebugLoc());165// Add the splat value to the start or multiply the start and stride by the166// splat.167switch (BO->getOpcode()) {168default:169llvm_unreachable("Unexpected opcode");170case Instruction::Or:171// TODO: We'd be better off creating disjoint or here, but we don't yet172// have an IRBuilder API for that.173[[fallthrough]];174case Instruction::Add:175Start = Builder.CreateAdd(Start, Splat);176break;177case Instruction::Mul:178Start = Builder.CreateMul(Start, Splat);179Stride = Builder.CreateMul(Stride, Splat);180break;181case Instruction::Shl:182Start = Builder.CreateShl(Start, Splat);183Stride = Builder.CreateShl(Stride, Splat);184break;185}186187return std::make_pair(Start, Stride);188}189190// Recursively, walk about the use-def chain until we find a Phi with a strided191// start value. Build and update a scalar recurrence as we unwind the recursion.192// We also update the Stride as we unwind. Our goal is to move all of the193// arithmetic out of the loop.194bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,195Value *&Stride,196PHINode *&BasePtr,197BinaryOperator *&Inc,198IRBuilderBase &Builder) {199// Our base case is a Phi.200if (auto *Phi = dyn_cast<PHINode>(Index)) {201// A phi node we want to perform this function on should be from the202// loop header.203if (Phi->getParent() != L->getHeader())204return false;205206Value *Step, *Start;207if (!matchSimpleRecurrence(Phi, Inc, Start, Step) ||208Inc->getOpcode() != Instruction::Add)209return false;210assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");211unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1;212assert(Phi->getIncomingValue(IncrementingBlock) == Inc &&213"Expected one operand of phi to be Inc");214215// Only proceed if the step is loop invariant.216if (!L->isLoopInvariant(Step))217return false;218219// Step should be a splat.220Step = getSplatValue(Step);221if (!Step)222return false;223224std::tie(Start, Stride) = matchStridedStart(Start, Builder);225if (!Start)226return false;227assert(Stride != nullptr);228229// Build scalar phi and increment.230BasePtr =231PHINode::Create(Start->getType(), 2, Phi->getName() + ".scalar", Phi->getIterator());232Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->getName() + ".scalar",233Inc->getIterator());234BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock));235BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock));236237// Note that this Phi might be eligible for removal.238MaybeDeadPHIs.push_back(Phi);239return true;240}241242// Otherwise look for binary operator.243auto *BO = dyn_cast<BinaryOperator>(Index);244if (!BO)245return false;246247switch (BO->getOpcode()) {248default:249return false;250case Instruction::Or:251// We need to be able to treat Or as Add.252if (!cast<PossiblyDisjointInst>(BO)->isDisjoint())253return false;254break;255case Instruction::Add:256break;257case Instruction::Shl:258break;259case Instruction::Mul:260break;261}262263// We should have one operand in the loop and one splat.264Value *OtherOp;265if (isa<Instruction>(BO->getOperand(0)) &&266L->contains(cast<Instruction>(BO->getOperand(0)))) {267Index = cast<Instruction>(BO->getOperand(0));268OtherOp = BO->getOperand(1);269} else if (isa<Instruction>(BO->getOperand(1)) &&270L->contains(cast<Instruction>(BO->getOperand(1))) &&271Instruction::isCommutative(BO->getOpcode())) {272Index = cast<Instruction>(BO->getOperand(1));273OtherOp = BO->getOperand(0);274} else {275return false;276}277278// Make sure other op is loop invariant.279if (!L->isLoopInvariant(OtherOp))280return false;281282// Make sure we have a splat.283Value *SplatOp = getSplatValue(OtherOp);284if (!SplatOp)285return false;286287// Recurse up the use-def chain.288if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder))289return false;290291// Locate the Step and Start values from the recurrence.292unsigned StepIndex = Inc->getOperand(0) == BasePtr ? 1 : 0;293unsigned StartBlock = BasePtr->getOperand(0) == Inc ? 1 : 0;294Value *Step = Inc->getOperand(StepIndex);295Value *Start = BasePtr->getOperand(StartBlock);296297// We need to adjust the start value in the preheader.298Builder.SetInsertPoint(299BasePtr->getIncomingBlock(StartBlock)->getTerminator());300Builder.SetCurrentDebugLocation(DebugLoc());301302switch (BO->getOpcode()) {303default:304llvm_unreachable("Unexpected opcode!");305case Instruction::Add:306case Instruction::Or: {307// An add only affects the start value. It's ok to do this for Or because308// we already checked that there are no common set bits.309Start = Builder.CreateAdd(Start, SplatOp, "start");310break;311}312case Instruction::Mul: {313Start = Builder.CreateMul(Start, SplatOp, "start");314Step = Builder.CreateMul(Step, SplatOp, "step");315Stride = Builder.CreateMul(Stride, SplatOp, "stride");316break;317}318case Instruction::Shl: {319Start = Builder.CreateShl(Start, SplatOp, "start");320Step = Builder.CreateShl(Step, SplatOp, "step");321Stride = Builder.CreateShl(Stride, SplatOp, "stride");322break;323}324}325326Inc->setOperand(StepIndex, Step);327BasePtr->setIncomingValue(StartBlock, Start);328return true;329}330331std::pair<Value *, Value *>332RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr,333IRBuilderBase &Builder) {334335// A gather/scatter of a splat is a zero strided load/store.336if (auto *BasePtr = getSplatValue(Ptr)) {337Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());338return std::make_pair(BasePtr, ConstantInt::get(IntPtrTy, 0));339}340341auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);342if (!GEP)343return std::make_pair(nullptr, nullptr);344345auto I = StridedAddrs.find(GEP);346if (I != StridedAddrs.end())347return I->second;348349SmallVector<Value *, 2> Ops(GEP->operands());350351// If the base pointer is a vector, check if it's strided.352Value *Base = GEP->getPointerOperand();353if (auto *BaseInst = dyn_cast<Instruction>(Base);354BaseInst && BaseInst->getType()->isVectorTy()) {355// If GEP's offset is scalar then we can add it to the base pointer's base.356auto IsScalar = [](Value *Idx) { return !Idx->getType()->isVectorTy(); };357if (all_of(GEP->indices(), IsScalar)) {358auto [BaseBase, Stride] = determineBaseAndStride(BaseInst, Builder);359if (BaseBase) {360Builder.SetInsertPoint(GEP);361SmallVector<Value *> Indices(GEP->indices());362Value *OffsetBase =363Builder.CreateGEP(GEP->getSourceElementType(), BaseBase, Indices,364GEP->getName() + "offset", GEP->isInBounds());365return {OffsetBase, Stride};366}367}368}369370// Base pointer needs to be a scalar.371Value *ScalarBase = Base;372if (ScalarBase->getType()->isVectorTy()) {373ScalarBase = getSplatValue(ScalarBase);374if (!ScalarBase)375return std::make_pair(nullptr, nullptr);376}377378std::optional<unsigned> VecOperand;379unsigned TypeScale = 0;380381// Look for a vector operand and scale.382gep_type_iterator GTI = gep_type_begin(GEP);383for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) {384if (!Ops[i]->getType()->isVectorTy())385continue;386387if (VecOperand)388return std::make_pair(nullptr, nullptr);389390VecOperand = i;391392TypeSize TS = GTI.getSequentialElementStride(*DL);393if (TS.isScalable())394return std::make_pair(nullptr, nullptr);395396TypeScale = TS.getFixedValue();397}398399// We need to find a vector index to simplify.400if (!VecOperand)401return std::make_pair(nullptr, nullptr);402403// We can't extract the stride if the arithmetic is done at a different size404// than the pointer type. Adding the stride later may not wrap correctly.405// Technically we could handle wider indices, but I don't expect that in406// practice. Handle one special case here - constants. This simplifies407// writing test cases.408Value *VecIndex = Ops[*VecOperand];409Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType());410if (VecIndex->getType() != VecIntPtrTy) {411auto *VecIndexC = dyn_cast<Constant>(VecIndex);412if (!VecIndexC)413return std::make_pair(nullptr, nullptr);414if (VecIndex->getType()->getScalarSizeInBits() > VecIntPtrTy->getScalarSizeInBits())415VecIndex = ConstantFoldCastInstruction(Instruction::Trunc, VecIndexC, VecIntPtrTy);416else417VecIndex = ConstantFoldCastInstruction(Instruction::SExt, VecIndexC, VecIntPtrTy);418}419420// Handle the non-recursive case. This is what we see if the vectorizer421// decides to use a scalar IV + vid on demand instead of a vector IV.422auto [Start, Stride] = matchStridedStart(VecIndex, Builder);423if (Start) {424assert(Stride);425Builder.SetInsertPoint(GEP);426427// Replace the vector index with the scalar start and build a scalar GEP.428Ops[*VecOperand] = Start;429Type *SourceTy = GEP->getSourceElementType();430Value *BasePtr =431Builder.CreateGEP(SourceTy, ScalarBase, ArrayRef(Ops).drop_front());432433// Convert stride to pointer size if needed.434Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());435assert(Stride->getType() == IntPtrTy && "Unexpected type");436437// Scale the stride by the size of the indexed type.438if (TypeScale != 1)439Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));440441auto P = std::make_pair(BasePtr, Stride);442StridedAddrs[GEP] = P;443return P;444}445446// Make sure we're in a loop and that has a pre-header and a single latch.447Loop *L = LI->getLoopFor(GEP->getParent());448if (!L || !L->getLoopPreheader() || !L->getLoopLatch())449return std::make_pair(nullptr, nullptr);450451BinaryOperator *Inc;452PHINode *BasePhi;453if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder))454return std::make_pair(nullptr, nullptr);455456assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");457unsigned IncrementingBlock = BasePhi->getOperand(0) == Inc ? 0 : 1;458assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc &&459"Expected one operand of phi to be Inc");460461Builder.SetInsertPoint(GEP);462463// Replace the vector index with the scalar phi and build a scalar GEP.464Ops[*VecOperand] = BasePhi;465Type *SourceTy = GEP->getSourceElementType();466Value *BasePtr =467Builder.CreateGEP(SourceTy, ScalarBase, ArrayRef(Ops).drop_front());468469// Final adjustments to stride should go in the start block.470Builder.SetInsertPoint(471BasePhi->getIncomingBlock(1 - IncrementingBlock)->getTerminator());472473// Convert stride to pointer size if needed.474Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());475assert(Stride->getType() == IntPtrTy && "Unexpected type");476477// Scale the stride by the size of the indexed type.478if (TypeScale != 1)479Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));480481auto P = std::make_pair(BasePtr, Stride);482StridedAddrs[GEP] = P;483return P;484}485486bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II,487Type *DataType,488Value *Ptr,489Value *AlignOp) {490// Make sure the operation will be supported by the backend.491MaybeAlign MA = cast<ConstantInt>(AlignOp)->getMaybeAlignValue();492EVT DataTypeVT = TLI->getValueType(*DL, DataType);493if (!MA || !TLI->isLegalStridedLoadStore(DataTypeVT, *MA))494return false;495496// FIXME: Let the backend type legalize by splitting/widening?497if (!TLI->isTypeLegal(DataTypeVT))498return false;499500// Pointer should be an instruction.501auto *PtrI = dyn_cast<Instruction>(Ptr);502if (!PtrI)503return false;504505LLVMContext &Ctx = PtrI->getContext();506IRBuilder<InstSimplifyFolder> Builder(Ctx, *DL);507Builder.SetInsertPoint(PtrI);508509Value *BasePtr, *Stride;510std::tie(BasePtr, Stride) = determineBaseAndStride(PtrI, Builder);511if (!BasePtr)512return false;513assert(Stride != nullptr);514515Builder.SetInsertPoint(II);516517CallInst *Call;518if (II->getIntrinsicID() == Intrinsic::masked_gather)519Call = Builder.CreateIntrinsic(520Intrinsic::riscv_masked_strided_load,521{DataType, BasePtr->getType(), Stride->getType()},522{II->getArgOperand(3), BasePtr, Stride, II->getArgOperand(2)});523else524Call = Builder.CreateIntrinsic(525Intrinsic::riscv_masked_strided_store,526{DataType, BasePtr->getType(), Stride->getType()},527{II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(3)});528529Call->takeName(II);530II->replaceAllUsesWith(Call);531II->eraseFromParent();532533if (PtrI->use_empty())534RecursivelyDeleteTriviallyDeadInstructions(PtrI);535536return true;537}538539bool RISCVGatherScatterLowering::runOnFunction(Function &F) {540if (skipFunction(F))541return false;542543auto &TPC = getAnalysis<TargetPassConfig>();544auto &TM = TPC.getTM<RISCVTargetMachine>();545ST = &TM.getSubtarget<RISCVSubtarget>(F);546if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors())547return false;548549TLI = ST->getTargetLowering();550DL = &F.getDataLayout();551LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();552553StridedAddrs.clear();554555SmallVector<IntrinsicInst *, 4> Gathers;556SmallVector<IntrinsicInst *, 4> Scatters;557558bool Changed = false;559560for (BasicBlock &BB : F) {561for (Instruction &I : BB) {562IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);563if (II && II->getIntrinsicID() == Intrinsic::masked_gather) {564Gathers.push_back(II);565} else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter) {566Scatters.push_back(II);567}568}569}570571// Rewrite gather/scatter to form strided load/store if possible.572for (auto *II : Gathers)573Changed |= tryCreateStridedLoadStore(574II, II->getType(), II->getArgOperand(0), II->getArgOperand(1));575for (auto *II : Scatters)576Changed |=577tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(),578II->getArgOperand(1), II->getArgOperand(2));579580// Remove any dead phis.581while (!MaybeDeadPHIs.empty()) {582if (auto *Phi = dyn_cast_or_null<PHINode>(MaybeDeadPHIs.pop_back_val()))583RecursivelyDeleteDeadPHINode(Phi);584}585586return Changed;587}588589590