Path: blob/main/contrib/llvm-project/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
213799 views
//===- DXILFlattenArrays.cpp - Flattens DXIL Arrays-----------------------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===---------------------------------------------------------------------===//7///8/// \file This file contains a pass to flatten arrays for the DirectX Backend.9///10//===----------------------------------------------------------------------===//1112#include "DXILFlattenArrays.h"13#include "DirectX.h"14#include "llvm/ADT/PostOrderIterator.h"15#include "llvm/ADT/STLExtras.h"16#include "llvm/IR/BasicBlock.h"17#include "llvm/IR/DerivedTypes.h"18#include "llvm/IR/IRBuilder.h"19#include "llvm/IR/InstVisitor.h"20#include "llvm/IR/ReplaceConstant.h"21#include "llvm/Support/Casting.h"22#include "llvm/Support/MathExtras.h"23#include "llvm/Transforms/Utils/Local.h"24#include <cassert>25#include <cstddef>26#include <cstdint>27#include <utility>2829#define DEBUG_TYPE "dxil-flatten-arrays"3031using namespace llvm;32namespace {3334class DXILFlattenArraysLegacy : public ModulePass {3536public:37bool runOnModule(Module &M) override;38DXILFlattenArraysLegacy() : ModulePass(ID) {}3940static char ID; // Pass identification.41};4243struct GEPInfo {44ArrayType *RootFlattenedArrayType;45Value *RootPointerOperand;46SmallMapVector<Value *, APInt, 4> VariableOffsets;47APInt ConstantOffset;48};4950class DXILFlattenArraysVisitor51: public InstVisitor<DXILFlattenArraysVisitor, bool> {52public:53DXILFlattenArraysVisitor(54SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap)55: GlobalMap(GlobalMap) {}56bool visit(Function &F);57// InstVisitor methods. They return true if the instruction was scalarized,58// false if nothing changed.59bool visitGetElementPtrInst(GetElementPtrInst &GEPI);60bool visitAllocaInst(AllocaInst &AI);61bool visitInstruction(Instruction &I) { return false; }62bool visitSelectInst(SelectInst &SI) { return false; }63bool visitICmpInst(ICmpInst &ICI) { return false; }64bool visitFCmpInst(FCmpInst &FCI) { return false; }65bool visitUnaryOperator(UnaryOperator &UO) { return false; }66bool visitBinaryOperator(BinaryOperator &BO) { return false; }67bool visitCastInst(CastInst &CI) { return false; }68bool visitBitCastInst(BitCastInst &BCI) { return false; }69bool visitInsertElementInst(InsertElementInst &IEI) { return false; }70bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }71bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }72bool visitPHINode(PHINode &PHI) { return false; }73bool visitLoadInst(LoadInst &LI);74bool visitStoreInst(StoreInst &SI);75bool visitCallInst(CallInst &ICI) { return false; }76bool visitFreezeInst(FreezeInst &FI) { return false; }77static bool isMultiDimensionalArray(Type *T);78static std::pair<unsigned, Type *> getElementCountAndType(Type *ArrayTy);7980private:81SmallVector<WeakTrackingVH> PotentiallyDeadInstrs;82SmallDenseMap<GEPOperator *, GEPInfo> GEPChainInfoMap;83SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap;84bool finish();85ConstantInt *genConstFlattenIndices(ArrayRef<Value *> Indices,86ArrayRef<uint64_t> Dims,87IRBuilder<> &Builder);88Value *genInstructionFlattenIndices(ArrayRef<Value *> Indices,89ArrayRef<uint64_t> Dims,90IRBuilder<> &Builder);91};92} // namespace9394bool DXILFlattenArraysVisitor::finish() {95GEPChainInfoMap.clear();96RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);97return true;98}99100bool DXILFlattenArraysVisitor::isMultiDimensionalArray(Type *T) {101if (ArrayType *ArrType = dyn_cast<ArrayType>(T))102return isa<ArrayType>(ArrType->getElementType());103return false;104}105106std::pair<unsigned, Type *>107DXILFlattenArraysVisitor::getElementCountAndType(Type *ArrayTy) {108unsigned TotalElements = 1;109Type *CurrArrayTy = ArrayTy;110while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {111TotalElements *= InnerArrayTy->getNumElements();112CurrArrayTy = InnerArrayTy->getElementType();113}114return std::make_pair(TotalElements, CurrArrayTy);115}116117ConstantInt *DXILFlattenArraysVisitor::genConstFlattenIndices(118ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {119assert(Indices.size() == Dims.size() &&120"Indicies and dimmensions should be the same");121unsigned FlatIndex = 0;122unsigned Multiplier = 1;123124for (int I = Indices.size() - 1; I >= 0; --I) {125unsigned DimSize = Dims[I];126ConstantInt *CIndex = dyn_cast<ConstantInt>(Indices[I]);127assert(CIndex && "This function expects all indicies to be ConstantInt");128FlatIndex += CIndex->getZExtValue() * Multiplier;129Multiplier *= DimSize;130}131return Builder.getInt32(FlatIndex);132}133134Value *DXILFlattenArraysVisitor::genInstructionFlattenIndices(135ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {136if (Indices.size() == 1)137return Indices[0];138139Value *FlatIndex = Builder.getInt32(0);140unsigned Multiplier = 1;141142for (int I = Indices.size() - 1; I >= 0; --I) {143unsigned DimSize = Dims[I];144Value *VMultiplier = Builder.getInt32(Multiplier);145Value *ScaledIndex = Builder.CreateMul(Indices[I], VMultiplier);146FlatIndex = Builder.CreateAdd(FlatIndex, ScaledIndex);147Multiplier *= DimSize;148}149return FlatIndex;150}151152bool DXILFlattenArraysVisitor::visitLoadInst(LoadInst &LI) {153unsigned NumOperands = LI.getNumOperands();154for (unsigned I = 0; I < NumOperands; ++I) {155Value *CurrOpperand = LI.getOperand(I);156ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);157if (CE && CE->getOpcode() == Instruction::GetElementPtr) {158GetElementPtrInst *OldGEP =159cast<GetElementPtrInst>(CE->getAsInstruction());160OldGEP->insertBefore(LI.getIterator());161162IRBuilder<> Builder(&LI);163LoadInst *NewLoad =164Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());165NewLoad->setAlignment(LI.getAlign());166LI.replaceAllUsesWith(NewLoad);167LI.eraseFromParent();168visitGetElementPtrInst(*OldGEP);169return true;170}171}172return false;173}174175bool DXILFlattenArraysVisitor::visitStoreInst(StoreInst &SI) {176unsigned NumOperands = SI.getNumOperands();177for (unsigned I = 0; I < NumOperands; ++I) {178Value *CurrOpperand = SI.getOperand(I);179ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);180if (CE && CE->getOpcode() == Instruction::GetElementPtr) {181GetElementPtrInst *OldGEP =182cast<GetElementPtrInst>(CE->getAsInstruction());183OldGEP->insertBefore(SI.getIterator());184185IRBuilder<> Builder(&SI);186StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);187NewStore->setAlignment(SI.getAlign());188SI.replaceAllUsesWith(NewStore);189SI.eraseFromParent();190visitGetElementPtrInst(*OldGEP);191return true;192}193}194return false;195}196197bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {198if (!isMultiDimensionalArray(AI.getAllocatedType()))199return false;200201ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType());202IRBuilder<> Builder(&AI);203auto [TotalElements, BaseType] = getElementCountAndType(ArrType);204205ArrayType *FattenedArrayType = ArrayType::get(BaseType, TotalElements);206AllocaInst *FlatAlloca =207Builder.CreateAlloca(FattenedArrayType, nullptr, AI.getName() + ".1dim");208FlatAlloca->setAlignment(AI.getAlign());209AI.replaceAllUsesWith(FlatAlloca);210AI.eraseFromParent();211return true;212}213214bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {215// Do not visit GEPs more than once216if (GEPChainInfoMap.contains(cast<GEPOperator>(&GEP)))217return false;218219Value *PtrOperand = GEP.getPointerOperand();220// It shouldn't(?) be possible for the pointer operand of a GEP to be a PHI221// node unless HLSL has pointers. If this assumption is incorrect or HLSL gets222// pointer types, then the handling of this case can be implemented later.223assert(!isa<PHINode>(PtrOperand) &&224"Pointer operand of GEP should not be a PHI Node");225226// Replace a GEP ConstantExpr pointer operand with a GEP instruction so that227// it can be visited228if (auto *PtrOpGEPCE = dyn_cast<ConstantExpr>(PtrOperand);229PtrOpGEPCE && PtrOpGEPCE->getOpcode() == Instruction::GetElementPtr) {230GetElementPtrInst *OldGEPI =231cast<GetElementPtrInst>(PtrOpGEPCE->getAsInstruction());232OldGEPI->insertBefore(GEP.getIterator());233234IRBuilder<> Builder(&GEP);235SmallVector<Value *> Indices(GEP.indices());236Value *NewGEP =237Builder.CreateGEP(GEP.getSourceElementType(), OldGEPI, Indices,238GEP.getName(), GEP.getNoWrapFlags());239assert(isa<GetElementPtrInst>(NewGEP) &&240"Expected newly-created GEP to be an instruction");241GetElementPtrInst *NewGEPI = cast<GetElementPtrInst>(NewGEP);242243GEP.replaceAllUsesWith(NewGEPI);244GEP.eraseFromParent();245visitGetElementPtrInst(*OldGEPI);246visitGetElementPtrInst(*NewGEPI);247return true;248}249250// Construct GEPInfo for this GEP251GEPInfo Info;252253// Obtain the variable and constant byte offsets computed by this GEP254const DataLayout &DL = GEP.getDataLayout();255unsigned BitWidth = DL.getIndexTypeSizeInBits(GEP.getType());256Info.ConstantOffset = {BitWidth, 0};257[[maybe_unused]] bool Success = GEP.collectOffset(258DL, BitWidth, Info.VariableOffsets, Info.ConstantOffset);259assert(Success && "Failed to collect offsets for GEP");260261// If there is a parent GEP, inherit the root array type and pointer, and262// merge the byte offsets. Otherwise, this GEP is itself the root of a GEP263// chain and we need to deterine the root array type264if (auto *PtrOpGEP = dyn_cast<GEPOperator>(PtrOperand)) {265assert(GEPChainInfoMap.contains(PtrOpGEP) &&266"Expected parent GEP to be visited before this GEP");267GEPInfo &PGEPInfo = GEPChainInfoMap[PtrOpGEP];268Info.RootFlattenedArrayType = PGEPInfo.RootFlattenedArrayType;269Info.RootPointerOperand = PGEPInfo.RootPointerOperand;270for (auto &VariableOffset : PGEPInfo.VariableOffsets)271Info.VariableOffsets.insert(VariableOffset);272Info.ConstantOffset += PGEPInfo.ConstantOffset;273} else {274Info.RootPointerOperand = PtrOperand;275276// We should try to determine the type of the root from the pointer rather277// than the GEP's source element type because this could be a scalar GEP278// into an array-typed pointer from an Alloca or Global Variable.279Type *RootTy = GEP.getSourceElementType();280if (auto *GlobalVar = dyn_cast<GlobalVariable>(PtrOperand)) {281if (GlobalMap.contains(GlobalVar))282GlobalVar = GlobalMap[GlobalVar];283Info.RootPointerOperand = GlobalVar;284RootTy = GlobalVar->getValueType();285} else if (auto *Alloca = dyn_cast<AllocaInst>(PtrOperand))286RootTy = Alloca->getAllocatedType();287assert(!isMultiDimensionalArray(RootTy) &&288"Expected root array type to be flattened");289290// If the root type is not an array, we don't need to do any flattening291if (!isa<ArrayType>(RootTy))292return false;293294Info.RootFlattenedArrayType = cast<ArrayType>(RootTy);295}296297// GEPs without users or GEPs with non-GEP users should be replaced such that298// the chain of GEPs they are a part of are collapsed to a single GEP into a299// flattened array.300bool ReplaceThisGEP = GEP.users().empty();301for (Value *User : GEP.users())302if (!isa<GetElementPtrInst>(User))303ReplaceThisGEP = true;304305if (ReplaceThisGEP) {306unsigned BytesPerElem =307DL.getTypeAllocSize(Info.RootFlattenedArrayType->getArrayElementType());308assert(isPowerOf2_32(BytesPerElem) &&309"Bytes per element should be a power of 2");310311// Compute the 32-bit index for this flattened GEP from the constant and312// variable byte offsets in the GEPInfo313IRBuilder<> Builder(&GEP);314Value *ZeroIndex = Builder.getInt32(0);315uint64_t ConstantOffset =316Info.ConstantOffset.udiv(BytesPerElem).getZExtValue();317assert(ConstantOffset < UINT32_MAX &&318"Constant byte offset for flat GEP index must fit within 32 bits");319Value *FlattenedIndex = Builder.getInt32(ConstantOffset);320for (auto [VarIndex, Multiplier] : Info.VariableOffsets) {321assert(Multiplier.getActiveBits() <= 32 &&322"The multiplier for a flat GEP index must fit within 32 bits");323assert(VarIndex->getType()->isIntegerTy(32) &&324"Expected i32-typed GEP indices");325Value *VI;326if (Multiplier.getZExtValue() % BytesPerElem != 0) {327// This can happen, e.g., with i8 GEPs. To handle this we just divide328// by BytesPerElem using an instruction after multiplying VarIndex by329// Multiplier.330VI = Builder.CreateMul(VarIndex,331Builder.getInt32(Multiplier.getZExtValue()));332VI = Builder.CreateLShr(VI, Builder.getInt32(Log2_32(BytesPerElem)));333} else334VI = Builder.CreateMul(335VarIndex,336Builder.getInt32(Multiplier.getZExtValue() / BytesPerElem));337FlattenedIndex = Builder.CreateAdd(FlattenedIndex, VI);338}339340// Construct a new GEP for the flattened array to replace the current GEP341Value *NewGEP = Builder.CreateGEP(342Info.RootFlattenedArrayType, Info.RootPointerOperand,343{ZeroIndex, FlattenedIndex}, GEP.getName(), GEP.getNoWrapFlags());344345// Replace the current GEP with the new GEP. Store GEPInfo into the map346// for later use in case this GEP was not the end of the chain347GEPChainInfoMap.insert({cast<GEPOperator>(NewGEP), std::move(Info)});348GEP.replaceAllUsesWith(NewGEP);349GEP.eraseFromParent();350return true;351}352353// This GEP is potentially dead at the end of the pass since it may not have354// any users anymore after GEP chains have been collapsed. We retain store355// GEPInfo for GEPs down the chain to use to compute their indices.356GEPChainInfoMap.insert({cast<GEPOperator>(&GEP), std::move(Info)});357PotentiallyDeadInstrs.emplace_back(&GEP);358return false;359}360361bool DXILFlattenArraysVisitor::visit(Function &F) {362bool MadeChange = false;363ReversePostOrderTraversal<Function *> RPOT(&F);364for (BasicBlock *BB : make_early_inc_range(RPOT)) {365for (Instruction &I : make_early_inc_range(*BB))366MadeChange |= InstVisitor::visit(I);367}368finish();369return MadeChange;370}371372static void collectElements(Constant *Init,373SmallVectorImpl<Constant *> &Elements) {374// Base case: If Init is not an array, add it directly to the vector.375auto *ArrayTy = dyn_cast<ArrayType>(Init->getType());376if (!ArrayTy) {377Elements.push_back(Init);378return;379}380unsigned ArrSize = ArrayTy->getNumElements();381if (isa<ConstantAggregateZero>(Init)) {382for (unsigned I = 0; I < ArrSize; ++I)383Elements.push_back(Constant::getNullValue(ArrayTy->getElementType()));384return;385}386387// Recursive case: Process each element in the array.388if (auto *ArrayConstant = dyn_cast<ConstantArray>(Init)) {389for (unsigned I = 0; I < ArrayConstant->getNumOperands(); ++I) {390collectElements(ArrayConstant->getOperand(I), Elements);391}392} else if (auto *DataArrayConstant = dyn_cast<ConstantDataArray>(Init)) {393for (unsigned I = 0; I < DataArrayConstant->getNumElements(); ++I) {394collectElements(DataArrayConstant->getElementAsConstant(I), Elements);395}396} else {397llvm_unreachable(398"Expected a ConstantArray or ConstantDataArray for array initializer!");399}400}401402static Constant *transformInitializer(Constant *Init, Type *OrigType,403ArrayType *FlattenedType,404LLVMContext &Ctx) {405// Handle ConstantAggregateZero (zero-initialized constants)406if (isa<ConstantAggregateZero>(Init))407return ConstantAggregateZero::get(FlattenedType);408409// Handle UndefValue (undefined constants)410if (isa<UndefValue>(Init))411return UndefValue::get(FlattenedType);412413if (!isa<ArrayType>(OrigType))414return Init;415416SmallVector<Constant *> FlattenedElements;417collectElements(Init, FlattenedElements);418assert(FlattenedType->getNumElements() == FlattenedElements.size() &&419"The number of collected elements should match the FlattenedType");420return ConstantArray::get(FlattenedType, FlattenedElements);421}422423static void flattenGlobalArrays(424Module &M, SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap) {425LLVMContext &Ctx = M.getContext();426for (GlobalVariable &G : M.globals()) {427Type *OrigType = G.getValueType();428if (!DXILFlattenArraysVisitor::isMultiDimensionalArray(OrigType))429continue;430431ArrayType *ArrType = cast<ArrayType>(OrigType);432auto [TotalElements, BaseType] =433DXILFlattenArraysVisitor::getElementCountAndType(ArrType);434ArrayType *FattenedArrayType = ArrayType::get(BaseType, TotalElements);435436// Create a new global variable with the updated type437// Note: Initializer is set via transformInitializer438GlobalVariable *NewGlobal =439new GlobalVariable(M, FattenedArrayType, G.isConstant(), G.getLinkage(),440/*Initializer=*/nullptr, G.getName() + ".1dim", &G,441G.getThreadLocalMode(), G.getAddressSpace(),442G.isExternallyInitialized());443444// Copy relevant attributes445NewGlobal->setUnnamedAddr(G.getUnnamedAddr());446if (G.getAlignment() > 0) {447NewGlobal->setAlignment(G.getAlign());448}449450if (G.hasInitializer()) {451Constant *Init = G.getInitializer();452Constant *NewInit =453transformInitializer(Init, OrigType, FattenedArrayType, Ctx);454NewGlobal->setInitializer(NewInit);455}456GlobalMap[&G] = NewGlobal;457}458}459460static bool flattenArrays(Module &M) {461bool MadeChange = false;462SmallDenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;463flattenGlobalArrays(M, GlobalMap);464DXILFlattenArraysVisitor Impl(GlobalMap);465for (auto &F : make_early_inc_range(M.functions())) {466if (F.isDeclaration())467continue;468MadeChange |= Impl.visit(F);469}470for (auto &[Old, New] : GlobalMap) {471Old->replaceAllUsesWith(New);472Old->eraseFromParent();473MadeChange = true;474}475return MadeChange;476}477478PreservedAnalyses DXILFlattenArrays::run(Module &M, ModuleAnalysisManager &) {479bool MadeChanges = flattenArrays(M);480if (!MadeChanges)481return PreservedAnalyses::all();482PreservedAnalyses PA;483return PA;484}485486bool DXILFlattenArraysLegacy::runOnModule(Module &M) {487return flattenArrays(M);488}489490char DXILFlattenArraysLegacy::ID = 0;491492INITIALIZE_PASS_BEGIN(DXILFlattenArraysLegacy, DEBUG_TYPE,493"DXIL Array Flattener", false, false)494INITIALIZE_PASS_END(DXILFlattenArraysLegacy, DEBUG_TYPE, "DXIL Array Flattener",495false, false)496497ModulePass *llvm::createDXILFlattenArraysLegacyPass() {498return new DXILFlattenArraysLegacy();499}500501502