Path: blob/main/contrib/llvm-project/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
213799 views
//===- DXILDataScalarization.cpp - Perform DXIL Data Legalization ---------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===---------------------------------------------------------------------===//78#include "DXILDataScalarization.h"9#include "DirectX.h"10#include "llvm/ADT/PostOrderIterator.h"11#include "llvm/ADT/STLExtras.h"12#include "llvm/IR/DerivedTypes.h"13#include "llvm/IR/GlobalVariable.h"14#include "llvm/IR/IRBuilder.h"15#include "llvm/IR/InstVisitor.h"16#include "llvm/IR/Instructions.h"17#include "llvm/IR/Module.h"18#include "llvm/IR/Operator.h"19#include "llvm/IR/PassManager.h"20#include "llvm/IR/ReplaceConstant.h"21#include "llvm/IR/Type.h"22#include "llvm/Support/Casting.h"23#include "llvm/Transforms/Utils/Cloning.h"24#include "llvm/Transforms/Utils/Local.h"2526#define DEBUG_TYPE "dxil-data-scalarization"27static const int MaxVecSize = 4;2829using namespace llvm;3031// Recursively creates an array-like version of a given vector type.32static Type *equivalentArrayTypeFromVector(Type *T) {33if (auto *VecTy = dyn_cast<VectorType>(T))34return ArrayType::get(VecTy->getElementType(),35dyn_cast<FixedVectorType>(VecTy)->getNumElements());36if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {37Type *NewElementType =38equivalentArrayTypeFromVector(ArrayTy->getElementType());39return ArrayType::get(NewElementType, ArrayTy->getNumElements());40}41// If it's not a vector or array, return the original type.42return T;43}4445class DXILDataScalarizationLegacy : public ModulePass {4647public:48bool runOnModule(Module &M) override;49DXILDataScalarizationLegacy() : ModulePass(ID) {}5051static char ID; // Pass identification.52};5354static bool findAndReplaceVectors(Module &M);5556class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {57public:58DataScalarizerVisitor() : GlobalMap() {}59bool visit(Function &F);60// InstVisitor methods. They return true if the instruction was scalarized,61// false if nothing changed.62bool visitAllocaInst(AllocaInst &AI);63bool visitInstruction(Instruction &I) { return false; }64bool visitSelectInst(SelectInst &SI) { return false; }65bool visitICmpInst(ICmpInst &ICI) { return false; }66bool visitFCmpInst(FCmpInst &FCI) { return false; }67bool visitUnaryOperator(UnaryOperator &UO) { return false; }68bool visitBinaryOperator(BinaryOperator &BO) { return false; }69bool visitGetElementPtrInst(GetElementPtrInst &GEPI);70bool visitCastInst(CastInst &CI) { return false; }71bool visitBitCastInst(BitCastInst &BCI) { return false; }72bool visitInsertElementInst(InsertElementInst &IEI);73bool visitExtractElementInst(ExtractElementInst &EEI);74bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }75bool visitPHINode(PHINode &PHI) { return false; }76bool visitLoadInst(LoadInst &LI);77bool visitStoreInst(StoreInst &SI);78bool visitCallInst(CallInst &ICI) { return false; }79bool visitFreezeInst(FreezeInst &FI) { return false; }80friend bool findAndReplaceVectors(llvm::Module &M);8182private:83typedef std::pair<AllocaInst *, SmallVector<Value *, 4>> AllocaAndGEPs;84typedef SmallDenseMap<Value *, AllocaAndGEPs>85VectorToArrayMap; // A map from a vector-typed Value to its corresponding86// AllocaInst and GEPs to each element of an array87VectorToArrayMap VectorAllocaMap;88AllocaAndGEPs createArrayFromVector(IRBuilder<> &Builder, Value *Vec,89const Twine &Name);90bool replaceDynamicInsertElementInst(InsertElementInst &IEI);91bool replaceDynamicExtractElementInst(ExtractElementInst &EEI);9293GlobalVariable *lookupReplacementGlobal(Value *CurrOperand);94DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;95};9697bool DataScalarizerVisitor::visit(Function &F) {98bool MadeChange = false;99ReversePostOrderTraversal<Function *> RPOT(&F);100for (BasicBlock *BB : make_early_inc_range(RPOT)) {101for (Instruction &I : make_early_inc_range(*BB))102MadeChange |= InstVisitor::visit(I);103}104VectorAllocaMap.clear();105return MadeChange;106}107108GlobalVariable *109DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {110if (GlobalVariable *OldGlobal = dyn_cast<GlobalVariable>(CurrOperand)) {111auto It = GlobalMap.find(OldGlobal);112if (It != GlobalMap.end()) {113return It->second; // Found, return the new global114}115}116return nullptr; // Not found117}118119// Helper function to check if a type is a vector or an array of vectors120static bool isVectorOrArrayOfVectors(Type *T) {121if (isa<VectorType>(T))122return true;123if (ArrayType *ArrType = dyn_cast<ArrayType>(T))124return isa<VectorType>(ArrType->getElementType()) ||125isVectorOrArrayOfVectors(ArrType->getElementType());126return false;127}128129bool DataScalarizerVisitor::visitAllocaInst(AllocaInst &AI) {130Type *AllocatedType = AI.getAllocatedType();131if (!isVectorOrArrayOfVectors(AllocatedType))132return false;133134IRBuilder<> Builder(&AI);135Type *NewType = equivalentArrayTypeFromVector(AllocatedType);136AllocaInst *ArrAlloca =137Builder.CreateAlloca(NewType, nullptr, AI.getName() + ".scalarize");138ArrAlloca->setAlignment(AI.getAlign());139AI.replaceAllUsesWith(ArrAlloca);140AI.eraseFromParent();141return true;142}143144bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {145Value *PtrOperand = LI.getPointerOperand();146ConstantExpr *CE = dyn_cast<ConstantExpr>(PtrOperand);147if (CE && CE->getOpcode() == Instruction::GetElementPtr) {148GetElementPtrInst *OldGEP = cast<GetElementPtrInst>(CE->getAsInstruction());149OldGEP->insertBefore(LI.getIterator());150IRBuilder<> Builder(&LI);151LoadInst *NewLoad = Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());152NewLoad->setAlignment(LI.getAlign());153LI.replaceAllUsesWith(NewLoad);154LI.eraseFromParent();155visitGetElementPtrInst(*OldGEP);156return true;157}158if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand))159LI.setOperand(LI.getPointerOperandIndex(), NewGlobal);160return false;161}162163bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {164165Value *PtrOperand = SI.getPointerOperand();166ConstantExpr *CE = dyn_cast<ConstantExpr>(PtrOperand);167if (CE && CE->getOpcode() == Instruction::GetElementPtr) {168GetElementPtrInst *OldGEP = cast<GetElementPtrInst>(CE->getAsInstruction());169OldGEP->insertBefore(SI.getIterator());170IRBuilder<> Builder(&SI);171StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);172NewStore->setAlignment(SI.getAlign());173SI.replaceAllUsesWith(NewStore);174SI.eraseFromParent();175visitGetElementPtrInst(*OldGEP);176return true;177}178if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand))179SI.setOperand(SI.getPointerOperandIndex(), NewGlobal);180181return false;182}183184DataScalarizerVisitor::AllocaAndGEPs185DataScalarizerVisitor::createArrayFromVector(IRBuilder<> &Builder, Value *Vec,186const Twine &Name = "") {187// If there is already an alloca for this vector, return it188if (VectorAllocaMap.contains(Vec))189return VectorAllocaMap[Vec];190191auto InsertPoint = Builder.GetInsertPoint();192193// Allocate the array to hold the vector elements194Builder.SetInsertPointPastAllocas(Builder.GetInsertBlock()->getParent());195Type *ArrTy = equivalentArrayTypeFromVector(Vec->getType());196AllocaInst *ArrAlloca =197Builder.CreateAlloca(ArrTy, nullptr, Name + ".alloca");198const uint64_t ArrNumElems = ArrTy->getArrayNumElements();199200// Create loads and stores to populate the array immediately after the201// original vector's defining instruction if available, else immediately after202// the alloca203if (auto *Instr = dyn_cast<Instruction>(Vec))204Builder.SetInsertPoint(Instr->getNextNonDebugInstruction());205SmallVector<Value *, 4> GEPs(ArrNumElems);206for (unsigned I = 0; I < ArrNumElems; ++I) {207Value *EE = Builder.CreateExtractElement(Vec, I, Name + ".extract");208GEPs[I] = Builder.CreateInBoundsGEP(209ArrTy, ArrAlloca, {Builder.getInt32(0), Builder.getInt32(I)},210Name + ".index");211Builder.CreateStore(EE, GEPs[I]);212}213214VectorAllocaMap.insert({Vec, {ArrAlloca, GEPs}});215Builder.SetInsertPoint(InsertPoint);216return {ArrAlloca, GEPs};217}218219/// Returns a pair of Value* with the first being a GEP into ArrAlloca using220/// indices {0, Index}, and the second Value* being a Load of the GEP221static std::pair<Value *, Value *>222dynamicallyLoadArray(IRBuilder<> &Builder, AllocaInst *ArrAlloca, Value *Index,223const Twine &Name = "") {224Type *ArrTy = ArrAlloca->getAllocatedType();225Value *GEP = Builder.CreateInBoundsGEP(226ArrTy, ArrAlloca, {Builder.getInt32(0), Index}, Name + ".index");227Value *Load =228Builder.CreateLoad(ArrTy->getArrayElementType(), GEP, Name + ".load");229return std::make_pair(GEP, Load);230}231232bool DataScalarizerVisitor::replaceDynamicInsertElementInst(233InsertElementInst &IEI) {234IRBuilder<> Builder(&IEI);235236Value *Vec = IEI.getOperand(0);237Value *Val = IEI.getOperand(1);238Value *Index = IEI.getOperand(2);239240AllocaAndGEPs ArrAllocaAndGEPs =241createArrayFromVector(Builder, Vec, IEI.getName());242AllocaInst *ArrAlloca = ArrAllocaAndGEPs.first;243Type *ArrTy = ArrAlloca->getAllocatedType();244SmallVector<Value *, 4> &ArrGEPs = ArrAllocaAndGEPs.second;245246auto GEPAndLoad =247dynamicallyLoadArray(Builder, ArrAlloca, Index, IEI.getName());248Value *GEP = GEPAndLoad.first;249Value *Load = GEPAndLoad.second;250251Builder.CreateStore(Val, GEP);252Value *NewIEI = PoisonValue::get(Vec->getType());253for (unsigned I = 0; I < ArrTy->getArrayNumElements(); ++I) {254Value *Load = Builder.CreateLoad(ArrTy->getArrayElementType(), ArrGEPs[I],255IEI.getName() + ".load");256NewIEI = Builder.CreateInsertElement(NewIEI, Load, Builder.getInt32(I),257IEI.getName() + ".insert");258}259260// Store back the original value so the Alloca can be reused for subsequent261// insertelement instructions on the same vector262Builder.CreateStore(Load, GEP);263264IEI.replaceAllUsesWith(NewIEI);265IEI.eraseFromParent();266return true;267}268269bool DataScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {270// If the index is a constant then we don't need to scalarize it271Value *Index = IEI.getOperand(2);272if (isa<ConstantInt>(Index))273return false;274return replaceDynamicInsertElementInst(IEI);275}276277bool DataScalarizerVisitor::replaceDynamicExtractElementInst(278ExtractElementInst &EEI) {279IRBuilder<> Builder(&EEI);280281AllocaAndGEPs ArrAllocaAndGEPs =282createArrayFromVector(Builder, EEI.getVectorOperand(), EEI.getName());283AllocaInst *ArrAlloca = ArrAllocaAndGEPs.first;284285auto GEPAndLoad = dynamicallyLoadArray(Builder, ArrAlloca,286EEI.getIndexOperand(), EEI.getName());287Value *Load = GEPAndLoad.second;288289EEI.replaceAllUsesWith(Load);290EEI.eraseFromParent();291return true;292}293294bool DataScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {295// If the index is a constant then we don't need to scalarize it296Value *Index = EEI.getIndexOperand();297if (isa<ConstantInt>(Index))298return false;299return replaceDynamicExtractElementInst(EEI);300}301302bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {303Value *PtrOperand = GEPI.getPointerOperand();304Type *OrigGEPType = GEPI.getSourceElementType();305Type *NewGEPType = OrigGEPType;306bool NeedsTransform = false;307308if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand)) {309NewGEPType = NewGlobal->getValueType();310PtrOperand = NewGlobal;311NeedsTransform = true;312} else if (AllocaInst *Alloca = dyn_cast<AllocaInst>(PtrOperand)) {313Type *AllocatedType = Alloca->getAllocatedType();314// Only transform if the allocated type is an array315if (AllocatedType != OrigGEPType && isa<ArrayType>(AllocatedType)) {316NewGEPType = AllocatedType;317NeedsTransform = true;318}319}320321// Scalar geps should remain scalars geps. The dxil-flatten-arrays pass will322// convert these scalar geps into flattened array geps323if (!isa<ArrayType>(OrigGEPType))324NewGEPType = OrigGEPType;325326// Note: We bail if this isn't a gep touched via alloca or global327// transformations328if (!NeedsTransform)329return false;330331IRBuilder<> Builder(&GEPI);332SmallVector<Value *, MaxVecSize> Indices(GEPI.indices());333334Value *NewGEP = Builder.CreateGEP(NewGEPType, PtrOperand, Indices,335GEPI.getName(), GEPI.getNoWrapFlags());336GEPI.replaceAllUsesWith(NewGEP);337GEPI.eraseFromParent();338return true;339}340341static Constant *transformInitializer(Constant *Init, Type *OrigType,342Type *NewType, LLVMContext &Ctx) {343// Handle ConstantAggregateZero (zero-initialized constants)344if (isa<ConstantAggregateZero>(Init)) {345return ConstantAggregateZero::get(NewType);346}347348// Handle UndefValue (undefined constants)349if (isa<UndefValue>(Init)) {350return UndefValue::get(NewType);351}352353// Handle vector to array transformation354if (isa<VectorType>(OrigType) && isa<ArrayType>(NewType)) {355// Convert vector initializer to array initializer356SmallVector<Constant *, MaxVecSize> ArrayElements;357if (ConstantVector *ConstVecInit = dyn_cast<ConstantVector>(Init)) {358for (unsigned I = 0; I < ConstVecInit->getNumOperands(); ++I)359ArrayElements.push_back(ConstVecInit->getOperand(I));360} else if (ConstantDataVector *ConstDataVecInit =361llvm::dyn_cast<llvm::ConstantDataVector>(Init)) {362for (unsigned I = 0; I < ConstDataVecInit->getNumElements(); ++I)363ArrayElements.push_back(ConstDataVecInit->getElementAsConstant(I));364} else {365assert(false && "Expected a ConstantVector or ConstantDataVector for "366"vector initializer!");367}368369return ConstantArray::get(cast<ArrayType>(NewType), ArrayElements);370}371372// Handle array of vectors transformation373if (auto *ArrayTy = dyn_cast<ArrayType>(OrigType)) {374auto *ArrayInit = dyn_cast<ConstantArray>(Init);375assert(ArrayInit && "Expected a ConstantArray for array initializer!");376377SmallVector<Constant *, MaxVecSize> NewArrayElements;378for (unsigned I = 0; I < ArrayTy->getNumElements(); ++I) {379// Recursively transform array elements380Constant *NewElemInit = transformInitializer(381ArrayInit->getOperand(I), ArrayTy->getElementType(),382cast<ArrayType>(NewType)->getElementType(), Ctx);383NewArrayElements.push_back(NewElemInit);384}385386return ConstantArray::get(cast<ArrayType>(NewType), NewArrayElements);387}388389// If not a vector or array, return the original initializer390return Init;391}392393static bool findAndReplaceVectors(Module &M) {394bool MadeChange = false;395LLVMContext &Ctx = M.getContext();396IRBuilder<> Builder(Ctx);397DataScalarizerVisitor Impl;398for (GlobalVariable &G : M.globals()) {399Type *OrigType = G.getValueType();400401Type *NewType = equivalentArrayTypeFromVector(OrigType);402if (OrigType != NewType) {403// Create a new global variable with the updated type404// Note: Initializer is set via transformInitializer405GlobalVariable *NewGlobal = new GlobalVariable(406M, NewType, G.isConstant(), G.getLinkage(),407/*Initializer=*/nullptr, G.getName() + ".scalarized", &G,408G.getThreadLocalMode(), G.getAddressSpace(),409G.isExternallyInitialized());410411// Copy relevant attributes412NewGlobal->setUnnamedAddr(G.getUnnamedAddr());413if (G.getAlignment() > 0) {414NewGlobal->setAlignment(G.getAlign());415}416417if (G.hasInitializer()) {418Constant *Init = G.getInitializer();419Constant *NewInit = transformInitializer(Init, OrigType, NewType, Ctx);420NewGlobal->setInitializer(NewInit);421}422423// Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes424// type equality. Instead we will use the visitor pattern.425Impl.GlobalMap[&G] = NewGlobal;426}427}428429for (auto &F : make_early_inc_range(M.functions())) {430if (F.isDeclaration())431continue;432MadeChange |= Impl.visit(F);433}434435// Remove the old globals after the iteration436for (auto &[Old, New] : Impl.GlobalMap) {437Old->eraseFromParent();438MadeChange = true;439}440return MadeChange;441}442443PreservedAnalyses DXILDataScalarization::run(Module &M,444ModuleAnalysisManager &) {445bool MadeChanges = findAndReplaceVectors(M);446if (!MadeChanges)447return PreservedAnalyses::all();448PreservedAnalyses PA;449return PA;450}451452bool DXILDataScalarizationLegacy::runOnModule(Module &M) {453return findAndReplaceVectors(M);454}455456char DXILDataScalarizationLegacy::ID = 0;457458INITIALIZE_PASS_BEGIN(DXILDataScalarizationLegacy, DEBUG_TYPE,459"DXIL Data Scalarization", false, false)460INITIALIZE_PASS_END(DXILDataScalarizationLegacy, DEBUG_TYPE,461"DXIL Data Scalarization", false, false)462463ModulePass *llvm::createDXILDataScalarizationLegacyPass() {464return new DXILDataScalarizationLegacy();465}466467468