Path: blob/21.2-virgl/src/compiler/spirv/vtn_alu.c
4545 views
/*1* Copyright © 2016 Intel Corporation2*3* Permission is hereby granted, free of charge, to any person obtaining a4* copy of this software and associated documentation files (the "Software"),5* to deal in the Software without restriction, including without limitation6* the rights to use, copy, modify, merge, publish, distribute, sublicense,7* and/or sell copies of the Software, and to permit persons to whom the8* Software is furnished to do so, subject to the following conditions:9*10* The above copyright notice and this permission notice (including the next11* paragraph) shall be included in all copies or substantial portions of the12* Software.13*14* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR15* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,16* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL17* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER18* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING19* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS20* IN THE SOFTWARE.21*/2223#include <math.h>24#include "vtn_private.h"25#include "spirv_info.h"2627/*28* Normally, column vectors in SPIR-V correspond to a single NIR SSA29* definition. But for matrix multiplies, we want to do one routine for30* multiplying a matrix by a matrix and then pretend that vectors are matrices31* with one column. So we "wrap" these things, and unwrap the result before we32* send it off.33*/3435static struct vtn_ssa_value *36wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)37{38if (val == NULL)39return NULL;4041if (glsl_type_is_matrix(val->type))42return val;4344struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value);45dest->type = glsl_get_bare_type(val->type);46dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1);47dest->elems[0] = val;4849return dest;50}5152static struct vtn_ssa_value *53unwrap_matrix(struct vtn_ssa_value *val)54{55if (glsl_type_is_matrix(val->type))56return val;5758return val->elems[0];59}6061static struct vtn_ssa_value *62matrix_multiply(struct vtn_builder *b,63struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)64{6566struct vtn_ssa_value *src0 = wrap_matrix(b, _src0);67struct vtn_ssa_value *src1 = wrap_matrix(b, _src1);68struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed);69struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed);7071unsigned src0_rows = glsl_get_vector_elements(src0->type);72unsigned src0_columns = glsl_get_matrix_columns(src0->type);73unsigned src1_columns = glsl_get_matrix_columns(src1->type);7475const struct glsl_type *dest_type;76if (src1_columns > 1) {77dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),78src0_rows, src1_columns);79} else {80dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);81}82struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);8384dest = wrap_matrix(b, dest);8586bool transpose_result = false;87if (src0_transpose && src1_transpose) {88/* transpose(A) * transpose(B) = transpose(B * A) */89src1 = src0_transpose;90src0 = src1_transpose;91src0_transpose = NULL;92src1_transpose = NULL;93transpose_result = true;94}9596if (src0_transpose && !src1_transpose &&97glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) {98/* We already have the rows of src0 and the columns of src1 available,99* so we can just take the dot product of each row with each column to100* get the result.101*/102103for (unsigned i = 0; i < src1_columns; i++) {104nir_ssa_def *vec_src[4];105for (unsigned j = 0; j < src0_rows; j++) {106vec_src[j] = nir_fdot(&b->nb, src0_transpose->elems[j]->def,107src1->elems[i]->def);108}109dest->elems[i]->def = nir_vec(&b->nb, vec_src, src0_rows);110}111} else {112/* We don't handle the case where src1 is transposed but not src0, since113* the general case only uses individual components of src1 so the114* optimizer should chew through the transpose we emitted for src1.115*/116117for (unsigned i = 0; i < src1_columns; i++) {118/* dest[i] = sum(src0[j] * src1[i][j] for all j) */119dest->elems[i]->def =120nir_fmul(&b->nb, src0->elems[src0_columns - 1]->def,121nir_channel(&b->nb, src1->elems[i]->def, src0_columns - 1));122for (int j = src0_columns - 2; j >= 0; j--) {123dest->elems[i]->def =124nir_fadd(&b->nb, dest->elems[i]->def,125nir_fmul(&b->nb, src0->elems[j]->def,126nir_channel(&b->nb, src1->elems[i]->def, j)));127}128}129}130131dest = unwrap_matrix(dest);132133if (transpose_result)134dest = vtn_ssa_transpose(b, dest);135136return dest;137}138139static struct vtn_ssa_value *140mat_times_scalar(struct vtn_builder *b,141struct vtn_ssa_value *mat,142nir_ssa_def *scalar)143{144struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);145for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {146if (glsl_base_type_is_integer(glsl_get_base_type(mat->type)))147dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);148else149dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);150}151152return dest;153}154155static struct vtn_ssa_value *156vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,157struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)158{159switch (opcode) {160case SpvOpFNegate: {161struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);162unsigned cols = glsl_get_matrix_columns(src0->type);163for (unsigned i = 0; i < cols; i++)164dest->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def);165return dest;166}167168case SpvOpFAdd: {169struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);170unsigned cols = glsl_get_matrix_columns(src0->type);171for (unsigned i = 0; i < cols; i++)172dest->elems[i]->def =173nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def);174return dest;175}176177case SpvOpFSub: {178struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);179unsigned cols = glsl_get_matrix_columns(src0->type);180for (unsigned i = 0; i < cols; i++)181dest->elems[i]->def =182nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def);183return dest;184}185186case SpvOpTranspose:187return vtn_ssa_transpose(b, src0);188189case SpvOpMatrixTimesScalar:190if (src0->transposed) {191return vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,192src1->def));193} else {194return mat_times_scalar(b, src0, src1->def);195}196break;197198case SpvOpVectorTimesMatrix:199case SpvOpMatrixTimesVector:200case SpvOpMatrixTimesMatrix:201if (opcode == SpvOpVectorTimesMatrix) {202return matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);203} else {204return matrix_multiply(b, src0, src1);205}206break;207208default: vtn_fail_with_opcode("unknown matrix opcode", opcode);209}210}211212static nir_alu_type213convert_op_src_type(SpvOp opcode)214{215switch (opcode) {216case SpvOpFConvert:217case SpvOpConvertFToS:218case SpvOpConvertFToU:219return nir_type_float;220case SpvOpSConvert:221case SpvOpConvertSToF:222case SpvOpSatConvertSToU:223return nir_type_int;224case SpvOpUConvert:225case SpvOpConvertUToF:226case SpvOpSatConvertUToS:227return nir_type_uint;228default:229unreachable("Unhandled conversion op");230}231}232233static nir_alu_type234convert_op_dst_type(SpvOp opcode)235{236switch (opcode) {237case SpvOpFConvert:238case SpvOpConvertSToF:239case SpvOpConvertUToF:240return nir_type_float;241case SpvOpSConvert:242case SpvOpConvertFToS:243case SpvOpSatConvertUToS:244return nir_type_int;245case SpvOpUConvert:246case SpvOpConvertFToU:247case SpvOpSatConvertSToU:248return nir_type_uint;249default:250unreachable("Unhandled conversion op");251}252}253254nir_op255vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,256SpvOp opcode, bool *swap, bool *exact,257unsigned src_bit_size, unsigned dst_bit_size)258{259/* Indicates that the first two arguments should be swapped. This is260* used for implementing greater-than and less-than-or-equal.261*/262*swap = false;263264*exact = false;265266switch (opcode) {267case SpvOpSNegate: return nir_op_ineg;268case SpvOpFNegate: return nir_op_fneg;269case SpvOpNot: return nir_op_inot;270case SpvOpIAdd: return nir_op_iadd;271case SpvOpFAdd: return nir_op_fadd;272case SpvOpISub: return nir_op_isub;273case SpvOpFSub: return nir_op_fsub;274case SpvOpIMul: return nir_op_imul;275case SpvOpFMul: return nir_op_fmul;276case SpvOpUDiv: return nir_op_udiv;277case SpvOpSDiv: return nir_op_idiv;278case SpvOpFDiv: return nir_op_fdiv;279case SpvOpUMod: return nir_op_umod;280case SpvOpSMod: return nir_op_imod;281case SpvOpFMod: return nir_op_fmod;282case SpvOpSRem: return nir_op_irem;283case SpvOpFRem: return nir_op_frem;284285case SpvOpShiftRightLogical: return nir_op_ushr;286case SpvOpShiftRightArithmetic: return nir_op_ishr;287case SpvOpShiftLeftLogical: return nir_op_ishl;288case SpvOpLogicalOr: return nir_op_ior;289case SpvOpLogicalEqual: return nir_op_ieq;290case SpvOpLogicalNotEqual: return nir_op_ine;291case SpvOpLogicalAnd: return nir_op_iand;292case SpvOpLogicalNot: return nir_op_inot;293case SpvOpBitwiseOr: return nir_op_ior;294case SpvOpBitwiseXor: return nir_op_ixor;295case SpvOpBitwiseAnd: return nir_op_iand;296case SpvOpSelect: return nir_op_bcsel;297case SpvOpIEqual: return nir_op_ieq;298299case SpvOpBitFieldInsert: return nir_op_bitfield_insert;300case SpvOpBitFieldSExtract: return nir_op_ibitfield_extract;301case SpvOpBitFieldUExtract: return nir_op_ubitfield_extract;302case SpvOpBitReverse: return nir_op_bitfield_reverse;303304case SpvOpUCountLeadingZerosINTEL: return nir_op_uclz;305/* SpvOpUCountTrailingZerosINTEL is handled elsewhere. */306case SpvOpAbsISubINTEL: return nir_op_uabs_isub;307case SpvOpAbsUSubINTEL: return nir_op_uabs_usub;308case SpvOpIAddSatINTEL: return nir_op_iadd_sat;309case SpvOpUAddSatINTEL: return nir_op_uadd_sat;310case SpvOpIAverageINTEL: return nir_op_ihadd;311case SpvOpUAverageINTEL: return nir_op_uhadd;312case SpvOpIAverageRoundedINTEL: return nir_op_irhadd;313case SpvOpUAverageRoundedINTEL: return nir_op_urhadd;314case SpvOpISubSatINTEL: return nir_op_isub_sat;315case SpvOpUSubSatINTEL: return nir_op_usub_sat;316case SpvOpIMul32x16INTEL: return nir_op_imul_32x16;317case SpvOpUMul32x16INTEL: return nir_op_umul_32x16;318319/* The ordered / unordered operators need special implementation besides320* the logical operator to use since they also need to check if operands are321* ordered.322*/323case SpvOpFOrdEqual: *exact = true; return nir_op_feq;324case SpvOpFUnordEqual: *exact = true; return nir_op_feq;325case SpvOpINotEqual: return nir_op_ine;326case SpvOpLessOrGreater: /* Deprecated, use OrdNotEqual */327case SpvOpFOrdNotEqual: *exact = true; return nir_op_fneu;328case SpvOpFUnordNotEqual: *exact = true; return nir_op_fneu;329case SpvOpULessThan: return nir_op_ult;330case SpvOpSLessThan: return nir_op_ilt;331case SpvOpFOrdLessThan: *exact = true; return nir_op_flt;332case SpvOpFUnordLessThan: *exact = true; return nir_op_flt;333case SpvOpUGreaterThan: *swap = true; return nir_op_ult;334case SpvOpSGreaterThan: *swap = true; return nir_op_ilt;335case SpvOpFOrdGreaterThan: *swap = true; *exact = true; return nir_op_flt;336case SpvOpFUnordGreaterThan: *swap = true; *exact = true; return nir_op_flt;337case SpvOpULessThanEqual: *swap = true; return nir_op_uge;338case SpvOpSLessThanEqual: *swap = true; return nir_op_ige;339case SpvOpFOrdLessThanEqual: *swap = true; *exact = true; return nir_op_fge;340case SpvOpFUnordLessThanEqual: *swap = true; *exact = true; return nir_op_fge;341case SpvOpUGreaterThanEqual: return nir_op_uge;342case SpvOpSGreaterThanEqual: return nir_op_ige;343case SpvOpFOrdGreaterThanEqual: *exact = true; return nir_op_fge;344case SpvOpFUnordGreaterThanEqual: *exact = true; return nir_op_fge;345346/* Conversions: */347case SpvOpQuantizeToF16: return nir_op_fquantize2f16;348case SpvOpUConvert:349case SpvOpConvertFToU:350case SpvOpConvertFToS:351case SpvOpConvertSToF:352case SpvOpConvertUToF:353case SpvOpSConvert:354case SpvOpFConvert: {355nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;356nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;357return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef);358}359360case SpvOpPtrCastToGeneric: return nir_op_mov;361case SpvOpGenericCastToPtr: return nir_op_mov;362363/* Derivatives: */364case SpvOpDPdx: return nir_op_fddx;365case SpvOpDPdy: return nir_op_fddy;366case SpvOpDPdxFine: return nir_op_fddx_fine;367case SpvOpDPdyFine: return nir_op_fddy_fine;368case SpvOpDPdxCoarse: return nir_op_fddx_coarse;369case SpvOpDPdyCoarse: return nir_op_fddy_coarse;370371case SpvOpIsNormal: return nir_op_fisnormal;372case SpvOpIsFinite: return nir_op_fisfinite;373374default:375vtn_fail("No NIR equivalent: %u", opcode);376}377}378379static void380handle_no_contraction(struct vtn_builder *b, struct vtn_value *val, int member,381const struct vtn_decoration *dec, void *_void)382{383vtn_assert(dec->scope == VTN_DEC_DECORATION);384if (dec->decoration != SpvDecorationNoContraction)385return;386387b->nb.exact = true;388}389390void391vtn_handle_no_contraction(struct vtn_builder *b, struct vtn_value *val)392{393vtn_foreach_decoration(b, val, handle_no_contraction, NULL);394}395396nir_rounding_mode397vtn_rounding_mode_to_nir(struct vtn_builder *b, SpvFPRoundingMode mode)398{399switch (mode) {400case SpvFPRoundingModeRTE:401return nir_rounding_mode_rtne;402case SpvFPRoundingModeRTZ:403return nir_rounding_mode_rtz;404case SpvFPRoundingModeRTP:405vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,406"FPRoundingModeRTP is only supported in kernels");407return nir_rounding_mode_ru;408case SpvFPRoundingModeRTN:409vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,410"FPRoundingModeRTN is only supported in kernels");411return nir_rounding_mode_rd;412default:413vtn_fail("Unsupported rounding mode: %s",414spirv_fproundingmode_to_string(mode));415break;416}417}418419struct conversion_opts {420nir_rounding_mode rounding_mode;421bool saturate;422};423424static void425handle_conversion_opts(struct vtn_builder *b, struct vtn_value *val, int member,426const struct vtn_decoration *dec, void *_opts)427{428struct conversion_opts *opts = _opts;429430switch (dec->decoration) {431case SpvDecorationFPRoundingMode:432opts->rounding_mode = vtn_rounding_mode_to_nir(b, dec->operands[0]);433break;434435case SpvDecorationSaturatedConversion:436vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,437"Saturated conversions are only allowed in kernels");438opts->saturate = true;439break;440441default:442break;443}444}445446static void447handle_no_wrap(struct vtn_builder *b, struct vtn_value *val, int member,448const struct vtn_decoration *dec, void *_alu)449{450nir_alu_instr *alu = _alu;451switch (dec->decoration) {452case SpvDecorationNoSignedWrap:453alu->no_signed_wrap = true;454break;455case SpvDecorationNoUnsignedWrap:456alu->no_unsigned_wrap = true;457break;458default:459/* Do nothing. */460break;461}462}463464void465vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,466const uint32_t *w, unsigned count)467{468struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);469const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;470471vtn_handle_no_contraction(b, dest_val);472473/* Collect the various SSA sources */474const unsigned num_inputs = count - 3;475struct vtn_ssa_value *vtn_src[4] = { NULL, };476for (unsigned i = 0; i < num_inputs; i++)477vtn_src[i] = vtn_ssa_value(b, w[i + 3]);478479if (glsl_type_is_matrix(vtn_src[0]->type) ||480(num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {481vtn_push_ssa_value(b, w[2],482vtn_handle_matrix_alu(b, opcode, vtn_src[0], vtn_src[1]));483b->nb.exact = b->exact;484return;485}486487struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);488nir_ssa_def *src[4] = { NULL, };489for (unsigned i = 0; i < num_inputs; i++) {490vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));491src[i] = vtn_src[i]->def;492}493494switch (opcode) {495case SpvOpAny:496dest->def = nir_bany(&b->nb, src[0]);497break;498499case SpvOpAll:500dest->def = nir_ball(&b->nb, src[0]);501break;502503case SpvOpOuterProduct: {504for (unsigned i = 0; i < src[1]->num_components; i++) {505dest->elems[i]->def =506nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));507}508break;509}510511case SpvOpDot:512dest->def = nir_fdot(&b->nb, src[0], src[1]);513break;514515case SpvOpIAddCarry:516vtn_assert(glsl_type_is_struct_or_ifc(dest_type));517dest->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);518dest->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);519break;520521case SpvOpISubBorrow:522vtn_assert(glsl_type_is_struct_or_ifc(dest_type));523dest->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);524dest->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);525break;526527case SpvOpUMulExtended: {528vtn_assert(glsl_type_is_struct_or_ifc(dest_type));529nir_ssa_def *umul = nir_umul_2x32_64(&b->nb, src[0], src[1]);530dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);531dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);532break;533}534535case SpvOpSMulExtended: {536vtn_assert(glsl_type_is_struct_or_ifc(dest_type));537nir_ssa_def *smul = nir_imul_2x32_64(&b->nb, src[0], src[1]);538dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, smul);539dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, smul);540break;541}542543case SpvOpFwidth:544dest->def = nir_fadd(&b->nb,545nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),546nir_fabs(&b->nb, nir_fddy(&b->nb, src[0])));547break;548case SpvOpFwidthFine:549dest->def = nir_fadd(&b->nb,550nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),551nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0])));552break;553case SpvOpFwidthCoarse:554dest->def = nir_fadd(&b->nb,555nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),556nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0])));557break;558559case SpvOpVectorTimesScalar:560/* The builder will take care of splatting for us. */561dest->def = nir_fmul(&b->nb, src[0], src[1]);562break;563564case SpvOpIsNan: {565const bool save_exact = b->nb.exact;566567b->nb.exact = true;568dest->def = nir_fneu(&b->nb, src[0], src[0]);569b->nb.exact = save_exact;570break;571}572573case SpvOpOrdered: {574const bool save_exact = b->nb.exact;575576b->nb.exact = true;577dest->def = nir_iand(&b->nb, nir_feq(&b->nb, src[0], src[0]),578nir_feq(&b->nb, src[1], src[1]));579b->nb.exact = save_exact;580break;581}582583case SpvOpUnordered: {584const bool save_exact = b->nb.exact;585586b->nb.exact = true;587dest->def = nir_ior(&b->nb, nir_fneu(&b->nb, src[0], src[0]),588nir_fneu(&b->nb, src[1], src[1]));589b->nb.exact = save_exact;590break;591}592593case SpvOpIsInf: {594nir_ssa_def *inf = nir_imm_floatN_t(&b->nb, INFINITY, src[0]->bit_size);595dest->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf);596break;597}598599case SpvOpFUnordEqual:600case SpvOpFUnordNotEqual:601case SpvOpFUnordLessThan:602case SpvOpFUnordGreaterThan:603case SpvOpFUnordLessThanEqual:604case SpvOpFUnordGreaterThanEqual: {605bool swap;606bool unused_exact;607unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);608unsigned dst_bit_size = glsl_get_bit_size(dest_type);609nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,610&unused_exact,611src_bit_size, dst_bit_size);612613if (swap) {614nir_ssa_def *tmp = src[0];615src[0] = src[1];616src[1] = tmp;617}618619const bool save_exact = b->nb.exact;620621b->nb.exact = true;622623dest->def =624nir_ior(&b->nb,625nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),626nir_ior(&b->nb,627nir_fneu(&b->nb, src[0], src[0]),628nir_fneu(&b->nb, src[1], src[1])));629630b->nb.exact = save_exact;631break;632}633634case SpvOpLessOrGreater:635case SpvOpFOrdNotEqual: {636/* For all the SpvOpFOrd* comparisons apart from NotEqual, the value637* from the ALU will probably already be false if the operands are not638* ordered so we don’t need to handle it specially.639*/640bool swap;641bool exact;642unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);643unsigned dst_bit_size = glsl_get_bit_size(dest_type);644nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact,645src_bit_size, dst_bit_size);646647assert(!swap);648assert(exact);649650const bool save_exact = b->nb.exact;651652b->nb.exact = true;653654dest->def =655nir_iand(&b->nb,656nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),657nir_iand(&b->nb,658nir_feq(&b->nb, src[0], src[0]),659nir_feq(&b->nb, src[1], src[1])));660661b->nb.exact = save_exact;662break;663}664665case SpvOpUConvert:666case SpvOpConvertFToU:667case SpvOpConvertFToS:668case SpvOpConvertSToF:669case SpvOpConvertUToF:670case SpvOpSConvert:671case SpvOpFConvert:672case SpvOpSatConvertSToU:673case SpvOpSatConvertUToS: {674unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);675unsigned dst_bit_size = glsl_get_bit_size(dest_type);676nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;677nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;678679struct conversion_opts opts = {680.rounding_mode = nir_rounding_mode_undef,681.saturate = false,682};683vtn_foreach_decoration(b, dest_val, handle_conversion_opts, &opts);684685if (opcode == SpvOpSatConvertSToU || opcode == SpvOpSatConvertUToS)686opts.saturate = true;687688if (b->shader->info.stage == MESA_SHADER_KERNEL) {689if (opts.rounding_mode == nir_rounding_mode_undef && !opts.saturate) {690nir_op op = nir_type_conversion_op(src_type, dst_type,691nir_rounding_mode_undef);692dest->def = nir_build_alu(&b->nb, op, src[0], NULL, NULL, NULL);693} else {694dest->def = nir_convert_alu_types(&b->nb, dst_bit_size, src[0],695src_type, dst_type,696opts.rounding_mode, opts.saturate);697}698} else {699vtn_fail_if(opts.rounding_mode != nir_rounding_mode_undef &&700dst_type != nir_type_float16,701"Rounding modes are only allowed on conversions to "702"16-bit float types");703nir_op op = nir_type_conversion_op(src_type, dst_type,704opts.rounding_mode);705dest->def = nir_build_alu(&b->nb, op, src[0], NULL, NULL, NULL);706}707break;708}709710case SpvOpBitFieldInsert:711case SpvOpBitFieldSExtract:712case SpvOpBitFieldUExtract:713case SpvOpShiftLeftLogical:714case SpvOpShiftRightArithmetic:715case SpvOpShiftRightLogical: {716bool swap;717bool exact;718unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type);719unsigned dst_bit_size = glsl_get_bit_size(dest_type);720nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact,721src0_bit_size, dst_bit_size);722723assert(!exact);724725assert (op == nir_op_ushr || op == nir_op_ishr || op == nir_op_ishl ||726op == nir_op_bitfield_insert || op == nir_op_ubitfield_extract ||727op == nir_op_ibitfield_extract);728729for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {730unsigned src_bit_size =731nir_alu_type_get_type_size(nir_op_infos[op].input_types[i]);732if (src_bit_size == 0)733continue;734if (src_bit_size != src[i]->bit_size) {735assert(src_bit_size == 32);736/* Convert the Shift, Offset and Count operands to 32 bits, which is the bitsize737* supported by the NIR instructions. See discussion here:738*739* https://lists.freedesktop.org/archives/mesa-dev/2018-April/193026.html740*/741src[i] = nir_u2u32(&b->nb, src[i]);742}743}744dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);745break;746}747748case SpvOpSignBitSet:749dest->def = nir_i2b(&b->nb,750nir_ushr(&b->nb, src[0], nir_imm_int(&b->nb, src[0]->bit_size - 1)));751break;752753case SpvOpUCountTrailingZerosINTEL:754dest->def = nir_umin(&b->nb,755nir_find_lsb(&b->nb, src[0]),756nir_imm_int(&b->nb, 32u));757break;758759case SpvOpBitCount: {760/* bit_count always returns int32, but the SPIR-V opcode just says the return761* value needs to be big enough to store the number of bits.762*/763dest->def = nir_u2u(&b->nb, nir_bit_count(&b->nb, src[0]), glsl_get_bit_size(dest_type));764break;765}766767default: {768bool swap;769bool exact;770unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);771unsigned dst_bit_size = glsl_get_bit_size(dest_type);772nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,773&exact,774src_bit_size, dst_bit_size);775776if (swap) {777nir_ssa_def *tmp = src[0];778src[0] = src[1];779src[1] = tmp;780}781782switch (op) {783case nir_op_ishl:784case nir_op_ishr:785case nir_op_ushr:786if (src[1]->bit_size != 32)787src[1] = nir_u2u32(&b->nb, src[1]);788break;789default:790break;791}792793const bool save_exact = b->nb.exact;794795if (exact)796b->nb.exact = true;797798dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);799800b->nb.exact = save_exact;801break;802} /* default */803}804805switch (opcode) {806case SpvOpIAdd:807case SpvOpIMul:808case SpvOpISub:809case SpvOpShiftLeftLogical:810case SpvOpSNegate: {811nir_alu_instr *alu = nir_instr_as_alu(dest->def->parent_instr);812vtn_foreach_decoration(b, dest_val, handle_no_wrap, alu);813break;814}815default:816/* Do nothing. */817break;818}819820vtn_push_ssa_value(b, w[2], dest);821822b->nb.exact = b->exact;823}824825void826vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count)827{828vtn_assert(count == 4);829/* From the definition of OpBitcast in the SPIR-V 1.2 spec:830*831* "If Result Type has the same number of components as Operand, they832* must also have the same component width, and results are computed per833* component.834*835* If Result Type has a different number of components than Operand, the836* total number of bits in Result Type must equal the total number of837* bits in Operand. Let L be the type, either Result Type or Operand’s838* type, that has the larger number of components. Let S be the other839* type, with the smaller number of components. The number of components840* in L must be an integer multiple of the number of components in S.841* The first component (that is, the only or lowest-numbered component)842* of S maps to the first components of L, and so on, up to the last843* component of S mapping to the last components of L. Within this844* mapping, any single component of S (mapping to multiple components of845* L) maps its lower-ordered bits to the lower-numbered components of L."846*/847848struct vtn_type *type = vtn_get_type(b, w[1]);849struct nir_ssa_def *src = vtn_get_nir_ssa(b, w[3]);850851vtn_fail_if(src->num_components * src->bit_size !=852glsl_get_vector_elements(type->type) * glsl_get_bit_size(type->type),853"Source and destination of OpBitcast must have the same "854"total number of bits");855nir_ssa_def *val =856nir_bitcast_vector(&b->nb, src, glsl_get_bit_size(type->type));857vtn_push_nir_ssa(b, w[2], val);858}859860861