Path: blob/21.2-virgl/src/compiler/nir/nir_conversion_builder.h
4545 views
/*1* Copyright © 2020 Collabora Ltd.2*3* Permission is hereby granted, free of charge, to any person obtaining a4* copy of this software and associated documentation files (the "Software"),5* to deal in the Software without restriction, including without limitation6* the rights to use, copy, modify, merge, publish, distribute, sublicense,7* and/or sell copies of the Software, and to permit persons to whom the8* Software is furnished to do so, subject to the following conditions:9*10* The above copyright notice and this permission notice (including the next11* paragraph) shall be included in all copies or substantial portions of the12* Software.13*14* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR15* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,16* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL17* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER18* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING19* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS20* IN THE SOFTWARE.21*/2223#ifndef NIR_CONVERSION_BUILDER_H24#define NIR_CONVERSION_BUILDER_H2526#include "util/u_math.h"27#include "nir_builder.h"28#include "nir_builtin_builder.h"2930#ifdef __cplusplus31extern "C" {32#endif3334static inline nir_ssa_def *35nir_round_float_to_int(nir_builder *b, nir_ssa_def *src,36nir_rounding_mode round)37{38switch (round) {39case nir_rounding_mode_ru:40return nir_fceil(b, src);4142case nir_rounding_mode_rd:43return nir_ffloor(b, src);4445case nir_rounding_mode_rtne:46return nir_fround_even(b, src);4748case nir_rounding_mode_undef:49case nir_rounding_mode_rtz:50break;51}52unreachable("unexpected rounding mode");53}5455static inline nir_ssa_def *56nir_round_float_to_float(nir_builder *b, nir_ssa_def *src,57unsigned dest_bit_size,58nir_rounding_mode round)59{60unsigned src_bit_size = src->bit_size;61if (dest_bit_size > src_bit_size)62return src; /* No rounding is needed for an up-convert */6364nir_op low_conv = nir_type_conversion_op(nir_type_float | src_bit_size,65nir_type_float | dest_bit_size,66nir_rounding_mode_undef);67nir_op high_conv = nir_type_conversion_op(nir_type_float | dest_bit_size,68nir_type_float | src_bit_size,69nir_rounding_mode_undef);7071switch (round) {72case nir_rounding_mode_ru: {73/* If lower-precision conversion results in a lower value, push it74* up one ULP. */75nir_ssa_def *lower_prec =76nir_build_alu(b, low_conv, src, NULL, NULL, NULL);77nir_ssa_def *roundtrip =78nir_build_alu(b, high_conv, lower_prec, NULL, NULL, NULL);79nir_ssa_def *cmp = nir_flt(b, roundtrip, src);80nir_ssa_def *inf = nir_imm_floatN_t(b, INFINITY, dest_bit_size);81return nir_bcsel(b, cmp, nir_nextafter(b, lower_prec, inf), lower_prec);82}83case nir_rounding_mode_rd: {84/* If lower-precision conversion results in a higher value, push it85* down one ULP. */86nir_ssa_def *lower_prec =87nir_build_alu(b, low_conv, src, NULL, NULL, NULL);88nir_ssa_def *roundtrip =89nir_build_alu(b, high_conv, lower_prec, NULL, NULL, NULL);90nir_ssa_def *cmp = nir_flt(b, src, roundtrip);91nir_ssa_def *neg_inf = nir_imm_floatN_t(b, -INFINITY, dest_bit_size);92return nir_bcsel(b, cmp, nir_nextafter(b, lower_prec, neg_inf), lower_prec);93}94case nir_rounding_mode_rtz:95return nir_bcsel(b, nir_flt(b, src, nir_imm_zero(b, 1, src->bit_size)),96nir_round_float_to_float(b, src, dest_bit_size,97nir_rounding_mode_ru),98nir_round_float_to_float(b, src, dest_bit_size,99nir_rounding_mode_rd));100case nir_rounding_mode_rtne:101case nir_rounding_mode_undef:102break;103}104unreachable("unexpected rounding mode");105}106107static inline nir_ssa_def *108nir_round_int_to_float(nir_builder *b, nir_ssa_def *src,109nir_alu_type src_type,110unsigned dest_bit_size,111nir_rounding_mode round)112{113/* We only care whether or not its signed */114src_type = nir_alu_type_get_base_type(src_type);115116unsigned mantissa_bits;117switch (dest_bit_size) {118case 16:119mantissa_bits = 10;120break;121case 32:122mantissa_bits = 23;123break;124case 64:125mantissa_bits = 52;126break;127default: unreachable("Unsupported bit size");128}129130if (src->bit_size < mantissa_bits)131return src;132133if (src_type == nir_type_int) {134nir_ssa_def *sign =135nir_i2b1(b, nir_ishr(b, src, nir_imm_int(b, src->bit_size - 1)));136nir_ssa_def *abs = nir_iabs(b, src);137nir_ssa_def *positive_rounded =138nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, round);139nir_ssa_def *max_positive =140nir_imm_intN_t(b, (1ull << (src->bit_size - 1)) - 1, src->bit_size);141switch (round) {142case nir_rounding_mode_rtz:143return nir_bcsel(b, sign, nir_ineg(b, positive_rounded),144positive_rounded);145break;146case nir_rounding_mode_ru:147return nir_bcsel(b, sign,148nir_ineg(b, nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, nir_rounding_mode_rd)),149nir_umin(b, positive_rounded, max_positive));150break;151case nir_rounding_mode_rd:152return nir_bcsel(b, sign,153nir_ineg(b,154nir_umin(b, max_positive,155nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, nir_rounding_mode_ru))),156positive_rounded);157case nir_rounding_mode_rtne:158case nir_rounding_mode_undef:159break;160}161unreachable("unexpected rounding mode");162} else {163nir_ssa_def *mantissa_bit_size = nir_imm_int(b, mantissa_bits);164nir_ssa_def *msb = nir_imax(b, nir_ufind_msb(b, src), mantissa_bit_size);165nir_ssa_def *bits_to_lose = nir_isub(b, msb, mantissa_bit_size);166nir_ssa_def *one = nir_imm_intN_t(b, 1, src->bit_size);167nir_ssa_def *adjust = nir_ishl(b, one, bits_to_lose);168nir_ssa_def *mask = nir_inot(b, nir_isub(b, adjust, one));169nir_ssa_def *truncated = nir_iand(b, src, mask);170switch (round) {171case nir_rounding_mode_rtz:172case nir_rounding_mode_rd:173return truncated;174break;175case nir_rounding_mode_ru:176return nir_bcsel(b, nir_ieq(b, src, truncated),177src, nir_uadd_sat(b, truncated, adjust));178case nir_rounding_mode_rtne:179case nir_rounding_mode_undef:180break;181}182unreachable("unexpected rounding mode");183}184}185186/** Returns true if the representable range of a contains the representable187* range of b.188*/189static inline bool190nir_alu_type_range_contains_type_range(nir_alu_type a, nir_alu_type b)191{192/* Split types from bit sizes */193nir_alu_type a_base_type = nir_alu_type_get_base_type(a);194nir_alu_type b_base_type = nir_alu_type_get_base_type(b);195unsigned a_bit_size = nir_alu_type_get_type_size(a);196unsigned b_bit_size = nir_alu_type_get_type_size(b);197198/* This requires sized types */199assert(a_bit_size > 0 && b_bit_size > 0);200201if (a_base_type == b_base_type && a_bit_size >= b_bit_size)202return true;203204if (a_base_type == nir_type_int && b_base_type == nir_type_uint &&205a_bit_size > b_bit_size)206return true;207208/* 16-bit floats fit in 32-bit integers */209if (a_base_type == nir_type_int && a_bit_size >= 32 &&210b == nir_type_float16)211return true;212213/* All signed or unsigned ints can fit in float or above. A uint8 can fit214* in a float16.215*/216if (a_base_type == nir_type_float && b_base_type != nir_type_float &&217(a_bit_size >= 32 || b_bit_size == 8))218return true;219220return false;221}222223/**224* Retrieves limits used for clamping a value of the src type into225* the widest representable range of the dst type via cmp + bcsel226*/227static inline void228nir_get_clamp_limits(nir_builder *b,229nir_alu_type src_type,230nir_alu_type dest_type,231nir_ssa_def **low, nir_ssa_def **high)232{233/* Split types from bit sizes */234nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type);235nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type);236unsigned src_bit_size = nir_alu_type_get_type_size(src_type);237unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type);238assert(dest_bit_size != 0 && src_bit_size != 0);239240*low = NULL;241*high = NULL;242243/* limits of the destination type, expressed in the source type */244switch (dest_base_type) {245case nir_type_int: {246int64_t ilow, ihigh;247if (dest_bit_size == 64) {248ilow = INT64_MIN;249ihigh = INT64_MAX;250} else {251ilow = -(1ll << (dest_bit_size - 1));252ihigh = (1ll << (dest_bit_size - 1)) - 1;253}254255if (src_base_type == nir_type_int) {256*low = nir_imm_intN_t(b, ilow, src_bit_size);257*high = nir_imm_intN_t(b, ihigh, src_bit_size);258} else if (src_base_type == nir_type_uint) {259assert(src_bit_size >= dest_bit_size);260*high = nir_imm_intN_t(b, ihigh, src_bit_size);261} else {262*low = nir_imm_floatN_t(b, ilow, src_bit_size);263*high = nir_imm_floatN_t(b, ihigh, src_bit_size);264}265break;266}267case nir_type_uint: {268uint64_t uhigh = dest_bit_size == 64 ?269~0ull : (1ull << dest_bit_size) - 1;270if (src_base_type != nir_type_float) {271*low = nir_imm_intN_t(b, 0, src_bit_size);272if (src_base_type == nir_type_uint || src_bit_size > dest_bit_size)273*high = nir_imm_intN_t(b, uhigh, src_bit_size);274} else {275*low = nir_imm_floatN_t(b, 0.0f, src_bit_size);276*high = nir_imm_floatN_t(b, uhigh, src_bit_size);277}278break;279}280case nir_type_float: {281double flow, fhigh;282switch (dest_bit_size) {283case 16:284flow = -65504.0f;285fhigh = 65504.0f;286break;287case 32:288flow = -FLT_MAX;289fhigh = FLT_MAX;290break;291case 64:292flow = -DBL_MAX;293fhigh = DBL_MAX;294break;295default:296unreachable("Unhandled bit size");297}298299switch (src_base_type) {300case nir_type_int: {301int64_t src_ilow, src_ihigh;302if (src_bit_size == 64) {303src_ilow = INT64_MIN;304src_ihigh = INT64_MAX;305} else {306src_ilow = -(1ll << (src_bit_size - 1));307src_ihigh = (1ll << (src_bit_size - 1)) - 1;308}309if (src_ilow < flow)310*low = nir_imm_intN_t(b, flow, src_bit_size);311if (src_ihigh > fhigh)312*high = nir_imm_intN_t(b, fhigh, src_bit_size);313break;314}315case nir_type_uint: {316uint64_t src_uhigh = src_bit_size == 64 ?317~0ull : (1ull << src_bit_size) - 1;318if (src_uhigh > fhigh)319*high = nir_imm_intN_t(b, fhigh, src_bit_size);320break;321}322case nir_type_float:323*low = nir_imm_floatN_t(b, flow, src_bit_size);324*high = nir_imm_floatN_t(b, fhigh, src_bit_size);325break;326default:327unreachable("Clamping from unknown type");328}329break;330}331default:332unreachable("clamping to unknown type");333break;334}335}336337/**338* Clamp the value into the widest representatble range of the339* destination type with cmp + bcsel.340*341* val/val_type: The variables used for bcsel342* src/src_type: The variables used for comparison343* dest_type: The type which determines the range used for comparison344*/345static inline nir_ssa_def *346nir_clamp_to_type_range(nir_builder *b,347nir_ssa_def *val, nir_alu_type val_type,348nir_ssa_def *src, nir_alu_type src_type,349nir_alu_type dest_type)350{351assert(nir_alu_type_get_type_size(src_type) == 0 ||352nir_alu_type_get_type_size(src_type) == src->bit_size);353src_type |= src->bit_size;354if (nir_alu_type_range_contains_type_range(dest_type, src_type))355return val;356357/* limits of the destination type, expressed in the source type */358nir_ssa_def *low = NULL, *high = NULL;359nir_get_clamp_limits(b, src_type, dest_type, &low, &high);360361nir_ssa_def *low_cond = NULL, *high_cond = NULL;362switch (nir_alu_type_get_base_type(src_type)) {363case nir_type_int:364low_cond = low ? nir_ilt(b, src, low) : NULL;365high_cond = high ? nir_ilt(b, high, src) : NULL;366break;367case nir_type_uint:368low_cond = low ? nir_ult(b, src, low) : NULL;369high_cond = high ? nir_ult(b, high, src) : NULL;370break;371case nir_type_float:372low_cond = low ? nir_fge(b, low, src) : NULL;373high_cond = high ? nir_fge(b, src, high) : NULL;374break;375default:376unreachable("clamping from unknown type");377}378379nir_ssa_def *val_low = low, *val_high = high;380if (val_type != src_type) {381nir_get_clamp_limits(b, val_type, dest_type, &val_low, &val_high);382}383384nir_ssa_def *res = val;385if (low_cond && val_low)386res = nir_bcsel(b, low_cond, val_low, res);387if (high_cond && val_high)388res = nir_bcsel(b, high_cond, val_high, res);389390return res;391}392393static inline nir_rounding_mode394nir_simplify_conversion_rounding(nir_alu_type src_type,395nir_alu_type dest_type,396nir_rounding_mode rounding)397{398nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type);399nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type);400unsigned src_bit_size = nir_alu_type_get_type_size(src_type);401unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type);402assert(src_bit_size > 0 && dest_bit_size > 0);403404if (rounding == nir_rounding_mode_undef)405return rounding;406407/* Pure integer conversion doesn't have any rounding */408if (src_base_type != nir_type_float &&409dest_base_type != nir_type_float)410return nir_rounding_mode_undef;411412/* Float down-casts don't round */413if (src_base_type == nir_type_float &&414dest_base_type == nir_type_float &&415dest_bit_size >= src_bit_size)416return nir_rounding_mode_undef;417418/* Regular float to int conversions are RTZ */419if (src_base_type == nir_type_float &&420dest_base_type != nir_type_float &&421rounding == nir_rounding_mode_rtz)422return nir_rounding_mode_undef;423424/* The CL spec requires regular conversions to float to be RTNE */425if (dest_base_type == nir_type_float &&426rounding == nir_rounding_mode_rtne)427return nir_rounding_mode_undef;428429/* Couldn't simplify */430return rounding;431}432433static inline nir_ssa_def *434nir_convert_with_rounding(nir_builder *b,435nir_ssa_def *src, nir_alu_type src_type,436nir_alu_type dest_type,437nir_rounding_mode round,438bool clamp)439{440/* Some stuff wants sized types */441assert(nir_alu_type_get_type_size(src_type) == 0 ||442nir_alu_type_get_type_size(src_type) == src->bit_size);443src_type |= src->bit_size;444445/* Split types from bit sizes */446nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type);447nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type);448unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type);449450/* Try to simplify the conversion if we can */451clamp = clamp &&452!nir_alu_type_range_contains_type_range(dest_type, src_type);453round = nir_simplify_conversion_rounding(src_type, dest_type, round);454455/* For float -> int/uint conversions, we might not be able to represent456* the destination range in the source float accurately. For these cases,457* do the comparison in float range, but the bcsel in the destination range.458*/459bool clamp_after_conversion = clamp &&460src_base_type == nir_type_float &&461dest_base_type != nir_type_float;462463/*464* If we don't care about rounding and clamping, we can just use NIR's465* built-in ops. There is also a special case for SPIR-V in shaders, where466* f32/f64 -> f16 conversions can have one of two rounding modes applied,467* which NIR has built-in opcodes for.468*469* For the rest, we have our own implementation of rounding and clamping.470*/471bool trivial_convert;472if (!clamp && round == nir_rounding_mode_undef) {473trivial_convert = true;474} else if (!clamp && src_type == nir_type_float32 &&475dest_type == nir_type_float16 &&476(round == nir_rounding_mode_rtne ||477round == nir_rounding_mode_rtz)) {478trivial_convert = true;479} else {480trivial_convert = false;481}482if (trivial_convert) {483nir_op op = nir_type_conversion_op(src_type, dest_type, round);484return nir_build_alu(b, op, src, NULL, NULL, NULL);485}486487nir_ssa_def *dest = src;488489/* clamp the result into range */490if (clamp && !clamp_after_conversion)491dest = nir_clamp_to_type_range(b, src, src_type, src, src_type, dest_type);492493/* round with selected rounding mode */494if (!trivial_convert && round != nir_rounding_mode_undef) {495if (src_base_type == nir_type_float) {496if (dest_base_type == nir_type_float) {497dest = nir_round_float_to_float(b, dest, dest_bit_size, round);498} else {499dest = nir_round_float_to_int(b, dest, round);500}501} else {502dest = nir_round_int_to_float(b, dest, src_type, dest_bit_size, round);503}504505round = nir_rounding_mode_undef;506}507508/* now we can convert the value */509nir_op op = nir_type_conversion_op(src_type, dest_type, round);510dest = nir_build_alu(b, op, dest, NULL, NULL, NULL);511512if (clamp_after_conversion)513dest = nir_clamp_to_type_range(b, dest, dest_type, src, src_type, dest_type);514515return dest;516}517518#ifdef __cplusplus519}520#endif521522#endif /* NIR_CONVERSION_BUILDER_H */523524525