Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Target/DirectX/DXILRootSignature.cpp
213799 views
1
//===- DXILRootSignature.cpp - DXIL Root Signature helper objects -------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
///
9
/// \file This file contains helper objects and APIs for working with DXIL
10
/// Root Signatures.
11
///
12
//===----------------------------------------------------------------------===//
13
#include "DXILRootSignature.h"
14
#include "DirectX.h"
15
#include "llvm/ADT/StringSwitch.h"
16
#include "llvm/ADT/Twine.h"
17
#include "llvm/Analysis/DXILMetadataAnalysis.h"
18
#include "llvm/BinaryFormat/DXContainer.h"
19
#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
20
#include "llvm/IR/Constants.h"
21
#include "llvm/IR/DiagnosticInfo.h"
22
#include "llvm/IR/Function.h"
23
#include "llvm/IR/LLVMContext.h"
24
#include "llvm/IR/Metadata.h"
25
#include "llvm/IR/Module.h"
26
#include "llvm/InitializePasses.h"
27
#include "llvm/Pass.h"
28
#include "llvm/Support/Error.h"
29
#include "llvm/Support/ErrorHandling.h"
30
#include "llvm/Support/raw_ostream.h"
31
#include <cstdint>
32
#include <optional>
33
#include <utility>
34
35
using namespace llvm;
36
using namespace llvm::dxil;
37
38
static bool reportError(LLVMContext *Ctx, Twine Message,
39
DiagnosticSeverity Severity = DS_Error) {
40
Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity));
41
return true;
42
}
43
44
static bool reportValueError(LLVMContext *Ctx, Twine ParamName,
45
uint32_t Value) {
46
Ctx->diagnose(DiagnosticInfoGeneric(
47
"Invalid value for " + ParamName + ": " + Twine(Value), DS_Error));
48
return true;
49
}
50
51
static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
52
unsigned int OpId) {
53
if (auto *CI =
54
mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpId).get()))
55
return CI->getZExtValue();
56
return std::nullopt;
57
}
58
59
static std::optional<float> extractMdFloatValue(MDNode *Node,
60
unsigned int OpId) {
61
if (auto *CI = mdconst::dyn_extract<ConstantFP>(Node->getOperand(OpId).get()))
62
return CI->getValueAPF().convertToFloat();
63
return std::nullopt;
64
}
65
66
static std::optional<StringRef> extractMdStringValue(MDNode *Node,
67
unsigned int OpId) {
68
MDString *NodeText = dyn_cast<MDString>(Node->getOperand(OpId));
69
if (NodeText == nullptr)
70
return std::nullopt;
71
return NodeText->getString();
72
}
73
74
static bool parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
75
MDNode *RootFlagNode) {
76
77
if (RootFlagNode->getNumOperands() != 2)
78
return reportError(Ctx, "Invalid format for RootFlag Element");
79
80
if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
81
RSD.Flags = *Val;
82
else
83
return reportError(Ctx, "Invalid value for RootFlag");
84
85
return false;
86
}
87
88
static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
89
MDNode *RootConstantNode) {
90
91
if (RootConstantNode->getNumOperands() != 5)
92
return reportError(Ctx, "Invalid format for RootConstants Element");
93
94
dxbc::RTS0::v1::RootParameterHeader Header;
95
// The parameter offset doesn't matter here - we recalculate it during
96
// serialization Header.ParameterOffset = 0;
97
Header.ParameterType =
98
llvm::to_underlying(dxbc::RootParameterType::Constants32Bit);
99
100
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
101
Header.ShaderVisibility = *Val;
102
else
103
return reportError(Ctx, "Invalid value for ShaderVisibility");
104
105
dxbc::RTS0::v1::RootConstants Constants;
106
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
107
Constants.ShaderRegister = *Val;
108
else
109
return reportError(Ctx, "Invalid value for ShaderRegister");
110
111
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
112
Constants.RegisterSpace = *Val;
113
else
114
return reportError(Ctx, "Invalid value for RegisterSpace");
115
116
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
117
Constants.Num32BitValues = *Val;
118
else
119
return reportError(Ctx, "Invalid value for Num32BitValues");
120
121
RSD.ParametersContainer.addParameter(Header, Constants);
122
123
return false;
124
}
125
126
static bool parseRootDescriptors(LLVMContext *Ctx,
127
mcdxbc::RootSignatureDesc &RSD,
128
MDNode *RootDescriptorNode,
129
RootSignatureElementKind ElementKind) {
130
assert(ElementKind == RootSignatureElementKind::SRV ||
131
ElementKind == RootSignatureElementKind::UAV ||
132
ElementKind == RootSignatureElementKind::CBV &&
133
"parseRootDescriptors should only be called with RootDescriptor "
134
"element kind.");
135
if (RootDescriptorNode->getNumOperands() != 5)
136
return reportError(Ctx, "Invalid format for Root Descriptor Element");
137
138
dxbc::RTS0::v1::RootParameterHeader Header;
139
switch (ElementKind) {
140
case RootSignatureElementKind::SRV:
141
Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::SRV);
142
break;
143
case RootSignatureElementKind::UAV:
144
Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::UAV);
145
break;
146
case RootSignatureElementKind::CBV:
147
Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::CBV);
148
break;
149
default:
150
llvm_unreachable("invalid Root Descriptor kind");
151
break;
152
}
153
154
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
155
Header.ShaderVisibility = *Val;
156
else
157
return reportError(Ctx, "Invalid value for ShaderVisibility");
158
159
dxbc::RTS0::v2::RootDescriptor Descriptor;
160
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
161
Descriptor.ShaderRegister = *Val;
162
else
163
return reportError(Ctx, "Invalid value for ShaderRegister");
164
165
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3))
166
Descriptor.RegisterSpace = *Val;
167
else
168
return reportError(Ctx, "Invalid value for RegisterSpace");
169
170
if (RSD.Version == 1) {
171
RSD.ParametersContainer.addParameter(Header, Descriptor);
172
return false;
173
}
174
assert(RSD.Version > 1);
175
176
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4))
177
Descriptor.Flags = *Val;
178
else
179
return reportError(Ctx, "Invalid value for Root Descriptor Flags");
180
181
RSD.ParametersContainer.addParameter(Header, Descriptor);
182
return false;
183
}
184
185
static bool parseDescriptorRange(LLVMContext *Ctx,
186
mcdxbc::DescriptorTable &Table,
187
MDNode *RangeDescriptorNode) {
188
189
if (RangeDescriptorNode->getNumOperands() != 6)
190
return reportError(Ctx, "Invalid format for Descriptor Range");
191
192
dxbc::RTS0::v2::DescriptorRange Range;
193
194
std::optional<StringRef> ElementText =
195
extractMdStringValue(RangeDescriptorNode, 0);
196
197
if (!ElementText.has_value())
198
return reportError(Ctx, "Descriptor Range, first element is not a string.");
199
200
Range.RangeType =
201
StringSwitch<uint32_t>(*ElementText)
202
.Case("CBV", llvm::to_underlying(dxbc::DescriptorRangeType::CBV))
203
.Case("SRV", llvm::to_underlying(dxbc::DescriptorRangeType::SRV))
204
.Case("UAV", llvm::to_underlying(dxbc::DescriptorRangeType::UAV))
205
.Case("Sampler",
206
llvm::to_underlying(dxbc::DescriptorRangeType::Sampler))
207
.Default(~0U);
208
209
if (Range.RangeType == ~0U)
210
return reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText);
211
212
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1))
213
Range.NumDescriptors = *Val;
214
else
215
return reportError(Ctx, "Invalid value for Number of Descriptor in Range");
216
217
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2))
218
Range.BaseShaderRegister = *Val;
219
else
220
return reportError(Ctx, "Invalid value for BaseShaderRegister");
221
222
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3))
223
Range.RegisterSpace = *Val;
224
else
225
return reportError(Ctx, "Invalid value for RegisterSpace");
226
227
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4))
228
Range.OffsetInDescriptorsFromTableStart = *Val;
229
else
230
return reportError(Ctx,
231
"Invalid value for OffsetInDescriptorsFromTableStart");
232
233
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5))
234
Range.Flags = *Val;
235
else
236
return reportError(Ctx, "Invalid value for Descriptor Range Flags");
237
238
Table.Ranges.push_back(Range);
239
return false;
240
}
241
242
static bool parseDescriptorTable(LLVMContext *Ctx,
243
mcdxbc::RootSignatureDesc &RSD,
244
MDNode *DescriptorTableNode) {
245
const unsigned int NumOperands = DescriptorTableNode->getNumOperands();
246
if (NumOperands < 2)
247
return reportError(Ctx, "Invalid format for Descriptor Table");
248
249
dxbc::RTS0::v1::RootParameterHeader Header;
250
if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
251
Header.ShaderVisibility = *Val;
252
else
253
return reportError(Ctx, "Invalid value for ShaderVisibility");
254
255
mcdxbc::DescriptorTable Table;
256
Header.ParameterType =
257
llvm::to_underlying(dxbc::RootParameterType::DescriptorTable);
258
259
for (unsigned int I = 2; I < NumOperands; I++) {
260
MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
261
if (Element == nullptr)
262
return reportError(Ctx, "Missing Root Element Metadata Node.");
263
264
if (parseDescriptorRange(Ctx, Table, Element))
265
return true;
266
}
267
268
RSD.ParametersContainer.addParameter(Header, Table);
269
return false;
270
}
271
272
static bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
273
MDNode *StaticSamplerNode) {
274
if (StaticSamplerNode->getNumOperands() != 14)
275
return reportError(Ctx, "Invalid format for Static Sampler");
276
277
dxbc::RTS0::v1::StaticSampler Sampler;
278
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1))
279
Sampler.Filter = *Val;
280
else
281
return reportError(Ctx, "Invalid value for Filter");
282
283
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2))
284
Sampler.AddressU = *Val;
285
else
286
return reportError(Ctx, "Invalid value for AddressU");
287
288
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3))
289
Sampler.AddressV = *Val;
290
else
291
return reportError(Ctx, "Invalid value for AddressV");
292
293
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4))
294
Sampler.AddressW = *Val;
295
else
296
return reportError(Ctx, "Invalid value for AddressW");
297
298
if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 5))
299
Sampler.MipLODBias = *Val;
300
else
301
return reportError(Ctx, "Invalid value for MipLODBias");
302
303
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6))
304
Sampler.MaxAnisotropy = *Val;
305
else
306
return reportError(Ctx, "Invalid value for MaxAnisotropy");
307
308
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7))
309
Sampler.ComparisonFunc = *Val;
310
else
311
return reportError(Ctx, "Invalid value for ComparisonFunc ");
312
313
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8))
314
Sampler.BorderColor = *Val;
315
else
316
return reportError(Ctx, "Invalid value for ComparisonFunc ");
317
318
if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 9))
319
Sampler.MinLOD = *Val;
320
else
321
return reportError(Ctx, "Invalid value for MinLOD");
322
323
if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 10))
324
Sampler.MaxLOD = *Val;
325
else
326
return reportError(Ctx, "Invalid value for MaxLOD");
327
328
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11))
329
Sampler.ShaderRegister = *Val;
330
else
331
return reportError(Ctx, "Invalid value for ShaderRegister");
332
333
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12))
334
Sampler.RegisterSpace = *Val;
335
else
336
return reportError(Ctx, "Invalid value for RegisterSpace");
337
338
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13))
339
Sampler.ShaderVisibility = *Val;
340
else
341
return reportError(Ctx, "Invalid value for ShaderVisibility");
342
343
RSD.StaticSamplers.push_back(Sampler);
344
return false;
345
}
346
347
static bool parseRootSignatureElement(LLVMContext *Ctx,
348
mcdxbc::RootSignatureDesc &RSD,
349
MDNode *Element) {
350
std::optional<StringRef> ElementText = extractMdStringValue(Element, 0);
351
if (!ElementText.has_value())
352
return reportError(Ctx, "Invalid format for Root Element");
353
354
RootSignatureElementKind ElementKind =
355
StringSwitch<RootSignatureElementKind>(*ElementText)
356
.Case("RootFlags", RootSignatureElementKind::RootFlags)
357
.Case("RootConstants", RootSignatureElementKind::RootConstants)
358
.Case("RootCBV", RootSignatureElementKind::CBV)
359
.Case("RootSRV", RootSignatureElementKind::SRV)
360
.Case("RootUAV", RootSignatureElementKind::UAV)
361
.Case("DescriptorTable", RootSignatureElementKind::DescriptorTable)
362
.Case("StaticSampler", RootSignatureElementKind::StaticSamplers)
363
.Default(RootSignatureElementKind::Error);
364
365
switch (ElementKind) {
366
367
case RootSignatureElementKind::RootFlags:
368
return parseRootFlags(Ctx, RSD, Element);
369
case RootSignatureElementKind::RootConstants:
370
return parseRootConstants(Ctx, RSD, Element);
371
case RootSignatureElementKind::CBV:
372
case RootSignatureElementKind::SRV:
373
case RootSignatureElementKind::UAV:
374
return parseRootDescriptors(Ctx, RSD, Element, ElementKind);
375
case RootSignatureElementKind::DescriptorTable:
376
return parseDescriptorTable(Ctx, RSD, Element);
377
case RootSignatureElementKind::StaticSamplers:
378
return parseStaticSampler(Ctx, RSD, Element);
379
case RootSignatureElementKind::Error:
380
return reportError(Ctx, "Invalid Root Signature Element: " + *ElementText);
381
}
382
383
llvm_unreachable("Unhandled RootSignatureElementKind enum.");
384
}
385
386
static bool parse(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
387
MDNode *Node) {
388
bool HasError = false;
389
390
// Loop through the Root Elements of the root signature.
391
for (const auto &Operand : Node->operands()) {
392
MDNode *Element = dyn_cast<MDNode>(Operand);
393
if (Element == nullptr)
394
return reportError(Ctx, "Missing Root Element Metadata Node.");
395
396
HasError = HasError || parseRootSignatureElement(Ctx, RSD, Element);
397
}
398
399
return HasError;
400
}
401
402
static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
403
404
if (!llvm::hlsl::rootsig::verifyVersion(RSD.Version)) {
405
return reportValueError(Ctx, "Version", RSD.Version);
406
}
407
408
if (!llvm::hlsl::rootsig::verifyRootFlag(RSD.Flags)) {
409
return reportValueError(Ctx, "RootFlags", RSD.Flags);
410
}
411
412
for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) {
413
if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility))
414
return reportValueError(Ctx, "ShaderVisibility",
415
Info.Header.ShaderVisibility);
416
417
assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
418
"Invalid value for ParameterType");
419
420
switch (Info.Header.ParameterType) {
421
422
case llvm::to_underlying(dxbc::RootParameterType::CBV):
423
case llvm::to_underlying(dxbc::RootParameterType::UAV):
424
case llvm::to_underlying(dxbc::RootParameterType::SRV): {
425
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
426
RSD.ParametersContainer.getRootDescriptor(Info.Location);
427
if (!llvm::hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))
428
return reportValueError(Ctx, "ShaderRegister",
429
Descriptor.ShaderRegister);
430
431
if (!llvm::hlsl::rootsig::verifyRegisterSpace(Descriptor.RegisterSpace))
432
return reportValueError(Ctx, "RegisterSpace", Descriptor.RegisterSpace);
433
434
if (RSD.Version > 1) {
435
if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version,
436
Descriptor.Flags))
437
return reportValueError(Ctx, "RootDescriptorFlag", Descriptor.Flags);
438
}
439
break;
440
}
441
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
442
const mcdxbc::DescriptorTable &Table =
443
RSD.ParametersContainer.getDescriptorTable(Info.Location);
444
for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
445
if (!llvm::hlsl::rootsig::verifyRangeType(Range.RangeType))
446
return reportValueError(Ctx, "RangeType", Range.RangeType);
447
448
if (!llvm::hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace))
449
return reportValueError(Ctx, "RegisterSpace", Range.RegisterSpace);
450
451
if (!llvm::hlsl::rootsig::verifyNumDescriptors(Range.NumDescriptors))
452
return reportValueError(Ctx, "NumDescriptors", Range.NumDescriptors);
453
454
if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(
455
RSD.Version, Range.RangeType, Range.Flags))
456
return reportValueError(Ctx, "DescriptorFlag", Range.Flags);
457
}
458
break;
459
}
460
}
461
}
462
463
for (const dxbc::RTS0::v1::StaticSampler &Sampler : RSD.StaticSamplers) {
464
if (!llvm::hlsl::rootsig::verifySamplerFilter(Sampler.Filter))
465
return reportValueError(Ctx, "Filter", Sampler.Filter);
466
467
if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressU))
468
return reportValueError(Ctx, "AddressU", Sampler.AddressU);
469
470
if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressV))
471
return reportValueError(Ctx, "AddressV", Sampler.AddressV);
472
473
if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressW))
474
return reportValueError(Ctx, "AddressW", Sampler.AddressW);
475
476
if (!llvm::hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias))
477
return reportValueError(Ctx, "MipLODBias", Sampler.MipLODBias);
478
479
if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(Sampler.MaxAnisotropy))
480
return reportValueError(Ctx, "MaxAnisotropy", Sampler.MaxAnisotropy);
481
482
if (!llvm::hlsl::rootsig::verifyComparisonFunc(Sampler.ComparisonFunc))
483
return reportValueError(Ctx, "ComparisonFunc", Sampler.ComparisonFunc);
484
485
if (!llvm::hlsl::rootsig::verifyBorderColor(Sampler.BorderColor))
486
return reportValueError(Ctx, "BorderColor", Sampler.BorderColor);
487
488
if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MinLOD))
489
return reportValueError(Ctx, "MinLOD", Sampler.MinLOD);
490
491
if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MaxLOD))
492
return reportValueError(Ctx, "MaxLOD", Sampler.MaxLOD);
493
494
if (!llvm::hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister))
495
return reportValueError(Ctx, "ShaderRegister", Sampler.ShaderRegister);
496
497
if (!llvm::hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace))
498
return reportValueError(Ctx, "RegisterSpace", Sampler.RegisterSpace);
499
500
if (!dxbc::isValidShaderVisibility(Sampler.ShaderVisibility))
501
return reportValueError(Ctx, "ShaderVisibility",
502
Sampler.ShaderVisibility);
503
}
504
505
return false;
506
}
507
508
static SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>
509
analyzeModule(Module &M) {
510
511
/** Root Signature are specified as following in the metadata:
512
513
!dx.rootsignatures = !{!2} ; list of function/root signature pairs
514
!2 = !{ ptr @main, !3 } ; function, root signature
515
!3 = !{ !4, !5, !6, !7 } ; list of root signature elements
516
517
So for each MDNode inside dx.rootsignatures NamedMDNode
518
(the Root parameter of this function), the parsing process needs
519
to loop through each of its operands and process the function,
520
signature pair.
521
*/
522
523
LLVMContext *Ctx = &M.getContext();
524
525
SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> RSDMap;
526
527
NamedMDNode *RootSignatureNode = M.getNamedMetadata("dx.rootsignatures");
528
if (RootSignatureNode == nullptr)
529
return RSDMap;
530
531
for (const auto &RSDefNode : RootSignatureNode->operands()) {
532
if (RSDefNode->getNumOperands() != 3) {
533
reportError(Ctx, "Invalid Root Signature metadata - expected function, "
534
"signature, and version.");
535
continue;
536
}
537
538
// Function was pruned during compilation.
539
const MDOperand &FunctionPointerMdNode = RSDefNode->getOperand(0);
540
if (FunctionPointerMdNode == nullptr) {
541
reportError(
542
Ctx, "Function associated with Root Signature definition is null.");
543
continue;
544
}
545
546
ValueAsMetadata *VAM =
547
llvm::dyn_cast<ValueAsMetadata>(FunctionPointerMdNode.get());
548
if (VAM == nullptr) {
549
reportError(Ctx, "First element of root signature is not a Value");
550
continue;
551
}
552
553
Function *F = dyn_cast<Function>(VAM->getValue());
554
if (F == nullptr) {
555
reportError(Ctx, "First element of root signature is not a Function");
556
continue;
557
}
558
559
Metadata *RootElementListOperand = RSDefNode->getOperand(1).get();
560
561
if (RootElementListOperand == nullptr) {
562
reportError(Ctx, "Root Element mdnode is null.");
563
continue;
564
}
565
566
MDNode *RootElementListNode = dyn_cast<MDNode>(RootElementListOperand);
567
if (RootElementListNode == nullptr) {
568
reportError(Ctx, "Root Element is not a metadata node.");
569
continue;
570
}
571
mcdxbc::RootSignatureDesc RSD;
572
if (std::optional<uint32_t> Version = extractMdIntValue(RSDefNode, 2))
573
RSD.Version = *Version;
574
else {
575
reportError(Ctx, "Invalid RSDefNode value, expected constant int");
576
continue;
577
}
578
579
// Clang emits the root signature data in dxcontainer following a specific
580
// sequence. First the header, then the root parameters. So the header
581
// offset will always equal to the header size.
582
RSD.RootParameterOffset = sizeof(dxbc::RTS0::v1::RootSignatureHeader);
583
584
// static sampler offset is calculated when writting dxcontainer.
585
RSD.StaticSamplersOffset = 0u;
586
587
if (parse(Ctx, RSD, RootElementListNode) || validate(Ctx, RSD)) {
588
return RSDMap;
589
}
590
591
RSDMap.insert(std::make_pair(F, RSD));
592
}
593
594
return RSDMap;
595
}
596
597
AnalysisKey RootSignatureAnalysis::Key;
598
599
RootSignatureAnalysis::Result
600
RootSignatureAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
601
return RootSignatureBindingInfo(analyzeModule(M));
602
}
603
604
//===----------------------------------------------------------------------===//
605
606
PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
607
ModuleAnalysisManager &AM) {
608
609
RootSignatureBindingInfo &RSDMap = AM.getResult<RootSignatureAnalysis>(M);
610
611
OS << "Root Signature Definitions"
612
<< "\n";
613
for (const Function &F : M) {
614
auto It = RSDMap.find(&F);
615
if (It == RSDMap.end())
616
continue;
617
const auto &RS = It->second;
618
OS << "Definition for '" << F.getName() << "':\n";
619
// start root signature header
620
OS << "Flags: " << format_hex(RS.Flags, 8) << "\n"
621
<< "Version: " << RS.Version << "\n"
622
<< "RootParametersOffset: " << RS.RootParameterOffset << "\n"
623
<< "NumParameters: " << RS.ParametersContainer.size() << "\n";
624
for (size_t I = 0; I < RS.ParametersContainer.size(); I++) {
625
const auto &[Type, Loc] =
626
RS.ParametersContainer.getTypeAndLocForParameter(I);
627
const dxbc::RTS0::v1::RootParameterHeader Header =
628
RS.ParametersContainer.getHeader(I);
629
630
OS << "- Parameter Type: " << Type << "\n"
631
<< " Shader Visibility: " << Header.ShaderVisibility << "\n";
632
633
switch (Type) {
634
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
635
const dxbc::RTS0::v1::RootConstants &Constants =
636
RS.ParametersContainer.getConstant(Loc);
637
OS << " Register Space: " << Constants.RegisterSpace << "\n"
638
<< " Shader Register: " << Constants.ShaderRegister << "\n"
639
<< " Num 32 Bit Values: " << Constants.Num32BitValues << "\n";
640
break;
641
}
642
case llvm::to_underlying(dxbc::RootParameterType::CBV):
643
case llvm::to_underlying(dxbc::RootParameterType::UAV):
644
case llvm::to_underlying(dxbc::RootParameterType::SRV): {
645
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
646
RS.ParametersContainer.getRootDescriptor(Loc);
647
OS << " Register Space: " << Descriptor.RegisterSpace << "\n"
648
<< " Shader Register: " << Descriptor.ShaderRegister << "\n";
649
if (RS.Version > 1)
650
OS << " Flags: " << Descriptor.Flags << "\n";
651
break;
652
}
653
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
654
const mcdxbc::DescriptorTable &Table =
655
RS.ParametersContainer.getDescriptorTable(Loc);
656
OS << " NumRanges: " << Table.Ranges.size() << "\n";
657
658
for (const dxbc::RTS0::v2::DescriptorRange Range : Table) {
659
OS << " - Range Type: " << Range.RangeType << "\n"
660
<< " Register Space: " << Range.RegisterSpace << "\n"
661
<< " Base Shader Register: " << Range.BaseShaderRegister << "\n"
662
<< " Num Descriptors: " << Range.NumDescriptors << "\n"
663
<< " Offset In Descriptors From Table Start: "
664
<< Range.OffsetInDescriptorsFromTableStart << "\n";
665
if (RS.Version > 1)
666
OS << " Flags: " << Range.Flags << "\n";
667
}
668
break;
669
}
670
}
671
}
672
OS << "NumStaticSamplers: " << 0 << "\n";
673
OS << "StaticSamplersOffset: " << RS.StaticSamplersOffset << "\n";
674
}
675
return PreservedAnalyses::all();
676
}
677
678
//===----------------------------------------------------------------------===//
679
bool RootSignatureAnalysisWrapper::runOnModule(Module &M) {
680
FuncToRsMap = std::make_unique<RootSignatureBindingInfo>(
681
RootSignatureBindingInfo(analyzeModule(M)));
682
return false;
683
}
684
685
void RootSignatureAnalysisWrapper::getAnalysisUsage(AnalysisUsage &AU) const {
686
AU.setPreservesAll();
687
AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
688
}
689
690
char RootSignatureAnalysisWrapper::ID = 0;
691
692
INITIALIZE_PASS_BEGIN(RootSignatureAnalysisWrapper,
693
"dxil-root-signature-analysis",
694
"DXIL Root Signature Analysis", true, true)
695
INITIALIZE_PASS_END(RootSignatureAnalysisWrapper,
696
"dxil-root-signature-analysis",
697
"DXIL Root Signature Analysis", true, true)
698
699