Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
35268 views
1
//===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This file contains the implementation of the SPIRVGlobalRegistry class,
10
// which is used to maintain rich type information required for SPIR-V even
11
// after lowering from LLVM IR to GMIR. It can convert an llvm::Type into
12
// an OpTypeXXX instruction, and map it to a virtual register. Also it builds
13
// and supports consistency of constants and global variables.
14
//
15
//===----------------------------------------------------------------------===//
16
17
#include "SPIRVGlobalRegistry.h"
18
#include "SPIRV.h"
19
#include "SPIRVBuiltins.h"
20
#include "SPIRVSubtarget.h"
21
#include "SPIRVTargetMachine.h"
22
#include "SPIRVUtils.h"
23
#include "llvm/ADT/APInt.h"
24
#include "llvm/IR/Constants.h"
25
#include "llvm/IR/Type.h"
26
#include "llvm/Support/Casting.h"
27
#include <cassert>
28
29
using namespace llvm;
30
SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
31
: PointerSize(PointerSize), Bound(0) {}
32
33
SPIRVType *SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth,
34
Register VReg,
35
MachineInstr &I,
36
const SPIRVInstrInfo &TII) {
37
SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
38
assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
39
return SpirvType;
40
}
41
42
SPIRVType *
43
SPIRVGlobalRegistry::assignFloatTypeToVReg(unsigned BitWidth, Register VReg,
44
MachineInstr &I,
45
const SPIRVInstrInfo &TII) {
46
SPIRVType *SpirvType = getOrCreateSPIRVFloatType(BitWidth, I, TII);
47
assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
48
return SpirvType;
49
}
50
51
SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg(
52
SPIRVType *BaseType, unsigned NumElements, Register VReg, MachineInstr &I,
53
const SPIRVInstrInfo &TII) {
54
SPIRVType *SpirvType =
55
getOrCreateSPIRVVectorType(BaseType, NumElements, I, TII);
56
assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
57
return SpirvType;
58
}
59
60
SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
61
const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder,
62
SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
63
SPIRVType *SpirvType =
64
getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR);
65
assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF());
66
return SpirvType;
67
}
68
69
void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType,
70
Register VReg,
71
MachineFunction &MF) {
72
VRegToTypeMap[&MF][VReg] = SpirvType;
73
}
74
75
static Register createTypeVReg(MachineIRBuilder &MIRBuilder) {
76
auto &MRI = MIRBuilder.getMF().getRegInfo();
77
auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
78
MRI.setRegClass(Res, &SPIRV::TYPERegClass);
79
return Res;
80
}
81
82
static Register createTypeVReg(MachineRegisterInfo &MRI) {
83
auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
84
MRI.setRegClass(Res, &SPIRV::TYPERegClass);
85
return Res;
86
}
87
88
SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
89
return MIRBuilder.buildInstr(SPIRV::OpTypeBool)
90
.addDef(createTypeVReg(MIRBuilder));
91
}
92
93
unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
94
if (Width > 64)
95
report_fatal_error("Unsupported integer width!");
96
const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
97
if (ST.canUseExtension(
98
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers))
99
return Width;
100
if (Width <= 8)
101
Width = 8;
102
else if (Width <= 16)
103
Width = 16;
104
else if (Width <= 32)
105
Width = 32;
106
else
107
Width = 64;
108
return Width;
109
}
110
111
SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
112
MachineIRBuilder &MIRBuilder,
113
bool IsSigned) {
114
Width = adjustOpTypeIntWidth(Width);
115
const SPIRVSubtarget &ST =
116
cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
117
if (ST.canUseExtension(
118
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
119
MIRBuilder.buildInstr(SPIRV::OpExtension)
120
.addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
121
MIRBuilder.buildInstr(SPIRV::OpCapability)
122
.addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
123
}
124
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt)
125
.addDef(createTypeVReg(MIRBuilder))
126
.addImm(Width)
127
.addImm(IsSigned ? 1 : 0);
128
return MIB;
129
}
130
131
SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
132
MachineIRBuilder &MIRBuilder) {
133
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
134
.addDef(createTypeVReg(MIRBuilder))
135
.addImm(Width);
136
return MIB;
137
}
138
139
SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
140
return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)
141
.addDef(createTypeVReg(MIRBuilder));
142
}
143
144
SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
145
SPIRVType *ElemType,
146
MachineIRBuilder &MIRBuilder) {
147
auto EleOpc = ElemType->getOpcode();
148
(void)EleOpc;
149
assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
150
EleOpc == SPIRV::OpTypeBool) &&
151
"Invalid vector element type");
152
153
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeVector)
154
.addDef(createTypeVReg(MIRBuilder))
155
.addUse(getSPIRVTypeID(ElemType))
156
.addImm(NumElems);
157
return MIB;
158
}
159
160
std::tuple<Register, ConstantInt *, bool>
161
SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
162
MachineIRBuilder *MIRBuilder,
163
MachineInstr *I,
164
const SPIRVInstrInfo *TII) {
165
const IntegerType *LLVMIntTy;
166
if (SpvType)
167
LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
168
else
169
LLVMIntTy = IntegerType::getInt32Ty(CurMF->getFunction().getContext());
170
bool NewInstr = false;
171
// Find a constant in DT or build a new one.
172
ConstantInt *CI = ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
173
Register Res = DT.find(CI, CurMF);
174
if (!Res.isValid()) {
175
unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
176
// TODO: handle cases where the type is not 32bit wide
177
// TODO: https://github.com/llvm/llvm-project/issues/88129
178
LLT LLTy = LLT::scalar(32);
179
Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
180
CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
181
if (MIRBuilder)
182
assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder);
183
else
184
assignIntTypeToVReg(BitWidth, Res, *I, *TII);
185
DT.add(CI, CurMF, Res);
186
NewInstr = true;
187
}
188
return std::make_tuple(Res, CI, NewInstr);
189
}
190
191
std::tuple<Register, ConstantFP *, bool, unsigned>
192
SPIRVGlobalRegistry::getOrCreateConstFloatReg(APFloat Val, SPIRVType *SpvType,
193
MachineIRBuilder *MIRBuilder,
194
MachineInstr *I,
195
const SPIRVInstrInfo *TII) {
196
const Type *LLVMFloatTy;
197
LLVMContext &Ctx = CurMF->getFunction().getContext();
198
unsigned BitWidth = 32;
199
if (SpvType)
200
LLVMFloatTy = getTypeForSPIRVType(SpvType);
201
else {
202
LLVMFloatTy = Type::getFloatTy(Ctx);
203
if (MIRBuilder)
204
SpvType = getOrCreateSPIRVType(LLVMFloatTy, *MIRBuilder);
205
}
206
bool NewInstr = false;
207
// Find a constant in DT or build a new one.
208
auto *const CI = ConstantFP::get(Ctx, Val);
209
Register Res = DT.find(CI, CurMF);
210
if (!Res.isValid()) {
211
if (SpvType)
212
BitWidth = getScalarOrVectorBitWidth(SpvType);
213
// TODO: handle cases where the type is not 32bit wide
214
// TODO: https://github.com/llvm/llvm-project/issues/88129
215
LLT LLTy = LLT::scalar(32);
216
Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
217
CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
218
if (MIRBuilder)
219
assignTypeToVReg(LLVMFloatTy, Res, *MIRBuilder);
220
else
221
assignFloatTypeToVReg(BitWidth, Res, *I, *TII);
222
DT.add(CI, CurMF, Res);
223
NewInstr = true;
224
}
225
return std::make_tuple(Res, CI, NewInstr, BitWidth);
226
}
227
228
Register SPIRVGlobalRegistry::getOrCreateConstFP(APFloat Val, MachineInstr &I,
229
SPIRVType *SpvType,
230
const SPIRVInstrInfo &TII,
231
bool ZeroAsNull) {
232
assert(SpvType);
233
ConstantFP *CI;
234
Register Res;
235
bool New;
236
unsigned BitWidth;
237
std::tie(Res, CI, New, BitWidth) =
238
getOrCreateConstFloatReg(Val, SpvType, nullptr, &I, &TII);
239
// If we have found Res register which is defined by the passed G_CONSTANT
240
// machine instruction, a new constant instruction should be created.
241
if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg()))
242
return Res;
243
MachineInstrBuilder MIB;
244
MachineBasicBlock &BB = *I.getParent();
245
// In OpenCL OpConstantNull - Scalar floating point: +0.0 (all bits 0)
246
if (Val.isPosZero() && ZeroAsNull) {
247
MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
248
.addDef(Res)
249
.addUse(getSPIRVTypeID(SpvType));
250
} else {
251
MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantF))
252
.addDef(Res)
253
.addUse(getSPIRVTypeID(SpvType));
254
addNumImm(
255
APInt(BitWidth, CI->getValueAPF().bitcastToAPInt().getZExtValue()),
256
MIB);
257
}
258
const auto &ST = CurMF->getSubtarget();
259
constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
260
*ST.getRegisterInfo(), *ST.getRegBankInfo());
261
return Res;
262
}
263
264
Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
265
SPIRVType *SpvType,
266
const SPIRVInstrInfo &TII,
267
bool ZeroAsNull) {
268
assert(SpvType);
269
ConstantInt *CI;
270
Register Res;
271
bool New;
272
std::tie(Res, CI, New) =
273
getOrCreateConstIntReg(Val, SpvType, nullptr, &I, &TII);
274
// If we have found Res register which is defined by the passed G_CONSTANT
275
// machine instruction, a new constant instruction should be created.
276
if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg()))
277
return Res;
278
MachineInstrBuilder MIB;
279
MachineBasicBlock &BB = *I.getParent();
280
if (Val || !ZeroAsNull) {
281
MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI))
282
.addDef(Res)
283
.addUse(getSPIRVTypeID(SpvType));
284
addNumImm(APInt(getScalarOrVectorBitWidth(SpvType), Val), MIB);
285
} else {
286
MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
287
.addDef(Res)
288
.addUse(getSPIRVTypeID(SpvType));
289
}
290
const auto &ST = CurMF->getSubtarget();
291
constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
292
*ST.getRegisterInfo(), *ST.getRegBankInfo());
293
return Res;
294
}
295
296
Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
297
MachineIRBuilder &MIRBuilder,
298
SPIRVType *SpvType,
299
bool EmitIR) {
300
auto &MF = MIRBuilder.getMF();
301
const IntegerType *LLVMIntTy;
302
if (SpvType)
303
LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
304
else
305
LLVMIntTy = IntegerType::getInt32Ty(MF.getFunction().getContext());
306
// Find a constant in DT or build a new one.
307
const auto ConstInt =
308
ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
309
Register Res = DT.find(ConstInt, &MF);
310
if (!Res.isValid()) {
311
unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
312
LLT LLTy = LLT::scalar(EmitIR ? BitWidth : 32);
313
Res = MF.getRegInfo().createGenericVirtualRegister(LLTy);
314
MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
315
assignTypeToVReg(LLVMIntTy, Res, MIRBuilder,
316
SPIRV::AccessQualifier::ReadWrite, EmitIR);
317
DT.add(ConstInt, &MIRBuilder.getMF(), Res);
318
if (EmitIR) {
319
MIRBuilder.buildConstant(Res, *ConstInt);
320
} else {
321
if (!SpvType)
322
SpvType = getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder);
323
MachineInstrBuilder MIB;
324
if (Val) {
325
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
326
.addDef(Res)
327
.addUse(getSPIRVTypeID(SpvType));
328
addNumImm(APInt(BitWidth, Val), MIB);
329
} else {
330
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
331
.addDef(Res)
332
.addUse(getSPIRVTypeID(SpvType));
333
}
334
const auto &Subtarget = CurMF->getSubtarget();
335
constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
336
*Subtarget.getRegisterInfo(),
337
*Subtarget.getRegBankInfo());
338
}
339
}
340
return Res;
341
}
342
343
Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
344
MachineIRBuilder &MIRBuilder,
345
SPIRVType *SpvType) {
346
auto &MF = MIRBuilder.getMF();
347
auto &Ctx = MF.getFunction().getContext();
348
if (!SpvType) {
349
const Type *LLVMFPTy = Type::getFloatTy(Ctx);
350
SpvType = getOrCreateSPIRVType(LLVMFPTy, MIRBuilder);
351
}
352
// Find a constant in DT or build a new one.
353
const auto ConstFP = ConstantFP::get(Ctx, Val);
354
Register Res = DT.find(ConstFP, &MF);
355
if (!Res.isValid()) {
356
Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(32));
357
MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
358
assignSPIRVTypeToVReg(SpvType, Res, MF);
359
DT.add(ConstFP, &MF, Res);
360
361
MachineInstrBuilder MIB;
362
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF)
363
.addDef(Res)
364
.addUse(getSPIRVTypeID(SpvType));
365
addNumImm(ConstFP->getValueAPF().bitcastToAPInt(), MIB);
366
}
367
368
return Res;
369
}
370
371
Register SPIRVGlobalRegistry::getOrCreateBaseRegister(Constant *Val,
372
MachineInstr &I,
373
SPIRVType *SpvType,
374
const SPIRVInstrInfo &TII,
375
unsigned BitWidth) {
376
SPIRVType *Type = SpvType;
377
if (SpvType->getOpcode() == SPIRV::OpTypeVector ||
378
SpvType->getOpcode() == SPIRV::OpTypeArray) {
379
auto EleTypeReg = SpvType->getOperand(1).getReg();
380
Type = getSPIRVTypeForVReg(EleTypeReg);
381
}
382
if (Type->getOpcode() == SPIRV::OpTypeFloat) {
383
SPIRVType *SpvBaseType = getOrCreateSPIRVFloatType(BitWidth, I, TII);
384
return getOrCreateConstFP(dyn_cast<ConstantFP>(Val)->getValue(), I,
385
SpvBaseType, TII);
386
}
387
assert(Type->getOpcode() == SPIRV::OpTypeInt);
388
SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
389
return getOrCreateConstInt(Val->getUniqueInteger().getSExtValue(), I,
390
SpvBaseType, TII);
391
}
392
393
Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
394
Constant *Val, MachineInstr &I, SPIRVType *SpvType,
395
const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth,
396
unsigned ElemCnt, bool ZeroAsNull) {
397
// Find a constant vector or array in DT or build a new one.
398
Register Res = DT.find(CA, CurMF);
399
// If no values are attached, the composite is null constant.
400
bool IsNull = Val->isNullValue() && ZeroAsNull;
401
if (!Res.isValid()) {
402
// SpvScalConst should be created before SpvVecConst to avoid undefined ID
403
// error on validation.
404
// TODO: can moved below once sorting of types/consts/defs is implemented.
405
Register SpvScalConst;
406
if (!IsNull)
407
SpvScalConst = getOrCreateBaseRegister(Val, I, SpvType, TII, BitWidth);
408
409
// TODO: handle cases where the type is not 32bit wide
410
// TODO: https://github.com/llvm/llvm-project/issues/88129
411
LLT LLTy = LLT::scalar(32);
412
Register SpvVecConst =
413
CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
414
CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass);
415
assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);
416
DT.add(CA, CurMF, SpvVecConst);
417
MachineInstrBuilder MIB;
418
MachineBasicBlock &BB = *I.getParent();
419
if (!IsNull) {
420
MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantComposite))
421
.addDef(SpvVecConst)
422
.addUse(getSPIRVTypeID(SpvType));
423
for (unsigned i = 0; i < ElemCnt; ++i)
424
MIB.addUse(SpvScalConst);
425
} else {
426
MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
427
.addDef(SpvVecConst)
428
.addUse(getSPIRVTypeID(SpvType));
429
}
430
const auto &Subtarget = CurMF->getSubtarget();
431
constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
432
*Subtarget.getRegisterInfo(),
433
*Subtarget.getRegBankInfo());
434
return SpvVecConst;
435
}
436
return Res;
437
}
438
439
Register SPIRVGlobalRegistry::getOrCreateConstVector(uint64_t Val,
440
MachineInstr &I,
441
SPIRVType *SpvType,
442
const SPIRVInstrInfo &TII,
443
bool ZeroAsNull) {
444
const Type *LLVMTy = getTypeForSPIRVType(SpvType);
445
assert(LLVMTy->isVectorTy());
446
const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
447
Type *LLVMBaseTy = LLVMVecTy->getElementType();
448
assert(LLVMBaseTy->isIntegerTy());
449
auto *ConstVal = ConstantInt::get(LLVMBaseTy, Val);
450
auto *ConstVec =
451
ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal);
452
unsigned BW = getScalarOrVectorBitWidth(SpvType);
453
return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW,
454
SpvType->getOperand(2).getImm(),
455
ZeroAsNull);
456
}
457
458
Register SPIRVGlobalRegistry::getOrCreateConstVector(APFloat Val,
459
MachineInstr &I,
460
SPIRVType *SpvType,
461
const SPIRVInstrInfo &TII,
462
bool ZeroAsNull) {
463
const Type *LLVMTy = getTypeForSPIRVType(SpvType);
464
assert(LLVMTy->isVectorTy());
465
const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
466
Type *LLVMBaseTy = LLVMVecTy->getElementType();
467
assert(LLVMBaseTy->isFloatingPointTy());
468
auto *ConstVal = ConstantFP::get(LLVMBaseTy, Val);
469
auto *ConstVec =
470
ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal);
471
unsigned BW = getScalarOrVectorBitWidth(SpvType);
472
return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW,
473
SpvType->getOperand(2).getImm(),
474
ZeroAsNull);
475
}
476
477
Register SPIRVGlobalRegistry::getOrCreateConstIntArray(
478
uint64_t Val, size_t Num, MachineInstr &I, SPIRVType *SpvType,
479
const SPIRVInstrInfo &TII) {
480
const Type *LLVMTy = getTypeForSPIRVType(SpvType);
481
assert(LLVMTy->isArrayTy());
482
const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
483
Type *LLVMBaseTy = LLVMArrTy->getElementType();
484
Constant *CI = ConstantInt::get(LLVMBaseTy, Val);
485
SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
486
unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
487
// The following is reasonably unique key that is better that [Val]. The naive
488
// alternative would be something along the lines of:
489
// SmallVector<Constant *> NumCI(Num, CI);
490
// Constant *UniqueKey =
491
// ConstantArray::get(const_cast<ArrayType*>(LLVMArrTy), NumCI);
492
// that would be a truly unique but dangerous key, because it could lead to
493
// the creation of constants of arbitrary length (that is, the parameter of
494
// memset) which were missing in the original module.
495
Constant *UniqueKey = ConstantStruct::getAnon(
496
{PoisonValue::get(const_cast<ArrayType *>(LLVMArrTy)),
497
ConstantInt::get(LLVMBaseTy, Val), ConstantInt::get(LLVMBaseTy, Num)});
498
return getOrCreateCompositeOrNull(CI, I, SpvType, TII, UniqueKey, BW,
499
LLVMArrTy->getNumElements());
500
}
501
502
Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
503
uint64_t Val, MachineIRBuilder &MIRBuilder, SPIRVType *SpvType, bool EmitIR,
504
Constant *CA, unsigned BitWidth, unsigned ElemCnt) {
505
Register Res = DT.find(CA, CurMF);
506
if (!Res.isValid()) {
507
Register SpvScalConst;
508
if (Val || EmitIR) {
509
SPIRVType *SpvBaseType =
510
getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder);
511
SpvScalConst = buildConstantInt(Val, MIRBuilder, SpvBaseType, EmitIR);
512
}
513
LLT LLTy = EmitIR ? LLT::fixed_vector(ElemCnt, BitWidth) : LLT::scalar(32);
514
Register SpvVecConst =
515
CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
516
CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass);
517
assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);
518
DT.add(CA, CurMF, SpvVecConst);
519
if (EmitIR) {
520
MIRBuilder.buildSplatVector(SpvVecConst, SpvScalConst);
521
} else {
522
if (Val) {
523
auto MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite)
524
.addDef(SpvVecConst)
525
.addUse(getSPIRVTypeID(SpvType));
526
for (unsigned i = 0; i < ElemCnt; ++i)
527
MIB.addUse(SpvScalConst);
528
} else {
529
MIRBuilder.buildInstr(SPIRV::OpConstantNull)
530
.addDef(SpvVecConst)
531
.addUse(getSPIRVTypeID(SpvType));
532
}
533
}
534
return SpvVecConst;
535
}
536
return Res;
537
}
538
539
Register
540
SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val,
541
MachineIRBuilder &MIRBuilder,
542
SPIRVType *SpvType, bool EmitIR) {
543
const Type *LLVMTy = getTypeForSPIRVType(SpvType);
544
assert(LLVMTy->isVectorTy());
545
const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
546
Type *LLVMBaseTy = LLVMVecTy->getElementType();
547
const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
548
auto ConstVec =
549
ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt);
550
unsigned BW = getScalarOrVectorBitWidth(SpvType);
551
return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR,
552
ConstVec, BW,
553
SpvType->getOperand(2).getImm());
554
}
555
556
Register
557
SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
558
SPIRVType *SpvType) {
559
const Type *LLVMTy = getTypeForSPIRVType(SpvType);
560
const TypedPointerType *LLVMPtrTy = cast<TypedPointerType>(LLVMTy);
561
// Find a constant in DT or build a new one.
562
Constant *CP = ConstantPointerNull::get(PointerType::get(
563
LLVMPtrTy->getElementType(), LLVMPtrTy->getAddressSpace()));
564
Register Res = DT.find(CP, CurMF);
565
if (!Res.isValid()) {
566
LLT LLTy = LLT::pointer(LLVMPtrTy->getAddressSpace(), PointerSize);
567
Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
568
CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
569
assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
570
MIRBuilder.buildInstr(SPIRV::OpConstantNull)
571
.addDef(Res)
572
.addUse(getSPIRVTypeID(SpvType));
573
DT.add(CP, CurMF, Res);
574
}
575
return Res;
576
}
577
578
Register SPIRVGlobalRegistry::buildConstantSampler(
579
Register ResReg, unsigned AddrMode, unsigned Param, unsigned FilerMode,
580
MachineIRBuilder &MIRBuilder, SPIRVType *SpvType) {
581
SPIRVType *SampTy;
582
if (SpvType)
583
SampTy = getOrCreateSPIRVType(getTypeForSPIRVType(SpvType), MIRBuilder);
584
else if ((SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t",
585
MIRBuilder)) == nullptr)
586
report_fatal_error("Unable to recognize SPIRV type name: opencl.sampler_t");
587
588
auto Sampler =
589
ResReg.isValid()
590
? ResReg
591
: MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
592
auto Res = MIRBuilder.buildInstr(SPIRV::OpConstantSampler)
593
.addDef(Sampler)
594
.addUse(getSPIRVTypeID(SampTy))
595
.addImm(AddrMode)
596
.addImm(Param)
597
.addImm(FilerMode);
598
assert(Res->getOperand(0).isReg());
599
return Res->getOperand(0).getReg();
600
}
601
602
Register SPIRVGlobalRegistry::buildGlobalVariable(
603
Register ResVReg, SPIRVType *BaseType, StringRef Name,
604
const GlobalValue *GV, SPIRV::StorageClass::StorageClass Storage,
605
const MachineInstr *Init, bool IsConst, bool HasLinkageTy,
606
SPIRV::LinkageType::LinkageType LinkageType, MachineIRBuilder &MIRBuilder,
607
bool IsInstSelector) {
608
const GlobalVariable *GVar = nullptr;
609
if (GV)
610
GVar = cast<const GlobalVariable>(GV);
611
else {
612
// If GV is not passed explicitly, use the name to find or construct
613
// the global variable.
614
Module *M = MIRBuilder.getMF().getFunction().getParent();
615
GVar = M->getGlobalVariable(Name);
616
if (GVar == nullptr) {
617
const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type.
618
// Module takes ownership of the global var.
619
GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false,
620
GlobalValue::ExternalLinkage, nullptr,
621
Twine(Name));
622
}
623
GV = GVar;
624
}
625
Register Reg = DT.find(GVar, &MIRBuilder.getMF());
626
if (Reg.isValid()) {
627
if (Reg != ResVReg)
628
MIRBuilder.buildCopy(ResVReg, Reg);
629
return ResVReg;
630
}
631
632
auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable)
633
.addDef(ResVReg)
634
.addUse(getSPIRVTypeID(BaseType))
635
.addImm(static_cast<uint32_t>(Storage));
636
637
if (Init != 0) {
638
MIB.addUse(Init->getOperand(0).getReg());
639
}
640
641
// ISel may introduce a new register on this step, so we need to add it to
642
// DT and correct its type avoiding fails on the next stage.
643
if (IsInstSelector) {
644
const auto &Subtarget = CurMF->getSubtarget();
645
constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
646
*Subtarget.getRegisterInfo(),
647
*Subtarget.getRegBankInfo());
648
}
649
Reg = MIB->getOperand(0).getReg();
650
DT.add(GVar, &MIRBuilder.getMF(), Reg);
651
652
// Set to Reg the same type as ResVReg has.
653
auto MRI = MIRBuilder.getMRI();
654
assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected");
655
if (Reg != ResVReg) {
656
LLT RegLLTy =
657
LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), getPointerSize());
658
MRI->setType(Reg, RegLLTy);
659
assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());
660
} else {
661
// Our knowledge about the type may be updated.
662
// If that's the case, we need to update a type
663
// associated with the register.
664
SPIRVType *DefType = getSPIRVTypeForVReg(ResVReg);
665
if (!DefType || DefType != BaseType)
666
assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());
667
}
668
669
// If it's a global variable with name, output OpName for it.
670
if (GVar && GVar->hasName())
671
buildOpName(Reg, GVar->getName(), MIRBuilder);
672
673
// Output decorations for the GV.
674
// TODO: maybe move to GenerateDecorations pass.
675
const SPIRVSubtarget &ST =
676
cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
677
if (IsConst && ST.isOpenCLEnv())
678
buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {});
679
680
if (GVar && GVar->getAlign().valueOrOne().value() != 1) {
681
unsigned Alignment = (unsigned)GVar->getAlign().valueOrOne().value();
682
buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Alignment, {Alignment});
683
}
684
685
if (HasLinkageTy)
686
buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
687
{static_cast<uint32_t>(LinkageType)}, Name);
688
689
SPIRV::BuiltIn::BuiltIn BuiltInId;
690
if (getSpirvBuiltInIdByName(Name, BuiltInId))
691
buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::BuiltIn,
692
{static_cast<uint32_t>(BuiltInId)});
693
694
// If it's a global variable with "spirv.Decorations" metadata node
695
// recognize it as a SPIR-V friendly LLVM IR and parse "spirv.Decorations"
696
// arguments.
697
MDNode *GVarMD = nullptr;
698
if (GVar && (GVarMD = GVar->getMetadata("spirv.Decorations")) != nullptr)
699
buildOpSpirvDecorations(Reg, MIRBuilder, GVarMD);
700
701
return Reg;
702
}
703
704
SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,
705
SPIRVType *ElemType,
706
MachineIRBuilder &MIRBuilder,
707
bool EmitIR) {
708
assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) &&
709
"Invalid array element type");
710
Register NumElementsVReg =
711
buildConstantInt(NumElems, MIRBuilder, nullptr, EmitIR);
712
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray)
713
.addDef(createTypeVReg(MIRBuilder))
714
.addUse(getSPIRVTypeID(ElemType))
715
.addUse(NumElementsVReg);
716
return MIB;
717
}
718
719
SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty,
720
MachineIRBuilder &MIRBuilder) {
721
assert(Ty->hasName());
722
const StringRef Name = Ty->hasName() ? Ty->getName() : "";
723
Register ResVReg = createTypeVReg(MIRBuilder);
724
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeOpaque).addDef(ResVReg);
725
addStringImm(Name, MIB);
726
buildOpName(ResVReg, Name, MIRBuilder);
727
return MIB;
728
}
729
730
SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
731
MachineIRBuilder &MIRBuilder,
732
bool EmitIR) {
733
SmallVector<Register, 4> FieldTypes;
734
for (const auto &Elem : Ty->elements()) {
735
SPIRVType *ElemTy = findSPIRVType(toTypedPointer(Elem), MIRBuilder);
736
assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
737
"Invalid struct element type");
738
FieldTypes.push_back(getSPIRVTypeID(ElemTy));
739
}
740
Register ResVReg = createTypeVReg(MIRBuilder);
741
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg);
742
for (const auto &Ty : FieldTypes)
743
MIB.addUse(Ty);
744
if (Ty->hasName())
745
buildOpName(ResVReg, Ty->getName(), MIRBuilder);
746
if (Ty->isPacked())
747
buildOpDecorate(ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {});
748
return MIB;
749
}
750
751
SPIRVType *SPIRVGlobalRegistry::getOrCreateSpecialType(
752
const Type *Ty, MachineIRBuilder &MIRBuilder,
753
SPIRV::AccessQualifier::AccessQualifier AccQual) {
754
assert(isSpecialOpaqueType(Ty) && "Not a special opaque builtin type");
755
return SPIRV::lowerBuiltinType(Ty, AccQual, MIRBuilder, this);
756
}
757
758
SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(
759
SPIRV::StorageClass::StorageClass SC, SPIRVType *ElemType,
760
MachineIRBuilder &MIRBuilder, Register Reg) {
761
if (!Reg.isValid())
762
Reg = createTypeVReg(MIRBuilder);
763
return MIRBuilder.buildInstr(SPIRV::OpTypePointer)
764
.addDef(Reg)
765
.addImm(static_cast<uint32_t>(SC))
766
.addUse(getSPIRVTypeID(ElemType));
767
}
768
769
SPIRVType *SPIRVGlobalRegistry::getOpTypeForwardPointer(
770
SPIRV::StorageClass::StorageClass SC, MachineIRBuilder &MIRBuilder) {
771
return MIRBuilder.buildInstr(SPIRV::OpTypeForwardPointer)
772
.addUse(createTypeVReg(MIRBuilder))
773
.addImm(static_cast<uint32_t>(SC));
774
}
775
776
SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction(
777
SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes,
778
MachineIRBuilder &MIRBuilder) {
779
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction)
780
.addDef(createTypeVReg(MIRBuilder))
781
.addUse(getSPIRVTypeID(RetType));
782
for (const SPIRVType *ArgType : ArgTypes)
783
MIB.addUse(getSPIRVTypeID(ArgType));
784
return MIB;
785
}
786
787
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(
788
const Type *Ty, SPIRVType *RetType,
789
const SmallVectorImpl<SPIRVType *> &ArgTypes,
790
MachineIRBuilder &MIRBuilder) {
791
Register Reg = DT.find(Ty, &MIRBuilder.getMF());
792
if (Reg.isValid())
793
return getSPIRVTypeForVReg(Reg);
794
SPIRVType *SpirvType = getOpTypeFunction(RetType, ArgTypes, MIRBuilder);
795
DT.add(Ty, CurMF, getSPIRVTypeID(SpirvType));
796
return finishCreatingSPIRVType(Ty, SpirvType);
797
}
798
799
SPIRVType *SPIRVGlobalRegistry::findSPIRVType(
800
const Type *Ty, MachineIRBuilder &MIRBuilder,
801
SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
802
Ty = adjustIntTypeByWidth(Ty);
803
Register Reg = DT.find(Ty, &MIRBuilder.getMF());
804
if (Reg.isValid())
805
return getSPIRVTypeForVReg(Reg);
806
if (ForwardPointerTypes.contains(Ty))
807
return ForwardPointerTypes[Ty];
808
return restOfCreateSPIRVType(Ty, MIRBuilder, AccQual, EmitIR);
809
}
810
811
Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const {
812
assert(SpirvType && "Attempting to get type id for nullptr type.");
813
if (SpirvType->getOpcode() == SPIRV::OpTypeForwardPointer)
814
return SpirvType->uses().begin()->getReg();
815
return SpirvType->defs().begin()->getReg();
816
}
817
818
// We need to use a new LLVM integer type if there is a mismatch between
819
// number of bits in LLVM and SPIRV integer types to let DuplicateTracker
820
// ensure uniqueness of a SPIRV type by the corresponding LLVM type. Without
821
// such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create the
822
// same "OpTypeInt 8" type for a series of LLVM integer types with number of
823
// bits less than 8. This would lead to duplicate type definitions
824
// eventually due to the method that DuplicateTracker utilizes to reason
825
// about uniqueness of type records.
826
const Type *SPIRVGlobalRegistry::adjustIntTypeByWidth(const Type *Ty) const {
827
if (auto IType = dyn_cast<IntegerType>(Ty)) {
828
unsigned SrcBitWidth = IType->getBitWidth();
829
if (SrcBitWidth > 1) {
830
unsigned BitWidth = adjustOpTypeIntWidth(SrcBitWidth);
831
// Maybe change source LLVM type to keep DuplicateTracker consistent.
832
if (SrcBitWidth != BitWidth)
833
Ty = IntegerType::get(Ty->getContext(), BitWidth);
834
}
835
}
836
return Ty;
837
}
838
839
SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
840
const Type *Ty, MachineIRBuilder &MIRBuilder,
841
SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
842
if (isSpecialOpaqueType(Ty))
843
return getOrCreateSpecialType(Ty, MIRBuilder, AccQual);
844
auto &TypeToSPIRVTypeMap = DT.getTypes()->getAllUses();
845
auto t = TypeToSPIRVTypeMap.find(Ty);
846
if (t != TypeToSPIRVTypeMap.end()) {
847
auto tt = t->second.find(&MIRBuilder.getMF());
848
if (tt != t->second.end())
849
return getSPIRVTypeForVReg(tt->second);
850
}
851
852
if (auto IType = dyn_cast<IntegerType>(Ty)) {
853
const unsigned Width = IType->getBitWidth();
854
return Width == 1 ? getOpTypeBool(MIRBuilder)
855
: getOpTypeInt(Width, MIRBuilder, false);
856
}
857
if (Ty->isFloatingPointTy())
858
return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
859
if (Ty->isVoidTy())
860
return getOpTypeVoid(MIRBuilder);
861
if (Ty->isVectorTy()) {
862
SPIRVType *El =
863
findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), MIRBuilder);
864
return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El,
865
MIRBuilder);
866
}
867
if (Ty->isArrayTy()) {
868
SPIRVType *El = findSPIRVType(Ty->getArrayElementType(), MIRBuilder);
869
return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR);
870
}
871
if (auto SType = dyn_cast<StructType>(Ty)) {
872
if (SType->isOpaque())
873
return getOpTypeOpaque(SType, MIRBuilder);
874
return getOpTypeStruct(SType, MIRBuilder, EmitIR);
875
}
876
if (auto FType = dyn_cast<FunctionType>(Ty)) {
877
SPIRVType *RetTy = findSPIRVType(FType->getReturnType(), MIRBuilder);
878
SmallVector<SPIRVType *, 4> ParamTypes;
879
for (const auto &t : FType->params()) {
880
ParamTypes.push_back(findSPIRVType(t, MIRBuilder));
881
}
882
return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder);
883
}
884
unsigned AddrSpace = 0xFFFF;
885
if (auto PType = dyn_cast<TypedPointerType>(Ty))
886
AddrSpace = PType->getAddressSpace();
887
else if (auto PType = dyn_cast<PointerType>(Ty))
888
AddrSpace = PType->getAddressSpace();
889
else
890
report_fatal_error("Unable to convert LLVM type to SPIRVType", true);
891
892
SPIRVType *SpvElementType = nullptr;
893
if (auto PType = dyn_cast<TypedPointerType>(Ty))
894
SpvElementType = getOrCreateSPIRVType(PType->getElementType(), MIRBuilder,
895
AccQual, EmitIR);
896
else
897
SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
898
899
// Get access to information about available extensions
900
const SPIRVSubtarget *ST =
901
static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
902
auto SC = addressSpaceToStorageClass(AddrSpace, *ST);
903
// Null pointer means we have a loop in type definitions, make and
904
// return corresponding OpTypeForwardPointer.
905
if (SpvElementType == nullptr) {
906
if (!ForwardPointerTypes.contains(Ty))
907
ForwardPointerTypes[Ty] = getOpTypeForwardPointer(SC, MIRBuilder);
908
return ForwardPointerTypes[Ty];
909
}
910
// If we have forward pointer associated with this type, use its register
911
// operand to create OpTypePointer.
912
if (ForwardPointerTypes.contains(Ty)) {
913
Register Reg = getSPIRVTypeID(ForwardPointerTypes[Ty]);
914
return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg);
915
}
916
917
return getOrCreateSPIRVPointerType(SpvElementType, MIRBuilder, SC);
918
}
919
920
SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
921
const Type *Ty, MachineIRBuilder &MIRBuilder,
922
SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
923
if (TypesInProcessing.count(Ty) && !isPointerTy(Ty))
924
return nullptr;
925
TypesInProcessing.insert(Ty);
926
SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
927
TypesInProcessing.erase(Ty);
928
VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
929
SPIRVToLLVMType[SpirvType] = unifyPtrType(Ty);
930
Register Reg = DT.find(Ty, &MIRBuilder.getMF());
931
// Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type
932
// will be added later. For special types it is already added to DT.
933
if (SpirvType->getOpcode() != SPIRV::OpTypeForwardPointer && !Reg.isValid() &&
934
!isSpecialOpaqueType(Ty)) {
935
if (!isPointerTy(Ty))
936
DT.add(Ty, &MIRBuilder.getMF(), getSPIRVTypeID(SpirvType));
937
else if (isTypedPointerTy(Ty))
938
DT.add(cast<TypedPointerType>(Ty)->getElementType(),
939
getPointerAddressSpace(Ty), &MIRBuilder.getMF(),
940
getSPIRVTypeID(SpirvType));
941
else
942
DT.add(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
943
getPointerAddressSpace(Ty), &MIRBuilder.getMF(),
944
getSPIRVTypeID(SpirvType));
945
}
946
947
return SpirvType;
948
}
949
950
SPIRVType *
951
SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg,
952
const MachineFunction *MF) const {
953
auto t = VRegToTypeMap.find(MF ? MF : CurMF);
954
if (t != VRegToTypeMap.end()) {
955
auto tt = t->second.find(VReg);
956
if (tt != t->second.end())
957
return tt->second;
958
}
959
return nullptr;
960
}
961
962
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
963
const Type *Ty, MachineIRBuilder &MIRBuilder,
964
SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
965
Register Reg;
966
if (!isPointerTy(Ty)) {
967
Ty = adjustIntTypeByWidth(Ty);
968
Reg = DT.find(Ty, &MIRBuilder.getMF());
969
} else if (isTypedPointerTy(Ty)) {
970
Reg = DT.find(cast<TypedPointerType>(Ty)->getElementType(),
971
getPointerAddressSpace(Ty), &MIRBuilder.getMF());
972
} else {
973
Reg =
974
DT.find(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
975
getPointerAddressSpace(Ty), &MIRBuilder.getMF());
976
}
977
978
if (Reg.isValid() && !isSpecialOpaqueType(Ty))
979
return getSPIRVTypeForVReg(Reg);
980
TypesInProcessing.clear();
981
SPIRVType *STy = restOfCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
982
// Create normal pointer types for the corresponding OpTypeForwardPointers.
983
for (auto &CU : ForwardPointerTypes) {
984
const Type *Ty2 = CU.first;
985
SPIRVType *STy2 = CU.second;
986
if ((Reg = DT.find(Ty2, &MIRBuilder.getMF())).isValid())
987
STy2 = getSPIRVTypeForVReg(Reg);
988
else
989
STy2 = restOfCreateSPIRVType(Ty2, MIRBuilder, AccessQual, EmitIR);
990
if (Ty == Ty2)
991
STy = STy2;
992
}
993
ForwardPointerTypes.clear();
994
return STy;
995
}
996
997
bool SPIRVGlobalRegistry::isScalarOfType(Register VReg,
998
unsigned TypeOpcode) const {
999
SPIRVType *Type = getSPIRVTypeForVReg(VReg);
1000
assert(Type && "isScalarOfType VReg has no type assigned");
1001
return Type->getOpcode() == TypeOpcode;
1002
}
1003
1004
bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg,
1005
unsigned TypeOpcode) const {
1006
SPIRVType *Type = getSPIRVTypeForVReg(VReg);
1007
assert(Type && "isScalarOrVectorOfType VReg has no type assigned");
1008
if (Type->getOpcode() == TypeOpcode)
1009
return true;
1010
if (Type->getOpcode() == SPIRV::OpTypeVector) {
1011
Register ScalarTypeVReg = Type->getOperand(1).getReg();
1012
SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg);
1013
return ScalarType->getOpcode() == TypeOpcode;
1014
}
1015
return false;
1016
}
1017
1018
unsigned
1019
SPIRVGlobalRegistry::getScalarOrVectorComponentCount(Register VReg) const {
1020
return getScalarOrVectorComponentCount(getSPIRVTypeForVReg(VReg));
1021
}
1022
1023
unsigned
1024
SPIRVGlobalRegistry::getScalarOrVectorComponentCount(SPIRVType *Type) const {
1025
if (!Type)
1026
return 0;
1027
return Type->getOpcode() == SPIRV::OpTypeVector
1028
? static_cast<unsigned>(Type->getOperand(2).getImm())
1029
: 1;
1030
}
1031
1032
unsigned
1033
SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const {
1034
assert(Type && "Invalid Type pointer");
1035
if (Type->getOpcode() == SPIRV::OpTypeVector) {
1036
auto EleTypeReg = Type->getOperand(1).getReg();
1037
Type = getSPIRVTypeForVReg(EleTypeReg);
1038
}
1039
if (Type->getOpcode() == SPIRV::OpTypeInt ||
1040
Type->getOpcode() == SPIRV::OpTypeFloat)
1041
return Type->getOperand(1).getImm();
1042
if (Type->getOpcode() == SPIRV::OpTypeBool)
1043
return 1;
1044
llvm_unreachable("Attempting to get bit width of non-integer/float type.");
1045
}
1046
1047
unsigned SPIRVGlobalRegistry::getNumScalarOrVectorTotalBitWidth(
1048
const SPIRVType *Type) const {
1049
assert(Type && "Invalid Type pointer");
1050
unsigned NumElements = 1;
1051
if (Type->getOpcode() == SPIRV::OpTypeVector) {
1052
NumElements = static_cast<unsigned>(Type->getOperand(2).getImm());
1053
Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());
1054
}
1055
return Type->getOpcode() == SPIRV::OpTypeInt ||
1056
Type->getOpcode() == SPIRV::OpTypeFloat
1057
? NumElements * Type->getOperand(1).getImm()
1058
: 0;
1059
}
1060
1061
const SPIRVType *SPIRVGlobalRegistry::retrieveScalarOrVectorIntType(
1062
const SPIRVType *Type) const {
1063
if (Type && Type->getOpcode() == SPIRV::OpTypeVector)
1064
Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());
1065
return Type && Type->getOpcode() == SPIRV::OpTypeInt ? Type : nullptr;
1066
}
1067
1068
bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
1069
const SPIRVType *IntType = retrieveScalarOrVectorIntType(Type);
1070
return IntType && IntType->getOperand(2).getImm() != 0;
1071
}
1072
1073
SPIRVType *SPIRVGlobalRegistry::getPointeeType(SPIRVType *PtrType) {
1074
return PtrType && PtrType->getOpcode() == SPIRV::OpTypePointer
1075
? getSPIRVTypeForVReg(PtrType->getOperand(2).getReg())
1076
: nullptr;
1077
}
1078
1079
unsigned SPIRVGlobalRegistry::getPointeeTypeOp(Register PtrReg) {
1080
SPIRVType *ElemType = getPointeeType(getSPIRVTypeForVReg(PtrReg));
1081
return ElemType ? ElemType->getOpcode() : 0;
1082
}
1083
1084
bool SPIRVGlobalRegistry::isBitcastCompatible(const SPIRVType *Type1,
1085
const SPIRVType *Type2) const {
1086
if (!Type1 || !Type2)
1087
return false;
1088
auto Op1 = Type1->getOpcode(), Op2 = Type2->getOpcode();
1089
// Ignore difference between <1.5 and >=1.5 protocol versions:
1090
// it's valid if either Result Type or Operand is a pointer, and the other
1091
// is a pointer, an integer scalar, or an integer vector.
1092
if (Op1 == SPIRV::OpTypePointer &&
1093
(Op2 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type2)))
1094
return true;
1095
if (Op2 == SPIRV::OpTypePointer &&
1096
(Op1 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type1)))
1097
return true;
1098
unsigned Bits1 = getNumScalarOrVectorTotalBitWidth(Type1),
1099
Bits2 = getNumScalarOrVectorTotalBitWidth(Type2);
1100
return Bits1 > 0 && Bits1 == Bits2;
1101
}
1102
1103
SPIRV::StorageClass::StorageClass
1104
SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const {
1105
SPIRVType *Type = getSPIRVTypeForVReg(VReg);
1106
assert(Type && Type->getOpcode() == SPIRV::OpTypePointer &&
1107
Type->getOperand(1).isImm() && "Pointer type is expected");
1108
return static_cast<SPIRV::StorageClass::StorageClass>(
1109
Type->getOperand(1).getImm());
1110
}
1111
1112
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage(
1113
MachineIRBuilder &MIRBuilder, SPIRVType *SampledType, SPIRV::Dim::Dim Dim,
1114
uint32_t Depth, uint32_t Arrayed, uint32_t Multisampled, uint32_t Sampled,
1115
SPIRV::ImageFormat::ImageFormat ImageFormat,
1116
SPIRV::AccessQualifier::AccessQualifier AccessQual) {
1117
auto TD = SPIRV::make_descr_image(SPIRVToLLVMType.lookup(SampledType), Dim,
1118
Depth, Arrayed, Multisampled, Sampled,
1119
ImageFormat, AccessQual);
1120
if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
1121
return Res;
1122
Register ResVReg = createTypeVReg(MIRBuilder);
1123
DT.add(TD, &MIRBuilder.getMF(), ResVReg);
1124
return MIRBuilder.buildInstr(SPIRV::OpTypeImage)
1125
.addDef(ResVReg)
1126
.addUse(getSPIRVTypeID(SampledType))
1127
.addImm(Dim)
1128
.addImm(Depth) // Depth (whether or not it is a Depth image).
1129
.addImm(Arrayed) // Arrayed.
1130
.addImm(Multisampled) // Multisampled (0 = only single-sample).
1131
.addImm(Sampled) // Sampled (0 = usage known at runtime).
1132
.addImm(ImageFormat)
1133
.addImm(AccessQual);
1134
}
1135
1136
SPIRVType *
1137
SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) {
1138
auto TD = SPIRV::make_descr_sampler();
1139
if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
1140
return Res;
1141
Register ResVReg = createTypeVReg(MIRBuilder);
1142
DT.add(TD, &MIRBuilder.getMF(), ResVReg);
1143
return MIRBuilder.buildInstr(SPIRV::OpTypeSampler).addDef(ResVReg);
1144
}
1145
1146
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe(
1147
MachineIRBuilder &MIRBuilder,
1148
SPIRV::AccessQualifier::AccessQualifier AccessQual) {
1149
auto TD = SPIRV::make_descr_pipe(AccessQual);
1150
if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
1151
return Res;
1152
Register ResVReg = createTypeVReg(MIRBuilder);
1153
DT.add(TD, &MIRBuilder.getMF(), ResVReg);
1154
return MIRBuilder.buildInstr(SPIRV::OpTypePipe)
1155
.addDef(ResVReg)
1156
.addImm(AccessQual);
1157
}
1158
1159
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent(
1160
MachineIRBuilder &MIRBuilder) {
1161
auto TD = SPIRV::make_descr_event();
1162
if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
1163
return Res;
1164
Register ResVReg = createTypeVReg(MIRBuilder);
1165
DT.add(TD, &MIRBuilder.getMF(), ResVReg);
1166
return MIRBuilder.buildInstr(SPIRV::OpTypeDeviceEvent).addDef(ResVReg);
1167
}
1168
1169
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage(
1170
SPIRVType *ImageType, MachineIRBuilder &MIRBuilder) {
1171
auto TD = SPIRV::make_descr_sampled_image(
1172
SPIRVToLLVMType.lookup(MIRBuilder.getMF().getRegInfo().getVRegDef(
1173
ImageType->getOperand(1).getReg())),
1174
ImageType);
1175
if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
1176
return Res;
1177
Register ResVReg = createTypeVReg(MIRBuilder);
1178
DT.add(TD, &MIRBuilder.getMF(), ResVReg);
1179
return MIRBuilder.buildInstr(SPIRV::OpTypeSampledImage)
1180
.addDef(ResVReg)
1181
.addUse(getSPIRVTypeID(ImageType));
1182
}
1183
1184
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
1185
MachineIRBuilder &MIRBuilder, const TargetExtType *ExtensionType,
1186
const SPIRVType *ElemType, uint32_t Scope, uint32_t Rows, uint32_t Columns,
1187
uint32_t Use) {
1188
Register ResVReg = DT.find(ExtensionType, &MIRBuilder.getMF());
1189
if (ResVReg.isValid())
1190
return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);
1191
ResVReg = createTypeVReg(MIRBuilder);
1192
SPIRVType *SpirvTy =
1193
MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR)
1194
.addDef(ResVReg)
1195
.addUse(getSPIRVTypeID(ElemType))
1196
.addUse(buildConstantInt(Scope, MIRBuilder, nullptr, true))
1197
.addUse(buildConstantInt(Rows, MIRBuilder, nullptr, true))
1198
.addUse(buildConstantInt(Columns, MIRBuilder, nullptr, true))
1199
.addUse(buildConstantInt(Use, MIRBuilder, nullptr, true));
1200
DT.add(ExtensionType, &MIRBuilder.getMF(), ResVReg);
1201
return SpirvTy;
1202
}
1203
1204
SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode(
1205
const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode) {
1206
Register ResVReg = DT.find(Ty, &MIRBuilder.getMF());
1207
if (ResVReg.isValid())
1208
return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);
1209
ResVReg = createTypeVReg(MIRBuilder);
1210
SPIRVType *SpirvTy = MIRBuilder.buildInstr(Opcode).addDef(ResVReg);
1211
DT.add(Ty, &MIRBuilder.getMF(), ResVReg);
1212
return SpirvTy;
1213
}
1214
1215
const MachineInstr *
1216
SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
1217
MachineIRBuilder &MIRBuilder) {
1218
Register Reg = DT.find(TD, &MIRBuilder.getMF());
1219
if (Reg.isValid())
1220
return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(Reg);
1221
return nullptr;
1222
}
1223
1224
// Returns nullptr if unable to recognize SPIRV type name
1225
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
1226
StringRef TypeStr, MachineIRBuilder &MIRBuilder,
1227
SPIRV::StorageClass::StorageClass SC,
1228
SPIRV::AccessQualifier::AccessQualifier AQ) {
1229
unsigned VecElts = 0;
1230
auto &Ctx = MIRBuilder.getMF().getFunction().getContext();
1231
1232
// Parse strings representing either a SPIR-V or OpenCL builtin type.
1233
if (hasBuiltinTypePrefix(TypeStr))
1234
return getOrCreateSPIRVType(SPIRV::parseBuiltinTypeNameToTargetExtType(
1235
TypeStr.str(), MIRBuilder.getContext()),
1236
MIRBuilder, AQ);
1237
1238
// Parse type name in either "typeN" or "type vector[N]" format, where
1239
// N is the number of elements of the vector.
1240
Type *Ty;
1241
1242
Ty = parseBasicTypeName(TypeStr, Ctx);
1243
if (!Ty)
1244
// Unable to recognize SPIRV type name
1245
return nullptr;
1246
1247
auto SpirvTy = getOrCreateSPIRVType(Ty, MIRBuilder, AQ);
1248
1249
// Handle "type*" or "type* vector[N]".
1250
if (TypeStr.starts_with("*")) {
1251
SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);
1252
TypeStr = TypeStr.substr(strlen("*"));
1253
}
1254
1255
// Handle "typeN*" or "type vector[N]*".
1256
bool IsPtrToVec = TypeStr.consume_back("*");
1257
1258
if (TypeStr.consume_front(" vector[")) {
1259
TypeStr = TypeStr.substr(0, TypeStr.find(']'));
1260
}
1261
TypeStr.getAsInteger(10, VecElts);
1262
if (VecElts > 0)
1263
SpirvTy = getOrCreateSPIRVVectorType(SpirvTy, VecElts, MIRBuilder);
1264
1265
if (IsPtrToVec)
1266
SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);
1267
1268
return SpirvTy;
1269
}
1270
1271
SPIRVType *
1272
SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth,
1273
MachineIRBuilder &MIRBuilder) {
1274
return getOrCreateSPIRVType(
1275
IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth),
1276
MIRBuilder);
1277
}
1278
1279
SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
1280
SPIRVType *SpirvType) {
1281
assert(CurMF == SpirvType->getMF());
1282
VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType;
1283
SPIRVToLLVMType[SpirvType] = unifyPtrType(LLVMTy);
1284
return SpirvType;
1285
}
1286
1287
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
1288
MachineInstr &I,
1289
const SPIRVInstrInfo &TII,
1290
unsigned SPIRVOPcode,
1291
Type *LLVMTy) {
1292
Register Reg = DT.find(LLVMTy, CurMF);
1293
if (Reg.isValid())
1294
return getSPIRVTypeForVReg(Reg);
1295
MachineBasicBlock &BB = *I.getParent();
1296
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRVOPcode))
1297
.addDef(createTypeVReg(CurMF->getRegInfo()))
1298
.addImm(BitWidth)
1299
.addImm(0);
1300
DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));
1301
return finishCreatingSPIRVType(LLVMTy, MIB);
1302
}
1303
1304
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
1305
unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
1306
// Maybe adjust bit width to keep DuplicateTracker consistent. Without
1307
// such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create, for
1308
// example, the same "OpTypeInt 8" type for a series of LLVM integer types
1309
// with number of bits less than 8, causing duplicate type definitions.
1310
BitWidth = adjustOpTypeIntWidth(BitWidth);
1311
Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
1312
return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeInt, LLVMTy);
1313
}
1314
1315
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
1316
unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
1317
LLVMContext &Ctx = CurMF->getFunction().getContext();
1318
Type *LLVMTy;
1319
switch (BitWidth) {
1320
case 16:
1321
LLVMTy = Type::getHalfTy(Ctx);
1322
break;
1323
case 32:
1324
LLVMTy = Type::getFloatTy(Ctx);
1325
break;
1326
case 64:
1327
LLVMTy = Type::getDoubleTy(Ctx);
1328
break;
1329
default:
1330
llvm_unreachable("Bit width is of unexpected size.");
1331
}
1332
return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeFloat, LLVMTy);
1333
}
1334
1335
SPIRVType *
1336
SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) {
1337
return getOrCreateSPIRVType(
1338
IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1),
1339
MIRBuilder);
1340
}
1341
1342
SPIRVType *
1343
SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I,
1344
const SPIRVInstrInfo &TII) {
1345
Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), 1);
1346
Register Reg = DT.find(LLVMTy, CurMF);
1347
if (Reg.isValid())
1348
return getSPIRVTypeForVReg(Reg);
1349
MachineBasicBlock &BB = *I.getParent();
1350
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeBool))
1351
.addDef(createTypeVReg(CurMF->getRegInfo()));
1352
DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));
1353
return finishCreatingSPIRVType(LLVMTy, MIB);
1354
}
1355
1356
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1357
SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) {
1358
return getOrCreateSPIRVType(
1359
FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
1360
NumElements),
1361
MIRBuilder);
1362
}
1363
1364
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1365
SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
1366
const SPIRVInstrInfo &TII) {
1367
Type *LLVMTy = FixedVectorType::get(
1368
const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
1369
Register Reg = DT.find(LLVMTy, CurMF);
1370
if (Reg.isValid())
1371
return getSPIRVTypeForVReg(Reg);
1372
MachineBasicBlock &BB = *I.getParent();
1373
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector))
1374
.addDef(createTypeVReg(CurMF->getRegInfo()))
1375
.addUse(getSPIRVTypeID(BaseType))
1376
.addImm(NumElements);
1377
DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));
1378
return finishCreatingSPIRVType(LLVMTy, MIB);
1379
}
1380
1381
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVArrayType(
1382
SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
1383
const SPIRVInstrInfo &TII) {
1384
Type *LLVMTy = ArrayType::get(
1385
const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
1386
Register Reg = DT.find(LLVMTy, CurMF);
1387
if (Reg.isValid())
1388
return getSPIRVTypeForVReg(Reg);
1389
MachineBasicBlock &BB = *I.getParent();
1390
SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(32, I, TII);
1391
Register Len = getOrCreateConstInt(NumElements, I, SpirvType, TII);
1392
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeArray))
1393
.addDef(createTypeVReg(CurMF->getRegInfo()))
1394
.addUse(getSPIRVTypeID(BaseType))
1395
.addUse(Len);
1396
DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));
1397
return finishCreatingSPIRVType(LLVMTy, MIB);
1398
}
1399
1400
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1401
SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
1402
SPIRV::StorageClass::StorageClass SC) {
1403
const Type *PointerElementType = getTypeForSPIRVType(BaseType);
1404
unsigned AddressSpace = storageClassToAddressSpace(SC);
1405
Type *LLVMTy = TypedPointerType::get(const_cast<Type *>(PointerElementType),
1406
AddressSpace);
1407
// check if this type is already available
1408
Register Reg = DT.find(PointerElementType, AddressSpace, CurMF);
1409
if (Reg.isValid())
1410
return getSPIRVTypeForVReg(Reg);
1411
// create a new type
1412
auto MIB = BuildMI(MIRBuilder.getMBB(), MIRBuilder.getInsertPt(),
1413
MIRBuilder.getDebugLoc(),
1414
MIRBuilder.getTII().get(SPIRV::OpTypePointer))
1415
.addDef(createTypeVReg(CurMF->getRegInfo()))
1416
.addImm(static_cast<uint32_t>(SC))
1417
.addUse(getSPIRVTypeID(BaseType));
1418
DT.add(PointerElementType, AddressSpace, CurMF, getSPIRVTypeID(MIB));
1419
return finishCreatingSPIRVType(LLVMTy, MIB);
1420
}
1421
1422
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1423
SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &,
1424
SPIRV::StorageClass::StorageClass SC) {
1425
MachineIRBuilder MIRBuilder(I);
1426
return getOrCreateSPIRVPointerType(BaseType, MIRBuilder, SC);
1427
}
1428
1429
Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
1430
SPIRVType *SpvType,
1431
const SPIRVInstrInfo &TII) {
1432
assert(SpvType);
1433
const Type *LLVMTy = getTypeForSPIRVType(SpvType);
1434
assert(LLVMTy);
1435
// Find a constant in DT or build a new one.
1436
UndefValue *UV = UndefValue::get(const_cast<Type *>(LLVMTy));
1437
Register Res = DT.find(UV, CurMF);
1438
if (Res.isValid())
1439
return Res;
1440
LLT LLTy = LLT::scalar(32);
1441
Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
1442
CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
1443
assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
1444
DT.add(UV, CurMF, Res);
1445
1446
MachineInstrBuilder MIB;
1447
MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpUndef))
1448
.addDef(Res)
1449
.addUse(getSPIRVTypeID(SpvType));
1450
const auto &ST = CurMF->getSubtarget();
1451
constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
1452
*ST.getRegisterInfo(), *ST.getRegBankInfo());
1453
return Res;
1454
}
1455
1456