Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Target/DirectX/DXILCBufferAccess.cpp
213799 views
1
//===- DXILCBufferAccess.cpp - Translate CBuffer Loads --------------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
9
#include "DXILCBufferAccess.h"
10
#include "DirectX.h"
11
#include "llvm/Frontend/HLSL/CBuffer.h"
12
#include "llvm/Frontend/HLSL/HLSLResource.h"
13
#include "llvm/IR/IRBuilder.h"
14
#include "llvm/IR/IntrinsicInst.h"
15
#include "llvm/IR/IntrinsicsDirectX.h"
16
#include "llvm/InitializePasses.h"
17
#include "llvm/Pass.h"
18
#include "llvm/Support/FormatVariadic.h"
19
#include "llvm/Transforms/Utils/Local.h"
20
21
#define DEBUG_TYPE "dxil-cbuffer-access"
22
using namespace llvm;
23
24
namespace {
25
/// Helper for building a `load.cbufferrow` intrinsic given a simple type.
26
struct CBufferRowIntrin {
27
Intrinsic::ID IID;
28
Type *RetTy;
29
unsigned int EltSize;
30
unsigned int NumElts;
31
32
CBufferRowIntrin(const DataLayout &DL, Type *Ty) {
33
assert(Ty == Ty->getScalarType() && "Expected scalar type");
34
35
switch (DL.getTypeSizeInBits(Ty)) {
36
case 16:
37
IID = Intrinsic::dx_resource_load_cbufferrow_8;
38
RetTy = StructType::get(Ty, Ty, Ty, Ty, Ty, Ty, Ty, Ty);
39
EltSize = 2;
40
NumElts = 8;
41
break;
42
case 32:
43
IID = Intrinsic::dx_resource_load_cbufferrow_4;
44
RetTy = StructType::get(Ty, Ty, Ty, Ty);
45
EltSize = 4;
46
NumElts = 4;
47
break;
48
case 64:
49
IID = Intrinsic::dx_resource_load_cbufferrow_2;
50
RetTy = StructType::get(Ty, Ty);
51
EltSize = 8;
52
NumElts = 2;
53
break;
54
default:
55
llvm_unreachable("Only 16, 32, and 64 bit types supported");
56
}
57
}
58
};
59
60
// Helper for creating CBuffer handles and loading data from them
61
struct CBufferResource {
62
GlobalVariable *GVHandle;
63
GlobalVariable *Member;
64
size_t MemberOffset;
65
66
LoadInst *Handle;
67
68
CBufferResource(GlobalVariable *GVHandle, GlobalVariable *Member,
69
size_t MemberOffset)
70
: GVHandle(GVHandle), Member(Member), MemberOffset(MemberOffset) {}
71
72
const DataLayout &getDataLayout() { return GVHandle->getDataLayout(); }
73
Type *getValueType() { return Member->getValueType(); }
74
iterator_range<ConstantDataSequential::user_iterator> users() {
75
return Member->users();
76
}
77
78
/// Get the byte offset of a Pointer-typed Value * `Val` relative to Member.
79
/// `Val` can either be Member itself, or a GEP of a constant offset from
80
/// Member
81
size_t getOffsetForCBufferGEP(Value *Val) {
82
assert(isa<PointerType>(Val->getType()) &&
83
"Expected a pointer-typed value");
84
85
if (Val == Member)
86
return 0;
87
88
if (auto *GEP = dyn_cast<GEPOperator>(Val)) {
89
// Since we should always have a constant offset, we should only ever have
90
// a single GEP of indirection from the Global.
91
assert(GEP->getPointerOperand() == Member &&
92
"Indirect access to resource handle");
93
94
const DataLayout &DL = getDataLayout();
95
APInt ConstantOffset(DL.getIndexTypeSizeInBits(GEP->getType()), 0);
96
bool Success = GEP->accumulateConstantOffset(DL, ConstantOffset);
97
(void)Success;
98
assert(Success && "Offsets into cbuffer globals must be constant");
99
100
if (auto *ATy = dyn_cast<ArrayType>(Member->getValueType()))
101
ConstantOffset =
102
hlsl::translateCBufArrayOffset(DL, ConstantOffset, ATy);
103
104
return ConstantOffset.getZExtValue();
105
}
106
107
llvm_unreachable("Expected Val to be a GlobalVariable or GEP");
108
}
109
110
/// Create a handle for this cbuffer resource using the IRBuilder `Builder`
111
/// and sets the handle as the current one to use for subsequent calls to
112
/// `loadValue`
113
void createAndSetCurrentHandle(IRBuilder<> &Builder) {
114
Handle = Builder.CreateLoad(GVHandle->getValueType(), GVHandle,
115
GVHandle->getName());
116
}
117
118
/// Load a value of type `Ty` at offset `Offset` using the handle from the
119
/// last call to `createAndSetCurrentHandle`
120
Value *loadValue(IRBuilder<> &Builder, Type *Ty, size_t Offset,
121
const Twine &Name = "") {
122
assert(Handle &&
123
"Expected a handle for this cbuffer global resource to be created "
124
"before loading a value from it");
125
const DataLayout &DL = getDataLayout();
126
127
size_t TargetOffset = MemberOffset + Offset;
128
CBufferRowIntrin Intrin(DL, Ty->getScalarType());
129
// The cbuffer consists of some number of 16-byte rows.
130
unsigned int CurrentRow = TargetOffset / hlsl::CBufferRowSizeInBytes;
131
unsigned int CurrentIndex =
132
(TargetOffset % hlsl::CBufferRowSizeInBytes) / Intrin.EltSize;
133
134
auto *CBufLoad = Builder.CreateIntrinsic(
135
Intrin.RetTy, Intrin.IID,
136
{Handle, ConstantInt::get(Builder.getInt32Ty(), CurrentRow)}, nullptr,
137
Name + ".load");
138
auto *Elt = Builder.CreateExtractValue(CBufLoad, {CurrentIndex++},
139
Name + ".extract");
140
141
Value *Result = nullptr;
142
unsigned int Remaining =
143
((DL.getTypeSizeInBits(Ty) / 8) / Intrin.EltSize) - 1;
144
145
if (Remaining == 0) {
146
// We only have a single element, so we're done.
147
Result = Elt;
148
149
// However, if we loaded a <1 x T>, then we need to adjust the type here.
150
if (auto *VT = dyn_cast<FixedVectorType>(Ty)) {
151
assert(VT->getNumElements() == 1 &&
152
"Can't have multiple elements here");
153
Result = Builder.CreateInsertElement(PoisonValue::get(VT), Result,
154
Builder.getInt32(0), Name);
155
}
156
return Result;
157
}
158
159
// Walk each element and extract it, wrapping to new rows as needed.
160
SmallVector<Value *> Extracts{Elt};
161
while (Remaining--) {
162
CurrentIndex %= Intrin.NumElts;
163
164
if (CurrentIndex == 0)
165
CBufLoad = Builder.CreateIntrinsic(
166
Intrin.RetTy, Intrin.IID,
167
{Handle, ConstantInt::get(Builder.getInt32Ty(), ++CurrentRow)},
168
nullptr, Name + ".load");
169
170
Extracts.push_back(Builder.CreateExtractValue(CBufLoad, {CurrentIndex++},
171
Name + ".extract"));
172
}
173
174
// Finally, we build up the original loaded value.
175
Result = PoisonValue::get(Ty);
176
for (int I = 0, E = Extracts.size(); I < E; ++I)
177
Result =
178
Builder.CreateInsertElement(Result, Extracts[I], Builder.getInt32(I),
179
Name + formatv(".upto{}", I));
180
return Result;
181
}
182
};
183
184
} // namespace
185
186
/// Replace load via cbuffer global with a load from the cbuffer handle itself.
187
static void replaceLoad(LoadInst *LI, CBufferResource &CBR,
188
SmallVectorImpl<WeakTrackingVH> &DeadInsts) {
189
size_t Offset = CBR.getOffsetForCBufferGEP(LI->getPointerOperand());
190
IRBuilder<> Builder(LI);
191
CBR.createAndSetCurrentHandle(Builder);
192
Value *Result = CBR.loadValue(Builder, LI->getType(), Offset, LI->getName());
193
LI->replaceAllUsesWith(Result);
194
DeadInsts.push_back(LI);
195
}
196
197
/// This function recursively copies N array elements from the cbuffer resource
198
/// CBR to the MemCpy Destination. Recursion is used to unravel multidimensional
199
/// arrays into a sequence of scalar/vector extracts and stores.
200
static void copyArrayElemsForMemCpy(IRBuilder<> &Builder, MemCpyInst *MCI,
201
CBufferResource &CBR, ArrayType *ArrTy,
202
size_t ArrOffset, size_t N,
203
const Twine &Name = "") {
204
const DataLayout &DL = MCI->getDataLayout();
205
Type *ElemTy = ArrTy->getElementType();
206
size_t ElemTySize = DL.getTypeAllocSize(ElemTy);
207
for (unsigned I = 0; I < N; ++I) {
208
size_t Offset = ArrOffset + I * ElemTySize;
209
210
// Recursively copy nested arrays
211
if (ArrayType *ElemArrTy = dyn_cast<ArrayType>(ElemTy)) {
212
copyArrayElemsForMemCpy(Builder, MCI, CBR, ElemArrTy, Offset,
213
ElemArrTy->getNumElements(), Name);
214
continue;
215
}
216
217
// Load CBuffer value and store it in Dest
218
APInt CBufArrayOffset(
219
DL.getIndexTypeSizeInBits(MCI->getSource()->getType()), Offset);
220
CBufArrayOffset =
221
hlsl::translateCBufArrayOffset(DL, CBufArrayOffset, ArrTy);
222
Value *CBufferVal =
223
CBR.loadValue(Builder, ElemTy, CBufArrayOffset.getZExtValue(), Name);
224
Value *GEP =
225
Builder.CreateInBoundsGEP(Builder.getInt8Ty(), MCI->getDest(),
226
{Builder.getInt32(Offset)}, Name + ".dest");
227
Builder.CreateStore(CBufferVal, GEP, MCI->isVolatile());
228
}
229
}
230
231
/// Replace memcpy from a cbuffer global with a memcpy from the cbuffer handle
232
/// itself. Assumes the cbuffer global is an array, and the length of bytes to
233
/// copy is divisible by array element allocation size.
234
/// The memcpy source must also be a direct cbuffer global reference, not a GEP.
235
static void replaceMemCpy(MemCpyInst *MCI, CBufferResource &CBR) {
236
237
ArrayType *ArrTy = dyn_cast<ArrayType>(CBR.getValueType());
238
assert(ArrTy && "MemCpy lowering is only supported for array types");
239
240
// This assumption vastly simplifies the implementation
241
if (MCI->getSource() != CBR.Member)
242
reportFatalUsageError(
243
"Expected MemCpy source to be a cbuffer global variable");
244
245
ConstantInt *Length = dyn_cast<ConstantInt>(MCI->getLength());
246
uint64_t ByteLength = Length->getZExtValue();
247
248
// If length to copy is zero, no memcpy is needed
249
if (ByteLength == 0) {
250
MCI->eraseFromParent();
251
return;
252
}
253
254
const DataLayout &DL = CBR.getDataLayout();
255
256
Type *ElemTy = ArrTy->getElementType();
257
size_t ElemSize = DL.getTypeAllocSize(ElemTy);
258
assert(ByteLength % ElemSize == 0 &&
259
"Length of bytes to MemCpy must be divisible by allocation size of "
260
"source/destination array elements");
261
size_t ElemsToCpy = ByteLength / ElemSize;
262
263
IRBuilder<> Builder(MCI);
264
CBR.createAndSetCurrentHandle(Builder);
265
266
copyArrayElemsForMemCpy(Builder, MCI, CBR, ArrTy, 0, ElemsToCpy,
267
"memcpy." + MCI->getDest()->getName() + "." +
268
MCI->getSource()->getName());
269
270
MCI->eraseFromParent();
271
}
272
273
static void replaceAccessesWithHandle(CBufferResource &CBR) {
274
SmallVector<WeakTrackingVH> DeadInsts;
275
276
SmallVector<User *> ToProcess{CBR.users()};
277
while (!ToProcess.empty()) {
278
User *Cur = ToProcess.pop_back_val();
279
280
// If we have a load instruction, replace the access.
281
if (auto *LI = dyn_cast<LoadInst>(Cur)) {
282
replaceLoad(LI, CBR, DeadInsts);
283
continue;
284
}
285
286
// If we have a memcpy instruction, replace it with multiple accesses and
287
// subsequent stores to the destination
288
if (auto *MCI = dyn_cast<MemCpyInst>(Cur)) {
289
replaceMemCpy(MCI, CBR);
290
continue;
291
}
292
293
// Otherwise, walk users looking for a load...
294
if (isa<GetElementPtrInst>(Cur) || isa<GEPOperator>(Cur)) {
295
ToProcess.append(Cur->user_begin(), Cur->user_end());
296
continue;
297
}
298
299
llvm_unreachable("Unexpected user of Global");
300
}
301
RecursivelyDeleteTriviallyDeadInstructions(DeadInsts);
302
}
303
304
static bool replaceCBufferAccesses(Module &M) {
305
std::optional<hlsl::CBufferMetadata> CBufMD = hlsl::CBufferMetadata::get(M);
306
if (!CBufMD)
307
return false;
308
309
for (const hlsl::CBufferMapping &Mapping : *CBufMD)
310
for (const hlsl::CBufferMember &Member : Mapping.Members) {
311
CBufferResource CBR(Mapping.Handle, Member.GV, Member.Offset);
312
replaceAccessesWithHandle(CBR);
313
Member.GV->removeFromParent();
314
}
315
316
CBufMD->eraseFromModule();
317
return true;
318
}
319
320
PreservedAnalyses DXILCBufferAccess::run(Module &M, ModuleAnalysisManager &AM) {
321
PreservedAnalyses PA;
322
bool Changed = replaceCBufferAccesses(M);
323
324
if (!Changed)
325
return PreservedAnalyses::all();
326
return PA;
327
}
328
329
namespace {
330
class DXILCBufferAccessLegacy : public ModulePass {
331
public:
332
bool runOnModule(Module &M) override { return replaceCBufferAccesses(M); }
333
StringRef getPassName() const override { return "DXIL CBuffer Access"; }
334
DXILCBufferAccessLegacy() : ModulePass(ID) {}
335
336
static char ID; // Pass identification.
337
};
338
char DXILCBufferAccessLegacy::ID = 0;
339
} // end anonymous namespace
340
341
INITIALIZE_PASS(DXILCBufferAccessLegacy, DEBUG_TYPE, "DXIL CBuffer Access",
342
false, false)
343
344
ModulePass *llvm::createDXILCBufferAccessLegacyPass() {
345
return new DXILCBufferAccessLegacy();
346
}
347
348