Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
PojavLauncherTeam
GitHub Repository: PojavLauncherTeam/mesa
Path: blob/21.2-virgl/src/compiler/nir/nir_conversion_builder.h
4545 views
1
/*
2
* Copyright © 2020 Collabora Ltd.
3
*
4
* Permission is hereby granted, free of charge, to any person obtaining a
5
* copy of this software and associated documentation files (the "Software"),
6
* to deal in the Software without restriction, including without limitation
7
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
8
* and/or sell copies of the Software, and to permit persons to whom the
9
* Software is furnished to do so, subject to the following conditions:
10
*
11
* The above copyright notice and this permission notice (including the next
12
* paragraph) shall be included in all copies or substantial portions of the
13
* Software.
14
*
15
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21
* IN THE SOFTWARE.
22
*/
23
24
#ifndef NIR_CONVERSION_BUILDER_H
25
#define NIR_CONVERSION_BUILDER_H
26
27
#include "util/u_math.h"
28
#include "nir_builder.h"
29
#include "nir_builtin_builder.h"
30
31
#ifdef __cplusplus
32
extern "C" {
33
#endif
34
35
static inline nir_ssa_def *
36
nir_round_float_to_int(nir_builder *b, nir_ssa_def *src,
37
nir_rounding_mode round)
38
{
39
switch (round) {
40
case nir_rounding_mode_ru:
41
return nir_fceil(b, src);
42
43
case nir_rounding_mode_rd:
44
return nir_ffloor(b, src);
45
46
case nir_rounding_mode_rtne:
47
return nir_fround_even(b, src);
48
49
case nir_rounding_mode_undef:
50
case nir_rounding_mode_rtz:
51
break;
52
}
53
unreachable("unexpected rounding mode");
54
}
55
56
static inline nir_ssa_def *
57
nir_round_float_to_float(nir_builder *b, nir_ssa_def *src,
58
unsigned dest_bit_size,
59
nir_rounding_mode round)
60
{
61
unsigned src_bit_size = src->bit_size;
62
if (dest_bit_size > src_bit_size)
63
return src; /* No rounding is needed for an up-convert */
64
65
nir_op low_conv = nir_type_conversion_op(nir_type_float | src_bit_size,
66
nir_type_float | dest_bit_size,
67
nir_rounding_mode_undef);
68
nir_op high_conv = nir_type_conversion_op(nir_type_float | dest_bit_size,
69
nir_type_float | src_bit_size,
70
nir_rounding_mode_undef);
71
72
switch (round) {
73
case nir_rounding_mode_ru: {
74
/* If lower-precision conversion results in a lower value, push it
75
* up one ULP. */
76
nir_ssa_def *lower_prec =
77
nir_build_alu(b, low_conv, src, NULL, NULL, NULL);
78
nir_ssa_def *roundtrip =
79
nir_build_alu(b, high_conv, lower_prec, NULL, NULL, NULL);
80
nir_ssa_def *cmp = nir_flt(b, roundtrip, src);
81
nir_ssa_def *inf = nir_imm_floatN_t(b, INFINITY, dest_bit_size);
82
return nir_bcsel(b, cmp, nir_nextafter(b, lower_prec, inf), lower_prec);
83
}
84
case nir_rounding_mode_rd: {
85
/* If lower-precision conversion results in a higher value, push it
86
* down one ULP. */
87
nir_ssa_def *lower_prec =
88
nir_build_alu(b, low_conv, src, NULL, NULL, NULL);
89
nir_ssa_def *roundtrip =
90
nir_build_alu(b, high_conv, lower_prec, NULL, NULL, NULL);
91
nir_ssa_def *cmp = nir_flt(b, src, roundtrip);
92
nir_ssa_def *neg_inf = nir_imm_floatN_t(b, -INFINITY, dest_bit_size);
93
return nir_bcsel(b, cmp, nir_nextafter(b, lower_prec, neg_inf), lower_prec);
94
}
95
case nir_rounding_mode_rtz:
96
return nir_bcsel(b, nir_flt(b, src, nir_imm_zero(b, 1, src->bit_size)),
97
nir_round_float_to_float(b, src, dest_bit_size,
98
nir_rounding_mode_ru),
99
nir_round_float_to_float(b, src, dest_bit_size,
100
nir_rounding_mode_rd));
101
case nir_rounding_mode_rtne:
102
case nir_rounding_mode_undef:
103
break;
104
}
105
unreachable("unexpected rounding mode");
106
}
107
108
static inline nir_ssa_def *
109
nir_round_int_to_float(nir_builder *b, nir_ssa_def *src,
110
nir_alu_type src_type,
111
unsigned dest_bit_size,
112
nir_rounding_mode round)
113
{
114
/* We only care whether or not its signed */
115
src_type = nir_alu_type_get_base_type(src_type);
116
117
unsigned mantissa_bits;
118
switch (dest_bit_size) {
119
case 16:
120
mantissa_bits = 10;
121
break;
122
case 32:
123
mantissa_bits = 23;
124
break;
125
case 64:
126
mantissa_bits = 52;
127
break;
128
default: unreachable("Unsupported bit size");
129
}
130
131
if (src->bit_size < mantissa_bits)
132
return src;
133
134
if (src_type == nir_type_int) {
135
nir_ssa_def *sign =
136
nir_i2b1(b, nir_ishr(b, src, nir_imm_int(b, src->bit_size - 1)));
137
nir_ssa_def *abs = nir_iabs(b, src);
138
nir_ssa_def *positive_rounded =
139
nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, round);
140
nir_ssa_def *max_positive =
141
nir_imm_intN_t(b, (1ull << (src->bit_size - 1)) - 1, src->bit_size);
142
switch (round) {
143
case nir_rounding_mode_rtz:
144
return nir_bcsel(b, sign, nir_ineg(b, positive_rounded),
145
positive_rounded);
146
break;
147
case nir_rounding_mode_ru:
148
return nir_bcsel(b, sign,
149
nir_ineg(b, nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, nir_rounding_mode_rd)),
150
nir_umin(b, positive_rounded, max_positive));
151
break;
152
case nir_rounding_mode_rd:
153
return nir_bcsel(b, sign,
154
nir_ineg(b,
155
nir_umin(b, max_positive,
156
nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, nir_rounding_mode_ru))),
157
positive_rounded);
158
case nir_rounding_mode_rtne:
159
case nir_rounding_mode_undef:
160
break;
161
}
162
unreachable("unexpected rounding mode");
163
} else {
164
nir_ssa_def *mantissa_bit_size = nir_imm_int(b, mantissa_bits);
165
nir_ssa_def *msb = nir_imax(b, nir_ufind_msb(b, src), mantissa_bit_size);
166
nir_ssa_def *bits_to_lose = nir_isub(b, msb, mantissa_bit_size);
167
nir_ssa_def *one = nir_imm_intN_t(b, 1, src->bit_size);
168
nir_ssa_def *adjust = nir_ishl(b, one, bits_to_lose);
169
nir_ssa_def *mask = nir_inot(b, nir_isub(b, adjust, one));
170
nir_ssa_def *truncated = nir_iand(b, src, mask);
171
switch (round) {
172
case nir_rounding_mode_rtz:
173
case nir_rounding_mode_rd:
174
return truncated;
175
break;
176
case nir_rounding_mode_ru:
177
return nir_bcsel(b, nir_ieq(b, src, truncated),
178
src, nir_uadd_sat(b, truncated, adjust));
179
case nir_rounding_mode_rtne:
180
case nir_rounding_mode_undef:
181
break;
182
}
183
unreachable("unexpected rounding mode");
184
}
185
}
186
187
/** Returns true if the representable range of a contains the representable
188
* range of b.
189
*/
190
static inline bool
191
nir_alu_type_range_contains_type_range(nir_alu_type a, nir_alu_type b)
192
{
193
/* Split types from bit sizes */
194
nir_alu_type a_base_type = nir_alu_type_get_base_type(a);
195
nir_alu_type b_base_type = nir_alu_type_get_base_type(b);
196
unsigned a_bit_size = nir_alu_type_get_type_size(a);
197
unsigned b_bit_size = nir_alu_type_get_type_size(b);
198
199
/* This requires sized types */
200
assert(a_bit_size > 0 && b_bit_size > 0);
201
202
if (a_base_type == b_base_type && a_bit_size >= b_bit_size)
203
return true;
204
205
if (a_base_type == nir_type_int && b_base_type == nir_type_uint &&
206
a_bit_size > b_bit_size)
207
return true;
208
209
/* 16-bit floats fit in 32-bit integers */
210
if (a_base_type == nir_type_int && a_bit_size >= 32 &&
211
b == nir_type_float16)
212
return true;
213
214
/* All signed or unsigned ints can fit in float or above. A uint8 can fit
215
* in a float16.
216
*/
217
if (a_base_type == nir_type_float && b_base_type != nir_type_float &&
218
(a_bit_size >= 32 || b_bit_size == 8))
219
return true;
220
221
return false;
222
}
223
224
/**
225
* Retrieves limits used for clamping a value of the src type into
226
* the widest representable range of the dst type via cmp + bcsel
227
*/
228
static inline void
229
nir_get_clamp_limits(nir_builder *b,
230
nir_alu_type src_type,
231
nir_alu_type dest_type,
232
nir_ssa_def **low, nir_ssa_def **high)
233
{
234
/* Split types from bit sizes */
235
nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type);
236
nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type);
237
unsigned src_bit_size = nir_alu_type_get_type_size(src_type);
238
unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type);
239
assert(dest_bit_size != 0 && src_bit_size != 0);
240
241
*low = NULL;
242
*high = NULL;
243
244
/* limits of the destination type, expressed in the source type */
245
switch (dest_base_type) {
246
case nir_type_int: {
247
int64_t ilow, ihigh;
248
if (dest_bit_size == 64) {
249
ilow = INT64_MIN;
250
ihigh = INT64_MAX;
251
} else {
252
ilow = -(1ll << (dest_bit_size - 1));
253
ihigh = (1ll << (dest_bit_size - 1)) - 1;
254
}
255
256
if (src_base_type == nir_type_int) {
257
*low = nir_imm_intN_t(b, ilow, src_bit_size);
258
*high = nir_imm_intN_t(b, ihigh, src_bit_size);
259
} else if (src_base_type == nir_type_uint) {
260
assert(src_bit_size >= dest_bit_size);
261
*high = nir_imm_intN_t(b, ihigh, src_bit_size);
262
} else {
263
*low = nir_imm_floatN_t(b, ilow, src_bit_size);
264
*high = nir_imm_floatN_t(b, ihigh, src_bit_size);
265
}
266
break;
267
}
268
case nir_type_uint: {
269
uint64_t uhigh = dest_bit_size == 64 ?
270
~0ull : (1ull << dest_bit_size) - 1;
271
if (src_base_type != nir_type_float) {
272
*low = nir_imm_intN_t(b, 0, src_bit_size);
273
if (src_base_type == nir_type_uint || src_bit_size > dest_bit_size)
274
*high = nir_imm_intN_t(b, uhigh, src_bit_size);
275
} else {
276
*low = nir_imm_floatN_t(b, 0.0f, src_bit_size);
277
*high = nir_imm_floatN_t(b, uhigh, src_bit_size);
278
}
279
break;
280
}
281
case nir_type_float: {
282
double flow, fhigh;
283
switch (dest_bit_size) {
284
case 16:
285
flow = -65504.0f;
286
fhigh = 65504.0f;
287
break;
288
case 32:
289
flow = -FLT_MAX;
290
fhigh = FLT_MAX;
291
break;
292
case 64:
293
flow = -DBL_MAX;
294
fhigh = DBL_MAX;
295
break;
296
default:
297
unreachable("Unhandled bit size");
298
}
299
300
switch (src_base_type) {
301
case nir_type_int: {
302
int64_t src_ilow, src_ihigh;
303
if (src_bit_size == 64) {
304
src_ilow = INT64_MIN;
305
src_ihigh = INT64_MAX;
306
} else {
307
src_ilow = -(1ll << (src_bit_size - 1));
308
src_ihigh = (1ll << (src_bit_size - 1)) - 1;
309
}
310
if (src_ilow < flow)
311
*low = nir_imm_intN_t(b, flow, src_bit_size);
312
if (src_ihigh > fhigh)
313
*high = nir_imm_intN_t(b, fhigh, src_bit_size);
314
break;
315
}
316
case nir_type_uint: {
317
uint64_t src_uhigh = src_bit_size == 64 ?
318
~0ull : (1ull << src_bit_size) - 1;
319
if (src_uhigh > fhigh)
320
*high = nir_imm_intN_t(b, fhigh, src_bit_size);
321
break;
322
}
323
case nir_type_float:
324
*low = nir_imm_floatN_t(b, flow, src_bit_size);
325
*high = nir_imm_floatN_t(b, fhigh, src_bit_size);
326
break;
327
default:
328
unreachable("Clamping from unknown type");
329
}
330
break;
331
}
332
default:
333
unreachable("clamping to unknown type");
334
break;
335
}
336
}
337
338
/**
339
* Clamp the value into the widest representatble range of the
340
* destination type with cmp + bcsel.
341
*
342
* val/val_type: The variables used for bcsel
343
* src/src_type: The variables used for comparison
344
* dest_type: The type which determines the range used for comparison
345
*/
346
static inline nir_ssa_def *
347
nir_clamp_to_type_range(nir_builder *b,
348
nir_ssa_def *val, nir_alu_type val_type,
349
nir_ssa_def *src, nir_alu_type src_type,
350
nir_alu_type dest_type)
351
{
352
assert(nir_alu_type_get_type_size(src_type) == 0 ||
353
nir_alu_type_get_type_size(src_type) == src->bit_size);
354
src_type |= src->bit_size;
355
if (nir_alu_type_range_contains_type_range(dest_type, src_type))
356
return val;
357
358
/* limits of the destination type, expressed in the source type */
359
nir_ssa_def *low = NULL, *high = NULL;
360
nir_get_clamp_limits(b, src_type, dest_type, &low, &high);
361
362
nir_ssa_def *low_cond = NULL, *high_cond = NULL;
363
switch (nir_alu_type_get_base_type(src_type)) {
364
case nir_type_int:
365
low_cond = low ? nir_ilt(b, src, low) : NULL;
366
high_cond = high ? nir_ilt(b, high, src) : NULL;
367
break;
368
case nir_type_uint:
369
low_cond = low ? nir_ult(b, src, low) : NULL;
370
high_cond = high ? nir_ult(b, high, src) : NULL;
371
break;
372
case nir_type_float:
373
low_cond = low ? nir_fge(b, low, src) : NULL;
374
high_cond = high ? nir_fge(b, src, high) : NULL;
375
break;
376
default:
377
unreachable("clamping from unknown type");
378
}
379
380
nir_ssa_def *val_low = low, *val_high = high;
381
if (val_type != src_type) {
382
nir_get_clamp_limits(b, val_type, dest_type, &val_low, &val_high);
383
}
384
385
nir_ssa_def *res = val;
386
if (low_cond && val_low)
387
res = nir_bcsel(b, low_cond, val_low, res);
388
if (high_cond && val_high)
389
res = nir_bcsel(b, high_cond, val_high, res);
390
391
return res;
392
}
393
394
static inline nir_rounding_mode
395
nir_simplify_conversion_rounding(nir_alu_type src_type,
396
nir_alu_type dest_type,
397
nir_rounding_mode rounding)
398
{
399
nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type);
400
nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type);
401
unsigned src_bit_size = nir_alu_type_get_type_size(src_type);
402
unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type);
403
assert(src_bit_size > 0 && dest_bit_size > 0);
404
405
if (rounding == nir_rounding_mode_undef)
406
return rounding;
407
408
/* Pure integer conversion doesn't have any rounding */
409
if (src_base_type != nir_type_float &&
410
dest_base_type != nir_type_float)
411
return nir_rounding_mode_undef;
412
413
/* Float down-casts don't round */
414
if (src_base_type == nir_type_float &&
415
dest_base_type == nir_type_float &&
416
dest_bit_size >= src_bit_size)
417
return nir_rounding_mode_undef;
418
419
/* Regular float to int conversions are RTZ */
420
if (src_base_type == nir_type_float &&
421
dest_base_type != nir_type_float &&
422
rounding == nir_rounding_mode_rtz)
423
return nir_rounding_mode_undef;
424
425
/* The CL spec requires regular conversions to float to be RTNE */
426
if (dest_base_type == nir_type_float &&
427
rounding == nir_rounding_mode_rtne)
428
return nir_rounding_mode_undef;
429
430
/* Couldn't simplify */
431
return rounding;
432
}
433
434
static inline nir_ssa_def *
435
nir_convert_with_rounding(nir_builder *b,
436
nir_ssa_def *src, nir_alu_type src_type,
437
nir_alu_type dest_type,
438
nir_rounding_mode round,
439
bool clamp)
440
{
441
/* Some stuff wants sized types */
442
assert(nir_alu_type_get_type_size(src_type) == 0 ||
443
nir_alu_type_get_type_size(src_type) == src->bit_size);
444
src_type |= src->bit_size;
445
446
/* Split types from bit sizes */
447
nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type);
448
nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type);
449
unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type);
450
451
/* Try to simplify the conversion if we can */
452
clamp = clamp &&
453
!nir_alu_type_range_contains_type_range(dest_type, src_type);
454
round = nir_simplify_conversion_rounding(src_type, dest_type, round);
455
456
/* For float -> int/uint conversions, we might not be able to represent
457
* the destination range in the source float accurately. For these cases,
458
* do the comparison in float range, but the bcsel in the destination range.
459
*/
460
bool clamp_after_conversion = clamp &&
461
src_base_type == nir_type_float &&
462
dest_base_type != nir_type_float;
463
464
/*
465
* If we don't care about rounding and clamping, we can just use NIR's
466
* built-in ops. There is also a special case for SPIR-V in shaders, where
467
* f32/f64 -> f16 conversions can have one of two rounding modes applied,
468
* which NIR has built-in opcodes for.
469
*
470
* For the rest, we have our own implementation of rounding and clamping.
471
*/
472
bool trivial_convert;
473
if (!clamp && round == nir_rounding_mode_undef) {
474
trivial_convert = true;
475
} else if (!clamp && src_type == nir_type_float32 &&
476
dest_type == nir_type_float16 &&
477
(round == nir_rounding_mode_rtne ||
478
round == nir_rounding_mode_rtz)) {
479
trivial_convert = true;
480
} else {
481
trivial_convert = false;
482
}
483
if (trivial_convert) {
484
nir_op op = nir_type_conversion_op(src_type, dest_type, round);
485
return nir_build_alu(b, op, src, NULL, NULL, NULL);
486
}
487
488
nir_ssa_def *dest = src;
489
490
/* clamp the result into range */
491
if (clamp && !clamp_after_conversion)
492
dest = nir_clamp_to_type_range(b, src, src_type, src, src_type, dest_type);
493
494
/* round with selected rounding mode */
495
if (!trivial_convert && round != nir_rounding_mode_undef) {
496
if (src_base_type == nir_type_float) {
497
if (dest_base_type == nir_type_float) {
498
dest = nir_round_float_to_float(b, dest, dest_bit_size, round);
499
} else {
500
dest = nir_round_float_to_int(b, dest, round);
501
}
502
} else {
503
dest = nir_round_int_to_float(b, dest, src_type, dest_bit_size, round);
504
}
505
506
round = nir_rounding_mode_undef;
507
}
508
509
/* now we can convert the value */
510
nir_op op = nir_type_conversion_op(src_type, dest_type, round);
511
dest = nir_build_alu(b, op, dest, NULL, NULL, NULL);
512
513
if (clamp_after_conversion)
514
dest = nir_clamp_to_type_range(b, dest, dest_type, src, src_type, dest_type);
515
516
return dest;
517
}
518
519
#ifdef __cplusplus
520
}
521
#endif
522
523
#endif /* NIR_CONVERSION_BUILDER_H */
524
525