Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
213799 views
1
//===-- SPIRVLegalizePointerCast.cpp ----------------------*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// The LLVM IR has multiple legal patterns we cannot lower to Logical SPIR-V.
10
// This pass modifies such loads to have an IR we can directly lower to valid
11
// logical SPIR-V.
12
// OpenCL can avoid this because they rely on ptrcast, which is not supported
13
// by logical SPIR-V.
14
//
15
// This pass relies on the assign_ptr_type intrinsic to deduce the type of the
16
// pointed values, must replace all occurences of `ptrcast`. This is why
17
// unhandled cases are reported as unreachable: we MUST cover all cases.
18
//
19
// 1. Loading the first element of an array
20
//
21
// %array = [10 x i32]
22
// %value = load i32, ptr %array
23
//
24
// LLVM can skip the GEP instruction, and only request loading the first 4
25
// bytes. In logical SPIR-V, we need an OpAccessChain to access the first
26
// element. This pass will add a getelementptr instruction before the load.
27
//
28
//
29
// 2. Implicit downcast from load
30
//
31
// %1 = getelementptr <4 x i32>, ptr %vec4, i64 0
32
// %2 = load <3 x i32>, ptr %1
33
//
34
// The pointer in the GEP instruction is only used for offset computations,
35
// but it doesn't NEED to match the pointed type. OpAccessChain however
36
// requires this. Also, LLVM loads define the bitwidth of the load, not the
37
// pointer. In this example, we can guess %vec4 is a vec4 thanks to the GEP
38
// instruction basetype, but we only want to load the first 3 elements, hence
39
// do a partial load. In logical SPIR-V, this is not legal. What we must do
40
// is load the full vector (basetype), extract 3 elements, and recombine them
41
// to form a 3-element vector.
42
//
43
//===----------------------------------------------------------------------===//
44
45
#include "SPIRV.h"
46
#include "SPIRVSubtarget.h"
47
#include "SPIRVTargetMachine.h"
48
#include "SPIRVUtils.h"
49
#include "llvm/CodeGen/IntrinsicLowering.h"
50
#include "llvm/IR/IRBuilder.h"
51
#include "llvm/IR/IntrinsicInst.h"
52
#include "llvm/IR/Intrinsics.h"
53
#include "llvm/IR/IntrinsicsSPIRV.h"
54
#include "llvm/Transforms/Utils/Cloning.h"
55
#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
56
57
using namespace llvm;
58
59
namespace {
60
class SPIRVLegalizePointerCast : public FunctionPass {
61
62
// Builds the `spv_assign_type` assigning |Ty| to |Value| at the current
63
// builder position.
64
void buildAssignType(IRBuilder<> &B, Type *Ty, Value *Arg) {
65
Value *OfType = PoisonValue::get(Ty);
66
CallInst *AssignCI = buildIntrWithMD(Intrinsic::spv_assign_type,
67
{Arg->getType()}, OfType, Arg, {}, B);
68
GR->addAssignPtrTypeInstr(Arg, AssignCI);
69
}
70
71
// Loads parts of the vector of type |SourceType| from the pointer |Source|
72
// and create a new vector of type |TargetType|. |TargetType| must be a vector
73
// type, and element types of |TargetType| and |SourceType| must match.
74
// Returns the loaded value.
75
Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *SourceType,
76
FixedVectorType *TargetType, Value *Source) {
77
// We expect the codegen to avoid doing implicit bitcast from a load.
78
assert(TargetType->getElementType() == SourceType->getElementType());
79
assert(TargetType->getNumElements() < SourceType->getNumElements());
80
81
LoadInst *NewLoad = B.CreateLoad(SourceType, Source);
82
buildAssignType(B, SourceType, NewLoad);
83
84
SmallVector<int> Mask(/* Size= */ TargetType->getNumElements());
85
for (unsigned I = 0; I < TargetType->getNumElements(); ++I)
86
Mask[I] = I;
87
Value *Output = B.CreateShuffleVector(NewLoad, NewLoad, Mask);
88
buildAssignType(B, TargetType, Output);
89
return Output;
90
}
91
92
// Loads the first value in an aggregate pointed by |Source| of containing
93
// elements of type |ElementType|. Load flags will be copied from |BadLoad|,
94
// which should be the load being legalized. Returns the loaded value.
95
Value *loadFirstValueFromAggregate(IRBuilder<> &B, Type *ElementType,
96
Value *Source, LoadInst *BadLoad) {
97
SmallVector<Type *, 2> Types = {BadLoad->getPointerOperandType(),
98
BadLoad->getPointerOperandType()};
99
SmallVector<Value *, 3> Args{/* isInBounds= */ B.getInt1(false), Source,
100
B.getInt32(0), B.getInt32(0)};
101
auto *GEP = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
102
GR->buildAssignPtr(B, ElementType, GEP);
103
104
LoadInst *LI = B.CreateLoad(ElementType, GEP);
105
LI->setAlignment(BadLoad->getAlign());
106
buildAssignType(B, ElementType, LI);
107
return LI;
108
}
109
110
// Replaces the load instruction to get rid of the ptrcast used as source
111
// operand.
112
void transformLoad(IRBuilder<> &B, LoadInst *LI, Value *CastedOperand,
113
Value *OriginalOperand) {
114
Type *FromTy = GR->findDeducedElementType(OriginalOperand);
115
Type *ToTy = GR->findDeducedElementType(CastedOperand);
116
Value *Output = nullptr;
117
118
auto *SAT = dyn_cast<ArrayType>(FromTy);
119
auto *SVT = dyn_cast<FixedVectorType>(FromTy);
120
auto *SST = dyn_cast<StructType>(FromTy);
121
auto *DVT = dyn_cast<FixedVectorType>(ToTy);
122
123
B.SetInsertPoint(LI);
124
125
// Destination is the element type of Source, and source is an array ->
126
// Loading 1st element.
127
// - float a = array[0];
128
if (SAT && SAT->getElementType() == ToTy)
129
Output = loadFirstValueFromAggregate(B, SAT->getElementType(),
130
OriginalOperand, LI);
131
// Destination is the element type of Source, and source is a vector ->
132
// Vector to scalar.
133
// - float a = vector.x;
134
else if (!DVT && SVT && SVT->getElementType() == ToTy) {
135
Output = loadFirstValueFromAggregate(B, SVT->getElementType(),
136
OriginalOperand, LI);
137
}
138
// Destination is a smaller vector than source.
139
// - float3 v3 = vector4;
140
else if (SVT && DVT)
141
Output = loadVectorFromVector(B, SVT, DVT, OriginalOperand);
142
// Destination is the scalar type stored at the start of an aggregate.
143
// - struct S { float m };
144
// - float v = s.m;
145
else if (SST && SST->getTypeAtIndex(0u) == ToTy)
146
Output = loadFirstValueFromAggregate(B, ToTy, OriginalOperand, LI);
147
else
148
llvm_unreachable("Unimplemented implicit down-cast from load.");
149
150
GR->replaceAllUsesWith(LI, Output, /* DeleteOld= */ true);
151
DeadInstructions.push_back(LI);
152
}
153
154
// Creates an spv_insertelt instruction (equivalent to llvm's insertelement).
155
Value *makeInsertElement(IRBuilder<> &B, Value *Vector, Value *Element,
156
unsigned Index) {
157
Type *Int32Ty = Type::getInt32Ty(B.getContext());
158
SmallVector<Type *, 4> Types = {Vector->getType(), Vector->getType(),
159
Element->getType(), Int32Ty};
160
SmallVector<Value *> Args = {Vector, Element, B.getInt32(Index)};
161
Instruction *NewI =
162
B.CreateIntrinsic(Intrinsic::spv_insertelt, {Types}, {Args});
163
buildAssignType(B, Vector->getType(), NewI);
164
return NewI;
165
}
166
167
// Creates an spv_extractelt instruction (equivalent to llvm's
168
// extractelement).
169
Value *makeExtractElement(IRBuilder<> &B, Type *ElementType, Value *Vector,
170
unsigned Index) {
171
Type *Int32Ty = Type::getInt32Ty(B.getContext());
172
SmallVector<Type *, 3> Types = {ElementType, Vector->getType(), Int32Ty};
173
SmallVector<Value *> Args = {Vector, B.getInt32(Index)};
174
Instruction *NewI =
175
B.CreateIntrinsic(Intrinsic::spv_extractelt, {Types}, {Args});
176
buildAssignType(B, ElementType, NewI);
177
return NewI;
178
}
179
180
// Stores the given Src vector operand into the Dst vector, adjusting the size
181
// if required.
182
Value *storeVectorFromVector(IRBuilder<> &B, Value *Src, Value *Dst,
183
Align Alignment) {
184
FixedVectorType *SrcType = cast<FixedVectorType>(Src->getType());
185
FixedVectorType *DstType =
186
cast<FixedVectorType>(GR->findDeducedElementType(Dst));
187
assert(DstType->getNumElements() >= SrcType->getNumElements());
188
189
LoadInst *LI = B.CreateLoad(DstType, Dst);
190
LI->setAlignment(Alignment);
191
Value *OldValues = LI;
192
buildAssignType(B, OldValues->getType(), OldValues);
193
Value *NewValues = Src;
194
195
for (unsigned I = 0; I < SrcType->getNumElements(); ++I) {
196
Value *Element =
197
makeExtractElement(B, SrcType->getElementType(), NewValues, I);
198
OldValues = makeInsertElement(B, OldValues, Element, I);
199
}
200
201
StoreInst *SI = B.CreateStore(OldValues, Dst);
202
SI->setAlignment(Alignment);
203
return SI;
204
}
205
206
void buildGEPIndexChain(IRBuilder<> &B, Type *Search, Type *Aggregate,
207
SmallVectorImpl<Value *> &Indices) {
208
Indices.push_back(B.getInt32(0));
209
210
if (Search == Aggregate)
211
return;
212
213
if (auto *ST = dyn_cast<StructType>(Aggregate))
214
buildGEPIndexChain(B, Search, ST->getTypeAtIndex(0u), Indices);
215
else if (auto *AT = dyn_cast<ArrayType>(Aggregate))
216
buildGEPIndexChain(B, Search, AT->getElementType(), Indices);
217
else if (auto *VT = dyn_cast<FixedVectorType>(Aggregate))
218
buildGEPIndexChain(B, Search, VT->getElementType(), Indices);
219
else
220
llvm_unreachable("Bad access chain?");
221
}
222
223
// Stores the given Src value into the first entry of the Dst aggregate.
224
Value *storeToFirstValueAggregate(IRBuilder<> &B, Value *Src, Value *Dst,
225
Type *DstPointeeType, Align Alignment) {
226
SmallVector<Type *, 2> Types = {Dst->getType(), Dst->getType()};
227
SmallVector<Value *, 3> Args{/* isInBounds= */ B.getInt1(true), Dst};
228
buildGEPIndexChain(B, Src->getType(), DstPointeeType, Args);
229
auto *GEP = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
230
GR->buildAssignPtr(B, Src->getType(), GEP);
231
StoreInst *SI = B.CreateStore(Src, GEP);
232
SI->setAlignment(Alignment);
233
return SI;
234
}
235
236
bool isTypeFirstElementAggregate(Type *Search, Type *Aggregate) {
237
if (Search == Aggregate)
238
return true;
239
if (auto *ST = dyn_cast<StructType>(Aggregate))
240
return isTypeFirstElementAggregate(Search, ST->getTypeAtIndex(0u));
241
if (auto *VT = dyn_cast<FixedVectorType>(Aggregate))
242
return isTypeFirstElementAggregate(Search, VT->getElementType());
243
if (auto *AT = dyn_cast<ArrayType>(Aggregate))
244
return isTypeFirstElementAggregate(Search, AT->getElementType());
245
return false;
246
}
247
248
// Transforms a store instruction (or SPV intrinsic) using a ptrcast as
249
// operand into a valid logical SPIR-V store with no ptrcast.
250
void transformStore(IRBuilder<> &B, Instruction *BadStore, Value *Src,
251
Value *Dst, Align Alignment) {
252
Type *ToTy = GR->findDeducedElementType(Dst);
253
Type *FromTy = Src->getType();
254
255
auto *S_VT = dyn_cast<FixedVectorType>(FromTy);
256
auto *D_ST = dyn_cast<StructType>(ToTy);
257
auto *D_VT = dyn_cast<FixedVectorType>(ToTy);
258
259
B.SetInsertPoint(BadStore);
260
if (D_ST && isTypeFirstElementAggregate(FromTy, D_ST))
261
storeToFirstValueAggregate(B, Src, Dst, D_ST, Alignment);
262
else if (D_VT && S_VT)
263
storeVectorFromVector(B, Src, Dst, Alignment);
264
else if (D_VT && !S_VT && FromTy == D_VT->getElementType())
265
storeToFirstValueAggregate(B, Src, Dst, D_VT, Alignment);
266
else
267
llvm_unreachable("Unsupported ptrcast use in store. Please fix.");
268
269
DeadInstructions.push_back(BadStore);
270
}
271
272
void legalizePointerCast(IntrinsicInst *II) {
273
Value *CastedOperand = II;
274
Value *OriginalOperand = II->getOperand(0);
275
276
IRBuilder<> B(II->getContext());
277
std::vector<Value *> Users;
278
for (Use &U : II->uses())
279
Users.push_back(U.getUser());
280
281
for (Value *User : Users) {
282
if (LoadInst *LI = dyn_cast<LoadInst>(User)) {
283
transformLoad(B, LI, CastedOperand, OriginalOperand);
284
continue;
285
}
286
287
if (StoreInst *SI = dyn_cast<StoreInst>(User)) {
288
transformStore(B, SI, SI->getValueOperand(), OriginalOperand,
289
SI->getAlign());
290
continue;
291
}
292
293
if (IntrinsicInst *Intrin = dyn_cast<IntrinsicInst>(User)) {
294
if (Intrin->getIntrinsicID() == Intrinsic::spv_assign_ptr_type) {
295
DeadInstructions.push_back(Intrin);
296
continue;
297
}
298
299
if (Intrin->getIntrinsicID() == Intrinsic::spv_gep) {
300
GR->replaceAllUsesWith(CastedOperand, OriginalOperand,
301
/* DeleteOld= */ false);
302
continue;
303
}
304
305
if (Intrin->getIntrinsicID() == Intrinsic::spv_store) {
306
Align Alignment;
307
if (ConstantInt *C = dyn_cast<ConstantInt>(Intrin->getOperand(3)))
308
Alignment = Align(C->getZExtValue());
309
transformStore(B, Intrin, Intrin->getArgOperand(0), OriginalOperand,
310
Alignment);
311
continue;
312
}
313
}
314
315
llvm_unreachable("Unsupported ptrcast user. Please fix.");
316
}
317
318
DeadInstructions.push_back(II);
319
}
320
321
public:
322
SPIRVLegalizePointerCast(SPIRVTargetMachine *TM) : FunctionPass(ID), TM(TM) {}
323
324
virtual bool runOnFunction(Function &F) override {
325
const SPIRVSubtarget &ST = TM->getSubtarget<SPIRVSubtarget>(F);
326
GR = ST.getSPIRVGlobalRegistry();
327
DeadInstructions.clear();
328
329
std::vector<IntrinsicInst *> WorkList;
330
for (auto &BB : F) {
331
for (auto &I : BB) {
332
auto *II = dyn_cast<IntrinsicInst>(&I);
333
if (II && II->getIntrinsicID() == Intrinsic::spv_ptrcast)
334
WorkList.push_back(II);
335
}
336
}
337
338
for (IntrinsicInst *II : WorkList)
339
legalizePointerCast(II);
340
341
for (Instruction *I : DeadInstructions)
342
I->eraseFromParent();
343
344
return DeadInstructions.size() != 0;
345
}
346
347
private:
348
SPIRVTargetMachine *TM = nullptr;
349
SPIRVGlobalRegistry *GR = nullptr;
350
std::vector<Instruction *> DeadInstructions;
351
352
public:
353
static char ID;
354
};
355
} // namespace
356
357
char SPIRVLegalizePointerCast::ID = 0;
358
INITIALIZE_PASS(SPIRVLegalizePointerCast, "spirv-legalize-bitcast",
359
"SPIRV legalize bitcast pass", false, false)
360
361
FunctionPass *llvm::createSPIRVLegalizePointerCastPass(SPIRVTargetMachine *TM) {
362
return new SPIRVLegalizePointerCast(TM);
363
}
364
365