Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
213799 views
1
//===- DXILFlattenArrays.cpp - Flattens DXIL Arrays-----------------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===---------------------------------------------------------------------===//
8
///
9
/// \file This file contains a pass to flatten arrays for the DirectX Backend.
10
///
11
//===----------------------------------------------------------------------===//
12
13
#include "DXILFlattenArrays.h"
14
#include "DirectX.h"
15
#include "llvm/ADT/PostOrderIterator.h"
16
#include "llvm/ADT/STLExtras.h"
17
#include "llvm/IR/BasicBlock.h"
18
#include "llvm/IR/DerivedTypes.h"
19
#include "llvm/IR/IRBuilder.h"
20
#include "llvm/IR/InstVisitor.h"
21
#include "llvm/IR/ReplaceConstant.h"
22
#include "llvm/Support/Casting.h"
23
#include "llvm/Support/MathExtras.h"
24
#include "llvm/Transforms/Utils/Local.h"
25
#include <cassert>
26
#include <cstddef>
27
#include <cstdint>
28
#include <utility>
29
30
#define DEBUG_TYPE "dxil-flatten-arrays"
31
32
using namespace llvm;
33
namespace {
34
35
class DXILFlattenArraysLegacy : public ModulePass {
36
37
public:
38
bool runOnModule(Module &M) override;
39
DXILFlattenArraysLegacy() : ModulePass(ID) {}
40
41
static char ID; // Pass identification.
42
};
43
44
struct GEPInfo {
45
ArrayType *RootFlattenedArrayType;
46
Value *RootPointerOperand;
47
SmallMapVector<Value *, APInt, 4> VariableOffsets;
48
APInt ConstantOffset;
49
};
50
51
class DXILFlattenArraysVisitor
52
: public InstVisitor<DXILFlattenArraysVisitor, bool> {
53
public:
54
DXILFlattenArraysVisitor(
55
SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap)
56
: GlobalMap(GlobalMap) {}
57
bool visit(Function &F);
58
// InstVisitor methods. They return true if the instruction was scalarized,
59
// false if nothing changed.
60
bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
61
bool visitAllocaInst(AllocaInst &AI);
62
bool visitInstruction(Instruction &I) { return false; }
63
bool visitSelectInst(SelectInst &SI) { return false; }
64
bool visitICmpInst(ICmpInst &ICI) { return false; }
65
bool visitFCmpInst(FCmpInst &FCI) { return false; }
66
bool visitUnaryOperator(UnaryOperator &UO) { return false; }
67
bool visitBinaryOperator(BinaryOperator &BO) { return false; }
68
bool visitCastInst(CastInst &CI) { return false; }
69
bool visitBitCastInst(BitCastInst &BCI) { return false; }
70
bool visitInsertElementInst(InsertElementInst &IEI) { return false; }
71
bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
72
bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
73
bool visitPHINode(PHINode &PHI) { return false; }
74
bool visitLoadInst(LoadInst &LI);
75
bool visitStoreInst(StoreInst &SI);
76
bool visitCallInst(CallInst &ICI) { return false; }
77
bool visitFreezeInst(FreezeInst &FI) { return false; }
78
static bool isMultiDimensionalArray(Type *T);
79
static std::pair<unsigned, Type *> getElementCountAndType(Type *ArrayTy);
80
81
private:
82
SmallVector<WeakTrackingVH> PotentiallyDeadInstrs;
83
SmallDenseMap<GEPOperator *, GEPInfo> GEPChainInfoMap;
84
SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap;
85
bool finish();
86
ConstantInt *genConstFlattenIndices(ArrayRef<Value *> Indices,
87
ArrayRef<uint64_t> Dims,
88
IRBuilder<> &Builder);
89
Value *genInstructionFlattenIndices(ArrayRef<Value *> Indices,
90
ArrayRef<uint64_t> Dims,
91
IRBuilder<> &Builder);
92
};
93
} // namespace
94
95
bool DXILFlattenArraysVisitor::finish() {
96
GEPChainInfoMap.clear();
97
RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
98
return true;
99
}
100
101
bool DXILFlattenArraysVisitor::isMultiDimensionalArray(Type *T) {
102
if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
103
return isa<ArrayType>(ArrType->getElementType());
104
return false;
105
}
106
107
std::pair<unsigned, Type *>
108
DXILFlattenArraysVisitor::getElementCountAndType(Type *ArrayTy) {
109
unsigned TotalElements = 1;
110
Type *CurrArrayTy = ArrayTy;
111
while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) {
112
TotalElements *= InnerArrayTy->getNumElements();
113
CurrArrayTy = InnerArrayTy->getElementType();
114
}
115
return std::make_pair(TotalElements, CurrArrayTy);
116
}
117
118
ConstantInt *DXILFlattenArraysVisitor::genConstFlattenIndices(
119
ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
120
assert(Indices.size() == Dims.size() &&
121
"Indicies and dimmensions should be the same");
122
unsigned FlatIndex = 0;
123
unsigned Multiplier = 1;
124
125
for (int I = Indices.size() - 1; I >= 0; --I) {
126
unsigned DimSize = Dims[I];
127
ConstantInt *CIndex = dyn_cast<ConstantInt>(Indices[I]);
128
assert(CIndex && "This function expects all indicies to be ConstantInt");
129
FlatIndex += CIndex->getZExtValue() * Multiplier;
130
Multiplier *= DimSize;
131
}
132
return Builder.getInt32(FlatIndex);
133
}
134
135
Value *DXILFlattenArraysVisitor::genInstructionFlattenIndices(
136
ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) {
137
if (Indices.size() == 1)
138
return Indices[0];
139
140
Value *FlatIndex = Builder.getInt32(0);
141
unsigned Multiplier = 1;
142
143
for (int I = Indices.size() - 1; I >= 0; --I) {
144
unsigned DimSize = Dims[I];
145
Value *VMultiplier = Builder.getInt32(Multiplier);
146
Value *ScaledIndex = Builder.CreateMul(Indices[I], VMultiplier);
147
FlatIndex = Builder.CreateAdd(FlatIndex, ScaledIndex);
148
Multiplier *= DimSize;
149
}
150
return FlatIndex;
151
}
152
153
bool DXILFlattenArraysVisitor::visitLoadInst(LoadInst &LI) {
154
unsigned NumOperands = LI.getNumOperands();
155
for (unsigned I = 0; I < NumOperands; ++I) {
156
Value *CurrOpperand = LI.getOperand(I);
157
ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
158
if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
159
GetElementPtrInst *OldGEP =
160
cast<GetElementPtrInst>(CE->getAsInstruction());
161
OldGEP->insertBefore(LI.getIterator());
162
163
IRBuilder<> Builder(&LI);
164
LoadInst *NewLoad =
165
Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());
166
NewLoad->setAlignment(LI.getAlign());
167
LI.replaceAllUsesWith(NewLoad);
168
LI.eraseFromParent();
169
visitGetElementPtrInst(*OldGEP);
170
return true;
171
}
172
}
173
return false;
174
}
175
176
bool DXILFlattenArraysVisitor::visitStoreInst(StoreInst &SI) {
177
unsigned NumOperands = SI.getNumOperands();
178
for (unsigned I = 0; I < NumOperands; ++I) {
179
Value *CurrOpperand = SI.getOperand(I);
180
ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
181
if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
182
GetElementPtrInst *OldGEP =
183
cast<GetElementPtrInst>(CE->getAsInstruction());
184
OldGEP->insertBefore(SI.getIterator());
185
186
IRBuilder<> Builder(&SI);
187
StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
188
NewStore->setAlignment(SI.getAlign());
189
SI.replaceAllUsesWith(NewStore);
190
SI.eraseFromParent();
191
visitGetElementPtrInst(*OldGEP);
192
return true;
193
}
194
}
195
return false;
196
}
197
198
bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
199
if (!isMultiDimensionalArray(AI.getAllocatedType()))
200
return false;
201
202
ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType());
203
IRBuilder<> Builder(&AI);
204
auto [TotalElements, BaseType] = getElementCountAndType(ArrType);
205
206
ArrayType *FattenedArrayType = ArrayType::get(BaseType, TotalElements);
207
AllocaInst *FlatAlloca =
208
Builder.CreateAlloca(FattenedArrayType, nullptr, AI.getName() + ".1dim");
209
FlatAlloca->setAlignment(AI.getAlign());
210
AI.replaceAllUsesWith(FlatAlloca);
211
AI.eraseFromParent();
212
return true;
213
}
214
215
bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
216
// Do not visit GEPs more than once
217
if (GEPChainInfoMap.contains(cast<GEPOperator>(&GEP)))
218
return false;
219
220
Value *PtrOperand = GEP.getPointerOperand();
221
// It shouldn't(?) be possible for the pointer operand of a GEP to be a PHI
222
// node unless HLSL has pointers. If this assumption is incorrect or HLSL gets
223
// pointer types, then the handling of this case can be implemented later.
224
assert(!isa<PHINode>(PtrOperand) &&
225
"Pointer operand of GEP should not be a PHI Node");
226
227
// Replace a GEP ConstantExpr pointer operand with a GEP instruction so that
228
// it can be visited
229
if (auto *PtrOpGEPCE = dyn_cast<ConstantExpr>(PtrOperand);
230
PtrOpGEPCE && PtrOpGEPCE->getOpcode() == Instruction::GetElementPtr) {
231
GetElementPtrInst *OldGEPI =
232
cast<GetElementPtrInst>(PtrOpGEPCE->getAsInstruction());
233
OldGEPI->insertBefore(GEP.getIterator());
234
235
IRBuilder<> Builder(&GEP);
236
SmallVector<Value *> Indices(GEP.indices());
237
Value *NewGEP =
238
Builder.CreateGEP(GEP.getSourceElementType(), OldGEPI, Indices,
239
GEP.getName(), GEP.getNoWrapFlags());
240
assert(isa<GetElementPtrInst>(NewGEP) &&
241
"Expected newly-created GEP to be an instruction");
242
GetElementPtrInst *NewGEPI = cast<GetElementPtrInst>(NewGEP);
243
244
GEP.replaceAllUsesWith(NewGEPI);
245
GEP.eraseFromParent();
246
visitGetElementPtrInst(*OldGEPI);
247
visitGetElementPtrInst(*NewGEPI);
248
return true;
249
}
250
251
// Construct GEPInfo for this GEP
252
GEPInfo Info;
253
254
// Obtain the variable and constant byte offsets computed by this GEP
255
const DataLayout &DL = GEP.getDataLayout();
256
unsigned BitWidth = DL.getIndexTypeSizeInBits(GEP.getType());
257
Info.ConstantOffset = {BitWidth, 0};
258
[[maybe_unused]] bool Success = GEP.collectOffset(
259
DL, BitWidth, Info.VariableOffsets, Info.ConstantOffset);
260
assert(Success && "Failed to collect offsets for GEP");
261
262
// If there is a parent GEP, inherit the root array type and pointer, and
263
// merge the byte offsets. Otherwise, this GEP is itself the root of a GEP
264
// chain and we need to deterine the root array type
265
if (auto *PtrOpGEP = dyn_cast<GEPOperator>(PtrOperand)) {
266
assert(GEPChainInfoMap.contains(PtrOpGEP) &&
267
"Expected parent GEP to be visited before this GEP");
268
GEPInfo &PGEPInfo = GEPChainInfoMap[PtrOpGEP];
269
Info.RootFlattenedArrayType = PGEPInfo.RootFlattenedArrayType;
270
Info.RootPointerOperand = PGEPInfo.RootPointerOperand;
271
for (auto &VariableOffset : PGEPInfo.VariableOffsets)
272
Info.VariableOffsets.insert(VariableOffset);
273
Info.ConstantOffset += PGEPInfo.ConstantOffset;
274
} else {
275
Info.RootPointerOperand = PtrOperand;
276
277
// We should try to determine the type of the root from the pointer rather
278
// than the GEP's source element type because this could be a scalar GEP
279
// into an array-typed pointer from an Alloca or Global Variable.
280
Type *RootTy = GEP.getSourceElementType();
281
if (auto *GlobalVar = dyn_cast<GlobalVariable>(PtrOperand)) {
282
if (GlobalMap.contains(GlobalVar))
283
GlobalVar = GlobalMap[GlobalVar];
284
Info.RootPointerOperand = GlobalVar;
285
RootTy = GlobalVar->getValueType();
286
} else if (auto *Alloca = dyn_cast<AllocaInst>(PtrOperand))
287
RootTy = Alloca->getAllocatedType();
288
assert(!isMultiDimensionalArray(RootTy) &&
289
"Expected root array type to be flattened");
290
291
// If the root type is not an array, we don't need to do any flattening
292
if (!isa<ArrayType>(RootTy))
293
return false;
294
295
Info.RootFlattenedArrayType = cast<ArrayType>(RootTy);
296
}
297
298
// GEPs without users or GEPs with non-GEP users should be replaced such that
299
// the chain of GEPs they are a part of are collapsed to a single GEP into a
300
// flattened array.
301
bool ReplaceThisGEP = GEP.users().empty();
302
for (Value *User : GEP.users())
303
if (!isa<GetElementPtrInst>(User))
304
ReplaceThisGEP = true;
305
306
if (ReplaceThisGEP) {
307
unsigned BytesPerElem =
308
DL.getTypeAllocSize(Info.RootFlattenedArrayType->getArrayElementType());
309
assert(isPowerOf2_32(BytesPerElem) &&
310
"Bytes per element should be a power of 2");
311
312
// Compute the 32-bit index for this flattened GEP from the constant and
313
// variable byte offsets in the GEPInfo
314
IRBuilder<> Builder(&GEP);
315
Value *ZeroIndex = Builder.getInt32(0);
316
uint64_t ConstantOffset =
317
Info.ConstantOffset.udiv(BytesPerElem).getZExtValue();
318
assert(ConstantOffset < UINT32_MAX &&
319
"Constant byte offset for flat GEP index must fit within 32 bits");
320
Value *FlattenedIndex = Builder.getInt32(ConstantOffset);
321
for (auto [VarIndex, Multiplier] : Info.VariableOffsets) {
322
assert(Multiplier.getActiveBits() <= 32 &&
323
"The multiplier for a flat GEP index must fit within 32 bits");
324
assert(VarIndex->getType()->isIntegerTy(32) &&
325
"Expected i32-typed GEP indices");
326
Value *VI;
327
if (Multiplier.getZExtValue() % BytesPerElem != 0) {
328
// This can happen, e.g., with i8 GEPs. To handle this we just divide
329
// by BytesPerElem using an instruction after multiplying VarIndex by
330
// Multiplier.
331
VI = Builder.CreateMul(VarIndex,
332
Builder.getInt32(Multiplier.getZExtValue()));
333
VI = Builder.CreateLShr(VI, Builder.getInt32(Log2_32(BytesPerElem)));
334
} else
335
VI = Builder.CreateMul(
336
VarIndex,
337
Builder.getInt32(Multiplier.getZExtValue() / BytesPerElem));
338
FlattenedIndex = Builder.CreateAdd(FlattenedIndex, VI);
339
}
340
341
// Construct a new GEP for the flattened array to replace the current GEP
342
Value *NewGEP = Builder.CreateGEP(
343
Info.RootFlattenedArrayType, Info.RootPointerOperand,
344
{ZeroIndex, FlattenedIndex}, GEP.getName(), GEP.getNoWrapFlags());
345
346
// Replace the current GEP with the new GEP. Store GEPInfo into the map
347
// for later use in case this GEP was not the end of the chain
348
GEPChainInfoMap.insert({cast<GEPOperator>(NewGEP), std::move(Info)});
349
GEP.replaceAllUsesWith(NewGEP);
350
GEP.eraseFromParent();
351
return true;
352
}
353
354
// This GEP is potentially dead at the end of the pass since it may not have
355
// any users anymore after GEP chains have been collapsed. We retain store
356
// GEPInfo for GEPs down the chain to use to compute their indices.
357
GEPChainInfoMap.insert({cast<GEPOperator>(&GEP), std::move(Info)});
358
PotentiallyDeadInstrs.emplace_back(&GEP);
359
return false;
360
}
361
362
bool DXILFlattenArraysVisitor::visit(Function &F) {
363
bool MadeChange = false;
364
ReversePostOrderTraversal<Function *> RPOT(&F);
365
for (BasicBlock *BB : make_early_inc_range(RPOT)) {
366
for (Instruction &I : make_early_inc_range(*BB))
367
MadeChange |= InstVisitor::visit(I);
368
}
369
finish();
370
return MadeChange;
371
}
372
373
static void collectElements(Constant *Init,
374
SmallVectorImpl<Constant *> &Elements) {
375
// Base case: If Init is not an array, add it directly to the vector.
376
auto *ArrayTy = dyn_cast<ArrayType>(Init->getType());
377
if (!ArrayTy) {
378
Elements.push_back(Init);
379
return;
380
}
381
unsigned ArrSize = ArrayTy->getNumElements();
382
if (isa<ConstantAggregateZero>(Init)) {
383
for (unsigned I = 0; I < ArrSize; ++I)
384
Elements.push_back(Constant::getNullValue(ArrayTy->getElementType()));
385
return;
386
}
387
388
// Recursive case: Process each element in the array.
389
if (auto *ArrayConstant = dyn_cast<ConstantArray>(Init)) {
390
for (unsigned I = 0; I < ArrayConstant->getNumOperands(); ++I) {
391
collectElements(ArrayConstant->getOperand(I), Elements);
392
}
393
} else if (auto *DataArrayConstant = dyn_cast<ConstantDataArray>(Init)) {
394
for (unsigned I = 0; I < DataArrayConstant->getNumElements(); ++I) {
395
collectElements(DataArrayConstant->getElementAsConstant(I), Elements);
396
}
397
} else {
398
llvm_unreachable(
399
"Expected a ConstantArray or ConstantDataArray for array initializer!");
400
}
401
}
402
403
static Constant *transformInitializer(Constant *Init, Type *OrigType,
404
ArrayType *FlattenedType,
405
LLVMContext &Ctx) {
406
// Handle ConstantAggregateZero (zero-initialized constants)
407
if (isa<ConstantAggregateZero>(Init))
408
return ConstantAggregateZero::get(FlattenedType);
409
410
// Handle UndefValue (undefined constants)
411
if (isa<UndefValue>(Init))
412
return UndefValue::get(FlattenedType);
413
414
if (!isa<ArrayType>(OrigType))
415
return Init;
416
417
SmallVector<Constant *> FlattenedElements;
418
collectElements(Init, FlattenedElements);
419
assert(FlattenedType->getNumElements() == FlattenedElements.size() &&
420
"The number of collected elements should match the FlattenedType");
421
return ConstantArray::get(FlattenedType, FlattenedElements);
422
}
423
424
static void flattenGlobalArrays(
425
Module &M, SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap) {
426
LLVMContext &Ctx = M.getContext();
427
for (GlobalVariable &G : M.globals()) {
428
Type *OrigType = G.getValueType();
429
if (!DXILFlattenArraysVisitor::isMultiDimensionalArray(OrigType))
430
continue;
431
432
ArrayType *ArrType = cast<ArrayType>(OrigType);
433
auto [TotalElements, BaseType] =
434
DXILFlattenArraysVisitor::getElementCountAndType(ArrType);
435
ArrayType *FattenedArrayType = ArrayType::get(BaseType, TotalElements);
436
437
// Create a new global variable with the updated type
438
// Note: Initializer is set via transformInitializer
439
GlobalVariable *NewGlobal =
440
new GlobalVariable(M, FattenedArrayType, G.isConstant(), G.getLinkage(),
441
/*Initializer=*/nullptr, G.getName() + ".1dim", &G,
442
G.getThreadLocalMode(), G.getAddressSpace(),
443
G.isExternallyInitialized());
444
445
// Copy relevant attributes
446
NewGlobal->setUnnamedAddr(G.getUnnamedAddr());
447
if (G.getAlignment() > 0) {
448
NewGlobal->setAlignment(G.getAlign());
449
}
450
451
if (G.hasInitializer()) {
452
Constant *Init = G.getInitializer();
453
Constant *NewInit =
454
transformInitializer(Init, OrigType, FattenedArrayType, Ctx);
455
NewGlobal->setInitializer(NewInit);
456
}
457
GlobalMap[&G] = NewGlobal;
458
}
459
}
460
461
static bool flattenArrays(Module &M) {
462
bool MadeChange = false;
463
SmallDenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
464
flattenGlobalArrays(M, GlobalMap);
465
DXILFlattenArraysVisitor Impl(GlobalMap);
466
for (auto &F : make_early_inc_range(M.functions())) {
467
if (F.isDeclaration())
468
continue;
469
MadeChange |= Impl.visit(F);
470
}
471
for (auto &[Old, New] : GlobalMap) {
472
Old->replaceAllUsesWith(New);
473
Old->eraseFromParent();
474
MadeChange = true;
475
}
476
return MadeChange;
477
}
478
479
PreservedAnalyses DXILFlattenArrays::run(Module &M, ModuleAnalysisManager &) {
480
bool MadeChanges = flattenArrays(M);
481
if (!MadeChanges)
482
return PreservedAnalyses::all();
483
PreservedAnalyses PA;
484
return PA;
485
}
486
487
bool DXILFlattenArraysLegacy::runOnModule(Module &M) {
488
return flattenArrays(M);
489
}
490
491
char DXILFlattenArraysLegacy::ID = 0;
492
493
INITIALIZE_PASS_BEGIN(DXILFlattenArraysLegacy, DEBUG_TYPE,
494
"DXIL Array Flattener", false, false)
495
INITIALIZE_PASS_END(DXILFlattenArraysLegacy, DEBUG_TYPE, "DXIL Array Flattener",
496
false, false)
497
498
ModulePass *llvm::createDXILFlattenArraysLegacyPass() {
499
return new DXILFlattenArraysLegacy();
500
}
501
502