Path: blob/main/contrib/llvm-project/clang/lib/CodeGen/CGHLSLBuiltins.cpp
213766 views
//===------- CGHLSLBuiltins.cpp - Emit LLVM Code for HLSL builtins --------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// This contains code to emit HLSL Builtin calls as LLVM code.9//10//===----------------------------------------------------------------------===//1112#include "CGBuiltin.h"13#include "CGHLSLRuntime.h"14#include "CodeGenFunction.h"1516using namespace clang;17using namespace CodeGen;18using namespace llvm;1920static Value *handleAsDoubleBuiltin(CodeGenFunction &CGF, const CallExpr *E) {21assert((E->getArg(0)->getType()->hasUnsignedIntegerRepresentation() &&22E->getArg(1)->getType()->hasUnsignedIntegerRepresentation()) &&23"asdouble operands types mismatch");24Value *OpLowBits = CGF.EmitScalarExpr(E->getArg(0));25Value *OpHighBits = CGF.EmitScalarExpr(E->getArg(1));2627llvm::Type *ResultType = CGF.DoubleTy;28int N = 1;29if (auto *VTy = E->getArg(0)->getType()->getAs<clang::VectorType>()) {30N = VTy->getNumElements();31ResultType = llvm::FixedVectorType::get(CGF.DoubleTy, N);32}3334if (CGF.CGM.getTarget().getTriple().isDXIL())35return CGF.Builder.CreateIntrinsic(36/*ReturnType=*/ResultType, Intrinsic::dx_asdouble,37{OpLowBits, OpHighBits}, nullptr, "hlsl.asdouble");3839if (!E->getArg(0)->getType()->isVectorType()) {40OpLowBits = CGF.Builder.CreateVectorSplat(1, OpLowBits);41OpHighBits = CGF.Builder.CreateVectorSplat(1, OpHighBits);42}4344llvm::SmallVector<int> Mask;45for (int i = 0; i < N; i++) {46Mask.push_back(i);47Mask.push_back(i + N);48}4950Value *BitVec = CGF.Builder.CreateShuffleVector(OpLowBits, OpHighBits, Mask);5152return CGF.Builder.CreateBitCast(BitVec, ResultType);53}5455static Value *handleHlslClip(const CallExpr *E, CodeGenFunction *CGF) {56Value *Op0 = CGF->EmitScalarExpr(E->getArg(0));5758Constant *FZeroConst = ConstantFP::getZero(CGF->FloatTy);59Value *CMP;60Value *LastInstr;6162if (const auto *VecTy = E->getArg(0)->getType()->getAs<clang::VectorType>()) {63FZeroConst = ConstantVector::getSplat(64ElementCount::getFixed(VecTy->getNumElements()), FZeroConst);65auto *FCompInst = CGF->Builder.CreateFCmpOLT(Op0, FZeroConst);66CMP = CGF->Builder.CreateIntrinsic(67CGF->Builder.getInt1Ty(), CGF->CGM.getHLSLRuntime().getAnyIntrinsic(),68{FCompInst});69} else {70CMP = CGF->Builder.CreateFCmpOLT(Op0, FZeroConst);71}7273if (CGF->CGM.getTarget().getTriple().isDXIL()) {74LastInstr = CGF->Builder.CreateIntrinsic(Intrinsic::dx_discard, {CMP});75} else if (CGF->CGM.getTarget().getTriple().isSPIRV()) {76BasicBlock *LT0 = CGF->createBasicBlock("lt0", CGF->CurFn);77BasicBlock *End = CGF->createBasicBlock("end", CGF->CurFn);7879CGF->Builder.CreateCondBr(CMP, LT0, End);8081CGF->Builder.SetInsertPoint(LT0);8283CGF->Builder.CreateIntrinsic(Intrinsic::spv_discard, {});8485LastInstr = CGF->Builder.CreateBr(End);86CGF->Builder.SetInsertPoint(End);87} else {88llvm_unreachable("Backend Codegen not supported.");89}9091return LastInstr;92}9394static Value *handleHlslSplitdouble(const CallExpr *E, CodeGenFunction *CGF) {95Value *Op0 = CGF->EmitScalarExpr(E->getArg(0));96const auto *OutArg1 = dyn_cast<HLSLOutArgExpr>(E->getArg(1));97const auto *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2));9899CallArgList Args;100LValue Op1TmpLValue =101CGF->EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType());102LValue Op2TmpLValue =103CGF->EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());104105if (CGF->getTarget().getCXXABI().areArgsDestroyedLeftToRightInCallee())106Args.reverseWritebacks();107108Value *LowBits = nullptr;109Value *HighBits = nullptr;110111if (CGF->CGM.getTarget().getTriple().isDXIL()) {112llvm::Type *RetElementTy = CGF->Int32Ty;113if (auto *Op0VecTy = E->getArg(0)->getType()->getAs<clang::VectorType>())114RetElementTy = llvm::VectorType::get(115CGF->Int32Ty, ElementCount::getFixed(Op0VecTy->getNumElements()));116auto *RetTy = llvm::StructType::get(RetElementTy, RetElementTy);117118CallInst *CI = CGF->Builder.CreateIntrinsic(119RetTy, Intrinsic::dx_splitdouble, {Op0}, nullptr, "hlsl.splitdouble");120121LowBits = CGF->Builder.CreateExtractValue(CI, 0);122HighBits = CGF->Builder.CreateExtractValue(CI, 1);123} else {124// For Non DXIL targets we generate the instructions.125126if (!Op0->getType()->isVectorTy()) {127FixedVectorType *DestTy = FixedVectorType::get(CGF->Int32Ty, 2);128Value *Bitcast = CGF->Builder.CreateBitCast(Op0, DestTy);129130LowBits = CGF->Builder.CreateExtractElement(Bitcast, (uint64_t)0);131HighBits = CGF->Builder.CreateExtractElement(Bitcast, 1);132} else {133int NumElements = 1;134if (const auto *VecTy =135E->getArg(0)->getType()->getAs<clang::VectorType>())136NumElements = VecTy->getNumElements();137138FixedVectorType *Uint32VecTy =139FixedVectorType::get(CGF->Int32Ty, NumElements * 2);140Value *Uint32Vec = CGF->Builder.CreateBitCast(Op0, Uint32VecTy);141if (NumElements == 1) {142LowBits = CGF->Builder.CreateExtractElement(Uint32Vec, (uint64_t)0);143HighBits = CGF->Builder.CreateExtractElement(Uint32Vec, 1);144} else {145SmallVector<int> EvenMask, OddMask;146for (int I = 0, E = NumElements; I != E; ++I) {147EvenMask.push_back(I * 2);148OddMask.push_back(I * 2 + 1);149}150LowBits = CGF->Builder.CreateShuffleVector(Uint32Vec, EvenMask);151HighBits = CGF->Builder.CreateShuffleVector(Uint32Vec, OddMask);152}153}154}155CGF->Builder.CreateStore(LowBits, Op1TmpLValue.getAddress());156auto *LastInst =157CGF->Builder.CreateStore(HighBits, Op2TmpLValue.getAddress());158CGF->EmitWritebacks(Args);159return LastInst;160}161162// Return dot product intrinsic that corresponds to the QT scalar type163static Intrinsic::ID getDotProductIntrinsic(CGHLSLRuntime &RT, QualType QT) {164if (QT->isFloatingType())165return RT.getFDotIntrinsic();166if (QT->isSignedIntegerType())167return RT.getSDotIntrinsic();168assert(QT->isUnsignedIntegerType());169return RT.getUDotIntrinsic();170}171172static Intrinsic::ID getFirstBitHighIntrinsic(CGHLSLRuntime &RT, QualType QT) {173if (QT->hasSignedIntegerRepresentation()) {174return RT.getFirstBitSHighIntrinsic();175}176177assert(QT->hasUnsignedIntegerRepresentation());178return RT.getFirstBitUHighIntrinsic();179}180181// Return wave active sum that corresponds to the QT scalar type182static Intrinsic::ID getWaveActiveSumIntrinsic(llvm::Triple::ArchType Arch,183CGHLSLRuntime &RT, QualType QT) {184switch (Arch) {185case llvm::Triple::spirv:186return Intrinsic::spv_wave_reduce_sum;187case llvm::Triple::dxil: {188if (QT->isUnsignedIntegerType())189return Intrinsic::dx_wave_reduce_usum;190return Intrinsic::dx_wave_reduce_sum;191}192default:193llvm_unreachable("Intrinsic WaveActiveSum"194" not supported by target architecture");195}196}197198// Return wave active sum that corresponds to the QT scalar type199static Intrinsic::ID getWaveActiveMaxIntrinsic(llvm::Triple::ArchType Arch,200CGHLSLRuntime &RT, QualType QT) {201switch (Arch) {202case llvm::Triple::spirv:203if (QT->isUnsignedIntegerType())204return Intrinsic::spv_wave_reduce_umax;205return Intrinsic::spv_wave_reduce_max;206case llvm::Triple::dxil: {207if (QT->isUnsignedIntegerType())208return Intrinsic::dx_wave_reduce_umax;209return Intrinsic::dx_wave_reduce_max;210}211default:212llvm_unreachable("Intrinsic WaveActiveMax"213" not supported by target architecture");214}215}216217// Returns the mangled name for a builtin function that the SPIR-V backend218// will expand into a spec Constant.219static std::string getSpecConstantFunctionName(clang::QualType SpecConstantType,220ASTContext &Context) {221// The parameter types for our conceptual intrinsic function.222QualType ClangParamTypes[] = {Context.IntTy, SpecConstantType};223224// Create a temporary FunctionDecl for the builtin fuction. It won't be225// added to the AST.226FunctionProtoType::ExtProtoInfo EPI;227QualType FnType =228Context.getFunctionType(SpecConstantType, ClangParamTypes, EPI);229DeclarationName FuncName = &Context.Idents.get("__spirv_SpecConstant");230FunctionDecl *FnDeclForMangling = FunctionDecl::Create(231Context, Context.getTranslationUnitDecl(), SourceLocation(),232SourceLocation(), FuncName, FnType, /*TSI=*/nullptr, SC_Extern);233234// Attach the created parameter declarations to the function declaration.235SmallVector<ParmVarDecl *, 2> ParamDecls;236for (QualType ParamType : ClangParamTypes) {237ParmVarDecl *PD = ParmVarDecl::Create(238Context, FnDeclForMangling, SourceLocation(), SourceLocation(),239/*IdentifierInfo*/ nullptr, ParamType, /*TSI*/ nullptr, SC_None,240/*DefaultArg*/ nullptr);241ParamDecls.push_back(PD);242}243FnDeclForMangling->setParams(ParamDecls);244245// Get the mangled name.246std::string Name;247llvm::raw_string_ostream MangledNameStream(Name);248std::unique_ptr<MangleContext> Mangler(Context.createMangleContext());249Mangler->mangleName(FnDeclForMangling, MangledNameStream);250MangledNameStream.flush();251252return Name;253}254255Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,256const CallExpr *E,257ReturnValueSlot ReturnValue) {258if (!getLangOpts().HLSL)259return nullptr;260261switch (BuiltinID) {262case Builtin::BI__builtin_hlsl_adduint64: {263Value *OpA = EmitScalarExpr(E->getArg(0));264Value *OpB = EmitScalarExpr(E->getArg(1));265QualType Arg0Ty = E->getArg(0)->getType();266uint64_t NumElements = Arg0Ty->castAs<VectorType>()->getNumElements();267assert(Arg0Ty == E->getArg(1)->getType() &&268"AddUint64 operand types must match");269assert(Arg0Ty->hasIntegerRepresentation() &&270"AddUint64 operands must have an integer representation");271assert((NumElements == 2 || NumElements == 4) &&272"AddUint64 operands must have 2 or 4 elements");273274llvm::Value *LowA;275llvm::Value *HighA;276llvm::Value *LowB;277llvm::Value *HighB;278279// Obtain low and high words of inputs A and B280if (NumElements == 2) {281LowA = Builder.CreateExtractElement(OpA, (uint64_t)0, "LowA");282HighA = Builder.CreateExtractElement(OpA, (uint64_t)1, "HighA");283LowB = Builder.CreateExtractElement(OpB, (uint64_t)0, "LowB");284HighB = Builder.CreateExtractElement(OpB, (uint64_t)1, "HighB");285} else {286LowA = Builder.CreateShuffleVector(OpA, {0, 2}, "LowA");287HighA = Builder.CreateShuffleVector(OpA, {1, 3}, "HighA");288LowB = Builder.CreateShuffleVector(OpB, {0, 2}, "LowB");289HighB = Builder.CreateShuffleVector(OpB, {1, 3}, "HighB");290}291292// Use an uadd_with_overflow to compute the sum of low words and obtain a293// carry value294llvm::Value *Carry;295llvm::Value *LowSum = EmitOverflowIntrinsic(296*this, Intrinsic::uadd_with_overflow, LowA, LowB, Carry);297llvm::Value *ZExtCarry =298Builder.CreateZExt(Carry, HighA->getType(), "CarryZExt");299300// Sum the high words and the carry301llvm::Value *HighSum = Builder.CreateAdd(HighA, HighB, "HighSum");302llvm::Value *HighSumPlusCarry =303Builder.CreateAdd(HighSum, ZExtCarry, "HighSumPlusCarry");304305if (NumElements == 4) {306return Builder.CreateShuffleVector(LowSum, HighSumPlusCarry, {0, 2, 1, 3},307"hlsl.AddUint64");308}309310llvm::Value *Result = PoisonValue::get(OpA->getType());311Result = Builder.CreateInsertElement(Result, LowSum, (uint64_t)0,312"hlsl.AddUint64.upto0");313Result = Builder.CreateInsertElement(Result, HighSumPlusCarry, (uint64_t)1,314"hlsl.AddUint64");315return Result;316}317case Builtin::BI__builtin_hlsl_resource_getpointer: {318Value *HandleOp = EmitScalarExpr(E->getArg(0));319Value *IndexOp = EmitScalarExpr(E->getArg(1));320321llvm::Type *RetTy = ConvertType(E->getType());322return Builder.CreateIntrinsic(323RetTy, CGM.getHLSLRuntime().getCreateResourceGetPointerIntrinsic(),324ArrayRef<Value *>{HandleOp, IndexOp});325}326case Builtin::BI__builtin_hlsl_resource_uninitializedhandle: {327llvm::Type *HandleTy = CGM.getTypes().ConvertType(E->getType());328return llvm::PoisonValue::get(HandleTy);329}330case Builtin::BI__builtin_hlsl_resource_handlefrombinding: {331llvm::Type *HandleTy = CGM.getTypes().ConvertType(E->getType());332Value *RegisterOp = EmitScalarExpr(E->getArg(1));333Value *SpaceOp = EmitScalarExpr(E->getArg(2));334Value *RangeOp = EmitScalarExpr(E->getArg(3));335Value *IndexOp = EmitScalarExpr(E->getArg(4));336Value *Name = EmitScalarExpr(E->getArg(5));337// FIXME: NonUniformResourceIndex bit is not yet implemented338// (llvm/llvm-project#135452)339Value *NonUniform =340llvm::ConstantInt::get(llvm::Type::getInt1Ty(getLLVMContext()), false);341342llvm::Intrinsic::ID IntrinsicID =343CGM.getHLSLRuntime().getCreateHandleFromBindingIntrinsic();344SmallVector<Value *> Args{SpaceOp, RegisterOp, RangeOp,345IndexOp, NonUniform, Name};346return Builder.CreateIntrinsic(HandleTy, IntrinsicID, Args);347}348case Builtin::BI__builtin_hlsl_resource_handlefromimplicitbinding: {349llvm::Type *HandleTy = CGM.getTypes().ConvertType(E->getType());350Value *SpaceOp = EmitScalarExpr(E->getArg(1));351Value *RangeOp = EmitScalarExpr(E->getArg(2));352Value *IndexOp = EmitScalarExpr(E->getArg(3));353Value *OrderID = EmitScalarExpr(E->getArg(4));354Value *Name = EmitScalarExpr(E->getArg(5));355// FIXME: NonUniformResourceIndex bit is not yet implemented356// (llvm/llvm-project#135452)357Value *NonUniform =358llvm::ConstantInt::get(llvm::Type::getInt1Ty(getLLVMContext()), false);359360llvm::Intrinsic::ID IntrinsicID =361CGM.getHLSLRuntime().getCreateHandleFromImplicitBindingIntrinsic();362SmallVector<Value *> Args{OrderID, SpaceOp, RangeOp,363IndexOp, NonUniform, Name};364return Builder.CreateIntrinsic(HandleTy, IntrinsicID, Args);365}366case Builtin::BI__builtin_hlsl_all: {367Value *Op0 = EmitScalarExpr(E->getArg(0));368return Builder.CreateIntrinsic(369/*ReturnType=*/llvm::Type::getInt1Ty(getLLVMContext()),370CGM.getHLSLRuntime().getAllIntrinsic(), ArrayRef<Value *>{Op0}, nullptr,371"hlsl.all");372}373case Builtin::BI__builtin_hlsl_and: {374Value *Op0 = EmitScalarExpr(E->getArg(0));375Value *Op1 = EmitScalarExpr(E->getArg(1));376return Builder.CreateAnd(Op0, Op1, "hlsl.and");377}378case Builtin::BI__builtin_hlsl_or: {379Value *Op0 = EmitScalarExpr(E->getArg(0));380Value *Op1 = EmitScalarExpr(E->getArg(1));381return Builder.CreateOr(Op0, Op1, "hlsl.or");382}383case Builtin::BI__builtin_hlsl_any: {384Value *Op0 = EmitScalarExpr(E->getArg(0));385return Builder.CreateIntrinsic(386/*ReturnType=*/llvm::Type::getInt1Ty(getLLVMContext()),387CGM.getHLSLRuntime().getAnyIntrinsic(), ArrayRef<Value *>{Op0}, nullptr,388"hlsl.any");389}390case Builtin::BI__builtin_hlsl_asdouble:391return handleAsDoubleBuiltin(*this, E);392case Builtin::BI__builtin_hlsl_elementwise_clamp: {393Value *OpX = EmitScalarExpr(E->getArg(0));394Value *OpMin = EmitScalarExpr(E->getArg(1));395Value *OpMax = EmitScalarExpr(E->getArg(2));396397QualType Ty = E->getArg(0)->getType();398if (auto *VecTy = Ty->getAs<VectorType>())399Ty = VecTy->getElementType();400401Intrinsic::ID Intr;402if (Ty->isFloatingType()) {403Intr = CGM.getHLSLRuntime().getNClampIntrinsic();404} else if (Ty->isUnsignedIntegerType()) {405Intr = CGM.getHLSLRuntime().getUClampIntrinsic();406} else {407assert(Ty->isSignedIntegerType());408Intr = CGM.getHLSLRuntime().getSClampIntrinsic();409}410return Builder.CreateIntrinsic(411/*ReturnType=*/OpX->getType(), Intr,412ArrayRef<Value *>{OpX, OpMin, OpMax}, nullptr, "hlsl.clamp");413}414case Builtin::BI__builtin_hlsl_crossf16:415case Builtin::BI__builtin_hlsl_crossf32: {416Value *Op0 = EmitScalarExpr(E->getArg(0));417Value *Op1 = EmitScalarExpr(E->getArg(1));418assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&419E->getArg(1)->getType()->hasFloatingRepresentation() &&420"cross operands must have a float representation");421// make sure each vector has exactly 3 elements422assert(423E->getArg(0)->getType()->castAs<VectorType>()->getNumElements() == 3 &&424E->getArg(1)->getType()->castAs<VectorType>()->getNumElements() == 3 &&425"input vectors must have 3 elements each");426return Builder.CreateIntrinsic(427/*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getCrossIntrinsic(),428ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.cross");429}430case Builtin::BI__builtin_hlsl_dot: {431Value *Op0 = EmitScalarExpr(E->getArg(0));432Value *Op1 = EmitScalarExpr(E->getArg(1));433llvm::Type *T0 = Op0->getType();434llvm::Type *T1 = Op1->getType();435436// If the arguments are scalars, just emit a multiply437if (!T0->isVectorTy() && !T1->isVectorTy()) {438if (T0->isFloatingPointTy())439return Builder.CreateFMul(Op0, Op1, "hlsl.dot");440441if (T0->isIntegerTy())442return Builder.CreateMul(Op0, Op1, "hlsl.dot");443444llvm_unreachable(445"Scalar dot product is only supported on ints and floats.");446}447// For vectors, validate types and emit the appropriate intrinsic448assert(CGM.getContext().hasSameUnqualifiedType(E->getArg(0)->getType(),449E->getArg(1)->getType()) &&450"Dot product operands must have the same type.");451452auto *VecTy0 = E->getArg(0)->getType()->castAs<VectorType>();453assert(VecTy0 && "Dot product argument must be a vector.");454455return Builder.CreateIntrinsic(456/*ReturnType=*/T0->getScalarType(),457getDotProductIntrinsic(CGM.getHLSLRuntime(), VecTy0->getElementType()),458ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.dot");459}460case Builtin::BI__builtin_hlsl_dot4add_i8packed: {461Value *X = EmitScalarExpr(E->getArg(0));462Value *Y = EmitScalarExpr(E->getArg(1));463Value *Acc = EmitScalarExpr(E->getArg(2));464465Intrinsic::ID ID = CGM.getHLSLRuntime().getDot4AddI8PackedIntrinsic();466// Note that the argument order disagrees between the builtin and the467// intrinsic here.468return Builder.CreateIntrinsic(469/*ReturnType=*/Acc->getType(), ID, ArrayRef<Value *>{Acc, X, Y},470nullptr, "hlsl.dot4add.i8packed");471}472case Builtin::BI__builtin_hlsl_dot4add_u8packed: {473Value *X = EmitScalarExpr(E->getArg(0));474Value *Y = EmitScalarExpr(E->getArg(1));475Value *Acc = EmitScalarExpr(E->getArg(2));476477Intrinsic::ID ID = CGM.getHLSLRuntime().getDot4AddU8PackedIntrinsic();478// Note that the argument order disagrees between the builtin and the479// intrinsic here.480return Builder.CreateIntrinsic(481/*ReturnType=*/Acc->getType(), ID, ArrayRef<Value *>{Acc, X, Y},482nullptr, "hlsl.dot4add.u8packed");483}484case Builtin::BI__builtin_hlsl_elementwise_firstbithigh: {485Value *X = EmitScalarExpr(E->getArg(0));486487return Builder.CreateIntrinsic(488/*ReturnType=*/ConvertType(E->getType()),489getFirstBitHighIntrinsic(CGM.getHLSLRuntime(), E->getArg(0)->getType()),490ArrayRef<Value *>{X}, nullptr, "hlsl.firstbithigh");491}492case Builtin::BI__builtin_hlsl_elementwise_firstbitlow: {493Value *X = EmitScalarExpr(E->getArg(0));494495return Builder.CreateIntrinsic(496/*ReturnType=*/ConvertType(E->getType()),497CGM.getHLSLRuntime().getFirstBitLowIntrinsic(), ArrayRef<Value *>{X},498nullptr, "hlsl.firstbitlow");499}500case Builtin::BI__builtin_hlsl_lerp: {501Value *X = EmitScalarExpr(E->getArg(0));502Value *Y = EmitScalarExpr(E->getArg(1));503Value *S = EmitScalarExpr(E->getArg(2));504if (!E->getArg(0)->getType()->hasFloatingRepresentation())505llvm_unreachable("lerp operand must have a float representation");506return Builder.CreateIntrinsic(507/*ReturnType=*/X->getType(), CGM.getHLSLRuntime().getLerpIntrinsic(),508ArrayRef<Value *>{X, Y, S}, nullptr, "hlsl.lerp");509}510case Builtin::BI__builtin_hlsl_normalize: {511Value *X = EmitScalarExpr(E->getArg(0));512513assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&514"normalize operand must have a float representation");515516return Builder.CreateIntrinsic(517/*ReturnType=*/X->getType(),518CGM.getHLSLRuntime().getNormalizeIntrinsic(), ArrayRef<Value *>{X},519nullptr, "hlsl.normalize");520}521case Builtin::BI__builtin_hlsl_elementwise_degrees: {522Value *X = EmitScalarExpr(E->getArg(0));523524assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&525"degree operand must have a float representation");526527return Builder.CreateIntrinsic(528/*ReturnType=*/X->getType(), CGM.getHLSLRuntime().getDegreesIntrinsic(),529ArrayRef<Value *>{X}, nullptr, "hlsl.degrees");530}531case Builtin::BI__builtin_hlsl_elementwise_frac: {532Value *Op0 = EmitScalarExpr(E->getArg(0));533if (!E->getArg(0)->getType()->hasFloatingRepresentation())534llvm_unreachable("frac operand must have a float representation");535return Builder.CreateIntrinsic(536/*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getFracIntrinsic(),537ArrayRef<Value *>{Op0}, nullptr, "hlsl.frac");538}539case Builtin::BI__builtin_hlsl_elementwise_isinf: {540Value *Op0 = EmitScalarExpr(E->getArg(0));541llvm::Type *Xty = Op0->getType();542llvm::Type *retType = llvm::Type::getInt1Ty(this->getLLVMContext());543if (Xty->isVectorTy()) {544auto *XVecTy = E->getArg(0)->getType()->castAs<VectorType>();545retType = llvm::VectorType::get(546retType, ElementCount::getFixed(XVecTy->getNumElements()));547}548if (!E->getArg(0)->getType()->hasFloatingRepresentation())549llvm_unreachable("isinf operand must have a float representation");550return Builder.CreateIntrinsic(retType, Intrinsic::dx_isinf,551ArrayRef<Value *>{Op0}, nullptr, "dx.isinf");552}553case Builtin::BI__builtin_hlsl_mad: {554Value *M = EmitScalarExpr(E->getArg(0));555Value *A = EmitScalarExpr(E->getArg(1));556Value *B = EmitScalarExpr(E->getArg(2));557if (E->getArg(0)->getType()->hasFloatingRepresentation())558return Builder.CreateIntrinsic(559/*ReturnType*/ M->getType(), Intrinsic::fmuladd,560ArrayRef<Value *>{M, A, B}, nullptr, "hlsl.fmad");561562if (E->getArg(0)->getType()->hasSignedIntegerRepresentation()) {563if (CGM.getTarget().getTriple().getArch() == llvm::Triple::dxil)564return Builder.CreateIntrinsic(565/*ReturnType*/ M->getType(), Intrinsic::dx_imad,566ArrayRef<Value *>{M, A, B}, nullptr, "dx.imad");567568Value *Mul = Builder.CreateNSWMul(M, A);569return Builder.CreateNSWAdd(Mul, B);570}571assert(E->getArg(0)->getType()->hasUnsignedIntegerRepresentation());572if (CGM.getTarget().getTriple().getArch() == llvm::Triple::dxil)573return Builder.CreateIntrinsic(574/*ReturnType=*/M->getType(), Intrinsic::dx_umad,575ArrayRef<Value *>{M, A, B}, nullptr, "dx.umad");576577Value *Mul = Builder.CreateNUWMul(M, A);578return Builder.CreateNUWAdd(Mul, B);579}580case Builtin::BI__builtin_hlsl_elementwise_rcp: {581Value *Op0 = EmitScalarExpr(E->getArg(0));582if (!E->getArg(0)->getType()->hasFloatingRepresentation())583llvm_unreachable("rcp operand must have a float representation");584llvm::Type *Ty = Op0->getType();585llvm::Type *EltTy = Ty->getScalarType();586Constant *One = Ty->isVectorTy()587? ConstantVector::getSplat(588ElementCount::getFixed(589cast<FixedVectorType>(Ty)->getNumElements()),590ConstantFP::get(EltTy, 1.0))591: ConstantFP::get(EltTy, 1.0);592return Builder.CreateFDiv(One, Op0, "hlsl.rcp");593}594case Builtin::BI__builtin_hlsl_elementwise_rsqrt: {595Value *Op0 = EmitScalarExpr(E->getArg(0));596if (!E->getArg(0)->getType()->hasFloatingRepresentation())597llvm_unreachable("rsqrt operand must have a float representation");598return Builder.CreateIntrinsic(599/*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getRsqrtIntrinsic(),600ArrayRef<Value *>{Op0}, nullptr, "hlsl.rsqrt");601}602case Builtin::BI__builtin_hlsl_elementwise_saturate: {603Value *Op0 = EmitScalarExpr(E->getArg(0));604assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&605"saturate operand must have a float representation");606return Builder.CreateIntrinsic(607/*ReturnType=*/Op0->getType(),608CGM.getHLSLRuntime().getSaturateIntrinsic(), ArrayRef<Value *>{Op0},609nullptr, "hlsl.saturate");610}611case Builtin::BI__builtin_hlsl_select: {612Value *OpCond = EmitScalarExpr(E->getArg(0));613RValue RValTrue = EmitAnyExpr(E->getArg(1));614Value *OpTrue =615RValTrue.isScalar()616? RValTrue.getScalarVal()617: RValTrue.getAggregatePointer(E->getArg(1)->getType(), *this);618RValue RValFalse = EmitAnyExpr(E->getArg(2));619Value *OpFalse =620RValFalse.isScalar()621? RValFalse.getScalarVal()622: RValFalse.getAggregatePointer(E->getArg(2)->getType(), *this);623if (auto *VTy = E->getType()->getAs<VectorType>()) {624if (!OpTrue->getType()->isVectorTy())625OpTrue =626Builder.CreateVectorSplat(VTy->getNumElements(), OpTrue, "splat");627if (!OpFalse->getType()->isVectorTy())628OpFalse =629Builder.CreateVectorSplat(VTy->getNumElements(), OpFalse, "splat");630}631632Value *SelectVal =633Builder.CreateSelect(OpCond, OpTrue, OpFalse, "hlsl.select");634if (!RValTrue.isScalar())635Builder.CreateStore(SelectVal, ReturnValue.getAddress(),636ReturnValue.isVolatile());637638return SelectVal;639}640case Builtin::BI__builtin_hlsl_step: {641Value *Op0 = EmitScalarExpr(E->getArg(0));642Value *Op1 = EmitScalarExpr(E->getArg(1));643assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&644E->getArg(1)->getType()->hasFloatingRepresentation() &&645"step operands must have a float representation");646return Builder.CreateIntrinsic(647/*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getStepIntrinsic(),648ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.step");649}650case Builtin::BI__builtin_hlsl_wave_active_all_true: {651Value *Op = EmitScalarExpr(E->getArg(0));652assert(Op->getType()->isIntegerTy(1) &&653"Intrinsic WaveActiveAllTrue operand must be a bool");654655Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveAllTrueIntrinsic();656return EmitRuntimeCall(657Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});658}659case Builtin::BI__builtin_hlsl_wave_active_any_true: {660Value *Op = EmitScalarExpr(E->getArg(0));661assert(Op->getType()->isIntegerTy(1) &&662"Intrinsic WaveActiveAnyTrue operand must be a bool");663664Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveAnyTrueIntrinsic();665return EmitRuntimeCall(666Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});667}668case Builtin::BI__builtin_hlsl_wave_active_count_bits: {669Value *OpExpr = EmitScalarExpr(E->getArg(0));670Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveCountBitsIntrinsic();671return EmitRuntimeCall(672Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID),673ArrayRef{OpExpr});674}675case Builtin::BI__builtin_hlsl_wave_active_sum: {676// Due to the use of variadic arguments, explicitly retreive argument677Value *OpExpr = EmitScalarExpr(E->getArg(0));678Intrinsic::ID IID = getWaveActiveSumIntrinsic(679getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),680E->getArg(0)->getType());681682return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(683&CGM.getModule(), IID, {OpExpr->getType()}),684ArrayRef{OpExpr}, "hlsl.wave.active.sum");685}686case Builtin::BI__builtin_hlsl_wave_active_max: {687// Due to the use of variadic arguments, explicitly retreive argument688Value *OpExpr = EmitScalarExpr(E->getArg(0));689Intrinsic::ID IID = getWaveActiveMaxIntrinsic(690getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),691E->getArg(0)->getType());692693return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(694&CGM.getModule(), IID, {OpExpr->getType()}),695ArrayRef{OpExpr}, "hlsl.wave.active.max");696}697case Builtin::BI__builtin_hlsl_wave_get_lane_index: {698// We don't define a SPIR-V intrinsic, instead it is a SPIR-V built-in699// defined in SPIRVBuiltins.td. So instead we manually get the matching name700// for the DirectX intrinsic and the demangled builtin name701switch (CGM.getTarget().getTriple().getArch()) {702case llvm::Triple::dxil:703return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(704&CGM.getModule(), Intrinsic::dx_wave_getlaneindex));705case llvm::Triple::spirv:706return EmitRuntimeCall(CGM.CreateRuntimeFunction(707llvm::FunctionType::get(IntTy, {}, false),708"__hlsl_wave_get_lane_index", {}, false, true));709default:710llvm_unreachable(711"Intrinsic WaveGetLaneIndex not supported by target architecture");712}713}714case Builtin::BI__builtin_hlsl_wave_is_first_lane: {715Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveIsFirstLaneIntrinsic();716return EmitRuntimeCall(717Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID));718}719case Builtin::BI__builtin_hlsl_wave_get_lane_count: {720Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveGetLaneCountIntrinsic();721return EmitRuntimeCall(722Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID));723}724case Builtin::BI__builtin_hlsl_wave_read_lane_at: {725// Due to the use of variadic arguments we must explicitly retreive them and726// create our function type.727Value *OpExpr = EmitScalarExpr(E->getArg(0));728Value *OpIndex = EmitScalarExpr(E->getArg(1));729return EmitRuntimeCall(730Intrinsic::getOrInsertDeclaration(731&CGM.getModule(), CGM.getHLSLRuntime().getWaveReadLaneAtIntrinsic(),732{OpExpr->getType()}),733ArrayRef{OpExpr, OpIndex}, "hlsl.wave.readlane");734}735case Builtin::BI__builtin_hlsl_elementwise_sign: {736auto *Arg0 = E->getArg(0);737Value *Op0 = EmitScalarExpr(Arg0);738llvm::Type *Xty = Op0->getType();739llvm::Type *retType = llvm::Type::getInt32Ty(this->getLLVMContext());740if (Xty->isVectorTy()) {741auto *XVecTy = Arg0->getType()->castAs<VectorType>();742retType = llvm::VectorType::get(743retType, ElementCount::getFixed(XVecTy->getNumElements()));744}745assert((Arg0->getType()->hasFloatingRepresentation() ||746Arg0->getType()->hasIntegerRepresentation()) &&747"sign operand must have a float or int representation");748749if (Arg0->getType()->hasUnsignedIntegerRepresentation()) {750Value *Cmp = Builder.CreateICmpEQ(Op0, ConstantInt::get(Xty, 0));751return Builder.CreateSelect(Cmp, ConstantInt::get(retType, 0),752ConstantInt::get(retType, 1), "hlsl.sign");753}754755return Builder.CreateIntrinsic(756retType, CGM.getHLSLRuntime().getSignIntrinsic(),757ArrayRef<Value *>{Op0}, nullptr, "hlsl.sign");758}759case Builtin::BI__builtin_hlsl_elementwise_radians: {760Value *Op0 = EmitScalarExpr(E->getArg(0));761assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&762"radians operand must have a float representation");763return Builder.CreateIntrinsic(764/*ReturnType=*/Op0->getType(),765CGM.getHLSLRuntime().getRadiansIntrinsic(), ArrayRef<Value *>{Op0},766nullptr, "hlsl.radians");767}768case Builtin::BI__builtin_hlsl_buffer_update_counter: {769Value *ResHandle = EmitScalarExpr(E->getArg(0));770Value *Offset = EmitScalarExpr(E->getArg(1));771Value *OffsetI8 = Builder.CreateIntCast(Offset, Int8Ty, true);772return Builder.CreateIntrinsic(773/*ReturnType=*/Offset->getType(),774CGM.getHLSLRuntime().getBufferUpdateCounterIntrinsic(),775ArrayRef<Value *>{ResHandle, OffsetI8}, nullptr);776}777case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {778779assert((E->getArg(0)->getType()->hasFloatingRepresentation() &&780E->getArg(1)->getType()->hasUnsignedIntegerRepresentation() &&781E->getArg(2)->getType()->hasUnsignedIntegerRepresentation()) &&782"asuint operands types mismatch");783return handleHlslSplitdouble(E, this);784}785case Builtin::BI__builtin_hlsl_elementwise_clip:786assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&787"clip operands types mismatch");788return handleHlslClip(E, this);789case Builtin::BI__builtin_hlsl_group_memory_barrier_with_group_sync: {790Intrinsic::ID ID =791CGM.getHLSLRuntime().getGroupMemoryBarrierWithGroupSyncIntrinsic();792return EmitRuntimeCall(793Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID));794}795case Builtin::BI__builtin_get_spirv_spec_constant_bool:796case Builtin::BI__builtin_get_spirv_spec_constant_short:797case Builtin::BI__builtin_get_spirv_spec_constant_ushort:798case Builtin::BI__builtin_get_spirv_spec_constant_int:799case Builtin::BI__builtin_get_spirv_spec_constant_uint:800case Builtin::BI__builtin_get_spirv_spec_constant_longlong:801case Builtin::BI__builtin_get_spirv_spec_constant_ulonglong:802case Builtin::BI__builtin_get_spirv_spec_constant_half:803case Builtin::BI__builtin_get_spirv_spec_constant_float:804case Builtin::BI__builtin_get_spirv_spec_constant_double: {805llvm::Function *SpecConstantFn = getSpecConstantFunction(E->getType());806llvm::Value *SpecId = EmitScalarExpr(E->getArg(0));807llvm::Value *DefaultVal = EmitScalarExpr(E->getArg(1));808llvm::Value *Args[] = {SpecId, DefaultVal};809return Builder.CreateCall(SpecConstantFn, Args);810}811}812return nullptr;813}814815llvm::Function *clang::CodeGen::CodeGenFunction::getSpecConstantFunction(816const clang::QualType &SpecConstantType) {817818// Find or create the declaration for the function.819llvm::Module *M = &CGM.getModule();820std::string MangledName =821getSpecConstantFunctionName(SpecConstantType, getContext());822llvm::Function *SpecConstantFn = M->getFunction(MangledName);823824if (!SpecConstantFn) {825llvm::Type *IntType = ConvertType(getContext().IntTy);826llvm::Type *RetTy = ConvertType(SpecConstantType);827llvm::Type *ArgTypes[] = {IntType, RetTy};828llvm::FunctionType *FnTy = llvm::FunctionType::get(RetTy, ArgTypes, false);829SpecConstantFn = llvm::Function::Create(830FnTy, llvm::GlobalValue::ExternalLinkage, MangledName, M);831}832return SpecConstantFn;833}834835836