Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
PojavLauncherTeam
GitHub Repository: PojavLauncherTeam/mesa
Path: blob/21.2-virgl/src/amd/common/ac_nir_lower_ngg.c
7236 views
1
/*
2
* Copyright © 2021 Valve Corporation
3
*
4
* Permission is hereby granted, free of charge, to any person obtaining a
5
* copy of this software and associated documentation files (the "Software"),
6
* to deal in the Software without restriction, including without limitation
7
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
8
* and/or sell copies of the Software, and to permit persons to whom the
9
* Software is furnished to do so, subject to the following conditions:
10
*
11
* The above copyright notice and this permission notice (including the next
12
* paragraph) shall be included in all copies or substantial portions of the
13
* Software.
14
*
15
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21
* IN THE SOFTWARE.
22
*
23
*/
24
25
#include "ac_nir.h"
26
#include "nir_builder.h"
27
#include "u_math.h"
28
#include "u_vector.h"
29
30
enum {
31
nggc_passflag_used_by_pos = 1,
32
nggc_passflag_used_by_other = 2,
33
nggc_passflag_used_by_both = nggc_passflag_used_by_pos | nggc_passflag_used_by_other,
34
};
35
36
typedef struct
37
{
38
nir_ssa_def *ssa;
39
nir_variable *var;
40
} saved_uniform;
41
42
typedef struct
43
{
44
nir_variable *position_value_var;
45
nir_variable *prim_exp_arg_var;
46
nir_variable *es_accepted_var;
47
nir_variable *gs_accepted_var;
48
49
struct u_vector saved_uniforms;
50
51
bool passthrough;
52
bool export_prim_id;
53
bool early_prim_export;
54
unsigned wave_size;
55
unsigned max_num_waves;
56
unsigned num_vertices_per_primitives;
57
unsigned provoking_vtx_idx;
58
unsigned max_es_num_vertices;
59
unsigned total_lds_bytes;
60
61
uint64_t inputs_needed_by_pos;
62
uint64_t inputs_needed_by_others;
63
} lower_ngg_nogs_state;
64
65
typedef struct
66
{
67
/* bitsize of this component (max 32), or 0 if it's never written at all */
68
uint8_t bit_size : 6;
69
/* output stream index */
70
uint8_t stream : 2;
71
} gs_output_component_info;
72
73
typedef struct
74
{
75
nir_variable *output_vars[VARYING_SLOT_MAX][4];
76
nir_variable *current_clear_primflag_idx_var;
77
int const_out_vtxcnt[4];
78
int const_out_prmcnt[4];
79
unsigned wave_size;
80
unsigned max_num_waves;
81
unsigned num_vertices_per_primitive;
82
unsigned lds_addr_gs_out_vtx;
83
unsigned lds_addr_gs_scratch;
84
unsigned lds_bytes_per_gs_out_vertex;
85
unsigned lds_offs_primflags;
86
bool found_out_vtxcnt[4];
87
bool output_compile_time_known;
88
bool provoking_vertex_last;
89
gs_output_component_info output_component_info[VARYING_SLOT_MAX][4];
90
} lower_ngg_gs_state;
91
92
typedef struct {
93
nir_variable *pre_cull_position_value_var;
94
} remove_culling_shader_outputs_state;
95
96
typedef struct {
97
nir_variable *pos_value_replacement;
98
} remove_extra_position_output_state;
99
100
typedef struct {
101
nir_ssa_def *reduction_result;
102
nir_ssa_def *excl_scan_result;
103
} wg_scan_result;
104
105
/* Per-vertex LDS layout of culling shaders */
106
enum {
107
/* Position of the ES vertex (at the beginning for alignment reasons) */
108
lds_es_pos_x = 0,
109
lds_es_pos_y = 4,
110
lds_es_pos_z = 8,
111
lds_es_pos_w = 12,
112
113
/* 1 when the vertex is accepted, 0 if it should be culled */
114
lds_es_vertex_accepted = 16,
115
/* ID of the thread which will export the current thread's vertex */
116
lds_es_exporter_tid = 17,
117
118
/* Repacked arguments - also listed separately for VS and TES */
119
lds_es_arg_0 = 20,
120
121
/* VS arguments which need to be repacked */
122
lds_es_vs_vertex_id = 20,
123
lds_es_vs_instance_id = 24,
124
125
/* TES arguments which need to be repacked */
126
lds_es_tes_u = 20,
127
lds_es_tes_v = 24,
128
lds_es_tes_rel_patch_id = 28,
129
lds_es_tes_patch_id = 32,
130
};
131
132
typedef struct {
133
nir_ssa_def *num_repacked_invocations;
134
nir_ssa_def *repacked_invocation_index;
135
} wg_repack_result;
136
137
/**
138
* Repacks invocations in the current workgroup to eliminate gaps between them.
139
*
140
* Uses 1 dword of LDS per 4 waves (1 byte of LDS per wave).
141
* Assumes that all invocations in the workgroup are active (exec = -1).
142
*/
143
static wg_repack_result
144
repack_invocations_in_workgroup(nir_builder *b, nir_ssa_def *input_bool,
145
unsigned lds_addr_base, unsigned max_num_waves,
146
unsigned wave_size)
147
{
148
/* Input boolean: 1 if the current invocation should survive the repack. */
149
assert(input_bool->bit_size == 1);
150
151
/* STEP 1. Count surviving invocations in the current wave.
152
*
153
* Implemented by a scalar instruction that simply counts the number of bits set in a 32/64-bit mask.
154
*/
155
156
nir_ssa_def *input_mask = nir_build_ballot(b, 1, wave_size, input_bool);
157
nir_ssa_def *surviving_invocations_in_current_wave = nir_bit_count(b, input_mask);
158
159
/* If we know at compile time that the workgroup has only 1 wave, no further steps are necessary. */
160
if (max_num_waves == 1) {
161
wg_repack_result r = {
162
.num_repacked_invocations = surviving_invocations_in_current_wave,
163
.repacked_invocation_index = nir_build_mbcnt_amd(b, input_mask, nir_imm_int(b, 0)),
164
};
165
return r;
166
}
167
168
/* STEP 2. Waves tell each other their number of surviving invocations.
169
*
170
* Each wave activates only its first lane (exec = 1), which stores the number of surviving
171
* invocations in that wave into the LDS, then reads the numbers from every wave.
172
*
173
* The workgroup size of NGG shaders is at most 256, which means
174
* the maximum number of waves is 4 in Wave64 mode and 8 in Wave32 mode.
175
* Each wave writes 1 byte, so it's up to 8 bytes, so at most 2 dwords are necessary.
176
*/
177
178
const unsigned num_lds_dwords = DIV_ROUND_UP(max_num_waves, 4);
179
assert(num_lds_dwords <= 2);
180
181
nir_ssa_def *wave_id = nir_build_load_subgroup_id(b);
182
nir_ssa_def *dont_care = nir_ssa_undef(b, 1, num_lds_dwords * 32);
183
nir_if *if_first_lane = nir_push_if(b, nir_build_elect(b, 1));
184
185
nir_build_store_shared(b, nir_u2u8(b, surviving_invocations_in_current_wave), wave_id, .base = lds_addr_base, .align_mul = 1u, .write_mask = 0x1u);
186
187
nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
188
.memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
189
190
nir_ssa_def *packed_counts = nir_build_load_shared(b, 1, num_lds_dwords * 32, nir_imm_int(b, 0), .base = lds_addr_base, .align_mul = 8u);
191
192
nir_pop_if(b, if_first_lane);
193
194
packed_counts = nir_if_phi(b, packed_counts, dont_care);
195
196
/* STEP 3. Compute the repacked invocation index and the total number of surviving invocations.
197
*
198
* By now, every wave knows the number of surviving invocations in all waves.
199
* Each number is 1 byte, and they are packed into up to 2 dwords.
200
*
201
* Each lane N will sum the number of surviving invocations from waves 0 to N-1.
202
* If the workgroup has M waves, then each wave will use only its first M+1 lanes for this.
203
* (Other lanes are not deactivated but their calculation is not used.)
204
*
205
* - We read the sum from the lane whose id is the current wave's id.
206
* Add the masked bitcount to this, and we get the repacked invocation index.
207
* - We read the sum from the lane whose id is the number of waves in the workgroup.
208
* This is the total number of surviving invocations in the workgroup.
209
*/
210
211
nir_ssa_def *num_waves = nir_build_load_num_subgroups(b);
212
213
/* sel = 0x01010101 * lane_id + 0x03020100 */
214
nir_ssa_def *lane_id = nir_load_subgroup_invocation(b);
215
nir_ssa_def *packed_id = nir_build_byte_permute_amd(b, nir_imm_int(b, 0), lane_id, nir_imm_int(b, 0));
216
nir_ssa_def *sel = nir_iadd_imm_nuw(b, packed_id, 0x03020100);
217
nir_ssa_def *sum = NULL;
218
219
if (num_lds_dwords == 1) {
220
/* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */
221
nir_ssa_def *packed_dw = nir_build_lane_permute_16_amd(b, packed_counts, nir_imm_int(b, 0), nir_imm_int(b, 0));
222
223
/* Use byte-permute to filter out the bytes not needed by the current lane. */
224
nir_ssa_def *filtered_packed = nir_build_byte_permute_amd(b, packed_dw, nir_imm_int(b, 0), sel);
225
226
/* Horizontally add the packed bytes. */
227
sum = nir_sad_u8x4(b, filtered_packed, nir_imm_int(b, 0), nir_imm_int(b, 0));
228
} else if (num_lds_dwords == 2) {
229
/* Create selectors for the byte-permutes below. */
230
nir_ssa_def *dw0_selector = nir_build_lane_permute_16_amd(b, sel, nir_imm_int(b, 0x44443210), nir_imm_int(b, 0x4));
231
nir_ssa_def *dw1_selector = nir_build_lane_permute_16_amd(b, sel, nir_imm_int(b, 0x32100000), nir_imm_int(b, 0x4));
232
233
/* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */
234
nir_ssa_def *packed_dw0 = nir_build_lane_permute_16_amd(b, nir_unpack_64_2x32_split_x(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
235
nir_ssa_def *packed_dw1 = nir_build_lane_permute_16_amd(b, nir_unpack_64_2x32_split_y(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
236
237
/* Use byte-permute to filter out the bytes not needed by the current lane. */
238
nir_ssa_def *filtered_packed_dw0 = nir_build_byte_permute_amd(b, packed_dw0, nir_imm_int(b, 0), dw0_selector);
239
nir_ssa_def *filtered_packed_dw1 = nir_build_byte_permute_amd(b, packed_dw1, nir_imm_int(b, 0), dw1_selector);
240
241
/* Horizontally add the packed bytes. */
242
sum = nir_sad_u8x4(b, filtered_packed_dw0, nir_imm_int(b, 0), nir_imm_int(b, 0));
243
sum = nir_sad_u8x4(b, filtered_packed_dw1, nir_imm_int(b, 0), sum);
244
} else {
245
unreachable("Unimplemented NGG wave count");
246
}
247
248
nir_ssa_def *wg_repacked_index_base = nir_build_read_invocation(b, sum, wave_id);
249
nir_ssa_def *wg_num_repacked_invocations = nir_build_read_invocation(b, sum, num_waves);
250
nir_ssa_def *wg_repacked_index = nir_build_mbcnt_amd(b, input_mask, wg_repacked_index_base);
251
252
wg_repack_result r = {
253
.num_repacked_invocations = wg_num_repacked_invocations,
254
.repacked_invocation_index = wg_repacked_index,
255
};
256
257
return r;
258
}
259
260
static nir_ssa_def *
261
pervertex_lds_addr(nir_builder *b, nir_ssa_def *vertex_idx, unsigned per_vtx_bytes)
262
{
263
return nir_imul_imm(b, vertex_idx, per_vtx_bytes);
264
}
265
266
static nir_ssa_def *
267
emit_pack_ngg_prim_exp_arg(nir_builder *b, unsigned num_vertices_per_primitives,
268
nir_ssa_def *vertex_indices[3], nir_ssa_def *is_null_prim)
269
{
270
nir_ssa_def *arg = vertex_indices[0];
271
272
for (unsigned i = 0; i < num_vertices_per_primitives; ++i) {
273
assert(vertex_indices[i]);
274
275
if (i)
276
arg = nir_ior(b, arg, nir_ishl(b, vertex_indices[i], nir_imm_int(b, 10u * i)));
277
278
if (b->shader->info.stage == MESA_SHADER_VERTEX) {
279
nir_ssa_def *edgeflag = nir_build_load_initial_edgeflag_amd(b, 32, nir_imm_int(b, i));
280
arg = nir_ior(b, arg, nir_ishl(b, edgeflag, nir_imm_int(b, 10u * i + 9u)));
281
}
282
}
283
284
if (is_null_prim) {
285
if (is_null_prim->bit_size == 1)
286
is_null_prim = nir_b2i32(b, is_null_prim);
287
assert(is_null_prim->bit_size == 32);
288
arg = nir_ior(b, arg, nir_ishl(b, is_null_prim, nir_imm_int(b, 31u)));
289
}
290
291
return arg;
292
}
293
294
static nir_ssa_def *
295
ngg_input_primitive_vertex_index(nir_builder *b, unsigned vertex)
296
{
297
/* TODO: This is RADV specific. We'll need to refactor RADV and/or RadeonSI to match. */
298
return nir_ubfe(b, nir_build_load_gs_vertex_offset_amd(b, .base = vertex / 2u * 2u),
299
nir_imm_int(b, (vertex % 2u) * 16u), nir_imm_int(b, 16u));
300
}
301
302
static nir_ssa_def *
303
emit_ngg_nogs_prim_exp_arg(nir_builder *b, lower_ngg_nogs_state *st)
304
{
305
if (st->passthrough) {
306
assert(!st->export_prim_id || b->shader->info.stage != MESA_SHADER_VERTEX);
307
return nir_build_load_packed_passthrough_primitive_amd(b);
308
} else {
309
nir_ssa_def *vtx_idx[3] = {0};
310
311
vtx_idx[0] = ngg_input_primitive_vertex_index(b, 0);
312
vtx_idx[1] = st->num_vertices_per_primitives >= 2
313
? ngg_input_primitive_vertex_index(b, 1)
314
: nir_imm_zero(b, 1, 32);
315
vtx_idx[2] = st->num_vertices_per_primitives >= 3
316
? ngg_input_primitive_vertex_index(b, 2)
317
: nir_imm_zero(b, 1, 32);
318
319
return emit_pack_ngg_prim_exp_arg(b, st->num_vertices_per_primitives, vtx_idx, NULL);
320
}
321
}
322
323
static void
324
emit_ngg_nogs_prim_export(nir_builder *b, lower_ngg_nogs_state *st, nir_ssa_def *arg)
325
{
326
nir_if *if_gs_thread = nir_push_if(b, nir_build_has_input_primitive_amd(b));
327
{
328
if (!arg)
329
arg = emit_ngg_nogs_prim_exp_arg(b, st);
330
331
if (st->export_prim_id && b->shader->info.stage == MESA_SHADER_VERTEX) {
332
/* Copy Primitive IDs from GS threads to the LDS address corresponding to the ES thread of the provoking vertex. */
333
nir_ssa_def *prim_id = nir_build_load_primitive_id(b);
334
nir_ssa_def *provoking_vtx_idx = ngg_input_primitive_vertex_index(b, st->provoking_vtx_idx);
335
nir_ssa_def *addr = pervertex_lds_addr(b, provoking_vtx_idx, 4u);
336
337
nir_build_store_shared(b, prim_id, addr, .write_mask = 1u, .align_mul = 4u);
338
}
339
340
nir_build_export_primitive_amd(b, arg);
341
}
342
nir_pop_if(b, if_gs_thread);
343
}
344
345
static void
346
emit_store_ngg_nogs_es_primitive_id(nir_builder *b)
347
{
348
nir_ssa_def *prim_id = NULL;
349
350
if (b->shader->info.stage == MESA_SHADER_VERTEX) {
351
/* Workgroup barrier - wait for GS threads to store primitive ID in LDS. */
352
nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP, .memory_scope = NIR_SCOPE_WORKGROUP,
353
.memory_semantics = NIR_MEMORY_ACQ_REL, .memory_modes = nir_var_mem_shared);
354
355
/* LDS address where the primitive ID is stored */
356
nir_ssa_def *thread_id_in_threadgroup = nir_build_load_local_invocation_index(b);
357
nir_ssa_def *addr = pervertex_lds_addr(b, thread_id_in_threadgroup, 4u);
358
359
/* Load primitive ID from LDS */
360
prim_id = nir_build_load_shared(b, 1, 32, addr, .align_mul = 4u);
361
} else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
362
/* Just use tess eval primitive ID, which is the same as the patch ID. */
363
prim_id = nir_build_load_primitive_id(b);
364
}
365
366
nir_io_semantics io_sem = {
367
.location = VARYING_SLOT_PRIMITIVE_ID,
368
.num_slots = 1,
369
};
370
371
nir_build_store_output(b, prim_id, nir_imm_zero(b, 1, 32),
372
.base = io_sem.location,
373
.write_mask = 1u, .src_type = nir_type_uint32, .io_semantics = io_sem);
374
}
375
376
static bool
377
remove_culling_shader_output(nir_builder *b, nir_instr *instr, void *state)
378
{
379
remove_culling_shader_outputs_state *s = (remove_culling_shader_outputs_state *) state;
380
381
if (instr->type != nir_instr_type_intrinsic)
382
return false;
383
384
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
385
386
/* These are not allowed in VS / TES */
387
assert(intrin->intrinsic != nir_intrinsic_store_per_vertex_output &&
388
intrin->intrinsic != nir_intrinsic_load_per_vertex_input);
389
390
/* We are only interested in output stores now */
391
if (intrin->intrinsic != nir_intrinsic_store_output)
392
return false;
393
394
b->cursor = nir_before_instr(instr);
395
396
/* Position output - store the value to a variable, remove output store */
397
nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
398
if (io_sem.location == VARYING_SLOT_POS) {
399
/* TODO: check if it's indirect, etc? */
400
unsigned writemask = nir_intrinsic_write_mask(intrin);
401
nir_ssa_def *store_val = intrin->src[0].ssa;
402
nir_store_var(b, s->pre_cull_position_value_var, store_val, writemask);
403
}
404
405
/* Remove all output stores */
406
nir_instr_remove(instr);
407
return true;
408
}
409
410
static void
411
remove_culling_shader_outputs(nir_shader *culling_shader, lower_ngg_nogs_state *nogs_state, nir_variable *pre_cull_position_value_var)
412
{
413
remove_culling_shader_outputs_state s = {
414
.pre_cull_position_value_var = pre_cull_position_value_var,
415
};
416
417
nir_shader_instructions_pass(culling_shader, remove_culling_shader_output,
418
nir_metadata_block_index | nir_metadata_dominance, &s);
419
420
/* Remove dead code resulting from the deleted outputs. */
421
bool progress;
422
do {
423
progress = false;
424
NIR_PASS(progress, culling_shader, nir_opt_dead_write_vars);
425
NIR_PASS(progress, culling_shader, nir_opt_dce);
426
NIR_PASS(progress, culling_shader, nir_opt_dead_cf);
427
} while (progress);
428
}
429
430
static void
431
rewrite_uses_to_var(nir_builder *b, nir_ssa_def *old_def, nir_variable *replacement_var, unsigned replacement_var_channel)
432
{
433
if (old_def->parent_instr->type == nir_instr_type_load_const)
434
return;
435
436
b->cursor = nir_after_instr(old_def->parent_instr);
437
if (b->cursor.instr->type == nir_instr_type_phi)
438
b->cursor = nir_after_phis(old_def->parent_instr->block);
439
440
nir_ssa_def *pos_val_rep = nir_load_var(b, replacement_var);
441
nir_ssa_def *replacement = nir_channel(b, pos_val_rep, replacement_var_channel);
442
443
if (old_def->num_components > 1) {
444
/* old_def uses a swizzled vector component.
445
* There is no way to replace the uses of just a single vector component,
446
* so instead create a new vector and replace all uses of the old vector.
447
*/
448
nir_ssa_def *old_def_elements[NIR_MAX_VEC_COMPONENTS] = {0};
449
for (unsigned j = 0; j < old_def->num_components; ++j)
450
old_def_elements[j] = nir_channel(b, old_def, j);
451
replacement = nir_vec(b, old_def_elements, old_def->num_components);
452
}
453
454
nir_ssa_def_rewrite_uses_after(old_def, replacement, replacement->parent_instr);
455
}
456
457
static bool
458
remove_extra_pos_output(nir_builder *b, nir_instr *instr, void *state)
459
{
460
remove_extra_position_output_state *s = (remove_extra_position_output_state *) state;
461
462
if (instr->type != nir_instr_type_intrinsic)
463
return false;
464
465
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
466
467
/* These are not allowed in VS / TES */
468
assert(intrin->intrinsic != nir_intrinsic_store_per_vertex_output &&
469
intrin->intrinsic != nir_intrinsic_load_per_vertex_input);
470
471
/* We are only interested in output stores now */
472
if (intrin->intrinsic != nir_intrinsic_store_output)
473
return false;
474
475
nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
476
if (io_sem.location != VARYING_SLOT_POS)
477
return false;
478
479
b->cursor = nir_before_instr(instr);
480
481
/* In case other outputs use what we calculated for pos,
482
* try to avoid calculating it again by rewriting the usages
483
* of the store components here.
484
*/
485
nir_ssa_def *store_val = intrin->src[0].ssa;
486
unsigned store_pos_component = nir_intrinsic_component(intrin);
487
488
nir_instr_remove(instr);
489
490
if (store_val->parent_instr->type == nir_instr_type_alu) {
491
nir_alu_instr *alu = nir_instr_as_alu(store_val->parent_instr);
492
if (nir_op_is_vec(alu->op)) {
493
/* Output store uses a vector, we can easily rewrite uses of each vector element. */
494
495
unsigned num_vec_src = 0;
496
if (alu->op == nir_op_mov)
497
num_vec_src = 1;
498
else if (alu->op == nir_op_vec2)
499
num_vec_src = 2;
500
else if (alu->op == nir_op_vec3)
501
num_vec_src = 3;
502
else if (alu->op == nir_op_vec4)
503
num_vec_src = 4;
504
assert(num_vec_src);
505
506
/* Remember the current components whose uses we wish to replace.
507
* This is needed because rewriting one source can affect the others too.
508
*/
509
nir_ssa_def *vec_comps[NIR_MAX_VEC_COMPONENTS] = {0};
510
for (unsigned i = 0; i < num_vec_src; i++)
511
vec_comps[i] = alu->src[i].src.ssa;
512
513
for (unsigned i = 0; i < num_vec_src; i++)
514
rewrite_uses_to_var(b, vec_comps[i], s->pos_value_replacement, store_pos_component + i);
515
} else {
516
rewrite_uses_to_var(b, store_val, s->pos_value_replacement, store_pos_component);
517
}
518
} else {
519
rewrite_uses_to_var(b, store_val, s->pos_value_replacement, store_pos_component);
520
}
521
522
return true;
523
}
524
525
static void
526
remove_extra_pos_outputs(nir_shader *shader, lower_ngg_nogs_state *nogs_state)
527
{
528
remove_extra_position_output_state s = {
529
.pos_value_replacement = nogs_state->position_value_var,
530
};
531
532
nir_shader_instructions_pass(shader, remove_extra_pos_output,
533
nir_metadata_block_index | nir_metadata_dominance, &s);
534
}
535
536
/**
537
* Perform vertex compaction after culling.
538
*
539
* 1. Repack surviving ES invocations (this determines which lane will export which vertex)
540
* 2. Surviving ES vertex invocations store their data to LDS
541
* 3. Emit GS_ALLOC_REQ
542
* 4. Repacked invocations load the vertex data from LDS
543
* 5. GS threads update their vertex indices
544
*/
545
static void
546
compact_vertices_after_culling(nir_builder *b,
547
lower_ngg_nogs_state *nogs_state,
548
nir_variable *vertices_in_wave_var,
549
nir_variable *primitives_in_wave_var,
550
nir_variable **repacked_arg_vars,
551
nir_variable **gs_vtxaddr_vars,
552
nir_ssa_def *invocation_index,
553
nir_ssa_def *es_vertex_lds_addr,
554
unsigned ngg_scratch_lds_base_addr,
555
unsigned pervertex_lds_bytes,
556
unsigned max_exported_args)
557
{
558
nir_variable *es_accepted_var = nogs_state->es_accepted_var;
559
nir_variable *gs_accepted_var = nogs_state->gs_accepted_var;
560
nir_variable *position_value_var = nogs_state->position_value_var;
561
nir_variable *prim_exp_arg_var = nogs_state->prim_exp_arg_var;
562
563
nir_ssa_def *es_accepted = nir_load_var(b, es_accepted_var);
564
565
/* Repack the vertices that survived the culling. */
566
wg_repack_result rep = repack_invocations_in_workgroup(b, es_accepted, ngg_scratch_lds_base_addr,
567
nogs_state->max_num_waves, nogs_state->wave_size);
568
nir_ssa_def *num_live_vertices_in_workgroup = rep.num_repacked_invocations;
569
nir_ssa_def *es_exporter_tid = rep.repacked_invocation_index;
570
571
nir_if *if_es_accepted = nir_push_if(b, es_accepted);
572
{
573
nir_ssa_def *exporter_addr = pervertex_lds_addr(b, es_exporter_tid, pervertex_lds_bytes);
574
575
/* Store the exporter thread's index to the LDS space of the current thread so GS threads can load it */
576
nir_build_store_shared(b, nir_u2u8(b, es_exporter_tid), es_vertex_lds_addr, .base = lds_es_exporter_tid, .align_mul = 1u, .write_mask = 0x1u);
577
578
/* Store the current thread's position output to the exporter thread's LDS space */
579
nir_ssa_def *pos = nir_load_var(b, position_value_var);
580
nir_build_store_shared(b, pos, exporter_addr, .base = lds_es_pos_x, .align_mul = 4u, .write_mask = 0xfu);
581
582
/* Store the current thread's repackable arguments to the exporter thread's LDS space */
583
for (unsigned i = 0; i < max_exported_args; ++i) {
584
nir_ssa_def *arg_val = nir_load_var(b, repacked_arg_vars[i]);
585
nir_build_store_shared(b, arg_val, exporter_addr, .base = lds_es_arg_0 + 4u * i, .align_mul = 4u, .write_mask = 0x1u);
586
}
587
}
588
nir_pop_if(b, if_es_accepted);
589
590
/* If all vertices are culled, set primitive count to 0 as well. */
591
nir_ssa_def *num_exported_prims = nir_build_load_workgroup_num_input_primitives_amd(b);
592
num_exported_prims = nir_bcsel(b, nir_ieq_imm(b, num_live_vertices_in_workgroup, 0u), nir_imm_int(b, 0u), num_exported_prims);
593
594
nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_build_load_subgroup_id(b), nir_imm_int(b, 0)));
595
{
596
/* Tell the final vertex and primitive count to the HW.
597
* We do this here to mask some of the latency of the LDS.
598
*/
599
nir_build_alloc_vertices_and_primitives_amd(b, num_live_vertices_in_workgroup, num_exported_prims);
600
}
601
nir_pop_if(b, if_wave_0);
602
603
/* Calculate the number of vertices and primitives left in the current wave */
604
nir_ssa_def *has_vtx_after_culling = nir_ilt(b, invocation_index, num_live_vertices_in_workgroup);
605
nir_ssa_def *has_prm_after_culling = nir_ilt(b, invocation_index, num_exported_prims);
606
nir_ssa_def *vtx_cnt = nir_bit_count(b, nir_build_ballot(b, 1, nogs_state->wave_size, has_vtx_after_culling));
607
nir_ssa_def *prm_cnt = nir_bit_count(b, nir_build_ballot(b, 1, nogs_state->wave_size, has_prm_after_culling));
608
nir_store_var(b, vertices_in_wave_var, vtx_cnt, 0x1u);
609
nir_store_var(b, primitives_in_wave_var, prm_cnt, 0x1u);
610
611
/* TODO: Consider adding a shortcut exit.
612
* Waves that have no vertices and primitives left can s_endpgm right here.
613
*/
614
615
nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
616
.memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
617
618
nir_if *if_packed_es_thread = nir_push_if(b, nir_ilt(b, invocation_index, num_live_vertices_in_workgroup));
619
{
620
/* Read position from the current ES thread's LDS space (written by the exported vertex's ES thread) */
621
nir_ssa_def *exported_pos = nir_build_load_shared(b, 4, 32, es_vertex_lds_addr, .base = lds_es_pos_x, .align_mul = 4u);
622
nir_store_var(b, position_value_var, exported_pos, 0xfu);
623
624
/* Read the repacked arguments */
625
for (unsigned i = 0; i < max_exported_args; ++i) {
626
nir_ssa_def *arg_val = nir_build_load_shared(b, 1, 32, es_vertex_lds_addr, .base = lds_es_arg_0 + 4u * i, .align_mul = 4u);
627
nir_store_var(b, repacked_arg_vars[i], arg_val, 0x1u);
628
}
629
}
630
nir_pop_if(b, if_packed_es_thread);
631
632
nir_if *if_gs_accepted = nir_push_if(b, nir_load_var(b, gs_accepted_var));
633
{
634
nir_ssa_def *exporter_vtx_indices[3] = {0};
635
636
/* Load the index of the ES threads that will export the current GS thread's vertices */
637
for (unsigned v = 0; v < 3; ++v) {
638
nir_ssa_def *vtx_addr = nir_load_var(b, gs_vtxaddr_vars[v]);
639
nir_ssa_def *exporter_vtx_idx = nir_build_load_shared(b, 1, 8, vtx_addr, .base = lds_es_exporter_tid, .align_mul = 1u);
640
exporter_vtx_indices[v] = nir_u2u32(b, exporter_vtx_idx);
641
}
642
643
nir_ssa_def *prim_exp_arg = emit_pack_ngg_prim_exp_arg(b, 3, exporter_vtx_indices, NULL);
644
nir_store_var(b, prim_exp_arg_var, prim_exp_arg, 0x1u);
645
}
646
nir_pop_if(b, if_gs_accepted);
647
}
648
649
static void
650
analyze_shader_before_culling_walk(nir_ssa_def *ssa,
651
uint8_t flag,
652
lower_ngg_nogs_state *nogs_state)
653
{
654
nir_instr *instr = ssa->parent_instr;
655
uint8_t old_pass_flags = instr->pass_flags;
656
instr->pass_flags |= flag;
657
658
if (instr->pass_flags == old_pass_flags)
659
return; /* Already visited. */
660
661
switch (instr->type) {
662
case nir_instr_type_intrinsic: {
663
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
664
665
/* VS input loads and SSBO loads are actually VRAM reads on AMD HW. */
666
switch (intrin->intrinsic) {
667
case nir_intrinsic_load_input: {
668
nir_io_semantics in_io_sem = nir_intrinsic_io_semantics(intrin);
669
uint64_t in_mask = UINT64_C(1) << (uint64_t) in_io_sem.location;
670
if (instr->pass_flags & nggc_passflag_used_by_pos)
671
nogs_state->inputs_needed_by_pos |= in_mask;
672
else if (instr->pass_flags & nggc_passflag_used_by_other)
673
nogs_state->inputs_needed_by_others |= in_mask;
674
break;
675
}
676
default:
677
break;
678
}
679
680
break;
681
}
682
case nir_instr_type_alu: {
683
nir_alu_instr *alu = nir_instr_as_alu(instr);
684
unsigned num_srcs = nir_op_infos[alu->op].num_inputs;
685
686
for (unsigned i = 0; i < num_srcs; ++i) {
687
analyze_shader_before_culling_walk(alu->src[i].src.ssa, flag, nogs_state);
688
}
689
690
break;
691
}
692
case nir_instr_type_phi: {
693
nir_phi_instr *phi = nir_instr_as_phi(instr);
694
nir_foreach_phi_src_safe(phi_src, phi) {
695
analyze_shader_before_culling_walk(phi_src->src.ssa, flag, nogs_state);
696
}
697
698
break;
699
}
700
default:
701
break;
702
}
703
}
704
705
static void
706
analyze_shader_before_culling(nir_shader *shader, lower_ngg_nogs_state *nogs_state)
707
{
708
nir_foreach_function(func, shader) {
709
nir_foreach_block(block, func->impl) {
710
nir_foreach_instr(instr, block) {
711
instr->pass_flags = 0;
712
713
if (instr->type != nir_instr_type_intrinsic)
714
continue;
715
716
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
717
if (intrin->intrinsic != nir_intrinsic_store_output)
718
continue;
719
720
nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
721
nir_ssa_def *store_val = intrin->src[0].ssa;
722
uint8_t flag = io_sem.location == VARYING_SLOT_POS ? nggc_passflag_used_by_pos : nggc_passflag_used_by_other;
723
analyze_shader_before_culling_walk(store_val, flag, nogs_state);
724
}
725
}
726
}
727
}
728
729
/**
730
* Save the reusable SSA definitions to variables so that the
731
* bottom shader part can reuse them from the top part.
732
*
733
* 1. We create a new function temporary variable for reusables,
734
* and insert a store+load.
735
* 2. The shader is cloned (the top part is created), then the
736
* control flow is reinserted (for the bottom part.)
737
* 3. For reusables, we delete the variable stores from the
738
* bottom part. This will make them use the variables from
739
* the top part and DCE the redundant instructions.
740
*/
741
static void
742
save_reusable_variables(nir_builder *b, lower_ngg_nogs_state *nogs_state)
743
{
744
ASSERTED int vec_ok = u_vector_init(&nogs_state->saved_uniforms, sizeof(saved_uniform), 4 * sizeof(saved_uniform));
745
assert(vec_ok);
746
747
unsigned loop_depth = 0;
748
749
nir_foreach_block_safe(block, b->impl) {
750
/* Check whether we're in a loop. */
751
nir_cf_node *next_cf_node = nir_cf_node_next(&block->cf_node);
752
nir_cf_node *prev_cf_node = nir_cf_node_prev(&block->cf_node);
753
if (next_cf_node && next_cf_node->type == nir_cf_node_loop)
754
loop_depth++;
755
if (prev_cf_node && prev_cf_node->type == nir_cf_node_loop)
756
loop_depth--;
757
758
/* The following code doesn't make sense in loops, so just skip it then. */
759
if (loop_depth)
760
continue;
761
762
nir_foreach_instr_safe(instr, block) {
763
/* Find instructions whose SSA definitions are used by both
764
* the top and bottom parts of the shader. In this case, it
765
* makes sense to try to reuse these from the top part.
766
*/
767
if ((instr->pass_flags & nggc_passflag_used_by_both) != nggc_passflag_used_by_both)
768
continue;
769
770
nir_ssa_def *ssa = NULL;
771
772
switch (instr->type) {
773
case nir_instr_type_alu: {
774
nir_alu_instr *alu = nir_instr_as_alu(instr);
775
if (alu->dest.dest.ssa.divergent)
776
continue;
777
/* Ignore uniform floats because they regress VGPR usage too much */
778
if (nir_op_infos[alu->op].output_type & nir_type_float)
779
continue;
780
ssa = &alu->dest.dest.ssa;
781
break;
782
}
783
case nir_instr_type_intrinsic: {
784
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
785
if (!nir_intrinsic_can_reorder(intrin) ||
786
!nir_intrinsic_infos[intrin->intrinsic].has_dest ||
787
intrin->dest.ssa.divergent)
788
continue;
789
ssa = &intrin->dest.ssa;
790
break;
791
}
792
case nir_instr_type_phi: {
793
nir_phi_instr *phi = nir_instr_as_phi(instr);
794
if (phi->dest.ssa.divergent)
795
continue;
796
ssa = &phi->dest.ssa;
797
break;
798
}
799
default:
800
continue;
801
}
802
803
assert(ssa);
804
805
enum glsl_base_type base_type = GLSL_TYPE_UINT;
806
switch (ssa->bit_size) {
807
case 8: base_type = GLSL_TYPE_UINT8; break;
808
case 16: base_type = GLSL_TYPE_UINT16; break;
809
case 32: base_type = GLSL_TYPE_UINT; break;
810
case 64: base_type = GLSL_TYPE_UINT64; break;
811
default: continue;
812
}
813
814
const struct glsl_type *t = ssa->num_components == 1
815
? glsl_scalar_type(base_type)
816
: glsl_vector_type(base_type, ssa->num_components);
817
818
saved_uniform *saved = (saved_uniform *) u_vector_add(&nogs_state->saved_uniforms);
819
assert(saved);
820
821
saved->var = nir_local_variable_create(b->impl, t, NULL);
822
saved->ssa = ssa;
823
824
b->cursor = instr->type == nir_instr_type_phi
825
? nir_after_instr_and_phis(instr)
826
: nir_after_instr(instr);
827
nir_store_var(b, saved->var, saved->ssa, BITFIELD_MASK(ssa->num_components));
828
nir_ssa_def *reloaded = nir_load_var(b, saved->var);
829
nir_ssa_def_rewrite_uses_after(ssa, reloaded, reloaded->parent_instr);
830
}
831
}
832
}
833
834
/**
835
* Reuses suitable variables from the top part of the shader,
836
* by deleting their stores from the bottom part.
837
*/
838
static void
839
apply_reusable_variables(nir_builder *b, lower_ngg_nogs_state *nogs_state)
840
{
841
if (!u_vector_length(&nogs_state->saved_uniforms)) {
842
u_vector_finish(&nogs_state->saved_uniforms);
843
return;
844
}
845
846
nir_foreach_block_reverse_safe(block, b->impl) {
847
nir_foreach_instr_reverse_safe(instr, block) {
848
if (instr->type != nir_instr_type_intrinsic)
849
continue;
850
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
851
852
/* When we found any of these intrinsics, it means
853
* we reached the top part and we must stop.
854
*/
855
if (intrin->intrinsic == nir_intrinsic_overwrite_subgroup_num_vertices_and_primitives_amd ||
856
intrin->intrinsic == nir_intrinsic_alloc_vertices_and_primitives_amd ||
857
intrin->intrinsic == nir_intrinsic_export_primitive_amd)
858
goto done;
859
860
if (intrin->intrinsic != nir_intrinsic_store_deref)
861
continue;
862
nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
863
if (deref->deref_type != nir_deref_type_var)
864
continue;
865
866
saved_uniform *saved;
867
u_vector_foreach(saved, &nogs_state->saved_uniforms) {
868
if (saved->var == deref->var) {
869
nir_instr_remove(instr);
870
}
871
}
872
}
873
}
874
875
done:
876
u_vector_finish(&nogs_state->saved_uniforms);
877
}
878
879
static void
880
add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_cf, lower_ngg_nogs_state *nogs_state)
881
{
882
assert(b->shader->info.outputs_written & (1 << VARYING_SLOT_POS));
883
884
bool uses_instance_id = BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_INSTANCE_ID);
885
bool uses_tess_primitive_id = BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_PRIMITIVE_ID);
886
887
unsigned max_exported_args = b->shader->info.stage == MESA_SHADER_VERTEX ? 2 : 4;
888
if (b->shader->info.stage == MESA_SHADER_VERTEX && !uses_instance_id)
889
max_exported_args--;
890
else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL && !uses_tess_primitive_id)
891
max_exported_args--;
892
893
unsigned pervertex_lds_bytes = lds_es_arg_0 + max_exported_args * 4u;
894
unsigned total_es_lds_bytes = pervertex_lds_bytes * nogs_state->max_es_num_vertices;
895
unsigned max_num_waves = nogs_state->max_num_waves;
896
unsigned ngg_scratch_lds_base_addr = ALIGN(total_es_lds_bytes, 8u);
897
unsigned ngg_scratch_lds_bytes = DIV_ROUND_UP(max_num_waves, 4u);
898
nogs_state->total_lds_bytes = ngg_scratch_lds_base_addr + ngg_scratch_lds_bytes;
899
900
nir_function_impl *impl = nir_shader_get_entrypoint(b->shader);
901
902
/* Create some helper variables. */
903
nir_variable *position_value_var = nogs_state->position_value_var;
904
nir_variable *prim_exp_arg_var = nogs_state->prim_exp_arg_var;
905
nir_variable *gs_accepted_var = nogs_state->gs_accepted_var;
906
nir_variable *es_accepted_var = nogs_state->es_accepted_var;
907
nir_variable *vertices_in_wave_var = nir_local_variable_create(impl, glsl_uint_type(), "vertices_in_wave");
908
nir_variable *primitives_in_wave_var = nir_local_variable_create(impl, glsl_uint_type(), "primitives_in_wave");
909
nir_variable *gs_vtxaddr_vars[3] = {
910
nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx0_addr"),
911
nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx1_addr"),
912
nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx2_addr"),
913
};
914
nir_variable *repacked_arg_vars[4] = {
915
nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_0"),
916
nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_1"),
917
nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_2"),
918
nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_3"),
919
};
920
921
/* Top part of the culling shader (aka. position shader part)
922
*
923
* We clone the full ES shader and emit it here, but we only really care
924
* about its position output, so we delete every other output from this part.
925
* The position output is stored into a temporary variable, and reloaded later.
926
*/
927
928
b->cursor = nir_before_cf_list(&impl->body);
929
930
nir_if *if_es_thread = nir_push_if(b, nir_build_has_input_vertex_amd(b));
931
{
932
/* Initialize the position output variable to zeroes, in case not all VS/TES invocations store the output.
933
* The spec doesn't require it, but we use (0, 0, 0, 1) because some games rely on that.
934
*/
935
nir_store_var(b, position_value_var, nir_imm_vec4(b, 0.0f, 0.0f, 0.0f, 1.0f), 0xfu);
936
937
/* Now reinsert a clone of the shader code */
938
struct hash_table *remap_table = _mesa_pointer_hash_table_create(NULL);
939
nir_cf_list_clone_and_reinsert(original_extracted_cf, &if_es_thread->cf_node, b->cursor, remap_table);
940
_mesa_hash_table_destroy(remap_table, NULL);
941
b->cursor = nir_after_cf_list(&if_es_thread->then_list);
942
943
/* Remember the current thread's shader arguments */
944
if (b->shader->info.stage == MESA_SHADER_VERTEX) {
945
nir_store_var(b, repacked_arg_vars[0], nir_build_load_vertex_id_zero_base(b), 0x1u);
946
if (uses_instance_id)
947
nir_store_var(b, repacked_arg_vars[1], nir_build_load_instance_id(b), 0x1u);
948
} else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
949
nir_ssa_def *tess_coord = nir_build_load_tess_coord(b);
950
nir_store_var(b, repacked_arg_vars[0], nir_channel(b, tess_coord, 0), 0x1u);
951
nir_store_var(b, repacked_arg_vars[1], nir_channel(b, tess_coord, 1), 0x1u);
952
nir_store_var(b, repacked_arg_vars[2], nir_build_load_tess_rel_patch_id_amd(b), 0x1u);
953
if (uses_tess_primitive_id)
954
nir_store_var(b, repacked_arg_vars[3], nir_build_load_primitive_id(b), 0x1u);
955
} else {
956
unreachable("Should be VS or TES.");
957
}
958
}
959
nir_pop_if(b, if_es_thread);
960
961
/* Remove all non-position outputs, and put the position output into the variable. */
962
nir_metadata_preserve(impl, nir_metadata_none);
963
remove_culling_shader_outputs(b->shader, nogs_state, position_value_var);
964
b->cursor = nir_after_cf_list(&impl->body);
965
966
/* Run culling algorithms if culling is enabled.
967
*
968
* NGG culling can be enabled or disabled in runtime.
969
* This is determined by a SGPR shader argument which is acccessed
970
* by the following NIR intrinsic.
971
*/
972
973
nir_if *if_cull_en = nir_push_if(b, nir_build_load_cull_any_enabled_amd(b));
974
{
975
nir_ssa_def *invocation_index = nir_build_load_local_invocation_index(b);
976
nir_ssa_def *es_vertex_lds_addr = pervertex_lds_addr(b, invocation_index, pervertex_lds_bytes);
977
978
/* ES invocations store their vertex data to LDS for GS threads to read. */
979
if_es_thread = nir_push_if(b, nir_build_has_input_vertex_amd(b));
980
{
981
/* Store position components that are relevant to culling in LDS */
982
nir_ssa_def *pre_cull_pos = nir_load_var(b, position_value_var);
983
nir_ssa_def *pre_cull_w = nir_channel(b, pre_cull_pos, 3);
984
nir_build_store_shared(b, pre_cull_w, es_vertex_lds_addr, .write_mask = 0x1u, .align_mul = 4, .base = lds_es_pos_w);
985
nir_ssa_def *pre_cull_x_div_w = nir_fdiv(b, nir_channel(b, pre_cull_pos, 0), pre_cull_w);
986
nir_ssa_def *pre_cull_y_div_w = nir_fdiv(b, nir_channel(b, pre_cull_pos, 1), pre_cull_w);
987
nir_build_store_shared(b, nir_vec2(b, pre_cull_x_div_w, pre_cull_y_div_w), es_vertex_lds_addr, .write_mask = 0x3u, .align_mul = 4, .base = lds_es_pos_x);
988
989
/* Clear out the ES accepted flag in LDS */
990
nir_build_store_shared(b, nir_imm_zero(b, 1, 8), es_vertex_lds_addr, .write_mask = 0x1u, .align_mul = 4, .base = lds_es_vertex_accepted);
991
}
992
nir_pop_if(b, if_es_thread);
993
994
nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
995
.memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
996
997
nir_store_var(b, gs_accepted_var, nir_imm_bool(b, false), 0x1u);
998
nir_store_var(b, prim_exp_arg_var, nir_imm_int(b, 1 << 31), 0x1u);
999
1000
/* GS invocations load the vertex data and perform the culling. */
1001
nir_if *if_gs_thread = nir_push_if(b, nir_build_has_input_primitive_amd(b));
1002
{
1003
/* Load vertex indices from input VGPRs */
1004
nir_ssa_def *vtx_idx[3] = {0};
1005
for (unsigned vertex = 0; vertex < 3; ++vertex)
1006
vtx_idx[vertex] = ngg_input_primitive_vertex_index(b, vertex);
1007
1008
nir_ssa_def *vtx_addr[3] = {0};
1009
nir_ssa_def *pos[3][4] = {0};
1010
1011
/* Load W positions of vertices first because the culling code will use these first */
1012
for (unsigned vtx = 0; vtx < 3; ++vtx) {
1013
vtx_addr[vtx] = pervertex_lds_addr(b, vtx_idx[vtx], pervertex_lds_bytes);
1014
pos[vtx][3] = nir_build_load_shared(b, 1, 32, vtx_addr[vtx], .align_mul = 4u, .base = lds_es_pos_w);
1015
nir_store_var(b, gs_vtxaddr_vars[vtx], vtx_addr[vtx], 0x1u);
1016
}
1017
1018
/* Load the X/W, Y/W positions of vertices */
1019
for (unsigned vtx = 0; vtx < 3; ++vtx) {
1020
nir_ssa_def *xy = nir_build_load_shared(b, 2, 32, vtx_addr[vtx], .align_mul = 4u, .base = lds_es_pos_x);
1021
pos[vtx][0] = nir_channel(b, xy, 0);
1022
pos[vtx][1] = nir_channel(b, xy, 1);
1023
}
1024
1025
/* See if the current primitive is accepted */
1026
nir_ssa_def *accepted = ac_nir_cull_triangle(b, nir_imm_bool(b, true), pos);
1027
nir_store_var(b, gs_accepted_var, accepted, 0x1u);
1028
1029
nir_if *if_gs_accepted = nir_push_if(b, accepted);
1030
{
1031
/* Store the accepted state to LDS for ES threads */
1032
for (unsigned vtx = 0; vtx < 3; ++vtx)
1033
nir_build_store_shared(b, nir_imm_intN_t(b, 0xff, 8), vtx_addr[vtx], .base = lds_es_vertex_accepted, .align_mul = 4u, .write_mask = 0x1u);
1034
}
1035
nir_pop_if(b, if_gs_accepted);
1036
}
1037
nir_pop_if(b, if_gs_thread);
1038
1039
nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
1040
.memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1041
1042
nir_store_var(b, es_accepted_var, nir_imm_bool(b, false), 0x1u);
1043
1044
/* ES invocations load their accepted flag from LDS. */
1045
if_es_thread = nir_push_if(b, nir_build_has_input_vertex_amd(b));
1046
{
1047
nir_ssa_def *accepted = nir_build_load_shared(b, 1, 8u, es_vertex_lds_addr, .base = lds_es_vertex_accepted, .align_mul = 4u);
1048
nir_ssa_def *accepted_bool = nir_ine(b, accepted, nir_imm_intN_t(b, 0, 8));
1049
nir_store_var(b, es_accepted_var, accepted_bool, 0x1u);
1050
}
1051
nir_pop_if(b, if_es_thread);
1052
1053
/* Vertex compaction. */
1054
compact_vertices_after_culling(b, nogs_state,
1055
vertices_in_wave_var, primitives_in_wave_var,
1056
repacked_arg_vars, gs_vtxaddr_vars,
1057
invocation_index, es_vertex_lds_addr,
1058
ngg_scratch_lds_base_addr, pervertex_lds_bytes, max_exported_args);
1059
}
1060
nir_push_else(b, if_cull_en);
1061
{
1062
/* When culling is disabled, we do the same as we would without culling. */
1063
nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_build_load_subgroup_id(b), nir_imm_int(b, 0)));
1064
{
1065
nir_ssa_def *vtx_cnt = nir_build_load_workgroup_num_input_vertices_amd(b);
1066
nir_ssa_def *prim_cnt = nir_build_load_workgroup_num_input_primitives_amd(b);
1067
nir_build_alloc_vertices_and_primitives_amd(b, vtx_cnt, prim_cnt);
1068
}
1069
nir_pop_if(b, if_wave_0);
1070
nir_store_var(b, prim_exp_arg_var, emit_ngg_nogs_prim_exp_arg(b, nogs_state), 0x1u);
1071
1072
nir_ssa_def *vtx_cnt = nir_bit_count(b, nir_build_ballot(b, 1, nogs_state->wave_size, nir_build_has_input_vertex_amd(b)));
1073
nir_ssa_def *prm_cnt = nir_bit_count(b, nir_build_ballot(b, 1, nogs_state->wave_size, nir_build_has_input_primitive_amd(b)));
1074
nir_store_var(b, vertices_in_wave_var, vtx_cnt, 0x1u);
1075
nir_store_var(b, primitives_in_wave_var, prm_cnt, 0x1u);
1076
}
1077
nir_pop_if(b, if_cull_en);
1078
1079
/* Update shader arguments.
1080
*
1081
* The registers which hold information about the subgroup's
1082
* vertices and primitives are updated here, so the rest of the shader
1083
* doesn't need to worry about the culling.
1084
*
1085
* These "overwrite" intrinsics must be at top level control flow,
1086
* otherwise they can mess up the backend (eg. ACO's SSA).
1087
*
1088
* TODO:
1089
* A cleaner solution would be to simply replace all usages of these args
1090
* with the load of the variables.
1091
* However, this wouldn't work right now because the backend uses the arguments
1092
* for purposes not expressed in NIR, eg. VS input loads, etc.
1093
* This can change if VS input loads and other stuff are lowered to eg. load_buffer_amd.
1094
*/
1095
1096
if (b->shader->info.stage == MESA_SHADER_VERTEX)
1097
nir_build_overwrite_vs_arguments_amd(b,
1098
nir_load_var(b, repacked_arg_vars[0]), nir_load_var(b, repacked_arg_vars[1]));
1099
else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL)
1100
nir_build_overwrite_tes_arguments_amd(b,
1101
nir_load_var(b, repacked_arg_vars[0]), nir_load_var(b, repacked_arg_vars[1]),
1102
nir_load_var(b, repacked_arg_vars[2]), nir_load_var(b, repacked_arg_vars[3]));
1103
else
1104
unreachable("Should be VS or TES.");
1105
1106
nir_ssa_def *vertices_in_wave = nir_load_var(b, vertices_in_wave_var);
1107
nir_ssa_def *primitives_in_wave = nir_load_var(b, primitives_in_wave_var);
1108
nir_build_overwrite_subgroup_num_vertices_and_primitives_amd(b, vertices_in_wave, primitives_in_wave);
1109
}
1110
1111
static bool
1112
can_use_deferred_attribute_culling(nir_shader *shader)
1113
{
1114
/* When the shader writes memory, it is difficult to guarantee correctness.
1115
* Future work:
1116
* - if only write-only SSBOs are used
1117
* - if we can prove that non-position outputs don't rely on memory stores
1118
* then may be okay to keep the memory stores in the 1st shader part, and delete them from the 2nd.
1119
*/
1120
if (shader->info.writes_memory)
1121
return false;
1122
1123
/* When the shader relies on the subgroup invocation ID, we'd break it, because the ID changes after the culling.
1124
* Future work: try to save this to LDS and reload, but it can still be broken in subtle ways.
1125
*/
1126
if (BITSET_TEST(shader->info.system_values_read, SYSTEM_VALUE_SUBGROUP_INVOCATION))
1127
return false;
1128
1129
return true;
1130
}
1131
1132
ac_nir_ngg_config
1133
ac_nir_lower_ngg_nogs(nir_shader *shader,
1134
unsigned max_num_es_vertices,
1135
unsigned num_vertices_per_primitives,
1136
unsigned max_workgroup_size,
1137
unsigned wave_size,
1138
bool consider_culling,
1139
bool consider_passthrough,
1140
bool export_prim_id,
1141
bool provoking_vtx_last)
1142
{
1143
nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1144
assert(impl);
1145
assert(max_num_es_vertices && max_workgroup_size && wave_size);
1146
1147
bool can_cull = consider_culling && (num_vertices_per_primitives == 3) &&
1148
can_use_deferred_attribute_culling(shader);
1149
bool passthrough = consider_passthrough && !can_cull &&
1150
!(shader->info.stage == MESA_SHADER_VERTEX && export_prim_id);
1151
1152
nir_variable *position_value_var = nir_local_variable_create(impl, glsl_vec4_type(), "position_value");
1153
nir_variable *prim_exp_arg_var = nir_local_variable_create(impl, glsl_uint_type(), "prim_exp_arg");
1154
nir_variable *es_accepted_var = can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "es_accepted") : NULL;
1155
nir_variable *gs_accepted_var = can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "gs_accepted") : NULL;
1156
1157
lower_ngg_nogs_state state = {
1158
.passthrough = passthrough,
1159
.export_prim_id = export_prim_id,
1160
.early_prim_export = exec_list_is_singular(&impl->body),
1161
.num_vertices_per_primitives = num_vertices_per_primitives,
1162
.provoking_vtx_idx = provoking_vtx_last ? (num_vertices_per_primitives - 1) : 0,
1163
.position_value_var = position_value_var,
1164
.prim_exp_arg_var = prim_exp_arg_var,
1165
.es_accepted_var = es_accepted_var,
1166
.gs_accepted_var = gs_accepted_var,
1167
.max_num_waves = DIV_ROUND_UP(max_workgroup_size, wave_size),
1168
.max_es_num_vertices = max_num_es_vertices,
1169
.wave_size = wave_size,
1170
};
1171
1172
/* We need LDS space when VS needs to export the primitive ID. */
1173
if (shader->info.stage == MESA_SHADER_VERTEX && export_prim_id)
1174
state.total_lds_bytes = max_num_es_vertices * 4u;
1175
1176
/* The shader only needs this much LDS when culling is turned off. */
1177
unsigned lds_bytes_if_culling_off = state.total_lds_bytes;
1178
1179
nir_builder builder;
1180
nir_builder *b = &builder; /* This is to avoid the & */
1181
nir_builder_init(b, impl);
1182
1183
if (can_cull) {
1184
/* We need divergence info for culling shaders. */
1185
nir_divergence_analysis(shader);
1186
analyze_shader_before_culling(shader, &state);
1187
save_reusable_variables(b, &state);
1188
}
1189
1190
nir_cf_list extracted;
1191
nir_cf_extract(&extracted, nir_before_cf_list(&impl->body), nir_after_cf_list(&impl->body));
1192
b->cursor = nir_before_cf_list(&impl->body);
1193
1194
if (!can_cull) {
1195
/* Allocate export space on wave 0 - confirm to the HW that we want to use all possible space */
1196
nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_build_load_subgroup_id(b), nir_imm_int(b, 0)));
1197
{
1198
nir_ssa_def *vtx_cnt = nir_build_load_workgroup_num_input_vertices_amd(b);
1199
nir_ssa_def *prim_cnt = nir_build_load_workgroup_num_input_primitives_amd(b);
1200
nir_build_alloc_vertices_and_primitives_amd(b, vtx_cnt, prim_cnt);
1201
}
1202
nir_pop_if(b, if_wave_0);
1203
1204
/* Take care of early primitive export, otherwise just pack the primitive export argument */
1205
if (state.early_prim_export)
1206
emit_ngg_nogs_prim_export(b, &state, NULL);
1207
else
1208
nir_store_var(b, prim_exp_arg_var, emit_ngg_nogs_prim_exp_arg(b, &state), 0x1u);
1209
} else {
1210
add_deferred_attribute_culling(b, &extracted, &state);
1211
b->cursor = nir_after_cf_list(&impl->body);
1212
1213
if (state.early_prim_export)
1214
emit_ngg_nogs_prim_export(b, &state, nir_load_var(b, state.prim_exp_arg_var));
1215
}
1216
1217
nir_intrinsic_instr *export_vertex_instr;
1218
1219
nir_if *if_es_thread = nir_push_if(b, nir_build_has_input_vertex_amd(b));
1220
{
1221
/* Run the actual shader */
1222
nir_cf_reinsert(&extracted, b->cursor);
1223
b->cursor = nir_after_cf_list(&if_es_thread->then_list);
1224
1225
/* Export all vertex attributes (except primitive ID) */
1226
export_vertex_instr = nir_build_export_vertex_amd(b);
1227
1228
/* Export primitive ID (in case of early primitive export or TES) */
1229
if (state.export_prim_id && (state.early_prim_export || shader->info.stage != MESA_SHADER_VERTEX))
1230
emit_store_ngg_nogs_es_primitive_id(b);
1231
}
1232
nir_pop_if(b, if_es_thread);
1233
1234
/* Take care of late primitive export */
1235
if (!state.early_prim_export) {
1236
emit_ngg_nogs_prim_export(b, &state, nir_load_var(b, prim_exp_arg_var));
1237
if (state.export_prim_id && shader->info.stage == MESA_SHADER_VERTEX) {
1238
if_es_thread = nir_push_if(b, nir_build_has_input_vertex_amd(b));
1239
emit_store_ngg_nogs_es_primitive_id(b);
1240
nir_pop_if(b, if_es_thread);
1241
}
1242
}
1243
1244
if (can_cull) {
1245
/* Replace uniforms. */
1246
apply_reusable_variables(b, &state);
1247
1248
/* Remove the redundant position output. */
1249
remove_extra_pos_outputs(shader, &state);
1250
1251
/* After looking at the performance in apps eg. Doom Eternal, and The Witcher 3,
1252
* it seems that it's best to put the position export always at the end, and
1253
* then let ACO schedule it up (slightly) only when early prim export is used.
1254
*/
1255
b->cursor = nir_before_instr(&export_vertex_instr->instr);
1256
1257
nir_ssa_def *pos_val = nir_load_var(b, state.position_value_var);
1258
nir_io_semantics io_sem = { .location = VARYING_SLOT_POS, .num_slots = 1 };
1259
nir_build_store_output(b, pos_val, nir_imm_int(b, 0), .base = VARYING_SLOT_POS, .component = 0, .io_semantics = io_sem, .write_mask = 0xfu);
1260
}
1261
1262
nir_metadata_preserve(impl, nir_metadata_none);
1263
nir_validate_shader(shader, "after emitting NGG VS/TES");
1264
1265
/* Cleanup */
1266
nir_opt_dead_write_vars(shader);
1267
nir_lower_vars_to_ssa(shader);
1268
nir_remove_dead_variables(shader, nir_var_function_temp, NULL);
1269
nir_lower_alu_to_scalar(shader, NULL, NULL);
1270
nir_lower_phis_to_scalar(shader, true);
1271
1272
if (can_cull) {
1273
/* It's beneficial to redo these opts after splitting the shader. */
1274
nir_opt_sink(shader, nir_move_load_input | nir_move_const_undef | nir_move_copies);
1275
nir_opt_move(shader, nir_move_load_input | nir_move_copies | nir_move_const_undef);
1276
}
1277
1278
bool progress;
1279
do {
1280
progress = false;
1281
NIR_PASS(progress, shader, nir_opt_undef);
1282
NIR_PASS(progress, shader, nir_opt_cse);
1283
NIR_PASS(progress, shader, nir_opt_dce);
1284
NIR_PASS(progress, shader, nir_opt_dead_cf);
1285
} while (progress);
1286
1287
shader->info.shared_size = state.total_lds_bytes;
1288
1289
ac_nir_ngg_config ret = {
1290
.lds_bytes_if_culling_off = lds_bytes_if_culling_off,
1291
.can_cull = can_cull,
1292
.passthrough = passthrough,
1293
.early_prim_export = state.early_prim_export,
1294
.nggc_inputs_read_by_pos = state.inputs_needed_by_pos,
1295
.nggc_inputs_read_by_others = state.inputs_needed_by_others,
1296
};
1297
1298
return ret;
1299
}
1300
1301
static nir_ssa_def *
1302
ngg_gs_out_vertex_addr(nir_builder *b, nir_ssa_def *out_vtx_idx, lower_ngg_gs_state *s)
1303
{
1304
unsigned write_stride_2exp = ffs(MAX2(b->shader->info.gs.vertices_out, 1)) - 1;
1305
1306
/* gs_max_out_vertices = 2^(write_stride_2exp) * some odd number */
1307
if (write_stride_2exp) {
1308
nir_ssa_def *row = nir_ushr_imm(b, out_vtx_idx, 5);
1309
nir_ssa_def *swizzle = nir_iand_imm(b, row, (1u << write_stride_2exp) - 1u);
1310
out_vtx_idx = nir_ixor(b, out_vtx_idx, swizzle);
1311
}
1312
1313
nir_ssa_def *out_vtx_offs = nir_imul_imm(b, out_vtx_idx, s->lds_bytes_per_gs_out_vertex);
1314
return nir_iadd_imm_nuw(b, out_vtx_offs, s->lds_addr_gs_out_vtx);
1315
}
1316
1317
static nir_ssa_def *
1318
ngg_gs_emit_vertex_addr(nir_builder *b, nir_ssa_def *gs_vtx_idx, lower_ngg_gs_state *s)
1319
{
1320
nir_ssa_def *tid_in_tg = nir_build_load_local_invocation_index(b);
1321
nir_ssa_def *gs_out_vtx_base = nir_imul_imm(b, tid_in_tg, b->shader->info.gs.vertices_out);
1322
nir_ssa_def *out_vtx_idx = nir_iadd_nuw(b, gs_out_vtx_base, gs_vtx_idx);
1323
1324
return ngg_gs_out_vertex_addr(b, out_vtx_idx, s);
1325
}
1326
1327
static void
1328
ngg_gs_clear_primflags(nir_builder *b, nir_ssa_def *num_vertices, unsigned stream, lower_ngg_gs_state *s)
1329
{
1330
nir_ssa_def *zero_u8 = nir_imm_zero(b, 1, 8);
1331
nir_store_var(b, s->current_clear_primflag_idx_var, num_vertices, 0x1u);
1332
1333
nir_loop *loop = nir_push_loop(b);
1334
{
1335
nir_ssa_def *current_clear_primflag_idx = nir_load_var(b, s->current_clear_primflag_idx_var);
1336
nir_if *if_break = nir_push_if(b, nir_uge(b, current_clear_primflag_idx, nir_imm_int(b, b->shader->info.gs.vertices_out)));
1337
{
1338
nir_jump(b, nir_jump_break);
1339
}
1340
nir_push_else(b, if_break);
1341
{
1342
nir_ssa_def *emit_vtx_addr = ngg_gs_emit_vertex_addr(b, current_clear_primflag_idx, s);
1343
nir_build_store_shared(b, zero_u8, emit_vtx_addr, .base = s->lds_offs_primflags + stream, .align_mul = 1, .write_mask = 0x1u);
1344
nir_store_var(b, s->current_clear_primflag_idx_var, nir_iadd_imm_nuw(b, current_clear_primflag_idx, 1), 0x1u);
1345
}
1346
nir_pop_if(b, if_break);
1347
}
1348
nir_pop_loop(b, loop);
1349
}
1350
1351
static void
1352
ngg_gs_shader_query(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
1353
{
1354
nir_if *if_shader_query = nir_push_if(b, nir_build_load_shader_query_enabled_amd(b));
1355
nir_ssa_def *num_prims_in_wave = NULL;
1356
1357
/* Calculate the "real" number of emitted primitives from the emitted GS vertices and primitives.
1358
* GS emits points, line strips or triangle strips.
1359
* Real primitives are points, lines or triangles.
1360
*/
1361
if (nir_src_is_const(intrin->src[0]) && nir_src_is_const(intrin->src[1])) {
1362
unsigned gs_vtx_cnt = nir_src_as_uint(intrin->src[0]);
1363
unsigned gs_prm_cnt = nir_src_as_uint(intrin->src[1]);
1364
unsigned total_prm_cnt = gs_vtx_cnt - gs_prm_cnt * (s->num_vertices_per_primitive - 1u);
1365
nir_ssa_def *num_threads = nir_bit_count(b, nir_build_ballot(b, 1, s->wave_size, nir_imm_bool(b, true)));
1366
num_prims_in_wave = nir_imul_imm(b, num_threads, total_prm_cnt);
1367
} else {
1368
nir_ssa_def *gs_vtx_cnt = intrin->src[0].ssa;
1369
nir_ssa_def *prm_cnt = intrin->src[1].ssa;
1370
if (s->num_vertices_per_primitive > 1)
1371
prm_cnt = nir_iadd_nuw(b, nir_imul_imm(b, prm_cnt, -1u * (s->num_vertices_per_primitive - 1)), gs_vtx_cnt);
1372
num_prims_in_wave = nir_build_reduce(b, prm_cnt, .reduction_op = nir_op_iadd);
1373
}
1374
1375
/* Store the query result to GDS using an atomic add. */
1376
nir_if *if_first_lane = nir_push_if(b, nir_build_elect(b, 1));
1377
nir_build_gds_atomic_add_amd(b, 32, num_prims_in_wave, nir_imm_int(b, 0), nir_imm_int(b, 0x100));
1378
nir_pop_if(b, if_first_lane);
1379
1380
nir_pop_if(b, if_shader_query);
1381
}
1382
1383
static bool
1384
lower_ngg_gs_store_output(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
1385
{
1386
assert(nir_src_is_const(intrin->src[1]));
1387
b->cursor = nir_before_instr(&intrin->instr);
1388
1389
unsigned writemask = nir_intrinsic_write_mask(intrin);
1390
unsigned base = nir_intrinsic_base(intrin);
1391
unsigned component_offset = nir_intrinsic_component(intrin);
1392
unsigned base_offset = nir_src_as_uint(intrin->src[1]);
1393
nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
1394
1395
assert((base + base_offset) < VARYING_SLOT_MAX);
1396
1397
nir_ssa_def *store_val = intrin->src[0].ssa;
1398
1399
for (unsigned comp = 0; comp < 4; ++comp) {
1400
if (!(writemask & (1 << comp)))
1401
continue;
1402
unsigned stream = (io_sem.gs_streams >> (comp * 2)) & 0x3;
1403
if (!(b->shader->info.gs.active_stream_mask & (1 << stream)))
1404
continue;
1405
1406
/* Small bitsize components consume the same amount of space as 32-bit components,
1407
* but 64-bit ones consume twice as many. (Vulkan spec 15.1.5)
1408
*/
1409
unsigned num_consumed_components = MIN2(1, DIV_ROUND_UP(store_val->bit_size, 32));
1410
nir_ssa_def *element = nir_channel(b, store_val, comp);
1411
if (num_consumed_components > 1)
1412
element = nir_extract_bits(b, &element, 1, 0, num_consumed_components, 32);
1413
1414
for (unsigned c = 0; c < num_consumed_components; ++c) {
1415
unsigned component_index = (comp * num_consumed_components) + c + component_offset;
1416
unsigned base_index = base + base_offset + component_index / 4;
1417
component_index %= 4;
1418
1419
/* Save output usage info */
1420
gs_output_component_info *info = &s->output_component_info[base_index][component_index];
1421
info->bit_size = MAX2(info->bit_size, MIN2(store_val->bit_size, 32));
1422
info->stream = stream;
1423
1424
/* Store the current component element */
1425
nir_ssa_def *component_element = element;
1426
if (num_consumed_components > 1)
1427
component_element = nir_channel(b, component_element, c);
1428
if (component_element->bit_size != 32)
1429
component_element = nir_u2u32(b, component_element);
1430
1431
nir_store_var(b, s->output_vars[base_index][component_index], component_element, 0x1u);
1432
}
1433
}
1434
1435
nir_instr_remove(&intrin->instr);
1436
return true;
1437
}
1438
1439
static bool
1440
lower_ngg_gs_emit_vertex_with_counter(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
1441
{
1442
b->cursor = nir_before_instr(&intrin->instr);
1443
1444
unsigned stream = nir_intrinsic_stream_id(intrin);
1445
if (!(b->shader->info.gs.active_stream_mask & (1 << stream))) {
1446
nir_instr_remove(&intrin->instr);
1447
return true;
1448
}
1449
1450
nir_ssa_def *gs_emit_vtx_idx = intrin->src[0].ssa;
1451
nir_ssa_def *current_vtx_per_prim = intrin->src[1].ssa;
1452
nir_ssa_def *gs_emit_vtx_addr = ngg_gs_emit_vertex_addr(b, gs_emit_vtx_idx, s);
1453
1454
for (unsigned slot = 0; slot < VARYING_SLOT_MAX; ++slot) {
1455
unsigned packed_location = util_bitcount64((b->shader->info.outputs_written & BITFIELD64_MASK(slot)));
1456
1457
for (unsigned comp = 0; comp < 4; ++comp) {
1458
gs_output_component_info *info = &s->output_component_info[slot][comp];
1459
if (info->stream != stream || !info->bit_size)
1460
continue;
1461
1462
/* Store the output to LDS */
1463
nir_ssa_def *out_val = nir_load_var(b, s->output_vars[slot][comp]);
1464
if (info->bit_size != 32)
1465
out_val = nir_u2u(b, out_val, info->bit_size);
1466
1467
nir_build_store_shared(b, out_val, gs_emit_vtx_addr, .base = packed_location * 16 + comp * 4, .align_mul = 4, .write_mask = 0x1u);
1468
1469
/* Clear the variable that holds the output */
1470
nir_store_var(b, s->output_vars[slot][comp], nir_ssa_undef(b, 1, 32), 0x1u);
1471
}
1472
}
1473
1474
/* Calculate and store per-vertex primitive flags based on vertex counts:
1475
* - bit 0: whether this vertex finishes a primitive (a real primitive, not the strip)
1476
* - bit 1: whether the primitive index is odd (if we are emitting triangle strips, otherwise always 0)
1477
* - bit 2: always 1 (so that we can use it for determining vertex liveness)
1478
*/
1479
1480
nir_ssa_def *completes_prim = nir_ige(b, current_vtx_per_prim, nir_imm_int(b, s->num_vertices_per_primitive - 1));
1481
nir_ssa_def *prim_flag = nir_bcsel(b, completes_prim, nir_imm_int(b, 0b101u), nir_imm_int(b, 0b100u));
1482
1483
if (s->num_vertices_per_primitive == 3) {
1484
nir_ssa_def *odd = nir_iand_imm(b, current_vtx_per_prim, 1);
1485
prim_flag = nir_iadd_nuw(b, prim_flag, nir_ishl(b, odd, nir_imm_int(b, 1)));
1486
}
1487
1488
nir_build_store_shared(b, nir_u2u8(b, prim_flag), gs_emit_vtx_addr, .base = s->lds_offs_primflags + stream, .align_mul = 4u, .write_mask = 0x1u);
1489
nir_instr_remove(&intrin->instr);
1490
return true;
1491
}
1492
1493
static bool
1494
lower_ngg_gs_end_primitive_with_counter(nir_builder *b, nir_intrinsic_instr *intrin, UNUSED lower_ngg_gs_state *s)
1495
{
1496
b->cursor = nir_before_instr(&intrin->instr);
1497
1498
/* These are not needed, we can simply remove them */
1499
nir_instr_remove(&intrin->instr);
1500
return true;
1501
}
1502
1503
static bool
1504
lower_ngg_gs_set_vertex_and_primitive_count(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
1505
{
1506
b->cursor = nir_before_instr(&intrin->instr);
1507
1508
unsigned stream = nir_intrinsic_stream_id(intrin);
1509
if (stream > 0 && !(b->shader->info.gs.active_stream_mask & (1 << stream))) {
1510
nir_instr_remove(&intrin->instr);
1511
return true;
1512
}
1513
1514
s->found_out_vtxcnt[stream] = true;
1515
1516
/* Clear the primitive flags of non-emitted vertices */
1517
if (!nir_src_is_const(intrin->src[0]) || nir_src_as_uint(intrin->src[0]) < b->shader->info.gs.vertices_out)
1518
ngg_gs_clear_primflags(b, intrin->src[0].ssa, stream, s);
1519
1520
ngg_gs_shader_query(b, intrin, s);
1521
nir_instr_remove(&intrin->instr);
1522
return true;
1523
}
1524
1525
static bool
1526
lower_ngg_gs_intrinsic(nir_builder *b, nir_instr *instr, void *state)
1527
{
1528
lower_ngg_gs_state *s = (lower_ngg_gs_state *) state;
1529
1530
if (instr->type != nir_instr_type_intrinsic)
1531
return false;
1532
1533
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1534
1535
if (intrin->intrinsic == nir_intrinsic_store_output)
1536
return lower_ngg_gs_store_output(b, intrin, s);
1537
else if (intrin->intrinsic == nir_intrinsic_emit_vertex_with_counter)
1538
return lower_ngg_gs_emit_vertex_with_counter(b, intrin, s);
1539
else if (intrin->intrinsic == nir_intrinsic_end_primitive_with_counter)
1540
return lower_ngg_gs_end_primitive_with_counter(b, intrin, s);
1541
else if (intrin->intrinsic == nir_intrinsic_set_vertex_and_primitive_count)
1542
return lower_ngg_gs_set_vertex_and_primitive_count(b, intrin, s);
1543
1544
return false;
1545
}
1546
1547
static void
1548
lower_ngg_gs_intrinsics(nir_shader *shader, lower_ngg_gs_state *s)
1549
{
1550
nir_shader_instructions_pass(shader, lower_ngg_gs_intrinsic, nir_metadata_none, s);
1551
}
1552
1553
static void
1554
ngg_gs_export_primitives(nir_builder *b, nir_ssa_def *max_num_out_prims, nir_ssa_def *tid_in_tg,
1555
nir_ssa_def *exporter_tid_in_tg, nir_ssa_def *primflag_0,
1556
lower_ngg_gs_state *s)
1557
{
1558
nir_if *if_prim_export_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_prims));
1559
1560
/* Only bit 0 matters here - set it to 1 when the primitive should be null */
1561
nir_ssa_def *is_null_prim = nir_ixor(b, primflag_0, nir_imm_int(b, -1u));
1562
1563
nir_ssa_def *vtx_indices[3] = {0};
1564
vtx_indices[s->num_vertices_per_primitive - 1] = exporter_tid_in_tg;
1565
if (s->num_vertices_per_primitive >= 2)
1566
vtx_indices[s->num_vertices_per_primitive - 2] = nir_isub(b, exporter_tid_in_tg, nir_imm_int(b, 1));
1567
if (s->num_vertices_per_primitive == 3)
1568
vtx_indices[s->num_vertices_per_primitive - 3] = nir_isub(b, exporter_tid_in_tg, nir_imm_int(b, 2));
1569
1570
if (s->num_vertices_per_primitive == 3) {
1571
/* API GS outputs triangle strips, but NGG HW understands triangles.
1572
* We already know the triangles due to how we set the primitive flags, but we need to
1573
* make sure the vertex order is so that the front/back is correct, and the provoking vertex is kept.
1574
*/
1575
1576
nir_ssa_def *is_odd = nir_ubfe(b, primflag_0, nir_imm_int(b, 1), nir_imm_int(b, 1));
1577
if (!s->provoking_vertex_last) {
1578
vtx_indices[1] = nir_iadd(b, vtx_indices[1], is_odd);
1579
vtx_indices[2] = nir_isub(b, vtx_indices[2], is_odd);
1580
} else {
1581
vtx_indices[0] = nir_iadd(b, vtx_indices[0], is_odd);
1582
vtx_indices[1] = nir_isub(b, vtx_indices[1], is_odd);
1583
}
1584
}
1585
1586
nir_ssa_def *arg = emit_pack_ngg_prim_exp_arg(b, s->num_vertices_per_primitive, vtx_indices, is_null_prim);
1587
nir_build_export_primitive_amd(b, arg);
1588
nir_pop_if(b, if_prim_export_thread);
1589
}
1590
1591
static void
1592
ngg_gs_export_vertices(nir_builder *b, nir_ssa_def *max_num_out_vtx, nir_ssa_def *tid_in_tg,
1593
nir_ssa_def *out_vtx_lds_addr, lower_ngg_gs_state *s)
1594
{
1595
nir_if *if_vtx_export_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_vtx));
1596
nir_ssa_def *exported_out_vtx_lds_addr = out_vtx_lds_addr;
1597
1598
if (!s->output_compile_time_known) {
1599
/* Vertex compaction.
1600
* The current thread will export a vertex that was live in another invocation.
1601
* Load the index of the vertex that the current thread will have to export.
1602
*/
1603
nir_ssa_def *exported_vtx_idx = nir_build_load_shared(b, 1, 8, out_vtx_lds_addr, .base = s->lds_offs_primflags + 1, .align_mul = 1u);
1604
exported_out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, nir_u2u32(b, exported_vtx_idx), s);
1605
}
1606
1607
for (unsigned slot = 0; slot < VARYING_SLOT_MAX; ++slot) {
1608
if (!(b->shader->info.outputs_written & BITFIELD64_BIT(slot)))
1609
continue;
1610
1611
unsigned packed_location = util_bitcount64((b->shader->info.outputs_written & BITFIELD64_MASK(slot)));
1612
nir_io_semantics io_sem = { .location = slot, .num_slots = 1 };
1613
1614
for (unsigned comp = 0; comp < 4; ++comp) {
1615
gs_output_component_info *info = &s->output_component_info[slot][comp];
1616
if (info->stream != 0 || info->bit_size == 0)
1617
continue;
1618
1619
nir_ssa_def *load = nir_build_load_shared(b, 1, info->bit_size, exported_out_vtx_lds_addr, .base = packed_location * 16u + comp * 4u, .align_mul = 4u);
1620
nir_build_store_output(b, load, nir_imm_int(b, 0), .write_mask = 0x1u, .base = slot, .component = comp, .io_semantics = io_sem);
1621
}
1622
}
1623
1624
nir_build_export_vertex_amd(b);
1625
nir_pop_if(b, if_vtx_export_thread);
1626
}
1627
1628
static void
1629
ngg_gs_setup_vertex_compaction(nir_builder *b, nir_ssa_def *vertex_live, nir_ssa_def *tid_in_tg,
1630
nir_ssa_def *exporter_tid_in_tg, lower_ngg_gs_state *s)
1631
{
1632
assert(vertex_live->bit_size == 1);
1633
nir_if *if_vertex_live = nir_push_if(b, vertex_live);
1634
{
1635
/* Setup the vertex compaction.
1636
* Save the current thread's id for the thread which will export the current vertex.
1637
* We reuse stream 1 of the primitive flag of the other thread's vertex for storing this.
1638
*/
1639
1640
nir_ssa_def *exporter_lds_addr = ngg_gs_out_vertex_addr(b, exporter_tid_in_tg, s);
1641
nir_ssa_def *tid_in_tg_u8 = nir_u2u8(b, tid_in_tg);
1642
nir_build_store_shared(b, tid_in_tg_u8, exporter_lds_addr, .base = s->lds_offs_primflags + 1, .align_mul = 1u, .write_mask = 0x1u);
1643
}
1644
nir_pop_if(b, if_vertex_live);
1645
}
1646
1647
static nir_ssa_def *
1648
ngg_gs_load_out_vtx_primflag_0(nir_builder *b, nir_ssa_def *tid_in_tg, nir_ssa_def *vtx_lds_addr,
1649
nir_ssa_def *max_num_out_vtx, lower_ngg_gs_state *s)
1650
{
1651
nir_ssa_def *zero = nir_imm_int(b, 0);
1652
1653
nir_if *if_outvtx_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_vtx));
1654
nir_ssa_def *primflag_0 = nir_build_load_shared(b, 1, 8, vtx_lds_addr, .base = s->lds_offs_primflags, .align_mul = 4u);
1655
primflag_0 = nir_u2u32(b, primflag_0);
1656
nir_pop_if(b, if_outvtx_thread);
1657
1658
return nir_if_phi(b, primflag_0, zero);
1659
}
1660
1661
static void
1662
ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
1663
{
1664
nir_ssa_def *tid_in_tg = nir_build_load_local_invocation_index(b);
1665
nir_ssa_def *max_vtxcnt = nir_build_load_workgroup_num_input_vertices_amd(b);
1666
nir_ssa_def *max_prmcnt = max_vtxcnt; /* They are currently practically the same; both RADV and RadeonSI do this. */
1667
nir_ssa_def *out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, tid_in_tg, s);
1668
1669
if (s->output_compile_time_known) {
1670
/* When the output is compile-time known, the GS writes all possible vertices and primitives it can.
1671
* The gs_alloc_req needs to happen on one wave only, otherwise the HW hangs.
1672
*/
1673
nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_build_load_subgroup_id(b), nir_imm_zero(b, 1, 32)));
1674
nir_build_alloc_vertices_and_primitives_amd(b, max_vtxcnt, max_prmcnt);
1675
nir_pop_if(b, if_wave_0);
1676
}
1677
1678
/* Workgroup barrier: wait for all GS threads to finish */
1679
nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
1680
.memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1681
1682
nir_ssa_def *out_vtx_primflag_0 = ngg_gs_load_out_vtx_primflag_0(b, tid_in_tg, out_vtx_lds_addr, max_vtxcnt, s);
1683
1684
if (s->output_compile_time_known) {
1685
ngg_gs_export_primitives(b, max_vtxcnt, tid_in_tg, tid_in_tg, out_vtx_primflag_0, s);
1686
ngg_gs_export_vertices(b, max_vtxcnt, tid_in_tg, out_vtx_lds_addr, s);
1687
return;
1688
}
1689
1690
/* When the output vertex count is not known at compile time:
1691
* There may be gaps between invocations that have live vertices, but NGG hardware
1692
* requires that the invocations that export vertices are packed (ie. compact).
1693
* To ensure this, we need to repack invocations that have a live vertex.
1694
*/
1695
nir_ssa_def *vertex_live = nir_ine(b, out_vtx_primflag_0, nir_imm_zero(b, 1, out_vtx_primflag_0->bit_size));
1696
wg_repack_result rep = repack_invocations_in_workgroup(b, vertex_live, s->lds_addr_gs_scratch, s->max_num_waves, s->wave_size);
1697
1698
nir_ssa_def *workgroup_num_vertices = rep.num_repacked_invocations;
1699
nir_ssa_def *exporter_tid_in_tg = rep.repacked_invocation_index;
1700
1701
/* When the workgroup emits 0 total vertices, we also must export 0 primitives (otherwise the HW can hang). */
1702
nir_ssa_def *any_output = nir_ine(b, workgroup_num_vertices, nir_imm_int(b, 0));
1703
max_prmcnt = nir_bcsel(b, any_output, max_prmcnt, nir_imm_int(b, 0));
1704
1705
/* Allocate export space. We currently don't compact primitives, just use the maximum number. */
1706
nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_build_load_subgroup_id(b), nir_imm_zero(b, 1, 32)));
1707
nir_build_alloc_vertices_and_primitives_amd(b, workgroup_num_vertices, max_prmcnt);
1708
nir_pop_if(b, if_wave_0);
1709
1710
/* Vertex compaction. This makes sure there are no gaps between threads that export vertices. */
1711
ngg_gs_setup_vertex_compaction(b, vertex_live, tid_in_tg, exporter_tid_in_tg, s);
1712
1713
/* Workgroup barrier: wait for all LDS stores to finish. */
1714
nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
1715
.memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1716
1717
ngg_gs_export_primitives(b, max_prmcnt, tid_in_tg, exporter_tid_in_tg, out_vtx_primflag_0, s);
1718
ngg_gs_export_vertices(b, workgroup_num_vertices, tid_in_tg, out_vtx_lds_addr, s);
1719
}
1720
1721
void
1722
ac_nir_lower_ngg_gs(nir_shader *shader,
1723
unsigned wave_size,
1724
unsigned max_workgroup_size,
1725
unsigned esgs_ring_lds_bytes,
1726
unsigned gs_out_vtx_bytes,
1727
unsigned gs_total_out_vtx_bytes,
1728
bool provoking_vertex_last)
1729
{
1730
nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1731
assert(impl);
1732
1733
lower_ngg_gs_state state = {
1734
.max_num_waves = DIV_ROUND_UP(max_workgroup_size, wave_size),
1735
.wave_size = wave_size,
1736
.lds_addr_gs_out_vtx = esgs_ring_lds_bytes,
1737
.lds_addr_gs_scratch = ALIGN(esgs_ring_lds_bytes + gs_total_out_vtx_bytes, 8u /* for the repacking code */),
1738
.lds_offs_primflags = gs_out_vtx_bytes,
1739
.lds_bytes_per_gs_out_vertex = gs_out_vtx_bytes + 4u,
1740
.provoking_vertex_last = provoking_vertex_last,
1741
};
1742
1743
unsigned lds_scratch_bytes = DIV_ROUND_UP(state.max_num_waves, 4u) * 4u;
1744
unsigned total_lds_bytes = state.lds_addr_gs_scratch + lds_scratch_bytes;
1745
shader->info.shared_size = total_lds_bytes;
1746
1747
nir_gs_count_vertices_and_primitives(shader, state.const_out_vtxcnt, state.const_out_prmcnt, 4u);
1748
state.output_compile_time_known = state.const_out_vtxcnt[0] == shader->info.gs.vertices_out &&
1749
state.const_out_prmcnt[0] != -1;
1750
1751
if (!state.output_compile_time_known)
1752
state.current_clear_primflag_idx_var = nir_local_variable_create(impl, glsl_uint_type(), "current_clear_primflag_idx");
1753
1754
if (shader->info.gs.output_primitive == GL_POINTS)
1755
state.num_vertices_per_primitive = 1;
1756
else if (shader->info.gs.output_primitive == GL_LINE_STRIP)
1757
state.num_vertices_per_primitive = 2;
1758
else if (shader->info.gs.output_primitive == GL_TRIANGLE_STRIP)
1759
state.num_vertices_per_primitive = 3;
1760
else
1761
unreachable("Invalid GS output primitive.");
1762
1763
/* Extract the full control flow. It is going to be wrapped in an if statement. */
1764
nir_cf_list extracted;
1765
nir_cf_extract(&extracted, nir_before_cf_list(&impl->body), nir_after_cf_list(&impl->body));
1766
1767
nir_builder builder;
1768
nir_builder *b = &builder; /* This is to avoid the & */
1769
nir_builder_init(b, impl);
1770
b->cursor = nir_before_cf_list(&impl->body);
1771
1772
/* Workgroup barrier: wait for ES threads */
1773
nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,
1774
.memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1775
1776
/* Wrap the GS control flow. */
1777
nir_if *if_gs_thread = nir_push_if(b, nir_build_has_input_primitive_amd(b));
1778
1779
/* Create and initialize output variables */
1780
for (unsigned slot = 0; slot < VARYING_SLOT_MAX; ++slot) {
1781
for (unsigned comp = 0; comp < 4; ++comp) {
1782
state.output_vars[slot][comp] = nir_local_variable_create(impl, glsl_uint_type(), "output");
1783
}
1784
}
1785
1786
nir_cf_reinsert(&extracted, b->cursor);
1787
b->cursor = nir_after_cf_list(&if_gs_thread->then_list);
1788
nir_pop_if(b, if_gs_thread);
1789
1790
/* Lower the GS intrinsics */
1791
lower_ngg_gs_intrinsics(shader, &state);
1792
b->cursor = nir_after_cf_list(&impl->body);
1793
1794
if (!state.found_out_vtxcnt[0]) {
1795
fprintf(stderr, "Could not find set_vertex_and_primitive_count for stream 0. This would hang your GPU.");
1796
abort();
1797
}
1798
1799
/* Emit the finale sequence */
1800
ngg_gs_finale(b, &state);
1801
nir_validate_shader(shader, "after emitting NGG GS");
1802
1803
/* Cleanup */
1804
nir_lower_vars_to_ssa(shader);
1805
nir_remove_dead_variables(shader, nir_var_function_temp, NULL);
1806
nir_metadata_preserve(impl, nir_metadata_none);
1807
}
1808
1809