Path: blob/main/contrib/llvm-project/clang/lib/CodeGen/CGHLSLRuntime.cpp
35233 views
//===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// This provides an abstract class for HLSL code generation. Concrete9// subclasses of this implement code generation for specific HLSL10// runtime libraries.11//12//===----------------------------------------------------------------------===//1314#include "CGHLSLRuntime.h"15#include "CGDebugInfo.h"16#include "CodeGenModule.h"17#include "clang/AST/Decl.h"18#include "clang/Basic/TargetOptions.h"19#include "llvm/IR/Metadata.h"20#include "llvm/IR/Module.h"21#include "llvm/Support/FormatVariadic.h"2223using namespace clang;24using namespace CodeGen;25using namespace clang::hlsl;26using namespace llvm;2728namespace {2930void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) {31// The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs.32// Assume ValVersionStr is legal here.33VersionTuple Version;34if (Version.tryParse(ValVersionStr) || Version.getBuild() ||35Version.getSubminor() || !Version.getMinor()) {36return;37}3839uint64_t Major = Version.getMajor();40uint64_t Minor = *Version.getMinor();4142auto &Ctx = M.getContext();43IRBuilder<> B(M.getContext());44MDNode *Val = MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(Major)),45ConstantAsMetadata::get(B.getInt32(Minor))});46StringRef DXILValKey = "dx.valver";47auto *DXILValMD = M.getOrInsertNamedMetadata(DXILValKey);48DXILValMD->addOperand(Val);49}50void addDisableOptimizations(llvm::Module &M) {51StringRef Key = "dx.disable_optimizations";52M.addModuleFlag(llvm::Module::ModFlagBehavior::Override, Key, 1);53}54// cbuffer will be translated into global variable in special address space.55// If translate into C,56// cbuffer A {57// float a;58// float b;59// }60// float foo() { return a + b; }61//62// will be translated into63//64// struct A {65// float a;66// float b;67// } cbuffer_A __attribute__((address_space(4)));68// float foo() { return cbuffer_A.a + cbuffer_A.b; }69//70// layoutBuffer will create the struct A type.71// replaceBuffer will replace use of global variable a and b with cbuffer_A.a72// and cbuffer_A.b.73//74void layoutBuffer(CGHLSLRuntime::Buffer &Buf, const DataLayout &DL) {75if (Buf.Constants.empty())76return;7778std::vector<llvm::Type *> EltTys;79for (auto &Const : Buf.Constants) {80GlobalVariable *GV = Const.first;81Const.second = EltTys.size();82llvm::Type *Ty = GV->getValueType();83EltTys.emplace_back(Ty);84}85Buf.LayoutStruct = llvm::StructType::get(EltTys[0]->getContext(), EltTys);86}8788GlobalVariable *replaceBuffer(CGHLSLRuntime::Buffer &Buf) {89// Create global variable for CB.90GlobalVariable *CBGV = new GlobalVariable(91Buf.LayoutStruct, /*isConstant*/ true,92GlobalValue::LinkageTypes::ExternalLinkage, nullptr,93llvm::formatv("{0}{1}", Buf.Name, Buf.IsCBuffer ? ".cb." : ".tb."),94GlobalValue::NotThreadLocal);9596IRBuilder<> B(CBGV->getContext());97Value *ZeroIdx = B.getInt32(0);98// Replace Const use with CB use.99for (auto &[GV, Offset] : Buf.Constants) {100Value *GEP =101B.CreateGEP(Buf.LayoutStruct, CBGV, {ZeroIdx, B.getInt32(Offset)});102103assert(Buf.LayoutStruct->getElementType(Offset) == GV->getValueType() &&104"constant type mismatch");105106// Replace.107GV->replaceAllUsesWith(GEP);108// Erase GV.109GV->removeDeadConstantUsers();110GV->eraseFromParent();111}112return CBGV;113}114115} // namespace116117llvm::Triple::ArchType CGHLSLRuntime::getArch() {118return CGM.getTarget().getTriple().getArch();119}120121void CGHLSLRuntime::addConstant(VarDecl *D, Buffer &CB) {122if (D->getStorageClass() == SC_Static) {123// For static inside cbuffer, take as global static.124// Don't add to cbuffer.125CGM.EmitGlobal(D);126return;127}128129auto *GV = cast<GlobalVariable>(CGM.GetAddrOfGlobalVar(D));130// Add debug info for constVal.131if (CGDebugInfo *DI = CGM.getModuleDebugInfo())132if (CGM.getCodeGenOpts().getDebugInfo() >=133codegenoptions::DebugInfoKind::LimitedDebugInfo)134DI->EmitGlobalVariable(cast<GlobalVariable>(GV), D);135136// FIXME: support packoffset.137// See https://github.com/llvm/llvm-project/issues/57914.138uint32_t Offset = 0;139bool HasUserOffset = false;140141unsigned LowerBound = HasUserOffset ? Offset : UINT_MAX;142CB.Constants.emplace_back(std::make_pair(GV, LowerBound));143}144145void CGHLSLRuntime::addBufferDecls(const DeclContext *DC, Buffer &CB) {146for (Decl *it : DC->decls()) {147if (auto *ConstDecl = dyn_cast<VarDecl>(it)) {148addConstant(ConstDecl, CB);149} else if (isa<CXXRecordDecl, EmptyDecl>(it)) {150// Nothing to do for this declaration.151} else if (isa<FunctionDecl>(it)) {152// A function within an cbuffer is effectively a top-level function,153// as it only refers to globally scoped declarations.154CGM.EmitTopLevelDecl(it);155}156}157}158159void CGHLSLRuntime::addBuffer(const HLSLBufferDecl *D) {160Buffers.emplace_back(Buffer(D));161addBufferDecls(D, Buffers.back());162}163164void CGHLSLRuntime::finishCodeGen() {165auto &TargetOpts = CGM.getTarget().getTargetOpts();166llvm::Module &M = CGM.getModule();167Triple T(M.getTargetTriple());168if (T.getArch() == Triple::ArchType::dxil)169addDxilValVersion(TargetOpts.DxilValidatorVersion, M);170171generateGlobalCtorDtorCalls();172if (CGM.getCodeGenOpts().OptimizationLevel == 0)173addDisableOptimizations(M);174175const DataLayout &DL = M.getDataLayout();176177for (auto &Buf : Buffers) {178layoutBuffer(Buf, DL);179GlobalVariable *GV = replaceBuffer(Buf);180M.insertGlobalVariable(GV);181llvm::hlsl::ResourceClass RC = Buf.IsCBuffer182? llvm::hlsl::ResourceClass::CBuffer183: llvm::hlsl::ResourceClass::SRV;184llvm::hlsl::ResourceKind RK = Buf.IsCBuffer185? llvm::hlsl::ResourceKind::CBuffer186: llvm::hlsl::ResourceKind::TBuffer;187addBufferResourceAnnotation(GV, RC, RK, /*IsROV=*/false,188llvm::hlsl::ElementType::Invalid, Buf.Binding);189}190}191192CGHLSLRuntime::Buffer::Buffer(const HLSLBufferDecl *D)193: Name(D->getName()), IsCBuffer(D->isCBuffer()),194Binding(D->getAttr<HLSLResourceBindingAttr>()) {}195196void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV,197llvm::hlsl::ResourceClass RC,198llvm::hlsl::ResourceKind RK,199bool IsROV,200llvm::hlsl::ElementType ET,201BufferResBinding &Binding) {202llvm::Module &M = CGM.getModule();203204NamedMDNode *ResourceMD = nullptr;205switch (RC) {206case llvm::hlsl::ResourceClass::UAV:207ResourceMD = M.getOrInsertNamedMetadata("hlsl.uavs");208break;209case llvm::hlsl::ResourceClass::SRV:210ResourceMD = M.getOrInsertNamedMetadata("hlsl.srvs");211break;212case llvm::hlsl::ResourceClass::CBuffer:213ResourceMD = M.getOrInsertNamedMetadata("hlsl.cbufs");214break;215default:216assert(false && "Unsupported buffer type!");217return;218}219assert(ResourceMD != nullptr &&220"ResourceMD must have been set by the switch above.");221222llvm::hlsl::FrontendResource Res(223GV, RK, ET, IsROV, Binding.Reg.value_or(UINT_MAX), Binding.Space);224ResourceMD->addOperand(Res.getMetadata());225}226227static llvm::hlsl::ElementType228calculateElementType(const ASTContext &Context, const clang::Type *ResourceTy) {229using llvm::hlsl::ElementType;230231// TODO: We may need to update this when we add things like ByteAddressBuffer232// that don't have a template parameter (or, indeed, an element type).233const auto *TST = ResourceTy->getAs<TemplateSpecializationType>();234assert(TST && "Resource types must be template specializations");235ArrayRef<TemplateArgument> Args = TST->template_arguments();236assert(!Args.empty() && "Resource has no element type");237238// At this point we have a resource with an element type, so we can assume239// that it's valid or we would have diagnosed the error earlier.240QualType ElTy = Args[0].getAsType();241242// We should either have a basic type or a vector of a basic type.243if (const auto *VecTy = ElTy->getAs<clang::VectorType>())244ElTy = VecTy->getElementType();245246if (ElTy->isSignedIntegerType()) {247switch (Context.getTypeSize(ElTy)) {248case 16:249return ElementType::I16;250case 32:251return ElementType::I32;252case 64:253return ElementType::I64;254}255} else if (ElTy->isUnsignedIntegerType()) {256switch (Context.getTypeSize(ElTy)) {257case 16:258return ElementType::U16;259case 32:260return ElementType::U32;261case 64:262return ElementType::U64;263}264} else if (ElTy->isSpecificBuiltinType(BuiltinType::Half))265return ElementType::F16;266else if (ElTy->isSpecificBuiltinType(BuiltinType::Float))267return ElementType::F32;268else if (ElTy->isSpecificBuiltinType(BuiltinType::Double))269return ElementType::F64;270271// TODO: We need to handle unorm/snorm float types here once we support them272llvm_unreachable("Invalid element type for resource");273}274275void CGHLSLRuntime::annotateHLSLResource(const VarDecl *D, GlobalVariable *GV) {276const Type *Ty = D->getType()->getPointeeOrArrayElementType();277if (!Ty)278return;279const auto *RD = Ty->getAsCXXRecordDecl();280if (!RD)281return;282const auto *HLSLResAttr = RD->getAttr<HLSLResourceAttr>();283const auto *HLSLResClassAttr = RD->getAttr<HLSLResourceClassAttr>();284if (!HLSLResAttr || !HLSLResClassAttr)285return;286287llvm::hlsl::ResourceClass RC = HLSLResClassAttr->getResourceClass();288llvm::hlsl::ResourceKind RK = HLSLResAttr->getResourceKind();289bool IsROV = HLSLResAttr->getIsROV();290llvm::hlsl::ElementType ET = calculateElementType(CGM.getContext(), Ty);291292BufferResBinding Binding(D->getAttr<HLSLResourceBindingAttr>());293addBufferResourceAnnotation(GV, RC, RK, IsROV, ET, Binding);294}295296CGHLSLRuntime::BufferResBinding::BufferResBinding(297HLSLResourceBindingAttr *Binding) {298if (Binding) {299llvm::APInt RegInt(64, 0);300Binding->getSlot().substr(1).getAsInteger(10, RegInt);301Reg = RegInt.getLimitedValue();302llvm::APInt SpaceInt(64, 0);303Binding->getSpace().substr(5).getAsInteger(10, SpaceInt);304Space = SpaceInt.getLimitedValue();305} else {306Space = 0;307}308}309310void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes(311const FunctionDecl *FD, llvm::Function *Fn) {312const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();313assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr");314const StringRef ShaderAttrKindStr = "hlsl.shader";315Fn->addFnAttr(ShaderAttrKindStr,316llvm::Triple::getEnvironmentTypeName(ShaderAttr->getType()));317if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) {318const StringRef NumThreadsKindStr = "hlsl.numthreads";319std::string NumThreadsStr =320formatv("{0},{1},{2}", NumThreadsAttr->getX(), NumThreadsAttr->getY(),321NumThreadsAttr->getZ());322Fn->addFnAttr(NumThreadsKindStr, NumThreadsStr);323}324}325326static Value *buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty) {327if (const auto *VT = dyn_cast<FixedVectorType>(Ty)) {328Value *Result = PoisonValue::get(Ty);329for (unsigned I = 0; I < VT->getNumElements(); ++I) {330Value *Elt = B.CreateCall(F, {B.getInt32(I)});331Result = B.CreateInsertElement(Result, Elt, I);332}333return Result;334}335return B.CreateCall(F, {B.getInt32(0)});336}337338llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,339const ParmVarDecl &D,340llvm::Type *Ty) {341assert(D.hasAttrs() && "Entry parameter missing annotation attribute!");342if (D.hasAttr<HLSLSV_GroupIndexAttr>()) {343llvm::Function *DxGroupIndex =344CGM.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group);345return B.CreateCall(FunctionCallee(DxGroupIndex));346}347if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) {348llvm::Function *ThreadIDIntrinsic =349CGM.getIntrinsic(getThreadIdIntrinsic());350return buildVectorInput(B, ThreadIDIntrinsic, Ty);351}352assert(false && "Unhandled parameter attribute");353return nullptr;354}355356void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,357llvm::Function *Fn) {358llvm::Module &M = CGM.getModule();359llvm::LLVMContext &Ctx = M.getContext();360auto *EntryTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false);361Function *EntryFn =362Function::Create(EntryTy, Function::ExternalLinkage, FD->getName(), &M);363364// Copy function attributes over, we have no argument or return attributes365// that can be valid on the real entry.366AttributeList NewAttrs = AttributeList::get(Ctx, AttributeList::FunctionIndex,367Fn->getAttributes().getFnAttrs());368EntryFn->setAttributes(NewAttrs);369setHLSLEntryAttributes(FD, EntryFn);370371// Set the called function as internal linkage.372Fn->setLinkage(GlobalValue::InternalLinkage);373374BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn);375IRBuilder<> B(BB);376llvm::SmallVector<Value *> Args;377// FIXME: support struct parameters where semantics are on members.378// See: https://github.com/llvm/llvm-project/issues/57874379unsigned SRetOffset = 0;380for (const auto &Param : Fn->args()) {381if (Param.hasStructRetAttr()) {382// FIXME: support output.383// See: https://github.com/llvm/llvm-project/issues/57874384SRetOffset = 1;385Args.emplace_back(PoisonValue::get(Param.getType()));386continue;387}388const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset);389Args.push_back(emitInputSemantic(B, *PD, Param.getType()));390}391392CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args);393(void)CI;394// FIXME: Handle codegen for return type semantics.395// See: https://github.com/llvm/llvm-project/issues/57875396B.CreateRetVoid();397}398399static void gatherFunctions(SmallVectorImpl<Function *> &Fns, llvm::Module &M,400bool CtorOrDtor) {401const auto *GV =402M.getNamedGlobal(CtorOrDtor ? "llvm.global_ctors" : "llvm.global_dtors");403if (!GV)404return;405const auto *CA = dyn_cast<ConstantArray>(GV->getInitializer());406if (!CA)407return;408// The global_ctor array elements are a struct [Priority, Fn *, COMDat].409// HLSL neither supports priorities or COMDat values, so we will check those410// in an assert but not handle them.411412llvm::SmallVector<Function *> CtorFns;413for (const auto &Ctor : CA->operands()) {414if (isa<ConstantAggregateZero>(Ctor))415continue;416ConstantStruct *CS = cast<ConstantStruct>(Ctor);417418assert(cast<ConstantInt>(CS->getOperand(0))->getValue() == 65535 &&419"HLSL doesn't support setting priority for global ctors.");420assert(isa<ConstantPointerNull>(CS->getOperand(2)) &&421"HLSL doesn't support COMDat for global ctors.");422Fns.push_back(cast<Function>(CS->getOperand(1)));423}424}425426void CGHLSLRuntime::generateGlobalCtorDtorCalls() {427llvm::Module &M = CGM.getModule();428SmallVector<Function *> CtorFns;429SmallVector<Function *> DtorFns;430gatherFunctions(CtorFns, M, true);431gatherFunctions(DtorFns, M, false);432433// Insert a call to the global constructor at the beginning of the entry block434// to externally exported functions. This is a bit of a hack, but HLSL allows435// global constructors, but doesn't support driver initialization of globals.436for (auto &F : M.functions()) {437if (!F.hasFnAttribute("hlsl.shader"))438continue;439IRBuilder<> B(&F.getEntryBlock(), F.getEntryBlock().begin());440for (auto *Fn : CtorFns)441B.CreateCall(FunctionCallee(Fn));442443// Insert global dtors before the terminator of the last instruction444B.SetInsertPoint(F.back().getTerminator());445for (auto *Fn : DtorFns)446B.CreateCall(FunctionCallee(Fn));447}448449// No need to keep global ctors/dtors for non-lib profile after call to450// ctors/dtors added for entry.451Triple T(M.getTargetTriple());452if (T.getEnvironment() != Triple::EnvironmentType::Library) {453if (auto *GV = M.getNamedGlobal("llvm.global_ctors"))454GV->eraseFromParent();455if (auto *GV = M.getNamedGlobal("llvm.global_dtors"))456GV->eraseFromParent();457}458}459460461