Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
stenzek
GitHub Repository: stenzek/duckstation
Path: blob/master/dep/reshadefx/src/effect_codegen_spirv.cpp
7328 views
1
/*
2
* Copyright (C) 2014 Patrick Mours
3
* SPDX-License-Identifier: BSD-3-Clause
4
*/
5
6
#include "effect_parser.hpp"
7
#include "effect_codegen.hpp"
8
#include <cassert>
9
#include <cstring> // std::memcmp
10
#include <charconv> // std::from_chars
11
#include <algorithm> // std::find_if, std::max, std::sort
12
#include <unordered_set>
13
14
// Use the C++ variant of the SPIR-V headers
15
#include <spirv.hpp>
16
namespace spv {
17
#include <GLSL.std.450.h>
18
}
19
20
using namespace reshadefx;
21
22
inline uint32_t align_up(uint32_t size, uint32_t alignment)
23
{
24
alignment -= 1;
25
return ((size + alignment) & ~alignment);
26
}
27
28
/// <summary>
29
/// A single instruction in a SPIR-V module
30
/// </summary>
31
struct spirv_instruction
32
{
33
spv::Op op;
34
spv::Id type;
35
spv::Id result;
36
std::vector<spv::Id> operands;
37
38
explicit spirv_instruction(spv::Op op = spv::OpNop) : op(op), type(0), result(0) {}
39
spirv_instruction(spv::Op op, spv::Id result) : op(op), type(result), result(0) {}
40
spirv_instruction(spv::Op op, spv::Id type, spv::Id result) : op(op), type(type), result(result) {}
41
42
/// <summary>
43
/// Add a single operand to the instruction.
44
/// </summary>
45
spirv_instruction &add(spv::Id operand)
46
{
47
operands.push_back(operand);
48
return *this;
49
}
50
51
/// <summary>
52
/// Add a range of operands to the instruction.
53
/// </summary>
54
template <typename It>
55
spirv_instruction &add(It begin, It end)
56
{
57
operands.insert(operands.end(), begin, end);
58
return *this;
59
}
60
61
/// <summary>
62
/// Add a null-terminated literal UTF-8 string to the instruction.
63
/// </summary>
64
spirv_instruction &add_string(const char *string)
65
{
66
uint32_t word;
67
do {
68
word = 0;
69
for (uint32_t i = 0; i < 4 && *string; ++i)
70
reinterpret_cast<uint8_t *>(&word)[i] = *string++;
71
add(word);
72
} while (*string || (word & 0xFF000000));
73
return *this;
74
}
75
76
/// <summary>
77
/// Write this instruction to a SPIR-V module.
78
/// </summary>
79
/// <param name="output">The output stream to append this instruction to.</param>
80
void write(std::basic_string<char> &output) const
81
{
82
// See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html
83
// 0 | Opcode: The 16 high-order bits are the WordCount of the instruction. The 16 low-order bits are the opcode enumerant.
84
// 1 | Optional instruction type <id>
85
// . | Optional instruction Result <id>
86
// . | Operand 1 (if needed)
87
// . | Operand 2 (if needed)
88
// ... | ...
89
// WordCount - 1 | Operand N (N is determined by WordCount minus the 1 to 3 words used for the opcode, instruction type <id>, and instruction Result <id>).
90
91
const uint32_t word_count = 1 + (type != 0) + (result != 0) + static_cast<uint32_t>(operands.size());
92
write_word(output, (word_count << spv::WordCountShift) | op);
93
94
// Optional instruction type ID
95
if (type != 0)
96
write_word(output, type);
97
98
// Optional instruction result ID
99
if (result != 0)
100
write_word(output, result);
101
102
// Write out the operands
103
for (const uint32_t operand : operands)
104
write_word(output, operand);
105
}
106
107
static void write_word(std::basic_string<char> &output, uint32_t word)
108
{
109
output.insert(output.end(), reinterpret_cast<const char *>(&word), reinterpret_cast<const char *>(&word + 1));
110
}
111
112
operator uint32_t() const
113
{
114
assert(result != 0);
115
116
return result;
117
}
118
};
119
120
/// <summary>
121
/// A list of instructions forming a basic block in the SPIR-V module
122
/// </summary>
123
struct spirv_basic_block
124
{
125
std::vector<spirv_instruction> instructions;
126
127
/// <summary>
128
/// Append another basic block the end of this one.
129
/// </summary>
130
void append(const spirv_basic_block &block)
131
{
132
instructions.insert(instructions.end(), block.instructions.begin(), block.instructions.end());
133
}
134
};
135
136
class codegen_spirv final : public codegen
137
{
138
static_assert(sizeof(id) == sizeof(spv::Id), "unexpected SPIR-V id type size");
139
140
public:
141
codegen_spirv(bool vulkan_semantics, bool debug_info, bool uniforms_to_spec_constants, bool enable_16bit_types, bool flip_vert_y, bool discard_is_demote) :
142
_debug_info(debug_info),
143
_vulkan_semantics(vulkan_semantics),
144
_uniforms_to_spec_constants(uniforms_to_spec_constants),
145
_enable_16bit_types(enable_16bit_types),
146
_flip_vert_y(flip_vert_y),
147
_discard_is_demote(discard_is_demote)
148
{
149
_glsl_ext = make_id();
150
}
151
152
private:
153
struct type_lookup
154
{
155
reshadefx::type type;
156
bool is_ptr;
157
uint32_t array_stride;
158
std::pair<spv::StorageClass, spv::ImageFormat> storage;
159
160
friend bool operator==(const type_lookup &lhs, const type_lookup &rhs)
161
{
162
return lhs.type == rhs.type && lhs.is_ptr == rhs.is_ptr && lhs.array_stride == rhs.array_stride && lhs.storage == rhs.storage;
163
}
164
};
165
struct function_blocks
166
{
167
spirv_basic_block declaration;
168
spirv_basic_block variables;
169
spirv_basic_block definition;
170
reshadefx::type return_type;
171
std::vector<reshadefx::type> param_types;
172
173
friend bool operator==(const function_blocks &lhs, const function_blocks &rhs)
174
{
175
if (lhs.param_types.size() != rhs.param_types.size())
176
return false;
177
for (size_t i = 0; i < lhs.param_types.size(); ++i)
178
if (!(lhs.param_types[i] == rhs.param_types[i]))
179
return false;
180
return lhs.return_type == rhs.return_type;
181
}
182
};
183
184
bool _debug_info = false;
185
bool _vulkan_semantics = false;
186
bool _uniforms_to_spec_constants = false;
187
bool _enable_16bit_types = false;
188
bool _flip_vert_y = false;
189
bool _discard_is_demote = false;
190
191
spirv_basic_block _entries;
192
spirv_basic_block _execution_modes;
193
spirv_basic_block _debug_a;
194
spirv_basic_block _debug_b;
195
spirv_basic_block _annotations;
196
spirv_basic_block _types_and_constants;
197
spirv_basic_block _variables;
198
199
std::vector<function_blocks> _functions_blocks;
200
std::unordered_map<id, spirv_basic_block> _block_data;
201
spirv_basic_block *_current_block_data = nullptr;
202
203
spv::Id _glsl_ext = 0;
204
spv::Id _global_ubo_type = 0;
205
spv::Id _global_ubo_variable = 0;
206
std::vector<spv::Id> _global_ubo_types;
207
function_blocks *_current_function_blocks = nullptr;
208
209
std::vector<std::pair<type_lookup, spv::Id>> _type_lookup;
210
std::vector<std::tuple<type, constant, spv::Id>> _constant_lookup;
211
std::vector<std::pair<function_blocks, spv::Id>> _function_type_lookup;
212
std::unordered_map<std::string, spv::Id> _string_lookup;
213
std::unordered_map<spv::Id, std::pair<spv::StorageClass, spv::ImageFormat>> _storage_lookup;
214
std::unordered_map<std::string, uint32_t> _semantic_to_location;
215
216
std::unordered_set<spv::Id> _spec_constants;
217
std::unordered_set<spv::Capability> _capabilities;
218
219
void add_location(const location &loc, spirv_basic_block &block)
220
{
221
if (loc.source.empty() || !_debug_info)
222
return;
223
224
spv::Id file;
225
226
if (const auto it = _string_lookup.find(loc.source);
227
it != _string_lookup.end())
228
{
229
file = it->second;
230
}
231
else
232
{
233
file =
234
add_instruction(spv::OpString, 0, _debug_a)
235
.add_string(loc.source.c_str());
236
_string_lookup.emplace(loc.source, file);
237
}
238
239
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpLine
240
add_instruction_without_result(spv::OpLine, block)
241
.add(file)
242
.add(loc.line)
243
.add(loc.column);
244
}
245
spirv_instruction &add_instruction(spv::Op op, spv::Id type = 0)
246
{
247
assert(is_in_function() && is_in_block());
248
249
return add_instruction(op, type, *_current_block_data);
250
}
251
spirv_instruction &add_instruction(spv::Op op, spv::Id type, spirv_basic_block &block)
252
{
253
spirv_instruction &instruction = add_instruction_without_result(op, block);
254
instruction.type = type;
255
instruction.result = make_id();
256
return instruction;
257
}
258
spirv_instruction &add_instruction_without_result(spv::Op op)
259
{
260
assert(is_in_function() && is_in_block());
261
262
return add_instruction_without_result(op, *_current_block_data);
263
}
264
spirv_instruction &add_instruction_without_result(spv::Op op, spirv_basic_block &block)
265
{
266
return block.instructions.emplace_back(op);
267
}
268
269
void finalize_header_section(std::basic_string<char> &spirv) const
270
{
271
// Write SPIRV header info
272
spirv_instruction::write_word(spirv, spv::MagicNumber);
273
spirv_instruction::write_word(spirv, 0x10300); // Force SPIR-V 1.3
274
spirv_instruction::write_word(spirv, 0u); // Generator magic number, see https://www.khronos.org/registry/spir-v/api/spir-v.xml
275
spirv_instruction::write_word(spirv, _next_id); // Maximum ID
276
spirv_instruction::write_word(spirv, 0u); // Reserved for instruction schema
277
278
// All capabilities
279
spirv_instruction(spv::OpCapability)
280
.add(spv::CapabilityShader) // Implicitly declares the Matrix capability too
281
.write(spirv);
282
283
for (const spv::Capability capability : _capabilities)
284
spirv_instruction(spv::OpCapability)
285
.add(capability)
286
.write(spirv);
287
288
// Optional extension instructions
289
spirv_instruction(spv::OpExtInstImport, _glsl_ext)
290
.add_string("GLSL.std.450") // Import GLSL extension
291
.write(spirv);
292
293
// Single required memory model instruction
294
spirv_instruction(spv::OpMemoryModel)
295
.add(spv::AddressingModelLogical)
296
.add(spv::MemoryModelGLSL450)
297
.write(spirv);
298
}
299
void finalize_debug_info_section(std::basic_string<char> &spirv) const
300
{
301
spirv_instruction(spv::OpSource)
302
.add(spv::SourceLanguageUnknown) // ReShade FX is not a reserved token at the moment
303
.add(0) // Language version, TODO: Maybe fill in ReShade version here?
304
.write(spirv);
305
306
if (_debug_info)
307
{
308
// All debug instructions
309
for (const spirv_instruction &inst : _debug_a.instructions)
310
inst.write(spirv);
311
}
312
}
313
void finalize_type_and_constants_section(std::basic_string<char> &spirv) const
314
{
315
// All type declarations
316
for (const spirv_instruction &inst : _types_and_constants.instructions)
317
inst.write(spirv);
318
319
// Initialize the UBO type now that all member types are known
320
if (_global_ubo_type == 0 || _global_ubo_variable == 0)
321
return;
322
323
const id global_ubo_type_ptr = _global_ubo_type + 1;
324
325
spirv_instruction(spv::OpTypeStruct, _global_ubo_type)
326
.add(_global_ubo_types.begin(), _global_ubo_types.end())
327
.write(spirv);
328
spirv_instruction(spv::OpTypePointer, global_ubo_type_ptr)
329
.add(spv::StorageClassUniform)
330
.add(_global_ubo_type)
331
.write(spirv);
332
333
spirv_instruction(spv::OpVariable, global_ubo_type_ptr, _global_ubo_variable)
334
.add(spv::StorageClassUniform)
335
.write(spirv);
336
}
337
338
std::basic_string<char> finalize_code() const override
339
{
340
std::basic_string<char> spirv;
341
finalize_header_section(spirv);
342
343
// All entry point declarations
344
for (const spirv_instruction &inst : _entries.instructions)
345
inst.write(spirv);
346
347
// All execution mode declarations
348
for (const spirv_instruction &inst : _execution_modes.instructions)
349
inst.write(spirv);
350
351
finalize_debug_info_section(spirv);
352
353
for (const spirv_instruction& inst : _debug_b.instructions)
354
inst.write(spirv);
355
356
// All annotation instructions
357
for (const spirv_instruction &inst : _annotations.instructions)
358
inst.write(spirv);
359
360
finalize_type_and_constants_section(spirv);
361
362
for (const spirv_instruction &inst : _variables.instructions)
363
inst.write(spirv);
364
365
// All function definitions
366
for (const function_blocks &func : _functions_blocks)
367
{
368
if (func.definition.instructions.empty())
369
continue;
370
371
for (const spirv_instruction &inst : func.declaration.instructions)
372
inst.write(spirv);
373
374
// Grab first label and move it in front of variable declarations
375
func.definition.instructions.front().write(spirv);
376
assert(func.definition.instructions.front().op == spv::OpLabel);
377
378
for (const spirv_instruction &inst : func.variables.instructions)
379
inst.write(spirv);
380
for (auto inst_it = func.definition.instructions.begin() + 1; inst_it != func.definition.instructions.end(); ++inst_it)
381
inst_it->write(spirv);
382
}
383
384
return spirv;
385
}
386
std::basic_string<char> finalize_code_for_entry_point(const std::string &entry_point_name) const override
387
{
388
const auto entry_point_it = std::find_if(_functions.begin(), _functions.end(),
389
[&entry_point_name](const std::unique_ptr<function> &func) {
390
return func->unique_name == entry_point_name;
391
});
392
if (entry_point_it == _functions.end())
393
return {};
394
const function &entry_point = *entry_point_it->get();
395
396
const auto write_entry_point = [this](const spirv_instruction& oins, std::basic_string<char>& spirv) {
397
assert(oins.operands.size() > 2);
398
spirv_instruction nins(oins.op, oins.type, oins.result);
399
nins.add(oins.operands[0]);
400
nins.add(oins.operands[1]);
401
nins.add_string("main");
402
403
size_t param_start_index = 2;
404
while (param_start_index < oins.operands.size() && (oins.operands[param_start_index] & 0xFF000000) != 0)
405
param_start_index++;
406
407
// skip zero
408
param_start_index++;
409
410
for (size_t i = param_start_index; i < oins.operands.size(); i++)
411
nins.add(oins.operands[i]);
412
nins.write(spirv);
413
};
414
415
// Build list of IDs to remove
416
std::vector<spv::Id> variables_to_remove;
417
#if 1
418
std::vector<spv::Id> functions_to_remove;
419
#else
420
for (const sampler &info : _module.samplers)
421
if (std::find(entry_point.referenced_samplers.begin(), entry_point.referenced_samplers.end(), info.id) == entry_point.referenced_samplers.end())
422
variables_to_remove.push_back(info.id);
423
for (const storage &info : _module.storages)
424
if (std::find(entry_point.referenced_storages.begin(), entry_point.referenced_storages.end(), info.id) == entry_point.referenced_storages.end())
425
variables_to_remove.push_back(info.id);
426
#endif
427
428
std::basic_string<char> spirv;
429
finalize_header_section(spirv);
430
431
// The entry point and execution mode declaration
432
for (const spirv_instruction &inst : _entries.instructions)
433
{
434
assert(inst.op == spv::OpEntryPoint);
435
436
// Only add the matching entry point
437
if (inst.operands[1] == entry_point.id)
438
{
439
write_entry_point(inst, spirv);
440
}
441
else
442
{
443
#if 1
444
functions_to_remove.push_back(inst.operands[1]);
445
#endif
446
// Add interface variables to list of variables to remove
447
for (uint32_t k = 2 + static_cast<uint32_t>((std::strlen(reinterpret_cast<const char *>(&inst.operands[2])) + 4) / 4); k < inst.operands.size(); ++k)
448
variables_to_remove.push_back(inst.operands[k]);
449
}
450
}
451
452
for (const spirv_instruction &inst : _execution_modes.instructions)
453
{
454
assert(inst.op == spv::OpExecutionMode);
455
456
// Only add execution mode for the matching entry point
457
if (inst.operands[0] == entry_point.id)
458
{
459
inst.write(spirv);
460
}
461
}
462
463
finalize_debug_info_section(spirv);
464
465
for (const spirv_instruction &inst : _debug_b.instructions)
466
{
467
// Remove all names of interface variables and functions for non-matching entry points
468
if (std::find(variables_to_remove.begin(), variables_to_remove.end(), inst.operands[0]) != variables_to_remove.end() ||
469
std::find(functions_to_remove.begin(), functions_to_remove.end(), inst.operands[0]) != functions_to_remove.end())
470
continue;
471
472
inst.write(spirv);
473
}
474
475
// All annotation instructions
476
for (spirv_instruction inst : _annotations.instructions)
477
{
478
if (inst.op == spv::OpDecorate)
479
{
480
// Remove all decorations targeting any of the interface variables for non-matching entry points
481
if (std::find(variables_to_remove.begin(), variables_to_remove.end(), inst.operands[0]) != variables_to_remove.end())
482
continue;
483
484
// Replace bindings
485
if (inst.operands[1] == spv::DecorationBinding)
486
{
487
if (const auto referenced_sampler_it = std::find(entry_point.referenced_samplers.begin(), entry_point.referenced_samplers.end(), inst.operands[0]);
488
referenced_sampler_it != entry_point.referenced_samplers.end())
489
inst.operands[2] = static_cast<uint32_t>(std::distance(entry_point.referenced_samplers.begin(), referenced_sampler_it));
490
else
491
if (const auto referenced_storage_it = std::find(entry_point.referenced_storages.begin(), entry_point.referenced_storages.end(), inst.operands[0]);
492
referenced_storage_it != entry_point.referenced_storages.end())
493
inst.operands[2] = static_cast<uint32_t>(std::distance(entry_point.referenced_storages.begin(), referenced_storage_it));
494
}
495
}
496
497
inst.write(spirv);
498
}
499
500
finalize_type_and_constants_section(spirv);
501
502
for (const spirv_instruction &inst : _variables.instructions)
503
{
504
// Remove all declarations of the interface variables for non-matching entry points
505
if (inst.op == spv::OpVariable && std::find(variables_to_remove.begin(), variables_to_remove.end(), inst.result) != variables_to_remove.end())
506
continue;
507
508
inst.write(spirv);
509
}
510
511
// All referenced function definitions
512
for (const function_blocks &func : _functions_blocks)
513
{
514
if (func.definition.instructions.empty())
515
continue;
516
517
assert(func.declaration.instructions[func.declaration.instructions[0].op != spv::OpFunction ? 1 : 0].op == spv::OpFunction);
518
const spv::Id definition = func.declaration.instructions[func.declaration.instructions[0].op != spv::OpFunction ? 1 : 0].result;
519
520
#if 1
521
if (std::find(functions_to_remove.begin(), functions_to_remove.end(), definition) != functions_to_remove.end())
522
#else
523
if (struct_definition != entry_point.struct_definition &&
524
entry_point.referenced_functions.find(struct_definition) == entry_point.referenced_functions.end())
525
#endif
526
continue;
527
528
for (const spirv_instruction &inst : func.declaration.instructions)
529
inst.write(spirv);
530
531
// Grab first label and move it in front of variable declarations
532
func.definition.instructions.front().write(spirv);
533
assert(func.definition.instructions.front().op == spv::OpLabel);
534
535
for (const spirv_instruction &inst : func.variables.instructions)
536
inst.write(spirv);
537
for (auto inst_it = func.definition.instructions.begin() + 1; inst_it != func.definition.instructions.end(); ++inst_it)
538
inst_it->write(spirv);
539
}
540
541
return spirv;
542
}
543
544
spv::Id convert_type(type info, bool is_ptr = false, spv::StorageClass storage = spv::StorageClassFunction, spv::ImageFormat format = spv::ImageFormatUnknown, uint32_t array_stride = 0)
545
{
546
assert(array_stride == 0 || info.is_array());
547
548
// The storage class is only relevant for pointers, so ignore it for other types during lookup
549
if (is_ptr == false)
550
storage = spv::StorageClassFunction;
551
// There cannot be sampler variables that are local to a function, so always assume uniform storage for them
552
if (info.is_object())
553
storage = spv::StorageClassUniformConstant;
554
else
555
assert(format == spv::ImageFormatUnknown);
556
557
if (info.is_sampler() || info.is_storage())
558
info.rows = info.cols = 1;
559
560
// Fall back to 32-bit types and use relaxed precision decoration instead if 16-bit types are not enabled
561
if (!_enable_16bit_types && info.is_numeric() && info.precision() < 32)
562
info.base = static_cast<type::datatype>(info.base + 1); // min16int -> int, min16uint -> uint, min16float -> float
563
564
const type_lookup lookup { info, is_ptr, array_stride, { storage, format } };
565
566
if (const auto lookup_it = std::find_if(_type_lookup.begin(), _type_lookup.end(),
567
[&lookup](const std::pair<type_lookup, spv::Id> &lookup_entry) { return lookup_entry.first == lookup; });
568
lookup_it != _type_lookup.end())
569
return lookup_it->second;
570
571
spv::Id type_id, elem_type_id;
572
if (is_ptr)
573
{
574
elem_type_id = convert_type(info, false, storage, format, array_stride);
575
type_id =
576
add_instruction(spv::OpTypePointer, 0, _types_and_constants)
577
.add(storage)
578
.add(elem_type_id);
579
}
580
else if (info.is_array())
581
{
582
type elem_info = info;
583
elem_info.array_length = 0;
584
585
elem_type_id = convert_type(elem_info, false, storage, format);
586
587
// Make sure we don't get any dynamic arrays here
588
assert(info.is_bounded_array());
589
590
const spv::Id array_length_id = emit_constant(info.array_length);
591
592
type_id =
593
add_instruction(spv::OpTypeArray, 0, _types_and_constants)
594
.add(elem_type_id)
595
.add(array_length_id);
596
597
if (array_stride != 0)
598
add_decoration(type_id, spv::DecorationArrayStride, { array_stride });
599
}
600
else if (info.is_matrix())
601
{
602
// Convert MxN matrix to a SPIR-V matrix with M vectors with N elements
603
type elem_info = info;
604
elem_info.rows = info.cols;
605
elem_info.cols = 1;
606
607
elem_type_id = convert_type(elem_info, false, storage, format);
608
609
// Matrix types with just one row are interpreted as if they were a vector type
610
if (info.rows == 1)
611
return elem_type_id;
612
613
type_id =
614
add_instruction(spv::OpTypeMatrix, 0, _types_and_constants)
615
.add(elem_type_id)
616
.add(info.rows);
617
}
618
else if (info.is_vector())
619
{
620
type elem_info = info;
621
elem_info.rows = 1;
622
elem_info.cols = 1;
623
624
elem_type_id = convert_type(elem_info, false, storage, format);
625
type_id =
626
add_instruction(spv::OpTypeVector, 0, _types_and_constants)
627
.add(elem_type_id)
628
.add(info.rows);
629
}
630
else
631
{
632
switch (info.base)
633
{
634
case type::t_void:
635
assert(info.rows == 0 && info.cols == 0);
636
type_id = add_instruction(spv::OpTypeVoid, 0, _types_and_constants);
637
break;
638
case type::t_bool:
639
assert(info.rows == 1 && info.cols == 1);
640
type_id = add_instruction(spv::OpTypeBool, 0, _types_and_constants);
641
break;
642
case type::t_min16int:
643
assert(_enable_16bit_types && info.rows == 1 && info.cols == 1);
644
add_capability(spv::CapabilityInt16);
645
if (storage == spv::StorageClassInput || storage == spv::StorageClassOutput)
646
add_capability(spv::CapabilityStorageInputOutput16);
647
type_id =
648
add_instruction(spv::OpTypeInt, 0, _types_and_constants)
649
.add(16) // Width
650
.add(1); // Signedness
651
break;
652
case type::t_int:
653
assert(info.rows == 1 && info.cols == 1);
654
type_id =
655
add_instruction(spv::OpTypeInt, 0, _types_and_constants)
656
.add(32) // Width
657
.add(1); // Signedness
658
break;
659
case type::t_min16uint:
660
assert(_enable_16bit_types && info.rows == 1 && info.cols == 1);
661
add_capability(spv::CapabilityInt16);
662
if (storage == spv::StorageClassInput || storage == spv::StorageClassOutput)
663
add_capability(spv::CapabilityStorageInputOutput16);
664
type_id =
665
add_instruction(spv::OpTypeInt, 0, _types_and_constants)
666
.add(16) // Width
667
.add(0); // Signedness
668
break;
669
case type::t_uint:
670
assert(info.rows == 1 && info.cols == 1);
671
type_id =
672
add_instruction(spv::OpTypeInt, 0, _types_and_constants)
673
.add(32) // Width
674
.add(0); // Signedness
675
break;
676
case type::t_min16float:
677
assert(_enable_16bit_types && info.rows == 1 && info.cols == 1);
678
add_capability(spv::CapabilityFloat16);
679
if (storage == spv::StorageClassInput || storage == spv::StorageClassOutput)
680
add_capability(spv::CapabilityStorageInputOutput16);
681
type_id =
682
add_instruction(spv::OpTypeFloat, 0, _types_and_constants)
683
.add(16); // Width
684
break;
685
case type::t_float:
686
assert(info.rows == 1 && info.cols == 1);
687
type_id =
688
add_instruction(spv::OpTypeFloat, 0, _types_and_constants)
689
.add(32); // Width
690
break;
691
case type::t_struct:
692
assert(info.rows == 0 && info.cols == 0 && info.struct_definition != 0);
693
type_id = info.struct_definition;
694
break;
695
case type::t_sampler1d_int:
696
case type::t_sampler1d_uint:
697
case type::t_sampler1d_float:
698
add_capability(spv::CapabilitySampled1D);
699
[[fallthrough]];
700
case type::t_sampler2d_int:
701
case type::t_sampler2d_uint:
702
case type::t_sampler2d_float:
703
case type::t_sampler3d_int:
704
case type::t_sampler3d_uint:
705
case type::t_sampler3d_float:
706
elem_type_id = convert_image_type(info, format);
707
type_id =
708
add_instruction(spv::OpTypeSampledImage, 0, _types_and_constants)
709
.add(elem_type_id);
710
break;
711
case type::t_storage1d_int:
712
case type::t_storage1d_uint:
713
case type::t_storage1d_float:
714
add_capability(spv::CapabilityImage1D);
715
[[fallthrough]];
716
case type::t_storage2d_int:
717
case type::t_storage2d_uint:
718
case type::t_storage2d_float:
719
case type::t_storage3d_int:
720
case type::t_storage3d_uint:
721
case type::t_storage3d_float:
722
// No format specified for the storage image
723
if (format == spv::ImageFormatUnknown)
724
add_capability(spv::CapabilityStorageImageWriteWithoutFormat);
725
return convert_image_type(info, format);
726
default:
727
assert(false);
728
return 0;
729
}
730
}
731
732
_type_lookup.push_back({ lookup, type_id });
733
734
return type_id;
735
}
736
spv::Id convert_type(const function_blocks &info)
737
{
738
if (const auto lookup_it = std::find_if(_function_type_lookup.begin(), _function_type_lookup.end(),
739
[&lookup = info](const std::pair<function_blocks, spv::Id> &lookup_entry) { return lookup_entry.first == lookup; });
740
lookup_it != _function_type_lookup.end())
741
return lookup_it->second;
742
743
const spv::Id return_type_id = convert_type(info.return_type);
744
assert(return_type_id != 0);
745
746
std::vector<spv::Id> param_type_ids;
747
param_type_ids.reserve(info.param_types.size());
748
for (const type &param_type : info.param_types)
749
param_type_ids.push_back(convert_type(param_type, true));
750
751
spirv_instruction &inst = add_instruction(spv::OpTypeFunction, 0, _types_and_constants)
752
.add(return_type_id)
753
.add(param_type_ids.begin(), param_type_ids.end());
754
755
_function_type_lookup.push_back({ info, inst });
756
757
return inst;
758
}
759
spv::Id convert_image_type(type info, spv::ImageFormat format = spv::ImageFormatUnknown)
760
{
761
type elem_info = info;
762
elem_info.rows = 1;
763
elem_info.cols = 1;
764
765
if (!info.is_numeric())
766
{
767
if ((info.is_integral() && info.is_signed()) || (format >= spv::ImageFormatRgba32i && format <= spv::ImageFormatR8i))
768
elem_info.base = type::t_int;
769
else if ((info.is_integral() && info.is_unsigned()) || (format >= spv::ImageFormatRgba32ui && format <= spv::ImageFormatR8ui))
770
elem_info.base = type::t_uint;
771
else
772
elem_info.base = type::t_float;
773
}
774
775
type_lookup lookup { info, false, 0u, { spv::StorageClassUniformConstant, format } };
776
if (!info.is_storage())
777
{
778
lookup.type = elem_info;
779
lookup.type.base = static_cast<type::datatype>(type::t_texture1d + info.texture_dimension() - 1);
780
lookup.type.struct_definition = static_cast<uint32_t>(elem_info.base);
781
}
782
783
if (const auto lookup_it = std::find_if(_type_lookup.begin(), _type_lookup.end(),
784
[&lookup](const std::pair<type_lookup, spv::Id> &lookup_entry) { return lookup_entry.first == lookup; });
785
lookup_it != _type_lookup.end())
786
return lookup_it->second;
787
788
spv::Id type_id, elem_type_id = convert_type(elem_info, false, spv::StorageClassUniformConstant);
789
type_id =
790
add_instruction(spv::OpTypeImage, 0, _types_and_constants)
791
.add(elem_type_id) // Sampled Type (always a scalar type)
792
.add(spv::Dim1D + info.texture_dimension() - 1)
793
.add(0) // Not a depth image
794
.add(0) // Not an array
795
.add(0) // Not multi-sampled
796
.add(info.is_storage() ? 2 : 1) // Used with a sampler or as storage
797
.add(format);
798
799
_type_lookup.push_back({ lookup, type_id });
800
801
return type_id;
802
}
803
804
uint32_t semantic_to_location(const std::string &semantic, uint32_t max_attributes = 1)
805
{
806
if (const auto it = _semantic_to_location.find(semantic);
807
it != _semantic_to_location.end())
808
return it->second;
809
810
// Extract the semantic index from the semantic name (e.g. 2 for "TEXCOORD2")
811
size_t digit_index = semantic.size() - 1;
812
while (digit_index != 0 && semantic[digit_index] >= '0' && semantic[digit_index] <= '9')
813
digit_index--;
814
digit_index++;
815
816
const std::string semantic_base = semantic.substr(0, digit_index);
817
818
uint32_t semantic_digit = 0;
819
std::from_chars(semantic.c_str() + digit_index, semantic.c_str() + semantic.size(), semantic_digit);
820
821
if (semantic_base == "COLOR" || semantic_base == "SV_TARGET")
822
return semantic_digit;
823
824
uint32_t location = static_cast<uint32_t>(_semantic_to_location.size());
825
826
// Now create adjoining location indices for all possible semantic indices belonging to this semantic name
827
for (uint32_t a = 0; a < semantic_digit + max_attributes; ++a)
828
{
829
const auto insert = _semantic_to_location.emplace(semantic_base + std::to_string(a), location + a);
830
if (!insert.second)
831
{
832
assert(a == 0 || (insert.first->second - a) == location);
833
834
// Semantic was already created with a different location index, so need to remap to that
835
location = insert.first->second - a;
836
}
837
}
838
839
return location + semantic_digit;
840
}
841
842
spv::BuiltIn semantic_to_builtin(const std::string &semantic, shader_type stype) const
843
{
844
if (semantic == "SV_POSITION")
845
return stype == shader_type::pixel ? spv::BuiltInFragCoord : spv::BuiltInPosition;
846
if (semantic == "SV_POINTSIZE")
847
return spv::BuiltInPointSize;
848
if (semantic == "SV_DEPTH")
849
return spv::BuiltInFragDepth;
850
if (semantic == "SV_VERTEXID")
851
return _vulkan_semantics ? spv::BuiltInVertexIndex : spv::BuiltInVertexId;
852
if (semantic == "SV_ISFRONTFACE")
853
return spv::BuiltInFrontFacing;
854
if (semantic == "SV_GROUPID")
855
return spv::BuiltInWorkgroupId;
856
if (semantic == "SV_GROUPINDEX")
857
return spv::BuiltInLocalInvocationIndex;
858
if (semantic == "SV_GROUPTHREADID")
859
return spv::BuiltInLocalInvocationId;
860
if (semantic == "SV_DISPATCHTHREADID")
861
return spv::BuiltInGlobalInvocationId;
862
return spv::BuiltInMax;
863
}
864
spv::ImageFormat format_to_image_format(texture_format format)
865
{
866
switch (format)
867
{
868
default:
869
assert(false);
870
[[fallthrough]];
871
case texture_format::unknown:
872
return spv::ImageFormatUnknown;
873
case texture_format::r8:
874
add_capability(spv::CapabilityStorageImageExtendedFormats);
875
return spv::ImageFormatR8;
876
case texture_format::r16:
877
add_capability(spv::CapabilityStorageImageExtendedFormats);
878
return spv::ImageFormatR16;
879
case texture_format::r16f:
880
add_capability(spv::CapabilityStorageImageExtendedFormats);
881
return spv::ImageFormatR16f;
882
case texture_format::r32i:
883
return spv::ImageFormatR32i;
884
case texture_format::r32u:
885
return spv::ImageFormatR32ui;
886
case texture_format::r32f:
887
return spv::ImageFormatR32f;
888
case texture_format::rg8:
889
add_capability(spv::CapabilityStorageImageExtendedFormats);
890
return spv::ImageFormatRg8;
891
case texture_format::rg16:
892
add_capability(spv::CapabilityStorageImageExtendedFormats);
893
return spv::ImageFormatRg16;
894
case texture_format::rg16f:
895
add_capability(spv::CapabilityStorageImageExtendedFormats);
896
return spv::ImageFormatRg16f;
897
case texture_format::rg32f:
898
add_capability(spv::CapabilityStorageImageExtendedFormats);
899
return spv::ImageFormatRg32f;
900
case texture_format::rgba8:
901
return spv::ImageFormatRgba8;
902
case texture_format::rgba16:
903
add_capability(spv::CapabilityStorageImageExtendedFormats);
904
return spv::ImageFormatRgba16;
905
case texture_format::rgba16f:
906
return spv::ImageFormatRgba16f;
907
case texture_format::rgba32f:
908
return spv::ImageFormatRgba32f;
909
case texture_format::rgb10a2:
910
add_capability(spv::CapabilityStorageImageExtendedFormats);
911
return spv::ImageFormatRgb10A2;
912
}
913
}
914
915
void add_name(id id, const char *name)
916
{
917
if (!_debug_info)
918
return;
919
920
assert(name != nullptr);
921
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpName
922
add_instruction_without_result(spv::OpName, _debug_b)
923
.add(id)
924
.add_string(name);
925
}
926
void add_builtin(id id, spv::BuiltIn builtin)
927
{
928
add_instruction_without_result(spv::OpDecorate, _annotations)
929
.add(id)
930
.add(spv::DecorationBuiltIn)
931
.add(builtin);
932
}
933
void add_decoration(id id, spv::Decoration decoration, std::initializer_list<uint32_t> values = {})
934
{
935
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpDecorate
936
add_instruction_without_result(spv::OpDecorate, _annotations)
937
.add(id)
938
.add(decoration)
939
.add(values.begin(), values.end());
940
}
941
void add_member_name(id id, uint32_t member_index, const char *name)
942
{
943
if (!_debug_info)
944
return;
945
946
assert(name != nullptr);
947
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpMemberName
948
add_instruction_without_result(spv::OpMemberName, _debug_b)
949
.add(id)
950
.add(member_index)
951
.add_string(name);
952
}
953
void add_member_builtin(id id, uint32_t member_index, spv::BuiltIn builtin)
954
{
955
add_instruction_without_result(spv::OpMemberDecorate, _annotations)
956
.add(id)
957
.add(member_index)
958
.add(spv::DecorationBuiltIn)
959
.add(builtin);
960
}
961
void add_member_decoration(id id, uint32_t member_index, spv::Decoration decoration, std::initializer_list<uint32_t> values = {})
962
{
963
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpMemberDecorate
964
add_instruction_without_result(spv::OpMemberDecorate, _annotations)
965
.add(id)
966
.add(member_index)
967
.add(decoration)
968
.add(values.begin(), values.end());
969
}
970
void add_capability(spv::Capability capability)
971
{
972
_capabilities.insert(capability);
973
}
974
975
id define_struct(const location &loc, struct_type &info) override
976
{
977
// First define all member types to make sure they are declared before the struct type references them
978
std::vector<spv::Id> member_types;
979
member_types.reserve(info.member_list.size());
980
for (const member_type &member : info.member_list)
981
member_types.push_back(convert_type(member.type));
982
983
// Afterwards define the actual struct type
984
add_location(loc, _types_and_constants);
985
986
const id res = info.id =
987
add_instruction(spv::OpTypeStruct, 0, _types_and_constants)
988
.add(member_types.begin(), member_types.end());
989
990
if (!info.unique_name.empty())
991
add_name(res, info.unique_name.c_str());
992
993
for (uint32_t index = 0; index < info.member_list.size(); ++index)
994
{
995
const member_type &member = info.member_list[index];
996
997
add_member_name(res, index, member.name.c_str());
998
999
if (!_enable_16bit_types && member.type.is_numeric() && member.type.precision() < 32)
1000
add_member_decoration(res, index, spv::DecorationRelaxedPrecision);
1001
}
1002
1003
_structs.push_back(info);
1004
1005
return res;
1006
}
1007
id define_texture(const location &, texture &info) override
1008
{
1009
const id res = info.id = make_id(); // Need to create an unique ID here too, so that the symbol lookup for textures works
1010
1011
_module.textures.push_back(info);
1012
1013
return res;
1014
}
1015
id define_sampler(const location &loc, const texture &, sampler &info) override
1016
{
1017
const id res = info.id = define_variable(loc, info.type, info.unique_name.c_str(), spv::StorageClassUniformConstant);
1018
1019
// Default to a binding index equivalent to the entry in the sampler list (this is later overwritten in 'finalize_code_for_entry_point' to a more optimal placement)
1020
const uint32_t default_binding = static_cast<uint32_t>(_module.samplers.size());
1021
add_decoration(res, spv::DecorationBinding, { default_binding });
1022
add_decoration(res, spv::DecorationDescriptorSet, { 1 });
1023
1024
_module.samplers.push_back(info);
1025
1026
return res;
1027
}
1028
id define_storage(const location &loc, const texture &tex_info, storage &info) override
1029
{
1030
const id res = info.id = define_variable(loc, info.type, info.unique_name.c_str(), spv::StorageClassUniformConstant, format_to_image_format(tex_info.format));
1031
1032
// Default to a binding index equivalent to the entry in the storage list (this is later overwritten in 'finalize_code_for_entry_point' to a more optimal placement)
1033
const uint32_t default_binding = static_cast<uint32_t>(_module.storages.size());
1034
add_decoration(res, spv::DecorationBinding, { default_binding });
1035
add_decoration(res, spv::DecorationDescriptorSet, { 2 });
1036
1037
_module.storages.push_back(info);
1038
1039
return res;
1040
}
1041
id define_uniform(const location &, uniform &info) override
1042
{
1043
if (_uniforms_to_spec_constants && info.has_initializer_value)
1044
{
1045
const id res = emit_constant(info.type, info.initializer_value, true);
1046
1047
add_name(res, info.name.c_str());
1048
1049
const auto add_spec_constant = [this](const spirv_instruction &inst, const uniform &info, const constant &initializer_value, size_t initializer_offset) {
1050
assert(inst.op == spv::OpSpecConstant || inst.op == spv::OpSpecConstantTrue || inst.op == spv::OpSpecConstantFalse);
1051
1052
const uint32_t spec_id = static_cast<uint32_t>(_module.spec_constants.size());
1053
add_decoration(inst, spv::DecorationSpecId, { spec_id });
1054
1055
uniform scalar_info = info;
1056
scalar_info.type.rows = 1;
1057
scalar_info.type.cols = 1;
1058
scalar_info.size = 4;
1059
scalar_info.offset = static_cast<uint32_t>(initializer_offset);
1060
scalar_info.initializer_value = {};
1061
scalar_info.initializer_value.as_uint[0] = initializer_value.as_uint[initializer_offset];
1062
1063
_module.spec_constants.push_back(std::move(scalar_info));
1064
};
1065
1066
const spirv_instruction &base_inst = _types_and_constants.instructions.back();
1067
assert(base_inst == res);
1068
1069
// External specialization constants need to be scalars
1070
if (info.type.is_scalar())
1071
{
1072
add_spec_constant(base_inst, info, info.initializer_value, 0);
1073
}
1074
else
1075
{
1076
assert(base_inst.op == spv::OpSpecConstantComposite);
1077
1078
// Add each individual scalar component of the constant as a separate external specialization constant
1079
for (size_t i = 0; i < (info.type.is_array() ? base_inst.operands.size() : 1); ++i)
1080
{
1081
constant initializer_value = info.initializer_value;
1082
spirv_instruction elem_inst = base_inst;
1083
1084
if (info.type.is_array())
1085
{
1086
elem_inst = *std::find_if(_types_and_constants.instructions.rbegin(), _types_and_constants.instructions.rend(),
1087
[operand_id = base_inst.operands[i]](const spirv_instruction &inst) { return inst == operand_id; });
1088
1089
assert(initializer_value.array_data.size() == base_inst.operands.size());
1090
initializer_value = initializer_value.array_data[i];
1091
}
1092
1093
for (size_t row = 0; row < elem_inst.operands.size(); ++row)
1094
{
1095
const spirv_instruction &row_inst = *std::find_if(_types_and_constants.instructions.rbegin(), _types_and_constants.instructions.rend(),
1096
[operand_id = elem_inst.operands[row]](const spirv_instruction &inst) { return inst == operand_id; });
1097
1098
if (row_inst.op != spv::OpSpecConstantComposite)
1099
{
1100
add_spec_constant(row_inst, info, initializer_value, row);
1101
continue;
1102
}
1103
1104
for (size_t col = 0; col < row_inst.operands.size(); ++col)
1105
{
1106
const spirv_instruction &col_inst = *std::find_if(_types_and_constants.instructions.rbegin(), _types_and_constants.instructions.rend(),
1107
[operand_id = row_inst.operands[col]](const spirv_instruction &inst) { return inst == operand_id; });
1108
1109
add_spec_constant(col_inst, info, initializer_value, row * info.type.cols + col);
1110
}
1111
}
1112
}
1113
}
1114
1115
return res;
1116
}
1117
else
1118
{
1119
// Create global uniform buffer variable on demand
1120
if (_global_ubo_type == 0)
1121
{
1122
_global_ubo_type = make_id();
1123
make_id(); // Pointer type for '_global_ubo_type'
1124
1125
add_decoration(_global_ubo_type, spv::DecorationBlock);
1126
}
1127
if (_global_ubo_variable == 0)
1128
{
1129
_global_ubo_variable = make_id();
1130
1131
add_decoration(_global_ubo_variable, spv::DecorationDescriptorSet, { 0 });
1132
add_decoration(_global_ubo_variable, spv::DecorationBinding, { 0 });
1133
}
1134
1135
uint32_t alignment = (info.type.rows == 3 ? 4 : info.type.rows) * 4;
1136
info.size = info.type.rows * 4;
1137
1138
uint32_t array_stride = 16;
1139
const uint32_t matrix_stride = 16;
1140
1141
if (info.type.is_matrix())
1142
{
1143
alignment = matrix_stride;
1144
info.size = info.type.rows * matrix_stride;
1145
}
1146
if (info.type.is_array())
1147
{
1148
alignment = array_stride;
1149
array_stride = align_up(info.size, array_stride);
1150
// Uniform block rules do not permit anything in the padding of an array
1151
info.size = array_stride * info.type.array_length;
1152
}
1153
1154
info.offset = _module.total_uniform_size;
1155
info.offset = align_up(info.offset, alignment);
1156
_module.total_uniform_size = info.offset + info.size;
1157
1158
type ubo_type = info.type;
1159
// Convert boolean uniform variables to integer type so that they have a defined size
1160
if (info.type.is_boolean())
1161
ubo_type.base = type::t_uint;
1162
1163
const uint32_t member_index = static_cast<uint32_t>(_global_ubo_types.size());
1164
1165
// Composite objects in the uniform storage class must be explicitly laid out, which includes array types requiring a stride decoration
1166
_global_ubo_types.push_back(
1167
convert_type(ubo_type, false, spv::StorageClassUniform, spv::ImageFormatUnknown, info.type.is_array() ? array_stride : 0u));
1168
1169
add_member_name(_global_ubo_type, member_index, info.name.c_str());
1170
1171
add_member_decoration(_global_ubo_type, member_index, spv::DecorationOffset, { info.offset });
1172
1173
if (info.type.is_matrix())
1174
{
1175
// Read matrices in column major layout, even though they are actually row major, to avoid transposing them on every access (since SPIR-V uses column matrices)
1176
// TODO: This technically only works with square matrices
1177
add_member_decoration(_global_ubo_type, member_index, spv::DecorationColMajor);
1178
add_member_decoration(_global_ubo_type, member_index, spv::DecorationMatrixStride, { matrix_stride });
1179
}
1180
1181
_module.uniforms.push_back(info);
1182
1183
return 0xF0000000 | member_index;
1184
}
1185
}
1186
id define_variable(const location &loc, const type &type, std::string name, bool global, id initializer_value) override
1187
{
1188
spv::StorageClass storage = spv::StorageClassFunction;
1189
if (type.has(type::q_groupshared))
1190
storage = spv::StorageClassWorkgroup;
1191
else if (global)
1192
storage = spv::StorageClassPrivate;
1193
1194
return define_variable(loc, type, name.c_str(), storage, spv::ImageFormatUnknown, initializer_value);
1195
}
1196
id define_variable(const location &loc, const type &type, const char *name, spv::StorageClass storage, spv::ImageFormat format = spv::ImageFormatUnknown, id initializer_value = 0)
1197
{
1198
assert(storage != spv::StorageClassFunction || (_current_function_blocks != nullptr && _current_function != nullptr && !_current_function->unique_name.empty() && (_current_function->unique_name[0] == 'F' || _current_function->unique_name[0] == 'E')));
1199
1200
spirv_basic_block &block = (storage != spv::StorageClassFunction) ?
1201
_variables : _current_function_blocks->variables;
1202
1203
add_location(loc, block);
1204
1205
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpVariable
1206
spirv_instruction &inst = add_instruction(spv::OpVariable, convert_type(type, true, storage, format), block);
1207
inst.add(storage);
1208
1209
const id res = inst.result;
1210
1211
if (initializer_value != 0)
1212
{
1213
if (storage != spv::StorageClassFunction || /* is_entry_point = */ _current_function->unique_name[0] == 'E')
1214
{
1215
// The initializer for variables must be a constant
1216
inst.add(initializer_value);
1217
}
1218
else
1219
{
1220
// Only use the variable initializer on global variables, since local variables for e.g. "for" statements need to be assigned in their respective scope and not their declaration
1221
expression variable;
1222
variable.reset_to_lvalue(loc, res, type);
1223
emit_store(variable, initializer_value);
1224
}
1225
}
1226
1227
if (name != nullptr && *name != '\0')
1228
add_name(res, name);
1229
1230
if (!_enable_16bit_types && type.is_numeric() && type.precision() < 32)
1231
add_decoration(res, spv::DecorationRelaxedPrecision);
1232
1233
_storage_lookup[res] = { storage, format };
1234
1235
return res;
1236
}
1237
id define_function(const location &loc, function &info) override
1238
{
1239
assert(!is_in_function());
1240
1241
function_blocks &func = _functions_blocks.emplace_back();
1242
func.return_type = info.return_type;
1243
1244
for (const member_type &param : info.parameter_list)
1245
func.param_types.push_back(param.type);
1246
1247
add_location(loc, func.declaration);
1248
1249
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpFunction
1250
const id res = info.id =
1251
add_instruction(spv::OpFunction, convert_type(info.return_type), func.declaration)
1252
.add(spv::FunctionControlMaskNone)
1253
.add(convert_type(func));
1254
1255
if (!info.name.empty())
1256
add_name(res, info.name.c_str());
1257
1258
for (member_type &param : info.parameter_list)
1259
{
1260
add_location(param.location, func.declaration);
1261
1262
param.id = add_instruction(spv::OpFunctionParameter, convert_type(param.type, true), func.declaration);
1263
1264
add_name(param.id, param.name.c_str());
1265
}
1266
1267
_functions.push_back(std::make_unique<function>(info));
1268
_current_function = _functions.back().get();
1269
_current_function_blocks = &func;
1270
1271
return res;
1272
}
1273
1274
void define_entry_point(function &func) override
1275
{
1276
assert(!func.unique_name.empty() && func.unique_name[0] == 'F');
1277
func.unique_name[0] = 'E';
1278
1279
// Modify entry point name so each thread configuration is made separate
1280
if (func.type == shader_type::compute)
1281
func.unique_name +=
1282
'_' + std::to_string(func.num_threads[0]) +
1283
'_' + std::to_string(func.num_threads[1]) +
1284
'_' + std::to_string(func.num_threads[2]);
1285
1286
if (std::find_if(_module.entry_points.begin(), _module.entry_points.end(),
1287
[&func](const std::pair<std::string, shader_type> &entry_point) {
1288
return entry_point.first == func.unique_name;
1289
}) != _module.entry_points.end())
1290
return;
1291
1292
_module.entry_points.emplace_back(func.unique_name, func.type);
1293
1294
spv::Id position_variable = 0;
1295
spv::Id point_size_variable = 0;
1296
std::vector<spv::Id> inputs_and_outputs;
1297
std::vector<expression> call_params;
1298
1299
// Generate the glue entry point function
1300
function entry_point = func;
1301
entry_point.referenced_functions.push_back(func.id);
1302
1303
// Change function signature to 'void main()'
1304
entry_point.return_type = { type::t_void };
1305
entry_point.return_semantic.clear();
1306
entry_point.parameter_list.clear();
1307
1308
const id entry_point_definition = define_function({}, entry_point);
1309
enter_block(create_block());
1310
1311
const auto create_varying_param = [this, &call_params](const member_type &param) {
1312
// Initialize all output variables with zero
1313
const spv::Id variable = define_variable({}, param.type, nullptr, spv::StorageClassFunction, spv::ImageFormatUnknown, emit_constant(param.type, 0u));
1314
1315
expression &call_param = call_params.emplace_back();
1316
call_param.reset_to_lvalue({}, variable, param.type);
1317
1318
return variable;
1319
};
1320
1321
const auto create_varying_variable = [this, &inputs_and_outputs, &position_variable, &point_size_variable, stype = func.type](const type &param_type, const std::string &semantic, spv::StorageClass storage, int a = 0) {
1322
const spv::Id variable = define_variable({}, param_type, nullptr, storage);
1323
1324
if (const spv::BuiltIn builtin = semantic_to_builtin(semantic, stype);
1325
builtin != spv::BuiltInMax)
1326
{
1327
assert(a == 0); // Built-in variables cannot be arrays
1328
1329
add_builtin(variable, builtin);
1330
1331
if (builtin == spv::BuiltInPosition && storage == spv::StorageClassOutput)
1332
position_variable = variable;
1333
if (builtin == spv::BuiltInPointSize && storage == spv::StorageClassOutput)
1334
point_size_variable = variable;
1335
}
1336
else
1337
{
1338
assert(stype != shader_type::compute); // Compute shaders cannot have custom inputs or outputs
1339
1340
const uint32_t location = semantic_to_location(semantic, std::max(1u, param_type.array_length));
1341
add_decoration(variable, spv::DecorationLocation, { location + a });
1342
}
1343
1344
if (param_type.has(type::q_noperspective))
1345
add_decoration(variable, spv::DecorationNoPerspective);
1346
if (param_type.has(type::q_centroid))
1347
add_decoration(variable, spv::DecorationCentroid);
1348
if (param_type.has(type::q_nointerpolation))
1349
add_decoration(variable, spv::DecorationFlat);
1350
1351
inputs_and_outputs.push_back(variable);
1352
return variable;
1353
};
1354
1355
// Translate function parameters to input/output variables
1356
for (const member_type &param : func.parameter_list)
1357
{
1358
spv::Id param_var = create_varying_param(param);
1359
1360
// Create separate input/output variables for "inout" parameters
1361
if (param.type.has(type::q_in))
1362
{
1363
spv::Id param_value = 0;
1364
1365
// Flatten structure parameters
1366
if (param.type.is_struct())
1367
{
1368
const struct_type &struct_definition = get_struct(param.type.struct_definition);
1369
1370
type struct_type = param.type;
1371
const auto array_length = std::max(1u, param.type.array_length);
1372
struct_type.array_length = 0;
1373
1374
// Struct arrays need to be flattened into individual elements as well
1375
std::vector<spv::Id> array_element_ids;
1376
array_element_ids.reserve(array_length);
1377
for (unsigned int a = 0; a < array_length; a++)
1378
{
1379
std::vector<spv::Id> struct_element_ids;
1380
struct_element_ids.reserve(struct_definition.member_list.size());
1381
for (const member_type &member : struct_definition.member_list)
1382
{
1383
const spv::Id input_var = create_varying_variable(member.type, member.semantic, spv::StorageClassInput, a);
1384
1385
param_value =
1386
add_instruction(spv::OpLoad, convert_type(member.type))
1387
.add(input_var);
1388
struct_element_ids.push_back(param_value);
1389
}
1390
1391
param_value =
1392
add_instruction(spv::OpCompositeConstruct, convert_type(struct_type))
1393
.add(struct_element_ids.begin(), struct_element_ids.end());
1394
array_element_ids.push_back(param_value);
1395
}
1396
1397
if (param.type.is_array())
1398
{
1399
// Build the array from all constructed struct elements
1400
param_value =
1401
add_instruction(spv::OpCompositeConstruct, convert_type(param.type))
1402
.add(array_element_ids.begin(), array_element_ids.end());
1403
}
1404
}
1405
else
1406
{
1407
const spv::Id input_var = create_varying_variable(param.type, param.semantic, spv::StorageClassInput);
1408
1409
param_value =
1410
add_instruction(spv::OpLoad, convert_type(param.type))
1411
.add(input_var);
1412
}
1413
1414
add_instruction_without_result(spv::OpStore)
1415
.add(param_var)
1416
.add(param_value);
1417
}
1418
1419
if (param.type.has(type::q_out))
1420
{
1421
if (param.type.is_struct())
1422
{
1423
const struct_type &struct_definition = get_struct(param.type.struct_definition);
1424
1425
for (unsigned int a = 0, array_length = std::max(1u, param.type.array_length); a < array_length; a++)
1426
{
1427
for (const member_type &member : struct_definition.member_list)
1428
{
1429
create_varying_variable(member.type, member.semantic, spv::StorageClassOutput, a);
1430
}
1431
}
1432
}
1433
else
1434
{
1435
create_varying_variable(param.type, param.semantic, spv::StorageClassOutput);
1436
}
1437
}
1438
}
1439
1440
const id call_result = emit_call({}, func.id, func.return_type, call_params);
1441
1442
for (size_t i = 0, inputs_and_outputs_index = 0; i < func.parameter_list.size(); ++i)
1443
{
1444
const member_type &param = func.parameter_list[i];
1445
1446
if (param.type.has(type::q_out))
1447
{
1448
const spv::Id value =
1449
add_instruction(spv::OpLoad, convert_type(param.type))
1450
.add(call_params[i].base);
1451
1452
if (param.type.is_struct())
1453
{
1454
const struct_type &struct_definition = get_struct(param.type.struct_definition);
1455
1456
type struct_type = param.type;
1457
const auto array_length = std::max(1u, param.type.array_length);
1458
struct_type.array_length = 0;
1459
1460
// Skip input variables if this is an "inout" parameter
1461
if (param.type.has(type::q_in))
1462
inputs_and_outputs_index += struct_definition.member_list.size() * array_length;
1463
1464
// Split up struct array into individual struct elements again
1465
for (unsigned int a = 0; a < array_length; a++)
1466
{
1467
spv::Id element_value = value;
1468
if (param.type.is_array())
1469
{
1470
element_value =
1471
add_instruction(spv::OpCompositeExtract, convert_type(struct_type))
1472
.add(value)
1473
.add(a);
1474
}
1475
1476
// Split out struct fields into separate output variables again
1477
for (uint32_t member_index = 0; member_index < struct_definition.member_list.size(); ++member_index)
1478
{
1479
const spv::Id member_value =
1480
add_instruction(spv::OpCompositeExtract, convert_type(struct_definition.member_list[member_index].type))
1481
.add(element_value)
1482
.add(member_index);
1483
1484
add_instruction_without_result(spv::OpStore)
1485
.add(inputs_and_outputs[inputs_and_outputs_index++])
1486
.add(member_value);
1487
}
1488
}
1489
}
1490
else
1491
{
1492
// Skip input variable if this is an "inout" parameter (see loop above)
1493
if (param.type.has(type::q_in))
1494
inputs_and_outputs_index += 1;
1495
1496
add_instruction_without_result(spv::OpStore)
1497
.add(inputs_and_outputs[inputs_and_outputs_index++])
1498
.add(value);
1499
}
1500
}
1501
else
1502
{
1503
// Input parameters do not need to store anything, but increase the input/output variable index
1504
if (param.type.is_struct())
1505
{
1506
const struct_type &struct_definition = get_struct(param.type.struct_definition);
1507
inputs_and_outputs_index += struct_definition.member_list.size() * std::max(1u, param.type.array_length);
1508
}
1509
else
1510
{
1511
inputs_and_outputs_index += 1;
1512
}
1513
}
1514
}
1515
1516
if (func.return_type.is_struct())
1517
{
1518
const struct_type &struct_definition = get_struct(func.return_type.struct_definition);
1519
1520
for (uint32_t member_index = 0; member_index < struct_definition.member_list.size(); ++member_index)
1521
{
1522
const member_type &member = struct_definition.member_list[member_index];
1523
1524
const spv::Id result_var = create_varying_variable(member.type, member.semantic, spv::StorageClassOutput);
1525
1526
const spv::Id member_result =
1527
add_instruction(spv::OpCompositeExtract, convert_type(member.type))
1528
.add(call_result)
1529
.add(member_index);
1530
1531
add_instruction_without_result(spv::OpStore)
1532
.add(result_var)
1533
.add(member_result);
1534
}
1535
}
1536
else if (!func.return_type.is_void())
1537
{
1538
const spv::Id result_var = create_varying_variable(func.return_type, func.return_semantic, spv::StorageClassOutput);
1539
1540
add_instruction_without_result(spv::OpStore)
1541
.add(result_var)
1542
.add(call_result);
1543
}
1544
1545
// Add code to flip the output vertically
1546
if (_flip_vert_y && position_variable != 0 && func.type == shader_type::vertex)
1547
{
1548
expression position;
1549
position.reset_to_lvalue({}, position_variable, { type::t_float, 4, 1 });
1550
position.add_constant_index_access(1); // Y component
1551
1552
// gl_Position.y = -gl_Position.y
1553
emit_store(position,
1554
emit_unary_op({}, tokenid::minus, { type::t_float, 1, 1 },
1555
emit_load(position, false)));
1556
}
1557
1558
#if 0
1559
// Disabled because it breaks on MacOS/Metal - point size should not be defined for a non-point primitive.
1560
// Add code that sets the point size to a default value (in case this vertex shader is used with point primitives)
1561
if (point_size_variable == 0 && func.type == shader_type::vertex)
1562
{
1563
create_varying_variable({ type::t_float, 1, 1 }, "SV_POINTSIZE", spv::StorageClassOutput);
1564
1565
expression point_size;
1566
point_size.reset_to_lvalue({}, point_size_variable, { type::t_float, 1, 1 });
1567
1568
// gl_PointSize = 1.0
1569
emit_store(point_size, emit_constant({ type::t_float, 1, 1 }, 1));
1570
}
1571
#endif
1572
1573
leave_block_and_return(0);
1574
leave_function();
1575
1576
spv::ExecutionModel model;
1577
switch (func.type)
1578
{
1579
case shader_type::vertex:
1580
model = spv::ExecutionModelVertex;
1581
break;
1582
case shader_type::pixel:
1583
model = spv::ExecutionModelFragment;
1584
add_instruction_without_result(spv::OpExecutionMode, _execution_modes)
1585
.add(entry_point_definition)
1586
.add(_vulkan_semantics ? spv::ExecutionModeOriginUpperLeft : spv::ExecutionModeOriginLowerLeft);
1587
break;
1588
case shader_type::compute:
1589
model = spv::ExecutionModelGLCompute;
1590
add_instruction_without_result(spv::OpExecutionMode, _execution_modes)
1591
.add(entry_point_definition)
1592
.add(spv::ExecutionModeLocalSize)
1593
.add(func.num_threads[0])
1594
.add(func.num_threads[1])
1595
.add(func.num_threads[2]);
1596
break;
1597
default:
1598
assert(false);
1599
return;
1600
}
1601
1602
add_instruction_without_result(spv::OpEntryPoint, _entries)
1603
.add(model)
1604
.add(entry_point_definition)
1605
.add_string(func.unique_name.c_str())
1606
.add(inputs_and_outputs.begin(), inputs_and_outputs.end());
1607
}
1608
1609
id emit_load(const expression &exp, bool) override
1610
{
1611
if (exp.is_constant) // Constant expressions do not have a complex access chain
1612
return emit_constant(exp.type, exp.constant);
1613
1614
size_t i = 0;
1615
spv::Id result = exp.base;
1616
type base_type = exp.type;
1617
bool is_uniform_bool = false;
1618
1619
if (exp.is_lvalue || !exp.chain.empty())
1620
add_location(exp.location, *_current_block_data);
1621
1622
// If a variable is referenced, load the value first
1623
if (exp.is_lvalue && _spec_constants.find(exp.base) == _spec_constants.end())
1624
{
1625
if (!exp.chain.empty())
1626
base_type = exp.chain[0].from;
1627
1628
std::pair<spv::StorageClass, spv::ImageFormat> storage = { spv::StorageClassFunction, spv::ImageFormatUnknown };
1629
if (const auto it = _storage_lookup.find(exp.base);
1630
it != _storage_lookup.end())
1631
storage = it->second;
1632
1633
spirv_instruction *access_chain = nullptr;
1634
1635
// Check if this is a uniform variable (see 'define_uniform' function above) and dereference it
1636
if (result & 0xF0000000)
1637
{
1638
const uint32_t member_index = result ^ 0xF0000000;
1639
1640
storage.first = spv::StorageClassUniform;
1641
is_uniform_bool = base_type.is_boolean();
1642
1643
if (is_uniform_bool)
1644
base_type.base = type::t_uint;
1645
1646
access_chain = &add_instruction(spv::OpAccessChain)
1647
.add(_global_ubo_variable)
1648
.add(emit_constant(member_index));
1649
}
1650
1651
// Any indexing expressions can be resolved during load with an 'OpAccessChain' already
1652
if (!exp.chain.empty() && (
1653
exp.chain[0].op == expression::operation::op_member ||
1654
exp.chain[0].op == expression::operation::op_dynamic_index ||
1655
exp.chain[0].op == expression::operation::op_constant_index))
1656
{
1657
// Ensure that 'access_chain' cannot get invalidated by calls to 'emit_constant' or 'convert_type'
1658
assert(_current_block_data != &_types_and_constants);
1659
1660
// Use access chain from uniform if possible, otherwise create new one
1661
if (access_chain == nullptr) access_chain =
1662
&add_instruction(spv::OpAccessChain).add(result); // Base
1663
1664
// Ignore first index into 1xN matrices, since they were translated to a vector type in SPIR-V
1665
if (exp.chain[0].from.rows == 1 && exp.chain[0].from.cols > 1)
1666
i = 1;
1667
1668
for (; i < exp.chain.size() && (
1669
exp.chain[i].op == expression::operation::op_member ||
1670
exp.chain[i].op == expression::operation::op_dynamic_index ||
1671
exp.chain[i].op == expression::operation::op_constant_index); ++i)
1672
access_chain->add(exp.chain[i].op == expression::operation::op_dynamic_index ?
1673
exp.chain[i].index :
1674
emit_constant(exp.chain[i].index)); // Indexes
1675
1676
base_type = exp.chain[i - 1].to;
1677
access_chain->type = convert_type(base_type, true, storage.first, storage.second); // Last type is the result
1678
result = access_chain->result;
1679
}
1680
else if (access_chain != nullptr)
1681
{
1682
access_chain->type = convert_type(base_type, true, storage.first, storage.second, base_type.is_array() ? 16u : 0u);
1683
result = access_chain->result;
1684
}
1685
1686
result =
1687
add_instruction(spv::OpLoad, convert_type(base_type, false, spv::StorageClassFunction, storage.second))
1688
.add(result); // Pointer
1689
}
1690
1691
// Need to convert boolean uniforms which are actually integers in SPIR-V
1692
if (is_uniform_bool)
1693
{
1694
base_type.base = type::t_bool;
1695
1696
result =
1697
add_instruction(spv::OpINotEqual, convert_type(base_type))
1698
.add(result)
1699
.add(emit_constant(0));
1700
}
1701
1702
// Work through all remaining operations in the access chain and apply them to the value
1703
for (; i < exp.chain.size(); ++i)
1704
{
1705
assert(result != 0);
1706
const expression::operation &op = exp.chain[i];
1707
1708
switch (op.op)
1709
{
1710
case expression::operation::op_cast:
1711
if (op.from.is_scalar() && !op.to.is_scalar())
1712
{
1713
type cast_type = op.to;
1714
cast_type.base = op.from.base;
1715
1716
std::vector<expression> args;
1717
args.reserve(op.to.components());
1718
for (unsigned int c = 0; c < op.to.components(); ++c)
1719
args.emplace_back().reset_to_rvalue(exp.location, result, op.from);
1720
1721
result = emit_construct(exp.location, cast_type, args);
1722
}
1723
1724
if (op.from.is_boolean())
1725
{
1726
const spv::Id true_constant = emit_constant(op.to, 1);
1727
const spv::Id false_constant = emit_constant(op.to, 0);
1728
1729
result =
1730
add_instruction(spv::OpSelect, convert_type(op.to))
1731
.add(result) // Condition
1732
.add(true_constant)
1733
.add(false_constant);
1734
}
1735
else
1736
{
1737
spv::Op spv_op = spv::OpNop;
1738
switch (op.to.base)
1739
{
1740
case type::t_bool:
1741
if (op.from.is_floating_point())
1742
spv_op = spv::OpFOrdNotEqual;
1743
else
1744
spv_op = spv::OpINotEqual;
1745
// Add instruction to compare value against zero instead of casting
1746
result =
1747
add_instruction(spv_op, convert_type(op.to))
1748
.add(result)
1749
.add(emit_constant(op.from, 0));
1750
continue;
1751
case type::t_min16int:
1752
case type::t_int:
1753
if (op.from.is_floating_point())
1754
spv_op = spv::OpConvertFToS;
1755
else if (op.from.precision() == op.to.precision())
1756
spv_op = spv::OpBitcast;
1757
else if (_enable_16bit_types)
1758
spv_op = spv::OpSConvert;
1759
else
1760
continue; // Do not have to add conversion instruction between min16int/int if 16-bit types are not enabled
1761
break;
1762
case type::t_min16uint:
1763
case type::t_uint:
1764
if (op.from.is_floating_point())
1765
spv_op = spv::OpConvertFToU;
1766
else if (op.from.precision() == op.to.precision())
1767
spv_op = spv::OpBitcast;
1768
else if (_enable_16bit_types)
1769
spv_op = spv::OpUConvert;
1770
else
1771
continue;
1772
break;
1773
case type::t_min16float:
1774
case type::t_float:
1775
if (op.from.is_floating_point() && !_enable_16bit_types)
1776
continue; // Do not have to add conversion instruction between min16float/float if 16-bit types are not enabled
1777
else if (op.from.is_floating_point())
1778
spv_op = spv::OpFConvert;
1779
else if (op.from.is_signed())
1780
spv_op = spv::OpConvertSToF;
1781
else
1782
spv_op = spv::OpConvertUToF;
1783
break;
1784
default:
1785
assert(false);
1786
}
1787
1788
result =
1789
add_instruction(spv_op, convert_type(op.to))
1790
.add(result);
1791
}
1792
break;
1793
case expression::operation::op_dynamic_index:
1794
assert(op.from.is_vector() && op.to.is_scalar());
1795
result =
1796
add_instruction(spv::OpVectorExtractDynamic, convert_type(op.to))
1797
.add(result) // Vector
1798
.add(op.index); // Index
1799
break;
1800
case expression::operation::op_member: // In case of struct return values, which are r-values
1801
case expression::operation::op_constant_index:
1802
assert(op.from.is_vector() || op.from.is_matrix() || op.from.is_struct());
1803
result =
1804
add_instruction(spv::OpCompositeExtract, convert_type(op.to))
1805
.add(result)
1806
.add(op.index); // Literal Index
1807
break;
1808
case expression::operation::op_swizzle:
1809
if (op.to.is_vector())
1810
{
1811
if (op.from.is_matrix())
1812
{
1813
spv::Id components[4];
1814
for (int c = 0; c < 4 && op.swizzle[c] >= 0; ++c)
1815
{
1816
const unsigned int row = op.swizzle[c] / 4;
1817
const unsigned int column = op.swizzle[c] - row * 4;
1818
1819
type scalar_type = op.to;
1820
scalar_type.rows = 1;
1821
scalar_type.cols = 1;
1822
1823
spirv_instruction &inst = add_instruction(spv::OpCompositeExtract, convert_type(scalar_type));
1824
inst.add(result);
1825
if (op.from.rows > 1) // Matrix types with a single row are actually vectors, so they don't need the extra index
1826
inst.add(row);
1827
inst.add(column);
1828
1829
components[c] = inst;
1830
}
1831
1832
spirv_instruction &inst = add_instruction(spv::OpCompositeConstruct, convert_type(op.to));
1833
for (int c = 0; c < 4 && op.swizzle[c] >= 0; ++c)
1834
inst.add(components[c]);
1835
result = inst;
1836
}
1837
else if (op.from.is_vector())
1838
{
1839
spirv_instruction &inst = add_instruction(spv::OpVectorShuffle, convert_type(op.to));
1840
inst.add(result); // Vector 1
1841
inst.add(result); // Vector 2
1842
for (int c = 0; c < 4 && op.swizzle[c] >= 0; ++c)
1843
inst.add(op.swizzle[c]);
1844
result = inst;
1845
}
1846
else
1847
{
1848
spirv_instruction &inst = add_instruction(spv::OpCompositeConstruct, convert_type(op.to));
1849
for (unsigned int c = 0; c < op.to.rows; ++c)
1850
inst.add(result);
1851
result = inst;
1852
}
1853
break;
1854
}
1855
else if (op.from.is_matrix() && op.to.is_scalar())
1856
{
1857
assert(op.swizzle[1] < 0);
1858
1859
spirv_instruction &inst = add_instruction(spv::OpCompositeExtract, convert_type(op.to));
1860
inst.add(result); // Composite
1861
if (op.from.rows > 1)
1862
{
1863
const unsigned int row = op.swizzle[0] / 4;
1864
const unsigned int column = op.swizzle[0] - row * 4;
1865
inst.add(row);
1866
inst.add(column);
1867
}
1868
else
1869
{
1870
inst.add(op.swizzle[0]);
1871
}
1872
result = inst;
1873
break;
1874
}
1875
else
1876
{
1877
assert(false);
1878
break;
1879
}
1880
}
1881
}
1882
1883
return result;
1884
}
1885
void emit_store(const expression &exp, id value) override
1886
{
1887
assert(value != 0 && exp.is_lvalue && !exp.is_constant && !exp.type.is_sampler());
1888
1889
add_location(exp.location, *_current_block_data);
1890
1891
size_t i = 0;
1892
// Any indexing expressions can be resolved with an 'OpAccessChain' already
1893
spv::Id target = emit_access_chain(exp, i);
1894
type base_type = exp.chain.empty() ? exp.type : i == 0 ? exp.chain[0].from : exp.chain[i - 1].to;
1895
1896
// TODO: Complex access chains like float4x4[0].m00m10[0] = 0;
1897
// Work through all remaining operations in the access chain and apply them to the value
1898
for (; i < exp.chain.size(); ++i)
1899
{
1900
const expression::operation &op = exp.chain[i];
1901
switch (op.op)
1902
{
1903
case expression::operation::op_cast:
1904
case expression::operation::op_member:
1905
// These should have been handled above already (and casting does not make sense for a store operation)
1906
break;
1907
case expression::operation::op_dynamic_index:
1908
case expression::operation::op_constant_index:
1909
assert(false);
1910
break;
1911
case expression::operation::op_swizzle:
1912
{
1913
spv::Id result =
1914
add_instruction(spv::OpLoad, convert_type(base_type))
1915
.add(target); // Pointer
1916
1917
if (base_type.is_vector())
1918
{
1919
spirv_instruction &inst = add_instruction(spv::OpVectorShuffle, convert_type(base_type));
1920
inst.add(result); // Vector 1
1921
inst.add(value); // Vector 2
1922
1923
unsigned int shuffle[4] = { 0, 1, 2, 3 };
1924
for (unsigned int c = 0; c < base_type.rows; ++c)
1925
if (op.swizzle[c] >= 0)
1926
shuffle[op.swizzle[c]] = base_type.rows + c;
1927
for (unsigned int c = 0; c < base_type.rows; ++c)
1928
inst.add(shuffle[c]);
1929
1930
value = inst;
1931
}
1932
else if (op.to.is_scalar())
1933
{
1934
assert(op.swizzle[1] < 0);
1935
1936
spirv_instruction &inst = add_instruction(spv::OpCompositeInsert, convert_type(base_type));
1937
inst.add(value); // Object
1938
inst.add(result); // Composite
1939
1940
if (op.from.is_matrix() && op.from.rows > 1)
1941
{
1942
const unsigned int row = op.swizzle[0] / 4;
1943
const unsigned int column = op.swizzle[0] - row * 4;
1944
inst.add(row);
1945
inst.add(column);
1946
}
1947
else
1948
{
1949
inst.add(op.swizzle[0]);
1950
}
1951
1952
value = inst;
1953
}
1954
else
1955
{
1956
// TODO: Implement matrix to vector swizzles
1957
assert(false);
1958
}
1959
break;
1960
}
1961
}
1962
}
1963
1964
add_instruction_without_result(spv::OpStore)
1965
.add(target)
1966
.add(value);
1967
}
1968
id emit_access_chain(const expression &exp, size_t &i) override
1969
{
1970
// This function cannot create access chains for uniform variables
1971
assert((exp.base & 0xF0000000) == 0);
1972
1973
i = 0;
1974
if (exp.chain.empty() || (
1975
exp.chain[0].op != expression::operation::op_member &&
1976
exp.chain[0].op != expression::operation::op_dynamic_index &&
1977
exp.chain[0].op != expression::operation::op_constant_index))
1978
return exp.base;
1979
1980
std::pair<spv::StorageClass, spv::ImageFormat> storage = { spv::StorageClassFunction, spv::ImageFormatUnknown };
1981
if (const auto it = _storage_lookup.find(exp.base);
1982
it != _storage_lookup.end())
1983
storage = it->second;
1984
1985
// Ensure that 'access_chain' cannot get invalidated by calls to 'emit_constant' or 'convert_type'
1986
assert(_current_block_data != &_types_and_constants);
1987
1988
spirv_instruction *access_chain =
1989
&add_instruction(spv::OpAccessChain).add(exp.base); // Base
1990
1991
// Ignore first index into 1xN matrices, since they were translated to a vector type in SPIR-V
1992
if (exp.chain[0].from.rows == 1 && exp.chain[0].from.cols > 1)
1993
i = 1;
1994
1995
for (; i < exp.chain.size() && (
1996
exp.chain[i].op == expression::operation::op_member ||
1997
exp.chain[i].op == expression::operation::op_dynamic_index ||
1998
exp.chain[i].op == expression::operation::op_constant_index); ++i)
1999
access_chain->add(exp.chain[i].op == expression::operation::op_dynamic_index ?
2000
exp.chain[i].index :
2001
emit_constant(exp.chain[i].index)); // Indexes
2002
2003
access_chain->type = convert_type(exp.chain[i - 1].to, true, storage.first, storage.second); // Last type is the result
2004
return access_chain->result;
2005
}
2006
2007
using codegen::emit_constant;
2008
id emit_constant(uint32_t value)
2009
{
2010
return emit_constant({ type::t_uint, 1, 1 }, value);
2011
}
2012
id emit_constant(const type &data_type, const constant &data) override
2013
{
2014
return emit_constant(data_type, data, false);
2015
}
2016
id emit_constant(const type &data_type, const constant &data, bool spec_constant)
2017
{
2018
if (!spec_constant) // Specialization constants cannot reuse other constants
2019
{
2020
if (const auto it = std::find_if(_constant_lookup.begin(), _constant_lookup.end(),
2021
[&data_type, &data](std::tuple<type, constant, spv::Id> &x) {
2022
if (!(std::get<0>(x) == data_type && std::memcmp(&std::get<1>(x).as_uint[0], &data.as_uint[0], sizeof(uint32_t) * 16) == 0 && std::get<1>(x).array_data.size() == data.array_data.size()))
2023
return false;
2024
for (size_t i = 0; i < data.array_data.size(); ++i)
2025
if (std::memcmp(&std::get<1>(x).array_data[i].as_uint[0], &data.array_data[i].as_uint[0], sizeof(uint32_t) * 16) != 0)
2026
return false;
2027
return true;
2028
});
2029
it != _constant_lookup.end())
2030
return std::get<2>(*it); // Reuse existing constant instead of duplicating the definition
2031
}
2032
2033
spv::Id result;
2034
if (data_type.is_array())
2035
{
2036
assert(data_type.is_bounded_array()); // Unbounded arrays cannot be constants
2037
2038
type elem_type = data_type;
2039
elem_type.array_length = 0;
2040
2041
std::vector<spv::Id> elements;
2042
elements.reserve(data_type.array_length);
2043
2044
// Fill up elements with constant array data
2045
for (const constant &elem : data.array_data)
2046
elements.push_back(emit_constant(elem_type, elem, spec_constant));
2047
// Fill up any remaining elements with a default value (when the array data did not specify them)
2048
for (size_t i = elements.size(); i < static_cast<size_t>(data_type.array_length); ++i)
2049
elements.push_back(emit_constant(elem_type, {}, spec_constant));
2050
2051
result =
2052
add_instruction(spec_constant ? spv::OpSpecConstantComposite : spv::OpConstantComposite, convert_type(data_type), _types_and_constants)
2053
.add(elements.begin(), elements.end());
2054
}
2055
else if (data_type.is_struct())
2056
{
2057
assert(!spec_constant); // Structures cannot be specialization constants
2058
2059
result = add_instruction(spv::OpConstantNull, convert_type(data_type), _types_and_constants);
2060
}
2061
else if (data_type.is_vector() || data_type.is_matrix())
2062
{
2063
type elem_type = data_type;
2064
elem_type.rows = data_type.cols;
2065
elem_type.cols = 1;
2066
2067
spv::Id rows[4] = {};
2068
2069
// Construct matrix constant out of row vector constants
2070
// Construct vector constant out of scalar constants for each element
2071
for (unsigned int i = 0; i < data_type.rows; ++i)
2072
{
2073
constant row_data = {};
2074
for (unsigned int k = 0; k < data_type.cols; ++k)
2075
row_data.as_uint[k] = data.as_uint[i * data_type.cols + k];
2076
2077
rows[i] = emit_constant(elem_type, row_data, spec_constant);
2078
}
2079
2080
if (data_type.rows == 1)
2081
{
2082
result = rows[0];
2083
}
2084
else
2085
{
2086
spirv_instruction &inst = add_instruction(spec_constant ? spv::OpSpecConstantComposite : spv::OpConstantComposite, convert_type(data_type), _types_and_constants);
2087
for (unsigned int i = 0; i < data_type.rows; ++i)
2088
inst.add(rows[i]);
2089
result = inst;
2090
}
2091
}
2092
else if (data_type.is_boolean())
2093
{
2094
result = add_instruction(data.as_uint[0] ?
2095
(spec_constant ? spv::OpSpecConstantTrue : spv::OpConstantTrue) :
2096
(spec_constant ? spv::OpSpecConstantFalse : spv::OpConstantFalse), convert_type(data_type), _types_and_constants);
2097
}
2098
else
2099
{
2100
assert(data_type.is_scalar());
2101
2102
result =
2103
add_instruction(spec_constant ? spv::OpSpecConstant : spv::OpConstant, convert_type(data_type), _types_and_constants)
2104
.add(data.as_uint[0]);
2105
}
2106
2107
if (spec_constant) // Keep track of all specialization constants
2108
_spec_constants.insert(result);
2109
else
2110
_constant_lookup.push_back({ data_type, data, result });
2111
2112
return result;
2113
}
2114
2115
id emit_unary_op(const location &loc, tokenid op, const type &res_type, id val) override
2116
{
2117
spv::Op spv_op = spv::OpNop;
2118
2119
switch (op)
2120
{
2121
case tokenid::minus:
2122
spv_op = res_type.is_floating_point() ? spv::OpFNegate : spv::OpSNegate;
2123
break;
2124
case tokenid::tilde:
2125
spv_op = spv::OpNot;
2126
break;
2127
case tokenid::exclaim:
2128
spv_op = spv::OpLogicalNot;
2129
break;
2130
default:
2131
return assert(false), 0;
2132
}
2133
2134
add_location(loc, *_current_block_data);
2135
2136
spirv_instruction &inst = add_instruction(spv_op, convert_type(res_type));
2137
inst.add(val); // Operand
2138
2139
return inst;
2140
}
2141
id emit_binary_op(const location &loc, tokenid op, const type &res_type, const type &exp_type, id lhs, id rhs) override
2142
{
2143
spv::Op spv_op = spv::OpNop;
2144
2145
switch (op)
2146
{
2147
case tokenid::plus:
2148
case tokenid::plus_plus:
2149
case tokenid::plus_equal:
2150
spv_op = exp_type.is_floating_point() ? spv::OpFAdd : spv::OpIAdd;
2151
break;
2152
case tokenid::minus:
2153
case tokenid::minus_minus:
2154
case tokenid::minus_equal:
2155
spv_op = exp_type.is_floating_point() ? spv::OpFSub : spv::OpISub;
2156
break;
2157
case tokenid::star:
2158
case tokenid::star_equal:
2159
spv_op = exp_type.is_floating_point() ? spv::OpFMul : spv::OpIMul;
2160
break;
2161
case tokenid::slash:
2162
case tokenid::slash_equal:
2163
spv_op = exp_type.is_floating_point() ? spv::OpFDiv : exp_type.is_signed() ? spv::OpSDiv : spv::OpUDiv;
2164
break;
2165
case tokenid::percent:
2166
case tokenid::percent_equal:
2167
spv_op = exp_type.is_floating_point() ? spv::OpFRem : exp_type.is_signed() ? spv::OpSRem : spv::OpUMod;
2168
break;
2169
case tokenid::caret:
2170
case tokenid::caret_equal:
2171
spv_op = spv::OpBitwiseXor;
2172
break;
2173
case tokenid::pipe:
2174
case tokenid::pipe_equal:
2175
spv_op = spv::OpBitwiseOr;
2176
break;
2177
case tokenid::ampersand:
2178
case tokenid::ampersand_equal:
2179
spv_op = spv::OpBitwiseAnd;
2180
break;
2181
case tokenid::less_less:
2182
case tokenid::less_less_equal:
2183
spv_op = spv::OpShiftLeftLogical;
2184
break;
2185
case tokenid::greater_greater:
2186
case tokenid::greater_greater_equal:
2187
spv_op = exp_type.is_signed() ? spv::OpShiftRightArithmetic : spv::OpShiftRightLogical;
2188
break;
2189
case tokenid::pipe_pipe:
2190
spv_op = spv::OpLogicalOr;
2191
break;
2192
case tokenid::ampersand_ampersand:
2193
spv_op = spv::OpLogicalAnd;
2194
break;
2195
case tokenid::less:
2196
spv_op = exp_type.is_floating_point() ? spv::OpFOrdLessThan :
2197
exp_type.is_signed() ? spv::OpSLessThan : spv::OpULessThan;
2198
break;
2199
case tokenid::less_equal:
2200
spv_op = exp_type.is_floating_point() ? spv::OpFOrdLessThanEqual :
2201
exp_type.is_signed() ? spv::OpSLessThanEqual : spv::OpULessThanEqual;
2202
break;
2203
case tokenid::greater:
2204
spv_op = exp_type.is_floating_point() ? spv::OpFOrdGreaterThan :
2205
exp_type.is_signed() ? spv::OpSGreaterThan : spv::OpUGreaterThan;
2206
break;
2207
case tokenid::greater_equal:
2208
spv_op = exp_type.is_floating_point() ? spv::OpFOrdGreaterThanEqual :
2209
exp_type.is_signed() ? spv::OpSGreaterThanEqual : spv::OpUGreaterThanEqual;
2210
break;
2211
case tokenid::equal_equal:
2212
spv_op = exp_type.is_floating_point() ? spv::OpFOrdEqual :
2213
exp_type.is_boolean() ? spv::OpLogicalEqual : spv::OpIEqual;
2214
break;
2215
case tokenid::exclaim_equal:
2216
spv_op = exp_type.is_floating_point() ? spv::OpFOrdNotEqual :
2217
exp_type.is_boolean() ? spv::OpLogicalNotEqual : spv::OpINotEqual;
2218
break;
2219
default:
2220
return assert(false), 0;
2221
}
2222
2223
add_location(loc, *_current_block_data);
2224
2225
// Binary operators generally only work on scalars and vectors in SPIR-V, so need to apply them to matrices component-wise
2226
if (exp_type.is_matrix() && exp_type.rows != 1)
2227
{
2228
std::vector<spv::Id> ids;
2229
ids.reserve(exp_type.cols);
2230
2231
type vector_type = exp_type;
2232
vector_type.rows = exp_type.cols;
2233
vector_type.cols = 1;
2234
2235
for (unsigned int row = 0; row < exp_type.rows; ++row)
2236
{
2237
const spv::Id lhs_elem = add_instruction(spv::OpCompositeExtract, convert_type(vector_type))
2238
.add(lhs)
2239
.add(row);
2240
const spv::Id rhs_elem = add_instruction(spv::OpCompositeExtract, convert_type(vector_type))
2241
.add(rhs)
2242
.add(row);
2243
2244
spirv_instruction &inst = add_instruction(spv_op, convert_type(vector_type));
2245
inst.add(lhs_elem); // Operand 1
2246
inst.add(rhs_elem); // Operand 2
2247
2248
if (res_type.has(type::q_precise))
2249
add_decoration(inst, spv::DecorationNoContraction);
2250
if (!_enable_16bit_types && res_type.precision() < 32)
2251
add_decoration(inst, spv::DecorationRelaxedPrecision);
2252
2253
ids.push_back(inst);
2254
}
2255
2256
spirv_instruction &inst = add_instruction(spv::OpCompositeConstruct, convert_type(res_type));
2257
inst.add(ids.begin(), ids.end());
2258
2259
return inst;
2260
}
2261
else
2262
{
2263
spirv_instruction &inst = add_instruction(spv_op, convert_type(res_type));
2264
inst.add(lhs); // Operand 1
2265
inst.add(rhs); // Operand 2
2266
2267
if (res_type.has(type::q_precise))
2268
add_decoration(inst, spv::DecorationNoContraction);
2269
if (!_enable_16bit_types && res_type.precision() < 32)
2270
add_decoration(inst, spv::DecorationRelaxedPrecision);
2271
2272
return inst;
2273
}
2274
}
2275
id emit_ternary_op(const location &loc, tokenid op, const type &res_type, id condition, id true_value, id false_value) override
2276
{
2277
if (op != tokenid::question)
2278
return assert(false), 0;
2279
2280
add_location(loc, *_current_block_data);
2281
2282
spirv_instruction &inst = add_instruction(spv::OpSelect, convert_type(res_type));
2283
inst.add(condition); // Condition
2284
inst.add(true_value); // Object 1
2285
inst.add(false_value); // Object 2
2286
2287
return inst;
2288
}
2289
id emit_call(const location &loc, id function, const type &res_type, const std::vector<expression> &args) override
2290
{
2291
#ifndef NDEBUG
2292
for (const expression &arg : args)
2293
assert(arg.chain.empty() && arg.base != 0);
2294
#endif
2295
add_location(loc, *_current_block_data);
2296
2297
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpFunctionCall
2298
spirv_instruction &inst = add_instruction(spv::OpFunctionCall, convert_type(res_type));
2299
inst.add(function); // Function
2300
for (const expression &arg : args)
2301
inst.add(arg.base); // Arguments
2302
2303
return inst;
2304
}
2305
id emit_call_intrinsic(const location &loc, id intrinsic, const type &res_type, const std::vector<expression> &args) override
2306
{
2307
#ifndef NDEBUG
2308
for (const expression &arg : args)
2309
assert(arg.chain.empty() && arg.base != 0);
2310
#endif
2311
add_location(loc, *_current_block_data);
2312
2313
enum
2314
{
2315
#define IMPLEMENT_INTRINSIC_SPIRV(name, i, code) name##i,
2316
#include "effect_symbol_table_intrinsics.inl"
2317
};
2318
2319
switch (intrinsic)
2320
{
2321
#define IMPLEMENT_INTRINSIC_SPIRV(name, i, code) case name##i: code
2322
#include "effect_symbol_table_intrinsics.inl"
2323
default:
2324
return assert(false), 0;
2325
}
2326
}
2327
id emit_construct(const location &loc, const type &res_type, const std::vector<expression> &args) override
2328
{
2329
#ifndef NDEBUG
2330
for (const expression &arg : args)
2331
assert((arg.type.is_scalar() || res_type.is_array()) && arg.chain.empty() && arg.base != 0);
2332
#endif
2333
add_location(loc, *_current_block_data);
2334
2335
std::vector<spv::Id> ids;
2336
ids.reserve(args.size());
2337
2338
// There must be exactly one constituent for each top-level component of the result
2339
if (res_type.is_matrix())
2340
{
2341
type vector_type = res_type;
2342
vector_type.rows = res_type.cols;
2343
vector_type.cols = 1;
2344
2345
// Turn the list of scalar arguments into a list of column vectors
2346
for (size_t arg = 0; arg < args.size(); arg += vector_type.rows)
2347
{
2348
spirv_instruction &inst = add_instruction(spv::OpCompositeConstruct, convert_type(vector_type));
2349
for (unsigned row = 0; row < vector_type.rows; ++row)
2350
inst.add(args[arg + row].base);
2351
2352
ids.push_back(inst);
2353
}
2354
}
2355
else
2356
{
2357
assert(res_type.is_vector() || res_type.is_array());
2358
2359
// The exception is that for constructing a vector, a contiguous subset of the scalars consumed can be represented by a vector operand instead
2360
for (const expression &arg : args)
2361
ids.push_back(arg.base);
2362
}
2363
2364
spirv_instruction &inst = add_instruction(spv::OpCompositeConstruct, convert_type(res_type));
2365
inst.add(ids.begin(), ids.end());
2366
2367
return inst;
2368
}
2369
2370
void emit_if(const location &loc, id, id condition_block, id true_statement_block, id false_statement_block, unsigned int selection_control) override
2371
{
2372
spirv_instruction merge_label = _current_block_data->instructions.back();
2373
assert(merge_label.op == spv::OpLabel);
2374
_current_block_data->instructions.pop_back();
2375
2376
// Add previous block containing the condition value first
2377
_current_block_data->append(_block_data[condition_block]);
2378
2379
spirv_instruction branch_inst = _current_block_data->instructions.back();
2380
assert(branch_inst.op == spv::OpBranchConditional);
2381
_current_block_data->instructions.pop_back();
2382
2383
// Add structured control flow instruction
2384
add_location(loc, *_current_block_data);
2385
add_instruction_without_result(spv::OpSelectionMerge)
2386
.add(merge_label)
2387
.add(selection_control & 0x3); // 'SelectionControl' happens to match the flags produced by the parser
2388
2389
// Append all blocks belonging to the branch
2390
_current_block_data->instructions.push_back(branch_inst);
2391
_current_block_data->append(_block_data[true_statement_block]);
2392
_current_block_data->append(_block_data[false_statement_block]);
2393
2394
_current_block_data->instructions.push_back(merge_label);
2395
}
2396
id emit_phi(const location &loc, id, id condition_block, id true_value, id true_statement_block, id false_value, id false_statement_block, const type &res_type) override
2397
{
2398
spirv_instruction merge_label = _current_block_data->instructions.back();
2399
assert(merge_label.op == spv::OpLabel);
2400
_current_block_data->instructions.pop_back();
2401
2402
// Add previous block containing the condition value first
2403
_current_block_data->append(_block_data[condition_block]);
2404
2405
if (true_statement_block != condition_block)
2406
_current_block_data->append(_block_data[true_statement_block]);
2407
if (false_statement_block != condition_block)
2408
_current_block_data->append(_block_data[false_statement_block]);
2409
2410
_current_block_data->instructions.push_back(merge_label);
2411
2412
add_location(loc, *_current_block_data);
2413
2414
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpPhi
2415
spirv_instruction &inst = add_instruction(spv::OpPhi, convert_type(res_type))
2416
.add(true_value) // Variable 0
2417
.add(true_statement_block) // Parent 0
2418
.add(false_value) // Variable 1
2419
.add(false_statement_block); // Parent 1
2420
2421
return inst;
2422
}
2423
void emit_loop(const location &loc, id, id prev_block, id header_block, id condition_block, id loop_block, id continue_block, unsigned int loop_control) override
2424
{
2425
spirv_instruction merge_label = _current_block_data->instructions.back();
2426
assert(merge_label.op == spv::OpLabel);
2427
_current_block_data->instructions.pop_back();
2428
2429
// Add previous block first
2430
_current_block_data->append(_block_data[prev_block]);
2431
2432
// Fill header block
2433
assert(_block_data[header_block].instructions.size() == 2);
2434
_current_block_data->instructions.push_back(_block_data[header_block].instructions[0]);
2435
assert(_current_block_data->instructions.back().op == spv::OpLabel);
2436
2437
// Add structured control flow instruction
2438
add_location(loc, *_current_block_data);
2439
add_instruction_without_result(spv::OpLoopMerge)
2440
.add(merge_label)
2441
.add(continue_block)
2442
.add(loop_control & 0x3); // 'LoopControl' happens to match the flags produced by the parser
2443
2444
_current_block_data->instructions.push_back(_block_data[header_block].instructions[1]);
2445
assert(_current_block_data->instructions.back().op == spv::OpBranch);
2446
2447
// Add condition block if it exists
2448
if (condition_block != 0)
2449
_current_block_data->append(_block_data[condition_block]);
2450
2451
// Append loop body block before continue block
2452
_current_block_data->append(_block_data[loop_block]);
2453
_current_block_data->append(_block_data[continue_block]);
2454
2455
_current_block_data->instructions.push_back(merge_label);
2456
}
2457
void emit_switch(const location &loc, id, id selector_block, id default_label, id default_block, const std::vector<id> &case_literal_and_labels, const std::vector<id> &case_blocks, unsigned int selection_control) override
2458
{
2459
assert(case_blocks.size() == case_literal_and_labels.size() / 2);
2460
2461
spirv_instruction merge_label = _current_block_data->instructions.back();
2462
assert(merge_label.op == spv::OpLabel);
2463
_current_block_data->instructions.pop_back();
2464
2465
// Add previous block containing the selector value first
2466
_current_block_data->append(_block_data[selector_block]);
2467
2468
spirv_instruction switch_inst = _current_block_data->instructions.back();
2469
assert(switch_inst.op == spv::OpSwitch);
2470
_current_block_data->instructions.pop_back();
2471
2472
// Add structured control flow instruction
2473
add_location(loc, *_current_block_data);
2474
add_instruction_without_result(spv::OpSelectionMerge)
2475
.add(merge_label)
2476
.add(selection_control & 0x3); // 'SelectionControl' happens to match the flags produced by the parser
2477
2478
// Update switch instruction to contain all case labels
2479
switch_inst.operands[1] = default_label;
2480
switch_inst.add(case_literal_and_labels.begin(), case_literal_and_labels.end());
2481
2482
// Append all blocks belonging to the switch
2483
_current_block_data->instructions.push_back(switch_inst);
2484
2485
std::vector<id> blocks = case_blocks;
2486
if (default_label != merge_label)
2487
blocks.push_back(default_block);
2488
// Eliminate duplicates (because of multiple case labels pointing to the same block)
2489
std::sort(blocks.begin(), blocks.end());
2490
blocks.erase(std::unique(blocks.begin(), blocks.end()), blocks.end());
2491
for (const id case_block : blocks)
2492
_current_block_data->append(_block_data[case_block]);
2493
2494
_current_block_data->instructions.push_back(merge_label);
2495
}
2496
2497
bool is_in_function() const { return _current_function_blocks != nullptr; }
2498
2499
id set_block(id id) override
2500
{
2501
_last_block = _current_block;
2502
_current_block = id;
2503
_current_block_data = &_block_data[id];
2504
2505
return _last_block;
2506
}
2507
void enter_block(id id) override
2508
{
2509
assert(id != 0);
2510
// Can only use labels inside functions and should never be in another basic block if creating a new one
2511
assert(is_in_function() && !is_in_block());
2512
2513
set_block(id);
2514
2515
add_instruction_without_result(spv::OpLabel).result = id;
2516
}
2517
id leave_block_and_kill() override
2518
{
2519
assert(is_in_function()); // Can only discard inside functions
2520
2521
if (!is_in_block())
2522
return 0;
2523
2524
// DXC chokes when discarding inside a function. Return a null value and use demote instead, since that's
2525
// what the HLSL discard instruction compiles to anyway.
2526
if (!_discard_is_demote || _current_function_blocks->return_type.is_void())
2527
{
2528
add_instruction_without_result(spv::OpKill);
2529
}
2530
else
2531
{
2532
add_instruction_without_result(spv::OpDemoteToHelperInvocation);
2533
2534
const id return_id = emit_constant(_current_function_blocks->return_type, constant{}, false);
2535
add_instruction_without_result(spv::OpReturnValue).add(return_id);
2536
}
2537
2538
return set_block(0);
2539
}
2540
id leave_block_and_return(id value) override
2541
{
2542
assert(is_in_function()); // Can only return from inside functions
2543
2544
if (!is_in_block()) // Might already have left the last block in which case this has to be ignored
2545
return 0;
2546
2547
if (_current_function_blocks->return_type.is_void())
2548
{
2549
add_instruction_without_result(spv::OpReturn);
2550
}
2551
else
2552
{
2553
if (0 == value) // The implicit return statement needs this
2554
value = add_instruction(spv::OpUndef, convert_type(_current_function_blocks->return_type), _types_and_constants);
2555
2556
add_instruction_without_result(spv::OpReturnValue)
2557
.add(value);
2558
}
2559
2560
return set_block(0);
2561
}
2562
id leave_block_and_switch(id value, id default_target) override
2563
{
2564
assert(value != 0 && default_target != 0);
2565
assert(is_in_function()); // Can only switch inside functions
2566
2567
if (!is_in_block())
2568
return _last_block;
2569
2570
add_instruction_without_result(spv::OpSwitch)
2571
.add(value)
2572
.add(default_target);
2573
2574
return set_block(0);
2575
}
2576
id leave_block_and_branch(id target, unsigned int) override
2577
{
2578
assert(target != 0);
2579
assert(is_in_function()); // Can only branch inside functions
2580
2581
if (!is_in_block())
2582
return _last_block;
2583
2584
add_instruction_without_result(spv::OpBranch)
2585
.add(target);
2586
2587
return set_block(0);
2588
}
2589
id leave_block_and_branch_conditional(id condition, id true_target, id false_target) override
2590
{
2591
assert(condition != 0 && true_target != 0 && false_target != 0);
2592
assert(is_in_function()); // Can only branch inside functions
2593
2594
if (!is_in_block())
2595
return _last_block;
2596
2597
add_instruction_without_result(spv::OpBranchConditional)
2598
.add(condition)
2599
.add(true_target)
2600
.add(false_target);
2601
2602
return set_block(0);
2603
}
2604
void leave_function() override
2605
{
2606
assert(is_in_function()); // Can only leave if there was a function to begin with
2607
2608
_current_function_blocks->definition = _block_data[_last_block];
2609
2610
// Append function end instruction
2611
add_instruction_without_result(spv::OpFunctionEnd, _current_function_blocks->definition);
2612
2613
_current_function = nullptr;
2614
_current_function_blocks = nullptr;
2615
}
2616
};
2617
2618
codegen *reshadefx::create_codegen_spirv(bool vulkan_semantics, bool debug_info, bool uniforms_to_spec_constants, bool enable_16bit_types, bool flip_vert_y, bool discard_is_demote)
2619
{
2620
return new codegen_spirv(vulkan_semantics, debug_info, uniforms_to_spec_constants, enable_16bit_types, flip_vert_y, discard_is_demote);
2621
}
2622
2623