Path: blob/21.2-virgl/src/amd/common/ac_nir_lower_ngg.c
7236 views
/*1* Copyright © 2021 Valve Corporation2*3* Permission is hereby granted, free of charge, to any person obtaining a4* copy of this software and associated documentation files (the "Software"),5* to deal in the Software without restriction, including without limitation6* the rights to use, copy, modify, merge, publish, distribute, sublicense,7* and/or sell copies of the Software, and to permit persons to whom the8* Software is furnished to do so, subject to the following conditions:9*10* The above copyright notice and this permission notice (including the next11* paragraph) shall be included in all copies or substantial portions of the12* Software.13*14* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR15* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,16* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL17* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER18* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING19* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS20* IN THE SOFTWARE.21*22*/2324#include "ac_nir.h"25#include "nir_builder.h"26#include "u_math.h"27#include "u_vector.h"2829enum {30nggc_passflag_used_by_pos = 1,31nggc_passflag_used_by_other = 2,32nggc_passflag_used_by_both = nggc_passflag_used_by_pos | nggc_passflag_used_by_other,33};3435typedef struct36{37nir_ssa_def *ssa;38nir_variable *var;39} saved_uniform;4041typedef struct42{43nir_variable *position_value_var;44nir_variable *prim_exp_arg_var;45nir_variable *es_accepted_var;46nir_variable *gs_accepted_var;4748struct u_vector saved_uniforms;4950bool passthrough;51bool export_prim_id;52bool early_prim_export;53unsigned wave_size;54unsigned max_num_waves;55unsigned num_vertices_per_primitives;56unsigned provoking_vtx_idx;57unsigned max_es_num_vertices;58unsigned total_lds_bytes;5960uint64_t inputs_needed_by_pos;61uint64_t inputs_needed_by_others;62} lower_ngg_nogs_state;6364typedef struct65{66/* bitsize of this component (max 32), or 0 if it's never written at all */67uint8_t bit_size : 6;68/* output stream index */69uint8_t stream : 2;70} gs_output_component_info;7172typedef struct73{74nir_variable *output_vars[VARYING_SLOT_MAX][4];75nir_variable *current_clear_primflag_idx_var;76int const_out_vtxcnt[4];77int const_out_prmcnt[4];78unsigned wave_size;79unsigned max_num_waves;80unsigned num_vertices_per_primitive;81unsigned lds_addr_gs_out_vtx;82unsigned lds_addr_gs_scratch;83unsigned lds_bytes_per_gs_out_vertex;84unsigned lds_offs_primflags;85bool found_out_vtxcnt[4];86bool output_compile_time_known;87bool provoking_vertex_last;88gs_output_component_info output_component_info[VARYING_SLOT_MAX][4];89} lower_ngg_gs_state;9091typedef struct {92nir_variable *pre_cull_position_value_var;93} remove_culling_shader_outputs_state;9495typedef struct {96nir_variable *pos_value_replacement;97} remove_extra_position_output_state;9899typedef struct {100nir_ssa_def *reduction_result;101nir_ssa_def *excl_scan_result;102} wg_scan_result;103104/* Per-vertex LDS layout of culling shaders */105enum {106/* Position of the ES vertex (at the beginning for alignment reasons) */107lds_es_pos_x = 0,108lds_es_pos_y = 4,109lds_es_pos_z = 8,110lds_es_pos_w = 12,111112/* 1 when the vertex is accepted, 0 if it should be culled */113lds_es_vertex_accepted = 16,114/* ID of the thread which will export the current thread's vertex */115lds_es_exporter_tid = 17,116117/* Repacked arguments - also listed separately for VS and TES */118lds_es_arg_0 = 20,119120/* VS arguments which need to be repacked */121lds_es_vs_vertex_id = 20,122lds_es_vs_instance_id = 24,123124/* TES arguments which need to be repacked */125lds_es_tes_u = 20,126lds_es_tes_v = 24,127lds_es_tes_rel_patch_id = 28,128lds_es_tes_patch_id = 32,129};130131typedef struct {132nir_ssa_def *num_repacked_invocations;133nir_ssa_def *repacked_invocation_index;134} wg_repack_result;135136/**137* Repacks invocations in the current workgroup to eliminate gaps between them.138*139* Uses 1 dword of LDS per 4 waves (1 byte of LDS per wave).140* Assumes that all invocations in the workgroup are active (exec = -1).141*/142static wg_repack_result143repack_invocations_in_workgroup(nir_builder *b, nir_ssa_def *input_bool,144unsigned lds_addr_base, unsigned max_num_waves,145unsigned wave_size)146{147/* Input boolean: 1 if the current invocation should survive the repack. */148assert(input_bool->bit_size == 1);149150/* STEP 1. Count surviving invocations in the current wave.151*152* Implemented by a scalar instruction that simply counts the number of bits set in a 32/64-bit mask.153*/154155nir_ssa_def *input_mask = nir_build_ballot(b, 1, wave_size, input_bool);156nir_ssa_def *surviving_invocations_in_current_wave = nir_bit_count(b, input_mask);157158/* If we know at compile time that the workgroup has only 1 wave, no further steps are necessary. */159if (max_num_waves == 1) {160wg_repack_result r = {161.num_repacked_invocations = surviving_invocations_in_current_wave,162.repacked_invocation_index = nir_build_mbcnt_amd(b, input_mask, nir_imm_int(b, 0)),163};164return r;165}166167/* STEP 2. Waves tell each other their number of surviving invocations.168*169* Each wave activates only its first lane (exec = 1), which stores the number of surviving170* invocations in that wave into the LDS, then reads the numbers from every wave.171*172* The workgroup size of NGG shaders is at most 256, which means173* the maximum number of waves is 4 in Wave64 mode and 8 in Wave32 mode.174* Each wave writes 1 byte, so it's up to 8 bytes, so at most 2 dwords are necessary.175*/176177const unsigned num_lds_dwords = DIV_ROUND_UP(max_num_waves, 4);178assert(num_lds_dwords <= 2);179180nir_ssa_def *wave_id = nir_build_load_subgroup_id(b);181nir_ssa_def *dont_care = nir_ssa_undef(b, 1, num_lds_dwords * 32);182nir_if *if_first_lane = nir_push_if(b, nir_build_elect(b, 1));183184nir_build_store_shared(b, nir_u2u8(b, surviving_invocations_in_current_wave), wave_id, .base = lds_addr_base, .align_mul = 1u, .write_mask = 0x1u);185186nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,187.memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);188189nir_ssa_def *packed_counts = nir_build_load_shared(b, 1, num_lds_dwords * 32, nir_imm_int(b, 0), .base = lds_addr_base, .align_mul = 8u);190191nir_pop_if(b, if_first_lane);192193packed_counts = nir_if_phi(b, packed_counts, dont_care);194195/* STEP 3. Compute the repacked invocation index and the total number of surviving invocations.196*197* By now, every wave knows the number of surviving invocations in all waves.198* Each number is 1 byte, and they are packed into up to 2 dwords.199*200* Each lane N will sum the number of surviving invocations from waves 0 to N-1.201* If the workgroup has M waves, then each wave will use only its first M+1 lanes for this.202* (Other lanes are not deactivated but their calculation is not used.)203*204* - We read the sum from the lane whose id is the current wave's id.205* Add the masked bitcount to this, and we get the repacked invocation index.206* - We read the sum from the lane whose id is the number of waves in the workgroup.207* This is the total number of surviving invocations in the workgroup.208*/209210nir_ssa_def *num_waves = nir_build_load_num_subgroups(b);211212/* sel = 0x01010101 * lane_id + 0x03020100 */213nir_ssa_def *lane_id = nir_load_subgroup_invocation(b);214nir_ssa_def *packed_id = nir_build_byte_permute_amd(b, nir_imm_int(b, 0), lane_id, nir_imm_int(b, 0));215nir_ssa_def *sel = nir_iadd_imm_nuw(b, packed_id, 0x03020100);216nir_ssa_def *sum = NULL;217218if (num_lds_dwords == 1) {219/* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */220nir_ssa_def *packed_dw = nir_build_lane_permute_16_amd(b, packed_counts, nir_imm_int(b, 0), nir_imm_int(b, 0));221222/* Use byte-permute to filter out the bytes not needed by the current lane. */223nir_ssa_def *filtered_packed = nir_build_byte_permute_amd(b, packed_dw, nir_imm_int(b, 0), sel);224225/* Horizontally add the packed bytes. */226sum = nir_sad_u8x4(b, filtered_packed, nir_imm_int(b, 0), nir_imm_int(b, 0));227} else if (num_lds_dwords == 2) {228/* Create selectors for the byte-permutes below. */229nir_ssa_def *dw0_selector = nir_build_lane_permute_16_amd(b, sel, nir_imm_int(b, 0x44443210), nir_imm_int(b, 0x4));230nir_ssa_def *dw1_selector = nir_build_lane_permute_16_amd(b, sel, nir_imm_int(b, 0x32100000), nir_imm_int(b, 0x4));231232/* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */233nir_ssa_def *packed_dw0 = nir_build_lane_permute_16_amd(b, nir_unpack_64_2x32_split_x(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));234nir_ssa_def *packed_dw1 = nir_build_lane_permute_16_amd(b, nir_unpack_64_2x32_split_y(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));235236/* Use byte-permute to filter out the bytes not needed by the current lane. */237nir_ssa_def *filtered_packed_dw0 = nir_build_byte_permute_amd(b, packed_dw0, nir_imm_int(b, 0), dw0_selector);238nir_ssa_def *filtered_packed_dw1 = nir_build_byte_permute_amd(b, packed_dw1, nir_imm_int(b, 0), dw1_selector);239240/* Horizontally add the packed bytes. */241sum = nir_sad_u8x4(b, filtered_packed_dw0, nir_imm_int(b, 0), nir_imm_int(b, 0));242sum = nir_sad_u8x4(b, filtered_packed_dw1, nir_imm_int(b, 0), sum);243} else {244unreachable("Unimplemented NGG wave count");245}246247nir_ssa_def *wg_repacked_index_base = nir_build_read_invocation(b, sum, wave_id);248nir_ssa_def *wg_num_repacked_invocations = nir_build_read_invocation(b, sum, num_waves);249nir_ssa_def *wg_repacked_index = nir_build_mbcnt_amd(b, input_mask, wg_repacked_index_base);250251wg_repack_result r = {252.num_repacked_invocations = wg_num_repacked_invocations,253.repacked_invocation_index = wg_repacked_index,254};255256return r;257}258259static nir_ssa_def *260pervertex_lds_addr(nir_builder *b, nir_ssa_def *vertex_idx, unsigned per_vtx_bytes)261{262return nir_imul_imm(b, vertex_idx, per_vtx_bytes);263}264265static nir_ssa_def *266emit_pack_ngg_prim_exp_arg(nir_builder *b, unsigned num_vertices_per_primitives,267nir_ssa_def *vertex_indices[3], nir_ssa_def *is_null_prim)268{269nir_ssa_def *arg = vertex_indices[0];270271for (unsigned i = 0; i < num_vertices_per_primitives; ++i) {272assert(vertex_indices[i]);273274if (i)275arg = nir_ior(b, arg, nir_ishl(b, vertex_indices[i], nir_imm_int(b, 10u * i)));276277if (b->shader->info.stage == MESA_SHADER_VERTEX) {278nir_ssa_def *edgeflag = nir_build_load_initial_edgeflag_amd(b, 32, nir_imm_int(b, i));279arg = nir_ior(b, arg, nir_ishl(b, edgeflag, nir_imm_int(b, 10u * i + 9u)));280}281}282283if (is_null_prim) {284if (is_null_prim->bit_size == 1)285is_null_prim = nir_b2i32(b, is_null_prim);286assert(is_null_prim->bit_size == 32);287arg = nir_ior(b, arg, nir_ishl(b, is_null_prim, nir_imm_int(b, 31u)));288}289290return arg;291}292293static nir_ssa_def *294ngg_input_primitive_vertex_index(nir_builder *b, unsigned vertex)295{296/* TODO: This is RADV specific. We'll need to refactor RADV and/or RadeonSI to match. */297return nir_ubfe(b, nir_build_load_gs_vertex_offset_amd(b, .base = vertex / 2u * 2u),298nir_imm_int(b, (vertex % 2u) * 16u), nir_imm_int(b, 16u));299}300301static nir_ssa_def *302emit_ngg_nogs_prim_exp_arg(nir_builder *b, lower_ngg_nogs_state *st)303{304if (st->passthrough) {305assert(!st->export_prim_id || b->shader->info.stage != MESA_SHADER_VERTEX);306return nir_build_load_packed_passthrough_primitive_amd(b);307} else {308nir_ssa_def *vtx_idx[3] = {0};309310vtx_idx[0] = ngg_input_primitive_vertex_index(b, 0);311vtx_idx[1] = st->num_vertices_per_primitives >= 2312? ngg_input_primitive_vertex_index(b, 1)313: nir_imm_zero(b, 1, 32);314vtx_idx[2] = st->num_vertices_per_primitives >= 3315? ngg_input_primitive_vertex_index(b, 2)316: nir_imm_zero(b, 1, 32);317318return emit_pack_ngg_prim_exp_arg(b, st->num_vertices_per_primitives, vtx_idx, NULL);319}320}321322static void323emit_ngg_nogs_prim_export(nir_builder *b, lower_ngg_nogs_state *st, nir_ssa_def *arg)324{325nir_if *if_gs_thread = nir_push_if(b, nir_build_has_input_primitive_amd(b));326{327if (!arg)328arg = emit_ngg_nogs_prim_exp_arg(b, st);329330if (st->export_prim_id && b->shader->info.stage == MESA_SHADER_VERTEX) {331/* Copy Primitive IDs from GS threads to the LDS address corresponding to the ES thread of the provoking vertex. */332nir_ssa_def *prim_id = nir_build_load_primitive_id(b);333nir_ssa_def *provoking_vtx_idx = ngg_input_primitive_vertex_index(b, st->provoking_vtx_idx);334nir_ssa_def *addr = pervertex_lds_addr(b, provoking_vtx_idx, 4u);335336nir_build_store_shared(b, prim_id, addr, .write_mask = 1u, .align_mul = 4u);337}338339nir_build_export_primitive_amd(b, arg);340}341nir_pop_if(b, if_gs_thread);342}343344static void345emit_store_ngg_nogs_es_primitive_id(nir_builder *b)346{347nir_ssa_def *prim_id = NULL;348349if (b->shader->info.stage == MESA_SHADER_VERTEX) {350/* Workgroup barrier - wait for GS threads to store primitive ID in LDS. */351nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP, .memory_scope = NIR_SCOPE_WORKGROUP,352.memory_semantics = NIR_MEMORY_ACQ_REL, .memory_modes = nir_var_mem_shared);353354/* LDS address where the primitive ID is stored */355nir_ssa_def *thread_id_in_threadgroup = nir_build_load_local_invocation_index(b);356nir_ssa_def *addr = pervertex_lds_addr(b, thread_id_in_threadgroup, 4u);357358/* Load primitive ID from LDS */359prim_id = nir_build_load_shared(b, 1, 32, addr, .align_mul = 4u);360} else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {361/* Just use tess eval primitive ID, which is the same as the patch ID. */362prim_id = nir_build_load_primitive_id(b);363}364365nir_io_semantics io_sem = {366.location = VARYING_SLOT_PRIMITIVE_ID,367.num_slots = 1,368};369370nir_build_store_output(b, prim_id, nir_imm_zero(b, 1, 32),371.base = io_sem.location,372.write_mask = 1u, .src_type = nir_type_uint32, .io_semantics = io_sem);373}374375static bool376remove_culling_shader_output(nir_builder *b, nir_instr *instr, void *state)377{378remove_culling_shader_outputs_state *s = (remove_culling_shader_outputs_state *) state;379380if (instr->type != nir_instr_type_intrinsic)381return false;382383nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);384385/* These are not allowed in VS / TES */386assert(intrin->intrinsic != nir_intrinsic_store_per_vertex_output &&387intrin->intrinsic != nir_intrinsic_load_per_vertex_input);388389/* We are only interested in output stores now */390if (intrin->intrinsic != nir_intrinsic_store_output)391return false;392393b->cursor = nir_before_instr(instr);394395/* Position output - store the value to a variable, remove output store */396nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);397if (io_sem.location == VARYING_SLOT_POS) {398/* TODO: check if it's indirect, etc? */399unsigned writemask = nir_intrinsic_write_mask(intrin);400nir_ssa_def *store_val = intrin->src[0].ssa;401nir_store_var(b, s->pre_cull_position_value_var, store_val, writemask);402}403404/* Remove all output stores */405nir_instr_remove(instr);406return true;407}408409static void410remove_culling_shader_outputs(nir_shader *culling_shader, lower_ngg_nogs_state *nogs_state, nir_variable *pre_cull_position_value_var)411{412remove_culling_shader_outputs_state s = {413.pre_cull_position_value_var = pre_cull_position_value_var,414};415416nir_shader_instructions_pass(culling_shader, remove_culling_shader_output,417nir_metadata_block_index | nir_metadata_dominance, &s);418419/* Remove dead code resulting from the deleted outputs. */420bool progress;421do {422progress = false;423NIR_PASS(progress, culling_shader, nir_opt_dead_write_vars);424NIR_PASS(progress, culling_shader, nir_opt_dce);425NIR_PASS(progress, culling_shader, nir_opt_dead_cf);426} while (progress);427}428429static void430rewrite_uses_to_var(nir_builder *b, nir_ssa_def *old_def, nir_variable *replacement_var, unsigned replacement_var_channel)431{432if (old_def->parent_instr->type == nir_instr_type_load_const)433return;434435b->cursor = nir_after_instr(old_def->parent_instr);436if (b->cursor.instr->type == nir_instr_type_phi)437b->cursor = nir_after_phis(old_def->parent_instr->block);438439nir_ssa_def *pos_val_rep = nir_load_var(b, replacement_var);440nir_ssa_def *replacement = nir_channel(b, pos_val_rep, replacement_var_channel);441442if (old_def->num_components > 1) {443/* old_def uses a swizzled vector component.444* There is no way to replace the uses of just a single vector component,445* so instead create a new vector and replace all uses of the old vector.446*/447nir_ssa_def *old_def_elements[NIR_MAX_VEC_COMPONENTS] = {0};448for (unsigned j = 0; j < old_def->num_components; ++j)449old_def_elements[j] = nir_channel(b, old_def, j);450replacement = nir_vec(b, old_def_elements, old_def->num_components);451}452453nir_ssa_def_rewrite_uses_after(old_def, replacement, replacement->parent_instr);454}455456static bool457remove_extra_pos_output(nir_builder *b, nir_instr *instr, void *state)458{459remove_extra_position_output_state *s = (remove_extra_position_output_state *) state;460461if (instr->type != nir_instr_type_intrinsic)462return false;463464nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);465466/* These are not allowed in VS / TES */467assert(intrin->intrinsic != nir_intrinsic_store_per_vertex_output &&468intrin->intrinsic != nir_intrinsic_load_per_vertex_input);469470/* We are only interested in output stores now */471if (intrin->intrinsic != nir_intrinsic_store_output)472return false;473474nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);475if (io_sem.location != VARYING_SLOT_POS)476return false;477478b->cursor = nir_before_instr(instr);479480/* In case other outputs use what we calculated for pos,481* try to avoid calculating it again by rewriting the usages482* of the store components here.483*/484nir_ssa_def *store_val = intrin->src[0].ssa;485unsigned store_pos_component = nir_intrinsic_component(intrin);486487nir_instr_remove(instr);488489if (store_val->parent_instr->type == nir_instr_type_alu) {490nir_alu_instr *alu = nir_instr_as_alu(store_val->parent_instr);491if (nir_op_is_vec(alu->op)) {492/* Output store uses a vector, we can easily rewrite uses of each vector element. */493494unsigned num_vec_src = 0;495if (alu->op == nir_op_mov)496num_vec_src = 1;497else if (alu->op == nir_op_vec2)498num_vec_src = 2;499else if (alu->op == nir_op_vec3)500num_vec_src = 3;501else if (alu->op == nir_op_vec4)502num_vec_src = 4;503assert(num_vec_src);504505/* Remember the current components whose uses we wish to replace.506* This is needed because rewriting one source can affect the others too.507*/508nir_ssa_def *vec_comps[NIR_MAX_VEC_COMPONENTS] = {0};509for (unsigned i = 0; i < num_vec_src; i++)510vec_comps[i] = alu->src[i].src.ssa;511512for (unsigned i = 0; i < num_vec_src; i++)513rewrite_uses_to_var(b, vec_comps[i], s->pos_value_replacement, store_pos_component + i);514} else {515rewrite_uses_to_var(b, store_val, s->pos_value_replacement, store_pos_component);516}517} else {518rewrite_uses_to_var(b, store_val, s->pos_value_replacement, store_pos_component);519}520521return true;522}523524static void525remove_extra_pos_outputs(nir_shader *shader, lower_ngg_nogs_state *nogs_state)526{527remove_extra_position_output_state s = {528.pos_value_replacement = nogs_state->position_value_var,529};530531nir_shader_instructions_pass(shader, remove_extra_pos_output,532nir_metadata_block_index | nir_metadata_dominance, &s);533}534535/**536* Perform vertex compaction after culling.537*538* 1. Repack surviving ES invocations (this determines which lane will export which vertex)539* 2. Surviving ES vertex invocations store their data to LDS540* 3. Emit GS_ALLOC_REQ541* 4. Repacked invocations load the vertex data from LDS542* 5. GS threads update their vertex indices543*/544static void545compact_vertices_after_culling(nir_builder *b,546lower_ngg_nogs_state *nogs_state,547nir_variable *vertices_in_wave_var,548nir_variable *primitives_in_wave_var,549nir_variable **repacked_arg_vars,550nir_variable **gs_vtxaddr_vars,551nir_ssa_def *invocation_index,552nir_ssa_def *es_vertex_lds_addr,553unsigned ngg_scratch_lds_base_addr,554unsigned pervertex_lds_bytes,555unsigned max_exported_args)556{557nir_variable *es_accepted_var = nogs_state->es_accepted_var;558nir_variable *gs_accepted_var = nogs_state->gs_accepted_var;559nir_variable *position_value_var = nogs_state->position_value_var;560nir_variable *prim_exp_arg_var = nogs_state->prim_exp_arg_var;561562nir_ssa_def *es_accepted = nir_load_var(b, es_accepted_var);563564/* Repack the vertices that survived the culling. */565wg_repack_result rep = repack_invocations_in_workgroup(b, es_accepted, ngg_scratch_lds_base_addr,566nogs_state->max_num_waves, nogs_state->wave_size);567nir_ssa_def *num_live_vertices_in_workgroup = rep.num_repacked_invocations;568nir_ssa_def *es_exporter_tid = rep.repacked_invocation_index;569570nir_if *if_es_accepted = nir_push_if(b, es_accepted);571{572nir_ssa_def *exporter_addr = pervertex_lds_addr(b, es_exporter_tid, pervertex_lds_bytes);573574/* Store the exporter thread's index to the LDS space of the current thread so GS threads can load it */575nir_build_store_shared(b, nir_u2u8(b, es_exporter_tid), es_vertex_lds_addr, .base = lds_es_exporter_tid, .align_mul = 1u, .write_mask = 0x1u);576577/* Store the current thread's position output to the exporter thread's LDS space */578nir_ssa_def *pos = nir_load_var(b, position_value_var);579nir_build_store_shared(b, pos, exporter_addr, .base = lds_es_pos_x, .align_mul = 4u, .write_mask = 0xfu);580581/* Store the current thread's repackable arguments to the exporter thread's LDS space */582for (unsigned i = 0; i < max_exported_args; ++i) {583nir_ssa_def *arg_val = nir_load_var(b, repacked_arg_vars[i]);584nir_build_store_shared(b, arg_val, exporter_addr, .base = lds_es_arg_0 + 4u * i, .align_mul = 4u, .write_mask = 0x1u);585}586}587nir_pop_if(b, if_es_accepted);588589/* If all vertices are culled, set primitive count to 0 as well. */590nir_ssa_def *num_exported_prims = nir_build_load_workgroup_num_input_primitives_amd(b);591num_exported_prims = nir_bcsel(b, nir_ieq_imm(b, num_live_vertices_in_workgroup, 0u), nir_imm_int(b, 0u), num_exported_prims);592593nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_build_load_subgroup_id(b), nir_imm_int(b, 0)));594{595/* Tell the final vertex and primitive count to the HW.596* We do this here to mask some of the latency of the LDS.597*/598nir_build_alloc_vertices_and_primitives_amd(b, num_live_vertices_in_workgroup, num_exported_prims);599}600nir_pop_if(b, if_wave_0);601602/* Calculate the number of vertices and primitives left in the current wave */603nir_ssa_def *has_vtx_after_culling = nir_ilt(b, invocation_index, num_live_vertices_in_workgroup);604nir_ssa_def *has_prm_after_culling = nir_ilt(b, invocation_index, num_exported_prims);605nir_ssa_def *vtx_cnt = nir_bit_count(b, nir_build_ballot(b, 1, nogs_state->wave_size, has_vtx_after_culling));606nir_ssa_def *prm_cnt = nir_bit_count(b, nir_build_ballot(b, 1, nogs_state->wave_size, has_prm_after_culling));607nir_store_var(b, vertices_in_wave_var, vtx_cnt, 0x1u);608nir_store_var(b, primitives_in_wave_var, prm_cnt, 0x1u);609610/* TODO: Consider adding a shortcut exit.611* Waves that have no vertices and primitives left can s_endpgm right here.612*/613614nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,615.memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);616617nir_if *if_packed_es_thread = nir_push_if(b, nir_ilt(b, invocation_index, num_live_vertices_in_workgroup));618{619/* Read position from the current ES thread's LDS space (written by the exported vertex's ES thread) */620nir_ssa_def *exported_pos = nir_build_load_shared(b, 4, 32, es_vertex_lds_addr, .base = lds_es_pos_x, .align_mul = 4u);621nir_store_var(b, position_value_var, exported_pos, 0xfu);622623/* Read the repacked arguments */624for (unsigned i = 0; i < max_exported_args; ++i) {625nir_ssa_def *arg_val = nir_build_load_shared(b, 1, 32, es_vertex_lds_addr, .base = lds_es_arg_0 + 4u * i, .align_mul = 4u);626nir_store_var(b, repacked_arg_vars[i], arg_val, 0x1u);627}628}629nir_pop_if(b, if_packed_es_thread);630631nir_if *if_gs_accepted = nir_push_if(b, nir_load_var(b, gs_accepted_var));632{633nir_ssa_def *exporter_vtx_indices[3] = {0};634635/* Load the index of the ES threads that will export the current GS thread's vertices */636for (unsigned v = 0; v < 3; ++v) {637nir_ssa_def *vtx_addr = nir_load_var(b, gs_vtxaddr_vars[v]);638nir_ssa_def *exporter_vtx_idx = nir_build_load_shared(b, 1, 8, vtx_addr, .base = lds_es_exporter_tid, .align_mul = 1u);639exporter_vtx_indices[v] = nir_u2u32(b, exporter_vtx_idx);640}641642nir_ssa_def *prim_exp_arg = emit_pack_ngg_prim_exp_arg(b, 3, exporter_vtx_indices, NULL);643nir_store_var(b, prim_exp_arg_var, prim_exp_arg, 0x1u);644}645nir_pop_if(b, if_gs_accepted);646}647648static void649analyze_shader_before_culling_walk(nir_ssa_def *ssa,650uint8_t flag,651lower_ngg_nogs_state *nogs_state)652{653nir_instr *instr = ssa->parent_instr;654uint8_t old_pass_flags = instr->pass_flags;655instr->pass_flags |= flag;656657if (instr->pass_flags == old_pass_flags)658return; /* Already visited. */659660switch (instr->type) {661case nir_instr_type_intrinsic: {662nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);663664/* VS input loads and SSBO loads are actually VRAM reads on AMD HW. */665switch (intrin->intrinsic) {666case nir_intrinsic_load_input: {667nir_io_semantics in_io_sem = nir_intrinsic_io_semantics(intrin);668uint64_t in_mask = UINT64_C(1) << (uint64_t) in_io_sem.location;669if (instr->pass_flags & nggc_passflag_used_by_pos)670nogs_state->inputs_needed_by_pos |= in_mask;671else if (instr->pass_flags & nggc_passflag_used_by_other)672nogs_state->inputs_needed_by_others |= in_mask;673break;674}675default:676break;677}678679break;680}681case nir_instr_type_alu: {682nir_alu_instr *alu = nir_instr_as_alu(instr);683unsigned num_srcs = nir_op_infos[alu->op].num_inputs;684685for (unsigned i = 0; i < num_srcs; ++i) {686analyze_shader_before_culling_walk(alu->src[i].src.ssa, flag, nogs_state);687}688689break;690}691case nir_instr_type_phi: {692nir_phi_instr *phi = nir_instr_as_phi(instr);693nir_foreach_phi_src_safe(phi_src, phi) {694analyze_shader_before_culling_walk(phi_src->src.ssa, flag, nogs_state);695}696697break;698}699default:700break;701}702}703704static void705analyze_shader_before_culling(nir_shader *shader, lower_ngg_nogs_state *nogs_state)706{707nir_foreach_function(func, shader) {708nir_foreach_block(block, func->impl) {709nir_foreach_instr(instr, block) {710instr->pass_flags = 0;711712if (instr->type != nir_instr_type_intrinsic)713continue;714715nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);716if (intrin->intrinsic != nir_intrinsic_store_output)717continue;718719nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);720nir_ssa_def *store_val = intrin->src[0].ssa;721uint8_t flag = io_sem.location == VARYING_SLOT_POS ? nggc_passflag_used_by_pos : nggc_passflag_used_by_other;722analyze_shader_before_culling_walk(store_val, flag, nogs_state);723}724}725}726}727728/**729* Save the reusable SSA definitions to variables so that the730* bottom shader part can reuse them from the top part.731*732* 1. We create a new function temporary variable for reusables,733* and insert a store+load.734* 2. The shader is cloned (the top part is created), then the735* control flow is reinserted (for the bottom part.)736* 3. For reusables, we delete the variable stores from the737* bottom part. This will make them use the variables from738* the top part and DCE the redundant instructions.739*/740static void741save_reusable_variables(nir_builder *b, lower_ngg_nogs_state *nogs_state)742{743ASSERTED int vec_ok = u_vector_init(&nogs_state->saved_uniforms, sizeof(saved_uniform), 4 * sizeof(saved_uniform));744assert(vec_ok);745746unsigned loop_depth = 0;747748nir_foreach_block_safe(block, b->impl) {749/* Check whether we're in a loop. */750nir_cf_node *next_cf_node = nir_cf_node_next(&block->cf_node);751nir_cf_node *prev_cf_node = nir_cf_node_prev(&block->cf_node);752if (next_cf_node && next_cf_node->type == nir_cf_node_loop)753loop_depth++;754if (prev_cf_node && prev_cf_node->type == nir_cf_node_loop)755loop_depth--;756757/* The following code doesn't make sense in loops, so just skip it then. */758if (loop_depth)759continue;760761nir_foreach_instr_safe(instr, block) {762/* Find instructions whose SSA definitions are used by both763* the top and bottom parts of the shader. In this case, it764* makes sense to try to reuse these from the top part.765*/766if ((instr->pass_flags & nggc_passflag_used_by_both) != nggc_passflag_used_by_both)767continue;768769nir_ssa_def *ssa = NULL;770771switch (instr->type) {772case nir_instr_type_alu: {773nir_alu_instr *alu = nir_instr_as_alu(instr);774if (alu->dest.dest.ssa.divergent)775continue;776/* Ignore uniform floats because they regress VGPR usage too much */777if (nir_op_infos[alu->op].output_type & nir_type_float)778continue;779ssa = &alu->dest.dest.ssa;780break;781}782case nir_instr_type_intrinsic: {783nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);784if (!nir_intrinsic_can_reorder(intrin) ||785!nir_intrinsic_infos[intrin->intrinsic].has_dest ||786intrin->dest.ssa.divergent)787continue;788ssa = &intrin->dest.ssa;789break;790}791case nir_instr_type_phi: {792nir_phi_instr *phi = nir_instr_as_phi(instr);793if (phi->dest.ssa.divergent)794continue;795ssa = &phi->dest.ssa;796break;797}798default:799continue;800}801802assert(ssa);803804enum glsl_base_type base_type = GLSL_TYPE_UINT;805switch (ssa->bit_size) {806case 8: base_type = GLSL_TYPE_UINT8; break;807case 16: base_type = GLSL_TYPE_UINT16; break;808case 32: base_type = GLSL_TYPE_UINT; break;809case 64: base_type = GLSL_TYPE_UINT64; break;810default: continue;811}812813const struct glsl_type *t = ssa->num_components == 1814? glsl_scalar_type(base_type)815: glsl_vector_type(base_type, ssa->num_components);816817saved_uniform *saved = (saved_uniform *) u_vector_add(&nogs_state->saved_uniforms);818assert(saved);819820saved->var = nir_local_variable_create(b->impl, t, NULL);821saved->ssa = ssa;822823b->cursor = instr->type == nir_instr_type_phi824? nir_after_instr_and_phis(instr)825: nir_after_instr(instr);826nir_store_var(b, saved->var, saved->ssa, BITFIELD_MASK(ssa->num_components));827nir_ssa_def *reloaded = nir_load_var(b, saved->var);828nir_ssa_def_rewrite_uses_after(ssa, reloaded, reloaded->parent_instr);829}830}831}832833/**834* Reuses suitable variables from the top part of the shader,835* by deleting their stores from the bottom part.836*/837static void838apply_reusable_variables(nir_builder *b, lower_ngg_nogs_state *nogs_state)839{840if (!u_vector_length(&nogs_state->saved_uniforms)) {841u_vector_finish(&nogs_state->saved_uniforms);842return;843}844845nir_foreach_block_reverse_safe(block, b->impl) {846nir_foreach_instr_reverse_safe(instr, block) {847if (instr->type != nir_instr_type_intrinsic)848continue;849nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);850851/* When we found any of these intrinsics, it means852* we reached the top part and we must stop.853*/854if (intrin->intrinsic == nir_intrinsic_overwrite_subgroup_num_vertices_and_primitives_amd ||855intrin->intrinsic == nir_intrinsic_alloc_vertices_and_primitives_amd ||856intrin->intrinsic == nir_intrinsic_export_primitive_amd)857goto done;858859if (intrin->intrinsic != nir_intrinsic_store_deref)860continue;861nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);862if (deref->deref_type != nir_deref_type_var)863continue;864865saved_uniform *saved;866u_vector_foreach(saved, &nogs_state->saved_uniforms) {867if (saved->var == deref->var) {868nir_instr_remove(instr);869}870}871}872}873874done:875u_vector_finish(&nogs_state->saved_uniforms);876}877878static void879add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_cf, lower_ngg_nogs_state *nogs_state)880{881assert(b->shader->info.outputs_written & (1 << VARYING_SLOT_POS));882883bool uses_instance_id = BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_INSTANCE_ID);884bool uses_tess_primitive_id = BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_PRIMITIVE_ID);885886unsigned max_exported_args = b->shader->info.stage == MESA_SHADER_VERTEX ? 2 : 4;887if (b->shader->info.stage == MESA_SHADER_VERTEX && !uses_instance_id)888max_exported_args--;889else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL && !uses_tess_primitive_id)890max_exported_args--;891892unsigned pervertex_lds_bytes = lds_es_arg_0 + max_exported_args * 4u;893unsigned total_es_lds_bytes = pervertex_lds_bytes * nogs_state->max_es_num_vertices;894unsigned max_num_waves = nogs_state->max_num_waves;895unsigned ngg_scratch_lds_base_addr = ALIGN(total_es_lds_bytes, 8u);896unsigned ngg_scratch_lds_bytes = DIV_ROUND_UP(max_num_waves, 4u);897nogs_state->total_lds_bytes = ngg_scratch_lds_base_addr + ngg_scratch_lds_bytes;898899nir_function_impl *impl = nir_shader_get_entrypoint(b->shader);900901/* Create some helper variables. */902nir_variable *position_value_var = nogs_state->position_value_var;903nir_variable *prim_exp_arg_var = nogs_state->prim_exp_arg_var;904nir_variable *gs_accepted_var = nogs_state->gs_accepted_var;905nir_variable *es_accepted_var = nogs_state->es_accepted_var;906nir_variable *vertices_in_wave_var = nir_local_variable_create(impl, glsl_uint_type(), "vertices_in_wave");907nir_variable *primitives_in_wave_var = nir_local_variable_create(impl, glsl_uint_type(), "primitives_in_wave");908nir_variable *gs_vtxaddr_vars[3] = {909nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx0_addr"),910nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx1_addr"),911nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx2_addr"),912};913nir_variable *repacked_arg_vars[4] = {914nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_0"),915nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_1"),916nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_2"),917nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_3"),918};919920/* Top part of the culling shader (aka. position shader part)921*922* We clone the full ES shader and emit it here, but we only really care923* about its position output, so we delete every other output from this part.924* The position output is stored into a temporary variable, and reloaded later.925*/926927b->cursor = nir_before_cf_list(&impl->body);928929nir_if *if_es_thread = nir_push_if(b, nir_build_has_input_vertex_amd(b));930{931/* Initialize the position output variable to zeroes, in case not all VS/TES invocations store the output.932* The spec doesn't require it, but we use (0, 0, 0, 1) because some games rely on that.933*/934nir_store_var(b, position_value_var, nir_imm_vec4(b, 0.0f, 0.0f, 0.0f, 1.0f), 0xfu);935936/* Now reinsert a clone of the shader code */937struct hash_table *remap_table = _mesa_pointer_hash_table_create(NULL);938nir_cf_list_clone_and_reinsert(original_extracted_cf, &if_es_thread->cf_node, b->cursor, remap_table);939_mesa_hash_table_destroy(remap_table, NULL);940b->cursor = nir_after_cf_list(&if_es_thread->then_list);941942/* Remember the current thread's shader arguments */943if (b->shader->info.stage == MESA_SHADER_VERTEX) {944nir_store_var(b, repacked_arg_vars[0], nir_build_load_vertex_id_zero_base(b), 0x1u);945if (uses_instance_id)946nir_store_var(b, repacked_arg_vars[1], nir_build_load_instance_id(b), 0x1u);947} else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {948nir_ssa_def *tess_coord = nir_build_load_tess_coord(b);949nir_store_var(b, repacked_arg_vars[0], nir_channel(b, tess_coord, 0), 0x1u);950nir_store_var(b, repacked_arg_vars[1], nir_channel(b, tess_coord, 1), 0x1u);951nir_store_var(b, repacked_arg_vars[2], nir_build_load_tess_rel_patch_id_amd(b), 0x1u);952if (uses_tess_primitive_id)953nir_store_var(b, repacked_arg_vars[3], nir_build_load_primitive_id(b), 0x1u);954} else {955unreachable("Should be VS or TES.");956}957}958nir_pop_if(b, if_es_thread);959960/* Remove all non-position outputs, and put the position output into the variable. */961nir_metadata_preserve(impl, nir_metadata_none);962remove_culling_shader_outputs(b->shader, nogs_state, position_value_var);963b->cursor = nir_after_cf_list(&impl->body);964965/* Run culling algorithms if culling is enabled.966*967* NGG culling can be enabled or disabled in runtime.968* This is determined by a SGPR shader argument which is acccessed969* by the following NIR intrinsic.970*/971972nir_if *if_cull_en = nir_push_if(b, nir_build_load_cull_any_enabled_amd(b));973{974nir_ssa_def *invocation_index = nir_build_load_local_invocation_index(b);975nir_ssa_def *es_vertex_lds_addr = pervertex_lds_addr(b, invocation_index, pervertex_lds_bytes);976977/* ES invocations store their vertex data to LDS for GS threads to read. */978if_es_thread = nir_push_if(b, nir_build_has_input_vertex_amd(b));979{980/* Store position components that are relevant to culling in LDS */981nir_ssa_def *pre_cull_pos = nir_load_var(b, position_value_var);982nir_ssa_def *pre_cull_w = nir_channel(b, pre_cull_pos, 3);983nir_build_store_shared(b, pre_cull_w, es_vertex_lds_addr, .write_mask = 0x1u, .align_mul = 4, .base = lds_es_pos_w);984nir_ssa_def *pre_cull_x_div_w = nir_fdiv(b, nir_channel(b, pre_cull_pos, 0), pre_cull_w);985nir_ssa_def *pre_cull_y_div_w = nir_fdiv(b, nir_channel(b, pre_cull_pos, 1), pre_cull_w);986nir_build_store_shared(b, nir_vec2(b, pre_cull_x_div_w, pre_cull_y_div_w), es_vertex_lds_addr, .write_mask = 0x3u, .align_mul = 4, .base = lds_es_pos_x);987988/* Clear out the ES accepted flag in LDS */989nir_build_store_shared(b, nir_imm_zero(b, 1, 8), es_vertex_lds_addr, .write_mask = 0x1u, .align_mul = 4, .base = lds_es_vertex_accepted);990}991nir_pop_if(b, if_es_thread);992993nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,994.memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);995996nir_store_var(b, gs_accepted_var, nir_imm_bool(b, false), 0x1u);997nir_store_var(b, prim_exp_arg_var, nir_imm_int(b, 1 << 31), 0x1u);998999/* GS invocations load the vertex data and perform the culling. */1000nir_if *if_gs_thread = nir_push_if(b, nir_build_has_input_primitive_amd(b));1001{1002/* Load vertex indices from input VGPRs */1003nir_ssa_def *vtx_idx[3] = {0};1004for (unsigned vertex = 0; vertex < 3; ++vertex)1005vtx_idx[vertex] = ngg_input_primitive_vertex_index(b, vertex);10061007nir_ssa_def *vtx_addr[3] = {0};1008nir_ssa_def *pos[3][4] = {0};10091010/* Load W positions of vertices first because the culling code will use these first */1011for (unsigned vtx = 0; vtx < 3; ++vtx) {1012vtx_addr[vtx] = pervertex_lds_addr(b, vtx_idx[vtx], pervertex_lds_bytes);1013pos[vtx][3] = nir_build_load_shared(b, 1, 32, vtx_addr[vtx], .align_mul = 4u, .base = lds_es_pos_w);1014nir_store_var(b, gs_vtxaddr_vars[vtx], vtx_addr[vtx], 0x1u);1015}10161017/* Load the X/W, Y/W positions of vertices */1018for (unsigned vtx = 0; vtx < 3; ++vtx) {1019nir_ssa_def *xy = nir_build_load_shared(b, 2, 32, vtx_addr[vtx], .align_mul = 4u, .base = lds_es_pos_x);1020pos[vtx][0] = nir_channel(b, xy, 0);1021pos[vtx][1] = nir_channel(b, xy, 1);1022}10231024/* See if the current primitive is accepted */1025nir_ssa_def *accepted = ac_nir_cull_triangle(b, nir_imm_bool(b, true), pos);1026nir_store_var(b, gs_accepted_var, accepted, 0x1u);10271028nir_if *if_gs_accepted = nir_push_if(b, accepted);1029{1030/* Store the accepted state to LDS for ES threads */1031for (unsigned vtx = 0; vtx < 3; ++vtx)1032nir_build_store_shared(b, nir_imm_intN_t(b, 0xff, 8), vtx_addr[vtx], .base = lds_es_vertex_accepted, .align_mul = 4u, .write_mask = 0x1u);1033}1034nir_pop_if(b, if_gs_accepted);1035}1036nir_pop_if(b, if_gs_thread);10371038nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,1039.memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);10401041nir_store_var(b, es_accepted_var, nir_imm_bool(b, false), 0x1u);10421043/* ES invocations load their accepted flag from LDS. */1044if_es_thread = nir_push_if(b, nir_build_has_input_vertex_amd(b));1045{1046nir_ssa_def *accepted = nir_build_load_shared(b, 1, 8u, es_vertex_lds_addr, .base = lds_es_vertex_accepted, .align_mul = 4u);1047nir_ssa_def *accepted_bool = nir_ine(b, accepted, nir_imm_intN_t(b, 0, 8));1048nir_store_var(b, es_accepted_var, accepted_bool, 0x1u);1049}1050nir_pop_if(b, if_es_thread);10511052/* Vertex compaction. */1053compact_vertices_after_culling(b, nogs_state,1054vertices_in_wave_var, primitives_in_wave_var,1055repacked_arg_vars, gs_vtxaddr_vars,1056invocation_index, es_vertex_lds_addr,1057ngg_scratch_lds_base_addr, pervertex_lds_bytes, max_exported_args);1058}1059nir_push_else(b, if_cull_en);1060{1061/* When culling is disabled, we do the same as we would without culling. */1062nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_build_load_subgroup_id(b), nir_imm_int(b, 0)));1063{1064nir_ssa_def *vtx_cnt = nir_build_load_workgroup_num_input_vertices_amd(b);1065nir_ssa_def *prim_cnt = nir_build_load_workgroup_num_input_primitives_amd(b);1066nir_build_alloc_vertices_and_primitives_amd(b, vtx_cnt, prim_cnt);1067}1068nir_pop_if(b, if_wave_0);1069nir_store_var(b, prim_exp_arg_var, emit_ngg_nogs_prim_exp_arg(b, nogs_state), 0x1u);10701071nir_ssa_def *vtx_cnt = nir_bit_count(b, nir_build_ballot(b, 1, nogs_state->wave_size, nir_build_has_input_vertex_amd(b)));1072nir_ssa_def *prm_cnt = nir_bit_count(b, nir_build_ballot(b, 1, nogs_state->wave_size, nir_build_has_input_primitive_amd(b)));1073nir_store_var(b, vertices_in_wave_var, vtx_cnt, 0x1u);1074nir_store_var(b, primitives_in_wave_var, prm_cnt, 0x1u);1075}1076nir_pop_if(b, if_cull_en);10771078/* Update shader arguments.1079*1080* The registers which hold information about the subgroup's1081* vertices and primitives are updated here, so the rest of the shader1082* doesn't need to worry about the culling.1083*1084* These "overwrite" intrinsics must be at top level control flow,1085* otherwise they can mess up the backend (eg. ACO's SSA).1086*1087* TODO:1088* A cleaner solution would be to simply replace all usages of these args1089* with the load of the variables.1090* However, this wouldn't work right now because the backend uses the arguments1091* for purposes not expressed in NIR, eg. VS input loads, etc.1092* This can change if VS input loads and other stuff are lowered to eg. load_buffer_amd.1093*/10941095if (b->shader->info.stage == MESA_SHADER_VERTEX)1096nir_build_overwrite_vs_arguments_amd(b,1097nir_load_var(b, repacked_arg_vars[0]), nir_load_var(b, repacked_arg_vars[1]));1098else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL)1099nir_build_overwrite_tes_arguments_amd(b,1100nir_load_var(b, repacked_arg_vars[0]), nir_load_var(b, repacked_arg_vars[1]),1101nir_load_var(b, repacked_arg_vars[2]), nir_load_var(b, repacked_arg_vars[3]));1102else1103unreachable("Should be VS or TES.");11041105nir_ssa_def *vertices_in_wave = nir_load_var(b, vertices_in_wave_var);1106nir_ssa_def *primitives_in_wave = nir_load_var(b, primitives_in_wave_var);1107nir_build_overwrite_subgroup_num_vertices_and_primitives_amd(b, vertices_in_wave, primitives_in_wave);1108}11091110static bool1111can_use_deferred_attribute_culling(nir_shader *shader)1112{1113/* When the shader writes memory, it is difficult to guarantee correctness.1114* Future work:1115* - if only write-only SSBOs are used1116* - if we can prove that non-position outputs don't rely on memory stores1117* then may be okay to keep the memory stores in the 1st shader part, and delete them from the 2nd.1118*/1119if (shader->info.writes_memory)1120return false;11211122/* When the shader relies on the subgroup invocation ID, we'd break it, because the ID changes after the culling.1123* Future work: try to save this to LDS and reload, but it can still be broken in subtle ways.1124*/1125if (BITSET_TEST(shader->info.system_values_read, SYSTEM_VALUE_SUBGROUP_INVOCATION))1126return false;11271128return true;1129}11301131ac_nir_ngg_config1132ac_nir_lower_ngg_nogs(nir_shader *shader,1133unsigned max_num_es_vertices,1134unsigned num_vertices_per_primitives,1135unsigned max_workgroup_size,1136unsigned wave_size,1137bool consider_culling,1138bool consider_passthrough,1139bool export_prim_id,1140bool provoking_vtx_last)1141{1142nir_function_impl *impl = nir_shader_get_entrypoint(shader);1143assert(impl);1144assert(max_num_es_vertices && max_workgroup_size && wave_size);11451146bool can_cull = consider_culling && (num_vertices_per_primitives == 3) &&1147can_use_deferred_attribute_culling(shader);1148bool passthrough = consider_passthrough && !can_cull &&1149!(shader->info.stage == MESA_SHADER_VERTEX && export_prim_id);11501151nir_variable *position_value_var = nir_local_variable_create(impl, glsl_vec4_type(), "position_value");1152nir_variable *prim_exp_arg_var = nir_local_variable_create(impl, glsl_uint_type(), "prim_exp_arg");1153nir_variable *es_accepted_var = can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "es_accepted") : NULL;1154nir_variable *gs_accepted_var = can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "gs_accepted") : NULL;11551156lower_ngg_nogs_state state = {1157.passthrough = passthrough,1158.export_prim_id = export_prim_id,1159.early_prim_export = exec_list_is_singular(&impl->body),1160.num_vertices_per_primitives = num_vertices_per_primitives,1161.provoking_vtx_idx = provoking_vtx_last ? (num_vertices_per_primitives - 1) : 0,1162.position_value_var = position_value_var,1163.prim_exp_arg_var = prim_exp_arg_var,1164.es_accepted_var = es_accepted_var,1165.gs_accepted_var = gs_accepted_var,1166.max_num_waves = DIV_ROUND_UP(max_workgroup_size, wave_size),1167.max_es_num_vertices = max_num_es_vertices,1168.wave_size = wave_size,1169};11701171/* We need LDS space when VS needs to export the primitive ID. */1172if (shader->info.stage == MESA_SHADER_VERTEX && export_prim_id)1173state.total_lds_bytes = max_num_es_vertices * 4u;11741175/* The shader only needs this much LDS when culling is turned off. */1176unsigned lds_bytes_if_culling_off = state.total_lds_bytes;11771178nir_builder builder;1179nir_builder *b = &builder; /* This is to avoid the & */1180nir_builder_init(b, impl);11811182if (can_cull) {1183/* We need divergence info for culling shaders. */1184nir_divergence_analysis(shader);1185analyze_shader_before_culling(shader, &state);1186save_reusable_variables(b, &state);1187}11881189nir_cf_list extracted;1190nir_cf_extract(&extracted, nir_before_cf_list(&impl->body), nir_after_cf_list(&impl->body));1191b->cursor = nir_before_cf_list(&impl->body);11921193if (!can_cull) {1194/* Allocate export space on wave 0 - confirm to the HW that we want to use all possible space */1195nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_build_load_subgroup_id(b), nir_imm_int(b, 0)));1196{1197nir_ssa_def *vtx_cnt = nir_build_load_workgroup_num_input_vertices_amd(b);1198nir_ssa_def *prim_cnt = nir_build_load_workgroup_num_input_primitives_amd(b);1199nir_build_alloc_vertices_and_primitives_amd(b, vtx_cnt, prim_cnt);1200}1201nir_pop_if(b, if_wave_0);12021203/* Take care of early primitive export, otherwise just pack the primitive export argument */1204if (state.early_prim_export)1205emit_ngg_nogs_prim_export(b, &state, NULL);1206else1207nir_store_var(b, prim_exp_arg_var, emit_ngg_nogs_prim_exp_arg(b, &state), 0x1u);1208} else {1209add_deferred_attribute_culling(b, &extracted, &state);1210b->cursor = nir_after_cf_list(&impl->body);12111212if (state.early_prim_export)1213emit_ngg_nogs_prim_export(b, &state, nir_load_var(b, state.prim_exp_arg_var));1214}12151216nir_intrinsic_instr *export_vertex_instr;12171218nir_if *if_es_thread = nir_push_if(b, nir_build_has_input_vertex_amd(b));1219{1220/* Run the actual shader */1221nir_cf_reinsert(&extracted, b->cursor);1222b->cursor = nir_after_cf_list(&if_es_thread->then_list);12231224/* Export all vertex attributes (except primitive ID) */1225export_vertex_instr = nir_build_export_vertex_amd(b);12261227/* Export primitive ID (in case of early primitive export or TES) */1228if (state.export_prim_id && (state.early_prim_export || shader->info.stage != MESA_SHADER_VERTEX))1229emit_store_ngg_nogs_es_primitive_id(b);1230}1231nir_pop_if(b, if_es_thread);12321233/* Take care of late primitive export */1234if (!state.early_prim_export) {1235emit_ngg_nogs_prim_export(b, &state, nir_load_var(b, prim_exp_arg_var));1236if (state.export_prim_id && shader->info.stage == MESA_SHADER_VERTEX) {1237if_es_thread = nir_push_if(b, nir_build_has_input_vertex_amd(b));1238emit_store_ngg_nogs_es_primitive_id(b);1239nir_pop_if(b, if_es_thread);1240}1241}12421243if (can_cull) {1244/* Replace uniforms. */1245apply_reusable_variables(b, &state);12461247/* Remove the redundant position output. */1248remove_extra_pos_outputs(shader, &state);12491250/* After looking at the performance in apps eg. Doom Eternal, and The Witcher 3,1251* it seems that it's best to put the position export always at the end, and1252* then let ACO schedule it up (slightly) only when early prim export is used.1253*/1254b->cursor = nir_before_instr(&export_vertex_instr->instr);12551256nir_ssa_def *pos_val = nir_load_var(b, state.position_value_var);1257nir_io_semantics io_sem = { .location = VARYING_SLOT_POS, .num_slots = 1 };1258nir_build_store_output(b, pos_val, nir_imm_int(b, 0), .base = VARYING_SLOT_POS, .component = 0, .io_semantics = io_sem, .write_mask = 0xfu);1259}12601261nir_metadata_preserve(impl, nir_metadata_none);1262nir_validate_shader(shader, "after emitting NGG VS/TES");12631264/* Cleanup */1265nir_opt_dead_write_vars(shader);1266nir_lower_vars_to_ssa(shader);1267nir_remove_dead_variables(shader, nir_var_function_temp, NULL);1268nir_lower_alu_to_scalar(shader, NULL, NULL);1269nir_lower_phis_to_scalar(shader, true);12701271if (can_cull) {1272/* It's beneficial to redo these opts after splitting the shader. */1273nir_opt_sink(shader, nir_move_load_input | nir_move_const_undef | nir_move_copies);1274nir_opt_move(shader, nir_move_load_input | nir_move_copies | nir_move_const_undef);1275}12761277bool progress;1278do {1279progress = false;1280NIR_PASS(progress, shader, nir_opt_undef);1281NIR_PASS(progress, shader, nir_opt_cse);1282NIR_PASS(progress, shader, nir_opt_dce);1283NIR_PASS(progress, shader, nir_opt_dead_cf);1284} while (progress);12851286shader->info.shared_size = state.total_lds_bytes;12871288ac_nir_ngg_config ret = {1289.lds_bytes_if_culling_off = lds_bytes_if_culling_off,1290.can_cull = can_cull,1291.passthrough = passthrough,1292.early_prim_export = state.early_prim_export,1293.nggc_inputs_read_by_pos = state.inputs_needed_by_pos,1294.nggc_inputs_read_by_others = state.inputs_needed_by_others,1295};12961297return ret;1298}12991300static nir_ssa_def *1301ngg_gs_out_vertex_addr(nir_builder *b, nir_ssa_def *out_vtx_idx, lower_ngg_gs_state *s)1302{1303unsigned write_stride_2exp = ffs(MAX2(b->shader->info.gs.vertices_out, 1)) - 1;13041305/* gs_max_out_vertices = 2^(write_stride_2exp) * some odd number */1306if (write_stride_2exp) {1307nir_ssa_def *row = nir_ushr_imm(b, out_vtx_idx, 5);1308nir_ssa_def *swizzle = nir_iand_imm(b, row, (1u << write_stride_2exp) - 1u);1309out_vtx_idx = nir_ixor(b, out_vtx_idx, swizzle);1310}13111312nir_ssa_def *out_vtx_offs = nir_imul_imm(b, out_vtx_idx, s->lds_bytes_per_gs_out_vertex);1313return nir_iadd_imm_nuw(b, out_vtx_offs, s->lds_addr_gs_out_vtx);1314}13151316static nir_ssa_def *1317ngg_gs_emit_vertex_addr(nir_builder *b, nir_ssa_def *gs_vtx_idx, lower_ngg_gs_state *s)1318{1319nir_ssa_def *tid_in_tg = nir_build_load_local_invocation_index(b);1320nir_ssa_def *gs_out_vtx_base = nir_imul_imm(b, tid_in_tg, b->shader->info.gs.vertices_out);1321nir_ssa_def *out_vtx_idx = nir_iadd_nuw(b, gs_out_vtx_base, gs_vtx_idx);13221323return ngg_gs_out_vertex_addr(b, out_vtx_idx, s);1324}13251326static void1327ngg_gs_clear_primflags(nir_builder *b, nir_ssa_def *num_vertices, unsigned stream, lower_ngg_gs_state *s)1328{1329nir_ssa_def *zero_u8 = nir_imm_zero(b, 1, 8);1330nir_store_var(b, s->current_clear_primflag_idx_var, num_vertices, 0x1u);13311332nir_loop *loop = nir_push_loop(b);1333{1334nir_ssa_def *current_clear_primflag_idx = nir_load_var(b, s->current_clear_primflag_idx_var);1335nir_if *if_break = nir_push_if(b, nir_uge(b, current_clear_primflag_idx, nir_imm_int(b, b->shader->info.gs.vertices_out)));1336{1337nir_jump(b, nir_jump_break);1338}1339nir_push_else(b, if_break);1340{1341nir_ssa_def *emit_vtx_addr = ngg_gs_emit_vertex_addr(b, current_clear_primflag_idx, s);1342nir_build_store_shared(b, zero_u8, emit_vtx_addr, .base = s->lds_offs_primflags + stream, .align_mul = 1, .write_mask = 0x1u);1343nir_store_var(b, s->current_clear_primflag_idx_var, nir_iadd_imm_nuw(b, current_clear_primflag_idx, 1), 0x1u);1344}1345nir_pop_if(b, if_break);1346}1347nir_pop_loop(b, loop);1348}13491350static void1351ngg_gs_shader_query(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)1352{1353nir_if *if_shader_query = nir_push_if(b, nir_build_load_shader_query_enabled_amd(b));1354nir_ssa_def *num_prims_in_wave = NULL;13551356/* Calculate the "real" number of emitted primitives from the emitted GS vertices and primitives.1357* GS emits points, line strips or triangle strips.1358* Real primitives are points, lines or triangles.1359*/1360if (nir_src_is_const(intrin->src[0]) && nir_src_is_const(intrin->src[1])) {1361unsigned gs_vtx_cnt = nir_src_as_uint(intrin->src[0]);1362unsigned gs_prm_cnt = nir_src_as_uint(intrin->src[1]);1363unsigned total_prm_cnt = gs_vtx_cnt - gs_prm_cnt * (s->num_vertices_per_primitive - 1u);1364nir_ssa_def *num_threads = nir_bit_count(b, nir_build_ballot(b, 1, s->wave_size, nir_imm_bool(b, true)));1365num_prims_in_wave = nir_imul_imm(b, num_threads, total_prm_cnt);1366} else {1367nir_ssa_def *gs_vtx_cnt = intrin->src[0].ssa;1368nir_ssa_def *prm_cnt = intrin->src[1].ssa;1369if (s->num_vertices_per_primitive > 1)1370prm_cnt = nir_iadd_nuw(b, nir_imul_imm(b, prm_cnt, -1u * (s->num_vertices_per_primitive - 1)), gs_vtx_cnt);1371num_prims_in_wave = nir_build_reduce(b, prm_cnt, .reduction_op = nir_op_iadd);1372}13731374/* Store the query result to GDS using an atomic add. */1375nir_if *if_first_lane = nir_push_if(b, nir_build_elect(b, 1));1376nir_build_gds_atomic_add_amd(b, 32, num_prims_in_wave, nir_imm_int(b, 0), nir_imm_int(b, 0x100));1377nir_pop_if(b, if_first_lane);13781379nir_pop_if(b, if_shader_query);1380}13811382static bool1383lower_ngg_gs_store_output(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)1384{1385assert(nir_src_is_const(intrin->src[1]));1386b->cursor = nir_before_instr(&intrin->instr);13871388unsigned writemask = nir_intrinsic_write_mask(intrin);1389unsigned base = nir_intrinsic_base(intrin);1390unsigned component_offset = nir_intrinsic_component(intrin);1391unsigned base_offset = nir_src_as_uint(intrin->src[1]);1392nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);13931394assert((base + base_offset) < VARYING_SLOT_MAX);13951396nir_ssa_def *store_val = intrin->src[0].ssa;13971398for (unsigned comp = 0; comp < 4; ++comp) {1399if (!(writemask & (1 << comp)))1400continue;1401unsigned stream = (io_sem.gs_streams >> (comp * 2)) & 0x3;1402if (!(b->shader->info.gs.active_stream_mask & (1 << stream)))1403continue;14041405/* Small bitsize components consume the same amount of space as 32-bit components,1406* but 64-bit ones consume twice as many. (Vulkan spec 15.1.5)1407*/1408unsigned num_consumed_components = MIN2(1, DIV_ROUND_UP(store_val->bit_size, 32));1409nir_ssa_def *element = nir_channel(b, store_val, comp);1410if (num_consumed_components > 1)1411element = nir_extract_bits(b, &element, 1, 0, num_consumed_components, 32);14121413for (unsigned c = 0; c < num_consumed_components; ++c) {1414unsigned component_index = (comp * num_consumed_components) + c + component_offset;1415unsigned base_index = base + base_offset + component_index / 4;1416component_index %= 4;14171418/* Save output usage info */1419gs_output_component_info *info = &s->output_component_info[base_index][component_index];1420info->bit_size = MAX2(info->bit_size, MIN2(store_val->bit_size, 32));1421info->stream = stream;14221423/* Store the current component element */1424nir_ssa_def *component_element = element;1425if (num_consumed_components > 1)1426component_element = nir_channel(b, component_element, c);1427if (component_element->bit_size != 32)1428component_element = nir_u2u32(b, component_element);14291430nir_store_var(b, s->output_vars[base_index][component_index], component_element, 0x1u);1431}1432}14331434nir_instr_remove(&intrin->instr);1435return true;1436}14371438static bool1439lower_ngg_gs_emit_vertex_with_counter(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)1440{1441b->cursor = nir_before_instr(&intrin->instr);14421443unsigned stream = nir_intrinsic_stream_id(intrin);1444if (!(b->shader->info.gs.active_stream_mask & (1 << stream))) {1445nir_instr_remove(&intrin->instr);1446return true;1447}14481449nir_ssa_def *gs_emit_vtx_idx = intrin->src[0].ssa;1450nir_ssa_def *current_vtx_per_prim = intrin->src[1].ssa;1451nir_ssa_def *gs_emit_vtx_addr = ngg_gs_emit_vertex_addr(b, gs_emit_vtx_idx, s);14521453for (unsigned slot = 0; slot < VARYING_SLOT_MAX; ++slot) {1454unsigned packed_location = util_bitcount64((b->shader->info.outputs_written & BITFIELD64_MASK(slot)));14551456for (unsigned comp = 0; comp < 4; ++comp) {1457gs_output_component_info *info = &s->output_component_info[slot][comp];1458if (info->stream != stream || !info->bit_size)1459continue;14601461/* Store the output to LDS */1462nir_ssa_def *out_val = nir_load_var(b, s->output_vars[slot][comp]);1463if (info->bit_size != 32)1464out_val = nir_u2u(b, out_val, info->bit_size);14651466nir_build_store_shared(b, out_val, gs_emit_vtx_addr, .base = packed_location * 16 + comp * 4, .align_mul = 4, .write_mask = 0x1u);14671468/* Clear the variable that holds the output */1469nir_store_var(b, s->output_vars[slot][comp], nir_ssa_undef(b, 1, 32), 0x1u);1470}1471}14721473/* Calculate and store per-vertex primitive flags based on vertex counts:1474* - bit 0: whether this vertex finishes a primitive (a real primitive, not the strip)1475* - bit 1: whether the primitive index is odd (if we are emitting triangle strips, otherwise always 0)1476* - bit 2: always 1 (so that we can use it for determining vertex liveness)1477*/14781479nir_ssa_def *completes_prim = nir_ige(b, current_vtx_per_prim, nir_imm_int(b, s->num_vertices_per_primitive - 1));1480nir_ssa_def *prim_flag = nir_bcsel(b, completes_prim, nir_imm_int(b, 0b101u), nir_imm_int(b, 0b100u));14811482if (s->num_vertices_per_primitive == 3) {1483nir_ssa_def *odd = nir_iand_imm(b, current_vtx_per_prim, 1);1484prim_flag = nir_iadd_nuw(b, prim_flag, nir_ishl(b, odd, nir_imm_int(b, 1)));1485}14861487nir_build_store_shared(b, nir_u2u8(b, prim_flag), gs_emit_vtx_addr, .base = s->lds_offs_primflags + stream, .align_mul = 4u, .write_mask = 0x1u);1488nir_instr_remove(&intrin->instr);1489return true;1490}14911492static bool1493lower_ngg_gs_end_primitive_with_counter(nir_builder *b, nir_intrinsic_instr *intrin, UNUSED lower_ngg_gs_state *s)1494{1495b->cursor = nir_before_instr(&intrin->instr);14961497/* These are not needed, we can simply remove them */1498nir_instr_remove(&intrin->instr);1499return true;1500}15011502static bool1503lower_ngg_gs_set_vertex_and_primitive_count(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)1504{1505b->cursor = nir_before_instr(&intrin->instr);15061507unsigned stream = nir_intrinsic_stream_id(intrin);1508if (stream > 0 && !(b->shader->info.gs.active_stream_mask & (1 << stream))) {1509nir_instr_remove(&intrin->instr);1510return true;1511}15121513s->found_out_vtxcnt[stream] = true;15141515/* Clear the primitive flags of non-emitted vertices */1516if (!nir_src_is_const(intrin->src[0]) || nir_src_as_uint(intrin->src[0]) < b->shader->info.gs.vertices_out)1517ngg_gs_clear_primflags(b, intrin->src[0].ssa, stream, s);15181519ngg_gs_shader_query(b, intrin, s);1520nir_instr_remove(&intrin->instr);1521return true;1522}15231524static bool1525lower_ngg_gs_intrinsic(nir_builder *b, nir_instr *instr, void *state)1526{1527lower_ngg_gs_state *s = (lower_ngg_gs_state *) state;15281529if (instr->type != nir_instr_type_intrinsic)1530return false;15311532nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);15331534if (intrin->intrinsic == nir_intrinsic_store_output)1535return lower_ngg_gs_store_output(b, intrin, s);1536else if (intrin->intrinsic == nir_intrinsic_emit_vertex_with_counter)1537return lower_ngg_gs_emit_vertex_with_counter(b, intrin, s);1538else if (intrin->intrinsic == nir_intrinsic_end_primitive_with_counter)1539return lower_ngg_gs_end_primitive_with_counter(b, intrin, s);1540else if (intrin->intrinsic == nir_intrinsic_set_vertex_and_primitive_count)1541return lower_ngg_gs_set_vertex_and_primitive_count(b, intrin, s);15421543return false;1544}15451546static void1547lower_ngg_gs_intrinsics(nir_shader *shader, lower_ngg_gs_state *s)1548{1549nir_shader_instructions_pass(shader, lower_ngg_gs_intrinsic, nir_metadata_none, s);1550}15511552static void1553ngg_gs_export_primitives(nir_builder *b, nir_ssa_def *max_num_out_prims, nir_ssa_def *tid_in_tg,1554nir_ssa_def *exporter_tid_in_tg, nir_ssa_def *primflag_0,1555lower_ngg_gs_state *s)1556{1557nir_if *if_prim_export_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_prims));15581559/* Only bit 0 matters here - set it to 1 when the primitive should be null */1560nir_ssa_def *is_null_prim = nir_ixor(b, primflag_0, nir_imm_int(b, -1u));15611562nir_ssa_def *vtx_indices[3] = {0};1563vtx_indices[s->num_vertices_per_primitive - 1] = exporter_tid_in_tg;1564if (s->num_vertices_per_primitive >= 2)1565vtx_indices[s->num_vertices_per_primitive - 2] = nir_isub(b, exporter_tid_in_tg, nir_imm_int(b, 1));1566if (s->num_vertices_per_primitive == 3)1567vtx_indices[s->num_vertices_per_primitive - 3] = nir_isub(b, exporter_tid_in_tg, nir_imm_int(b, 2));15681569if (s->num_vertices_per_primitive == 3) {1570/* API GS outputs triangle strips, but NGG HW understands triangles.1571* We already know the triangles due to how we set the primitive flags, but we need to1572* make sure the vertex order is so that the front/back is correct, and the provoking vertex is kept.1573*/15741575nir_ssa_def *is_odd = nir_ubfe(b, primflag_0, nir_imm_int(b, 1), nir_imm_int(b, 1));1576if (!s->provoking_vertex_last) {1577vtx_indices[1] = nir_iadd(b, vtx_indices[1], is_odd);1578vtx_indices[2] = nir_isub(b, vtx_indices[2], is_odd);1579} else {1580vtx_indices[0] = nir_iadd(b, vtx_indices[0], is_odd);1581vtx_indices[1] = nir_isub(b, vtx_indices[1], is_odd);1582}1583}15841585nir_ssa_def *arg = emit_pack_ngg_prim_exp_arg(b, s->num_vertices_per_primitive, vtx_indices, is_null_prim);1586nir_build_export_primitive_amd(b, arg);1587nir_pop_if(b, if_prim_export_thread);1588}15891590static void1591ngg_gs_export_vertices(nir_builder *b, nir_ssa_def *max_num_out_vtx, nir_ssa_def *tid_in_tg,1592nir_ssa_def *out_vtx_lds_addr, lower_ngg_gs_state *s)1593{1594nir_if *if_vtx_export_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_vtx));1595nir_ssa_def *exported_out_vtx_lds_addr = out_vtx_lds_addr;15961597if (!s->output_compile_time_known) {1598/* Vertex compaction.1599* The current thread will export a vertex that was live in another invocation.1600* Load the index of the vertex that the current thread will have to export.1601*/1602nir_ssa_def *exported_vtx_idx = nir_build_load_shared(b, 1, 8, out_vtx_lds_addr, .base = s->lds_offs_primflags + 1, .align_mul = 1u);1603exported_out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, nir_u2u32(b, exported_vtx_idx), s);1604}16051606for (unsigned slot = 0; slot < VARYING_SLOT_MAX; ++slot) {1607if (!(b->shader->info.outputs_written & BITFIELD64_BIT(slot)))1608continue;16091610unsigned packed_location = util_bitcount64((b->shader->info.outputs_written & BITFIELD64_MASK(slot)));1611nir_io_semantics io_sem = { .location = slot, .num_slots = 1 };16121613for (unsigned comp = 0; comp < 4; ++comp) {1614gs_output_component_info *info = &s->output_component_info[slot][comp];1615if (info->stream != 0 || info->bit_size == 0)1616continue;16171618nir_ssa_def *load = nir_build_load_shared(b, 1, info->bit_size, exported_out_vtx_lds_addr, .base = packed_location * 16u + comp * 4u, .align_mul = 4u);1619nir_build_store_output(b, load, nir_imm_int(b, 0), .write_mask = 0x1u, .base = slot, .component = comp, .io_semantics = io_sem);1620}1621}16221623nir_build_export_vertex_amd(b);1624nir_pop_if(b, if_vtx_export_thread);1625}16261627static void1628ngg_gs_setup_vertex_compaction(nir_builder *b, nir_ssa_def *vertex_live, nir_ssa_def *tid_in_tg,1629nir_ssa_def *exporter_tid_in_tg, lower_ngg_gs_state *s)1630{1631assert(vertex_live->bit_size == 1);1632nir_if *if_vertex_live = nir_push_if(b, vertex_live);1633{1634/* Setup the vertex compaction.1635* Save the current thread's id for the thread which will export the current vertex.1636* We reuse stream 1 of the primitive flag of the other thread's vertex for storing this.1637*/16381639nir_ssa_def *exporter_lds_addr = ngg_gs_out_vertex_addr(b, exporter_tid_in_tg, s);1640nir_ssa_def *tid_in_tg_u8 = nir_u2u8(b, tid_in_tg);1641nir_build_store_shared(b, tid_in_tg_u8, exporter_lds_addr, .base = s->lds_offs_primflags + 1, .align_mul = 1u, .write_mask = 0x1u);1642}1643nir_pop_if(b, if_vertex_live);1644}16451646static nir_ssa_def *1647ngg_gs_load_out_vtx_primflag_0(nir_builder *b, nir_ssa_def *tid_in_tg, nir_ssa_def *vtx_lds_addr,1648nir_ssa_def *max_num_out_vtx, lower_ngg_gs_state *s)1649{1650nir_ssa_def *zero = nir_imm_int(b, 0);16511652nir_if *if_outvtx_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_vtx));1653nir_ssa_def *primflag_0 = nir_build_load_shared(b, 1, 8, vtx_lds_addr, .base = s->lds_offs_primflags, .align_mul = 4u);1654primflag_0 = nir_u2u32(b, primflag_0);1655nir_pop_if(b, if_outvtx_thread);16561657return nir_if_phi(b, primflag_0, zero);1658}16591660static void1661ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)1662{1663nir_ssa_def *tid_in_tg = nir_build_load_local_invocation_index(b);1664nir_ssa_def *max_vtxcnt = nir_build_load_workgroup_num_input_vertices_amd(b);1665nir_ssa_def *max_prmcnt = max_vtxcnt; /* They are currently practically the same; both RADV and RadeonSI do this. */1666nir_ssa_def *out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, tid_in_tg, s);16671668if (s->output_compile_time_known) {1669/* When the output is compile-time known, the GS writes all possible vertices and primitives it can.1670* The gs_alloc_req needs to happen on one wave only, otherwise the HW hangs.1671*/1672nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_build_load_subgroup_id(b), nir_imm_zero(b, 1, 32)));1673nir_build_alloc_vertices_and_primitives_amd(b, max_vtxcnt, max_prmcnt);1674nir_pop_if(b, if_wave_0);1675}16761677/* Workgroup barrier: wait for all GS threads to finish */1678nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,1679.memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);16801681nir_ssa_def *out_vtx_primflag_0 = ngg_gs_load_out_vtx_primflag_0(b, tid_in_tg, out_vtx_lds_addr, max_vtxcnt, s);16821683if (s->output_compile_time_known) {1684ngg_gs_export_primitives(b, max_vtxcnt, tid_in_tg, tid_in_tg, out_vtx_primflag_0, s);1685ngg_gs_export_vertices(b, max_vtxcnt, tid_in_tg, out_vtx_lds_addr, s);1686return;1687}16881689/* When the output vertex count is not known at compile time:1690* There may be gaps between invocations that have live vertices, but NGG hardware1691* requires that the invocations that export vertices are packed (ie. compact).1692* To ensure this, we need to repack invocations that have a live vertex.1693*/1694nir_ssa_def *vertex_live = nir_ine(b, out_vtx_primflag_0, nir_imm_zero(b, 1, out_vtx_primflag_0->bit_size));1695wg_repack_result rep = repack_invocations_in_workgroup(b, vertex_live, s->lds_addr_gs_scratch, s->max_num_waves, s->wave_size);16961697nir_ssa_def *workgroup_num_vertices = rep.num_repacked_invocations;1698nir_ssa_def *exporter_tid_in_tg = rep.repacked_invocation_index;16991700/* When the workgroup emits 0 total vertices, we also must export 0 primitives (otherwise the HW can hang). */1701nir_ssa_def *any_output = nir_ine(b, workgroup_num_vertices, nir_imm_int(b, 0));1702max_prmcnt = nir_bcsel(b, any_output, max_prmcnt, nir_imm_int(b, 0));17031704/* Allocate export space. We currently don't compact primitives, just use the maximum number. */1705nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_build_load_subgroup_id(b), nir_imm_zero(b, 1, 32)));1706nir_build_alloc_vertices_and_primitives_amd(b, workgroup_num_vertices, max_prmcnt);1707nir_pop_if(b, if_wave_0);17081709/* Vertex compaction. This makes sure there are no gaps between threads that export vertices. */1710ngg_gs_setup_vertex_compaction(b, vertex_live, tid_in_tg, exporter_tid_in_tg, s);17111712/* Workgroup barrier: wait for all LDS stores to finish. */1713nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,1714.memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);17151716ngg_gs_export_primitives(b, max_prmcnt, tid_in_tg, exporter_tid_in_tg, out_vtx_primflag_0, s);1717ngg_gs_export_vertices(b, workgroup_num_vertices, tid_in_tg, out_vtx_lds_addr, s);1718}17191720void1721ac_nir_lower_ngg_gs(nir_shader *shader,1722unsigned wave_size,1723unsigned max_workgroup_size,1724unsigned esgs_ring_lds_bytes,1725unsigned gs_out_vtx_bytes,1726unsigned gs_total_out_vtx_bytes,1727bool provoking_vertex_last)1728{1729nir_function_impl *impl = nir_shader_get_entrypoint(shader);1730assert(impl);17311732lower_ngg_gs_state state = {1733.max_num_waves = DIV_ROUND_UP(max_workgroup_size, wave_size),1734.wave_size = wave_size,1735.lds_addr_gs_out_vtx = esgs_ring_lds_bytes,1736.lds_addr_gs_scratch = ALIGN(esgs_ring_lds_bytes + gs_total_out_vtx_bytes, 8u /* for the repacking code */),1737.lds_offs_primflags = gs_out_vtx_bytes,1738.lds_bytes_per_gs_out_vertex = gs_out_vtx_bytes + 4u,1739.provoking_vertex_last = provoking_vertex_last,1740};17411742unsigned lds_scratch_bytes = DIV_ROUND_UP(state.max_num_waves, 4u) * 4u;1743unsigned total_lds_bytes = state.lds_addr_gs_scratch + lds_scratch_bytes;1744shader->info.shared_size = total_lds_bytes;17451746nir_gs_count_vertices_and_primitives(shader, state.const_out_vtxcnt, state.const_out_prmcnt, 4u);1747state.output_compile_time_known = state.const_out_vtxcnt[0] == shader->info.gs.vertices_out &&1748state.const_out_prmcnt[0] != -1;17491750if (!state.output_compile_time_known)1751state.current_clear_primflag_idx_var = nir_local_variable_create(impl, glsl_uint_type(), "current_clear_primflag_idx");17521753if (shader->info.gs.output_primitive == GL_POINTS)1754state.num_vertices_per_primitive = 1;1755else if (shader->info.gs.output_primitive == GL_LINE_STRIP)1756state.num_vertices_per_primitive = 2;1757else if (shader->info.gs.output_primitive == GL_TRIANGLE_STRIP)1758state.num_vertices_per_primitive = 3;1759else1760unreachable("Invalid GS output primitive.");17611762/* Extract the full control flow. It is going to be wrapped in an if statement. */1763nir_cf_list extracted;1764nir_cf_extract(&extracted, nir_before_cf_list(&impl->body), nir_after_cf_list(&impl->body));17651766nir_builder builder;1767nir_builder *b = &builder; /* This is to avoid the & */1768nir_builder_init(b, impl);1769b->cursor = nir_before_cf_list(&impl->body);17701771/* Workgroup barrier: wait for ES threads */1772nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP,1773.memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);17741775/* Wrap the GS control flow. */1776nir_if *if_gs_thread = nir_push_if(b, nir_build_has_input_primitive_amd(b));17771778/* Create and initialize output variables */1779for (unsigned slot = 0; slot < VARYING_SLOT_MAX; ++slot) {1780for (unsigned comp = 0; comp < 4; ++comp) {1781state.output_vars[slot][comp] = nir_local_variable_create(impl, glsl_uint_type(), "output");1782}1783}17841785nir_cf_reinsert(&extracted, b->cursor);1786b->cursor = nir_after_cf_list(&if_gs_thread->then_list);1787nir_pop_if(b, if_gs_thread);17881789/* Lower the GS intrinsics */1790lower_ngg_gs_intrinsics(shader, &state);1791b->cursor = nir_after_cf_list(&impl->body);17921793if (!state.found_out_vtxcnt[0]) {1794fprintf(stderr, "Could not find set_vertex_and_primitive_count for stream 0. This would hang your GPU.");1795abort();1796}17971798/* Emit the finale sequence */1799ngg_gs_finale(b, &state);1800nir_validate_shader(shader, "after emitting NGG GS");18011802/* Cleanup */1803nir_lower_vars_to_ssa(shader);1804nir_remove_dead_variables(shader, nir_var_function_temp, NULL);1805nir_metadata_preserve(impl, nir_metadata_none);1806}180718081809