Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
PojavLauncherTeam
GitHub Repository: PojavLauncherTeam/mesa
Path: blob/21.2-virgl/src/compiler/spirv/vtn_subgroup.c
4545 views
1
/*
2
* Copyright © 2016 Intel Corporation
3
*
4
* Permission is hereby granted, free of charge, to any person obtaining a
5
* copy of this software and associated documentation files (the "Software"),
6
* to deal in the Software without restriction, including without limitation
7
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
8
* and/or sell copies of the Software, and to permit persons to whom the
9
* Software is furnished to do so, subject to the following conditions:
10
*
11
* The above copyright notice and this permission notice (including the next
12
* paragraph) shall be included in all copies or substantial portions of the
13
* Software.
14
*
15
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21
* IN THE SOFTWARE.
22
*/
23
24
#include "vtn_private.h"
25
26
static struct vtn_ssa_value *
27
vtn_build_subgroup_instr(struct vtn_builder *b,
28
nir_intrinsic_op nir_op,
29
struct vtn_ssa_value *src0,
30
nir_ssa_def *index,
31
unsigned const_idx0,
32
unsigned const_idx1)
33
{
34
/* Some of the subgroup operations take an index. SPIR-V allows this to be
35
* any integer type. To make things simpler for drivers, we only support
36
* 32-bit indices.
37
*/
38
if (index && index->bit_size != 32)
39
index = nir_u2u32(&b->nb, index);
40
41
struct vtn_ssa_value *dst = vtn_create_ssa_value(b, src0->type);
42
43
vtn_assert(dst->type == src0->type);
44
if (!glsl_type_is_vector_or_scalar(dst->type)) {
45
for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
46
dst->elems[0] =
47
vtn_build_subgroup_instr(b, nir_op, src0->elems[i], index,
48
const_idx0, const_idx1);
49
}
50
return dst;
51
}
52
53
nir_intrinsic_instr *intrin =
54
nir_intrinsic_instr_create(b->nb.shader, nir_op);
55
nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
56
dst->type, NULL);
57
intrin->num_components = intrin->dest.ssa.num_components;
58
59
intrin->src[0] = nir_src_for_ssa(src0->def);
60
if (index)
61
intrin->src[1] = nir_src_for_ssa(index);
62
63
intrin->const_index[0] = const_idx0;
64
intrin->const_index[1] = const_idx1;
65
66
nir_builder_instr_insert(&b->nb, &intrin->instr);
67
68
dst->def = &intrin->dest.ssa;
69
70
return dst;
71
}
72
73
void
74
vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
75
const uint32_t *w, unsigned count)
76
{
77
struct vtn_type *dest_type = vtn_get_type(b, w[1]);
78
79
switch (opcode) {
80
case SpvOpGroupNonUniformElect: {
81
vtn_fail_if(dest_type->type != glsl_bool_type(),
82
"OpGroupNonUniformElect must return a Bool");
83
nir_intrinsic_instr *elect =
84
nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_elect);
85
nir_ssa_dest_init_for_type(&elect->instr, &elect->dest,
86
dest_type->type, NULL);
87
nir_builder_instr_insert(&b->nb, &elect->instr);
88
vtn_push_nir_ssa(b, w[2], &elect->dest.ssa);
89
break;
90
}
91
92
case SpvOpGroupNonUniformBallot:
93
case SpvOpSubgroupBallotKHR: {
94
bool has_scope = (opcode != SpvOpSubgroupBallotKHR);
95
vtn_fail_if(dest_type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
96
"OpGroupNonUniformBallot must return a uvec4");
97
nir_intrinsic_instr *ballot =
98
nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
99
ballot->src[0] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[3 + has_scope]));
100
nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL);
101
ballot->num_components = 4;
102
nir_builder_instr_insert(&b->nb, &ballot->instr);
103
vtn_push_nir_ssa(b, w[2], &ballot->dest.ssa);
104
break;
105
}
106
107
case SpvOpGroupNonUniformInverseBallot: {
108
/* This one is just a BallotBitfieldExtract with subgroup invocation.
109
* We could add a NIR intrinsic but it's easier to just lower it on the
110
* spot.
111
*/
112
nir_intrinsic_instr *intrin =
113
nir_intrinsic_instr_create(b->nb.shader,
114
nir_intrinsic_ballot_bitfield_extract);
115
116
intrin->src[0] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[4]));
117
intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb));
118
119
nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
120
dest_type->type, NULL);
121
nir_builder_instr_insert(&b->nb, &intrin->instr);
122
123
vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);
124
break;
125
}
126
127
case SpvOpGroupNonUniformBallotBitExtract:
128
case SpvOpGroupNonUniformBallotBitCount:
129
case SpvOpGroupNonUniformBallotFindLSB:
130
case SpvOpGroupNonUniformBallotFindMSB: {
131
nir_ssa_def *src0, *src1 = NULL;
132
nir_intrinsic_op op;
133
switch (opcode) {
134
case SpvOpGroupNonUniformBallotBitExtract:
135
op = nir_intrinsic_ballot_bitfield_extract;
136
src0 = vtn_get_nir_ssa(b, w[4]);
137
src1 = vtn_get_nir_ssa(b, w[5]);
138
break;
139
case SpvOpGroupNonUniformBallotBitCount:
140
switch ((SpvGroupOperation)w[4]) {
141
case SpvGroupOperationReduce:
142
op = nir_intrinsic_ballot_bit_count_reduce;
143
break;
144
case SpvGroupOperationInclusiveScan:
145
op = nir_intrinsic_ballot_bit_count_inclusive;
146
break;
147
case SpvGroupOperationExclusiveScan:
148
op = nir_intrinsic_ballot_bit_count_exclusive;
149
break;
150
default:
151
unreachable("Invalid group operation");
152
}
153
src0 = vtn_get_nir_ssa(b, w[5]);
154
break;
155
case SpvOpGroupNonUniformBallotFindLSB:
156
op = nir_intrinsic_ballot_find_lsb;
157
src0 = vtn_get_nir_ssa(b, w[4]);
158
break;
159
case SpvOpGroupNonUniformBallotFindMSB:
160
op = nir_intrinsic_ballot_find_msb;
161
src0 = vtn_get_nir_ssa(b, w[4]);
162
break;
163
default:
164
unreachable("Unhandled opcode");
165
}
166
167
nir_intrinsic_instr *intrin =
168
nir_intrinsic_instr_create(b->nb.shader, op);
169
170
intrin->src[0] = nir_src_for_ssa(src0);
171
if (src1)
172
intrin->src[1] = nir_src_for_ssa(src1);
173
174
nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
175
dest_type->type, NULL);
176
nir_builder_instr_insert(&b->nb, &intrin->instr);
177
178
vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);
179
break;
180
}
181
182
case SpvOpGroupNonUniformBroadcastFirst:
183
case SpvOpSubgroupFirstInvocationKHR: {
184
bool has_scope = (opcode != SpvOpSubgroupFirstInvocationKHR);
185
vtn_push_ssa_value(b, w[2],
186
vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
187
vtn_ssa_value(b, w[3 + has_scope]),
188
NULL, 0, 0));
189
break;
190
}
191
192
case SpvOpGroupNonUniformBroadcast:
193
case SpvOpGroupBroadcast:
194
case SpvOpSubgroupReadInvocationKHR: {
195
bool has_scope = (opcode != SpvOpSubgroupReadInvocationKHR);
196
vtn_push_ssa_value(b, w[2],
197
vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
198
vtn_ssa_value(b, w[3 + has_scope]),
199
vtn_get_nir_ssa(b, w[4 + has_scope]), 0, 0));
200
break;
201
}
202
203
case SpvOpGroupNonUniformAll:
204
case SpvOpGroupNonUniformAny:
205
case SpvOpGroupNonUniformAllEqual:
206
case SpvOpGroupAll:
207
case SpvOpGroupAny:
208
case SpvOpSubgroupAllKHR:
209
case SpvOpSubgroupAnyKHR:
210
case SpvOpSubgroupAllEqualKHR: {
211
vtn_fail_if(dest_type->type != glsl_bool_type(),
212
"OpGroupNonUniform(All|Any|AllEqual) must return a bool");
213
nir_intrinsic_op op;
214
switch (opcode) {
215
case SpvOpGroupNonUniformAll:
216
case SpvOpGroupAll:
217
case SpvOpSubgroupAllKHR:
218
op = nir_intrinsic_vote_all;
219
break;
220
case SpvOpGroupNonUniformAny:
221
case SpvOpGroupAny:
222
case SpvOpSubgroupAnyKHR:
223
op = nir_intrinsic_vote_any;
224
break;
225
case SpvOpSubgroupAllEqualKHR:
226
op = nir_intrinsic_vote_ieq;
227
break;
228
case SpvOpGroupNonUniformAllEqual:
229
switch (glsl_get_base_type(vtn_ssa_value(b, w[4])->type)) {
230
case GLSL_TYPE_FLOAT:
231
case GLSL_TYPE_FLOAT16:
232
case GLSL_TYPE_DOUBLE:
233
op = nir_intrinsic_vote_feq;
234
break;
235
case GLSL_TYPE_UINT:
236
case GLSL_TYPE_INT:
237
case GLSL_TYPE_UINT8:
238
case GLSL_TYPE_INT8:
239
case GLSL_TYPE_UINT16:
240
case GLSL_TYPE_INT16:
241
case GLSL_TYPE_UINT64:
242
case GLSL_TYPE_INT64:
243
case GLSL_TYPE_BOOL:
244
op = nir_intrinsic_vote_ieq;
245
break;
246
default:
247
unreachable("Unhandled type");
248
}
249
break;
250
default:
251
unreachable("Unhandled opcode");
252
}
253
254
nir_ssa_def *src0;
255
if (opcode == SpvOpGroupNonUniformAll || opcode == SpvOpGroupAll ||
256
opcode == SpvOpGroupNonUniformAny || opcode == SpvOpGroupAny ||
257
opcode == SpvOpGroupNonUniformAllEqual) {
258
src0 = vtn_get_nir_ssa(b, w[4]);
259
} else {
260
src0 = vtn_get_nir_ssa(b, w[3]);
261
}
262
nir_intrinsic_instr *intrin =
263
nir_intrinsic_instr_create(b->nb.shader, op);
264
if (nir_intrinsic_infos[op].src_components[0] == 0)
265
intrin->num_components = src0->num_components;
266
intrin->src[0] = nir_src_for_ssa(src0);
267
nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
268
dest_type->type, NULL);
269
nir_builder_instr_insert(&b->nb, &intrin->instr);
270
271
vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);
272
break;
273
}
274
275
case SpvOpGroupNonUniformShuffle:
276
case SpvOpGroupNonUniformShuffleXor:
277
case SpvOpGroupNonUniformShuffleUp:
278
case SpvOpGroupNonUniformShuffleDown: {
279
nir_intrinsic_op op;
280
switch (opcode) {
281
case SpvOpGroupNonUniformShuffle:
282
op = nir_intrinsic_shuffle;
283
break;
284
case SpvOpGroupNonUniformShuffleXor:
285
op = nir_intrinsic_shuffle_xor;
286
break;
287
case SpvOpGroupNonUniformShuffleUp:
288
op = nir_intrinsic_shuffle_up;
289
break;
290
case SpvOpGroupNonUniformShuffleDown:
291
op = nir_intrinsic_shuffle_down;
292
break;
293
default:
294
unreachable("Invalid opcode");
295
}
296
vtn_push_ssa_value(b, w[2],
297
vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]),
298
vtn_get_nir_ssa(b, w[5]), 0, 0));
299
break;
300
}
301
302
case SpvOpSubgroupShuffleINTEL:
303
case SpvOpSubgroupShuffleXorINTEL: {
304
nir_intrinsic_op op = opcode == SpvOpSubgroupShuffleINTEL ?
305
nir_intrinsic_shuffle : nir_intrinsic_shuffle_xor;
306
vtn_push_ssa_value(b, w[2],
307
vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[3]),
308
vtn_get_nir_ssa(b, w[4]), 0, 0));
309
break;
310
}
311
312
case SpvOpSubgroupShuffleUpINTEL:
313
case SpvOpSubgroupShuffleDownINTEL: {
314
/* TODO: Move this lower on the compiler stack, where we can move the
315
* current/other data to adjacent registers to avoid doing a shuffle
316
* twice.
317
*/
318
319
nir_builder *nb = &b->nb;
320
nir_ssa_def *size = nir_load_subgroup_size(nb);
321
nir_ssa_def *delta = vtn_get_nir_ssa(b, w[5]);
322
323
/* Rewrite UP in terms of DOWN.
324
*
325
* UP(a, b, delta) == DOWN(a, b, size - delta)
326
*/
327
if (opcode == SpvOpSubgroupShuffleUpINTEL)
328
delta = nir_isub(nb, size, delta);
329
330
nir_ssa_def *index = nir_iadd(nb, nir_load_subgroup_invocation(nb), delta);
331
struct vtn_ssa_value *current =
332
vtn_build_subgroup_instr(b, nir_intrinsic_shuffle, vtn_ssa_value(b, w[3]),
333
index, 0, 0);
334
335
struct vtn_ssa_value *next =
336
vtn_build_subgroup_instr(b, nir_intrinsic_shuffle, vtn_ssa_value(b, w[4]),
337
nir_isub(nb, index, size), 0, 0);
338
339
nir_ssa_def *cond = nir_ilt(nb, index, size);
340
vtn_push_nir_ssa(b, w[2], nir_bcsel(nb, cond, current->def, next->def));
341
342
break;
343
}
344
345
case SpvOpGroupNonUniformQuadBroadcast:
346
vtn_push_ssa_value(b, w[2],
347
vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
348
vtn_ssa_value(b, w[4]),
349
vtn_get_nir_ssa(b, w[5]), 0, 0));
350
break;
351
352
case SpvOpGroupNonUniformQuadSwap: {
353
unsigned direction = vtn_constant_uint(b, w[5]);
354
nir_intrinsic_op op;
355
switch (direction) {
356
case 0:
357
op = nir_intrinsic_quad_swap_horizontal;
358
break;
359
case 1:
360
op = nir_intrinsic_quad_swap_vertical;
361
break;
362
case 2:
363
op = nir_intrinsic_quad_swap_diagonal;
364
break;
365
default:
366
vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap");
367
}
368
vtn_push_ssa_value(b, w[2],
369
vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]), NULL, 0, 0));
370
break;
371
}
372
373
case SpvOpGroupNonUniformIAdd:
374
case SpvOpGroupNonUniformFAdd:
375
case SpvOpGroupNonUniformIMul:
376
case SpvOpGroupNonUniformFMul:
377
case SpvOpGroupNonUniformSMin:
378
case SpvOpGroupNonUniformUMin:
379
case SpvOpGroupNonUniformFMin:
380
case SpvOpGroupNonUniformSMax:
381
case SpvOpGroupNonUniformUMax:
382
case SpvOpGroupNonUniformFMax:
383
case SpvOpGroupNonUniformBitwiseAnd:
384
case SpvOpGroupNonUniformBitwiseOr:
385
case SpvOpGroupNonUniformBitwiseXor:
386
case SpvOpGroupNonUniformLogicalAnd:
387
case SpvOpGroupNonUniformLogicalOr:
388
case SpvOpGroupNonUniformLogicalXor:
389
case SpvOpGroupIAdd:
390
case SpvOpGroupFAdd:
391
case SpvOpGroupFMin:
392
case SpvOpGroupUMin:
393
case SpvOpGroupSMin:
394
case SpvOpGroupFMax:
395
case SpvOpGroupUMax:
396
case SpvOpGroupSMax:
397
case SpvOpGroupIAddNonUniformAMD:
398
case SpvOpGroupFAddNonUniformAMD:
399
case SpvOpGroupFMinNonUniformAMD:
400
case SpvOpGroupUMinNonUniformAMD:
401
case SpvOpGroupSMinNonUniformAMD:
402
case SpvOpGroupFMaxNonUniformAMD:
403
case SpvOpGroupUMaxNonUniformAMD:
404
case SpvOpGroupSMaxNonUniformAMD: {
405
nir_op reduction_op;
406
switch (opcode) {
407
case SpvOpGroupNonUniformIAdd:
408
case SpvOpGroupIAdd:
409
case SpvOpGroupIAddNonUniformAMD:
410
reduction_op = nir_op_iadd;
411
break;
412
case SpvOpGroupNonUniformFAdd:
413
case SpvOpGroupFAdd:
414
case SpvOpGroupFAddNonUniformAMD:
415
reduction_op = nir_op_fadd;
416
break;
417
case SpvOpGroupNonUniformIMul:
418
reduction_op = nir_op_imul;
419
break;
420
case SpvOpGroupNonUniformFMul:
421
reduction_op = nir_op_fmul;
422
break;
423
case SpvOpGroupNonUniformSMin:
424
case SpvOpGroupSMin:
425
case SpvOpGroupSMinNonUniformAMD:
426
reduction_op = nir_op_imin;
427
break;
428
case SpvOpGroupNonUniformUMin:
429
case SpvOpGroupUMin:
430
case SpvOpGroupUMinNonUniformAMD:
431
reduction_op = nir_op_umin;
432
break;
433
case SpvOpGroupNonUniformFMin:
434
case SpvOpGroupFMin:
435
case SpvOpGroupFMinNonUniformAMD:
436
reduction_op = nir_op_fmin;
437
break;
438
case SpvOpGroupNonUniformSMax:
439
case SpvOpGroupSMax:
440
case SpvOpGroupSMaxNonUniformAMD:
441
reduction_op = nir_op_imax;
442
break;
443
case SpvOpGroupNonUniformUMax:
444
case SpvOpGroupUMax:
445
case SpvOpGroupUMaxNonUniformAMD:
446
reduction_op = nir_op_umax;
447
break;
448
case SpvOpGroupNonUniformFMax:
449
case SpvOpGroupFMax:
450
case SpvOpGroupFMaxNonUniformAMD:
451
reduction_op = nir_op_fmax;
452
break;
453
case SpvOpGroupNonUniformBitwiseAnd:
454
case SpvOpGroupNonUniformLogicalAnd:
455
reduction_op = nir_op_iand;
456
break;
457
case SpvOpGroupNonUniformBitwiseOr:
458
case SpvOpGroupNonUniformLogicalOr:
459
reduction_op = nir_op_ior;
460
break;
461
case SpvOpGroupNonUniformBitwiseXor:
462
case SpvOpGroupNonUniformLogicalXor:
463
reduction_op = nir_op_ixor;
464
break;
465
default:
466
unreachable("Invalid reduction operation");
467
}
468
469
nir_intrinsic_op op;
470
unsigned cluster_size = 0;
471
switch ((SpvGroupOperation)w[4]) {
472
case SpvGroupOperationReduce:
473
op = nir_intrinsic_reduce;
474
break;
475
case SpvGroupOperationInclusiveScan:
476
op = nir_intrinsic_inclusive_scan;
477
break;
478
case SpvGroupOperationExclusiveScan:
479
op = nir_intrinsic_exclusive_scan;
480
break;
481
case SpvGroupOperationClusteredReduce:
482
op = nir_intrinsic_reduce;
483
assert(count == 7);
484
cluster_size = vtn_constant_uint(b, w[6]);
485
break;
486
default:
487
unreachable("Invalid group operation");
488
}
489
490
vtn_push_ssa_value(b, w[2],
491
vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[5]), NULL,
492
reduction_op, cluster_size));
493
break;
494
}
495
496
default:
497
unreachable("Invalid SPIR-V opcode");
498
}
499
}
500
501