Path: blob/21.2-virgl/src/compiler/spirv/vtn_subgroup.c
4545 views
/*1* Copyright © 2016 Intel Corporation2*3* Permission is hereby granted, free of charge, to any person obtaining a4* copy of this software and associated documentation files (the "Software"),5* to deal in the Software without restriction, including without limitation6* the rights to use, copy, modify, merge, publish, distribute, sublicense,7* and/or sell copies of the Software, and to permit persons to whom the8* Software is furnished to do so, subject to the following conditions:9*10* The above copyright notice and this permission notice (including the next11* paragraph) shall be included in all copies or substantial portions of the12* Software.13*14* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR15* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,16* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL17* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER18* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING19* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS20* IN THE SOFTWARE.21*/2223#include "vtn_private.h"2425static struct vtn_ssa_value *26vtn_build_subgroup_instr(struct vtn_builder *b,27nir_intrinsic_op nir_op,28struct vtn_ssa_value *src0,29nir_ssa_def *index,30unsigned const_idx0,31unsigned const_idx1)32{33/* Some of the subgroup operations take an index. SPIR-V allows this to be34* any integer type. To make things simpler for drivers, we only support35* 32-bit indices.36*/37if (index && index->bit_size != 32)38index = nir_u2u32(&b->nb, index);3940struct vtn_ssa_value *dst = vtn_create_ssa_value(b, src0->type);4142vtn_assert(dst->type == src0->type);43if (!glsl_type_is_vector_or_scalar(dst->type)) {44for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {45dst->elems[0] =46vtn_build_subgroup_instr(b, nir_op, src0->elems[i], index,47const_idx0, const_idx1);48}49return dst;50}5152nir_intrinsic_instr *intrin =53nir_intrinsic_instr_create(b->nb.shader, nir_op);54nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,55dst->type, NULL);56intrin->num_components = intrin->dest.ssa.num_components;5758intrin->src[0] = nir_src_for_ssa(src0->def);59if (index)60intrin->src[1] = nir_src_for_ssa(index);6162intrin->const_index[0] = const_idx0;63intrin->const_index[1] = const_idx1;6465nir_builder_instr_insert(&b->nb, &intrin->instr);6667dst->def = &intrin->dest.ssa;6869return dst;70}7172void73vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,74const uint32_t *w, unsigned count)75{76struct vtn_type *dest_type = vtn_get_type(b, w[1]);7778switch (opcode) {79case SpvOpGroupNonUniformElect: {80vtn_fail_if(dest_type->type != glsl_bool_type(),81"OpGroupNonUniformElect must return a Bool");82nir_intrinsic_instr *elect =83nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_elect);84nir_ssa_dest_init_for_type(&elect->instr, &elect->dest,85dest_type->type, NULL);86nir_builder_instr_insert(&b->nb, &elect->instr);87vtn_push_nir_ssa(b, w[2], &elect->dest.ssa);88break;89}9091case SpvOpGroupNonUniformBallot:92case SpvOpSubgroupBallotKHR: {93bool has_scope = (opcode != SpvOpSubgroupBallotKHR);94vtn_fail_if(dest_type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),95"OpGroupNonUniformBallot must return a uvec4");96nir_intrinsic_instr *ballot =97nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);98ballot->src[0] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[3 + has_scope]));99nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL);100ballot->num_components = 4;101nir_builder_instr_insert(&b->nb, &ballot->instr);102vtn_push_nir_ssa(b, w[2], &ballot->dest.ssa);103break;104}105106case SpvOpGroupNonUniformInverseBallot: {107/* This one is just a BallotBitfieldExtract with subgroup invocation.108* We could add a NIR intrinsic but it's easier to just lower it on the109* spot.110*/111nir_intrinsic_instr *intrin =112nir_intrinsic_instr_create(b->nb.shader,113nir_intrinsic_ballot_bitfield_extract);114115intrin->src[0] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[4]));116intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb));117118nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,119dest_type->type, NULL);120nir_builder_instr_insert(&b->nb, &intrin->instr);121122vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);123break;124}125126case SpvOpGroupNonUniformBallotBitExtract:127case SpvOpGroupNonUniformBallotBitCount:128case SpvOpGroupNonUniformBallotFindLSB:129case SpvOpGroupNonUniformBallotFindMSB: {130nir_ssa_def *src0, *src1 = NULL;131nir_intrinsic_op op;132switch (opcode) {133case SpvOpGroupNonUniformBallotBitExtract:134op = nir_intrinsic_ballot_bitfield_extract;135src0 = vtn_get_nir_ssa(b, w[4]);136src1 = vtn_get_nir_ssa(b, w[5]);137break;138case SpvOpGroupNonUniformBallotBitCount:139switch ((SpvGroupOperation)w[4]) {140case SpvGroupOperationReduce:141op = nir_intrinsic_ballot_bit_count_reduce;142break;143case SpvGroupOperationInclusiveScan:144op = nir_intrinsic_ballot_bit_count_inclusive;145break;146case SpvGroupOperationExclusiveScan:147op = nir_intrinsic_ballot_bit_count_exclusive;148break;149default:150unreachable("Invalid group operation");151}152src0 = vtn_get_nir_ssa(b, w[5]);153break;154case SpvOpGroupNonUniformBallotFindLSB:155op = nir_intrinsic_ballot_find_lsb;156src0 = vtn_get_nir_ssa(b, w[4]);157break;158case SpvOpGroupNonUniformBallotFindMSB:159op = nir_intrinsic_ballot_find_msb;160src0 = vtn_get_nir_ssa(b, w[4]);161break;162default:163unreachable("Unhandled opcode");164}165166nir_intrinsic_instr *intrin =167nir_intrinsic_instr_create(b->nb.shader, op);168169intrin->src[0] = nir_src_for_ssa(src0);170if (src1)171intrin->src[1] = nir_src_for_ssa(src1);172173nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,174dest_type->type, NULL);175nir_builder_instr_insert(&b->nb, &intrin->instr);176177vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);178break;179}180181case SpvOpGroupNonUniformBroadcastFirst:182case SpvOpSubgroupFirstInvocationKHR: {183bool has_scope = (opcode != SpvOpSubgroupFirstInvocationKHR);184vtn_push_ssa_value(b, w[2],185vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,186vtn_ssa_value(b, w[3 + has_scope]),187NULL, 0, 0));188break;189}190191case SpvOpGroupNonUniformBroadcast:192case SpvOpGroupBroadcast:193case SpvOpSubgroupReadInvocationKHR: {194bool has_scope = (opcode != SpvOpSubgroupReadInvocationKHR);195vtn_push_ssa_value(b, w[2],196vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,197vtn_ssa_value(b, w[3 + has_scope]),198vtn_get_nir_ssa(b, w[4 + has_scope]), 0, 0));199break;200}201202case SpvOpGroupNonUniformAll:203case SpvOpGroupNonUniformAny:204case SpvOpGroupNonUniformAllEqual:205case SpvOpGroupAll:206case SpvOpGroupAny:207case SpvOpSubgroupAllKHR:208case SpvOpSubgroupAnyKHR:209case SpvOpSubgroupAllEqualKHR: {210vtn_fail_if(dest_type->type != glsl_bool_type(),211"OpGroupNonUniform(All|Any|AllEqual) must return a bool");212nir_intrinsic_op op;213switch (opcode) {214case SpvOpGroupNonUniformAll:215case SpvOpGroupAll:216case SpvOpSubgroupAllKHR:217op = nir_intrinsic_vote_all;218break;219case SpvOpGroupNonUniformAny:220case SpvOpGroupAny:221case SpvOpSubgroupAnyKHR:222op = nir_intrinsic_vote_any;223break;224case SpvOpSubgroupAllEqualKHR:225op = nir_intrinsic_vote_ieq;226break;227case SpvOpGroupNonUniformAllEqual:228switch (glsl_get_base_type(vtn_ssa_value(b, w[4])->type)) {229case GLSL_TYPE_FLOAT:230case GLSL_TYPE_FLOAT16:231case GLSL_TYPE_DOUBLE:232op = nir_intrinsic_vote_feq;233break;234case GLSL_TYPE_UINT:235case GLSL_TYPE_INT:236case GLSL_TYPE_UINT8:237case GLSL_TYPE_INT8:238case GLSL_TYPE_UINT16:239case GLSL_TYPE_INT16:240case GLSL_TYPE_UINT64:241case GLSL_TYPE_INT64:242case GLSL_TYPE_BOOL:243op = nir_intrinsic_vote_ieq;244break;245default:246unreachable("Unhandled type");247}248break;249default:250unreachable("Unhandled opcode");251}252253nir_ssa_def *src0;254if (opcode == SpvOpGroupNonUniformAll || opcode == SpvOpGroupAll ||255opcode == SpvOpGroupNonUniformAny || opcode == SpvOpGroupAny ||256opcode == SpvOpGroupNonUniformAllEqual) {257src0 = vtn_get_nir_ssa(b, w[4]);258} else {259src0 = vtn_get_nir_ssa(b, w[3]);260}261nir_intrinsic_instr *intrin =262nir_intrinsic_instr_create(b->nb.shader, op);263if (nir_intrinsic_infos[op].src_components[0] == 0)264intrin->num_components = src0->num_components;265intrin->src[0] = nir_src_for_ssa(src0);266nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,267dest_type->type, NULL);268nir_builder_instr_insert(&b->nb, &intrin->instr);269270vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);271break;272}273274case SpvOpGroupNonUniformShuffle:275case SpvOpGroupNonUniformShuffleXor:276case SpvOpGroupNonUniformShuffleUp:277case SpvOpGroupNonUniformShuffleDown: {278nir_intrinsic_op op;279switch (opcode) {280case SpvOpGroupNonUniformShuffle:281op = nir_intrinsic_shuffle;282break;283case SpvOpGroupNonUniformShuffleXor:284op = nir_intrinsic_shuffle_xor;285break;286case SpvOpGroupNonUniformShuffleUp:287op = nir_intrinsic_shuffle_up;288break;289case SpvOpGroupNonUniformShuffleDown:290op = nir_intrinsic_shuffle_down;291break;292default:293unreachable("Invalid opcode");294}295vtn_push_ssa_value(b, w[2],296vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]),297vtn_get_nir_ssa(b, w[5]), 0, 0));298break;299}300301case SpvOpSubgroupShuffleINTEL:302case SpvOpSubgroupShuffleXorINTEL: {303nir_intrinsic_op op = opcode == SpvOpSubgroupShuffleINTEL ?304nir_intrinsic_shuffle : nir_intrinsic_shuffle_xor;305vtn_push_ssa_value(b, w[2],306vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[3]),307vtn_get_nir_ssa(b, w[4]), 0, 0));308break;309}310311case SpvOpSubgroupShuffleUpINTEL:312case SpvOpSubgroupShuffleDownINTEL: {313/* TODO: Move this lower on the compiler stack, where we can move the314* current/other data to adjacent registers to avoid doing a shuffle315* twice.316*/317318nir_builder *nb = &b->nb;319nir_ssa_def *size = nir_load_subgroup_size(nb);320nir_ssa_def *delta = vtn_get_nir_ssa(b, w[5]);321322/* Rewrite UP in terms of DOWN.323*324* UP(a, b, delta) == DOWN(a, b, size - delta)325*/326if (opcode == SpvOpSubgroupShuffleUpINTEL)327delta = nir_isub(nb, size, delta);328329nir_ssa_def *index = nir_iadd(nb, nir_load_subgroup_invocation(nb), delta);330struct vtn_ssa_value *current =331vtn_build_subgroup_instr(b, nir_intrinsic_shuffle, vtn_ssa_value(b, w[3]),332index, 0, 0);333334struct vtn_ssa_value *next =335vtn_build_subgroup_instr(b, nir_intrinsic_shuffle, vtn_ssa_value(b, w[4]),336nir_isub(nb, index, size), 0, 0);337338nir_ssa_def *cond = nir_ilt(nb, index, size);339vtn_push_nir_ssa(b, w[2], nir_bcsel(nb, cond, current->def, next->def));340341break;342}343344case SpvOpGroupNonUniformQuadBroadcast:345vtn_push_ssa_value(b, w[2],346vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,347vtn_ssa_value(b, w[4]),348vtn_get_nir_ssa(b, w[5]), 0, 0));349break;350351case SpvOpGroupNonUniformQuadSwap: {352unsigned direction = vtn_constant_uint(b, w[5]);353nir_intrinsic_op op;354switch (direction) {355case 0:356op = nir_intrinsic_quad_swap_horizontal;357break;358case 1:359op = nir_intrinsic_quad_swap_vertical;360break;361case 2:362op = nir_intrinsic_quad_swap_diagonal;363break;364default:365vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap");366}367vtn_push_ssa_value(b, w[2],368vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]), NULL, 0, 0));369break;370}371372case SpvOpGroupNonUniformIAdd:373case SpvOpGroupNonUniformFAdd:374case SpvOpGroupNonUniformIMul:375case SpvOpGroupNonUniformFMul:376case SpvOpGroupNonUniformSMin:377case SpvOpGroupNonUniformUMin:378case SpvOpGroupNonUniformFMin:379case SpvOpGroupNonUniformSMax:380case SpvOpGroupNonUniformUMax:381case SpvOpGroupNonUniformFMax:382case SpvOpGroupNonUniformBitwiseAnd:383case SpvOpGroupNonUniformBitwiseOr:384case SpvOpGroupNonUniformBitwiseXor:385case SpvOpGroupNonUniformLogicalAnd:386case SpvOpGroupNonUniformLogicalOr:387case SpvOpGroupNonUniformLogicalXor:388case SpvOpGroupIAdd:389case SpvOpGroupFAdd:390case SpvOpGroupFMin:391case SpvOpGroupUMin:392case SpvOpGroupSMin:393case SpvOpGroupFMax:394case SpvOpGroupUMax:395case SpvOpGroupSMax:396case SpvOpGroupIAddNonUniformAMD:397case SpvOpGroupFAddNonUniformAMD:398case SpvOpGroupFMinNonUniformAMD:399case SpvOpGroupUMinNonUniformAMD:400case SpvOpGroupSMinNonUniformAMD:401case SpvOpGroupFMaxNonUniformAMD:402case SpvOpGroupUMaxNonUniformAMD:403case SpvOpGroupSMaxNonUniformAMD: {404nir_op reduction_op;405switch (opcode) {406case SpvOpGroupNonUniformIAdd:407case SpvOpGroupIAdd:408case SpvOpGroupIAddNonUniformAMD:409reduction_op = nir_op_iadd;410break;411case SpvOpGroupNonUniformFAdd:412case SpvOpGroupFAdd:413case SpvOpGroupFAddNonUniformAMD:414reduction_op = nir_op_fadd;415break;416case SpvOpGroupNonUniformIMul:417reduction_op = nir_op_imul;418break;419case SpvOpGroupNonUniformFMul:420reduction_op = nir_op_fmul;421break;422case SpvOpGroupNonUniformSMin:423case SpvOpGroupSMin:424case SpvOpGroupSMinNonUniformAMD:425reduction_op = nir_op_imin;426break;427case SpvOpGroupNonUniformUMin:428case SpvOpGroupUMin:429case SpvOpGroupUMinNonUniformAMD:430reduction_op = nir_op_umin;431break;432case SpvOpGroupNonUniformFMin:433case SpvOpGroupFMin:434case SpvOpGroupFMinNonUniformAMD:435reduction_op = nir_op_fmin;436break;437case SpvOpGroupNonUniformSMax:438case SpvOpGroupSMax:439case SpvOpGroupSMaxNonUniformAMD:440reduction_op = nir_op_imax;441break;442case SpvOpGroupNonUniformUMax:443case SpvOpGroupUMax:444case SpvOpGroupUMaxNonUniformAMD:445reduction_op = nir_op_umax;446break;447case SpvOpGroupNonUniformFMax:448case SpvOpGroupFMax:449case SpvOpGroupFMaxNonUniformAMD:450reduction_op = nir_op_fmax;451break;452case SpvOpGroupNonUniformBitwiseAnd:453case SpvOpGroupNonUniformLogicalAnd:454reduction_op = nir_op_iand;455break;456case SpvOpGroupNonUniformBitwiseOr:457case SpvOpGroupNonUniformLogicalOr:458reduction_op = nir_op_ior;459break;460case SpvOpGroupNonUniformBitwiseXor:461case SpvOpGroupNonUniformLogicalXor:462reduction_op = nir_op_ixor;463break;464default:465unreachable("Invalid reduction operation");466}467468nir_intrinsic_op op;469unsigned cluster_size = 0;470switch ((SpvGroupOperation)w[4]) {471case SpvGroupOperationReduce:472op = nir_intrinsic_reduce;473break;474case SpvGroupOperationInclusiveScan:475op = nir_intrinsic_inclusive_scan;476break;477case SpvGroupOperationExclusiveScan:478op = nir_intrinsic_exclusive_scan;479break;480case SpvGroupOperationClusteredReduce:481op = nir_intrinsic_reduce;482assert(count == 7);483cluster_size = vtn_constant_uint(b, w[6]);484break;485default:486unreachable("Invalid group operation");487}488489vtn_push_ssa_value(b, w[2],490vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[5]), NULL,491reduction_op, cluster_size));492break;493}494495default:496unreachable("Invalid SPIR-V opcode");497}498}499500501