Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
35294 views
1
//===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This file implements the SPIRVTargetLowering class.
10
//
11
//===----------------------------------------------------------------------===//
12
13
#include "SPIRVISelLowering.h"
14
#include "SPIRV.h"
15
#include "SPIRVInstrInfo.h"
16
#include "SPIRVRegisterBankInfo.h"
17
#include "SPIRVRegisterInfo.h"
18
#include "SPIRVSubtarget.h"
19
#include "SPIRVTargetMachine.h"
20
#include "llvm/CodeGen/MachineInstrBuilder.h"
21
#include "llvm/CodeGen/MachineRegisterInfo.h"
22
#include "llvm/IR/IntrinsicsSPIRV.h"
23
24
#define DEBUG_TYPE "spirv-lower"
25
26
using namespace llvm;
27
28
unsigned SPIRVTargetLowering::getNumRegistersForCallingConv(
29
LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
30
// This code avoids CallLowering fail inside getVectorTypeBreakdown
31
// on v3i1 arguments. Maybe we need to return 1 for all types.
32
// TODO: remove it once this case is supported by the default implementation.
33
if (VT.isVector() && VT.getVectorNumElements() == 3 &&
34
(VT.getVectorElementType() == MVT::i1 ||
35
VT.getVectorElementType() == MVT::i8))
36
return 1;
37
if (!VT.isVector() && VT.isInteger() && VT.getSizeInBits() <= 64)
38
return 1;
39
return getNumRegisters(Context, VT);
40
}
41
42
MVT SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
43
CallingConv::ID CC,
44
EVT VT) const {
45
// This code avoids CallLowering fail inside getVectorTypeBreakdown
46
// on v3i1 arguments. Maybe we need to return i32 for all types.
47
// TODO: remove it once this case is supported by the default implementation.
48
if (VT.isVector() && VT.getVectorNumElements() == 3) {
49
if (VT.getVectorElementType() == MVT::i1)
50
return MVT::v4i1;
51
else if (VT.getVectorElementType() == MVT::i8)
52
return MVT::v4i8;
53
}
54
return getRegisterType(Context, VT);
55
}
56
57
bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
58
const CallInst &I,
59
MachineFunction &MF,
60
unsigned Intrinsic) const {
61
unsigned AlignIdx = 3;
62
switch (Intrinsic) {
63
case Intrinsic::spv_load:
64
AlignIdx = 2;
65
[[fallthrough]];
66
case Intrinsic::spv_store: {
67
if (I.getNumOperands() >= AlignIdx + 1) {
68
auto *AlignOp = cast<ConstantInt>(I.getOperand(AlignIdx));
69
Info.align = Align(AlignOp->getZExtValue());
70
}
71
Info.flags = static_cast<MachineMemOperand::Flags>(
72
cast<ConstantInt>(I.getOperand(AlignIdx - 1))->getZExtValue());
73
Info.memVT = MVT::i64;
74
// TODO: take into account opaque pointers (don't use getElementType).
75
// MVT::getVT(PtrTy->getElementType());
76
return true;
77
break;
78
}
79
default:
80
break;
81
}
82
return false;
83
}
84
85
std::pair<unsigned, const TargetRegisterClass *>
86
SPIRVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
87
StringRef Constraint,
88
MVT VT) const {
89
const TargetRegisterClass *RC = nullptr;
90
if (Constraint.starts_with("{"))
91
return std::make_pair(0u, RC);
92
93
if (VT.isFloatingPoint())
94
RC = VT.isVector() ? &SPIRV::vfIDRegClass
95
: (VT.getScalarSizeInBits() > 32 ? &SPIRV::fID64RegClass
96
: &SPIRV::fIDRegClass);
97
else if (VT.isInteger())
98
RC = VT.isVector() ? &SPIRV::vIDRegClass
99
: (VT.getScalarSizeInBits() > 32 ? &SPIRV::ID64RegClass
100
: &SPIRV::IDRegClass);
101
else
102
RC = &SPIRV::IDRegClass;
103
104
return std::make_pair(0u, RC);
105
}
106
107
inline Register getTypeReg(MachineRegisterInfo *MRI, Register OpReg) {
108
SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
109
return TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
110
? TypeInst->getOperand(1).getReg()
111
: OpReg;
112
}
113
114
static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
115
SPIRVGlobalRegistry &GR, MachineInstr &I,
116
Register OpReg, unsigned OpIdx,
117
SPIRVType *NewPtrType) {
118
Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
119
MachineIRBuilder MIB(I);
120
bool Res = MIB.buildInstr(SPIRV::OpBitcast)
121
.addDef(NewReg)
122
.addUse(GR.getSPIRVTypeID(NewPtrType))
123
.addUse(OpReg)
124
.constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
125
*STI.getRegBankInfo());
126
if (!Res)
127
report_fatal_error("insert validation bitcast: cannot constrain all uses");
128
MRI->setRegClass(NewReg, &SPIRV::IDRegClass);
129
GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF());
130
I.getOperand(OpIdx).setReg(NewReg);
131
}
132
133
static SPIRVType *createNewPtrType(SPIRVGlobalRegistry &GR, MachineInstr &I,
134
SPIRVType *OpType, bool ReuseType,
135
bool EmitIR, SPIRVType *ResType,
136
const Type *ResTy) {
137
SPIRV::StorageClass::StorageClass SC =
138
static_cast<SPIRV::StorageClass::StorageClass>(
139
OpType->getOperand(1).getImm());
140
MachineIRBuilder MIB(I);
141
SPIRVType *NewBaseType =
142
ReuseType ? ResType
143
: GR.getOrCreateSPIRVType(
144
ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, EmitIR);
145
return GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC);
146
}
147
148
// Insert a bitcast before the instruction to keep SPIR-V code valid
149
// when there is a type mismatch between results and operand types.
150
static void validatePtrTypes(const SPIRVSubtarget &STI,
151
MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR,
152
MachineInstr &I, unsigned OpIdx,
153
SPIRVType *ResType, const Type *ResTy = nullptr) {
154
// Get operand type
155
MachineFunction *MF = I.getParent()->getParent();
156
Register OpReg = I.getOperand(OpIdx).getReg();
157
Register OpTypeReg = getTypeReg(MRI, OpReg);
158
SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
159
if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
160
return;
161
// Get operand's pointee type
162
Register ElemTypeReg = OpType->getOperand(2).getReg();
163
SPIRVType *ElemType = GR.getSPIRVTypeForVReg(ElemTypeReg, MF);
164
if (!ElemType)
165
return;
166
// Check if we need a bitcast to make a statement valid
167
bool IsSameMF = MF == ResType->getParent()->getParent();
168
bool IsEqualTypes = IsSameMF ? ElemType == ResType
169
: GR.getTypeForSPIRVType(ElemType) == ResTy;
170
if (IsEqualTypes)
171
return;
172
// There is a type mismatch between results and operand types
173
// and we insert a bitcast before the instruction to keep SPIR-V code valid
174
SPIRVType *NewPtrType =
175
createNewPtrType(GR, I, OpType, IsSameMF, false, ResType, ResTy);
176
if (!GR.isBitcastCompatible(NewPtrType, OpType))
177
report_fatal_error(
178
"insert validation bitcast: incompatible result and operand types");
179
doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
180
}
181
182
// Insert a bitcast before OpGroupWaitEvents if the last argument is a pointer
183
// that doesn't point to OpTypeEvent.
184
static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI,
185
MachineRegisterInfo *MRI,
186
SPIRVGlobalRegistry &GR,
187
MachineInstr &I) {
188
constexpr unsigned OpIdx = 2;
189
MachineFunction *MF = I.getParent()->getParent();
190
Register OpReg = I.getOperand(OpIdx).getReg();
191
Register OpTypeReg = getTypeReg(MRI, OpReg);
192
SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
193
if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
194
return;
195
SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
196
if (!ElemType || ElemType->getOpcode() == SPIRV::OpTypeEvent)
197
return;
198
// Insert a bitcast before the instruction to keep SPIR-V code valid.
199
LLVMContext &Context = MF->getFunction().getContext();
200
SPIRVType *NewPtrType =
201
createNewPtrType(GR, I, OpType, false, true, nullptr,
202
TargetExtType::get(Context, "spirv.Event"));
203
doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
204
}
205
206
static void validateGroupAsyncCopyPtr(const SPIRVSubtarget &STI,
207
MachineRegisterInfo *MRI,
208
SPIRVGlobalRegistry &GR, MachineInstr &I,
209
unsigned OpIdx) {
210
MachineFunction *MF = I.getParent()->getParent();
211
Register OpReg = I.getOperand(OpIdx).getReg();
212
Register OpTypeReg = getTypeReg(MRI, OpReg);
213
SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
214
if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
215
return;
216
SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
217
if (!ElemType || ElemType->getOpcode() != SPIRV::OpTypeStruct ||
218
ElemType->getNumOperands() != 2)
219
return;
220
// It's a structure-wrapper around another type with a single member field.
221
SPIRVType *MemberType =
222
GR.getSPIRVTypeForVReg(ElemType->getOperand(1).getReg());
223
if (!MemberType)
224
return;
225
unsigned MemberTypeOp = MemberType->getOpcode();
226
if (MemberTypeOp != SPIRV::OpTypeVector && MemberTypeOp != SPIRV::OpTypeInt &&
227
MemberTypeOp != SPIRV::OpTypeFloat && MemberTypeOp != SPIRV::OpTypeBool)
228
return;
229
// It's a structure-wrapper around a valid type. Insert a bitcast before the
230
// instruction to keep SPIR-V code valid.
231
SPIRV::StorageClass::StorageClass SC =
232
static_cast<SPIRV::StorageClass::StorageClass>(
233
OpType->getOperand(1).getImm());
234
MachineIRBuilder MIB(I);
235
SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(MemberType, MIB, SC);
236
doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
237
}
238
239
// Insert a bitcast before the function call instruction to keep SPIR-V code
240
// valid when there is a type mismatch between actual and expected types of an
241
// argument:
242
// %formal = OpFunctionParameter %formal_type
243
// ...
244
// %res = OpFunctionCall %ty %fun %actual ...
245
// implies that %actual is of %formal_type, and in case of opaque pointers.
246
// We may need to insert a bitcast to ensure this.
247
void validateFunCallMachineDef(const SPIRVSubtarget &STI,
248
MachineRegisterInfo *DefMRI,
249
MachineRegisterInfo *CallMRI,
250
SPIRVGlobalRegistry &GR, MachineInstr &FunCall,
251
MachineInstr *FunDef) {
252
if (FunDef->getOpcode() != SPIRV::OpFunction)
253
return;
254
unsigned OpIdx = 3;
255
for (FunDef = FunDef->getNextNode();
256
FunDef && FunDef->getOpcode() == SPIRV::OpFunctionParameter &&
257
OpIdx < FunCall.getNumOperands();
258
FunDef = FunDef->getNextNode(), OpIdx++) {
259
SPIRVType *DefPtrType = DefMRI->getVRegDef(FunDef->getOperand(1).getReg());
260
SPIRVType *DefElemType =
261
DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer
262
? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg(),
263
DefPtrType->getParent()->getParent())
264
: nullptr;
265
if (DefElemType) {
266
const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType);
267
// validatePtrTypes() works in the context if the call site
268
// When we process historical records about forward calls
269
// we need to switch context to the (forward) call site and
270
// then restore it back to the current machine function.
271
MachineFunction *CurMF =
272
GR.setCurrentFunc(*FunCall.getParent()->getParent());
273
validatePtrTypes(STI, CallMRI, GR, FunCall, OpIdx, DefElemType,
274
DefElemTy);
275
GR.setCurrentFunc(*CurMF);
276
}
277
}
278
}
279
280
// Ensure there is no mismatch between actual and expected arg types: calls
281
// with a processed definition. Return Function pointer if it's a forward
282
// call (ahead of definition), and nullptr otherwise.
283
const Function *validateFunCall(const SPIRVSubtarget &STI,
284
MachineRegisterInfo *CallMRI,
285
SPIRVGlobalRegistry &GR,
286
MachineInstr &FunCall) {
287
const GlobalValue *GV = FunCall.getOperand(2).getGlobal();
288
const Function *F = dyn_cast<Function>(GV);
289
MachineInstr *FunDef =
290
const_cast<MachineInstr *>(GR.getFunctionDefinition(F));
291
if (!FunDef)
292
return F;
293
MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo();
294
validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef);
295
return nullptr;
296
}
297
298
// Ensure there is no mismatch between actual and expected arg types: calls
299
// ahead of a processed definition.
300
void validateForwardCalls(const SPIRVSubtarget &STI,
301
MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR,
302
MachineInstr &FunDef) {
303
const Function *F = GR.getFunctionByDefinition(&FunDef);
304
if (SmallPtrSet<MachineInstr *, 8> *FwdCalls = GR.getForwardCalls(F))
305
for (MachineInstr *FunCall : *FwdCalls) {
306
MachineRegisterInfo *CallMRI =
307
&FunCall->getParent()->getParent()->getRegInfo();
308
validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, *FunCall, &FunDef);
309
}
310
}
311
312
// Validation of an access chain.
313
void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
314
SPIRVGlobalRegistry &GR, MachineInstr &I) {
315
SPIRVType *BaseTypeInst = GR.getSPIRVTypeForVReg(I.getOperand(0).getReg());
316
if (BaseTypeInst && BaseTypeInst->getOpcode() == SPIRV::OpTypePointer) {
317
SPIRVType *BaseElemType =
318
GR.getSPIRVTypeForVReg(BaseTypeInst->getOperand(2).getReg());
319
validatePtrTypes(STI, MRI, GR, I, 2, BaseElemType);
320
}
321
}
322
323
// TODO: the logic of inserting additional bitcast's is to be moved
324
// to pre-IRTranslation passes eventually
325
void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
326
// finalizeLowering() is called twice (see GlobalISel/InstructionSelect.cpp)
327
// We'd like to avoid the needless second processing pass.
328
if (ProcessedMF.find(&MF) != ProcessedMF.end())
329
return;
330
331
MachineRegisterInfo *MRI = &MF.getRegInfo();
332
SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
333
GR.setCurrentFunc(MF);
334
for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) {
335
MachineBasicBlock *MBB = &*I;
336
for (MachineBasicBlock::iterator MBBI = MBB->begin(), MBBE = MBB->end();
337
MBBI != MBBE;) {
338
MachineInstr &MI = *MBBI++;
339
switch (MI.getOpcode()) {
340
case SPIRV::OpAtomicLoad:
341
case SPIRV::OpAtomicExchange:
342
case SPIRV::OpAtomicCompareExchange:
343
case SPIRV::OpAtomicCompareExchangeWeak:
344
case SPIRV::OpAtomicIIncrement:
345
case SPIRV::OpAtomicIDecrement:
346
case SPIRV::OpAtomicIAdd:
347
case SPIRV::OpAtomicISub:
348
case SPIRV::OpAtomicSMin:
349
case SPIRV::OpAtomicUMin:
350
case SPIRV::OpAtomicSMax:
351
case SPIRV::OpAtomicUMax:
352
case SPIRV::OpAtomicAnd:
353
case SPIRV::OpAtomicOr:
354
case SPIRV::OpAtomicXor:
355
// for the above listed instructions
356
// OpAtomicXXX <ResType>, ptr %Op, ...
357
// implies that %Op is a pointer to <ResType>
358
case SPIRV::OpLoad:
359
// OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType>
360
validatePtrTypes(STI, MRI, GR, MI, 2,
361
GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
362
break;
363
case SPIRV::OpAtomicStore:
364
// OpAtomicStore ptr %Op, <Scope>, <Mem>, <Obj>
365
// implies that %Op points to the <Obj>'s type
366
validatePtrTypes(STI, MRI, GR, MI, 0,
367
GR.getSPIRVTypeForVReg(MI.getOperand(3).getReg()));
368
break;
369
case SPIRV::OpStore:
370
// OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type
371
validatePtrTypes(STI, MRI, GR, MI, 0,
372
GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()));
373
break;
374
case SPIRV::OpPtrCastToGeneric:
375
case SPIRV::OpGenericCastToPtr:
376
validateAccessChain(STI, MRI, GR, MI);
377
break;
378
case SPIRV::OpInBoundsPtrAccessChain:
379
if (MI.getNumOperands() == 4)
380
validateAccessChain(STI, MRI, GR, MI);
381
break;
382
383
case SPIRV::OpFunctionCall:
384
// ensure there is no mismatch between actual and expected arg types:
385
// calls with a processed definition
386
if (MI.getNumOperands() > 3)
387
if (const Function *F = validateFunCall(STI, MRI, GR, MI))
388
GR.addForwardCall(F, &MI);
389
break;
390
case SPIRV::OpFunction:
391
// ensure there is no mismatch between actual and expected arg types:
392
// calls ahead of a processed definition
393
validateForwardCalls(STI, MRI, GR, MI);
394
break;
395
396
// ensure that LLVM IR bitwise instructions result in logical SPIR-V
397
// instructions when applied to bool type
398
case SPIRV::OpBitwiseOrS:
399
case SPIRV::OpBitwiseOrV:
400
if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
401
SPIRV::OpTypeBool))
402
MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalOr));
403
break;
404
case SPIRV::OpBitwiseAndS:
405
case SPIRV::OpBitwiseAndV:
406
if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
407
SPIRV::OpTypeBool))
408
MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalAnd));
409
break;
410
case SPIRV::OpBitwiseXorS:
411
case SPIRV::OpBitwiseXorV:
412
if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
413
SPIRV::OpTypeBool))
414
MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
415
break;
416
case SPIRV::OpGroupAsyncCopy:
417
validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 3);
418
validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 4);
419
break;
420
case SPIRV::OpGroupWaitEvents:
421
// OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
422
validateGroupWaitEventsPtr(STI, MRI, GR, MI);
423
break;
424
case SPIRV::OpConstantI: {
425
SPIRVType *Type = GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg());
426
if (Type->getOpcode() != SPIRV::OpTypeInt && MI.getOperand(2).isImm() &&
427
MI.getOperand(2).getImm() == 0) {
428
// Validate the null constant of a target extension type
429
MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpConstantNull));
430
for (unsigned i = MI.getNumOperands() - 1; i > 1; --i)
431
MI.removeOperand(i);
432
}
433
} break;
434
}
435
}
436
}
437
ProcessedMF.insert(&MF);
438
TargetLowering::finalizeLowering(MF);
439
}
440
441