Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
35266 views
1
//===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// The analysis collects instructions that should be output at the module level
10
// and performs the global register numbering.
11
//
12
// The results of this analysis are used in AsmPrinter to rename registers
13
// globally and to output required instructions at the module level.
14
//
15
//===----------------------------------------------------------------------===//
16
17
#include "SPIRVModuleAnalysis.h"
18
#include "MCTargetDesc/SPIRVBaseInfo.h"
19
#include "MCTargetDesc/SPIRVMCTargetDesc.h"
20
#include "SPIRV.h"
21
#include "SPIRVSubtarget.h"
22
#include "SPIRVTargetMachine.h"
23
#include "SPIRVUtils.h"
24
#include "TargetInfo/SPIRVTargetInfo.h"
25
#include "llvm/ADT/STLExtras.h"
26
#include "llvm/CodeGen/MachineModuleInfo.h"
27
#include "llvm/CodeGen/TargetPassConfig.h"
28
29
using namespace llvm;
30
31
#define DEBUG_TYPE "spirv-module-analysis"
32
33
static cl::opt<bool>
34
SPVDumpDeps("spv-dump-deps",
35
cl::desc("Dump MIR with SPIR-V dependencies info"),
36
cl::Optional, cl::init(false));
37
38
static cl::list<SPIRV::Capability::Capability>
39
AvoidCapabilities("avoid-spirv-capabilities",
40
cl::desc("SPIR-V capabilities to avoid if there are "
41
"other options enabling a feature"),
42
cl::ZeroOrMore, cl::Hidden,
43
cl::values(clEnumValN(SPIRV::Capability::Shader, "Shader",
44
"SPIR-V Shader capability")));
45
// Use sets instead of cl::list to check "if contains" condition
46
struct AvoidCapabilitiesSet {
47
SmallSet<SPIRV::Capability::Capability, 4> S;
48
AvoidCapabilitiesSet() {
49
for (auto Cap : AvoidCapabilities)
50
S.insert(Cap);
51
}
52
};
53
54
char llvm::SPIRVModuleAnalysis::ID = 0;
55
56
namespace llvm {
57
void initializeSPIRVModuleAnalysisPass(PassRegistry &);
58
} // namespace llvm
59
60
INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true,
61
true)
62
63
// Retrieve an unsigned from an MDNode with a list of them as operands.
64
static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex,
65
unsigned DefaultVal = 0) {
66
if (MdNode && OpIndex < MdNode->getNumOperands()) {
67
const auto &Op = MdNode->getOperand(OpIndex);
68
return mdconst::extract<ConstantInt>(Op)->getZExtValue();
69
}
70
return DefaultVal;
71
}
72
73
static SPIRV::Requirements
74
getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category,
75
unsigned i, const SPIRVSubtarget &ST,
76
SPIRV::RequirementHandler &Reqs) {
77
static AvoidCapabilitiesSet
78
AvoidCaps; // contains capabilities to avoid if there is another option
79
80
VersionTuple ReqMinVer = getSymbolicOperandMinVersion(Category, i);
81
VersionTuple ReqMaxVer = getSymbolicOperandMaxVersion(Category, i);
82
VersionTuple SPIRVVersion = ST.getSPIRVVersion();
83
bool MinVerOK = SPIRVVersion.empty() || SPIRVVersion >= ReqMinVer;
84
bool MaxVerOK =
85
ReqMaxVer.empty() || SPIRVVersion.empty() || SPIRVVersion <= ReqMaxVer;
86
CapabilityList ReqCaps = getSymbolicOperandCapabilities(Category, i);
87
ExtensionList ReqExts = getSymbolicOperandExtensions(Category, i);
88
if (ReqCaps.empty()) {
89
if (ReqExts.empty()) {
90
if (MinVerOK && MaxVerOK)
91
return {true, {}, {}, ReqMinVer, ReqMaxVer};
92
return {false, {}, {}, VersionTuple(), VersionTuple()};
93
}
94
} else if (MinVerOK && MaxVerOK) {
95
if (ReqCaps.size() == 1) {
96
auto Cap = ReqCaps[0];
97
if (Reqs.isCapabilityAvailable(Cap))
98
return {true, {Cap}, ReqExts, ReqMinVer, ReqMaxVer};
99
} else {
100
// By SPIR-V specification: "If an instruction, enumerant, or other
101
// feature specifies multiple enabling capabilities, only one such
102
// capability needs to be declared to use the feature." However, one
103
// capability may be preferred over another. We use command line
104
// argument(s) and AvoidCapabilities to avoid selection of certain
105
// capabilities if there are other options.
106
CapabilityList UseCaps;
107
for (auto Cap : ReqCaps)
108
if (Reqs.isCapabilityAvailable(Cap))
109
UseCaps.push_back(Cap);
110
for (size_t i = 0, Sz = UseCaps.size(); i < Sz; ++i) {
111
auto Cap = UseCaps[i];
112
if (i == Sz - 1 || !AvoidCaps.S.contains(Cap))
113
return {true, {Cap}, ReqExts, ReqMinVer, ReqMaxVer};
114
}
115
}
116
}
117
// If there are no capabilities, or we can't satisfy the version or
118
// capability requirements, use the list of extensions (if the subtarget
119
// can handle them all).
120
if (llvm::all_of(ReqExts, [&ST](const SPIRV::Extension::Extension &Ext) {
121
return ST.canUseExtension(Ext);
122
})) {
123
return {true,
124
{},
125
ReqExts,
126
VersionTuple(),
127
VersionTuple()}; // TODO: add versions to extensions.
128
}
129
return {false, {}, {}, VersionTuple(), VersionTuple()};
130
}
131
132
void SPIRVModuleAnalysis::setBaseInfo(const Module &M) {
133
MAI.MaxID = 0;
134
for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++)
135
MAI.MS[i].clear();
136
MAI.RegisterAliasTable.clear();
137
MAI.InstrsToDelete.clear();
138
MAI.FuncMap.clear();
139
MAI.GlobalVarList.clear();
140
MAI.ExtInstSetMap.clear();
141
MAI.Reqs.clear();
142
MAI.Reqs.initAvailableCapabilities(*ST);
143
144
// TODO: determine memory model and source language from the configuratoin.
145
if (auto MemModel = M.getNamedMetadata("spirv.MemoryModel")) {
146
auto MemMD = MemModel->getOperand(0);
147
MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>(
148
getMetadataUInt(MemMD, 0));
149
MAI.Mem =
150
static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MemMD, 1));
151
} else {
152
// TODO: Add support for VulkanMemoryModel.
153
MAI.Mem = ST->isOpenCLEnv() ? SPIRV::MemoryModel::OpenCL
154
: SPIRV::MemoryModel::GLSL450;
155
if (MAI.Mem == SPIRV::MemoryModel::OpenCL) {
156
unsigned PtrSize = ST->getPointerSize();
157
MAI.Addr = PtrSize == 32 ? SPIRV::AddressingModel::Physical32
158
: PtrSize == 64 ? SPIRV::AddressingModel::Physical64
159
: SPIRV::AddressingModel::Logical;
160
} else {
161
// TODO: Add support for PhysicalStorageBufferAddress.
162
MAI.Addr = SPIRV::AddressingModel::Logical;
163
}
164
}
165
// Get the OpenCL version number from metadata.
166
// TODO: support other source languages.
167
if (auto VerNode = M.getNamedMetadata("opencl.ocl.version")) {
168
MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C;
169
// Construct version literal in accordance with SPIRV-LLVM-Translator.
170
// TODO: support multiple OCL version metadata.
171
assert(VerNode->getNumOperands() > 0 && "Invalid SPIR");
172
auto VersionMD = VerNode->getOperand(0);
173
unsigned MajorNum = getMetadataUInt(VersionMD, 0, 2);
174
unsigned MinorNum = getMetadataUInt(VersionMD, 1);
175
unsigned RevNum = getMetadataUInt(VersionMD, 2);
176
// Prevent Major part of OpenCL version to be 0
177
MAI.SrcLangVersion =
178
(std::max(1U, MajorNum) * 100 + MinorNum) * 1000 + RevNum;
179
} else {
180
// If there is no information about OpenCL version we are forced to generate
181
// OpenCL 1.0 by default for the OpenCL environment to avoid puzzling
182
// run-times with Unknown/0.0 version output. For a reference, LLVM-SPIRV
183
// Translator avoids potential issues with run-times in a similar manner.
184
if (ST->isOpenCLEnv()) {
185
MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_CPP;
186
MAI.SrcLangVersion = 100000;
187
} else {
188
MAI.SrcLang = SPIRV::SourceLanguage::Unknown;
189
MAI.SrcLangVersion = 0;
190
}
191
}
192
193
if (auto ExtNode = M.getNamedMetadata("opencl.used.extensions")) {
194
for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) {
195
MDNode *MD = ExtNode->getOperand(I);
196
if (!MD || MD->getNumOperands() == 0)
197
continue;
198
for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J)
199
MAI.SrcExt.insert(cast<MDString>(MD->getOperand(J))->getString());
200
}
201
}
202
203
// Update required capabilities for this memory model, addressing model and
204
// source language.
205
MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand,
206
MAI.Mem, *ST);
207
MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::SourceLanguageOperand,
208
MAI.SrcLang, *ST);
209
MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
210
MAI.Addr, *ST);
211
212
if (ST->isOpenCLEnv()) {
213
// TODO: check if it's required by default.
214
MAI.ExtInstSetMap[static_cast<unsigned>(
215
SPIRV::InstructionSet::OpenCL_std)] =
216
Register::index2VirtReg(MAI.getNextID());
217
}
218
}
219
220
// Collect MI which defines the register in the given machine function.
221
static void collectDefInstr(Register Reg, const MachineFunction *MF,
222
SPIRV::ModuleAnalysisInfo *MAI,
223
SPIRV::ModuleSectionType MSType,
224
bool DoInsert = true) {
225
assert(MAI->hasRegisterAlias(MF, Reg) && "Cannot find register alias");
226
MachineInstr *MI = MF->getRegInfo().getUniqueVRegDef(Reg);
227
assert(MI && "There should be an instruction that defines the register");
228
MAI->setSkipEmission(MI);
229
if (DoInsert)
230
MAI->MS[MSType].push_back(MI);
231
}
232
233
void SPIRVModuleAnalysis::collectGlobalEntities(
234
const std::vector<SPIRV::DTSortableEntry *> &DepsGraph,
235
SPIRV::ModuleSectionType MSType,
236
std::function<bool(const SPIRV::DTSortableEntry *)> Pred,
237
bool UsePreOrder = false) {
238
DenseSet<const SPIRV::DTSortableEntry *> Visited;
239
for (const auto *E : DepsGraph) {
240
std::function<void(const SPIRV::DTSortableEntry *)> RecHoistUtil;
241
// NOTE: here we prefer recursive approach over iterative because
242
// we don't expect depchains long enough to cause SO.
243
RecHoistUtil = [MSType, UsePreOrder, &Visited, &Pred,
244
&RecHoistUtil](const SPIRV::DTSortableEntry *E) {
245
if (Visited.count(E) || !Pred(E))
246
return;
247
Visited.insert(E);
248
249
// Traversing deps graph in post-order allows us to get rid of
250
// register aliases preprocessing.
251
// But pre-order is required for correct processing of function
252
// declaration and arguments processing.
253
if (!UsePreOrder)
254
for (auto *S : E->getDeps())
255
RecHoistUtil(S);
256
257
Register GlobalReg = Register::index2VirtReg(MAI.getNextID());
258
bool IsFirst = true;
259
for (auto &U : *E) {
260
const MachineFunction *MF = U.first;
261
Register Reg = U.second;
262
MAI.setRegisterAlias(MF, Reg, GlobalReg);
263
if (!MF->getRegInfo().getUniqueVRegDef(Reg))
264
continue;
265
collectDefInstr(Reg, MF, &MAI, MSType, IsFirst);
266
IsFirst = false;
267
if (E->getIsGV())
268
MAI.GlobalVarList.push_back(MF->getRegInfo().getUniqueVRegDef(Reg));
269
}
270
271
if (UsePreOrder)
272
for (auto *S : E->getDeps())
273
RecHoistUtil(S);
274
};
275
RecHoistUtil(E);
276
}
277
}
278
279
// The function initializes global register alias table for types, consts,
280
// global vars and func decls and collects these instruction for output
281
// at module level. Also it collects explicit OpExtension/OpCapability
282
// instructions.
283
void SPIRVModuleAnalysis::processDefInstrs(const Module &M) {
284
std::vector<SPIRV::DTSortableEntry *> DepsGraph;
285
286
GR->buildDepsGraph(DepsGraph, SPVDumpDeps ? MMI : nullptr);
287
288
collectGlobalEntities(
289
DepsGraph, SPIRV::MB_TypeConstVars,
290
[](const SPIRV::DTSortableEntry *E) { return !E->getIsFunc(); });
291
292
for (auto F = M.begin(), E = M.end(); F != E; ++F) {
293
MachineFunction *MF = MMI->getMachineFunction(*F);
294
if (!MF)
295
continue;
296
// Iterate through and collect OpExtension/OpCapability instructions.
297
for (MachineBasicBlock &MBB : *MF) {
298
for (MachineInstr &MI : MBB) {
299
if (MI.getOpcode() == SPIRV::OpExtension) {
300
// Here, OpExtension just has a single enum operand, not a string.
301
auto Ext = SPIRV::Extension::Extension(MI.getOperand(0).getImm());
302
MAI.Reqs.addExtension(Ext);
303
MAI.setSkipEmission(&MI);
304
} else if (MI.getOpcode() == SPIRV::OpCapability) {
305
auto Cap = SPIRV::Capability::Capability(MI.getOperand(0).getImm());
306
MAI.Reqs.addCapability(Cap);
307
MAI.setSkipEmission(&MI);
308
}
309
}
310
}
311
}
312
313
collectGlobalEntities(
314
DepsGraph, SPIRV::MB_ExtFuncDecls,
315
[](const SPIRV::DTSortableEntry *E) { return E->getIsFunc(); }, true);
316
}
317
318
// Look for IDs declared with Import linkage, and map the corresponding function
319
// to the register defining that variable (which will usually be the result of
320
// an OpFunction). This lets us call externally imported functions using
321
// the correct ID registers.
322
void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI,
323
const Function *F) {
324
if (MI.getOpcode() == SPIRV::OpDecorate) {
325
// If it's got Import linkage.
326
auto Dec = MI.getOperand(1).getImm();
327
if (Dec == static_cast<unsigned>(SPIRV::Decoration::LinkageAttributes)) {
328
auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm();
329
if (Lnk == static_cast<unsigned>(SPIRV::LinkageType::Import)) {
330
// Map imported function name to function ID register.
331
const Function *ImportedFunc =
332
F->getParent()->getFunction(getStringImm(MI, 2));
333
Register Target = MI.getOperand(0).getReg();
334
MAI.FuncMap[ImportedFunc] = MAI.getRegisterAlias(MI.getMF(), Target);
335
}
336
}
337
} else if (MI.getOpcode() == SPIRV::OpFunction) {
338
// Record all internal OpFunction declarations.
339
Register Reg = MI.defs().begin()->getReg();
340
Register GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg);
341
assert(GlobalReg.isValid());
342
MAI.FuncMap[F] = GlobalReg;
343
}
344
}
345
346
// References to a function via function pointers generate virtual
347
// registers without a definition. We are able to resolve this
348
// reference using Globar Register info into an OpFunction instruction
349
// and replace dummy operands by the corresponding global register references.
350
void SPIRVModuleAnalysis::collectFuncPtrs() {
351
for (auto &MI : MAI.MS[SPIRV::MB_TypeConstVars])
352
if (MI->getOpcode() == SPIRV::OpConstantFunctionPointerINTEL)
353
collectFuncPtrs(MI);
354
}
355
356
void SPIRVModuleAnalysis::collectFuncPtrs(MachineInstr *MI) {
357
const MachineOperand *FunUse = &MI->getOperand(2);
358
if (const MachineOperand *FunDef = GR->getFunctionDefinitionByUse(FunUse)) {
359
const MachineInstr *FunDefMI = FunDef->getParent();
360
assert(FunDefMI->getOpcode() == SPIRV::OpFunction &&
361
"Constant function pointer must refer to function definition");
362
Register FunDefReg = FunDef->getReg();
363
Register GlobalFunDefReg =
364
MAI.getRegisterAlias(FunDefMI->getMF(), FunDefReg);
365
assert(GlobalFunDefReg.isValid() &&
366
"Function definition must refer to a global register");
367
Register FunPtrReg = FunUse->getReg();
368
MAI.setRegisterAlias(MI->getMF(), FunPtrReg, GlobalFunDefReg);
369
}
370
}
371
372
using InstrSignature = SmallVector<size_t>;
373
using InstrTraces = std::set<InstrSignature>;
374
375
// Returns a representation of an instruction as a vector of MachineOperand
376
// hash values, see llvm::hash_value(const MachineOperand &MO) for details.
377
// This creates a signature of the instruction with the same content
378
// that MachineOperand::isIdenticalTo uses for comparison.
379
static InstrSignature instrToSignature(MachineInstr &MI,
380
SPIRV::ModuleAnalysisInfo &MAI) {
381
InstrSignature Signature;
382
for (unsigned i = 0; i < MI.getNumOperands(); ++i) {
383
const MachineOperand &MO = MI.getOperand(i);
384
size_t h;
385
if (MO.isReg()) {
386
Register RegAlias = MAI.getRegisterAlias(MI.getMF(), MO.getReg());
387
// mimic llvm::hash_value(const MachineOperand &MO)
388
h = hash_combine(MO.getType(), (unsigned)RegAlias, MO.getSubReg(),
389
MO.isDef());
390
} else {
391
h = hash_value(MO);
392
}
393
Signature.push_back(h);
394
}
395
return Signature;
396
}
397
398
// Collect the given instruction in the specified MS. We assume global register
399
// numbering has already occurred by this point. We can directly compare reg
400
// arguments when detecting duplicates.
401
static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI,
402
SPIRV::ModuleSectionType MSType, InstrTraces &IS,
403
bool Append = true) {
404
MAI.setSkipEmission(&MI);
405
InstrSignature MISign = instrToSignature(MI, MAI);
406
auto FoundMI = IS.insert(MISign);
407
if (!FoundMI.second)
408
return; // insert failed, so we found a duplicate; don't add it to MAI.MS
409
// No duplicates, so add it.
410
if (Append)
411
MAI.MS[MSType].push_back(&MI);
412
else
413
MAI.MS[MSType].insert(MAI.MS[MSType].begin(), &MI);
414
}
415
416
// Some global instructions make reference to function-local ID regs, so cannot
417
// be correctly collected until these registers are globally numbered.
418
void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) {
419
InstrTraces IS;
420
for (auto F = M.begin(), E = M.end(); F != E; ++F) {
421
if ((*F).isDeclaration())
422
continue;
423
MachineFunction *MF = MMI->getMachineFunction(*F);
424
assert(MF);
425
for (MachineBasicBlock &MBB : *MF)
426
for (MachineInstr &MI : MBB) {
427
if (MAI.getSkipEmission(&MI))
428
continue;
429
const unsigned OpCode = MI.getOpcode();
430
if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) {
431
collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames, IS);
432
} else if (OpCode == SPIRV::OpEntryPoint) {
433
collectOtherInstr(MI, MAI, SPIRV::MB_EntryPoints, IS);
434
} else if (TII->isDecorationInstr(MI)) {
435
collectOtherInstr(MI, MAI, SPIRV::MB_Annotations, IS);
436
collectFuncNames(MI, &*F);
437
} else if (TII->isConstantInstr(MI)) {
438
// Now OpSpecConstant*s are not in DT,
439
// but they need to be collected anyway.
440
collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS);
441
} else if (OpCode == SPIRV::OpFunction) {
442
collectFuncNames(MI, &*F);
443
} else if (OpCode == SPIRV::OpTypeForwardPointer) {
444
collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS, false);
445
}
446
}
447
}
448
}
449
450
// Number registers in all functions globally from 0 onwards and store
451
// the result in global register alias table. Some registers are already
452
// numbered in collectGlobalEntities.
453
void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) {
454
for (auto F = M.begin(), E = M.end(); F != E; ++F) {
455
if ((*F).isDeclaration())
456
continue;
457
MachineFunction *MF = MMI->getMachineFunction(*F);
458
assert(MF);
459
for (MachineBasicBlock &MBB : *MF) {
460
for (MachineInstr &MI : MBB) {
461
for (MachineOperand &Op : MI.operands()) {
462
if (!Op.isReg())
463
continue;
464
Register Reg = Op.getReg();
465
if (MAI.hasRegisterAlias(MF, Reg))
466
continue;
467
Register NewReg = Register::index2VirtReg(MAI.getNextID());
468
MAI.setRegisterAlias(MF, Reg, NewReg);
469
}
470
if (MI.getOpcode() != SPIRV::OpExtInst)
471
continue;
472
auto Set = MI.getOperand(2).getImm();
473
if (!MAI.ExtInstSetMap.contains(Set))
474
MAI.ExtInstSetMap[Set] = Register::index2VirtReg(MAI.getNextID());
475
}
476
}
477
}
478
}
479
480
// RequirementHandler implementations.
481
void SPIRV::RequirementHandler::getAndAddRequirements(
482
SPIRV::OperandCategory::OperandCategory Category, uint32_t i,
483
const SPIRVSubtarget &ST) {
484
addRequirements(getSymbolicOperandRequirements(Category, i, ST, *this));
485
}
486
487
void SPIRV::RequirementHandler::recursiveAddCapabilities(
488
const CapabilityList &ToPrune) {
489
for (const auto &Cap : ToPrune) {
490
AllCaps.insert(Cap);
491
CapabilityList ImplicitDecls =
492
getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
493
recursiveAddCapabilities(ImplicitDecls);
494
}
495
}
496
497
void SPIRV::RequirementHandler::addCapabilities(const CapabilityList &ToAdd) {
498
for (const auto &Cap : ToAdd) {
499
bool IsNewlyInserted = AllCaps.insert(Cap).second;
500
if (!IsNewlyInserted) // Don't re-add if it's already been declared.
501
continue;
502
CapabilityList ImplicitDecls =
503
getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
504
recursiveAddCapabilities(ImplicitDecls);
505
MinimalCaps.push_back(Cap);
506
}
507
}
508
509
void SPIRV::RequirementHandler::addRequirements(
510
const SPIRV::Requirements &Req) {
511
if (!Req.IsSatisfiable)
512
report_fatal_error("Adding SPIR-V requirements this target can't satisfy.");
513
514
if (Req.Cap.has_value())
515
addCapabilities({Req.Cap.value()});
516
517
addExtensions(Req.Exts);
518
519
if (!Req.MinVer.empty()) {
520
if (!MaxVersion.empty() && Req.MinVer > MaxVersion) {
521
LLVM_DEBUG(dbgs() << "Conflicting version requirements: >= " << Req.MinVer
522
<< " and <= " << MaxVersion << "\n");
523
report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
524
}
525
526
if (MinVersion.empty() || Req.MinVer > MinVersion)
527
MinVersion = Req.MinVer;
528
}
529
530
if (!Req.MaxVer.empty()) {
531
if (!MinVersion.empty() && Req.MaxVer < MinVersion) {
532
LLVM_DEBUG(dbgs() << "Conflicting version requirements: <= " << Req.MaxVer
533
<< " and >= " << MinVersion << "\n");
534
report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
535
}
536
537
if (MaxVersion.empty() || Req.MaxVer < MaxVersion)
538
MaxVersion = Req.MaxVer;
539
}
540
}
541
542
void SPIRV::RequirementHandler::checkSatisfiable(
543
const SPIRVSubtarget &ST) const {
544
// Report as many errors as possible before aborting the compilation.
545
bool IsSatisfiable = true;
546
auto TargetVer = ST.getSPIRVVersion();
547
548
if (!MaxVersion.empty() && !TargetVer.empty() && MaxVersion < TargetVer) {
549
LLVM_DEBUG(
550
dbgs() << "Target SPIR-V version too high for required features\n"
551
<< "Required max version: " << MaxVersion << " target version "
552
<< TargetVer << "\n");
553
IsSatisfiable = false;
554
}
555
556
if (!MinVersion.empty() && !TargetVer.empty() && MinVersion > TargetVer) {
557
LLVM_DEBUG(dbgs() << "Target SPIR-V version too low for required features\n"
558
<< "Required min version: " << MinVersion
559
<< " target version " << TargetVer << "\n");
560
IsSatisfiable = false;
561
}
562
563
if (!MinVersion.empty() && !MaxVersion.empty() && MinVersion > MaxVersion) {
564
LLVM_DEBUG(
565
dbgs()
566
<< "Version is too low for some features and too high for others.\n"
567
<< "Required SPIR-V min version: " << MinVersion
568
<< " required SPIR-V max version " << MaxVersion << "\n");
569
IsSatisfiable = false;
570
}
571
572
for (auto Cap : MinimalCaps) {
573
if (AvailableCaps.contains(Cap))
574
continue;
575
LLVM_DEBUG(dbgs() << "Capability not supported: "
576
<< getSymbolicOperandMnemonic(
577
OperandCategory::CapabilityOperand, Cap)
578
<< "\n");
579
IsSatisfiable = false;
580
}
581
582
for (auto Ext : AllExtensions) {
583
if (ST.canUseExtension(Ext))
584
continue;
585
LLVM_DEBUG(dbgs() << "Extension not supported: "
586
<< getSymbolicOperandMnemonic(
587
OperandCategory::ExtensionOperand, Ext)
588
<< "\n");
589
IsSatisfiable = false;
590
}
591
592
if (!IsSatisfiable)
593
report_fatal_error("Unable to meet SPIR-V requirements for this target.");
594
}
595
596
// Add the given capabilities and all their implicitly defined capabilities too.
597
void SPIRV::RequirementHandler::addAvailableCaps(const CapabilityList &ToAdd) {
598
for (const auto Cap : ToAdd)
599
if (AvailableCaps.insert(Cap).second)
600
addAvailableCaps(getSymbolicOperandCapabilities(
601
SPIRV::OperandCategory::CapabilityOperand, Cap));
602
}
603
604
void SPIRV::RequirementHandler::removeCapabilityIf(
605
const Capability::Capability ToRemove,
606
const Capability::Capability IfPresent) {
607
if (AllCaps.contains(IfPresent))
608
AllCaps.erase(ToRemove);
609
}
610
611
namespace llvm {
612
namespace SPIRV {
613
void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) {
614
if (ST.isOpenCLEnv()) {
615
initAvailableCapabilitiesForOpenCL(ST);
616
return;
617
}
618
619
if (ST.isVulkanEnv()) {
620
initAvailableCapabilitiesForVulkan(ST);
621
return;
622
}
623
624
report_fatal_error("Unimplemented environment for SPIR-V generation.");
625
}
626
627
void RequirementHandler::initAvailableCapabilitiesForOpenCL(
628
const SPIRVSubtarget &ST) {
629
// Add the min requirements for different OpenCL and SPIR-V versions.
630
addAvailableCaps({Capability::Addresses, Capability::Float16Buffer,
631
Capability::Int16, Capability::Int8, Capability::Kernel,
632
Capability::Linkage, Capability::Vector16,
633
Capability::Groups, Capability::GenericPointer,
634
Capability::Shader});
635
if (ST.hasOpenCLFullProfile())
636
addAvailableCaps({Capability::Int64, Capability::Int64Atomics});
637
if (ST.hasOpenCLImageSupport()) {
638
addAvailableCaps({Capability::ImageBasic, Capability::LiteralSampler,
639
Capability::Image1D, Capability::SampledBuffer,
640
Capability::ImageBuffer});
641
if (ST.isAtLeastOpenCLVer(VersionTuple(2, 0)))
642
addAvailableCaps({Capability::ImageReadWrite});
643
}
644
if (ST.isAtLeastSPIRVVer(VersionTuple(1, 1)) &&
645
ST.isAtLeastOpenCLVer(VersionTuple(2, 2)))
646
addAvailableCaps({Capability::SubgroupDispatch, Capability::PipeStorage});
647
if (ST.isAtLeastSPIRVVer(VersionTuple(1, 3)))
648
addAvailableCaps({Capability::GroupNonUniform,
649
Capability::GroupNonUniformVote,
650
Capability::GroupNonUniformArithmetic,
651
Capability::GroupNonUniformBallot,
652
Capability::GroupNonUniformClustered,
653
Capability::GroupNonUniformShuffle,
654
Capability::GroupNonUniformShuffleRelative});
655
if (ST.isAtLeastSPIRVVer(VersionTuple(1, 4)))
656
addAvailableCaps({Capability::DenormPreserve, Capability::DenormFlushToZero,
657
Capability::SignedZeroInfNanPreserve,
658
Capability::RoundingModeRTE,
659
Capability::RoundingModeRTZ});
660
// TODO: verify if this needs some checks.
661
addAvailableCaps({Capability::Float16, Capability::Float64});
662
663
// Add capabilities enabled by extensions.
664
for (auto Extension : ST.getAllAvailableExtensions()) {
665
CapabilityList EnabledCapabilities =
666
getCapabilitiesEnabledByExtension(Extension);
667
addAvailableCaps(EnabledCapabilities);
668
}
669
670
// TODO: add OpenCL extensions.
671
}
672
673
void RequirementHandler::initAvailableCapabilitiesForVulkan(
674
const SPIRVSubtarget &ST) {
675
addAvailableCaps({Capability::Shader, Capability::Linkage});
676
677
// Provided by all supported Vulkan versions.
678
addAvailableCaps({Capability::Int16, Capability::Int64, Capability::Float16,
679
Capability::Float64, Capability::GroupNonUniform});
680
}
681
682
} // namespace SPIRV
683
} // namespace llvm
684
685
// Add the required capabilities from a decoration instruction (including
686
// BuiltIns).
687
static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex,
688
SPIRV::RequirementHandler &Reqs,
689
const SPIRVSubtarget &ST) {
690
int64_t DecOp = MI.getOperand(DecIndex).getImm();
691
auto Dec = static_cast<SPIRV::Decoration::Decoration>(DecOp);
692
Reqs.addRequirements(getSymbolicOperandRequirements(
693
SPIRV::OperandCategory::DecorationOperand, Dec, ST, Reqs));
694
695
if (Dec == SPIRV::Decoration::BuiltIn) {
696
int64_t BuiltInOp = MI.getOperand(DecIndex + 1).getImm();
697
auto BuiltIn = static_cast<SPIRV::BuiltIn::BuiltIn>(BuiltInOp);
698
Reqs.addRequirements(getSymbolicOperandRequirements(
699
SPIRV::OperandCategory::BuiltInOperand, BuiltIn, ST, Reqs));
700
} else if (Dec == SPIRV::Decoration::LinkageAttributes) {
701
int64_t LinkageOp = MI.getOperand(MI.getNumOperands() - 1).getImm();
702
SPIRV::LinkageType::LinkageType LnkType =
703
static_cast<SPIRV::LinkageType::LinkageType>(LinkageOp);
704
if (LnkType == SPIRV::LinkageType::LinkOnceODR)
705
Reqs.addExtension(SPIRV::Extension::SPV_KHR_linkonce_odr);
706
} else if (Dec == SPIRV::Decoration::CacheControlLoadINTEL ||
707
Dec == SPIRV::Decoration::CacheControlStoreINTEL) {
708
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_cache_controls);
709
} else if (Dec == SPIRV::Decoration::HostAccessINTEL) {
710
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_global_variable_host_access);
711
} else if (Dec == SPIRV::Decoration::InitModeINTEL ||
712
Dec == SPIRV::Decoration::ImplementInRegisterMapINTEL) {
713
Reqs.addExtension(
714
SPIRV::Extension::SPV_INTEL_global_variable_fpga_decorations);
715
}
716
}
717
718
// Add requirements for image handling.
719
static void addOpTypeImageReqs(const MachineInstr &MI,
720
SPIRV::RequirementHandler &Reqs,
721
const SPIRVSubtarget &ST) {
722
assert(MI.getNumOperands() >= 8 && "Insufficient operands for OpTypeImage");
723
// The operand indices used here are based on the OpTypeImage layout, which
724
// the MachineInstr follows as well.
725
int64_t ImgFormatOp = MI.getOperand(7).getImm();
726
auto ImgFormat = static_cast<SPIRV::ImageFormat::ImageFormat>(ImgFormatOp);
727
Reqs.getAndAddRequirements(SPIRV::OperandCategory::ImageFormatOperand,
728
ImgFormat, ST);
729
730
bool IsArrayed = MI.getOperand(4).getImm() == 1;
731
bool IsMultisampled = MI.getOperand(5).getImm() == 1;
732
bool NoSampler = MI.getOperand(6).getImm() == 2;
733
// Add dimension requirements.
734
assert(MI.getOperand(2).isImm());
735
switch (MI.getOperand(2).getImm()) {
736
case SPIRV::Dim::DIM_1D:
737
Reqs.addRequirements(NoSampler ? SPIRV::Capability::Image1D
738
: SPIRV::Capability::Sampled1D);
739
break;
740
case SPIRV::Dim::DIM_2D:
741
if (IsMultisampled && NoSampler)
742
Reqs.addRequirements(SPIRV::Capability::ImageMSArray);
743
break;
744
case SPIRV::Dim::DIM_Cube:
745
Reqs.addRequirements(SPIRV::Capability::Shader);
746
if (IsArrayed)
747
Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageCubeArray
748
: SPIRV::Capability::SampledCubeArray);
749
break;
750
case SPIRV::Dim::DIM_Rect:
751
Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageRect
752
: SPIRV::Capability::SampledRect);
753
break;
754
case SPIRV::Dim::DIM_Buffer:
755
Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageBuffer
756
: SPIRV::Capability::SampledBuffer);
757
break;
758
case SPIRV::Dim::DIM_SubpassData:
759
Reqs.addRequirements(SPIRV::Capability::InputAttachment);
760
break;
761
}
762
763
// Has optional access qualifier.
764
// TODO: check if it's OpenCL's kernel.
765
if (MI.getNumOperands() > 8 &&
766
MI.getOperand(8).getImm() == SPIRV::AccessQualifier::ReadWrite)
767
Reqs.addRequirements(SPIRV::Capability::ImageReadWrite);
768
else
769
Reqs.addRequirements(SPIRV::Capability::ImageBasic);
770
}
771
772
// Add requirements for handling atomic float instructions
773
#define ATOM_FLT_REQ_EXT_MSG(ExtName) \
774
"The atomic float instruction requires the following SPIR-V " \
775
"extension: SPV_EXT_shader_atomic_float" ExtName
776
static void AddAtomicFloatRequirements(const MachineInstr &MI,
777
SPIRV::RequirementHandler &Reqs,
778
const SPIRVSubtarget &ST) {
779
assert(MI.getOperand(1).isReg() &&
780
"Expect register operand in atomic float instruction");
781
Register TypeReg = MI.getOperand(1).getReg();
782
SPIRVType *TypeDef = MI.getMF()->getRegInfo().getVRegDef(TypeReg);
783
if (TypeDef->getOpcode() != SPIRV::OpTypeFloat)
784
report_fatal_error("Result type of an atomic float instruction must be a "
785
"floating-point type scalar");
786
787
unsigned BitWidth = TypeDef->getOperand(1).getImm();
788
unsigned Op = MI.getOpcode();
789
if (Op == SPIRV::OpAtomicFAddEXT) {
790
if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add))
791
report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_add"), false);
792
Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add);
793
switch (BitWidth) {
794
case 16:
795
if (!ST.canUseExtension(
796
SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
797
report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false);
798
Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
799
Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT);
800
break;
801
case 32:
802
Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT);
803
break;
804
case 64:
805
Reqs.addCapability(SPIRV::Capability::AtomicFloat64AddEXT);
806
break;
807
default:
808
report_fatal_error(
809
"Unexpected floating-point type width in atomic float instruction");
810
}
811
} else {
812
if (!ST.canUseExtension(
813
SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max))
814
report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_min_max"), false);
815
Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max);
816
switch (BitWidth) {
817
case 16:
818
Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT);
819
break;
820
case 32:
821
Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT);
822
break;
823
case 64:
824
Reqs.addCapability(SPIRV::Capability::AtomicFloat64MinMaxEXT);
825
break;
826
default:
827
report_fatal_error(
828
"Unexpected floating-point type width in atomic float instruction");
829
}
830
}
831
}
832
833
void addInstrRequirements(const MachineInstr &MI,
834
SPIRV::RequirementHandler &Reqs,
835
const SPIRVSubtarget &ST) {
836
switch (MI.getOpcode()) {
837
case SPIRV::OpMemoryModel: {
838
int64_t Addr = MI.getOperand(0).getImm();
839
Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
840
Addr, ST);
841
int64_t Mem = MI.getOperand(1).getImm();
842
Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, Mem,
843
ST);
844
break;
845
}
846
case SPIRV::OpEntryPoint: {
847
int64_t Exe = MI.getOperand(0).getImm();
848
Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModelOperand,
849
Exe, ST);
850
break;
851
}
852
case SPIRV::OpExecutionMode:
853
case SPIRV::OpExecutionModeId: {
854
int64_t Exe = MI.getOperand(1).getImm();
855
Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModeOperand,
856
Exe, ST);
857
break;
858
}
859
case SPIRV::OpTypeMatrix:
860
Reqs.addCapability(SPIRV::Capability::Matrix);
861
break;
862
case SPIRV::OpTypeInt: {
863
unsigned BitWidth = MI.getOperand(1).getImm();
864
if (BitWidth == 64)
865
Reqs.addCapability(SPIRV::Capability::Int64);
866
else if (BitWidth == 16)
867
Reqs.addCapability(SPIRV::Capability::Int16);
868
else if (BitWidth == 8)
869
Reqs.addCapability(SPIRV::Capability::Int8);
870
break;
871
}
872
case SPIRV::OpTypeFloat: {
873
unsigned BitWidth = MI.getOperand(1).getImm();
874
if (BitWidth == 64)
875
Reqs.addCapability(SPIRV::Capability::Float64);
876
else if (BitWidth == 16)
877
Reqs.addCapability(SPIRV::Capability::Float16);
878
break;
879
}
880
case SPIRV::OpTypeVector: {
881
unsigned NumComponents = MI.getOperand(2).getImm();
882
if (NumComponents == 8 || NumComponents == 16)
883
Reqs.addCapability(SPIRV::Capability::Vector16);
884
break;
885
}
886
case SPIRV::OpTypePointer: {
887
auto SC = MI.getOperand(1).getImm();
888
Reqs.getAndAddRequirements(SPIRV::OperandCategory::StorageClassOperand, SC,
889
ST);
890
// If it's a type of pointer to float16 targeting OpenCL, add Float16Buffer
891
// capability.
892
if (!ST.isOpenCLEnv())
893
break;
894
assert(MI.getOperand(2).isReg());
895
const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
896
SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg());
897
if (TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
898
TypeDef->getOperand(1).getImm() == 16)
899
Reqs.addCapability(SPIRV::Capability::Float16Buffer);
900
break;
901
}
902
case SPIRV::OpBitReverse:
903
case SPIRV::OpBitFieldInsert:
904
case SPIRV::OpBitFieldSExtract:
905
case SPIRV::OpBitFieldUExtract:
906
if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions)) {
907
Reqs.addCapability(SPIRV::Capability::Shader);
908
break;
909
}
910
Reqs.addExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
911
Reqs.addCapability(SPIRV::Capability::BitInstructions);
912
break;
913
case SPIRV::OpTypeRuntimeArray:
914
Reqs.addCapability(SPIRV::Capability::Shader);
915
break;
916
case SPIRV::OpTypeOpaque:
917
case SPIRV::OpTypeEvent:
918
Reqs.addCapability(SPIRV::Capability::Kernel);
919
break;
920
case SPIRV::OpTypePipe:
921
case SPIRV::OpTypeReserveId:
922
Reqs.addCapability(SPIRV::Capability::Pipes);
923
break;
924
case SPIRV::OpTypeDeviceEvent:
925
case SPIRV::OpTypeQueue:
926
case SPIRV::OpBuildNDRange:
927
Reqs.addCapability(SPIRV::Capability::DeviceEnqueue);
928
break;
929
case SPIRV::OpDecorate:
930
case SPIRV::OpDecorateId:
931
case SPIRV::OpDecorateString:
932
addOpDecorateReqs(MI, 1, Reqs, ST);
933
break;
934
case SPIRV::OpMemberDecorate:
935
case SPIRV::OpMemberDecorateString:
936
addOpDecorateReqs(MI, 2, Reqs, ST);
937
break;
938
case SPIRV::OpInBoundsPtrAccessChain:
939
Reqs.addCapability(SPIRV::Capability::Addresses);
940
break;
941
case SPIRV::OpConstantSampler:
942
Reqs.addCapability(SPIRV::Capability::LiteralSampler);
943
break;
944
case SPIRV::OpTypeImage:
945
addOpTypeImageReqs(MI, Reqs, ST);
946
break;
947
case SPIRV::OpTypeSampler:
948
Reqs.addCapability(SPIRV::Capability::ImageBasic);
949
break;
950
case SPIRV::OpTypeForwardPointer:
951
// TODO: check if it's OpenCL's kernel.
952
Reqs.addCapability(SPIRV::Capability::Addresses);
953
break;
954
case SPIRV::OpAtomicFlagTestAndSet:
955
case SPIRV::OpAtomicLoad:
956
case SPIRV::OpAtomicStore:
957
case SPIRV::OpAtomicExchange:
958
case SPIRV::OpAtomicCompareExchange:
959
case SPIRV::OpAtomicIIncrement:
960
case SPIRV::OpAtomicIDecrement:
961
case SPIRV::OpAtomicIAdd:
962
case SPIRV::OpAtomicISub:
963
case SPIRV::OpAtomicUMin:
964
case SPIRV::OpAtomicUMax:
965
case SPIRV::OpAtomicSMin:
966
case SPIRV::OpAtomicSMax:
967
case SPIRV::OpAtomicAnd:
968
case SPIRV::OpAtomicOr:
969
case SPIRV::OpAtomicXor: {
970
const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
971
const MachineInstr *InstrPtr = &MI;
972
if (MI.getOpcode() == SPIRV::OpAtomicStore) {
973
assert(MI.getOperand(3).isReg());
974
InstrPtr = MRI.getVRegDef(MI.getOperand(3).getReg());
975
assert(InstrPtr && "Unexpected type instruction for OpAtomicStore");
976
}
977
assert(InstrPtr->getOperand(1).isReg() && "Unexpected operand in atomic");
978
Register TypeReg = InstrPtr->getOperand(1).getReg();
979
SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);
980
if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
981
unsigned BitWidth = TypeDef->getOperand(1).getImm();
982
if (BitWidth == 64)
983
Reqs.addCapability(SPIRV::Capability::Int64Atomics);
984
}
985
break;
986
}
987
case SPIRV::OpGroupNonUniformIAdd:
988
case SPIRV::OpGroupNonUniformFAdd:
989
case SPIRV::OpGroupNonUniformIMul:
990
case SPIRV::OpGroupNonUniformFMul:
991
case SPIRV::OpGroupNonUniformSMin:
992
case SPIRV::OpGroupNonUniformUMin:
993
case SPIRV::OpGroupNonUniformFMin:
994
case SPIRV::OpGroupNonUniformSMax:
995
case SPIRV::OpGroupNonUniformUMax:
996
case SPIRV::OpGroupNonUniformFMax:
997
case SPIRV::OpGroupNonUniformBitwiseAnd:
998
case SPIRV::OpGroupNonUniformBitwiseOr:
999
case SPIRV::OpGroupNonUniformBitwiseXor:
1000
case SPIRV::OpGroupNonUniformLogicalAnd:
1001
case SPIRV::OpGroupNonUniformLogicalOr:
1002
case SPIRV::OpGroupNonUniformLogicalXor: {
1003
assert(MI.getOperand(3).isImm());
1004
int64_t GroupOp = MI.getOperand(3).getImm();
1005
switch (GroupOp) {
1006
case SPIRV::GroupOperation::Reduce:
1007
case SPIRV::GroupOperation::InclusiveScan:
1008
case SPIRV::GroupOperation::ExclusiveScan:
1009
Reqs.addCapability(SPIRV::Capability::Kernel);
1010
Reqs.addCapability(SPIRV::Capability::GroupNonUniformArithmetic);
1011
Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);
1012
break;
1013
case SPIRV::GroupOperation::ClusteredReduce:
1014
Reqs.addCapability(SPIRV::Capability::GroupNonUniformClustered);
1015
break;
1016
case SPIRV::GroupOperation::PartitionedReduceNV:
1017
case SPIRV::GroupOperation::PartitionedInclusiveScanNV:
1018
case SPIRV::GroupOperation::PartitionedExclusiveScanNV:
1019
Reqs.addCapability(SPIRV::Capability::GroupNonUniformPartitionedNV);
1020
break;
1021
}
1022
break;
1023
}
1024
case SPIRV::OpGroupNonUniformShuffle:
1025
case SPIRV::OpGroupNonUniformShuffleXor:
1026
Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffle);
1027
break;
1028
case SPIRV::OpGroupNonUniformShuffleUp:
1029
case SPIRV::OpGroupNonUniformShuffleDown:
1030
Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffleRelative);
1031
break;
1032
case SPIRV::OpGroupAll:
1033
case SPIRV::OpGroupAny:
1034
case SPIRV::OpGroupBroadcast:
1035
case SPIRV::OpGroupIAdd:
1036
case SPIRV::OpGroupFAdd:
1037
case SPIRV::OpGroupFMin:
1038
case SPIRV::OpGroupUMin:
1039
case SPIRV::OpGroupSMin:
1040
case SPIRV::OpGroupFMax:
1041
case SPIRV::OpGroupUMax:
1042
case SPIRV::OpGroupSMax:
1043
Reqs.addCapability(SPIRV::Capability::Groups);
1044
break;
1045
case SPIRV::OpGroupNonUniformElect:
1046
Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
1047
break;
1048
case SPIRV::OpGroupNonUniformAll:
1049
case SPIRV::OpGroupNonUniformAny:
1050
case SPIRV::OpGroupNonUniformAllEqual:
1051
Reqs.addCapability(SPIRV::Capability::GroupNonUniformVote);
1052
break;
1053
case SPIRV::OpGroupNonUniformBroadcast:
1054
case SPIRV::OpGroupNonUniformBroadcastFirst:
1055
case SPIRV::OpGroupNonUniformBallot:
1056
case SPIRV::OpGroupNonUniformInverseBallot:
1057
case SPIRV::OpGroupNonUniformBallotBitExtract:
1058
case SPIRV::OpGroupNonUniformBallotBitCount:
1059
case SPIRV::OpGroupNonUniformBallotFindLSB:
1060
case SPIRV::OpGroupNonUniformBallotFindMSB:
1061
Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);
1062
break;
1063
case SPIRV::OpSubgroupShuffleINTEL:
1064
case SPIRV::OpSubgroupShuffleDownINTEL:
1065
case SPIRV::OpSubgroupShuffleUpINTEL:
1066
case SPIRV::OpSubgroupShuffleXorINTEL:
1067
if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1068
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1069
Reqs.addCapability(SPIRV::Capability::SubgroupShuffleINTEL);
1070
}
1071
break;
1072
case SPIRV::OpSubgroupBlockReadINTEL:
1073
case SPIRV::OpSubgroupBlockWriteINTEL:
1074
if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1075
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1076
Reqs.addCapability(SPIRV::Capability::SubgroupBufferBlockIOINTEL);
1077
}
1078
break;
1079
case SPIRV::OpSubgroupImageBlockReadINTEL:
1080
case SPIRV::OpSubgroupImageBlockWriteINTEL:
1081
if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1082
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1083
Reqs.addCapability(SPIRV::Capability::SubgroupImageBlockIOINTEL);
1084
}
1085
break;
1086
case SPIRV::OpAssumeTrueKHR:
1087
case SPIRV::OpExpectKHR:
1088
if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) {
1089
Reqs.addExtension(SPIRV::Extension::SPV_KHR_expect_assume);
1090
Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR);
1091
}
1092
break;
1093
case SPIRV::OpPtrCastToCrossWorkgroupINTEL:
1094
case SPIRV::OpCrossWorkgroupCastToPtrINTEL:
1095
if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)) {
1096
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes);
1097
Reqs.addCapability(SPIRV::Capability::USMStorageClassesINTEL);
1098
}
1099
break;
1100
case SPIRV::OpConstantFunctionPointerINTEL:
1101
if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
1102
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
1103
Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
1104
}
1105
break;
1106
case SPIRV::OpGroupNonUniformRotateKHR:
1107
if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate))
1108
report_fatal_error("OpGroupNonUniformRotateKHR instruction requires the "
1109
"following SPIR-V extension: SPV_KHR_subgroup_rotate",
1110
false);
1111
Reqs.addExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate);
1112
Reqs.addCapability(SPIRV::Capability::GroupNonUniformRotateKHR);
1113
Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
1114
break;
1115
case SPIRV::OpGroupIMulKHR:
1116
case SPIRV::OpGroupFMulKHR:
1117
case SPIRV::OpGroupBitwiseAndKHR:
1118
case SPIRV::OpGroupBitwiseOrKHR:
1119
case SPIRV::OpGroupBitwiseXorKHR:
1120
case SPIRV::OpGroupLogicalAndKHR:
1121
case SPIRV::OpGroupLogicalOrKHR:
1122
case SPIRV::OpGroupLogicalXorKHR:
1123
if (ST.canUseExtension(
1124
SPIRV::Extension::SPV_KHR_uniform_group_instructions)) {
1125
Reqs.addExtension(SPIRV::Extension::SPV_KHR_uniform_group_instructions);
1126
Reqs.addCapability(SPIRV::Capability::GroupUniformArithmeticKHR);
1127
}
1128
break;
1129
case SPIRV::OpReadClockKHR:
1130
if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_shader_clock))
1131
report_fatal_error("OpReadClockKHR instruction requires the "
1132
"following SPIR-V extension: SPV_KHR_shader_clock",
1133
false);
1134
Reqs.addExtension(SPIRV::Extension::SPV_KHR_shader_clock);
1135
Reqs.addCapability(SPIRV::Capability::ShaderClockKHR);
1136
break;
1137
case SPIRV::OpFunctionPointerCallINTEL:
1138
if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
1139
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
1140
Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
1141
}
1142
break;
1143
case SPIRV::OpAtomicFAddEXT:
1144
case SPIRV::OpAtomicFMinEXT:
1145
case SPIRV::OpAtomicFMaxEXT:
1146
AddAtomicFloatRequirements(MI, Reqs, ST);
1147
break;
1148
case SPIRV::OpConvertBF16ToFINTEL:
1149
case SPIRV::OpConvertFToBF16INTEL:
1150
if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) {
1151
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion);
1152
Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL);
1153
}
1154
break;
1155
case SPIRV::OpVariableLengthArrayINTEL:
1156
case SPIRV::OpSaveMemoryINTEL:
1157
case SPIRV::OpRestoreMemoryINTEL:
1158
if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_variable_length_array)) {
1159
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_variable_length_array);
1160
Reqs.addCapability(SPIRV::Capability::VariableLengthArrayINTEL);
1161
}
1162
break;
1163
case SPIRV::OpAsmTargetINTEL:
1164
case SPIRV::OpAsmINTEL:
1165
case SPIRV::OpAsmCallINTEL:
1166
if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_inline_assembly)) {
1167
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_inline_assembly);
1168
Reqs.addCapability(SPIRV::Capability::AsmINTEL);
1169
}
1170
break;
1171
case SPIRV::OpTypeCooperativeMatrixKHR:
1172
if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
1173
report_fatal_error(
1174
"OpTypeCooperativeMatrixKHR type requires the "
1175
"following SPIR-V extension: SPV_KHR_cooperative_matrix",
1176
false);
1177
Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
1178
Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
1179
break;
1180
default:
1181
break;
1182
}
1183
1184
// If we require capability Shader, then we can remove the requirement for
1185
// the BitInstructions capability, since Shader is a superset capability
1186
// of BitInstructions.
1187
Reqs.removeCapabilityIf(SPIRV::Capability::BitInstructions,
1188
SPIRV::Capability::Shader);
1189
}
1190
1191
static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
1192
MachineModuleInfo *MMI, const SPIRVSubtarget &ST) {
1193
// Collect requirements for existing instructions.
1194
for (auto F = M.begin(), E = M.end(); F != E; ++F) {
1195
MachineFunction *MF = MMI->getMachineFunction(*F);
1196
if (!MF)
1197
continue;
1198
for (const MachineBasicBlock &MBB : *MF)
1199
for (const MachineInstr &MI : MBB)
1200
addInstrRequirements(MI, MAI.Reqs, ST);
1201
}
1202
// Collect requirements for OpExecutionMode instructions.
1203
auto Node = M.getNamedMetadata("spirv.ExecutionMode");
1204
if (Node) {
1205
// SPV_KHR_float_controls is not available until v1.4
1206
bool RequireFloatControls = false,
1207
VerLower14 = !ST.isAtLeastSPIRVVer(VersionTuple(1, 4));
1208
for (unsigned i = 0; i < Node->getNumOperands(); i++) {
1209
MDNode *MDN = cast<MDNode>(Node->getOperand(i));
1210
const MDOperand &MDOp = MDN->getOperand(1);
1211
if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) {
1212
Constant *C = CMeta->getValue();
1213
if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) {
1214
auto EM = Const->getZExtValue();
1215
MAI.Reqs.getAndAddRequirements(
1216
SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
1217
// add SPV_KHR_float_controls if the version is too low
1218
switch (EM) {
1219
case SPIRV::ExecutionMode::DenormPreserve:
1220
case SPIRV::ExecutionMode::DenormFlushToZero:
1221
case SPIRV::ExecutionMode::SignedZeroInfNanPreserve:
1222
case SPIRV::ExecutionMode::RoundingModeRTE:
1223
case SPIRV::ExecutionMode::RoundingModeRTZ:
1224
RequireFloatControls = VerLower14;
1225
break;
1226
}
1227
}
1228
}
1229
}
1230
if (RequireFloatControls &&
1231
ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls))
1232
MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls);
1233
}
1234
for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
1235
const Function &F = *FI;
1236
if (F.isDeclaration())
1237
continue;
1238
if (F.getMetadata("reqd_work_group_size"))
1239
MAI.Reqs.getAndAddRequirements(
1240
SPIRV::OperandCategory::ExecutionModeOperand,
1241
SPIRV::ExecutionMode::LocalSize, ST);
1242
if (F.getFnAttribute("hlsl.numthreads").isValid()) {
1243
MAI.Reqs.getAndAddRequirements(
1244
SPIRV::OperandCategory::ExecutionModeOperand,
1245
SPIRV::ExecutionMode::LocalSize, ST);
1246
}
1247
if (F.getMetadata("work_group_size_hint"))
1248
MAI.Reqs.getAndAddRequirements(
1249
SPIRV::OperandCategory::ExecutionModeOperand,
1250
SPIRV::ExecutionMode::LocalSizeHint, ST);
1251
if (F.getMetadata("intel_reqd_sub_group_size"))
1252
MAI.Reqs.getAndAddRequirements(
1253
SPIRV::OperandCategory::ExecutionModeOperand,
1254
SPIRV::ExecutionMode::SubgroupSize, ST);
1255
if (F.getMetadata("vec_type_hint"))
1256
MAI.Reqs.getAndAddRequirements(
1257
SPIRV::OperandCategory::ExecutionModeOperand,
1258
SPIRV::ExecutionMode::VecTypeHint, ST);
1259
1260
if (F.hasOptNone() &&
1261
ST.canUseExtension(SPIRV::Extension::SPV_INTEL_optnone)) {
1262
// Output OpCapability OptNoneINTEL.
1263
MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_optnone);
1264
MAI.Reqs.addCapability(SPIRV::Capability::OptNoneINTEL);
1265
}
1266
}
1267
}
1268
1269
static unsigned getFastMathFlags(const MachineInstr &I) {
1270
unsigned Flags = SPIRV::FPFastMathMode::None;
1271
if (I.getFlag(MachineInstr::MIFlag::FmNoNans))
1272
Flags |= SPIRV::FPFastMathMode::NotNaN;
1273
if (I.getFlag(MachineInstr::MIFlag::FmNoInfs))
1274
Flags |= SPIRV::FPFastMathMode::NotInf;
1275
if (I.getFlag(MachineInstr::MIFlag::FmNsz))
1276
Flags |= SPIRV::FPFastMathMode::NSZ;
1277
if (I.getFlag(MachineInstr::MIFlag::FmArcp))
1278
Flags |= SPIRV::FPFastMathMode::AllowRecip;
1279
if (I.getFlag(MachineInstr::MIFlag::FmReassoc))
1280
Flags |= SPIRV::FPFastMathMode::Fast;
1281
return Flags;
1282
}
1283
1284
static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST,
1285
const SPIRVInstrInfo &TII,
1286
SPIRV::RequirementHandler &Reqs) {
1287
if (I.getFlag(MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(I) &&
1288
getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
1289
SPIRV::Decoration::NoSignedWrap, ST, Reqs)
1290
.IsSatisfiable) {
1291
buildOpDecorate(I.getOperand(0).getReg(), I, TII,
1292
SPIRV::Decoration::NoSignedWrap, {});
1293
}
1294
if (I.getFlag(MachineInstr::MIFlag::NoUWrap) && TII.canUseNUW(I) &&
1295
getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
1296
SPIRV::Decoration::NoUnsignedWrap, ST,
1297
Reqs)
1298
.IsSatisfiable) {
1299
buildOpDecorate(I.getOperand(0).getReg(), I, TII,
1300
SPIRV::Decoration::NoUnsignedWrap, {});
1301
}
1302
if (!TII.canUseFastMathFlags(I))
1303
return;
1304
unsigned FMFlags = getFastMathFlags(I);
1305
if (FMFlags == SPIRV::FPFastMathMode::None)
1306
return;
1307
Register DstReg = I.getOperand(0).getReg();
1308
buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode, {FMFlags});
1309
}
1310
1311
// Walk all functions and add decorations related to MI flags.
1312
static void addDecorations(const Module &M, const SPIRVInstrInfo &TII,
1313
MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
1314
SPIRV::ModuleAnalysisInfo &MAI) {
1315
for (auto F = M.begin(), E = M.end(); F != E; ++F) {
1316
MachineFunction *MF = MMI->getMachineFunction(*F);
1317
if (!MF)
1318
continue;
1319
for (auto &MBB : *MF)
1320
for (auto &MI : MBB)
1321
handleMIFlagDecoration(MI, ST, TII, MAI.Reqs);
1322
}
1323
}
1324
1325
struct SPIRV::ModuleAnalysisInfo SPIRVModuleAnalysis::MAI;
1326
1327
void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
1328
AU.addRequired<TargetPassConfig>();
1329
AU.addRequired<MachineModuleInfoWrapperPass>();
1330
}
1331
1332
bool SPIRVModuleAnalysis::runOnModule(Module &M) {
1333
SPIRVTargetMachine &TM =
1334
getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>();
1335
ST = TM.getSubtargetImpl();
1336
GR = ST->getSPIRVGlobalRegistry();
1337
TII = ST->getInstrInfo();
1338
1339
MMI = &getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
1340
1341
setBaseInfo(M);
1342
1343
addDecorations(M, *TII, MMI, *ST, MAI);
1344
1345
collectReqs(M, MAI, MMI, *ST);
1346
1347
// Process type/const/global var/func decl instructions, number their
1348
// destination registers from 0 to N, collect Extensions and Capabilities.
1349
processDefInstrs(M);
1350
1351
// Number rest of registers from N+1 onwards.
1352
numberRegistersGlobally(M);
1353
1354
// Update references to OpFunction instructions to use Global Registers
1355
if (GR->hasConstFunPtr())
1356
collectFuncPtrs();
1357
1358
// Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions.
1359
processOtherInstrs(M);
1360
1361
// If there are no entry points, we need the Linkage capability.
1362
if (MAI.MS[SPIRV::MB_EntryPoints].empty())
1363
MAI.Reqs.addCapability(SPIRV::Capability::Linkage);
1364
1365
// Set maximum ID used.
1366
GR->setBound(MAI.MaxID);
1367
1368
return false;
1369
}
1370
1371