Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
PojavLauncherTeam
GitHub Repository: PojavLauncherTeam/mesa
Path: blob/21.2-virgl/src/compiler/spirv/vtn_alu.c
4545 views
1
/*
2
* Copyright © 2016 Intel Corporation
3
*
4
* Permission is hereby granted, free of charge, to any person obtaining a
5
* copy of this software and associated documentation files (the "Software"),
6
* to deal in the Software without restriction, including without limitation
7
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
8
* and/or sell copies of the Software, and to permit persons to whom the
9
* Software is furnished to do so, subject to the following conditions:
10
*
11
* The above copyright notice and this permission notice (including the next
12
* paragraph) shall be included in all copies or substantial portions of the
13
* Software.
14
*
15
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21
* IN THE SOFTWARE.
22
*/
23
24
#include <math.h>
25
#include "vtn_private.h"
26
#include "spirv_info.h"
27
28
/*
29
* Normally, column vectors in SPIR-V correspond to a single NIR SSA
30
* definition. But for matrix multiplies, we want to do one routine for
31
* multiplying a matrix by a matrix and then pretend that vectors are matrices
32
* with one column. So we "wrap" these things, and unwrap the result before we
33
* send it off.
34
*/
35
36
static struct vtn_ssa_value *
37
wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
38
{
39
if (val == NULL)
40
return NULL;
41
42
if (glsl_type_is_matrix(val->type))
43
return val;
44
45
struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value);
46
dest->type = glsl_get_bare_type(val->type);
47
dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1);
48
dest->elems[0] = val;
49
50
return dest;
51
}
52
53
static struct vtn_ssa_value *
54
unwrap_matrix(struct vtn_ssa_value *val)
55
{
56
if (glsl_type_is_matrix(val->type))
57
return val;
58
59
return val->elems[0];
60
}
61
62
static struct vtn_ssa_value *
63
matrix_multiply(struct vtn_builder *b,
64
struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
65
{
66
67
struct vtn_ssa_value *src0 = wrap_matrix(b, _src0);
68
struct vtn_ssa_value *src1 = wrap_matrix(b, _src1);
69
struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed);
70
struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed);
71
72
unsigned src0_rows = glsl_get_vector_elements(src0->type);
73
unsigned src0_columns = glsl_get_matrix_columns(src0->type);
74
unsigned src1_columns = glsl_get_matrix_columns(src1->type);
75
76
const struct glsl_type *dest_type;
77
if (src1_columns > 1) {
78
dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),
79
src0_rows, src1_columns);
80
} else {
81
dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);
82
}
83
struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
84
85
dest = wrap_matrix(b, dest);
86
87
bool transpose_result = false;
88
if (src0_transpose && src1_transpose) {
89
/* transpose(A) * transpose(B) = transpose(B * A) */
90
src1 = src0_transpose;
91
src0 = src1_transpose;
92
src0_transpose = NULL;
93
src1_transpose = NULL;
94
transpose_result = true;
95
}
96
97
if (src0_transpose && !src1_transpose &&
98
glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) {
99
/* We already have the rows of src0 and the columns of src1 available,
100
* so we can just take the dot product of each row with each column to
101
* get the result.
102
*/
103
104
for (unsigned i = 0; i < src1_columns; i++) {
105
nir_ssa_def *vec_src[4];
106
for (unsigned j = 0; j < src0_rows; j++) {
107
vec_src[j] = nir_fdot(&b->nb, src0_transpose->elems[j]->def,
108
src1->elems[i]->def);
109
}
110
dest->elems[i]->def = nir_vec(&b->nb, vec_src, src0_rows);
111
}
112
} else {
113
/* We don't handle the case where src1 is transposed but not src0, since
114
* the general case only uses individual components of src1 so the
115
* optimizer should chew through the transpose we emitted for src1.
116
*/
117
118
for (unsigned i = 0; i < src1_columns; i++) {
119
/* dest[i] = sum(src0[j] * src1[i][j] for all j) */
120
dest->elems[i]->def =
121
nir_fmul(&b->nb, src0->elems[src0_columns - 1]->def,
122
nir_channel(&b->nb, src1->elems[i]->def, src0_columns - 1));
123
for (int j = src0_columns - 2; j >= 0; j--) {
124
dest->elems[i]->def =
125
nir_fadd(&b->nb, dest->elems[i]->def,
126
nir_fmul(&b->nb, src0->elems[j]->def,
127
nir_channel(&b->nb, src1->elems[i]->def, j)));
128
}
129
}
130
}
131
132
dest = unwrap_matrix(dest);
133
134
if (transpose_result)
135
dest = vtn_ssa_transpose(b, dest);
136
137
return dest;
138
}
139
140
static struct vtn_ssa_value *
141
mat_times_scalar(struct vtn_builder *b,
142
struct vtn_ssa_value *mat,
143
nir_ssa_def *scalar)
144
{
145
struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
146
for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
147
if (glsl_base_type_is_integer(glsl_get_base_type(mat->type)))
148
dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
149
else
150
dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
151
}
152
153
return dest;
154
}
155
156
static struct vtn_ssa_value *
157
vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
158
struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
159
{
160
switch (opcode) {
161
case SpvOpFNegate: {
162
struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
163
unsigned cols = glsl_get_matrix_columns(src0->type);
164
for (unsigned i = 0; i < cols; i++)
165
dest->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def);
166
return dest;
167
}
168
169
case SpvOpFAdd: {
170
struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
171
unsigned cols = glsl_get_matrix_columns(src0->type);
172
for (unsigned i = 0; i < cols; i++)
173
dest->elems[i]->def =
174
nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
175
return dest;
176
}
177
178
case SpvOpFSub: {
179
struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
180
unsigned cols = glsl_get_matrix_columns(src0->type);
181
for (unsigned i = 0; i < cols; i++)
182
dest->elems[i]->def =
183
nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
184
return dest;
185
}
186
187
case SpvOpTranspose:
188
return vtn_ssa_transpose(b, src0);
189
190
case SpvOpMatrixTimesScalar:
191
if (src0->transposed) {
192
return vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
193
src1->def));
194
} else {
195
return mat_times_scalar(b, src0, src1->def);
196
}
197
break;
198
199
case SpvOpVectorTimesMatrix:
200
case SpvOpMatrixTimesVector:
201
case SpvOpMatrixTimesMatrix:
202
if (opcode == SpvOpVectorTimesMatrix) {
203
return matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);
204
} else {
205
return matrix_multiply(b, src0, src1);
206
}
207
break;
208
209
default: vtn_fail_with_opcode("unknown matrix opcode", opcode);
210
}
211
}
212
213
static nir_alu_type
214
convert_op_src_type(SpvOp opcode)
215
{
216
switch (opcode) {
217
case SpvOpFConvert:
218
case SpvOpConvertFToS:
219
case SpvOpConvertFToU:
220
return nir_type_float;
221
case SpvOpSConvert:
222
case SpvOpConvertSToF:
223
case SpvOpSatConvertSToU:
224
return nir_type_int;
225
case SpvOpUConvert:
226
case SpvOpConvertUToF:
227
case SpvOpSatConvertUToS:
228
return nir_type_uint;
229
default:
230
unreachable("Unhandled conversion op");
231
}
232
}
233
234
static nir_alu_type
235
convert_op_dst_type(SpvOp opcode)
236
{
237
switch (opcode) {
238
case SpvOpFConvert:
239
case SpvOpConvertSToF:
240
case SpvOpConvertUToF:
241
return nir_type_float;
242
case SpvOpSConvert:
243
case SpvOpConvertFToS:
244
case SpvOpSatConvertUToS:
245
return nir_type_int;
246
case SpvOpUConvert:
247
case SpvOpConvertFToU:
248
case SpvOpSatConvertSToU:
249
return nir_type_uint;
250
default:
251
unreachable("Unhandled conversion op");
252
}
253
}
254
255
nir_op
256
vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
257
SpvOp opcode, bool *swap, bool *exact,
258
unsigned src_bit_size, unsigned dst_bit_size)
259
{
260
/* Indicates that the first two arguments should be swapped. This is
261
* used for implementing greater-than and less-than-or-equal.
262
*/
263
*swap = false;
264
265
*exact = false;
266
267
switch (opcode) {
268
case SpvOpSNegate: return nir_op_ineg;
269
case SpvOpFNegate: return nir_op_fneg;
270
case SpvOpNot: return nir_op_inot;
271
case SpvOpIAdd: return nir_op_iadd;
272
case SpvOpFAdd: return nir_op_fadd;
273
case SpvOpISub: return nir_op_isub;
274
case SpvOpFSub: return nir_op_fsub;
275
case SpvOpIMul: return nir_op_imul;
276
case SpvOpFMul: return nir_op_fmul;
277
case SpvOpUDiv: return nir_op_udiv;
278
case SpvOpSDiv: return nir_op_idiv;
279
case SpvOpFDiv: return nir_op_fdiv;
280
case SpvOpUMod: return nir_op_umod;
281
case SpvOpSMod: return nir_op_imod;
282
case SpvOpFMod: return nir_op_fmod;
283
case SpvOpSRem: return nir_op_irem;
284
case SpvOpFRem: return nir_op_frem;
285
286
case SpvOpShiftRightLogical: return nir_op_ushr;
287
case SpvOpShiftRightArithmetic: return nir_op_ishr;
288
case SpvOpShiftLeftLogical: return nir_op_ishl;
289
case SpvOpLogicalOr: return nir_op_ior;
290
case SpvOpLogicalEqual: return nir_op_ieq;
291
case SpvOpLogicalNotEqual: return nir_op_ine;
292
case SpvOpLogicalAnd: return nir_op_iand;
293
case SpvOpLogicalNot: return nir_op_inot;
294
case SpvOpBitwiseOr: return nir_op_ior;
295
case SpvOpBitwiseXor: return nir_op_ixor;
296
case SpvOpBitwiseAnd: return nir_op_iand;
297
case SpvOpSelect: return nir_op_bcsel;
298
case SpvOpIEqual: return nir_op_ieq;
299
300
case SpvOpBitFieldInsert: return nir_op_bitfield_insert;
301
case SpvOpBitFieldSExtract: return nir_op_ibitfield_extract;
302
case SpvOpBitFieldUExtract: return nir_op_ubitfield_extract;
303
case SpvOpBitReverse: return nir_op_bitfield_reverse;
304
305
case SpvOpUCountLeadingZerosINTEL: return nir_op_uclz;
306
/* SpvOpUCountTrailingZerosINTEL is handled elsewhere. */
307
case SpvOpAbsISubINTEL: return nir_op_uabs_isub;
308
case SpvOpAbsUSubINTEL: return nir_op_uabs_usub;
309
case SpvOpIAddSatINTEL: return nir_op_iadd_sat;
310
case SpvOpUAddSatINTEL: return nir_op_uadd_sat;
311
case SpvOpIAverageINTEL: return nir_op_ihadd;
312
case SpvOpUAverageINTEL: return nir_op_uhadd;
313
case SpvOpIAverageRoundedINTEL: return nir_op_irhadd;
314
case SpvOpUAverageRoundedINTEL: return nir_op_urhadd;
315
case SpvOpISubSatINTEL: return nir_op_isub_sat;
316
case SpvOpUSubSatINTEL: return nir_op_usub_sat;
317
case SpvOpIMul32x16INTEL: return nir_op_imul_32x16;
318
case SpvOpUMul32x16INTEL: return nir_op_umul_32x16;
319
320
/* The ordered / unordered operators need special implementation besides
321
* the logical operator to use since they also need to check if operands are
322
* ordered.
323
*/
324
case SpvOpFOrdEqual: *exact = true; return nir_op_feq;
325
case SpvOpFUnordEqual: *exact = true; return nir_op_feq;
326
case SpvOpINotEqual: return nir_op_ine;
327
case SpvOpLessOrGreater: /* Deprecated, use OrdNotEqual */
328
case SpvOpFOrdNotEqual: *exact = true; return nir_op_fneu;
329
case SpvOpFUnordNotEqual: *exact = true; return nir_op_fneu;
330
case SpvOpULessThan: return nir_op_ult;
331
case SpvOpSLessThan: return nir_op_ilt;
332
case SpvOpFOrdLessThan: *exact = true; return nir_op_flt;
333
case SpvOpFUnordLessThan: *exact = true; return nir_op_flt;
334
case SpvOpUGreaterThan: *swap = true; return nir_op_ult;
335
case SpvOpSGreaterThan: *swap = true; return nir_op_ilt;
336
case SpvOpFOrdGreaterThan: *swap = true; *exact = true; return nir_op_flt;
337
case SpvOpFUnordGreaterThan: *swap = true; *exact = true; return nir_op_flt;
338
case SpvOpULessThanEqual: *swap = true; return nir_op_uge;
339
case SpvOpSLessThanEqual: *swap = true; return nir_op_ige;
340
case SpvOpFOrdLessThanEqual: *swap = true; *exact = true; return nir_op_fge;
341
case SpvOpFUnordLessThanEqual: *swap = true; *exact = true; return nir_op_fge;
342
case SpvOpUGreaterThanEqual: return nir_op_uge;
343
case SpvOpSGreaterThanEqual: return nir_op_ige;
344
case SpvOpFOrdGreaterThanEqual: *exact = true; return nir_op_fge;
345
case SpvOpFUnordGreaterThanEqual: *exact = true; return nir_op_fge;
346
347
/* Conversions: */
348
case SpvOpQuantizeToF16: return nir_op_fquantize2f16;
349
case SpvOpUConvert:
350
case SpvOpConvertFToU:
351
case SpvOpConvertFToS:
352
case SpvOpConvertSToF:
353
case SpvOpConvertUToF:
354
case SpvOpSConvert:
355
case SpvOpFConvert: {
356
nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
357
nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
358
return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef);
359
}
360
361
case SpvOpPtrCastToGeneric: return nir_op_mov;
362
case SpvOpGenericCastToPtr: return nir_op_mov;
363
364
/* Derivatives: */
365
case SpvOpDPdx: return nir_op_fddx;
366
case SpvOpDPdy: return nir_op_fddy;
367
case SpvOpDPdxFine: return nir_op_fddx_fine;
368
case SpvOpDPdyFine: return nir_op_fddy_fine;
369
case SpvOpDPdxCoarse: return nir_op_fddx_coarse;
370
case SpvOpDPdyCoarse: return nir_op_fddy_coarse;
371
372
case SpvOpIsNormal: return nir_op_fisnormal;
373
case SpvOpIsFinite: return nir_op_fisfinite;
374
375
default:
376
vtn_fail("No NIR equivalent: %u", opcode);
377
}
378
}
379
380
static void
381
handle_no_contraction(struct vtn_builder *b, struct vtn_value *val, int member,
382
const struct vtn_decoration *dec, void *_void)
383
{
384
vtn_assert(dec->scope == VTN_DEC_DECORATION);
385
if (dec->decoration != SpvDecorationNoContraction)
386
return;
387
388
b->nb.exact = true;
389
}
390
391
void
392
vtn_handle_no_contraction(struct vtn_builder *b, struct vtn_value *val)
393
{
394
vtn_foreach_decoration(b, val, handle_no_contraction, NULL);
395
}
396
397
nir_rounding_mode
398
vtn_rounding_mode_to_nir(struct vtn_builder *b, SpvFPRoundingMode mode)
399
{
400
switch (mode) {
401
case SpvFPRoundingModeRTE:
402
return nir_rounding_mode_rtne;
403
case SpvFPRoundingModeRTZ:
404
return nir_rounding_mode_rtz;
405
case SpvFPRoundingModeRTP:
406
vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
407
"FPRoundingModeRTP is only supported in kernels");
408
return nir_rounding_mode_ru;
409
case SpvFPRoundingModeRTN:
410
vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
411
"FPRoundingModeRTN is only supported in kernels");
412
return nir_rounding_mode_rd;
413
default:
414
vtn_fail("Unsupported rounding mode: %s",
415
spirv_fproundingmode_to_string(mode));
416
break;
417
}
418
}
419
420
struct conversion_opts {
421
nir_rounding_mode rounding_mode;
422
bool saturate;
423
};
424
425
static void
426
handle_conversion_opts(struct vtn_builder *b, struct vtn_value *val, int member,
427
const struct vtn_decoration *dec, void *_opts)
428
{
429
struct conversion_opts *opts = _opts;
430
431
switch (dec->decoration) {
432
case SpvDecorationFPRoundingMode:
433
opts->rounding_mode = vtn_rounding_mode_to_nir(b, dec->operands[0]);
434
break;
435
436
case SpvDecorationSaturatedConversion:
437
vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
438
"Saturated conversions are only allowed in kernels");
439
opts->saturate = true;
440
break;
441
442
default:
443
break;
444
}
445
}
446
447
static void
448
handle_no_wrap(struct vtn_builder *b, struct vtn_value *val, int member,
449
const struct vtn_decoration *dec, void *_alu)
450
{
451
nir_alu_instr *alu = _alu;
452
switch (dec->decoration) {
453
case SpvDecorationNoSignedWrap:
454
alu->no_signed_wrap = true;
455
break;
456
case SpvDecorationNoUnsignedWrap:
457
alu->no_unsigned_wrap = true;
458
break;
459
default:
460
/* Do nothing. */
461
break;
462
}
463
}
464
465
void
466
vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
467
const uint32_t *w, unsigned count)
468
{
469
struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
470
const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
471
472
vtn_handle_no_contraction(b, dest_val);
473
474
/* Collect the various SSA sources */
475
const unsigned num_inputs = count - 3;
476
struct vtn_ssa_value *vtn_src[4] = { NULL, };
477
for (unsigned i = 0; i < num_inputs; i++)
478
vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
479
480
if (glsl_type_is_matrix(vtn_src[0]->type) ||
481
(num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
482
vtn_push_ssa_value(b, w[2],
483
vtn_handle_matrix_alu(b, opcode, vtn_src[0], vtn_src[1]));
484
b->nb.exact = b->exact;
485
return;
486
}
487
488
struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
489
nir_ssa_def *src[4] = { NULL, };
490
for (unsigned i = 0; i < num_inputs; i++) {
491
vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
492
src[i] = vtn_src[i]->def;
493
}
494
495
switch (opcode) {
496
case SpvOpAny:
497
dest->def = nir_bany(&b->nb, src[0]);
498
break;
499
500
case SpvOpAll:
501
dest->def = nir_ball(&b->nb, src[0]);
502
break;
503
504
case SpvOpOuterProduct: {
505
for (unsigned i = 0; i < src[1]->num_components; i++) {
506
dest->elems[i]->def =
507
nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
508
}
509
break;
510
}
511
512
case SpvOpDot:
513
dest->def = nir_fdot(&b->nb, src[0], src[1]);
514
break;
515
516
case SpvOpIAddCarry:
517
vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
518
dest->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
519
dest->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
520
break;
521
522
case SpvOpISubBorrow:
523
vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
524
dest->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
525
dest->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
526
break;
527
528
case SpvOpUMulExtended: {
529
vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
530
nir_ssa_def *umul = nir_umul_2x32_64(&b->nb, src[0], src[1]);
531
dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
532
dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
533
break;
534
}
535
536
case SpvOpSMulExtended: {
537
vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
538
nir_ssa_def *smul = nir_imul_2x32_64(&b->nb, src[0], src[1]);
539
dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, smul);
540
dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, smul);
541
break;
542
}
543
544
case SpvOpFwidth:
545
dest->def = nir_fadd(&b->nb,
546
nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
547
nir_fabs(&b->nb, nir_fddy(&b->nb, src[0])));
548
break;
549
case SpvOpFwidthFine:
550
dest->def = nir_fadd(&b->nb,
551
nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
552
nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0])));
553
break;
554
case SpvOpFwidthCoarse:
555
dest->def = nir_fadd(&b->nb,
556
nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
557
nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0])));
558
break;
559
560
case SpvOpVectorTimesScalar:
561
/* The builder will take care of splatting for us. */
562
dest->def = nir_fmul(&b->nb, src[0], src[1]);
563
break;
564
565
case SpvOpIsNan: {
566
const bool save_exact = b->nb.exact;
567
568
b->nb.exact = true;
569
dest->def = nir_fneu(&b->nb, src[0], src[0]);
570
b->nb.exact = save_exact;
571
break;
572
}
573
574
case SpvOpOrdered: {
575
const bool save_exact = b->nb.exact;
576
577
b->nb.exact = true;
578
dest->def = nir_iand(&b->nb, nir_feq(&b->nb, src[0], src[0]),
579
nir_feq(&b->nb, src[1], src[1]));
580
b->nb.exact = save_exact;
581
break;
582
}
583
584
case SpvOpUnordered: {
585
const bool save_exact = b->nb.exact;
586
587
b->nb.exact = true;
588
dest->def = nir_ior(&b->nb, nir_fneu(&b->nb, src[0], src[0]),
589
nir_fneu(&b->nb, src[1], src[1]));
590
b->nb.exact = save_exact;
591
break;
592
}
593
594
case SpvOpIsInf: {
595
nir_ssa_def *inf = nir_imm_floatN_t(&b->nb, INFINITY, src[0]->bit_size);
596
dest->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf);
597
break;
598
}
599
600
case SpvOpFUnordEqual:
601
case SpvOpFUnordNotEqual:
602
case SpvOpFUnordLessThan:
603
case SpvOpFUnordGreaterThan:
604
case SpvOpFUnordLessThanEqual:
605
case SpvOpFUnordGreaterThanEqual: {
606
bool swap;
607
bool unused_exact;
608
unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
609
unsigned dst_bit_size = glsl_get_bit_size(dest_type);
610
nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
611
&unused_exact,
612
src_bit_size, dst_bit_size);
613
614
if (swap) {
615
nir_ssa_def *tmp = src[0];
616
src[0] = src[1];
617
src[1] = tmp;
618
}
619
620
const bool save_exact = b->nb.exact;
621
622
b->nb.exact = true;
623
624
dest->def =
625
nir_ior(&b->nb,
626
nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
627
nir_ior(&b->nb,
628
nir_fneu(&b->nb, src[0], src[0]),
629
nir_fneu(&b->nb, src[1], src[1])));
630
631
b->nb.exact = save_exact;
632
break;
633
}
634
635
case SpvOpLessOrGreater:
636
case SpvOpFOrdNotEqual: {
637
/* For all the SpvOpFOrd* comparisons apart from NotEqual, the value
638
* from the ALU will probably already be false if the operands are not
639
* ordered so we don’t need to handle it specially.
640
*/
641
bool swap;
642
bool exact;
643
unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
644
unsigned dst_bit_size = glsl_get_bit_size(dest_type);
645
nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact,
646
src_bit_size, dst_bit_size);
647
648
assert(!swap);
649
assert(exact);
650
651
const bool save_exact = b->nb.exact;
652
653
b->nb.exact = true;
654
655
dest->def =
656
nir_iand(&b->nb,
657
nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
658
nir_iand(&b->nb,
659
nir_feq(&b->nb, src[0], src[0]),
660
nir_feq(&b->nb, src[1], src[1])));
661
662
b->nb.exact = save_exact;
663
break;
664
}
665
666
case SpvOpUConvert:
667
case SpvOpConvertFToU:
668
case SpvOpConvertFToS:
669
case SpvOpConvertSToF:
670
case SpvOpConvertUToF:
671
case SpvOpSConvert:
672
case SpvOpFConvert:
673
case SpvOpSatConvertSToU:
674
case SpvOpSatConvertUToS: {
675
unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
676
unsigned dst_bit_size = glsl_get_bit_size(dest_type);
677
nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
678
nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
679
680
struct conversion_opts opts = {
681
.rounding_mode = nir_rounding_mode_undef,
682
.saturate = false,
683
};
684
vtn_foreach_decoration(b, dest_val, handle_conversion_opts, &opts);
685
686
if (opcode == SpvOpSatConvertSToU || opcode == SpvOpSatConvertUToS)
687
opts.saturate = true;
688
689
if (b->shader->info.stage == MESA_SHADER_KERNEL) {
690
if (opts.rounding_mode == nir_rounding_mode_undef && !opts.saturate) {
691
nir_op op = nir_type_conversion_op(src_type, dst_type,
692
nir_rounding_mode_undef);
693
dest->def = nir_build_alu(&b->nb, op, src[0], NULL, NULL, NULL);
694
} else {
695
dest->def = nir_convert_alu_types(&b->nb, dst_bit_size, src[0],
696
src_type, dst_type,
697
opts.rounding_mode, opts.saturate);
698
}
699
} else {
700
vtn_fail_if(opts.rounding_mode != nir_rounding_mode_undef &&
701
dst_type != nir_type_float16,
702
"Rounding modes are only allowed on conversions to "
703
"16-bit float types");
704
nir_op op = nir_type_conversion_op(src_type, dst_type,
705
opts.rounding_mode);
706
dest->def = nir_build_alu(&b->nb, op, src[0], NULL, NULL, NULL);
707
}
708
break;
709
}
710
711
case SpvOpBitFieldInsert:
712
case SpvOpBitFieldSExtract:
713
case SpvOpBitFieldUExtract:
714
case SpvOpShiftLeftLogical:
715
case SpvOpShiftRightArithmetic:
716
case SpvOpShiftRightLogical: {
717
bool swap;
718
bool exact;
719
unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type);
720
unsigned dst_bit_size = glsl_get_bit_size(dest_type);
721
nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact,
722
src0_bit_size, dst_bit_size);
723
724
assert(!exact);
725
726
assert (op == nir_op_ushr || op == nir_op_ishr || op == nir_op_ishl ||
727
op == nir_op_bitfield_insert || op == nir_op_ubitfield_extract ||
728
op == nir_op_ibitfield_extract);
729
730
for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
731
unsigned src_bit_size =
732
nir_alu_type_get_type_size(nir_op_infos[op].input_types[i]);
733
if (src_bit_size == 0)
734
continue;
735
if (src_bit_size != src[i]->bit_size) {
736
assert(src_bit_size == 32);
737
/* Convert the Shift, Offset and Count operands to 32 bits, which is the bitsize
738
* supported by the NIR instructions. See discussion here:
739
*
740
* https://lists.freedesktop.org/archives/mesa-dev/2018-April/193026.html
741
*/
742
src[i] = nir_u2u32(&b->nb, src[i]);
743
}
744
}
745
dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
746
break;
747
}
748
749
case SpvOpSignBitSet:
750
dest->def = nir_i2b(&b->nb,
751
nir_ushr(&b->nb, src[0], nir_imm_int(&b->nb, src[0]->bit_size - 1)));
752
break;
753
754
case SpvOpUCountTrailingZerosINTEL:
755
dest->def = nir_umin(&b->nb,
756
nir_find_lsb(&b->nb, src[0]),
757
nir_imm_int(&b->nb, 32u));
758
break;
759
760
case SpvOpBitCount: {
761
/* bit_count always returns int32, but the SPIR-V opcode just says the return
762
* value needs to be big enough to store the number of bits.
763
*/
764
dest->def = nir_u2u(&b->nb, nir_bit_count(&b->nb, src[0]), glsl_get_bit_size(dest_type));
765
break;
766
}
767
768
default: {
769
bool swap;
770
bool exact;
771
unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
772
unsigned dst_bit_size = glsl_get_bit_size(dest_type);
773
nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
774
&exact,
775
src_bit_size, dst_bit_size);
776
777
if (swap) {
778
nir_ssa_def *tmp = src[0];
779
src[0] = src[1];
780
src[1] = tmp;
781
}
782
783
switch (op) {
784
case nir_op_ishl:
785
case nir_op_ishr:
786
case nir_op_ushr:
787
if (src[1]->bit_size != 32)
788
src[1] = nir_u2u32(&b->nb, src[1]);
789
break;
790
default:
791
break;
792
}
793
794
const bool save_exact = b->nb.exact;
795
796
if (exact)
797
b->nb.exact = true;
798
799
dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
800
801
b->nb.exact = save_exact;
802
break;
803
} /* default */
804
}
805
806
switch (opcode) {
807
case SpvOpIAdd:
808
case SpvOpIMul:
809
case SpvOpISub:
810
case SpvOpShiftLeftLogical:
811
case SpvOpSNegate: {
812
nir_alu_instr *alu = nir_instr_as_alu(dest->def->parent_instr);
813
vtn_foreach_decoration(b, dest_val, handle_no_wrap, alu);
814
break;
815
}
816
default:
817
/* Do nothing. */
818
break;
819
}
820
821
vtn_push_ssa_value(b, w[2], dest);
822
823
b->nb.exact = b->exact;
824
}
825
826
void
827
vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count)
828
{
829
vtn_assert(count == 4);
830
/* From the definition of OpBitcast in the SPIR-V 1.2 spec:
831
*
832
* "If Result Type has the same number of components as Operand, they
833
* must also have the same component width, and results are computed per
834
* component.
835
*
836
* If Result Type has a different number of components than Operand, the
837
* total number of bits in Result Type must equal the total number of
838
* bits in Operand. Let L be the type, either Result Type or Operand’s
839
* type, that has the larger number of components. Let S be the other
840
* type, with the smaller number of components. The number of components
841
* in L must be an integer multiple of the number of components in S.
842
* The first component (that is, the only or lowest-numbered component)
843
* of S maps to the first components of L, and so on, up to the last
844
* component of S mapping to the last components of L. Within this
845
* mapping, any single component of S (mapping to multiple components of
846
* L) maps its lower-ordered bits to the lower-numbered components of L."
847
*/
848
849
struct vtn_type *type = vtn_get_type(b, w[1]);
850
struct nir_ssa_def *src = vtn_get_nir_ssa(b, w[3]);
851
852
vtn_fail_if(src->num_components * src->bit_size !=
853
glsl_get_vector_elements(type->type) * glsl_get_bit_size(type->type),
854
"Source and destination of OpBitcast must have the same "
855
"total number of bits");
856
nir_ssa_def *val =
857
nir_bitcast_vector(&b->nb, src, glsl_get_bit_size(type->type));
858
vtn_push_nir_ssa(b, w[2], val);
859
}
860
861