Path: blob/main/contrib/llvm-project/llvm/lib/Target/DirectX/DXILCBufferAccess.cpp
213799 views
//===- DXILCBufferAccess.cpp - Translate CBuffer Loads --------------------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//78#include "DXILCBufferAccess.h"9#include "DirectX.h"10#include "llvm/Frontend/HLSL/CBuffer.h"11#include "llvm/Frontend/HLSL/HLSLResource.h"12#include "llvm/IR/IRBuilder.h"13#include "llvm/IR/IntrinsicInst.h"14#include "llvm/IR/IntrinsicsDirectX.h"15#include "llvm/InitializePasses.h"16#include "llvm/Pass.h"17#include "llvm/Support/FormatVariadic.h"18#include "llvm/Transforms/Utils/Local.h"1920#define DEBUG_TYPE "dxil-cbuffer-access"21using namespace llvm;2223namespace {24/// Helper for building a `load.cbufferrow` intrinsic given a simple type.25struct CBufferRowIntrin {26Intrinsic::ID IID;27Type *RetTy;28unsigned int EltSize;29unsigned int NumElts;3031CBufferRowIntrin(const DataLayout &DL, Type *Ty) {32assert(Ty == Ty->getScalarType() && "Expected scalar type");3334switch (DL.getTypeSizeInBits(Ty)) {35case 16:36IID = Intrinsic::dx_resource_load_cbufferrow_8;37RetTy = StructType::get(Ty, Ty, Ty, Ty, Ty, Ty, Ty, Ty);38EltSize = 2;39NumElts = 8;40break;41case 32:42IID = Intrinsic::dx_resource_load_cbufferrow_4;43RetTy = StructType::get(Ty, Ty, Ty, Ty);44EltSize = 4;45NumElts = 4;46break;47case 64:48IID = Intrinsic::dx_resource_load_cbufferrow_2;49RetTy = StructType::get(Ty, Ty);50EltSize = 8;51NumElts = 2;52break;53default:54llvm_unreachable("Only 16, 32, and 64 bit types supported");55}56}57};5859// Helper for creating CBuffer handles and loading data from them60struct CBufferResource {61GlobalVariable *GVHandle;62GlobalVariable *Member;63size_t MemberOffset;6465LoadInst *Handle;6667CBufferResource(GlobalVariable *GVHandle, GlobalVariable *Member,68size_t MemberOffset)69: GVHandle(GVHandle), Member(Member), MemberOffset(MemberOffset) {}7071const DataLayout &getDataLayout() { return GVHandle->getDataLayout(); }72Type *getValueType() { return Member->getValueType(); }73iterator_range<ConstantDataSequential::user_iterator> users() {74return Member->users();75}7677/// Get the byte offset of a Pointer-typed Value * `Val` relative to Member.78/// `Val` can either be Member itself, or a GEP of a constant offset from79/// Member80size_t getOffsetForCBufferGEP(Value *Val) {81assert(isa<PointerType>(Val->getType()) &&82"Expected a pointer-typed value");8384if (Val == Member)85return 0;8687if (auto *GEP = dyn_cast<GEPOperator>(Val)) {88// Since we should always have a constant offset, we should only ever have89// a single GEP of indirection from the Global.90assert(GEP->getPointerOperand() == Member &&91"Indirect access to resource handle");9293const DataLayout &DL = getDataLayout();94APInt ConstantOffset(DL.getIndexTypeSizeInBits(GEP->getType()), 0);95bool Success = GEP->accumulateConstantOffset(DL, ConstantOffset);96(void)Success;97assert(Success && "Offsets into cbuffer globals must be constant");9899if (auto *ATy = dyn_cast<ArrayType>(Member->getValueType()))100ConstantOffset =101hlsl::translateCBufArrayOffset(DL, ConstantOffset, ATy);102103return ConstantOffset.getZExtValue();104}105106llvm_unreachable("Expected Val to be a GlobalVariable or GEP");107}108109/// Create a handle for this cbuffer resource using the IRBuilder `Builder`110/// and sets the handle as the current one to use for subsequent calls to111/// `loadValue`112void createAndSetCurrentHandle(IRBuilder<> &Builder) {113Handle = Builder.CreateLoad(GVHandle->getValueType(), GVHandle,114GVHandle->getName());115}116117/// Load a value of type `Ty` at offset `Offset` using the handle from the118/// last call to `createAndSetCurrentHandle`119Value *loadValue(IRBuilder<> &Builder, Type *Ty, size_t Offset,120const Twine &Name = "") {121assert(Handle &&122"Expected a handle for this cbuffer global resource to be created "123"before loading a value from it");124const DataLayout &DL = getDataLayout();125126size_t TargetOffset = MemberOffset + Offset;127CBufferRowIntrin Intrin(DL, Ty->getScalarType());128// The cbuffer consists of some number of 16-byte rows.129unsigned int CurrentRow = TargetOffset / hlsl::CBufferRowSizeInBytes;130unsigned int CurrentIndex =131(TargetOffset % hlsl::CBufferRowSizeInBytes) / Intrin.EltSize;132133auto *CBufLoad = Builder.CreateIntrinsic(134Intrin.RetTy, Intrin.IID,135{Handle, ConstantInt::get(Builder.getInt32Ty(), CurrentRow)}, nullptr,136Name + ".load");137auto *Elt = Builder.CreateExtractValue(CBufLoad, {CurrentIndex++},138Name + ".extract");139140Value *Result = nullptr;141unsigned int Remaining =142((DL.getTypeSizeInBits(Ty) / 8) / Intrin.EltSize) - 1;143144if (Remaining == 0) {145// We only have a single element, so we're done.146Result = Elt;147148// However, if we loaded a <1 x T>, then we need to adjust the type here.149if (auto *VT = dyn_cast<FixedVectorType>(Ty)) {150assert(VT->getNumElements() == 1 &&151"Can't have multiple elements here");152Result = Builder.CreateInsertElement(PoisonValue::get(VT), Result,153Builder.getInt32(0), Name);154}155return Result;156}157158// Walk each element and extract it, wrapping to new rows as needed.159SmallVector<Value *> Extracts{Elt};160while (Remaining--) {161CurrentIndex %= Intrin.NumElts;162163if (CurrentIndex == 0)164CBufLoad = Builder.CreateIntrinsic(165Intrin.RetTy, Intrin.IID,166{Handle, ConstantInt::get(Builder.getInt32Ty(), ++CurrentRow)},167nullptr, Name + ".load");168169Extracts.push_back(Builder.CreateExtractValue(CBufLoad, {CurrentIndex++},170Name + ".extract"));171}172173// Finally, we build up the original loaded value.174Result = PoisonValue::get(Ty);175for (int I = 0, E = Extracts.size(); I < E; ++I)176Result =177Builder.CreateInsertElement(Result, Extracts[I], Builder.getInt32(I),178Name + formatv(".upto{}", I));179return Result;180}181};182183} // namespace184185/// Replace load via cbuffer global with a load from the cbuffer handle itself.186static void replaceLoad(LoadInst *LI, CBufferResource &CBR,187SmallVectorImpl<WeakTrackingVH> &DeadInsts) {188size_t Offset = CBR.getOffsetForCBufferGEP(LI->getPointerOperand());189IRBuilder<> Builder(LI);190CBR.createAndSetCurrentHandle(Builder);191Value *Result = CBR.loadValue(Builder, LI->getType(), Offset, LI->getName());192LI->replaceAllUsesWith(Result);193DeadInsts.push_back(LI);194}195196/// This function recursively copies N array elements from the cbuffer resource197/// CBR to the MemCpy Destination. Recursion is used to unravel multidimensional198/// arrays into a sequence of scalar/vector extracts and stores.199static void copyArrayElemsForMemCpy(IRBuilder<> &Builder, MemCpyInst *MCI,200CBufferResource &CBR, ArrayType *ArrTy,201size_t ArrOffset, size_t N,202const Twine &Name = "") {203const DataLayout &DL = MCI->getDataLayout();204Type *ElemTy = ArrTy->getElementType();205size_t ElemTySize = DL.getTypeAllocSize(ElemTy);206for (unsigned I = 0; I < N; ++I) {207size_t Offset = ArrOffset + I * ElemTySize;208209// Recursively copy nested arrays210if (ArrayType *ElemArrTy = dyn_cast<ArrayType>(ElemTy)) {211copyArrayElemsForMemCpy(Builder, MCI, CBR, ElemArrTy, Offset,212ElemArrTy->getNumElements(), Name);213continue;214}215216// Load CBuffer value and store it in Dest217APInt CBufArrayOffset(218DL.getIndexTypeSizeInBits(MCI->getSource()->getType()), Offset);219CBufArrayOffset =220hlsl::translateCBufArrayOffset(DL, CBufArrayOffset, ArrTy);221Value *CBufferVal =222CBR.loadValue(Builder, ElemTy, CBufArrayOffset.getZExtValue(), Name);223Value *GEP =224Builder.CreateInBoundsGEP(Builder.getInt8Ty(), MCI->getDest(),225{Builder.getInt32(Offset)}, Name + ".dest");226Builder.CreateStore(CBufferVal, GEP, MCI->isVolatile());227}228}229230/// Replace memcpy from a cbuffer global with a memcpy from the cbuffer handle231/// itself. Assumes the cbuffer global is an array, and the length of bytes to232/// copy is divisible by array element allocation size.233/// The memcpy source must also be a direct cbuffer global reference, not a GEP.234static void replaceMemCpy(MemCpyInst *MCI, CBufferResource &CBR) {235236ArrayType *ArrTy = dyn_cast<ArrayType>(CBR.getValueType());237assert(ArrTy && "MemCpy lowering is only supported for array types");238239// This assumption vastly simplifies the implementation240if (MCI->getSource() != CBR.Member)241reportFatalUsageError(242"Expected MemCpy source to be a cbuffer global variable");243244ConstantInt *Length = dyn_cast<ConstantInt>(MCI->getLength());245uint64_t ByteLength = Length->getZExtValue();246247// If length to copy is zero, no memcpy is needed248if (ByteLength == 0) {249MCI->eraseFromParent();250return;251}252253const DataLayout &DL = CBR.getDataLayout();254255Type *ElemTy = ArrTy->getElementType();256size_t ElemSize = DL.getTypeAllocSize(ElemTy);257assert(ByteLength % ElemSize == 0 &&258"Length of bytes to MemCpy must be divisible by allocation size of "259"source/destination array elements");260size_t ElemsToCpy = ByteLength / ElemSize;261262IRBuilder<> Builder(MCI);263CBR.createAndSetCurrentHandle(Builder);264265copyArrayElemsForMemCpy(Builder, MCI, CBR, ArrTy, 0, ElemsToCpy,266"memcpy." + MCI->getDest()->getName() + "." +267MCI->getSource()->getName());268269MCI->eraseFromParent();270}271272static void replaceAccessesWithHandle(CBufferResource &CBR) {273SmallVector<WeakTrackingVH> DeadInsts;274275SmallVector<User *> ToProcess{CBR.users()};276while (!ToProcess.empty()) {277User *Cur = ToProcess.pop_back_val();278279// If we have a load instruction, replace the access.280if (auto *LI = dyn_cast<LoadInst>(Cur)) {281replaceLoad(LI, CBR, DeadInsts);282continue;283}284285// If we have a memcpy instruction, replace it with multiple accesses and286// subsequent stores to the destination287if (auto *MCI = dyn_cast<MemCpyInst>(Cur)) {288replaceMemCpy(MCI, CBR);289continue;290}291292// Otherwise, walk users looking for a load...293if (isa<GetElementPtrInst>(Cur) || isa<GEPOperator>(Cur)) {294ToProcess.append(Cur->user_begin(), Cur->user_end());295continue;296}297298llvm_unreachable("Unexpected user of Global");299}300RecursivelyDeleteTriviallyDeadInstructions(DeadInsts);301}302303static bool replaceCBufferAccesses(Module &M) {304std::optional<hlsl::CBufferMetadata> CBufMD = hlsl::CBufferMetadata::get(M);305if (!CBufMD)306return false;307308for (const hlsl::CBufferMapping &Mapping : *CBufMD)309for (const hlsl::CBufferMember &Member : Mapping.Members) {310CBufferResource CBR(Mapping.Handle, Member.GV, Member.Offset);311replaceAccessesWithHandle(CBR);312Member.GV->removeFromParent();313}314315CBufMD->eraseFromModule();316return true;317}318319PreservedAnalyses DXILCBufferAccess::run(Module &M, ModuleAnalysisManager &AM) {320PreservedAnalyses PA;321bool Changed = replaceCBufferAccesses(M);322323if (!Changed)324return PreservedAnalyses::all();325return PA;326}327328namespace {329class DXILCBufferAccessLegacy : public ModulePass {330public:331bool runOnModule(Module &M) override { return replaceCBufferAccesses(M); }332StringRef getPassName() const override { return "DXIL CBuffer Access"; }333DXILCBufferAccessLegacy() : ModulePass(ID) {}334335static char ID; // Pass identification.336};337char DXILCBufferAccessLegacy::ID = 0;338} // end anonymous namespace339340INITIALIZE_PASS(DXILCBufferAccessLegacy, DEBUG_TYPE, "DXIL CBuffer Access",341false, false)342343ModulePass *llvm::createDXILCBufferAccessLegacyPass() {344return new DXILCBufferAccessLegacy();345}346347348