Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
213799 views
1
//===- DXILDataScalarization.cpp - Perform DXIL Data Legalization ---------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===---------------------------------------------------------------------===//
8
9
#include "DXILDataScalarization.h"
10
#include "DirectX.h"
11
#include "llvm/ADT/PostOrderIterator.h"
12
#include "llvm/ADT/STLExtras.h"
13
#include "llvm/IR/DerivedTypes.h"
14
#include "llvm/IR/GlobalVariable.h"
15
#include "llvm/IR/IRBuilder.h"
16
#include "llvm/IR/InstVisitor.h"
17
#include "llvm/IR/Instructions.h"
18
#include "llvm/IR/Module.h"
19
#include "llvm/IR/Operator.h"
20
#include "llvm/IR/PassManager.h"
21
#include "llvm/IR/ReplaceConstant.h"
22
#include "llvm/IR/Type.h"
23
#include "llvm/Support/Casting.h"
24
#include "llvm/Transforms/Utils/Cloning.h"
25
#include "llvm/Transforms/Utils/Local.h"
26
27
#define DEBUG_TYPE "dxil-data-scalarization"
28
static const int MaxVecSize = 4;
29
30
using namespace llvm;
31
32
// Recursively creates an array-like version of a given vector type.
33
static Type *equivalentArrayTypeFromVector(Type *T) {
34
if (auto *VecTy = dyn_cast<VectorType>(T))
35
return ArrayType::get(VecTy->getElementType(),
36
dyn_cast<FixedVectorType>(VecTy)->getNumElements());
37
if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
38
Type *NewElementType =
39
equivalentArrayTypeFromVector(ArrayTy->getElementType());
40
return ArrayType::get(NewElementType, ArrayTy->getNumElements());
41
}
42
// If it's not a vector or array, return the original type.
43
return T;
44
}
45
46
class DXILDataScalarizationLegacy : public ModulePass {
47
48
public:
49
bool runOnModule(Module &M) override;
50
DXILDataScalarizationLegacy() : ModulePass(ID) {}
51
52
static char ID; // Pass identification.
53
};
54
55
static bool findAndReplaceVectors(Module &M);
56
57
class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
58
public:
59
DataScalarizerVisitor() : GlobalMap() {}
60
bool visit(Function &F);
61
// InstVisitor methods. They return true if the instruction was scalarized,
62
// false if nothing changed.
63
bool visitAllocaInst(AllocaInst &AI);
64
bool visitInstruction(Instruction &I) { return false; }
65
bool visitSelectInst(SelectInst &SI) { return false; }
66
bool visitICmpInst(ICmpInst &ICI) { return false; }
67
bool visitFCmpInst(FCmpInst &FCI) { return false; }
68
bool visitUnaryOperator(UnaryOperator &UO) { return false; }
69
bool visitBinaryOperator(BinaryOperator &BO) { return false; }
70
bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
71
bool visitCastInst(CastInst &CI) { return false; }
72
bool visitBitCastInst(BitCastInst &BCI) { return false; }
73
bool visitInsertElementInst(InsertElementInst &IEI);
74
bool visitExtractElementInst(ExtractElementInst &EEI);
75
bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
76
bool visitPHINode(PHINode &PHI) { return false; }
77
bool visitLoadInst(LoadInst &LI);
78
bool visitStoreInst(StoreInst &SI);
79
bool visitCallInst(CallInst &ICI) { return false; }
80
bool visitFreezeInst(FreezeInst &FI) { return false; }
81
friend bool findAndReplaceVectors(llvm::Module &M);
82
83
private:
84
typedef std::pair<AllocaInst *, SmallVector<Value *, 4>> AllocaAndGEPs;
85
typedef SmallDenseMap<Value *, AllocaAndGEPs>
86
VectorToArrayMap; // A map from a vector-typed Value to its corresponding
87
// AllocaInst and GEPs to each element of an array
88
VectorToArrayMap VectorAllocaMap;
89
AllocaAndGEPs createArrayFromVector(IRBuilder<> &Builder, Value *Vec,
90
const Twine &Name);
91
bool replaceDynamicInsertElementInst(InsertElementInst &IEI);
92
bool replaceDynamicExtractElementInst(ExtractElementInst &EEI);
93
94
GlobalVariable *lookupReplacementGlobal(Value *CurrOperand);
95
DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
96
};
97
98
bool DataScalarizerVisitor::visit(Function &F) {
99
bool MadeChange = false;
100
ReversePostOrderTraversal<Function *> RPOT(&F);
101
for (BasicBlock *BB : make_early_inc_range(RPOT)) {
102
for (Instruction &I : make_early_inc_range(*BB))
103
MadeChange |= InstVisitor::visit(I);
104
}
105
VectorAllocaMap.clear();
106
return MadeChange;
107
}
108
109
GlobalVariable *
110
DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
111
if (GlobalVariable *OldGlobal = dyn_cast<GlobalVariable>(CurrOperand)) {
112
auto It = GlobalMap.find(OldGlobal);
113
if (It != GlobalMap.end()) {
114
return It->second; // Found, return the new global
115
}
116
}
117
return nullptr; // Not found
118
}
119
120
// Helper function to check if a type is a vector or an array of vectors
121
static bool isVectorOrArrayOfVectors(Type *T) {
122
if (isa<VectorType>(T))
123
return true;
124
if (ArrayType *ArrType = dyn_cast<ArrayType>(T))
125
return isa<VectorType>(ArrType->getElementType()) ||
126
isVectorOrArrayOfVectors(ArrType->getElementType());
127
return false;
128
}
129
130
bool DataScalarizerVisitor::visitAllocaInst(AllocaInst &AI) {
131
Type *AllocatedType = AI.getAllocatedType();
132
if (!isVectorOrArrayOfVectors(AllocatedType))
133
return false;
134
135
IRBuilder<> Builder(&AI);
136
Type *NewType = equivalentArrayTypeFromVector(AllocatedType);
137
AllocaInst *ArrAlloca =
138
Builder.CreateAlloca(NewType, nullptr, AI.getName() + ".scalarize");
139
ArrAlloca->setAlignment(AI.getAlign());
140
AI.replaceAllUsesWith(ArrAlloca);
141
AI.eraseFromParent();
142
return true;
143
}
144
145
bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {
146
Value *PtrOperand = LI.getPointerOperand();
147
ConstantExpr *CE = dyn_cast<ConstantExpr>(PtrOperand);
148
if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
149
GetElementPtrInst *OldGEP = cast<GetElementPtrInst>(CE->getAsInstruction());
150
OldGEP->insertBefore(LI.getIterator());
151
IRBuilder<> Builder(&LI);
152
LoadInst *NewLoad = Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());
153
NewLoad->setAlignment(LI.getAlign());
154
LI.replaceAllUsesWith(NewLoad);
155
LI.eraseFromParent();
156
visitGetElementPtrInst(*OldGEP);
157
return true;
158
}
159
if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand))
160
LI.setOperand(LI.getPointerOperandIndex(), NewGlobal);
161
return false;
162
}
163
164
bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
165
166
Value *PtrOperand = SI.getPointerOperand();
167
ConstantExpr *CE = dyn_cast<ConstantExpr>(PtrOperand);
168
if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
169
GetElementPtrInst *OldGEP = cast<GetElementPtrInst>(CE->getAsInstruction());
170
OldGEP->insertBefore(SI.getIterator());
171
IRBuilder<> Builder(&SI);
172
StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
173
NewStore->setAlignment(SI.getAlign());
174
SI.replaceAllUsesWith(NewStore);
175
SI.eraseFromParent();
176
visitGetElementPtrInst(*OldGEP);
177
return true;
178
}
179
if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand))
180
SI.setOperand(SI.getPointerOperandIndex(), NewGlobal);
181
182
return false;
183
}
184
185
DataScalarizerVisitor::AllocaAndGEPs
186
DataScalarizerVisitor::createArrayFromVector(IRBuilder<> &Builder, Value *Vec,
187
const Twine &Name = "") {
188
// If there is already an alloca for this vector, return it
189
if (VectorAllocaMap.contains(Vec))
190
return VectorAllocaMap[Vec];
191
192
auto InsertPoint = Builder.GetInsertPoint();
193
194
// Allocate the array to hold the vector elements
195
Builder.SetInsertPointPastAllocas(Builder.GetInsertBlock()->getParent());
196
Type *ArrTy = equivalentArrayTypeFromVector(Vec->getType());
197
AllocaInst *ArrAlloca =
198
Builder.CreateAlloca(ArrTy, nullptr, Name + ".alloca");
199
const uint64_t ArrNumElems = ArrTy->getArrayNumElements();
200
201
// Create loads and stores to populate the array immediately after the
202
// original vector's defining instruction if available, else immediately after
203
// the alloca
204
if (auto *Instr = dyn_cast<Instruction>(Vec))
205
Builder.SetInsertPoint(Instr->getNextNonDebugInstruction());
206
SmallVector<Value *, 4> GEPs(ArrNumElems);
207
for (unsigned I = 0; I < ArrNumElems; ++I) {
208
Value *EE = Builder.CreateExtractElement(Vec, I, Name + ".extract");
209
GEPs[I] = Builder.CreateInBoundsGEP(
210
ArrTy, ArrAlloca, {Builder.getInt32(0), Builder.getInt32(I)},
211
Name + ".index");
212
Builder.CreateStore(EE, GEPs[I]);
213
}
214
215
VectorAllocaMap.insert({Vec, {ArrAlloca, GEPs}});
216
Builder.SetInsertPoint(InsertPoint);
217
return {ArrAlloca, GEPs};
218
}
219
220
/// Returns a pair of Value* with the first being a GEP into ArrAlloca using
221
/// indices {0, Index}, and the second Value* being a Load of the GEP
222
static std::pair<Value *, Value *>
223
dynamicallyLoadArray(IRBuilder<> &Builder, AllocaInst *ArrAlloca, Value *Index,
224
const Twine &Name = "") {
225
Type *ArrTy = ArrAlloca->getAllocatedType();
226
Value *GEP = Builder.CreateInBoundsGEP(
227
ArrTy, ArrAlloca, {Builder.getInt32(0), Index}, Name + ".index");
228
Value *Load =
229
Builder.CreateLoad(ArrTy->getArrayElementType(), GEP, Name + ".load");
230
return std::make_pair(GEP, Load);
231
}
232
233
bool DataScalarizerVisitor::replaceDynamicInsertElementInst(
234
InsertElementInst &IEI) {
235
IRBuilder<> Builder(&IEI);
236
237
Value *Vec = IEI.getOperand(0);
238
Value *Val = IEI.getOperand(1);
239
Value *Index = IEI.getOperand(2);
240
241
AllocaAndGEPs ArrAllocaAndGEPs =
242
createArrayFromVector(Builder, Vec, IEI.getName());
243
AllocaInst *ArrAlloca = ArrAllocaAndGEPs.first;
244
Type *ArrTy = ArrAlloca->getAllocatedType();
245
SmallVector<Value *, 4> &ArrGEPs = ArrAllocaAndGEPs.second;
246
247
auto GEPAndLoad =
248
dynamicallyLoadArray(Builder, ArrAlloca, Index, IEI.getName());
249
Value *GEP = GEPAndLoad.first;
250
Value *Load = GEPAndLoad.second;
251
252
Builder.CreateStore(Val, GEP);
253
Value *NewIEI = PoisonValue::get(Vec->getType());
254
for (unsigned I = 0; I < ArrTy->getArrayNumElements(); ++I) {
255
Value *Load = Builder.CreateLoad(ArrTy->getArrayElementType(), ArrGEPs[I],
256
IEI.getName() + ".load");
257
NewIEI = Builder.CreateInsertElement(NewIEI, Load, Builder.getInt32(I),
258
IEI.getName() + ".insert");
259
}
260
261
// Store back the original value so the Alloca can be reused for subsequent
262
// insertelement instructions on the same vector
263
Builder.CreateStore(Load, GEP);
264
265
IEI.replaceAllUsesWith(NewIEI);
266
IEI.eraseFromParent();
267
return true;
268
}
269
270
bool DataScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
271
// If the index is a constant then we don't need to scalarize it
272
Value *Index = IEI.getOperand(2);
273
if (isa<ConstantInt>(Index))
274
return false;
275
return replaceDynamicInsertElementInst(IEI);
276
}
277
278
bool DataScalarizerVisitor::replaceDynamicExtractElementInst(
279
ExtractElementInst &EEI) {
280
IRBuilder<> Builder(&EEI);
281
282
AllocaAndGEPs ArrAllocaAndGEPs =
283
createArrayFromVector(Builder, EEI.getVectorOperand(), EEI.getName());
284
AllocaInst *ArrAlloca = ArrAllocaAndGEPs.first;
285
286
auto GEPAndLoad = dynamicallyLoadArray(Builder, ArrAlloca,
287
EEI.getIndexOperand(), EEI.getName());
288
Value *Load = GEPAndLoad.second;
289
290
EEI.replaceAllUsesWith(Load);
291
EEI.eraseFromParent();
292
return true;
293
}
294
295
bool DataScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
296
// If the index is a constant then we don't need to scalarize it
297
Value *Index = EEI.getIndexOperand();
298
if (isa<ConstantInt>(Index))
299
return false;
300
return replaceDynamicExtractElementInst(EEI);
301
}
302
303
bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
304
Value *PtrOperand = GEPI.getPointerOperand();
305
Type *OrigGEPType = GEPI.getSourceElementType();
306
Type *NewGEPType = OrigGEPType;
307
bool NeedsTransform = false;
308
309
if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand)) {
310
NewGEPType = NewGlobal->getValueType();
311
PtrOperand = NewGlobal;
312
NeedsTransform = true;
313
} else if (AllocaInst *Alloca = dyn_cast<AllocaInst>(PtrOperand)) {
314
Type *AllocatedType = Alloca->getAllocatedType();
315
// Only transform if the allocated type is an array
316
if (AllocatedType != OrigGEPType && isa<ArrayType>(AllocatedType)) {
317
NewGEPType = AllocatedType;
318
NeedsTransform = true;
319
}
320
}
321
322
// Scalar geps should remain scalars geps. The dxil-flatten-arrays pass will
323
// convert these scalar geps into flattened array geps
324
if (!isa<ArrayType>(OrigGEPType))
325
NewGEPType = OrigGEPType;
326
327
// Note: We bail if this isn't a gep touched via alloca or global
328
// transformations
329
if (!NeedsTransform)
330
return false;
331
332
IRBuilder<> Builder(&GEPI);
333
SmallVector<Value *, MaxVecSize> Indices(GEPI.indices());
334
335
Value *NewGEP = Builder.CreateGEP(NewGEPType, PtrOperand, Indices,
336
GEPI.getName(), GEPI.getNoWrapFlags());
337
GEPI.replaceAllUsesWith(NewGEP);
338
GEPI.eraseFromParent();
339
return true;
340
}
341
342
static Constant *transformInitializer(Constant *Init, Type *OrigType,
343
Type *NewType, LLVMContext &Ctx) {
344
// Handle ConstantAggregateZero (zero-initialized constants)
345
if (isa<ConstantAggregateZero>(Init)) {
346
return ConstantAggregateZero::get(NewType);
347
}
348
349
// Handle UndefValue (undefined constants)
350
if (isa<UndefValue>(Init)) {
351
return UndefValue::get(NewType);
352
}
353
354
// Handle vector to array transformation
355
if (isa<VectorType>(OrigType) && isa<ArrayType>(NewType)) {
356
// Convert vector initializer to array initializer
357
SmallVector<Constant *, MaxVecSize> ArrayElements;
358
if (ConstantVector *ConstVecInit = dyn_cast<ConstantVector>(Init)) {
359
for (unsigned I = 0; I < ConstVecInit->getNumOperands(); ++I)
360
ArrayElements.push_back(ConstVecInit->getOperand(I));
361
} else if (ConstantDataVector *ConstDataVecInit =
362
llvm::dyn_cast<llvm::ConstantDataVector>(Init)) {
363
for (unsigned I = 0; I < ConstDataVecInit->getNumElements(); ++I)
364
ArrayElements.push_back(ConstDataVecInit->getElementAsConstant(I));
365
} else {
366
assert(false && "Expected a ConstantVector or ConstantDataVector for "
367
"vector initializer!");
368
}
369
370
return ConstantArray::get(cast<ArrayType>(NewType), ArrayElements);
371
}
372
373
// Handle array of vectors transformation
374
if (auto *ArrayTy = dyn_cast<ArrayType>(OrigType)) {
375
auto *ArrayInit = dyn_cast<ConstantArray>(Init);
376
assert(ArrayInit && "Expected a ConstantArray for array initializer!");
377
378
SmallVector<Constant *, MaxVecSize> NewArrayElements;
379
for (unsigned I = 0; I < ArrayTy->getNumElements(); ++I) {
380
// Recursively transform array elements
381
Constant *NewElemInit = transformInitializer(
382
ArrayInit->getOperand(I), ArrayTy->getElementType(),
383
cast<ArrayType>(NewType)->getElementType(), Ctx);
384
NewArrayElements.push_back(NewElemInit);
385
}
386
387
return ConstantArray::get(cast<ArrayType>(NewType), NewArrayElements);
388
}
389
390
// If not a vector or array, return the original initializer
391
return Init;
392
}
393
394
static bool findAndReplaceVectors(Module &M) {
395
bool MadeChange = false;
396
LLVMContext &Ctx = M.getContext();
397
IRBuilder<> Builder(Ctx);
398
DataScalarizerVisitor Impl;
399
for (GlobalVariable &G : M.globals()) {
400
Type *OrigType = G.getValueType();
401
402
Type *NewType = equivalentArrayTypeFromVector(OrigType);
403
if (OrigType != NewType) {
404
// Create a new global variable with the updated type
405
// Note: Initializer is set via transformInitializer
406
GlobalVariable *NewGlobal = new GlobalVariable(
407
M, NewType, G.isConstant(), G.getLinkage(),
408
/*Initializer=*/nullptr, G.getName() + ".scalarized", &G,
409
G.getThreadLocalMode(), G.getAddressSpace(),
410
G.isExternallyInitialized());
411
412
// Copy relevant attributes
413
NewGlobal->setUnnamedAddr(G.getUnnamedAddr());
414
if (G.getAlignment() > 0) {
415
NewGlobal->setAlignment(G.getAlign());
416
}
417
418
if (G.hasInitializer()) {
419
Constant *Init = G.getInitializer();
420
Constant *NewInit = transformInitializer(Init, OrigType, NewType, Ctx);
421
NewGlobal->setInitializer(NewInit);
422
}
423
424
// Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes
425
// type equality. Instead we will use the visitor pattern.
426
Impl.GlobalMap[&G] = NewGlobal;
427
}
428
}
429
430
for (auto &F : make_early_inc_range(M.functions())) {
431
if (F.isDeclaration())
432
continue;
433
MadeChange |= Impl.visit(F);
434
}
435
436
// Remove the old globals after the iteration
437
for (auto &[Old, New] : Impl.GlobalMap) {
438
Old->eraseFromParent();
439
MadeChange = true;
440
}
441
return MadeChange;
442
}
443
444
PreservedAnalyses DXILDataScalarization::run(Module &M,
445
ModuleAnalysisManager &) {
446
bool MadeChanges = findAndReplaceVectors(M);
447
if (!MadeChanges)
448
return PreservedAnalyses::all();
449
PreservedAnalyses PA;
450
return PA;
451
}
452
453
bool DXILDataScalarizationLegacy::runOnModule(Module &M) {
454
return findAndReplaceVectors(M);
455
}
456
457
char DXILDataScalarizationLegacy::ID = 0;
458
459
INITIALIZE_PASS_BEGIN(DXILDataScalarizationLegacy, DEBUG_TYPE,
460
"DXIL Data Scalarization", false, false)
461
INITIALIZE_PASS_END(DXILDataScalarizationLegacy, DEBUG_TYPE,
462
"DXIL Data Scalarization", false, false)
463
464
ModulePass *llvm::createDXILDataScalarizationLegacyPass() {
465
return new DXILDataScalarizationLegacy();
466
}
467
468