Path: blob/main/contrib/llvm-project/clang/lib/Sema/SemaHLSL.cpp
35233 views
//===- SemaHLSL.cpp - Semantic Analysis for HLSL constructs ---------------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7// This implements Semantic Analysis for HLSL constructs.8//===----------------------------------------------------------------------===//910#include "clang/Sema/SemaHLSL.h"11#include "clang/AST/Decl.h"12#include "clang/AST/Expr.h"13#include "clang/AST/RecursiveASTVisitor.h"14#include "clang/Basic/DiagnosticSema.h"15#include "clang/Basic/LLVM.h"16#include "clang/Basic/TargetInfo.h"17#include "clang/Sema/ParsedAttr.h"18#include "clang/Sema/Sema.h"19#include "llvm/ADT/STLExtras.h"20#include "llvm/ADT/StringExtras.h"21#include "llvm/ADT/StringRef.h"22#include "llvm/Support/Casting.h"23#include "llvm/Support/ErrorHandling.h"24#include "llvm/TargetParser/Triple.h"25#include <iterator>2627using namespace clang;2829SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {}3031Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,32SourceLocation KwLoc, IdentifierInfo *Ident,33SourceLocation IdentLoc,34SourceLocation LBrace) {35// For anonymous namespace, take the location of the left brace.36DeclContext *LexicalParent = SemaRef.getCurLexicalContext();37HLSLBufferDecl *Result = HLSLBufferDecl::Create(38getASTContext(), LexicalParent, CBuffer, KwLoc, Ident, IdentLoc, LBrace);3940SemaRef.PushOnScopeChains(Result, BufferScope);41SemaRef.PushDeclContext(BufferScope, Result);4243return Result;44}4546// Calculate the size of a legacy cbuffer type based on47// https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-packing-rules48static unsigned calculateLegacyCbufferSize(const ASTContext &Context,49QualType T) {50unsigned Size = 0;51constexpr unsigned CBufferAlign = 128;52if (const RecordType *RT = T->getAs<RecordType>()) {53const RecordDecl *RD = RT->getDecl();54for (const FieldDecl *Field : RD->fields()) {55QualType Ty = Field->getType();56unsigned FieldSize = calculateLegacyCbufferSize(Context, Ty);57unsigned FieldAlign = 32;58if (Ty->isAggregateType())59FieldAlign = CBufferAlign;60Size = llvm::alignTo(Size, FieldAlign);61Size += FieldSize;62}63} else if (const ConstantArrayType *AT = Context.getAsConstantArrayType(T)) {64if (unsigned ElementCount = AT->getSize().getZExtValue()) {65unsigned ElementSize =66calculateLegacyCbufferSize(Context, AT->getElementType());67unsigned AlignedElementSize = llvm::alignTo(ElementSize, CBufferAlign);68Size = AlignedElementSize * (ElementCount - 1) + ElementSize;69}70} else if (const VectorType *VT = T->getAs<VectorType>()) {71unsigned ElementCount = VT->getNumElements();72unsigned ElementSize =73calculateLegacyCbufferSize(Context, VT->getElementType());74Size = ElementSize * ElementCount;75} else {76Size = Context.getTypeSize(T);77}78return Size;79}8081void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) {82auto *BufDecl = cast<HLSLBufferDecl>(Dcl);83BufDecl->setRBraceLoc(RBrace);8485// Validate packoffset.86llvm::SmallVector<std::pair<VarDecl *, HLSLPackOffsetAttr *>> PackOffsetVec;87bool HasPackOffset = false;88bool HasNonPackOffset = false;89for (auto *Field : BufDecl->decls()) {90VarDecl *Var = dyn_cast<VarDecl>(Field);91if (!Var)92continue;93if (Field->hasAttr<HLSLPackOffsetAttr>()) {94PackOffsetVec.emplace_back(Var, Field->getAttr<HLSLPackOffsetAttr>());95HasPackOffset = true;96} else {97HasNonPackOffset = true;98}99}100101if (HasPackOffset && HasNonPackOffset)102Diag(BufDecl->getLocation(), diag::warn_hlsl_packoffset_mix);103104if (HasPackOffset) {105ASTContext &Context = getASTContext();106// Make sure no overlap in packoffset.107// Sort PackOffsetVec by offset.108std::sort(PackOffsetVec.begin(), PackOffsetVec.end(),109[](const std::pair<VarDecl *, HLSLPackOffsetAttr *> &LHS,110const std::pair<VarDecl *, HLSLPackOffsetAttr *> &RHS) {111return LHS.second->getOffset() < RHS.second->getOffset();112});113114for (unsigned i = 0; i < PackOffsetVec.size() - 1; i++) {115VarDecl *Var = PackOffsetVec[i].first;116HLSLPackOffsetAttr *Attr = PackOffsetVec[i].second;117unsigned Size = calculateLegacyCbufferSize(Context, Var->getType());118unsigned Begin = Attr->getOffset() * 32;119unsigned End = Begin + Size;120unsigned NextBegin = PackOffsetVec[i + 1].second->getOffset() * 32;121if (End > NextBegin) {122VarDecl *NextVar = PackOffsetVec[i + 1].first;123Diag(NextVar->getLocation(), diag::err_hlsl_packoffset_overlap)124<< NextVar << Var;125}126}127}128129SemaRef.PopDeclContext();130}131132HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,133const AttributeCommonInfo &AL,134int X, int Y, int Z) {135if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) {136if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) {137Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;138Diag(AL.getLoc(), diag::note_conflicting_attribute);139}140return nullptr;141}142return ::new (getASTContext())143HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z);144}145146HLSLShaderAttr *147SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,148llvm::Triple::EnvironmentType ShaderType) {149if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {150if (NT->getType() != ShaderType) {151Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;152Diag(AL.getLoc(), diag::note_conflicting_attribute);153}154return nullptr;155}156return HLSLShaderAttr::Create(getASTContext(), ShaderType, AL);157}158159HLSLParamModifierAttr *160SemaHLSL::mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,161HLSLParamModifierAttr::Spelling Spelling) {162// We can only merge an `in` attribute with an `out` attribute. All other163// combinations of duplicated attributes are ill-formed.164if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) {165if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) ||166(PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) {167D->dropAttr<HLSLParamModifierAttr>();168SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()};169return HLSLParamModifierAttr::Create(170getASTContext(), /*MergedSpelling=*/true, AdjustedRange,171HLSLParamModifierAttr::Keyword_inout);172}173Diag(AL.getLoc(), diag::err_hlsl_duplicate_parameter_modifier) << AL;174Diag(PA->getLocation(), diag::note_conflicting_attribute);175return nullptr;176}177return HLSLParamModifierAttr::Create(getASTContext(), AL);178}179180void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {181auto &TargetInfo = getASTContext().getTargetInfo();182183if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)184return;185186llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment();187if (HLSLShaderAttr::isValidShaderType(Env) && Env != llvm::Triple::Library) {188if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {189// The entry point is already annotated - check that it matches the190// triple.191if (Shader->getType() != Env) {192Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch)193<< Shader;194FD->setInvalidDecl();195}196} else {197// Implicitly add the shader attribute if the entry function isn't198// explicitly annotated.199FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), Env,200FD->getBeginLoc()));201}202} else {203switch (Env) {204case llvm::Triple::UnknownEnvironment:205case llvm::Triple::Library:206break;207default:208llvm_unreachable("Unhandled environment in triple");209}210}211}212213void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {214const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();215assert(ShaderAttr && "Entry point has no shader attribute");216llvm::Triple::EnvironmentType ST = ShaderAttr->getType();217218switch (ST) {219case llvm::Triple::Pixel:220case llvm::Triple::Vertex:221case llvm::Triple::Geometry:222case llvm::Triple::Hull:223case llvm::Triple::Domain:224case llvm::Triple::RayGeneration:225case llvm::Triple::Intersection:226case llvm::Triple::AnyHit:227case llvm::Triple::ClosestHit:228case llvm::Triple::Miss:229case llvm::Triple::Callable:230if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {231DiagnoseAttrStageMismatch(NT, ST,232{llvm::Triple::Compute,233llvm::Triple::Amplification,234llvm::Triple::Mesh});235FD->setInvalidDecl();236}237break;238239case llvm::Triple::Compute:240case llvm::Triple::Amplification:241case llvm::Triple::Mesh:242if (!FD->hasAttr<HLSLNumThreadsAttr>()) {243Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads)244<< llvm::Triple::getEnvironmentTypeName(ST);245FD->setInvalidDecl();246}247break;248default:249llvm_unreachable("Unhandled environment in triple");250}251252for (ParmVarDecl *Param : FD->parameters()) {253if (const auto *AnnotationAttr = Param->getAttr<HLSLAnnotationAttr>()) {254CheckSemanticAnnotation(FD, Param, AnnotationAttr);255} else {256// FIXME: Handle struct parameters where annotations are on struct fields.257// See: https://github.com/llvm/llvm-project/issues/57875258Diag(FD->getLocation(), diag::err_hlsl_missing_semantic_annotation);259Diag(Param->getLocation(), diag::note_previous_decl) << Param;260FD->setInvalidDecl();261}262}263// FIXME: Verify return type semantic annotation.264}265266void SemaHLSL::CheckSemanticAnnotation(267FunctionDecl *EntryPoint, const Decl *Param,268const HLSLAnnotationAttr *AnnotationAttr) {269auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();270assert(ShaderAttr && "Entry point has no shader attribute");271llvm::Triple::EnvironmentType ST = ShaderAttr->getType();272273switch (AnnotationAttr->getKind()) {274case attr::HLSLSV_DispatchThreadID:275case attr::HLSLSV_GroupIndex:276if (ST == llvm::Triple::Compute)277return;278DiagnoseAttrStageMismatch(AnnotationAttr, ST, {llvm::Triple::Compute});279break;280default:281llvm_unreachable("Unknown HLSLAnnotationAttr");282}283}284285void SemaHLSL::DiagnoseAttrStageMismatch(286const Attr *A, llvm::Triple::EnvironmentType Stage,287std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) {288SmallVector<StringRef, 8> StageStrings;289llvm::transform(AllowedStages, std::back_inserter(StageStrings),290[](llvm::Triple::EnvironmentType ST) {291return StringRef(292HLSLShaderAttr::ConvertEnvironmentTypeToStr(ST));293});294Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)295<< A << llvm::Triple::getEnvironmentTypeName(Stage)296<< (AllowedStages.size() != 1) << join(StageStrings, ", ");297}298299void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) {300llvm::VersionTuple SMVersion =301getASTContext().getTargetInfo().getTriple().getOSVersion();302uint32_t ZMax = 1024;303uint32_t ThreadMax = 1024;304if (SMVersion.getMajor() <= 4) {305ZMax = 1;306ThreadMax = 768;307} else if (SMVersion.getMajor() == 5) {308ZMax = 64;309ThreadMax = 1024;310}311312uint32_t X;313if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), X))314return;315if (X > 1024) {316Diag(AL.getArgAsExpr(0)->getExprLoc(),317diag::err_hlsl_numthreads_argument_oor)318<< 0 << 1024;319return;320}321uint32_t Y;322if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Y))323return;324if (Y > 1024) {325Diag(AL.getArgAsExpr(1)->getExprLoc(),326diag::err_hlsl_numthreads_argument_oor)327<< 1 << 1024;328return;329}330uint32_t Z;331if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Z))332return;333if (Z > ZMax) {334SemaRef.Diag(AL.getArgAsExpr(2)->getExprLoc(),335diag::err_hlsl_numthreads_argument_oor)336<< 2 << ZMax;337return;338}339340if (X * Y * Z > ThreadMax) {341Diag(AL.getLoc(), diag::err_hlsl_numthreads_invalid) << ThreadMax;342return;343}344345HLSLNumThreadsAttr *NewAttr = mergeNumThreadsAttr(D, AL, X, Y, Z);346if (NewAttr)347D->addAttr(NewAttr);348}349350static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) {351if (!T->hasUnsignedIntegerRepresentation())352return false;353if (const auto *VT = T->getAs<VectorType>())354return VT->getNumElements() <= 3;355return true;356}357358void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) {359auto *VD = cast<ValueDecl>(D);360if (!isLegalTypeForHLSLSV_DispatchThreadID(VD->getType())) {361Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)362<< AL << "uint/uint2/uint3";363return;364}365366D->addAttr(::new (getASTContext())367HLSLSV_DispatchThreadIDAttr(getASTContext(), AL));368}369370void SemaHLSL::handlePackOffsetAttr(Decl *D, const ParsedAttr &AL) {371if (!isa<VarDecl>(D) || !isa<HLSLBufferDecl>(D->getDeclContext())) {372Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_ast_node)373<< AL << "shader constant in a constant buffer";374return;375}376377uint32_t SubComponent;378if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), SubComponent))379return;380uint32_t Component;381if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Component))382return;383384QualType T = cast<VarDecl>(D)->getType().getCanonicalType();385// Check if T is an array or struct type.386// TODO: mark matrix type as aggregate type.387bool IsAggregateTy = (T->isArrayType() || T->isStructureType());388389// Check Component is valid for T.390if (Component) {391unsigned Size = getASTContext().getTypeSize(T);392if (IsAggregateTy || Size > 128) {393Diag(AL.getLoc(), diag::err_hlsl_packoffset_cross_reg_boundary);394return;395} else {396// Make sure Component + sizeof(T) <= 4.397if ((Component * 32 + Size) > 128) {398Diag(AL.getLoc(), diag::err_hlsl_packoffset_cross_reg_boundary);399return;400}401QualType EltTy = T;402if (const auto *VT = T->getAs<VectorType>())403EltTy = VT->getElementType();404unsigned Align = getASTContext().getTypeAlign(EltTy);405if (Align > 32 && Component == 1) {406// NOTE: Component 3 will hit err_hlsl_packoffset_cross_reg_boundary.407// So we only need to check Component 1 here.408Diag(AL.getLoc(), diag::err_hlsl_packoffset_alignment_mismatch)409<< Align << EltTy;410return;411}412}413}414415D->addAttr(::new (getASTContext()) HLSLPackOffsetAttr(416getASTContext(), AL, SubComponent, Component));417}418419void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr &AL) {420StringRef Str;421SourceLocation ArgLoc;422if (!SemaRef.checkStringLiteralArgumentAttr(AL, 0, Str, &ArgLoc))423return;424425llvm::Triple::EnvironmentType ShaderType;426if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Str, ShaderType)) {427Diag(AL.getLoc(), diag::warn_attribute_type_not_supported)428<< AL << Str << ArgLoc;429return;430}431432// FIXME: check function match the shader stage.433434HLSLShaderAttr *NewAttr = mergeShaderAttr(D, AL, ShaderType);435if (NewAttr)436D->addAttr(NewAttr);437}438439void SemaHLSL::handleResourceClassAttr(Decl *D, const ParsedAttr &AL) {440if (!AL.isArgIdent(0)) {441Diag(AL.getLoc(), diag::err_attribute_argument_type)442<< AL << AANT_ArgumentIdentifier;443return;444}445446IdentifierLoc *Loc = AL.getArgAsIdent(0);447StringRef Identifier = Loc->Ident->getName();448SourceLocation ArgLoc = Loc->Loc;449450// Validate.451llvm::dxil::ResourceClass RC;452if (!HLSLResourceClassAttr::ConvertStrToResourceClass(Identifier, RC)) {453Diag(ArgLoc, diag::warn_attribute_type_not_supported)454<< "ResourceClass" << Identifier;455return;456}457458D->addAttr(HLSLResourceClassAttr::Create(getASTContext(), RC, ArgLoc));459}460461void SemaHLSL::handleResourceBindingAttr(Decl *D, const ParsedAttr &AL) {462StringRef Space = "space0";463StringRef Slot = "";464465if (!AL.isArgIdent(0)) {466Diag(AL.getLoc(), diag::err_attribute_argument_type)467<< AL << AANT_ArgumentIdentifier;468return;469}470471IdentifierLoc *Loc = AL.getArgAsIdent(0);472StringRef Str = Loc->Ident->getName();473SourceLocation ArgLoc = Loc->Loc;474475SourceLocation SpaceArgLoc;476if (AL.getNumArgs() == 2) {477Slot = Str;478if (!AL.isArgIdent(1)) {479Diag(AL.getLoc(), diag::err_attribute_argument_type)480<< AL << AANT_ArgumentIdentifier;481return;482}483484IdentifierLoc *Loc = AL.getArgAsIdent(1);485Space = Loc->Ident->getName();486SpaceArgLoc = Loc->Loc;487} else {488Slot = Str;489}490491// Validate.492if (!Slot.empty()) {493switch (Slot[0]) {494case 'u':495case 'b':496case 's':497case 't':498break;499default:500Diag(ArgLoc, diag::err_hlsl_unsupported_register_type)501<< Slot.substr(0, 1);502return;503}504505StringRef SlotNum = Slot.substr(1);506unsigned Num = 0;507if (SlotNum.getAsInteger(10, Num)) {508Diag(ArgLoc, diag::err_hlsl_unsupported_register_number);509return;510}511}512513if (!Space.starts_with("space")) {514Diag(SpaceArgLoc, diag::err_hlsl_expected_space) << Space;515return;516}517StringRef SpaceNum = Space.substr(5);518unsigned Num = 0;519if (SpaceNum.getAsInteger(10, Num)) {520Diag(SpaceArgLoc, diag::err_hlsl_expected_space) << Space;521return;522}523524// FIXME: check reg type match decl. Issue525// https://github.com/llvm/llvm-project/issues/57886.526HLSLResourceBindingAttr *NewAttr =527HLSLResourceBindingAttr::Create(getASTContext(), Slot, Space, AL);528if (NewAttr)529D->addAttr(NewAttr);530}531532void SemaHLSL::handleParamModifierAttr(Decl *D, const ParsedAttr &AL) {533HLSLParamModifierAttr *NewAttr = mergeParamModifierAttr(534D, AL,535static_cast<HLSLParamModifierAttr::Spelling>(AL.getSemanticSpelling()));536if (NewAttr)537D->addAttr(NewAttr);538}539540namespace {541542/// This class implements HLSL availability diagnostics for default543/// and relaxed mode544///545/// The goal of this diagnostic is to emit an error or warning when an546/// unavailable API is found in code that is reachable from the shader547/// entry function or from an exported function (when compiling a shader548/// library).549///550/// This is done by traversing the AST of all shader entry point functions551/// and of all exported functions, and any functions that are referenced552/// from this AST. In other words, any functions that are reachable from553/// the entry points.554class DiagnoseHLSLAvailability555: public RecursiveASTVisitor<DiagnoseHLSLAvailability> {556557Sema &SemaRef;558559// Stack of functions to be scaned560llvm::SmallVector<const FunctionDecl *, 8> DeclsToScan;561562// Tracks which environments functions have been scanned in.563//564// Maps FunctionDecl to an unsigned number that represents the set of shader565// environments the function has been scanned for.566// The llvm::Triple::EnvironmentType enum values for shader stages guaranteed567// to be numbered from llvm::Triple::Pixel to llvm::Triple::Amplification568// (verified by static_asserts in Triple.cpp), we can use it to index569// individual bits in the set, as long as we shift the values to start with 0570// by subtracting the value of llvm::Triple::Pixel first.571//572// The N'th bit in the set will be set if the function has been scanned573// in shader environment whose llvm::Triple::EnvironmentType integer value574// equals (llvm::Triple::Pixel + N).575//576// For example, if a function has been scanned in compute and pixel stage577// environment, the value will be 0x21 (100001 binary) because:578//579// (int)(llvm::Triple::Pixel - llvm::Triple::Pixel) == 0580// (int)(llvm::Triple::Compute - llvm::Triple::Pixel) == 5581//582// A FunctionDecl is mapped to 0 (or not included in the map) if it has not583// been scanned in any environment.584llvm::DenseMap<const FunctionDecl *, unsigned> ScannedDecls;585586// Do not access these directly, use the get/set methods below to make587// sure the values are in sync588llvm::Triple::EnvironmentType CurrentShaderEnvironment;589unsigned CurrentShaderStageBit;590591// True if scanning a function that was already scanned in a different592// shader stage context, and therefore we should not report issues that593// depend only on shader model version because they would be duplicate.594bool ReportOnlyShaderStageIssues;595596// Helper methods for dealing with current stage context / environment597void SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType) {598static_assert(sizeof(unsigned) >= 4);599assert(HLSLShaderAttr::isValidShaderType(ShaderType));600assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 &&601"ShaderType is too big for this bitmap"); // 31 is reserved for602// "unknown"603604unsigned bitmapIndex = ShaderType - llvm::Triple::Pixel;605CurrentShaderEnvironment = ShaderType;606CurrentShaderStageBit = (1 << bitmapIndex);607}608609void SetUnknownShaderStageContext() {610CurrentShaderEnvironment = llvm::Triple::UnknownEnvironment;611CurrentShaderStageBit = (1 << 31);612}613614llvm::Triple::EnvironmentType GetCurrentShaderEnvironment() const {615return CurrentShaderEnvironment;616}617618bool InUnknownShaderStageContext() const {619return CurrentShaderEnvironment == llvm::Triple::UnknownEnvironment;620}621622// Helper methods for dealing with shader stage bitmap623void AddToScannedFunctions(const FunctionDecl *FD) {624unsigned &ScannedStages = ScannedDecls.getOrInsertDefault(FD);625ScannedStages |= CurrentShaderStageBit;626}627628unsigned GetScannedStages(const FunctionDecl *FD) {629return ScannedDecls.getOrInsertDefault(FD);630}631632bool WasAlreadyScannedInCurrentStage(const FunctionDecl *FD) {633return WasAlreadyScannedInCurrentStage(GetScannedStages(FD));634}635636bool WasAlreadyScannedInCurrentStage(unsigned ScannerStages) {637return ScannerStages & CurrentShaderStageBit;638}639640static bool NeverBeenScanned(unsigned ScannedStages) {641return ScannedStages == 0;642}643644// Scanning methods645void HandleFunctionOrMethodRef(FunctionDecl *FD, Expr *RefExpr);646void CheckDeclAvailability(NamedDecl *D, const AvailabilityAttr *AA,647SourceRange Range);648const AvailabilityAttr *FindAvailabilityAttr(const Decl *D);649bool HasMatchingEnvironmentOrNone(const AvailabilityAttr *AA);650651public:652DiagnoseHLSLAvailability(Sema &SemaRef) : SemaRef(SemaRef) {}653654// AST traversal methods655void RunOnTranslationUnit(const TranslationUnitDecl *TU);656void RunOnFunction(const FunctionDecl *FD);657658bool VisitDeclRefExpr(DeclRefExpr *DRE) {659FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(DRE->getDecl());660if (FD)661HandleFunctionOrMethodRef(FD, DRE);662return true;663}664665bool VisitMemberExpr(MemberExpr *ME) {666FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(ME->getMemberDecl());667if (FD)668HandleFunctionOrMethodRef(FD, ME);669return true;670}671};672673void DiagnoseHLSLAvailability::HandleFunctionOrMethodRef(FunctionDecl *FD,674Expr *RefExpr) {675assert((isa<DeclRefExpr>(RefExpr) || isa<MemberExpr>(RefExpr)) &&676"expected DeclRefExpr or MemberExpr");677678// has a definition -> add to stack to be scanned679const FunctionDecl *FDWithBody = nullptr;680if (FD->hasBody(FDWithBody)) {681if (!WasAlreadyScannedInCurrentStage(FDWithBody))682DeclsToScan.push_back(FDWithBody);683return;684}685686// no body -> diagnose availability687const AvailabilityAttr *AA = FindAvailabilityAttr(FD);688if (AA)689CheckDeclAvailability(690FD, AA, SourceRange(RefExpr->getBeginLoc(), RefExpr->getEndLoc()));691}692693void DiagnoseHLSLAvailability::RunOnTranslationUnit(694const TranslationUnitDecl *TU) {695696// Iterate over all shader entry functions and library exports, and for those697// that have a body (definiton), run diag scan on each, setting appropriate698// shader environment context based on whether it is a shader entry function699// or an exported function. Exported functions can be in namespaces and in700// export declarations so we need to scan those declaration contexts as well.701llvm::SmallVector<const DeclContext *, 8> DeclContextsToScan;702DeclContextsToScan.push_back(TU);703704while (!DeclContextsToScan.empty()) {705const DeclContext *DC = DeclContextsToScan.pop_back_val();706for (auto &D : DC->decls()) {707// do not scan implicit declaration generated by the implementation708if (D->isImplicit())709continue;710711// for namespace or export declaration add the context to the list to be712// scanned later713if (llvm::dyn_cast<NamespaceDecl>(D) || llvm::dyn_cast<ExportDecl>(D)) {714DeclContextsToScan.push_back(llvm::dyn_cast<DeclContext>(D));715continue;716}717718// skip over other decls or function decls without body719const FunctionDecl *FD = llvm::dyn_cast<FunctionDecl>(D);720if (!FD || !FD->isThisDeclarationADefinition())721continue;722723// shader entry point724if (HLSLShaderAttr *ShaderAttr = FD->getAttr<HLSLShaderAttr>()) {725SetShaderStageContext(ShaderAttr->getType());726RunOnFunction(FD);727continue;728}729// exported library function730// FIXME: replace this loop with external linkage check once issue #92071731// is resolved732bool isExport = FD->isInExportDeclContext();733if (!isExport) {734for (const auto *Redecl : FD->redecls()) {735if (Redecl->isInExportDeclContext()) {736isExport = true;737break;738}739}740}741if (isExport) {742SetUnknownShaderStageContext();743RunOnFunction(FD);744continue;745}746}747}748}749750void DiagnoseHLSLAvailability::RunOnFunction(const FunctionDecl *FD) {751assert(DeclsToScan.empty() && "DeclsToScan should be empty");752DeclsToScan.push_back(FD);753754while (!DeclsToScan.empty()) {755// Take one decl from the stack and check it by traversing its AST.756// For any CallExpr found during the traversal add it's callee to the top of757// the stack to be processed next. Functions already processed are stored in758// ScannedDecls.759const FunctionDecl *FD = DeclsToScan.pop_back_val();760761// Decl was already scanned762const unsigned ScannedStages = GetScannedStages(FD);763if (WasAlreadyScannedInCurrentStage(ScannedStages))764continue;765766ReportOnlyShaderStageIssues = !NeverBeenScanned(ScannedStages);767768AddToScannedFunctions(FD);769TraverseStmt(FD->getBody());770}771}772773bool DiagnoseHLSLAvailability::HasMatchingEnvironmentOrNone(774const AvailabilityAttr *AA) {775IdentifierInfo *IIEnvironment = AA->getEnvironment();776if (!IIEnvironment)777return true;778779llvm::Triple::EnvironmentType CurrentEnv = GetCurrentShaderEnvironment();780if (CurrentEnv == llvm::Triple::UnknownEnvironment)781return false;782783llvm::Triple::EnvironmentType AttrEnv =784AvailabilityAttr::getEnvironmentType(IIEnvironment->getName());785786return CurrentEnv == AttrEnv;787}788789const AvailabilityAttr *790DiagnoseHLSLAvailability::FindAvailabilityAttr(const Decl *D) {791AvailabilityAttr const *PartialMatch = nullptr;792// Check each AvailabilityAttr to find the one for this platform.793// For multiple attributes with the same platform try to find one for this794// environment.795for (const auto *A : D->attrs()) {796if (const auto *Avail = dyn_cast<AvailabilityAttr>(A)) {797StringRef AttrPlatform = Avail->getPlatform()->getName();798StringRef TargetPlatform =799SemaRef.getASTContext().getTargetInfo().getPlatformName();800801// Match the platform name.802if (AttrPlatform == TargetPlatform) {803// Find the best matching attribute for this environment804if (HasMatchingEnvironmentOrNone(Avail))805return Avail;806PartialMatch = Avail;807}808}809}810return PartialMatch;811}812813// Check availability against target shader model version and current shader814// stage and emit diagnostic815void DiagnoseHLSLAvailability::CheckDeclAvailability(NamedDecl *D,816const AvailabilityAttr *AA,817SourceRange Range) {818819IdentifierInfo *IIEnv = AA->getEnvironment();820821if (!IIEnv) {822// The availability attribute does not have environment -> it depends only823// on shader model version and not on specific the shader stage.824825// Skip emitting the diagnostics if the diagnostic mode is set to826// strict (-fhlsl-strict-availability) because all relevant diagnostics827// were already emitted in the DiagnoseUnguardedAvailability scan828// (SemaAvailability.cpp).829if (SemaRef.getLangOpts().HLSLStrictAvailability)830return;831832// Do not report shader-stage-independent issues if scanning a function833// that was already scanned in a different shader stage context (they would834// be duplicate)835if (ReportOnlyShaderStageIssues)836return;837838} else {839// The availability attribute has environment -> we need to know840// the current stage context to property diagnose it.841if (InUnknownShaderStageContext())842return;843}844845// Check introduced version and if environment matches846bool EnvironmentMatches = HasMatchingEnvironmentOrNone(AA);847VersionTuple Introduced = AA->getIntroduced();848VersionTuple TargetVersion =849SemaRef.Context.getTargetInfo().getPlatformMinVersion();850851if (TargetVersion >= Introduced && EnvironmentMatches)852return;853854// Emit diagnostic message855const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo();856llvm::StringRef PlatformName(857AvailabilityAttr::getPrettyPlatformName(TI.getPlatformName()));858859llvm::StringRef CurrentEnvStr =860llvm::Triple::getEnvironmentTypeName(GetCurrentShaderEnvironment());861862llvm::StringRef AttrEnvStr =863AA->getEnvironment() ? AA->getEnvironment()->getName() : "";864bool UseEnvironment = !AttrEnvStr.empty();865866if (EnvironmentMatches) {867SemaRef.Diag(Range.getBegin(), diag::warn_hlsl_availability)868<< Range << D << PlatformName << Introduced.getAsString()869<< UseEnvironment << CurrentEnvStr;870} else {871SemaRef.Diag(Range.getBegin(), diag::warn_hlsl_availability_unavailable)872<< Range << D;873}874875SemaRef.Diag(D->getLocation(), diag::note_partial_availability_specified_here)876<< D << PlatformName << Introduced.getAsString()877<< SemaRef.Context.getTargetInfo().getPlatformMinVersion().getAsString()878<< UseEnvironment << AttrEnvStr << CurrentEnvStr;879}880881} // namespace882883void SemaHLSL::DiagnoseAvailabilityViolations(TranslationUnitDecl *TU) {884// Skip running the diagnostics scan if the diagnostic mode is885// strict (-fhlsl-strict-availability) and the target shader stage is known886// because all relevant diagnostics were already emitted in the887// DiagnoseUnguardedAvailability scan (SemaAvailability.cpp).888const TargetInfo &TI = SemaRef.getASTContext().getTargetInfo();889if (SemaRef.getLangOpts().HLSLStrictAvailability &&890TI.getTriple().getEnvironment() != llvm::Triple::EnvironmentType::Library)891return;892893DiagnoseHLSLAvailability(SemaRef).RunOnTranslationUnit(TU);894}895896// Helper function for CheckHLSLBuiltinFunctionCall897bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {898assert(TheCall->getNumArgs() > 1);899ExprResult A = TheCall->getArg(0);900901QualType ArgTyA = A.get()->getType();902903auto *VecTyA = ArgTyA->getAs<VectorType>();904SourceLocation BuiltinLoc = TheCall->getBeginLoc();905906for (unsigned i = 1; i < TheCall->getNumArgs(); ++i) {907ExprResult B = TheCall->getArg(i);908QualType ArgTyB = B.get()->getType();909auto *VecTyB = ArgTyB->getAs<VectorType>();910if (VecTyA == nullptr && VecTyB == nullptr)911return false;912913if (VecTyA && VecTyB) {914bool retValue = false;915if (VecTyA->getElementType() != VecTyB->getElementType()) {916// Note: type promotion is intended to be handeled via the intrinsics917// and not the builtin itself.918S->Diag(TheCall->getBeginLoc(),919diag::err_vec_builtin_incompatible_vector)920<< TheCall->getDirectCallee() << /*useAllTerminology*/ true921<< SourceRange(A.get()->getBeginLoc(), B.get()->getEndLoc());922retValue = true;923}924if (VecTyA->getNumElements() != VecTyB->getNumElements()) {925// You should only be hitting this case if you are calling the builtin926// directly. HLSL intrinsics should avoid this case via a927// HLSLVectorTruncation.928S->Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector)929<< TheCall->getDirectCallee() << /*useAllTerminology*/ true930<< SourceRange(TheCall->getArg(0)->getBeginLoc(),931TheCall->getArg(1)->getEndLoc());932retValue = true;933}934return retValue;935}936}937938// Note: if we get here one of the args is a scalar which939// requires a VectorSplat on Arg0 or Arg1940S->Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)941<< TheCall->getDirectCallee() << /*useAllTerminology*/ true942<< SourceRange(TheCall->getArg(0)->getBeginLoc(),943TheCall->getArg(1)->getEndLoc());944return true;945}946947bool CheckArgsTypesAreCorrect(948Sema *S, CallExpr *TheCall, QualType ExpectedType,949llvm::function_ref<bool(clang::QualType PassedType)> Check) {950for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {951QualType PassedType = TheCall->getArg(i)->getType();952if (Check(PassedType)) {953if (auto *VecTyA = PassedType->getAs<VectorType>())954ExpectedType = S->Context.getVectorType(955ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());956S->Diag(TheCall->getArg(0)->getBeginLoc(),957diag::err_typecheck_convert_incompatible)958<< PassedType << ExpectedType << 1 << 0 << 0;959return true;960}961}962return false;963}964965bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {966auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool {967return !PassedType->hasFloatingRepresentation();968};969return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy,970checkAllFloatTypes);971}972973bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {974auto checkFloatorHalf = [](clang::QualType PassedType) -> bool {975clang::QualType BaseType =976PassedType->isVectorType()977? PassedType->getAs<clang::VectorType>()->getElementType()978: PassedType;979return !BaseType->isHalfType() && !BaseType->isFloat32Type();980};981return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy,982checkFloatorHalf);983}984985bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) {986auto checkDoubleVector = [](clang::QualType PassedType) -> bool {987if (const auto *VecTy = PassedType->getAs<VectorType>())988return VecTy->getElementType()->isDoubleType();989return false;990};991return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy,992checkDoubleVector);993}994995bool CheckUnsignedIntRepresentation(Sema *S, CallExpr *TheCall) {996auto checkAllUnsignedTypes = [](clang::QualType PassedType) -> bool {997return !PassedType->hasUnsignedIntegerRepresentation();998};999return CheckArgsTypesAreCorrect(S, TheCall, S->Context.UnsignedIntTy,1000checkAllUnsignedTypes);1001}10021003void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,1004QualType ReturnType) {1005auto *VecTyA = TheCall->getArg(0)->getType()->getAs<VectorType>();1006if (VecTyA)1007ReturnType = S->Context.getVectorType(ReturnType, VecTyA->getNumElements(),1008VectorKind::Generic);1009TheCall->setType(ReturnType);1010}10111012// Note: returning true in this case results in CheckBuiltinFunctionCall1013// returning an ExprError1014bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {1015switch (BuiltinID) {1016case Builtin::BI__builtin_hlsl_elementwise_all:1017case Builtin::BI__builtin_hlsl_elementwise_any: {1018if (SemaRef.checkArgCount(TheCall, 1))1019return true;1020break;1021}1022case Builtin::BI__builtin_hlsl_elementwise_clamp: {1023if (SemaRef.checkArgCount(TheCall, 3))1024return true;1025if (CheckVectorElementCallArgs(&SemaRef, TheCall))1026return true;1027if (SemaRef.BuiltinElementwiseTernaryMath(1028TheCall, /*CheckForFloatArgs*/1029TheCall->getArg(0)->getType()->hasFloatingRepresentation()))1030return true;1031break;1032}1033case Builtin::BI__builtin_hlsl_dot: {1034if (SemaRef.checkArgCount(TheCall, 2))1035return true;1036if (CheckVectorElementCallArgs(&SemaRef, TheCall))1037return true;1038if (SemaRef.BuiltinVectorToScalarMath(TheCall))1039return true;1040if (CheckNoDoubleVectors(&SemaRef, TheCall))1041return true;1042break;1043}1044case Builtin::BI__builtin_hlsl_elementwise_rcp: {1045if (CheckAllArgsHaveFloatRepresentation(&SemaRef, TheCall))1046return true;1047if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))1048return true;1049break;1050}1051case Builtin::BI__builtin_hlsl_elementwise_rsqrt:1052case Builtin::BI__builtin_hlsl_elementwise_frac: {1053if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))1054return true;1055if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))1056return true;1057break;1058}1059case Builtin::BI__builtin_hlsl_elementwise_isinf: {1060if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))1061return true;1062if (SemaRef.PrepareBuiltinElementwiseMathOneArgCall(TheCall))1063return true;1064SetElementTypeAsReturnType(&SemaRef, TheCall, getASTContext().BoolTy);1065break;1066}1067case Builtin::BI__builtin_hlsl_lerp: {1068if (SemaRef.checkArgCount(TheCall, 3))1069return true;1070if (CheckVectorElementCallArgs(&SemaRef, TheCall))1071return true;1072if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))1073return true;1074if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))1075return true;1076break;1077}1078case Builtin::BI__builtin_hlsl_mad: {1079if (SemaRef.checkArgCount(TheCall, 3))1080return true;1081if (CheckVectorElementCallArgs(&SemaRef, TheCall))1082return true;1083if (SemaRef.BuiltinElementwiseTernaryMath(1084TheCall, /*CheckForFloatArgs*/1085TheCall->getArg(0)->getType()->hasFloatingRepresentation()))1086return true;1087break;1088}1089// Note these are llvm builtins that we want to catch invalid intrinsic1090// generation. Normal handling of these builitns will occur elsewhere.1091case Builtin::BI__builtin_elementwise_bitreverse: {1092if (CheckUnsignedIntRepresentation(&SemaRef, TheCall))1093return true;1094break;1095}1096case Builtin::BI__builtin_elementwise_acos:1097case Builtin::BI__builtin_elementwise_asin:1098case Builtin::BI__builtin_elementwise_atan:1099case Builtin::BI__builtin_elementwise_ceil:1100case Builtin::BI__builtin_elementwise_cos:1101case Builtin::BI__builtin_elementwise_cosh:1102case Builtin::BI__builtin_elementwise_exp:1103case Builtin::BI__builtin_elementwise_exp2:1104case Builtin::BI__builtin_elementwise_floor:1105case Builtin::BI__builtin_elementwise_log:1106case Builtin::BI__builtin_elementwise_log2:1107case Builtin::BI__builtin_elementwise_log10:1108case Builtin::BI__builtin_elementwise_pow:1109case Builtin::BI__builtin_elementwise_roundeven:1110case Builtin::BI__builtin_elementwise_sin:1111case Builtin::BI__builtin_elementwise_sinh:1112case Builtin::BI__builtin_elementwise_sqrt:1113case Builtin::BI__builtin_elementwise_tan:1114case Builtin::BI__builtin_elementwise_tanh:1115case Builtin::BI__builtin_elementwise_trunc: {1116if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))1117return true;1118break;1119}1120}1121return false;1122}112311241125