Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
stenzek
GitHub Repository: stenzek/duckstation
Path: blob/master/dep/reshadefx/src/effect_codegen_hlsl.cpp
4246 views
1
/*
2
* Copyright (C) 2014 Patrick Mours
3
* SPDX-License-Identifier: BSD-3-Clause
4
*/
5
6
#include "effect_parser.hpp"
7
#include "effect_codegen.hpp"
8
#include <cmath> // std::isinf, std::isnan, std::signbit
9
#include <cctype> // std::tolower
10
#include <cassert>
11
#include <cstring> // stricmp, std::memcmp
12
#include <charconv> // std::from_chars, std::to_chars
13
#include <algorithm> // std::equal, std::find, std::find_if, std::max
14
#include <locale>
15
#include <sstream>
16
17
using namespace reshadefx;
18
19
inline char to_digit(unsigned int value)
20
{
21
assert(value < 10);
22
return '0' + static_cast<char>(value);
23
}
24
25
inline uint32_t align_up(uint32_t size, uint32_t alignment, uint32_t elements)
26
{
27
alignment -= 1;
28
return ((size + alignment) & ~alignment) * (elements - 1) + size;
29
}
30
31
class codegen_hlsl final : public codegen
32
{
33
public:
34
codegen_hlsl(unsigned int shader_model, bool debug_info, bool uniforms_to_spec_constants) :
35
_shader_model(shader_model),
36
_debug_info(debug_info),
37
_uniforms_to_spec_constants(uniforms_to_spec_constants)
38
{
39
// Create default block and reserve a memory block to avoid frequent reallocations
40
std::string &block = _blocks.emplace(0, std::string()).first->second;
41
block.reserve(8192);
42
}
43
44
private:
45
enum class naming
46
{
47
// Name should already be unique, so no additional steps are taken
48
unique,
49
// Will be numbered when clashing with another name
50
general,
51
// Replace name with a code snippet
52
expression,
53
};
54
55
unsigned int _shader_model = 0;
56
bool _debug_info = false;
57
bool _uniforms_to_spec_constants = false;
58
59
std::unordered_map<id, std::string> _names;
60
std::unordered_map<id, std::string> _blocks;
61
std::string _cbuffer_block;
62
std::string _current_location;
63
std::string _current_function_declaration;
64
65
std::string _remapped_semantics[15];
66
std::vector<std::tuple<type, constant, id>> _constant_lookup;
67
#if 0
68
std::vector<sampler_binding> _sampler_lookup;
69
#endif
70
71
// Only write compatibility intrinsics to result if they are actually in use
72
bool _uses_bitwise_cast = false;
73
bool _uses_bitwise_intrinsics = false;
74
75
void optimize_bindings() override
76
{
77
codegen::optimize_bindings();
78
79
#if 0
80
if (_shader_model < 40)
81
return;
82
83
_module.num_sampler_bindings = static_cast<uint32_t>(_sampler_lookup.size());
84
85
for (technique &tech : _module.techniques)
86
for (pass &pass : tech.passes)
87
pass.sampler_bindings.assign(_sampler_lookup.begin(), _sampler_lookup.end());
88
#endif
89
}
90
91
std::string finalize_preamble() const
92
{
93
std::string preamble;
94
95
#define IMPLEMENT_INTRINSIC_FALLBACK_ASINT(n) \
96
"int" #n " __asint(float" #n " v) {" \
97
"float" #n " e = 0;" \
98
"float" #n " f = frexp(v, e) * 2 - 1;" /* frexp does not include sign bit in HLSL, so can use as is */ \
99
"float" #n " m = ldexp(f, 23);" \
100
"return (v == 0) ? 0 : (v < 0 ? 2147483648 : 0) + (" /* Zero (does not handle negative zero) */ \
101
/* isnan(v) ? 2147483647 : */ /* NaN */ \
102
/* isinf(v) ? 2139095040 : */ /* Infinity */ \
103
"ldexp(e + 126, 23) + m);" \
104
"}"
105
#define IMPLEMENT_INTRINSIC_FALLBACK_ASUINT(n) \
106
"int" #n " __asuint(float" #n " v) { return __asint(v); }"
107
#define IMPLEMENT_INTRINSIC_FALLBACK_ASFLOAT(n) \
108
"float" #n " __asfloat(int" #n " v) {" \
109
"float" #n " m = v % exp2(23);" \
110
"float" #n " f = ldexp(m, -23);" \
111
"float" #n " e = floor(ldexp(v, -23) % 256);" \
112
"return (v > 2147483647 ? -1 : 1) * (" \
113
/* e == 0 ? ldexp(f, -126) : */ /* Denormalized */ \
114
/* e == 255 ? (m == 0 ? 1.#INF : -1.#IND) : */ /* Infinity and NaN */ \
115
"ldexp(1 + f, e - 127));" \
116
"}"
117
118
// See https://graphics.stanford.edu/%7Eseander/bithacks.html#CountBitsSetParallel
119
#define IMPLEMENT_INTRINSIC_FALLBACK_COUNTBITS(n) \
120
"uint" #n " __countbits(uint" #n " v) {" \
121
"v = v - ((v >> 1) & 0x55555555);" \
122
"v = (v & 0x33333333) + ((v >> 2) & 0x33333333);" \
123
"v = (v + (v >> 4)) & 0x0F0F0F0F;" \
124
"v *= 0x01010101;" \
125
"return v >> 24;" \
126
"}"
127
#define IMPLEMENT_INTRINSIC_FALLBACK_COUNTBITS_LOOP(n) \
128
"uint" #n " __countbits(uint" #n " v) {" \
129
"uint" #n " c = 0;" \
130
"while (any(v > 0)) {" \
131
"c += v % 2;" \
132
"v /= 2;" \
133
"}" \
134
"return c;" \
135
"}"
136
137
// See https://graphics.stanford.edu/%7Eseander/bithacks.html#ReverseParallel
138
#define IMPLEMENT_INTRINSIC_FALLBACK_REVERSEBITS(n) \
139
"uint" #n " __reversebits(uint" #n " v) {" \
140
"v = ((v >> 1) & 0x55555555) | ((v & 0x55555555) << 1);" \
141
"v = ((v >> 2) & 0x33333333) | ((v & 0x33333333) << 2);" \
142
"v = ((v >> 4) & 0x0F0F0F0F) | ((v & 0x0F0F0F0F) << 4);" \
143
"v = ((v >> 8) & 0x00FF00FF) | ((v & 0x00FF00FF) << 8);" \
144
"return (v >> 16) | (v << 16);" \
145
"}"
146
#define IMPLEMENT_INTRINSIC_FALLBACK_REVERSEBITS_LOOP(n) \
147
"uint" #n " __reversebits(uint" #n " v) {" \
148
"uint" #n " r = 0;" \
149
"for (int i = 0; i < 32; i++) {" \
150
"r *= 2;" \
151
"r += floor(x % 2);" \
152
"v /= 2;" \
153
"}" \
154
"return r;" \
155
"}"
156
157
// See https://graphics.stanford.edu/%7Eseander/bithacks.html#ZerosOnRightParallel
158
#define IMPLEMENT_INTRINSIC_FALLBACK_FIRSTBITLOW(n) \
159
"uint" #n " __firstbitlow(uint" #n " v) {" \
160
"uint" #n " c = (v != 0) ? 31 : 32;" \
161
"v &= -int" #n "(v);" \
162
"c = (v & 0x0000FFFF) ? c - 16 : c;" \
163
"c = (v & 0x00FF00FF) ? c - 8 : c;" \
164
"c = (v & 0x0F0F0F0F) ? c - 4 : c;" \
165
"c = (v & 0x33333333) ? c - 2 : c;" \
166
"c = (v & 0x55555555) ? c - 1 : c;" \
167
"return c;" \
168
"}"
169
#define IMPLEMENT_INTRINSIC_FALLBACK_FIRSTBITLOW_LOOP(n) \
170
"uint" #n " __firstbitlow(uint" #n " v) {" \
171
"uint" #n " c = (v != 0) ? 31 : 32;" \
172
"for (int i = 0; i < 32; i++) {" \
173
"c = c > i && (v % 2) != 0 ? i : c;" \
174
"v /= 2;" \
175
"}" \
176
"return c;" \
177
"}"
178
179
180
#define IMPLEMENT_INTRINSIC_FALLBACK_FIRSTBITHIGH(n) \
181
"uint" #n " __firstbithigh(uint" #n " v) { return __firstbitlow(__reversebits(v)); }"
182
183
if (_shader_model >= 40)
184
{
185
preamble +=
186
"struct __sampler1D_int { Texture1D<int> t; SamplerState s; };\n"
187
"struct __sampler2D_int { Texture2D<int> t; SamplerState s; };\n"
188
"struct __sampler3D_int { Texture3D<int> t; SamplerState s; };\n"
189
"struct __sampler1D_uint { Texture1D<uint> t; SamplerState s; };\n"
190
"struct __sampler2D_uint { Texture2D<uint> t; SamplerState s; };\n"
191
"struct __sampler3D_uint { Texture3D<uint> t; SamplerState s; };\n"
192
"struct __sampler1D_float { Texture1D<float> t; SamplerState s; };\n"
193
"struct __sampler2D_float { Texture2D<float> t; SamplerState s; };\n"
194
"struct __sampler3D_float { Texture3D<float> t; SamplerState s; };\n"
195
"struct __sampler1D_float4 { Texture1D<float4> t; SamplerState s; };\n"
196
"struct __sampler2D_float4 { Texture2D<float4> t; SamplerState s; };\n"
197
"struct __sampler3D_float4 { Texture3D<float4> t; SamplerState s; };\n";
198
199
if (_uses_bitwise_intrinsics && _shader_model < 50)
200
preamble +=
201
IMPLEMENT_INTRINSIC_FALLBACK_COUNTBITS(1) "\n"
202
IMPLEMENT_INTRINSIC_FALLBACK_COUNTBITS(2) "\n"
203
IMPLEMENT_INTRINSIC_FALLBACK_COUNTBITS(3) "\n"
204
IMPLEMENT_INTRINSIC_FALLBACK_COUNTBITS(4) "\n"
205
206
IMPLEMENT_INTRINSIC_FALLBACK_REVERSEBITS(1) "\n"
207
IMPLEMENT_INTRINSIC_FALLBACK_REVERSEBITS(2) "\n"
208
IMPLEMENT_INTRINSIC_FALLBACK_REVERSEBITS(3) "\n"
209
IMPLEMENT_INTRINSIC_FALLBACK_REVERSEBITS(4) "\n"
210
211
IMPLEMENT_INTRINSIC_FALLBACK_FIRSTBITLOW(1) "\n"
212
IMPLEMENT_INTRINSIC_FALLBACK_FIRSTBITLOW(2) "\n"
213
IMPLEMENT_INTRINSIC_FALLBACK_FIRSTBITLOW(3) "\n"
214
IMPLEMENT_INTRINSIC_FALLBACK_FIRSTBITLOW(4) "\n"
215
216
IMPLEMENT_INTRINSIC_FALLBACK_FIRSTBITHIGH(1) "\n"
217
IMPLEMENT_INTRINSIC_FALLBACK_FIRSTBITHIGH(2) "\n"
218
IMPLEMENT_INTRINSIC_FALLBACK_FIRSTBITHIGH(3) "\n"
219
IMPLEMENT_INTRINSIC_FALLBACK_FIRSTBITHIGH(4) "\n";
220
221
if (!_cbuffer_block.empty())
222
{
223
#if 0
224
if (_shader_model >= 60)
225
preamble += "[[vk::binding(0, 0)]] "; // Descriptor set 0
226
#endif
227
228
preamble += "cbuffer _Globals {\n" + _cbuffer_block + "};\n";
229
}
230
}
231
else
232
{
233
preamble +=
234
"struct __sampler1D { sampler1D s; float1 pixelsize; };\n"
235
"struct __sampler2D { sampler2D s; float2 pixelsize; };\n"
236
"struct __sampler3D { sampler3D s; float3 pixelsize; };\n"
237
"uniform float2 __TEXEL_SIZE__ : register(c255);\n";
238
239
if (_uses_bitwise_cast)
240
preamble +=
241
IMPLEMENT_INTRINSIC_FALLBACK_ASINT(1) "\n"
242
IMPLEMENT_INTRINSIC_FALLBACK_ASINT(2) "\n"
243
IMPLEMENT_INTRINSIC_FALLBACK_ASINT(3) "\n"
244
IMPLEMENT_INTRINSIC_FALLBACK_ASINT(4) "\n"
245
246
IMPLEMENT_INTRINSIC_FALLBACK_ASUINT(1) "\n"
247
IMPLEMENT_INTRINSIC_FALLBACK_ASUINT(2) "\n"
248
IMPLEMENT_INTRINSIC_FALLBACK_ASUINT(3) "\n"
249
IMPLEMENT_INTRINSIC_FALLBACK_ASUINT(4) "\n"
250
251
IMPLEMENT_INTRINSIC_FALLBACK_ASFLOAT(1) "\n"
252
IMPLEMENT_INTRINSIC_FALLBACK_ASFLOAT(2) "\n"
253
IMPLEMENT_INTRINSIC_FALLBACK_ASFLOAT(3) "\n"
254
IMPLEMENT_INTRINSIC_FALLBACK_ASFLOAT(4) "\n";
255
256
if (_uses_bitwise_intrinsics)
257
preamble +=
258
IMPLEMENT_INTRINSIC_FALLBACK_COUNTBITS_LOOP(1) "\n"
259
IMPLEMENT_INTRINSIC_FALLBACK_COUNTBITS_LOOP(2) "\n"
260
IMPLEMENT_INTRINSIC_FALLBACK_COUNTBITS_LOOP(3) "\n"
261
IMPLEMENT_INTRINSIC_FALLBACK_COUNTBITS_LOOP(4) "\n"
262
263
IMPLEMENT_INTRINSIC_FALLBACK_REVERSEBITS_LOOP(1) "\n"
264
IMPLEMENT_INTRINSIC_FALLBACK_REVERSEBITS_LOOP(2) "\n"
265
IMPLEMENT_INTRINSIC_FALLBACK_REVERSEBITS_LOOP(3) "\n"
266
IMPLEMENT_INTRINSIC_FALLBACK_REVERSEBITS_LOOP(4) "\n"
267
268
IMPLEMENT_INTRINSIC_FALLBACK_FIRSTBITLOW_LOOP(1) "\n"
269
IMPLEMENT_INTRINSIC_FALLBACK_FIRSTBITLOW_LOOP(2) "\n"
270
IMPLEMENT_INTRINSIC_FALLBACK_FIRSTBITLOW_LOOP(3) "\n"
271
IMPLEMENT_INTRINSIC_FALLBACK_FIRSTBITLOW_LOOP(4) "\n"
272
273
IMPLEMENT_INTRINSIC_FALLBACK_FIRSTBITHIGH(1) "\n"
274
IMPLEMENT_INTRINSIC_FALLBACK_FIRSTBITHIGH(2) "\n"
275
IMPLEMENT_INTRINSIC_FALLBACK_FIRSTBITHIGH(3) "\n"
276
IMPLEMENT_INTRINSIC_FALLBACK_FIRSTBITHIGH(4) "\n";
277
278
if (!_cbuffer_block.empty())
279
{
280
preamble += _cbuffer_block;
281
}
282
}
283
284
return preamble;
285
}
286
287
std::string finalize_code() const override
288
{
289
std::string code = finalize_preamble();
290
291
// Add global definitions (struct types, global variables, sampler state declarations, ...)
292
code += _blocks.at(0);
293
294
// Add texture and sampler definitions
295
for (const sampler &info : _module.samplers)
296
code += _blocks.at(info.id);
297
298
// Add storage definitions
299
for (const storage &info : _module.storages)
300
code += _blocks.at(info.id);
301
302
// Add function definitions
303
for (const std::unique_ptr<function> &func : _functions)
304
code += _blocks.at(func->id);
305
306
return code;
307
}
308
std::string finalize_code_for_entry_point(const std::string &entry_point_name) const override
309
{
310
const auto entry_point_it = std::find_if(_functions.begin(), _functions.end(),
311
[&entry_point_name](const std::unique_ptr<function> &func) {
312
return func->unique_name == entry_point_name;
313
});
314
if (entry_point_it == _functions.end())
315
return {};
316
const function &entry_point = *entry_point_it->get();
317
318
std::string code = finalize_preamble();
319
320
if (_shader_model < 40 && entry_point.type == shader_type::pixel)
321
// Overwrite position semantic in pixel shaders
322
code += "#define POSITION VPOS\n";
323
324
// Add global definitions (struct types, global variables, sampler state declarations, ...)
325
code += _blocks.at(0);
326
327
const auto replace_binding =
328
[](std::string &code, uint32_t binding) {
329
for (size_t start = 0;;)
330
{
331
const size_t pos = code.find(": register(", start);
332
if (pos == std::string::npos)
333
break;
334
const size_t beg = pos + 12;
335
const size_t end = code.find(')', beg);
336
const std::string replacement = std::to_string(binding);
337
code.replace(beg, end - beg, replacement);
338
start = beg + replacement.length();
339
}
340
};
341
342
// Add referenced texture and sampler definitions
343
for (uint32_t binding = 0; binding < entry_point.referenced_samplers.size(); ++binding)
344
{
345
if (entry_point.referenced_samplers[binding] == 0)
346
continue;
347
348
std::string block_code = _blocks.at(entry_point.referenced_samplers[binding]);
349
replace_binding(block_code, binding);
350
code += block_code;
351
}
352
353
// Add referenced storage definitions
354
for (uint32_t binding = 0; binding < entry_point.referenced_storages.size(); ++binding)
355
{
356
if (entry_point.referenced_storages[binding] == 0)
357
continue;
358
359
std::string block_code = _blocks.at(entry_point.referenced_storages[binding]);
360
replace_binding(block_code, binding);
361
code += block_code;
362
}
363
364
// Add referenced function definitions
365
for (const std::unique_ptr<function> &func : _functions)
366
{
367
if (func->id != entry_point.id &&
368
std::find(entry_point.referenced_functions.begin(), entry_point.referenced_functions.end(), func->id) == entry_point.referenced_functions.end())
369
continue;
370
371
code += _blocks.at(func->id);
372
}
373
374
return code;
375
}
376
377
template <bool is_param = false, bool is_decl = true>
378
void write_type(std::string &s, const type &type) const
379
{
380
if constexpr (is_decl)
381
{
382
if (type.has(type::q_static))
383
s += "static ";
384
if (type.has(type::q_precise))
385
s += "precise ";
386
if (type.has(type::q_groupshared))
387
s += "groupshared ";
388
}
389
390
if constexpr (is_param)
391
{
392
if (type.has(type::q_linear))
393
s += "linear ";
394
if (type.has(type::q_noperspective))
395
s += "noperspective ";
396
if (type.has(type::q_centroid))
397
s += "centroid ";
398
if (type.has(type::q_nointerpolation))
399
s += "nointerpolation ";
400
401
if (type.has(type::q_inout))
402
s += "inout ";
403
else if (type.has(type::q_in))
404
s += "in ";
405
else if (type.has(type::q_out))
406
s += "out ";
407
}
408
409
switch (type.base)
410
{
411
case type::t_void:
412
s += "void";
413
return;
414
case type::t_bool:
415
s += "bool";
416
break;
417
case type::t_min16int:
418
// Minimum precision types are only supported in shader model 4 and up
419
// Real 16-bit types were added in shader model 6.2
420
s += _shader_model >= 62 ? "int16_t" : _shader_model >= 40 ? "min16int" : "int";
421
break;
422
case type::t_int:
423
s += "int";
424
break;
425
case type::t_min16uint:
426
s += _shader_model >= 62 ? "uint16_t" : _shader_model >= 40 ? "min16uint" : "int";
427
break;
428
case type::t_uint:
429
// In shader model 3, uints can only be used with known-positive values, so use ints instead
430
s += _shader_model >= 40 ? "uint" : "int";
431
break;
432
case type::t_min16float:
433
s += _shader_model >= 62 ? "float16_t" : _shader_model >= 40 ? "min16float" : "float";
434
break;
435
case type::t_float:
436
s += "float";
437
break;
438
case type::t_struct:
439
s += id_to_name(type.struct_definition);
440
return;
441
case type::t_sampler1d_int:
442
case type::t_sampler2d_int:
443
case type::t_sampler3d_int:
444
s += "__sampler";
445
s += to_digit(type.texture_dimension());
446
s += 'D';
447
if (_shader_model >= 40)
448
{
449
s += "_int";
450
if (type.rows > 1)
451
s += to_digit(type.rows);
452
}
453
return;
454
case type::t_sampler1d_uint:
455
case type::t_sampler2d_uint:
456
case type::t_sampler3d_uint:
457
s += "__sampler";
458
s += to_digit(type.texture_dimension());
459
s += 'D';
460
if (_shader_model >= 40)
461
{
462
s += "_uint";
463
if (type.rows > 1)
464
s += to_digit(type.rows);
465
}
466
return;
467
case type::t_sampler1d_float:
468
case type::t_sampler2d_float:
469
case type::t_sampler3d_float:
470
s += "__sampler";
471
s += to_digit(type.texture_dimension());
472
s += 'D';
473
if (_shader_model >= 40)
474
{
475
s += "_float";
476
if (type.rows > 1)
477
s += to_digit(type.rows);
478
}
479
return;
480
case type::t_storage1d_int:
481
case type::t_storage2d_int:
482
case type::t_storage3d_int:
483
s += "RWTexture";
484
s += to_digit(type.texture_dimension());
485
s += "D<";
486
s += "int";
487
if (type.rows > 1)
488
s += to_digit(type.rows);
489
s += '>';
490
return;
491
case type::t_storage1d_uint:
492
case type::t_storage2d_uint:
493
case type::t_storage3d_uint:
494
s += "RWTexture";
495
s += to_digit(type.texture_dimension());
496
s += "D<";
497
s += "uint";
498
if (type.rows > 1)
499
s += to_digit(type.rows);
500
s += '>';
501
return;
502
case type::t_storage1d_float:
503
case type::t_storage2d_float:
504
case type::t_storage3d_float:
505
s += "RWTexture";
506
s += to_digit(type.texture_dimension());
507
s += "D<";
508
s += "float";
509
if (type.rows > 1)
510
s += to_digit(type.rows);
511
s += '>';
512
return;
513
default:
514
assert(false);
515
return;
516
}
517
518
if (type.rows > 1)
519
s += to_digit(type.rows);
520
if (type.cols > 1)
521
s += 'x', s += to_digit(type.cols);
522
}
523
void write_constant(std::string &s, const type &data_type, const constant &data) const
524
{
525
if (data_type.is_array())
526
{
527
assert(data_type.is_bounded_array());
528
529
type elem_type = data_type;
530
elem_type.array_length = 0;
531
532
s += "{ ";
533
534
for (unsigned int a = 0; a < data_type.array_length; ++a)
535
{
536
write_constant(s, elem_type, a < static_cast<unsigned int>(data.array_data.size()) ? data.array_data[a] : constant {});
537
s += ", ";
538
}
539
540
// Remove trailing ", "
541
s.erase(s.size() - 2);
542
543
s += " }";
544
return;
545
}
546
547
if (data_type.is_struct())
548
{
549
// The can only be zero initializer struct constants
550
assert(data.as_uint[0] == 0);
551
552
s += '(' + id_to_name(data_type.struct_definition) + ")0";
553
return;
554
}
555
556
// There can only be numeric constants
557
assert(data_type.is_numeric());
558
559
if (!data_type.is_scalar())
560
write_type<false, false>(s, data_type), s += '(';
561
562
for (unsigned int i = 0; i < data_type.components(); ++i)
563
{
564
switch (data_type.base)
565
{
566
case type::t_bool:
567
s += data.as_uint[i] ? "true" : "false";
568
break;
569
case type::t_min16int:
570
case type::t_int:
571
s += std::to_string(data.as_int[i]);
572
break;
573
case type::t_min16uint:
574
case type::t_uint:
575
s += std::to_string(data.as_uint[i]);
576
break;
577
case type::t_min16float:
578
case type::t_float:
579
if (std::isnan(data.as_float[i])) {
580
s += "-1.#IND";
581
break;
582
}
583
if (std::isinf(data.as_float[i])) {
584
s += std::signbit(data.as_float[i]) ? "1.#INF" : "-1.#INF";
585
break;
586
}
587
{
588
#ifdef _MSC_VER
589
char temp[64];
590
const std::to_chars_result res = std::to_chars(temp, temp + sizeof(temp), data.as_float[i], std::chars_format::scientific, 8);
591
if (res.ec == std::errc())
592
s.append(temp, res.ptr);
593
else
594
assert(false);
595
#else
596
std::ostringstream ss;
597
ss.imbue(std::locale::classic());
598
ss << data.as_float[i];
599
s += ss.str();
600
#endif
601
}
602
break;
603
default:
604
assert(false);
605
}
606
607
s += ", ";
608
}
609
610
// Remove trailing ", "
611
s.erase(s.size() - 2);
612
613
if (!data_type.is_scalar())
614
s += ')';
615
}
616
template <bool force_source = false>
617
void write_location(std::string &s, const location &loc)
618
{
619
if (loc.source.empty() || !_debug_info)
620
return;
621
622
s += "#line " + std::to_string(loc.line);
623
624
size_t offset = s.size();
625
626
// Avoid writing the file name every time to reduce output text size
627
if constexpr (force_source)
628
{
629
s += " \"" + loc.source + '\"';
630
}
631
else if (loc.source != _current_location)
632
{
633
s += " \"" + loc.source + '\"';
634
635
_current_location = loc.source;
636
}
637
638
// Need to escape string for new DirectX Shader Compiler (dxc)
639
if (_shader_model >= 60)
640
{
641
for (; (offset = s.find('\\', offset)) != std::string::npos; offset += 2)
642
s.insert(offset, "\\", 1);
643
}
644
645
s += '\n';
646
}
647
void write_texture_format(std::string &s, texture_format format)
648
{
649
switch (format)
650
{
651
case texture_format::r32i:
652
s += "int";
653
break;
654
case texture_format::r32u:
655
s += "uint";
656
break;
657
default:
658
assert(false);
659
[[fallthrough]];
660
case texture_format::unknown:
661
case texture_format::r8:
662
case texture_format::r16:
663
case texture_format::r16f:
664
case texture_format::r32f:
665
case texture_format::rg8:
666
case texture_format::rg16:
667
case texture_format::rg16f:
668
case texture_format::rg32f:
669
case texture_format::rgba8:
670
case texture_format::rgba16:
671
case texture_format::rgba16f:
672
case texture_format::rgba32f:
673
case texture_format::rgb10a2:
674
s += "float4";
675
break;
676
}
677
}
678
679
std::string id_to_name(id id) const
680
{
681
assert(id != 0);
682
if (const auto names_it = _names.find(id);
683
names_it != _names.end())
684
return names_it->second;
685
return '_' + std::to_string(id);
686
}
687
688
template <naming naming_type = naming::general>
689
void define_name(const id id, std::string name)
690
{
691
assert(!name.empty());
692
if constexpr (naming_type != naming::expression)
693
if (name[0] == '_')
694
return; // Filter out names that may clash with automatic ones
695
name = escape_name(std::move(name));
696
if constexpr (naming_type == naming::general)
697
if (std::find_if(_names.begin(), _names.end(),
698
[&name](const auto &names_it) { return names_it.second == name; }) != _names.end())
699
name += '_' + std::to_string(id); // Append a numbered suffix if the name already exists
700
_names[id] = std::move(name);
701
}
702
703
std::string convert_semantic(const std::string &semantic, uint32_t max_attributes = 1)
704
{
705
if (_shader_model < 40)
706
{
707
if (semantic == "SV_POSITION")
708
return "POSITION"; // For pixel shaders this has to be "VPOS", so need to redefine that in post
709
if (semantic == "VPOS")
710
return "VPOS";
711
if (semantic == "SV_POINTSIZE")
712
return "PSIZE";
713
if (semantic.compare(0, 9, "SV_TARGET") == 0)
714
return "COLOR" + semantic.substr(9);
715
if (semantic == "SV_DEPTH")
716
return "DEPTH";
717
if (semantic == "SV_VERTEXID")
718
return "TEXCOORD0 /* VERTEXID */";
719
if (semantic == "SV_ISFRONTFACE")
720
return "VFACE";
721
722
size_t digit_index = semantic.size() - 1;
723
while (digit_index != 0 && semantic[digit_index] >= '0' && semantic[digit_index] <= '9')
724
digit_index--;
725
digit_index++;
726
727
const std::string semantic_base = semantic.substr(0, digit_index);
728
729
uint32_t semantic_digit = 0;
730
std::from_chars(semantic.c_str() + digit_index, semantic.c_str() + semantic.size(), semantic_digit);
731
732
if (semantic_base == "TEXCOORD")
733
{
734
if (semantic_digit < 15)
735
{
736
assert(_remapped_semantics[semantic_digit].empty() || _remapped_semantics[semantic_digit] == semantic); // Mixing custom semantic names and multiple TEXCOORD indices is not supported
737
_remapped_semantics[semantic_digit] = semantic;
738
}
739
}
740
// Shader model 3 only supports a selected list of semantic names, so need to remap custom ones to that
741
else if (
742
semantic_base != "COLOR" &&
743
semantic_base != "NORMAL" &&
744
semantic_base != "TANGENT" &&
745
semantic_base != "BINORMAL")
746
{
747
// Legal semantic indices are between 0 and 15, but skip first entry in case both custom semantic names and the common TEXCOORD0 exist
748
for (int i = 1; i < 15; ++i)
749
{
750
if (_remapped_semantics[i].empty() || _remapped_semantics[i] == semantic)
751
{
752
for (uint32_t a = 0; a < max_attributes && i + a < 15; ++a)
753
_remapped_semantics[i + a] = semantic_base + std::to_string(semantic_digit + a);
754
755
return "TEXCOORD" + std::to_string(i) + " /* " + semantic + " */";
756
}
757
}
758
}
759
}
760
else
761
{
762
if (semantic.compare(0, 5, "COLOR") == 0)
763
return "SV_TARGET" + semantic.substr(5);
764
}
765
766
return semantic;
767
}
768
769
static std::string escape_name(std::string name)
770
{
771
static const auto stringicmp = [](const std::string &a, const std::string &b) {
772
#ifdef _WIN32
773
return _stricmp(a.c_str(), b.c_str()) == 0;
774
#else
775
return std::equal(a.begin(), a.end(), b.begin(), b.end(), [](std::string::value_type a, std::string::value_type b) { return std::tolower(a) == std::tolower(b); });
776
#endif
777
};
778
779
// HLSL compiler complains about "technique" and "pass" names in strict mode (no matter the casing)
780
if (stringicmp(name, "line") ||
781
stringicmp(name, "pass") ||
782
stringicmp(name, "technique") ||
783
stringicmp(name, "point") ||
784
stringicmp(name, "export") ||
785
stringicmp(name, "extern") ||
786
stringicmp(name, "compile") ||
787
stringicmp(name, "discard") ||
788
stringicmp(name, "half") ||
789
stringicmp(name, "in") ||
790
stringicmp(name, "lineadj") ||
791
stringicmp(name, "matrix") ||
792
stringicmp(name, "sample") ||
793
stringicmp(name, "sampler") ||
794
stringicmp(name, "shared") ||
795
stringicmp(name, "precise") ||
796
stringicmp(name, "register") ||
797
stringicmp(name, "texture") ||
798
stringicmp(name, "unorm") ||
799
stringicmp(name, "triangle") ||
800
stringicmp(name, "triangleadj") ||
801
stringicmp(name, "out") ||
802
stringicmp(name, "vector"))
803
// This is guaranteed to not clash with user defined names, since those starting with an underscore are filtered out in 'define_name'
804
name = '_' + name;
805
806
return name;
807
}
808
809
static void increase_indentation_level(std::string &block)
810
{
811
if (block.empty())
812
return;
813
814
for (size_t pos = 0; (pos = block.find("\n\t", pos)) != std::string::npos; pos += 3)
815
block.replace(pos, 2, "\n\t\t");
816
817
block.insert(block.begin(), '\t');
818
}
819
820
id define_struct(const location &loc, struct_type &info) override
821
{
822
const id res = info.id = make_id();
823
define_name<naming::unique>(res, info.unique_name);
824
825
_structs.push_back(info);
826
827
std::string &code = _blocks.at(_current_block);
828
829
write_location(code, loc);
830
831
code += "struct " + id_to_name(res) + "\n{\n";
832
833
for (const member_type &member : info.member_list)
834
{
835
code += '\t';
836
write_type<true>(code, member.type); // HLSL allows interpolation attributes on struct members, so handle this like a parameter
837
code += ' ' + member.name;
838
839
if (member.type.is_array())
840
code += '[' + std::to_string(member.type.array_length) + ']';
841
842
if (!member.semantic.empty())
843
code += " : " + convert_semantic(member.semantic, std::max(1u, member.type.components() / 4) * std::max(1u, member.type.array_length));
844
845
code += ";\n";
846
}
847
848
code += "};\n";
849
850
return res;
851
}
852
id define_texture(const location &, texture &info) override
853
{
854
const id res = info.id = make_id();
855
856
_module.textures.push_back(info);
857
858
return res;
859
}
860
id define_sampler(const location &loc, const texture &tex_info, sampler &info) override
861
{
862
const id res = info.id = create_block();
863
define_name<naming::unique>(res, info.unique_name);
864
865
std::string &code = _blocks.at(res);
866
867
// Default to a register index equivalent to the entry in the sampler list (this is later overwritten in 'finalize_code_for_entry_point' to a more optimal placement)
868
const uint32_t default_binding = static_cast<uint32_t>(_module.samplers.size());
869
uint32_t sampler_state_binding = 0;
870
871
if (_shader_model >= 40)
872
{
873
#if 0
874
// Try and reuse a sampler binding with the same sampler description
875
const auto existing_sampler_it = std::find_if(_sampler_lookup.begin(), _sampler_lookup.end(),
876
[&info](const sampler_desc &existing_info) {
877
return
878
existing_info.filter == info.filter &&
879
existing_info.address_u == info.address_u &&
880
existing_info.address_v == info.address_v &&
881
existing_info.address_w == info.address_w &&
882
existing_info.min_lod == info.min_lod &&
883
existing_info.max_lod == info.max_lod &&
884
existing_info.lod_bias == info.lod_bias;
885
});
886
if (existing_sampler_it != _sampler_lookup.end())
887
{
888
sampler_state_binding = existing_sampler_it->binding;
889
}
890
else
891
{
892
sampler_state_binding = static_cast<uint32_t>(_sampler_lookup.size());
893
894
sampler_binding s;
895
s.filter = info.filter;
896
s.address_u = info.address_u;
897
s.address_v = info.address_v;
898
s.address_w = info.address_w;
899
s.min_lod = info.min_lod;
900
s.max_lod = info.max_lod;
901
s.lod_bias = info.lod_bias;
902
s.binding = sampler_state_binding;
903
_sampler_lookup.push_back(std::move(s));
904
905
if (_shader_model >= 60)
906
_blocks.at(0) += "[[vk::binding(" + std::to_string(sampler_state_binding) + ", 1)]] "; // Descriptor set 1
907
908
_blocks.at(0) += "SamplerState __s" + std::to_string(sampler_state_binding) + " : register(s" + std::to_string(sampler_state_binding) + ");\n";
909
}
910
911
if (_shader_model >= 60)
912
code += "[[vk::binding(" + std::to_string(default_binding) + ", 2)]] "; // Descriptor set 2
913
914
code += "Texture";
915
code += to_digit(static_cast<unsigned int>(tex_info.type));
916
code += "D<";
917
write_texture_format(code, tex_info.format);
918
code += "> __" + info.unique_name + "_t : register(t" + std::to_string(default_binding) + "); \n";
919
920
write_location(code, loc);
921
922
code += "static const ";
923
write_type(code, info.type);
924
code += ' ' + id_to_name(res) + " = { __" + info.unique_name + "_t, __s" + std::to_string(sampler_state_binding) + " };\n";
925
#else
926
code += "Texture";
927
code += to_digit(static_cast<unsigned int>(tex_info.type));
928
code += "D<";
929
write_texture_format(code, tex_info.format);
930
code += "> __" + info.unique_name + "_t : register(t" + std::to_string(default_binding) + "); \n";
931
932
code += "SamplerState __" + info.unique_name + "_s : register(s" + std::to_string(default_binding) + "); \n";
933
934
write_location(code, loc);
935
936
code += "static const ";
937
write_type(code, info.type);
938
code += ' ' + id_to_name(res) + " = { __" + info.unique_name + "_t, __" + info.unique_name + "_s };\n";
939
#endif
940
}
941
else
942
{
943
const unsigned int texture_dimension = info.type.texture_dimension();
944
945
code += "sampler";
946
code += to_digit(texture_dimension);
947
code += "D __" + info.unique_name + "_s : register(s" + std::to_string(default_binding) + ");\n";
948
949
write_location(code, loc);
950
951
code += "static const ";
952
write_type(code, info.type);
953
code += ' ' + id_to_name(res) + " = { __" + info.unique_name + "_s, float" + to_digit(texture_dimension) + '(';
954
955
if (tex_info.semantic.empty())
956
{
957
code += "1.0 / " + std::to_string(tex_info.width);
958
if (texture_dimension >= 2)
959
code += ", 1.0 / " + std::to_string(tex_info.height);
960
if (texture_dimension >= 3)
961
code += ", 1.0 / " + std::to_string(tex_info.depth);
962
}
963
else
964
{
965
// Expect application to set inverse texture size via a define if it is not known here
966
code += tex_info.semantic + "_PIXEL_SIZE";
967
}
968
969
code += ") }; \n";
970
}
971
972
_module.samplers.push_back(info);
973
974
return res;
975
}
976
id define_storage(const location &loc, const texture &, storage &info) override
977
{
978
const id res = info.id = create_block();
979
define_name<naming::unique>(res, info.unique_name);
980
981
// Default to a register index equivalent to the entry in the storage list (this is later overwritten in 'finalize_code_for_entry_point' to a more optimal placement)
982
const uint32_t default_binding = static_cast<uint32_t>(_module.storages.size());
983
984
if (_shader_model >= 50)
985
{
986
std::string &code = _blocks.at(res);
987
988
write_location(code, loc);
989
990
#if 0
991
if (_shader_model >= 60)
992
code += "[[vk::binding(" + std::to_string(default_binding) + ", 3)]] "; // Descriptor set 3
993
#endif
994
995
write_type(code, info.type);
996
code += ' ' + info.unique_name + " : register(u" + std::to_string(default_binding) + ");\n";
997
}
998
999
_module.storages.push_back(info);
1000
1001
return res;
1002
}
1003
id define_uniform(const location &loc, uniform &info) override
1004
{
1005
const id res = make_id();
1006
define_name<naming::unique>(res, info.name);
1007
1008
if (_uniforms_to_spec_constants && info.has_initializer_value)
1009
{
1010
info.size = info.type.components() * 4;
1011
if (info.type.is_array())
1012
info.size *= info.type.array_length;
1013
1014
std::string &code = _blocks.at(_current_block);
1015
1016
write_location(code, loc);
1017
1018
assert(!info.type.has(type::q_static) && !info.type.has(type::q_const));
1019
1020
code += "static const ";
1021
write_type(code, info.type);
1022
code += ' ' + id_to_name(res) + " = ";
1023
if (!info.type.is_scalar())
1024
write_type<false, false>(code, info.type);
1025
code += "(SPEC_CONSTANT_" + info.name + ");\n";
1026
1027
_module.spec_constants.push_back(info);
1028
}
1029
else
1030
{
1031
if (info.type.is_matrix())
1032
info.size = align_up(info.type.cols * 4, 16, info.type.rows);
1033
else // Vectors are column major (1xN), matrices are row major (NxM)
1034
info.size = info.type.rows * 4;
1035
// Arrays are not packed in HLSL by default, each element is stored in a four-component vector (16 bytes)
1036
if (info.type.is_array())
1037
info.size = align_up(info.size, 16, info.type.array_length);
1038
1039
if (_shader_model < 40)
1040
_module.total_uniform_size /= 4;
1041
1042
// Data is packed into 4-byte boundaries (see https://docs.microsoft.com/windows/win32/direct3dhlsl/dx-graphics-hlsl-packing-rules)
1043
// This is already guaranteed, since all types are at least 4-byte in size
1044
info.offset = _module.total_uniform_size;
1045
// Additionally, HLSL packs data so that it does not cross a 16-byte boundary
1046
const uint32_t remaining = 16 - (info.offset & 15);
1047
if (remaining != 16 && info.size > remaining)
1048
info.offset += remaining;
1049
_module.total_uniform_size = info.offset + info.size;
1050
1051
write_location<true>(_cbuffer_block, loc);
1052
1053
if (_shader_model >= 40)
1054
_cbuffer_block += '\t';
1055
if (info.type.is_matrix()) // Force row major matrices
1056
_cbuffer_block += "row_major ";
1057
1058
type type = info.type;
1059
if (_shader_model < 40)
1060
{
1061
// The HLSL compiler tries to evaluate boolean values with temporary registers, which breaks branches, so force it to use constant float registers
1062
if (type.is_boolean())
1063
type.base = type::t_float;
1064
1065
// Simply put each uniform into a separate constant register in shader model 3 for now
1066
info.offset *= 4;
1067
_module.total_uniform_size *= 4;
1068
}
1069
1070
write_type(_cbuffer_block, type);
1071
_cbuffer_block += ' ' + id_to_name(res);
1072
1073
if (info.type.is_array())
1074
_cbuffer_block += '[' + std::to_string(info.type.array_length) + ']';
1075
1076
if (_shader_model < 40)
1077
{
1078
// Every constant register is 16 bytes wide, so divide memory offset by 16 to get the constant register index
1079
// Note: All uniforms are floating-point in shader model 3, even if the uniform type says different!!
1080
_cbuffer_block += " : register(c" + std::to_string(info.offset / 16) + ')';
1081
}
1082
1083
_cbuffer_block += ";\n";
1084
1085
_module.uniforms.push_back(info);
1086
}
1087
1088
return res;
1089
}
1090
id define_variable(const location &loc, const type &type, std::string name, bool global, id initializer_value) override
1091
{
1092
// Constant variables with a constant initializer can just point to the initializer SSA variable, since they cannot be modified anyway, thus saving an unnecessary assignment
1093
if (initializer_value != 0 && type.has(type::q_const) &&
1094
std::find_if(_constant_lookup.begin(), _constant_lookup.end(),
1095
[initializer_value](const auto &x) {
1096
return initializer_value == std::get<2>(x);
1097
}) != _constant_lookup.end())
1098
return initializer_value;
1099
1100
const id res = make_id();
1101
1102
if (!name.empty())
1103
define_name<naming::general>(res, name);
1104
1105
std::string &code = _blocks.at(_current_block);
1106
1107
write_location(code, loc);
1108
1109
if (!global)
1110
code += '\t';
1111
1112
write_type(code, type);
1113
code += ' ' + id_to_name(res);
1114
1115
if (type.is_array())
1116
code += '[' + std::to_string(type.array_length) + ']';
1117
1118
if (initializer_value != 0)
1119
code += " = " + id_to_name(initializer_value);
1120
1121
code += ";\n";
1122
1123
return res;
1124
}
1125
id define_function(const location &loc, function &info) override
1126
{
1127
const id res = info.id = make_id();
1128
define_name<naming::unique>(res, info.unique_name);
1129
1130
assert(_current_block == 0 && (_current_function_declaration.empty() || info.type != shader_type::unknown));
1131
std::string &code = _current_function_declaration;
1132
1133
write_location(code, loc);
1134
1135
write_type(code, info.return_type);
1136
code += ' ' + id_to_name(res) + '(';
1137
1138
for (member_type &param : info.parameter_list)
1139
{
1140
param.id = make_id();
1141
define_name<naming::unique>(param.id, param.name);
1142
1143
code += '\n';
1144
write_location(code, param.location);
1145
code += '\t';
1146
write_type<true>(code, param.type);
1147
code += ' ' + id_to_name(param.id);
1148
1149
if (param.type.is_array())
1150
code += '[' + std::to_string(param.type.array_length) + ']';
1151
1152
if (!param.semantic.empty())
1153
code += " : " + convert_semantic(param.semantic, std::max(1u, param.type.cols / 4u) * std::max(1u, param.type.array_length));
1154
1155
code += ',';
1156
}
1157
1158
// Remove trailing comma
1159
if (!info.parameter_list.empty())
1160
code.pop_back();
1161
1162
code += ')';
1163
1164
if (!info.return_semantic.empty())
1165
code += " : " + convert_semantic(info.return_semantic);
1166
1167
code += '\n';
1168
1169
_functions.push_back(std::make_unique<function>(info));
1170
_current_function = _functions.back().get();
1171
1172
return res;
1173
}
1174
1175
void define_entry_point(function &func) override
1176
{
1177
// Modify entry point name since a new function is created for it below
1178
assert(!func.unique_name.empty() && func.unique_name[0] == 'F');
1179
if (_shader_model < 40 || func.type == shader_type::compute)
1180
func.unique_name[0] = 'E';
1181
1182
if (func.type == shader_type::compute)
1183
func.unique_name +=
1184
'_' + std::to_string(func.num_threads[0]) +
1185
'_' + std::to_string(func.num_threads[1]) +
1186
'_' + std::to_string(func.num_threads[2]);
1187
1188
if (std::find_if(_module.entry_points.begin(), _module.entry_points.end(),
1189
[&func](const std::pair<std::string, shader_type> &entry_point) {
1190
return entry_point.first == func.unique_name;
1191
}) != _module.entry_points.end())
1192
return;
1193
1194
_module.entry_points.emplace_back(func.unique_name, func.type);
1195
1196
// Only have to rewrite the entry point function signature in shader model 3 and for compute (to write "numthreads" attribute)
1197
if (_shader_model >= 40 && func.type != shader_type::compute)
1198
return;
1199
1200
function entry_point = func;
1201
entry_point.referenced_functions.push_back(func.id);
1202
1203
const auto is_color_semantic = [](const std::string &semantic) {
1204
return semantic.compare(0, 9, "SV_TARGET") == 0 || semantic.compare(0, 5, "COLOR") == 0; };
1205
const auto is_position_semantic = [](const std::string &semantic) {
1206
return semantic == "SV_POSITION" || semantic == "POSITION"; };
1207
1208
const id ret = make_id();
1209
define_name<naming::general>(ret, "ret");
1210
1211
std::string position_variable_name;
1212
{
1213
if (func.type == shader_type::vertex && func.return_type.is_struct())
1214
{
1215
// If this function returns a struct which contains a position output, keep track of its member name
1216
for (const member_type &member : get_struct(func.return_type.struct_definition).member_list)
1217
if (is_position_semantic(member.semantic))
1218
position_variable_name = id_to_name(ret) + '.' + member.name;
1219
}
1220
1221
if (is_color_semantic(func.return_semantic))
1222
{
1223
// The COLOR output semantic has to be a four-component vector in shader model 3, so enforce that
1224
entry_point.return_type.rows = 4;
1225
}
1226
if (is_position_semantic(func.return_semantic))
1227
{
1228
if (func.type == shader_type::vertex)
1229
// Keep track of the position output variable
1230
position_variable_name = id_to_name(ret);
1231
}
1232
}
1233
for (member_type &param : entry_point.parameter_list)
1234
{
1235
if (func.type == shader_type::vertex && param.type.is_struct())
1236
{
1237
for (const member_type &member : get_struct(param.type.struct_definition).member_list)
1238
if (is_position_semantic(member.semantic))
1239
position_variable_name = id_to_name(param.id) + '.' + member.name;
1240
}
1241
1242
if (is_color_semantic(param.semantic))
1243
{
1244
param.type.rows = 4;
1245
}
1246
if (is_position_semantic(param.semantic))
1247
{
1248
if (func.type == shader_type::vertex)
1249
// Keep track of the position output variable
1250
position_variable_name = id_to_name(param.id);
1251
else if (func.type == shader_type::pixel)
1252
// Change the position input semantic in pixel shaders
1253
param.semantic = "VPOS";
1254
}
1255
}
1256
1257
assert(_current_function_declaration.empty());
1258
if (func.type == shader_type::compute)
1259
_current_function_declaration += "[numthreads(" +
1260
std::to_string(func.num_threads[0]) + ", " +
1261
std::to_string(func.num_threads[1]) + ", " +
1262
std::to_string(func.num_threads[2]) + ")]\n";
1263
1264
define_function({}, entry_point);
1265
enter_block(create_block());
1266
1267
std::string &code = _blocks.at(_current_block);
1268
1269
// Clear all color output parameters so no component is left uninitialized
1270
for (const member_type &param : entry_point.parameter_list)
1271
{
1272
if (is_color_semantic(param.semantic))
1273
code += '\t' + id_to_name(param.id) + " = float4(0.0, 0.0, 0.0, 0.0);\n";
1274
}
1275
1276
code += '\t';
1277
if (is_color_semantic(func.return_semantic))
1278
{
1279
code += "const float4 " + id_to_name(ret) + " = float4(";
1280
}
1281
else if (!func.return_type.is_void())
1282
{
1283
write_type(code, func.return_type);
1284
code += ' ' + id_to_name(ret) + " = ";
1285
}
1286
1287
// Call the function this entry point refers to
1288
code += id_to_name(func.id) + '(';
1289
1290
for (size_t i = 0; i < func.parameter_list.size(); ++i)
1291
{
1292
code += id_to_name(entry_point.parameter_list[i].id);
1293
1294
const member_type &param = func.parameter_list[i];
1295
1296
if (is_color_semantic(param.semantic))
1297
{
1298
code += '.';
1299
for (unsigned int c = 0; c < param.type.rows; c++)
1300
code += "xyzw"[c];
1301
}
1302
1303
code += ", ";
1304
}
1305
1306
// Remove trailing ", "
1307
if (!entry_point.parameter_list.empty())
1308
code.erase(code.size() - 2);
1309
1310
code += ')';
1311
1312
// Cast the output value to a four-component vector
1313
if (is_color_semantic(func.return_semantic))
1314
{
1315
for (unsigned int c = 0; c < (4 - func.return_type.rows); c++)
1316
code += ", 0.0";
1317
code += ')';
1318
}
1319
1320
code += ";\n";
1321
1322
// Shift everything by half a viewport pixel to workaround the different half-pixel offset in D3D9 (https://aras-p.info/blog/2016/04/08/solving-dx9-half-pixel-offset/)
1323
if (func.type == shader_type::vertex && !position_variable_name.empty()) // Check if we are in a vertex shader definition
1324
code += '\t' + position_variable_name + ".xy += __TEXEL_SIZE__ * " + position_variable_name + ".ww;\n";
1325
1326
leave_block_and_return(func.return_type.is_void() ? 0 : ret);
1327
leave_function();
1328
}
1329
1330
id emit_load(const expression &exp, bool force_new_id) override
1331
{
1332
if (exp.is_constant)
1333
return emit_constant(exp.type, exp.constant);
1334
else if (exp.chain.empty() && !force_new_id) // Can refer to values without access chain directly
1335
return exp.base;
1336
1337
const id res = make_id();
1338
1339
static const char s_matrix_swizzles[16][5] = {
1340
"_m00", "_m01", "_m02", "_m03",
1341
"_m10", "_m11", "_m12", "_m13",
1342
"_m20", "_m21", "_m22", "_m23",
1343
"_m30", "_m31", "_m32", "_m33"
1344
};
1345
1346
std::string type, expr_code = id_to_name(exp.base);
1347
1348
for (const expression::operation &op : exp.chain)
1349
{
1350
switch (op.op)
1351
{
1352
case expression::operation::op_cast:
1353
type.clear();
1354
write_type<false, false>(type, op.to);
1355
// Cast is in parentheses so that a subsequent operation operates on the casted value
1356
expr_code = "((" + type + ')' + expr_code + ')';
1357
break;
1358
case expression::operation::op_member:
1359
expr_code += '.';
1360
expr_code += get_struct(op.from.struct_definition).member_list[op.index].name;
1361
break;
1362
case expression::operation::op_dynamic_index:
1363
expr_code += '[' + id_to_name(op.index) + ']';
1364
break;
1365
case expression::operation::op_constant_index:
1366
if (op.from.is_vector() && !op.from.is_array())
1367
expr_code += '.',
1368
expr_code += "xyzw"[op.index];
1369
else
1370
expr_code += '[' + std::to_string(op.index) + ']';
1371
break;
1372
case expression::operation::op_swizzle:
1373
expr_code += '.';
1374
for (int i = 0; i < 4 && op.swizzle[i] >= 0; ++i)
1375
if (op.from.is_matrix())
1376
expr_code += s_matrix_swizzles[op.swizzle[i]];
1377
else
1378
expr_code += "xyzw"[op.swizzle[i]];
1379
break;
1380
}
1381
}
1382
1383
if (force_new_id)
1384
{
1385
// Need to store value in a new variable to comply with request for a new ID
1386
std::string &code = _blocks.at(_current_block);
1387
1388
code += '\t';
1389
write_type(code, exp.type);
1390
code += ' ' + id_to_name(res) + " = " + expr_code + ";\n";
1391
}
1392
else
1393
{
1394
// Avoid excessive variable definitions by instancing simple load operations in code every time
1395
define_name<naming::expression>(res, std::move(expr_code));
1396
}
1397
1398
return res;
1399
}
1400
void emit_store(const expression &exp, id value) override
1401
{
1402
std::string &code = _blocks.at(_current_block);
1403
1404
write_location(code, exp.location);
1405
1406
code += '\t' + id_to_name(exp.base);
1407
1408
static const char s_matrix_swizzles[16][5] = {
1409
"_m00", "_m01", "_m02", "_m03",
1410
"_m10", "_m11", "_m12", "_m13",
1411
"_m20", "_m21", "_m22", "_m23",
1412
"_m30", "_m31", "_m32", "_m33"
1413
};
1414
1415
for (const expression::operation &op : exp.chain)
1416
{
1417
switch (op.op)
1418
{
1419
case expression::operation::op_member:
1420
code += '.';
1421
code += get_struct(op.from.struct_definition).member_list[op.index].name;
1422
break;
1423
case expression::operation::op_dynamic_index:
1424
code += '[' + id_to_name(op.index) + ']';
1425
break;
1426
case expression::operation::op_constant_index:
1427
code += '[' + std::to_string(op.index) + ']';
1428
break;
1429
case expression::operation::op_swizzle:
1430
code += '.';
1431
for (int i = 0; i < 4 && op.swizzle[i] >= 0; ++i)
1432
if (op.from.is_matrix())
1433
code += s_matrix_swizzles[op.swizzle[i]];
1434
else
1435
code += "xyzw"[op.swizzle[i]];
1436
break;
1437
}
1438
}
1439
1440
code += " = " + id_to_name(value) + ";\n";
1441
}
1442
1443
id emit_constant(const type &data_type, const constant &data) override
1444
{
1445
const id res = make_id();
1446
1447
if (data_type.is_array())
1448
{
1449
assert(data_type.has(type::q_const));
1450
1451
if (const auto it = std::find_if(_constant_lookup.begin(), _constant_lookup.end(),
1452
[&data_type, &data](const std::tuple<type, constant, id> &x) {
1453
if (!(std::get<0>(x) == data_type && std::memcmp(&std::get<1>(x).as_uint[0], &data.as_uint[0], sizeof(uint32_t) * 16) == 0 && std::get<1>(x).array_data.size() == data.array_data.size()))
1454
return false;
1455
for (size_t i = 0; i < data.array_data.size(); ++i)
1456
if (std::memcmp(&std::get<1>(x).array_data[i].as_uint[0], &data.array_data[i].as_uint[0], sizeof(uint32_t) * 16) != 0)
1457
return false;
1458
return true;
1459
});
1460
it != _constant_lookup.end())
1461
return std::get<2>(*it); // Reuse existing constant instead of duplicating the definition
1462
else
1463
_constant_lookup.push_back({ data_type, data, res });
1464
1465
// Put constant variable into global scope, so that it can be reused in different blocks
1466
std::string &code = _blocks.at(0);
1467
1468
// Array constants need to be stored in a constant variable as they cannot be used in-place
1469
code += "static const ";
1470
write_type<false, false>(code, data_type);
1471
code += ' ' + id_to_name(res);
1472
code += '[' + std::to_string(data_type.array_length) + ']';
1473
code += " = ";
1474
write_constant(code, data_type, data);
1475
code += ";\n";
1476
return res;
1477
}
1478
1479
std::string code;
1480
write_constant(code, data_type, data);
1481
define_name<naming::expression>(res, std::move(code));
1482
1483
return res;
1484
}
1485
1486
id emit_unary_op(const location &loc, tokenid op, const type &res_type, id val) override
1487
{
1488
const id res = make_id();
1489
1490
std::string &code = _blocks.at(_current_block);
1491
1492
write_location(code, loc);
1493
1494
code += '\t';
1495
write_type(code, res_type);
1496
code += ' ' + id_to_name(res) + " = ";
1497
1498
if (_shader_model < 40 && op == tokenid::tilde)
1499
code += "0xFFFFFFFF - "; // Emulate bitwise not operator on shader model 3
1500
else
1501
code += char(op);
1502
1503
code += id_to_name(val) + ";\n";
1504
1505
return res;
1506
}
1507
id emit_binary_op(const location &loc, tokenid op, const type &res_type, const type &, id lhs, id rhs) override
1508
{
1509
const id res = make_id();
1510
1511
std::string &code = _blocks.at(_current_block);
1512
1513
write_location(code, loc);
1514
1515
code += '\t';
1516
write_type(code, res_type);
1517
code += ' ' + id_to_name(res) + " = ";
1518
1519
if (_shader_model < 40)
1520
{
1521
// See bitwise shift operator emulation below
1522
if (op == tokenid::less_less || op == tokenid::less_less_equal)
1523
code += '(';
1524
else if (op == tokenid::greater_greater || op == tokenid::greater_greater_equal)
1525
code += "floor(";
1526
}
1527
1528
code += id_to_name(lhs) + ' ';
1529
1530
switch (op)
1531
{
1532
case tokenid::plus:
1533
case tokenid::plus_plus:
1534
case tokenid::plus_equal:
1535
code += '+';
1536
break;
1537
case tokenid::minus:
1538
case tokenid::minus_minus:
1539
case tokenid::minus_equal:
1540
code += '-';
1541
break;
1542
case tokenid::star:
1543
case tokenid::star_equal:
1544
code += '*';
1545
break;
1546
case tokenid::slash:
1547
case tokenid::slash_equal:
1548
code += '/';
1549
break;
1550
case tokenid::percent:
1551
case tokenid::percent_equal:
1552
code += '%';
1553
break;
1554
case tokenid::caret:
1555
case tokenid::caret_equal:
1556
code += '^';
1557
break;
1558
case tokenid::pipe:
1559
case tokenid::pipe_equal:
1560
code += '|';
1561
break;
1562
case tokenid::ampersand:
1563
case tokenid::ampersand_equal:
1564
code += '&';
1565
break;
1566
case tokenid::less_less:
1567
case tokenid::less_less_equal:
1568
code += _shader_model >= 40 ? "<<" : ") * exp2("; // Emulate bitwise shift operators on shader model 3
1569
break;
1570
case tokenid::greater_greater:
1571
case tokenid::greater_greater_equal:
1572
code += _shader_model >= 40 ? ">>" : ") / exp2(";
1573
break;
1574
case tokenid::pipe_pipe:
1575
code += "||";
1576
break;
1577
case tokenid::ampersand_ampersand:
1578
code += "&&";
1579
break;
1580
case tokenid::less:
1581
code += '<';
1582
break;
1583
case tokenid::less_equal:
1584
code += "<=";
1585
break;
1586
case tokenid::greater:
1587
code += '>';
1588
break;
1589
case tokenid::greater_equal:
1590
code += ">=";
1591
break;
1592
case tokenid::equal_equal:
1593
code += "==";
1594
break;
1595
case tokenid::exclaim_equal:
1596
code += "!=";
1597
break;
1598
default:
1599
assert(false);
1600
}
1601
1602
code += ' ' + id_to_name(rhs);
1603
1604
if (_shader_model < 40)
1605
{
1606
// See bitwise shift operator emulation above
1607
if (op == tokenid::less_less || op == tokenid::less_less_equal ||
1608
op == tokenid::greater_greater || op == tokenid::greater_greater_equal)
1609
code += ')';
1610
}
1611
1612
code += ";\n";
1613
1614
return res;
1615
}
1616
id emit_ternary_op(const location &loc, tokenid op, const type &res_type, id condition, id true_value, id false_value) override
1617
{
1618
if (op != tokenid::question)
1619
return assert(false), 0; // Should never happen, since this is the only ternary operator currently supported
1620
1621
const id res = make_id();
1622
1623
std::string &code = _blocks.at(_current_block);
1624
1625
write_location(code, loc);
1626
1627
code += '\t';
1628
write_type(code, res_type);
1629
code += ' ' + id_to_name(res);
1630
1631
if (res_type.is_array())
1632
code += '[' + std::to_string(res_type.array_length) + ']';
1633
1634
code += " = " + id_to_name(condition) + " ? " + id_to_name(true_value) + " : " + id_to_name(false_value) + ";\n";
1635
1636
return res;
1637
}
1638
id emit_call(const location &loc, id function, const type &res_type, const std::vector<expression> &args) override
1639
{
1640
#ifndef NDEBUG
1641
for (const expression &arg : args)
1642
assert(arg.chain.empty() && arg.base != 0);
1643
#endif
1644
1645
const id res = make_id();
1646
1647
std::string &code = _blocks.at(_current_block);
1648
1649
write_location(code, loc);
1650
1651
code += '\t';
1652
1653
if (!res_type.is_void())
1654
{
1655
write_type(code, res_type);
1656
code += ' ' + id_to_name(res);
1657
1658
if (res_type.is_array())
1659
code += '[' + std::to_string(res_type.array_length) + ']';
1660
1661
code += " = ";
1662
}
1663
1664
code += id_to_name(function) + '(';
1665
1666
for (const expression &arg : args)
1667
{
1668
code += id_to_name(arg.base);
1669
code += ", ";
1670
}
1671
1672
// Remove trailing ", "
1673
if (!args.empty())
1674
code.erase(code.size() - 2);
1675
1676
code += ");\n";
1677
1678
return res;
1679
}
1680
id emit_call_intrinsic(const location &loc, id intrinsic, const type &res_type, const std::vector<expression> &args) override
1681
{
1682
#ifndef NDEBUG
1683
for (const expression &arg : args)
1684
assert(arg.chain.empty() && arg.base != 0);
1685
#endif
1686
1687
const id res = make_id();
1688
1689
std::string &code = _blocks.at(_current_block);
1690
1691
enum
1692
{
1693
#define IMPLEMENT_INTRINSIC_HLSL(name, i, code) name##i,
1694
#include "effect_symbol_table_intrinsics.inl"
1695
};
1696
1697
write_location(code, loc);
1698
1699
code += '\t';
1700
1701
if (!res_type.is_void())
1702
{
1703
write_type(code, res_type);
1704
code += ' ' + id_to_name(res) + " = ";
1705
}
1706
1707
switch (intrinsic)
1708
{
1709
#define IMPLEMENT_INTRINSIC_HLSL(name, i, code) case name##i: code break;
1710
#include "effect_symbol_table_intrinsics.inl"
1711
default:
1712
assert(false);
1713
}
1714
1715
code += ";\n";
1716
1717
return res;
1718
}
1719
id emit_construct(const location &loc, const type &res_type, const std::vector<expression> &args) override
1720
{
1721
#ifndef NDEBUG
1722
for (const expression &arg : args)
1723
assert((arg.type.is_scalar() || res_type.is_array()) && arg.chain.empty() && arg.base != 0);
1724
#endif
1725
1726
const id res = make_id();
1727
1728
std::string &code = _blocks.at(_current_block);
1729
1730
write_location(code, loc);
1731
1732
code += '\t';
1733
write_type(code, res_type);
1734
code += ' ' + id_to_name(res);
1735
1736
if (res_type.is_array())
1737
code += '[' + std::to_string(res_type.array_length) + ']';
1738
1739
code += " = ";
1740
1741
if (res_type.is_array())
1742
code += "{ ";
1743
else
1744
write_type<false, false>(code, res_type), code += '(';
1745
1746
for (const expression &arg : args)
1747
{
1748
code += id_to_name(arg.base);
1749
code += ", ";
1750
}
1751
1752
// Remove trailing ", "
1753
if (!args.empty())
1754
code.erase(code.size() - 2);
1755
1756
if (res_type.is_array())
1757
code += " }";
1758
else
1759
code += ')';
1760
1761
code += ";\n";
1762
1763
return res;
1764
}
1765
1766
void emit_if(const location &loc, id condition_value, id condition_block, id true_statement_block, id false_statement_block, unsigned int flags) override
1767
{
1768
assert(condition_value != 0 && condition_block != 0 && true_statement_block != 0 && false_statement_block != 0);
1769
1770
std::string &code = _blocks.at(_current_block);
1771
1772
std::string &true_statement_data = _blocks.at(true_statement_block);
1773
std::string &false_statement_data = _blocks.at(false_statement_block);
1774
1775
increase_indentation_level(true_statement_data);
1776
increase_indentation_level(false_statement_data);
1777
1778
code += _blocks.at(condition_block);
1779
1780
write_location(code, loc);
1781
1782
code += '\t';
1783
1784
if (flags & 0x1) code += "[flatten] ";
1785
if (flags & 0x2) code += "[branch] ";
1786
1787
code += "if (" + id_to_name(condition_value) + ")\n\t{\n";
1788
code += true_statement_data;
1789
code += "\t}\n";
1790
1791
if (!false_statement_data.empty())
1792
{
1793
code += "\telse\n\t{\n";
1794
code += false_statement_data;
1795
code += "\t}\n";
1796
}
1797
1798
// Remove consumed blocks to save memory
1799
_blocks.erase(condition_block);
1800
_blocks.erase(true_statement_block);
1801
_blocks.erase(false_statement_block);
1802
}
1803
id emit_phi(const location &loc, id condition_value, id condition_block, id true_value, id true_statement_block, id false_value, id false_statement_block, const type &res_type) override
1804
{
1805
assert(condition_value != 0 && condition_block != 0 && true_value != 0 && true_statement_block != 0 && false_value != 0 && false_statement_block != 0);
1806
1807
std::string &code = _blocks.at(_current_block);
1808
1809
std::string &true_statement_data = _blocks.at(true_statement_block);
1810
std::string &false_statement_data = _blocks.at(false_statement_block);
1811
1812
increase_indentation_level(true_statement_data);
1813
increase_indentation_level(false_statement_data);
1814
1815
const id res = make_id();
1816
1817
code += _blocks.at(condition_block);
1818
1819
code += '\t';
1820
write_type(code, res_type);
1821
code += ' ' + id_to_name(res) + ";\n";
1822
1823
write_location(code, loc);
1824
1825
code += "\tif (" + id_to_name(condition_value) + ")\n\t{\n";
1826
code += (true_statement_block != condition_block ? true_statement_data : std::string());
1827
code += "\t\t" + id_to_name(res) + " = " + id_to_name(true_value) + ";\n";
1828
code += "\t}\n\telse\n\t{\n";
1829
code += (false_statement_block != condition_block ? false_statement_data : std::string());
1830
code += "\t\t" + id_to_name(res) + " = " + id_to_name(false_value) + ";\n";
1831
code += "\t}\n";
1832
1833
// Remove consumed blocks to save memory
1834
_blocks.erase(condition_block);
1835
_blocks.erase(true_statement_block);
1836
_blocks.erase(false_statement_block);
1837
1838
return res;
1839
}
1840
void emit_loop(const location &loc, id condition_value, id prev_block, id header_block, id condition_block, id loop_block, id continue_block, unsigned int flags) override
1841
{
1842
assert(prev_block != 0 && header_block != 0 && loop_block != 0 && continue_block != 0);
1843
1844
std::string &code = _blocks.at(_current_block);
1845
1846
std::string &loop_data = _blocks.at(loop_block);
1847
std::string &continue_data = _blocks.at(continue_block);
1848
1849
increase_indentation_level(loop_data);
1850
increase_indentation_level(loop_data);
1851
increase_indentation_level(continue_data);
1852
1853
code += _blocks.at(prev_block);
1854
1855
std::string attributes;
1856
if (flags & 0x1)
1857
attributes += "[unroll] ";
1858
if (flags & 0x2)
1859
attributes += _shader_model >= 40 ? "[fastopt] " : "[loop] ";
1860
1861
// Condition value can be missing in infinite loop constructs like "for (;;)"
1862
std::string condition_name = condition_value != 0 ? id_to_name(condition_value) : "true";
1863
1864
if (condition_block == 0)
1865
{
1866
// Convert the last SSA variable initializer to an assignment statement
1867
const size_t pos_assign = continue_data.rfind(condition_name);
1868
const size_t pos_prev_assign = continue_data.rfind('\t', pos_assign);
1869
continue_data.erase(pos_prev_assign + 1, pos_assign - pos_prev_assign - 1);
1870
1871
// We need to add the continue block to all "continue" statements as well
1872
const std::string continue_id = "__CONTINUE__" + std::to_string(continue_block);
1873
for (size_t offset = 0; (offset = loop_data.find(continue_id, offset)) != std::string::npos; offset += continue_data.size())
1874
loop_data.replace(offset, continue_id.size(), continue_data);
1875
1876
code += "\tbool " + condition_name + ";\n";
1877
1878
write_location(code, loc);
1879
1880
code += '\t' + attributes;
1881
code += "do\n\t{\n\t\t{\n";
1882
code += loop_data; // Encapsulate loop body into another scope, so not to confuse any local variables with the current iteration variable accessed in the continue block below
1883
code += "\t\t}\n";
1884
code += continue_data;
1885
code += "\t}\n\twhile (" + condition_name + ");\n";
1886
}
1887
else
1888
{
1889
std::string &condition_data = _blocks.at(condition_block);
1890
1891
// Work around D3DCompiler putting uniform variables that are used as the loop count register into integer registers (only in SM3)
1892
// Only applies to dynamic loops with uniform variables in the condition, where it generates a loop instruction like "rep i0", but then expects the "i0" register to be set externally
1893
// Moving the loop condition into the loop body forces it to move the uniform variable into a constant register instead and geneates a fixed number of loop iterations with "defi i0, 255, ..."
1894
// Check 'condition_name' instead of 'condition_value' here to also catch cases where a constant boolean expression was passed in as loop condition
1895
bool use_break_statement_for_condition = (_shader_model < 40 && condition_name != "true") &&
1896
std::find_if(_module.uniforms.begin(), _module.uniforms.end(),
1897
[&](const uniform &info) {
1898
return condition_data.find(info.name) != std::string::npos || condition_name.find(info.name) != std::string::npos;
1899
}) != _module.uniforms.end();
1900
1901
// If the condition data is just a single line, then it is a simple expression, which we can just put into the loop condition as-is
1902
if (!use_break_statement_for_condition && std::count(condition_data.begin(), condition_data.end(), '\n') == 1)
1903
{
1904
// Convert SSA variable initializer back to a condition expression
1905
const size_t pos_assign = condition_data.find('=');
1906
condition_data.erase(0, pos_assign + 2);
1907
const size_t pos_semicolon = condition_data.rfind(';');
1908
condition_data.erase(pos_semicolon);
1909
1910
condition_name = std::move(condition_data);
1911
assert(condition_data.empty());
1912
}
1913
else
1914
{
1915
code += condition_data;
1916
1917
increase_indentation_level(condition_data);
1918
1919
// Convert the last SSA variable initializer to an assignment statement
1920
const size_t pos_assign = condition_data.rfind(condition_name);
1921
const size_t pos_prev_assign = condition_data.rfind('\t', pos_assign);
1922
condition_data.erase(pos_prev_assign + 1, pos_assign - pos_prev_assign - 1);
1923
}
1924
1925
const std::string continue_id = "__CONTINUE__" + std::to_string(continue_block);
1926
for (size_t offset = 0; (offset = loop_data.find(continue_id, offset)) != std::string::npos; offset += continue_data.size())
1927
loop_data.replace(offset, continue_id.size(), continue_data + condition_data);
1928
1929
write_location(code, loc);
1930
1931
code += '\t' + attributes;
1932
if (use_break_statement_for_condition)
1933
code += "while (true)\n\t{\n\t\tif (" + condition_name + ")\n\t\t{\n";
1934
else
1935
code += "while (" + condition_name + ")\n\t{\n\t\t{\n";
1936
code += loop_data;
1937
code += "\t\t}\n";
1938
if (use_break_statement_for_condition)
1939
code += "\t\telse break;\n";
1940
code += continue_data;
1941
code += condition_data;
1942
code += "\t}\n";
1943
1944
_blocks.erase(condition_block);
1945
}
1946
1947
// Remove consumed blocks to save memory
1948
_blocks.erase(prev_block);
1949
_blocks.erase(header_block);
1950
_blocks.erase(loop_block);
1951
_blocks.erase(continue_block);
1952
}
1953
void emit_switch(const location &loc, id selector_value, id selector_block, id default_label, id default_block, const std::vector<id> &case_literal_and_labels, const std::vector<id> &case_blocks, unsigned int flags) override
1954
{
1955
assert(selector_value != 0 && selector_block != 0 && default_label != 0 && default_block != 0);
1956
assert(case_blocks.size() == case_literal_and_labels.size() / 2);
1957
1958
std::string &code = _blocks.at(_current_block);
1959
1960
code += _blocks.at(selector_block);
1961
1962
if (_shader_model >= 40)
1963
{
1964
write_location(code, loc);
1965
1966
code += '\t';
1967
1968
if (flags & 0x1) code += "[flatten] ";
1969
if (flags & 0x2) code += "[branch] ";
1970
if (flags & 0x4) code += "[forcecase] ";
1971
if (flags & 0x8) code += "[call] ";
1972
1973
code += "switch (" + id_to_name(selector_value) + ")\n\t{\n";
1974
1975
std::vector<id> labels = case_literal_and_labels;
1976
for (size_t i = 0; i < labels.size(); i += 2)
1977
{
1978
if (labels[i + 1] == 0)
1979
continue; // Happens if a case was already handled, see below
1980
1981
code += "\tcase " + std::to_string(labels[i]) + ": ";
1982
1983
if (labels[i + 1] == default_label)
1984
{
1985
code += "default: ";
1986
default_label = 0;
1987
}
1988
else
1989
{
1990
for (size_t k = i + 2; k < labels.size(); k += 2)
1991
{
1992
if (labels[k + 1] == 0 || labels[k + 1] != labels[i + 1])
1993
continue;
1994
1995
code += "case " + std::to_string(labels[k]) + ": ";
1996
labels[k + 1] = 0;
1997
}
1998
}
1999
2000
assert(case_blocks[i / 2] != 0);
2001
std::string &case_data = _blocks.at(case_blocks[i / 2]);
2002
2003
increase_indentation_level(case_data);
2004
2005
code += "{\n";
2006
code += case_data;
2007
code += "\t}\n";
2008
}
2009
2010
if (default_label != 0 && default_block != _current_block)
2011
{
2012
std::string &default_data = _blocks.at(default_block);
2013
2014
increase_indentation_level(default_data);
2015
2016
code += "\tdefault: {\n";
2017
code += default_data;
2018
code += "\t}\n";
2019
2020
_blocks.erase(default_block);
2021
}
2022
2023
code += "\t}\n";
2024
}
2025
else // Switch statements do not work correctly in SM3 if a constant is used as selector value (this is a D3DCompiler bug), so replace them with if statements
2026
{
2027
write_location(code, loc);
2028
2029
code += "\t[unroll] do { "; // This dummy loop makes "break" statements work
2030
2031
if (flags & 0x1) code += "[flatten] ";
2032
if (flags & 0x2) code += "[branch] ";
2033
2034
std::vector<id> labels = case_literal_and_labels;
2035
for (size_t i = 0; i < labels.size(); i += 2)
2036
{
2037
if (labels[i + 1] == 0)
2038
continue; // Happens if a case was already handled, see below
2039
2040
code += "if (" + id_to_name(selector_value) + " == " + std::to_string(labels[i]);
2041
2042
for (size_t k = i + 2; k < labels.size(); k += 2)
2043
{
2044
if (labels[k + 1] == 0 || labels[k + 1] != labels[i + 1])
2045
continue;
2046
2047
code += " || " + id_to_name(selector_value) + " == " + std::to_string(labels[k]);
2048
labels[k + 1] = 0;
2049
}
2050
2051
assert(case_blocks[i / 2] != 0);
2052
std::string &case_data = _blocks.at(case_blocks[i / 2]);
2053
2054
increase_indentation_level(case_data);
2055
2056
code += ")\n\t{\n";
2057
code += case_data;
2058
code += "\t}\n\telse\n\t";
2059
}
2060
2061
code += "{\n";
2062
2063
if (default_block != _current_block)
2064
{
2065
std::string &default_data = _blocks.at(default_block);
2066
2067
increase_indentation_level(default_data);
2068
2069
code += default_data;
2070
2071
_blocks.erase(default_block);
2072
}
2073
2074
code += "\t} } while (false);\n";
2075
}
2076
2077
// Remove consumed blocks to save memory
2078
_blocks.erase(selector_block);
2079
for (const id case_block : case_blocks)
2080
_blocks.erase(case_block);
2081
}
2082
2083
id create_block() override
2084
{
2085
const id res = make_id();
2086
2087
std::string &block = _blocks.emplace(res, std::string()).first->second;
2088
// Reserve a decently big enough memory block to avoid frequent reallocations
2089
block.reserve(4096);
2090
2091
return res;
2092
}
2093
id set_block(id id) override
2094
{
2095
_last_block = _current_block;
2096
_current_block = id;
2097
2098
return _last_block;
2099
}
2100
void enter_block(id id) override
2101
{
2102
_current_block = id;
2103
}
2104
id leave_block_and_kill() override
2105
{
2106
if (!is_in_block())
2107
return 0;
2108
2109
std::string &code = _blocks.at(_current_block);
2110
2111
code += "\tdiscard;\n";
2112
2113
const type &return_type = _current_function->return_type;
2114
if (!return_type.is_void())
2115
{
2116
// HLSL compiler doesn't handle discard like a shader kill
2117
// Add a return statement to exit functions in case discard is the last control flow statement
2118
// See https://docs.microsoft.com/windows/win32/direct3dhlsl/discard--sm4---asm-
2119
code += "\treturn ";
2120
write_constant(code, return_type, constant());
2121
code += ";\n";
2122
}
2123
2124
return set_block(0);
2125
}
2126
id leave_block_and_return(id value) override
2127
{
2128
if (!is_in_block())
2129
return 0;
2130
2131
// Skip implicit return statement
2132
if (!_current_function->return_type.is_void() && value == 0)
2133
return set_block(0);
2134
2135
std::string &code = _blocks.at(_current_block);
2136
2137
code += "\treturn";
2138
2139
if (value != 0)
2140
code += ' ' + id_to_name(value);
2141
2142
code += ";\n";
2143
2144
return set_block(0);
2145
}
2146
id leave_block_and_switch(id, id) override
2147
{
2148
if (!is_in_block())
2149
return _last_block;
2150
2151
return set_block(0);
2152
}
2153
id leave_block_and_branch(id target, unsigned int loop_flow) override
2154
{
2155
if (!is_in_block())
2156
return _last_block;
2157
2158
std::string &code = _blocks.at(_current_block);
2159
2160
switch (loop_flow)
2161
{
2162
case 1:
2163
code += "\tbreak;\n";
2164
break;
2165
case 2: // Keep track of continue target block, so we can insert its code here later
2166
code += "__CONTINUE__" + std::to_string(target) + "\tcontinue;\n";
2167
break;
2168
}
2169
2170
return set_block(0);
2171
}
2172
id leave_block_and_branch_conditional(id, id, id) override
2173
{
2174
if (!is_in_block())
2175
return _last_block;
2176
2177
return set_block(0);
2178
}
2179
void leave_function() override
2180
{
2181
assert(_current_function != nullptr && _last_block != 0);
2182
2183
_blocks.emplace(_current_function->id, _current_function_declaration + "{\n" + _blocks.at(_last_block) + "}\n");
2184
2185
_current_function = nullptr;
2186
_current_function_declaration.clear();
2187
}
2188
};
2189
2190
codegen *reshadefx::create_codegen_hlsl(unsigned int shader_model, bool debug_info, bool uniforms_to_spec_constants)
2191
{
2192
return new codegen_hlsl(shader_model, debug_info, uniforms_to_spec_constants);
2193
}
2194
2195