Path: blob/21.2-virgl/src/gallium/drivers/swr/rasterizer/jitter/builder_misc.cpp
4574 views
/****************************************************************************1* Copyright (C) 2014-2015 Intel Corporation. All Rights Reserved.2*3* Permission is hereby granted, free of charge, to any person obtaining a4* copy of this software and associated documentation files (the "Software"),5* to deal in the Software without restriction, including without limitation6* the rights to use, copy, modify, merge, publish, distribute, sublicense,7* and/or sell copies of the Software, and to permit persons to whom the8* Software is furnished to do so, subject to the following conditions:9*10* The above copyright notice and this permission notice (including the next11* paragraph) shall be included in all copies or substantial portions of the12* Software.13*14* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR15* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,16* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL17* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER18* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING19* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS20* IN THE SOFTWARE.21*22* @file builder_misc.cpp23*24* @brief Implementation for miscellaneous builder functions25*26* Notes:27*28******************************************************************************/29#include "jit_pch.hpp"30#include "builder.h"31#include "common/rdtsc_buckets.h"3233#include <cstdarg>3435extern "C" void CallPrint(const char* fmt, ...);3637namespace SwrJit38{39//////////////////////////////////////////////////////////////////////////40/// @brief Convert an IEEE 754 32-bit single precision float to an41/// 16 bit float with 5 exponent bits and a variable42/// number of mantissa bits.43/// @param val - 32-bit float44/// @todo Maybe move this outside of this file into a header?45static uint16_t ConvertFloat32ToFloat16(float val)46{47uint32_t sign, exp, mant;48uint32_t roundBits;4950// Extract the sign, exponent, and mantissa51uint32_t uf = *(uint32_t*)&val;52sign = (uf & 0x80000000) >> 31;53exp = (uf & 0x7F800000) >> 23;54mant = uf & 0x007FFFFF;5556// Check for out of range57if (std::isnan(val))58{59exp = 0x1F;60mant = 0x200;61sign = 1; // set the sign bit for NANs62}63else if (std::isinf(val))64{65exp = 0x1f;66mant = 0x0;67}68else if (exp > (0x70 + 0x1E)) // Too big to represent -> max representable value69{70exp = 0x1E;71mant = 0x3FF;72}73else if ((exp <= 0x70) && (exp >= 0x66)) // It's a denorm74{75mant |= 0x00800000;76for (; exp <= 0x70; mant >>= 1, exp++)77;78exp = 0;79mant = mant >> 13;80}81else if (exp < 0x66) // Too small to represent -> Zero82{83exp = 0;84mant = 0;85}86else87{88// Saves bits that will be shifted off for rounding89roundBits = mant & 0x1FFFu;90// convert exponent and mantissa to 16 bit format91exp = exp - 0x70;92mant = mant >> 13;9394// Essentially RTZ, but round up if off by only 1 lsb95if (roundBits == 0x1FFFu)96{97mant++;98// check for overflow99if ((mant & 0xC00u) != 0)100exp++;101// make sure only the needed bits are used102mant &= 0x3FF;103}104}105106uint32_t tmpVal = (sign << 15) | (exp << 10) | mant;107return (uint16_t)tmpVal;108}109110Constant* Builder::C(bool i) { return ConstantInt::get(IRB()->getInt1Ty(), (i ? 1 : 0)); }111112Constant* Builder::C(char i) { return ConstantInt::get(IRB()->getInt8Ty(), i); }113114Constant* Builder::C(uint8_t i) { return ConstantInt::get(IRB()->getInt8Ty(), i); }115116Constant* Builder::C(int i) { return ConstantInt::get(IRB()->getInt32Ty(), i); }117118Constant* Builder::C(int64_t i) { return ConstantInt::get(IRB()->getInt64Ty(), i); }119120Constant* Builder::C(uint16_t i) { return ConstantInt::get(mInt16Ty, i); }121122Constant* Builder::C(uint32_t i) { return ConstantInt::get(IRB()->getInt32Ty(), i); }123124Constant* Builder::C(uint64_t i) { return ConstantInt::get(IRB()->getInt64Ty(), i); }125126Constant* Builder::C(float i) { return ConstantFP::get(IRB()->getFloatTy(), i); }127128Constant* Builder::PRED(bool pred)129{130return ConstantInt::get(IRB()->getInt1Ty(), (pred ? 1 : 0));131}132133Value* Builder::VIMMED1(uint64_t i)134{135#if LLVM_VERSION_MAJOR <= 10136return ConstantVector::getSplat(mVWidth, cast<ConstantInt>(C(i)));137#elif LLVM_VERSION_MAJOR == 11138return ConstantVector::getSplat(ElementCount(mVWidth, false), cast<ConstantInt>(C(i)));139#else140return ConstantVector::getSplat(ElementCount::get(mVWidth, false), cast<ConstantInt>(C(i)));141#endif142}143144Value* Builder::VIMMED1_16(uint64_t i)145{146#if LLVM_VERSION_MAJOR <= 10147return ConstantVector::getSplat(mVWidth16, cast<ConstantInt>(C(i)));148#elif LLVM_VERSION_MAJOR == 11149return ConstantVector::getSplat(ElementCount(mVWidth16, false), cast<ConstantInt>(C(i)));150#else151return ConstantVector::getSplat(ElementCount::get(mVWidth16, false), cast<ConstantInt>(C(i)));152#endif153}154155Value* Builder::VIMMED1(int i)156{157#if LLVM_VERSION_MAJOR <= 10158return ConstantVector::getSplat(mVWidth, cast<ConstantInt>(C(i)));159#elif LLVM_VERSION_MAJOR == 11160return ConstantVector::getSplat(ElementCount(mVWidth, false), cast<ConstantInt>(C(i)));161#else162return ConstantVector::getSplat(ElementCount::get(mVWidth, false), cast<ConstantInt>(C(i)));163#endif164}165166Value* Builder::VIMMED1_16(int i)167{168#if LLVM_VERSION_MAJOR <= 10169return ConstantVector::getSplat(mVWidth16, cast<ConstantInt>(C(i)));170#elif LLVM_VERSION_MAJOR == 11171return ConstantVector::getSplat(ElementCount(mVWidth16, false), cast<ConstantInt>(C(i)));172#else173return ConstantVector::getSplat(ElementCount::get(mVWidth16, false), cast<ConstantInt>(C(i)));174#endif175}176177Value* Builder::VIMMED1(uint32_t i)178{179#if LLVM_VERSION_MAJOR <= 10180return ConstantVector::getSplat(mVWidth, cast<ConstantInt>(C(i)));181#elif LLVM_VERSION_MAJOR == 11182return ConstantVector::getSplat(ElementCount(mVWidth, false), cast<ConstantInt>(C(i)));183#else184return ConstantVector::getSplat(ElementCount::get(mVWidth, false), cast<ConstantInt>(C(i)));185#endif186}187188Value* Builder::VIMMED1_16(uint32_t i)189{190#if LLVM_VERSION_MAJOR <= 10191return ConstantVector::getSplat(mVWidth16, cast<ConstantInt>(C(i)));192#elif LLVM_VERSION_MAJOR == 11193return ConstantVector::getSplat(ElementCount(mVWidth16, false), cast<ConstantInt>(C(i)));194#else195return ConstantVector::getSplat(ElementCount::get(mVWidth16, false), cast<ConstantInt>(C(i)));196#endif197}198199Value* Builder::VIMMED1(float i)200{201#if LLVM_VERSION_MAJOR <= 10202return ConstantVector::getSplat(mVWidth, cast<ConstantFP>(C(i)));203#elif LLVM_VERSION_MAJOR == 11204return ConstantVector::getSplat(ElementCount(mVWidth, false), cast<ConstantFP>(C(i)));205#else206return ConstantVector::getSplat(ElementCount::get(mVWidth, false), cast<ConstantFP>(C(i)));207#endif208}209210Value* Builder::VIMMED1_16(float i)211{212#if LLVM_VERSION_MAJOR <= 10213return ConstantVector::getSplat(mVWidth16, cast<ConstantFP>(C(i)));214#elif LLVM_VERSION_MAJOR == 11215return ConstantVector::getSplat(ElementCount(mVWidth16, false), cast<ConstantFP>(C(i)));216#else217return ConstantVector::getSplat(ElementCount::get(mVWidth16, false), cast<ConstantFP>(C(i)));218#endif219}220221Value* Builder::VIMMED1(bool i)222{223#if LLVM_VERSION_MAJOR <= 10224return ConstantVector::getSplat(mVWidth, cast<ConstantInt>(C(i)));225#elif LLVM_VERSION_MAJOR == 11226return ConstantVector::getSplat(ElementCount(mVWidth, false), cast<ConstantInt>(C(i)));227#else228return ConstantVector::getSplat(ElementCount::get(mVWidth, false), cast<ConstantInt>(C(i)));229#endif230}231232Value* Builder::VIMMED1_16(bool i)233{234#if LLVM_VERSION_MAJOR <= 10235return ConstantVector::getSplat(mVWidth16, cast<ConstantInt>(C(i)));236#elif LLVM_VERSION_MAJOR == 11237return ConstantVector::getSplat(ElementCount(mVWidth16, false), cast<ConstantInt>(C(i)));238#else239return ConstantVector::getSplat(ElementCount::get(mVWidth16, false), cast<ConstantInt>(C(i)));240#endif241}242243Value* Builder::VUNDEF_IPTR() { return UndefValue::get(getVectorType(mInt32PtrTy, mVWidth)); }244245Value* Builder::VUNDEF(Type* t) { return UndefValue::get(getVectorType(t, mVWidth)); }246247Value* Builder::VUNDEF_I() { return UndefValue::get(getVectorType(mInt32Ty, mVWidth)); }248249Value* Builder::VUNDEF_I_16() { return UndefValue::get(getVectorType(mInt32Ty, mVWidth16)); }250251Value* Builder::VUNDEF_F() { return UndefValue::get(getVectorType(mFP32Ty, mVWidth)); }252253Value* Builder::VUNDEF_F_16() { return UndefValue::get(getVectorType(mFP32Ty, mVWidth16)); }254255Value* Builder::VUNDEF(Type* ty, uint32_t size)256{257return UndefValue::get(getVectorType(ty, size));258}259260Value* Builder::VBROADCAST(Value* src, const llvm::Twine& name)261{262// check if src is already a vector263if (src->getType()->isVectorTy())264{265return src;266}267268return VECTOR_SPLAT(mVWidth, src, name);269}270271Value* Builder::VBROADCAST_16(Value* src)272{273// check if src is already a vector274if (src->getType()->isVectorTy())275{276return src;277}278279return VECTOR_SPLAT(mVWidth16, src);280}281282uint32_t Builder::IMMED(Value* v)283{284SWR_ASSERT(isa<ConstantInt>(v));285ConstantInt* pValConst = cast<ConstantInt>(v);286return pValConst->getZExtValue();287}288289int32_t Builder::S_IMMED(Value* v)290{291SWR_ASSERT(isa<ConstantInt>(v));292ConstantInt* pValConst = cast<ConstantInt>(v);293return pValConst->getSExtValue();294}295296CallInst* Builder::CALL(Value* Callee,297const std::initializer_list<Value*>& argsList,298const llvm::Twine& name)299{300std::vector<Value*> args;301for (auto arg : argsList)302args.push_back(arg);303#if LLVM_VERSION_MAJOR >= 11304// see comment to CALLA(Callee) function in the header305return CALLA(FunctionCallee(cast<Function>(Callee)), args, name);306#else307return CALLA(Callee, args, name);308#endif309}310311CallInst* Builder::CALL(Value* Callee, Value* arg)312{313std::vector<Value*> args;314args.push_back(arg);315#if LLVM_VERSION_MAJOR >= 11316// see comment to CALLA(Callee) function in the header317return CALLA(FunctionCallee(cast<Function>(Callee)), args);318#else319return CALLA(Callee, args);320#endif321}322323CallInst* Builder::CALL2(Value* Callee, Value* arg1, Value* arg2)324{325std::vector<Value*> args;326args.push_back(arg1);327args.push_back(arg2);328#if LLVM_VERSION_MAJOR >= 11329// see comment to CALLA(Callee) function in the header330return CALLA(FunctionCallee(cast<Function>(Callee)), args);331#else332return CALLA(Callee, args);333#endif334}335336CallInst* Builder::CALL3(Value* Callee, Value* arg1, Value* arg2, Value* arg3)337{338std::vector<Value*> args;339args.push_back(arg1);340args.push_back(arg2);341args.push_back(arg3);342#if LLVM_VERSION_MAJOR >= 11343// see comment to CALLA(Callee) function in the header344return CALLA(FunctionCallee(cast<Function>(Callee)), args);345#else346return CALLA(Callee, args);347#endif348}349350Value* Builder::VRCP(Value* va, const llvm::Twine& name)351{352return FDIV(VIMMED1(1.0f), va, name); // 1 / a353}354355Value* Builder::VPLANEPS(Value* vA, Value* vB, Value* vC, Value*& vX, Value*& vY)356{357Value* vOut = FMADDPS(vA, vX, vC);358vOut = FMADDPS(vB, vY, vOut);359return vOut;360}361362//////////////////////////////////////////////////////////////////////////363/// @brief insert a JIT call to CallPrint364/// - outputs formatted string to both stdout and VS output window365/// - DEBUG builds only366/// Usage example:367/// PRINT("index %d = 0x%p\n",{C(lane), pIndex});368/// where C(lane) creates a constant value to print, and pIndex is the Value*369/// result from a GEP, printing out the pointer to memory370/// @param printStr - constant string to print, which includes format specifiers371/// @param printArgs - initializer list of Value*'s to print to std out372CallInst* Builder::PRINT(const std::string& printStr,373const std::initializer_list<Value*>& printArgs)374{375// push the arguments to CallPrint into a vector376std::vector<Value*> printCallArgs;377// save room for the format string. we still need to modify it for vectors378printCallArgs.resize(1);379380// search through the format string for special processing381size_t pos = 0;382std::string tempStr(printStr);383pos = tempStr.find('%', pos);384auto v = printArgs.begin();385386while ((pos != std::string::npos) && (v != printArgs.end()))387{388Value* pArg = *v;389Type* pType = pArg->getType();390391if (pType->isVectorTy())392{393Type* pContainedType = pType->getContainedType(0);394#if LLVM_VERSION_MAJOR >= 12395FixedVectorType* pVectorType = cast<FixedVectorType>(pType);396#elif LLVM_VERSION_MAJOR >= 11397VectorType* pVectorType = cast<VectorType>(pType);398#endif399if (toupper(tempStr[pos + 1]) == 'X')400{401tempStr[pos] = '0';402tempStr[pos + 1] = 'x';403tempStr.insert(pos + 2, "%08X ");404pos += 7;405406printCallArgs.push_back(VEXTRACT(pArg, C(0)));407408std::string vectorFormatStr;409#if LLVM_VERSION_MAJOR >= 11410for (uint32_t i = 1; i < pVectorType->getNumElements(); ++i)411#else412for (uint32_t i = 1; i < pType->getVectorNumElements(); ++i)413#endif414{415vectorFormatStr += "0x%08X ";416printCallArgs.push_back(VEXTRACT(pArg, C(i)));417}418419tempStr.insert(pos, vectorFormatStr);420pos += vectorFormatStr.size();421}422else if ((tempStr[pos + 1] == 'f') && (pContainedType->isFloatTy()))423{424uint32_t i = 0;425#if LLVM_VERSION_MAJOR >= 11426for (; i < pVectorType->getNumElements() - 1; i++)427#else428for (; i < pType->getVectorNumElements() - 1; i++)429#endif430{431tempStr.insert(pos, std::string("%f "));432pos += 3;433printCallArgs.push_back(434FP_EXT(VEXTRACT(pArg, C(i)), Type::getDoubleTy(JM()->mContext)));435}436printCallArgs.push_back(437FP_EXT(VEXTRACT(pArg, C(i)), Type::getDoubleTy(JM()->mContext)));438}439else if ((tempStr[pos + 1] == 'd') && (pContainedType->isIntegerTy()))440{441uint32_t i = 0;442#if LLVM_VERSION_MAJOR >= 11443for (; i < pVectorType->getNumElements() - 1; i++)444#else445for (; i < pType->getVectorNumElements() - 1; i++)446#endif447{448tempStr.insert(pos, std::string("%d "));449pos += 3;450printCallArgs.push_back(451S_EXT(VEXTRACT(pArg, C(i)), Type::getInt32Ty(JM()->mContext)));452}453printCallArgs.push_back(454S_EXT(VEXTRACT(pArg, C(i)), Type::getInt32Ty(JM()->mContext)));455}456else if ((tempStr[pos + 1] == 'u') && (pContainedType->isIntegerTy()))457{458uint32_t i = 0;459#if LLVM_VERSION_MAJOR >= 11460for (; i < pVectorType->getNumElements() - 1; i++)461#else462for (; i < pType->getVectorNumElements() - 1; i++)463#endif464{465tempStr.insert(pos, std::string("%d "));466pos += 3;467printCallArgs.push_back(468Z_EXT(VEXTRACT(pArg, C(i)), Type::getInt32Ty(JM()->mContext)));469}470printCallArgs.push_back(471Z_EXT(VEXTRACT(pArg, C(i)), Type::getInt32Ty(JM()->mContext)));472}473}474else475{476if (toupper(tempStr[pos + 1]) == 'X')477{478tempStr[pos] = '0';479tempStr.insert(pos + 1, "x%08");480printCallArgs.push_back(pArg);481pos += 3;482}483// for %f we need to cast float Values to doubles so that they print out correctly484else if ((tempStr[pos + 1] == 'f') && (pType->isFloatTy()))485{486printCallArgs.push_back(FP_EXT(pArg, Type::getDoubleTy(JM()->mContext)));487pos++;488}489else490{491printCallArgs.push_back(pArg);492}493}494495// advance to the next argument496v++;497pos = tempStr.find('%', ++pos);498}499500// create global variable constant string501Constant* constString = ConstantDataArray::getString(JM()->mContext, tempStr, true);502GlobalVariable* gvPtr = new GlobalVariable(503constString->getType(), true, GlobalValue::InternalLinkage, constString, "printStr");504JM()->mpCurrentModule->getGlobalList().push_back(gvPtr);505506// get a pointer to the first character in the constant string array507std::vector<Constant*> geplist{C(0), C(0)};508Constant* strGEP = ConstantExpr::getGetElementPtr(nullptr, gvPtr, geplist, false);509510// insert the pointer to the format string in the argument vector511printCallArgs[0] = strGEP;512513// get pointer to CallPrint function and insert decl into the module if needed514std::vector<Type*> args;515args.push_back(PointerType::get(mInt8Ty, 0));516FunctionType* callPrintTy = FunctionType::get(Type::getVoidTy(JM()->mContext), args, true);517Function* callPrintFn =518#if LLVM_VERSION_MAJOR >= 9519cast<Function>(JM()->mpCurrentModule->getOrInsertFunction("CallPrint", callPrintTy).getCallee());520#else521cast<Function>(JM()->mpCurrentModule->getOrInsertFunction("CallPrint", callPrintTy));522#endif523524// if we haven't yet added the symbol to the symbol table525if ((sys::DynamicLibrary::SearchForAddressOfSymbol("CallPrint")) == nullptr)526{527sys::DynamicLibrary::AddSymbol("CallPrint", (void*)&CallPrint);528}529530// insert a call to CallPrint531return CALLA(callPrintFn, printCallArgs);532}533534//////////////////////////////////////////////////////////////////////////535/// @brief Wrapper around PRINT with initializer list.536CallInst* Builder::PRINT(const std::string& printStr) { return PRINT(printStr, {}); }537538Value* Builder::EXTRACT_16(Value* x, uint32_t imm)539{540if (imm == 0)541{542return VSHUFFLE(x, UndefValue::get(x->getType()), {0, 1, 2, 3, 4, 5, 6, 7});543}544else545{546return VSHUFFLE(x, UndefValue::get(x->getType()), {8, 9, 10, 11, 12, 13, 14, 15});547}548}549550Value* Builder::JOIN_16(Value* a, Value* b)551{552return VSHUFFLE(a, b, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});553}554555//////////////////////////////////////////////////////////////////////////556/// @brief convert x86 <N x float> mask to llvm <N x i1> mask557Value* Builder::MASK(Value* vmask)558{559Value* src = BITCAST(vmask, mSimdInt32Ty);560return ICMP_SLT(src, VIMMED1(0));561}562563Value* Builder::MASK_16(Value* vmask)564{565Value* src = BITCAST(vmask, mSimd16Int32Ty);566return ICMP_SLT(src, VIMMED1_16(0));567}568569//////////////////////////////////////////////////////////////////////////570/// @brief convert llvm <N x i1> mask to x86 <N x i32> mask571Value* Builder::VMASK(Value* mask) { return S_EXT(mask, mSimdInt32Ty); }572573Value* Builder::VMASK_16(Value* mask) { return S_EXT(mask, mSimd16Int32Ty); }574575/// @brief Convert <Nxi1> llvm mask to integer576Value* Builder::VMOVMSK(Value* mask)577{578#if LLVM_VERSION_MAJOR >= 11579#if LLVM_VERSION_MAJOR >= 12580FixedVectorType* pVectorType = cast<FixedVectorType>(mask->getType());581#else582VectorType* pVectorType = cast<VectorType>(mask->getType());583#endif584SWR_ASSERT(pVectorType->getElementType() == mInt1Ty);585uint32_t numLanes = pVectorType->getNumElements();586#else587SWR_ASSERT(mask->getType()->getVectorElementType() == mInt1Ty);588uint32_t numLanes = mask->getType()->getVectorNumElements();589#endif590Value* i32Result;591if (numLanes == 8)592{593i32Result = BITCAST(mask, mInt8Ty);594}595else if (numLanes == 16)596{597i32Result = BITCAST(mask, mInt16Ty);598}599else600{601SWR_ASSERT("Unsupported vector width");602i32Result = BITCAST(mask, mInt8Ty);603}604return Z_EXT(i32Result, mInt32Ty);605}606607//////////////////////////////////////////////////////////////////////////608/// @brief Generate a VPSHUFB operation in LLVM IR. If not609/// supported on the underlying platform, emulate it610/// @param a - 256bit SIMD(32x8bit) of 8bit integer values611/// @param b - 256bit SIMD(32x8bit) of 8bit integer mask values612/// Byte masks in lower 128 lane of b selects 8 bit values from lower613/// 128bits of a, and vice versa for the upper lanes. If the mask614/// value is negative, '0' is inserted.615Value* Builder::PSHUFB(Value* a, Value* b)616{617Value* res;618// use avx2 pshufb instruction if available619if (JM()->mArch.AVX2())620{621res = VPSHUFB(a, b);622}623else624{625Constant* cB = dyn_cast<Constant>(b);626assert(cB != nullptr);627// number of 8 bit elements in b628#if LLVM_VERSION_MAJOR >= 12629uint32_t numElms = cast<FixedVectorType>(cB->getType())->getNumElements();630#else631uint32_t numElms = cast<VectorType>(cB->getType())->getNumElements();632#endif633// output vector634Value* vShuf = UndefValue::get(getVectorType(mInt8Ty, numElms));635636// insert an 8 bit value from the high and low lanes of a per loop iteration637numElms /= 2;638for (uint32_t i = 0; i < numElms; i++)639{640ConstantInt* cLow128b = cast<ConstantInt>(cB->getAggregateElement(i));641ConstantInt* cHigh128b = cast<ConstantInt>(cB->getAggregateElement(i + numElms));642643// extract values from constant mask644char valLow128bLane = (char)(cLow128b->getSExtValue());645char valHigh128bLane = (char)(cHigh128b->getSExtValue());646647Value* insertValLow128b;648Value* insertValHigh128b;649650// if the mask value is negative, insert a '0' in the respective output position651// otherwise, lookup the value at mask position (bits 3..0 of the respective mask652// byte) in a and insert in output vector653insertValLow128b =654(valLow128bLane < 0) ? C((char)0) : VEXTRACT(a, C((valLow128bLane & 0xF)));655insertValHigh128b = (valHigh128bLane < 0)656? C((char)0)657: VEXTRACT(a, C((valHigh128bLane & 0xF) + numElms));658659vShuf = VINSERT(vShuf, insertValLow128b, i);660vShuf = VINSERT(vShuf, insertValHigh128b, (i + numElms));661}662res = vShuf;663}664return res;665}666667//////////////////////////////////////////////////////////////////////////668/// @brief Generate a VPSHUFB operation (sign extend 8 8bit values to 32669/// bits)in LLVM IR. If not supported on the underlying platform, emulate it670/// @param a - 128bit SIMD lane(16x8bit) of 8bit integer values. Only671/// lower 8 values are used.672Value* Builder::PMOVSXBD(Value* a)673{674// VPMOVSXBD output type675Type* v8x32Ty = getVectorType(mInt32Ty, 8);676// Extract 8 values from 128bit lane and sign extend677return S_EXT(VSHUFFLE(a, a, C<int>({0, 1, 2, 3, 4, 5, 6, 7})), v8x32Ty);678}679680//////////////////////////////////////////////////////////////////////////681/// @brief Generate a VPSHUFB operation (sign extend 8 16bit values to 32682/// bits)in LLVM IR. If not supported on the underlying platform, emulate it683/// @param a - 128bit SIMD lane(8x16bit) of 16bit integer values.684Value* Builder::PMOVSXWD(Value* a)685{686// VPMOVSXWD output type687Type* v8x32Ty = getVectorType(mInt32Ty, 8);688// Extract 8 values from 128bit lane and sign extend689return S_EXT(VSHUFFLE(a, a, C<int>({0, 1, 2, 3, 4, 5, 6, 7})), v8x32Ty);690}691692//////////////////////////////////////////////////////////////////////////693/// @brief Generate a VCVTPH2PS operation (float16->float32 conversion)694/// in LLVM IR. If not supported on the underlying platform, emulate it695/// @param a - 128bit SIMD lane(8x16bit) of float16 in int16 format.696Value* Builder::CVTPH2PS(Value* a, const llvm::Twine& name)697{698// Bitcast Nxint16 to Nxhalf699#if LLVM_VERSION_MAJOR >= 12700uint32_t numElems = cast<FixedVectorType>(a->getType())->getNumElements();701#elif LLVM_VERSION_MAJOR >= 11702uint32_t numElems = cast<VectorType>(a->getType())->getNumElements();703#else704uint32_t numElems = a->getType()->getVectorNumElements();705#endif706Value* input = BITCAST(a, getVectorType(mFP16Ty, numElems));707708return FP_EXT(input, getVectorType(mFP32Ty, numElems), name);709}710711//////////////////////////////////////////////////////////////////////////712/// @brief Generate a VCVTPS2PH operation (float32->float16 conversion)713/// in LLVM IR. If not supported on the underlying platform, emulate it714/// @param a - 128bit SIMD lane(8x16bit) of float16 in int16 format.715Value* Builder::CVTPS2PH(Value* a, Value* rounding)716{717if (JM()->mArch.F16C())718{719return VCVTPS2PH(a, rounding);720}721else722{723// call scalar C function for now724FunctionType* pFuncTy = FunctionType::get(mInt16Ty, mFP32Ty);725Function* pCvtPs2Ph = cast<Function>(726#if LLVM_VERSION_MAJOR >= 9727JM()->mpCurrentModule->getOrInsertFunction("ConvertFloat32ToFloat16", pFuncTy).getCallee());728#else729JM()->mpCurrentModule->getOrInsertFunction("ConvertFloat32ToFloat16", pFuncTy));730#endif731732if (sys::DynamicLibrary::SearchForAddressOfSymbol("ConvertFloat32ToFloat16") == nullptr)733{734sys::DynamicLibrary::AddSymbol("ConvertFloat32ToFloat16",735(void*)&ConvertFloat32ToFloat16);736}737738Value* pResult = UndefValue::get(mSimdInt16Ty);739for (uint32_t i = 0; i < mVWidth; ++i)740{741Value* pSrc = VEXTRACT(a, C(i));742Value* pConv = CALL(pCvtPs2Ph, std::initializer_list<Value*>{pSrc});743pResult = VINSERT(pResult, pConv, C(i));744}745746return pResult;747}748}749750Value* Builder::PMAXSD(Value* a, Value* b)751{752Value* cmp = ICMP_SGT(a, b);753return SELECT(cmp, a, b);754}755756Value* Builder::PMINSD(Value* a, Value* b)757{758Value* cmp = ICMP_SLT(a, b);759return SELECT(cmp, a, b);760}761762Value* Builder::PMAXUD(Value* a, Value* b)763{764Value* cmp = ICMP_UGT(a, b);765return SELECT(cmp, a, b);766}767768Value* Builder::PMINUD(Value* a, Value* b)769{770Value* cmp = ICMP_ULT(a, b);771return SELECT(cmp, a, b);772}773774// Helper function to create alloca in entry block of function775Value* Builder::CreateEntryAlloca(Function* pFunc, Type* pType)776{777auto saveIP = IRB()->saveIP();778IRB()->SetInsertPoint(&pFunc->getEntryBlock(), pFunc->getEntryBlock().begin());779Value* pAlloca = ALLOCA(pType);780if (saveIP.isSet())781IRB()->restoreIP(saveIP);782return pAlloca;783}784785Value* Builder::CreateEntryAlloca(Function* pFunc, Type* pType, Value* pArraySize)786{787auto saveIP = IRB()->saveIP();788IRB()->SetInsertPoint(&pFunc->getEntryBlock(), pFunc->getEntryBlock().begin());789Value* pAlloca = ALLOCA(pType, pArraySize);790if (saveIP.isSet())791IRB()->restoreIP(saveIP);792return pAlloca;793}794795Value* Builder::VABSPS(Value* a)796{797Value* asInt = BITCAST(a, mSimdInt32Ty);798Value* result = BITCAST(AND(asInt, VIMMED1(0x7fffffff)), mSimdFP32Ty);799return result;800}801802Value* Builder::ICLAMP(Value* src, Value* low, Value* high, const llvm::Twine& name)803{804Value* lowCmp = ICMP_SLT(src, low);805Value* ret = SELECT(lowCmp, low, src);806807Value* highCmp = ICMP_SGT(ret, high);808ret = SELECT(highCmp, high, ret, name);809810return ret;811}812813Value* Builder::FCLAMP(Value* src, Value* low, Value* high)814{815Value* lowCmp = FCMP_OLT(src, low);816Value* ret = SELECT(lowCmp, low, src);817818Value* highCmp = FCMP_OGT(ret, high);819ret = SELECT(highCmp, high, ret);820821return ret;822}823824Value* Builder::FCLAMP(Value* src, float low, float high)825{826Value* result = VMAXPS(src, VIMMED1(low));827result = VMINPS(result, VIMMED1(high));828829return result;830}831832Value* Builder::FMADDPS(Value* a, Value* b, Value* c)833{834Value* vOut;835// This maps to LLVM fmuladd intrinsic836vOut = VFMADDPS(a, b, c);837return vOut;838}839840//////////////////////////////////////////////////////////////////////////841/// @brief pop count on vector mask (e.g. <8 x i1>)842Value* Builder::VPOPCNT(Value* a) { return POPCNT(VMOVMSK(a)); }843844//////////////////////////////////////////////////////////////////////////845/// @brief Float / Fixed-point conversions846//////////////////////////////////////////////////////////////////////////847Value* Builder::VCVT_F32_FIXED_SI(Value* vFloat,848uint32_t numIntBits,849uint32_t numFracBits,850const llvm::Twine& name)851{852SWR_ASSERT((numIntBits + numFracBits) <= 32, "Can only handle 32-bit fixed-point values");853Value* fixed = nullptr;854855#if 0 // This doesn't work for negative numbers!!856{857fixed = FP_TO_SI(VROUND(FMUL(vFloat, VIMMED1(float(1 << numFracBits))),858C(_MM_FROUND_TO_NEAREST_INT)),859mSimdInt32Ty);860}861else862#endif863{864// Do round to nearest int on fractional bits first865// Not entirely perfect for negative numbers, but close enough866vFloat = VROUND(FMUL(vFloat, VIMMED1(float(1 << numFracBits))),867C(_MM_FROUND_TO_NEAREST_INT));868vFloat = FMUL(vFloat, VIMMED1(1.0f / float(1 << numFracBits)));869870// TODO: Handle INF, NAN, overflow / underflow, etc.871872Value* vSgn = FCMP_OLT(vFloat, VIMMED1(0.0f));873Value* vFloatInt = BITCAST(vFloat, mSimdInt32Ty);874Value* vFixed = AND(vFloatInt, VIMMED1((1 << 23) - 1));875vFixed = OR(vFixed, VIMMED1(1 << 23));876vFixed = SELECT(vSgn, NEG(vFixed), vFixed);877878Value* vExp = LSHR(SHL(vFloatInt, VIMMED1(1)), VIMMED1(24));879vExp = SUB(vExp, VIMMED1(127));880881Value* vExtraBits = SUB(VIMMED1(23 - numFracBits), vExp);882883fixed = ASHR(vFixed, vExtraBits, name);884}885886return fixed;887}888889Value* Builder::VCVT_FIXED_SI_F32(Value* vFixed,890uint32_t numIntBits,891uint32_t numFracBits,892const llvm::Twine& name)893{894SWR_ASSERT((numIntBits + numFracBits) <= 32, "Can only handle 32-bit fixed-point values");895uint32_t extraBits = 32 - numIntBits - numFracBits;896if (numIntBits && extraBits)897{898// Sign extend899Value* shftAmt = VIMMED1(extraBits);900vFixed = ASHR(SHL(vFixed, shftAmt), shftAmt);901}902903Value* fVal = VIMMED1(0.0f);904Value* fFrac = VIMMED1(0.0f);905if (numIntBits)906{907fVal = SI_TO_FP(ASHR(vFixed, VIMMED1(numFracBits)), mSimdFP32Ty, name);908}909910if (numFracBits)911{912fFrac = UI_TO_FP(AND(vFixed, VIMMED1((1 << numFracBits) - 1)), mSimdFP32Ty);913fFrac = FDIV(fFrac, VIMMED1(float(1 << numFracBits)), name);914}915916return FADD(fVal, fFrac, name);917}918919Value* Builder::VCVT_F32_FIXED_UI(Value* vFloat,920uint32_t numIntBits,921uint32_t numFracBits,922const llvm::Twine& name)923{924SWR_ASSERT((numIntBits + numFracBits) <= 32, "Can only handle 32-bit fixed-point values");925Value* fixed = nullptr;926#if 1 // KNOB_SIM_FAST_MATH? Below works correctly from a precision927// standpoint...928{929fixed = FP_TO_UI(VROUND(FMUL(vFloat, VIMMED1(float(1 << numFracBits))),930C(_MM_FROUND_TO_NEAREST_INT)),931mSimdInt32Ty);932}933#else934{935// Do round to nearest int on fractional bits first936vFloat = VROUND(FMUL(vFloat, VIMMED1(float(1 << numFracBits))),937C(_MM_FROUND_TO_NEAREST_INT));938vFloat = FMUL(vFloat, VIMMED1(1.0f / float(1 << numFracBits)));939940// TODO: Handle INF, NAN, overflow / underflow, etc.941942Value* vSgn = FCMP_OLT(vFloat, VIMMED1(0.0f));943Value* vFloatInt = BITCAST(vFloat, mSimdInt32Ty);944Value* vFixed = AND(vFloatInt, VIMMED1((1 << 23) - 1));945vFixed = OR(vFixed, VIMMED1(1 << 23));946947Value* vExp = LSHR(SHL(vFloatInt, VIMMED1(1)), VIMMED1(24));948vExp = SUB(vExp, VIMMED1(127));949950Value* vExtraBits = SUB(VIMMED1(23 - numFracBits), vExp);951952fixed = LSHR(vFixed, vExtraBits, name);953}954#endif955return fixed;956}957958Value* Builder::VCVT_FIXED_UI_F32(Value* vFixed,959uint32_t numIntBits,960uint32_t numFracBits,961const llvm::Twine& name)962{963SWR_ASSERT((numIntBits + numFracBits) <= 32, "Can only handle 32-bit fixed-point values");964uint32_t extraBits = 32 - numIntBits - numFracBits;965if (numIntBits && extraBits)966{967// Sign extend968Value* shftAmt = VIMMED1(extraBits);969vFixed = ASHR(SHL(vFixed, shftAmt), shftAmt);970}971972Value* fVal = VIMMED1(0.0f);973Value* fFrac = VIMMED1(0.0f);974if (numIntBits)975{976fVal = UI_TO_FP(LSHR(vFixed, VIMMED1(numFracBits)), mSimdFP32Ty, name);977}978979if (numFracBits)980{981fFrac = UI_TO_FP(AND(vFixed, VIMMED1((1 << numFracBits) - 1)), mSimdFP32Ty);982fFrac = FDIV(fFrac, VIMMED1(float(1 << numFracBits)), name);983}984985return FADD(fVal, fFrac, name);986}987988//////////////////////////////////////////////////////////////////////////989/// @brief C functions called by LLVM IR990//////////////////////////////////////////////////////////////////////////991992Value* Builder::VEXTRACTI128(Value* a, Constant* imm8)993{994bool flag = !imm8->isZeroValue();995SmallVector<Constant*, 8> idx;996for (unsigned i = 0; i < mVWidth / 2; i++)997{998idx.push_back(C(flag ? i + mVWidth / 2 : i));999}1000return VSHUFFLE(a, VUNDEF_I(), ConstantVector::get(idx));1001}10021003Value* Builder::VINSERTI128(Value* a, Value* b, Constant* imm8)1004{1005bool flag = !imm8->isZeroValue();1006SmallVector<Constant*, 8> idx;1007for (unsigned i = 0; i < mVWidth; i++)1008{1009idx.push_back(C(i));1010}1011Value* inter = VSHUFFLE(b, VUNDEF_I(), ConstantVector::get(idx));10121013SmallVector<Constant*, 8> idx2;1014for (unsigned i = 0; i < mVWidth / 2; i++)1015{1016idx2.push_back(C(flag ? i : i + mVWidth));1017}1018for (unsigned i = mVWidth / 2; i < mVWidth; i++)1019{1020idx2.push_back(C(flag ? i + mVWidth / 2 : i));1021}1022return VSHUFFLE(a, inter, ConstantVector::get(idx2));1023}10241025// rdtsc buckets macros1026void Builder::RDTSC_START(Value* pBucketMgr, Value* pId)1027{1028// @todo due to an issue with thread local storage propagation in llvm, we can only safely1029// call into buckets framework when single threaded1030if (KNOB_SINGLE_THREADED)1031{1032std::vector<Type*> args{1033PointerType::get(mInt32Ty, 0), // pBucketMgr1034mInt32Ty // id1035};10361037FunctionType* pFuncTy = FunctionType::get(Type::getVoidTy(JM()->mContext), args, false);1038Function* pFunc = cast<Function>(1039#if LLVM_VERSION_MAJOR >= 91040JM()->mpCurrentModule->getOrInsertFunction("BucketManager_StartBucket", pFuncTy).getCallee());1041#else1042JM()->mpCurrentModule->getOrInsertFunction("BucketManager_StartBucket", pFuncTy));1043#endif1044if (sys::DynamicLibrary::SearchForAddressOfSymbol("BucketManager_StartBucket") ==1045nullptr)1046{1047sys::DynamicLibrary::AddSymbol("BucketManager_StartBucket",1048(void*)&BucketManager_StartBucket);1049}10501051CALL(pFunc, {pBucketMgr, pId});1052}1053}10541055void Builder::RDTSC_STOP(Value* pBucketMgr, Value* pId)1056{1057// @todo due to an issue with thread local storage propagation in llvm, we can only safely1058// call into buckets framework when single threaded1059if (KNOB_SINGLE_THREADED)1060{1061std::vector<Type*> args{1062PointerType::get(mInt32Ty, 0), // pBucketMgr1063mInt32Ty // id1064};10651066FunctionType* pFuncTy = FunctionType::get(Type::getVoidTy(JM()->mContext), args, false);1067Function* pFunc = cast<Function>(1068#if LLVM_VERSION_MAJOR >= 91069JM()->mpCurrentModule->getOrInsertFunction("BucketManager_StopBucket", pFuncTy).getCallee());1070#else1071JM()->mpCurrentModule->getOrInsertFunction("BucketManager_StopBucket", pFuncTy));1072#endif1073if (sys::DynamicLibrary::SearchForAddressOfSymbol("BucketManager_StopBucket") ==1074nullptr)1075{1076sys::DynamicLibrary::AddSymbol("BucketManager_StopBucket",1077(void*)&BucketManager_StopBucket);1078}10791080CALL(pFunc, {pBucketMgr, pId});1081}1082}10831084uint32_t Builder::GetTypeSize(Type* pType)1085{1086if (pType->isStructTy())1087{1088uint32_t numElems = pType->getStructNumElements();1089Type* pElemTy = pType->getStructElementType(0);1090return numElems * GetTypeSize(pElemTy);1091}10921093if (pType->isArrayTy())1094{1095uint32_t numElems = pType->getArrayNumElements();1096Type* pElemTy = pType->getArrayElementType();1097return numElems * GetTypeSize(pElemTy);1098}10991100if (pType->isIntegerTy())1101{1102uint32_t bitSize = pType->getIntegerBitWidth();1103return bitSize / 8;1104}11051106if (pType->isFloatTy())1107{1108return 4;1109}11101111if (pType->isHalfTy())1112{1113return 2;1114}11151116if (pType->isDoubleTy())1117{1118return 8;1119}11201121SWR_ASSERT(false, "Unimplemented type.");1122return 0;1123}1124} // namespace SwrJit112511261127