Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/clang/lib/CodeGen/CGHLSLRuntime.cpp
35233 views
1
//===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This provides an abstract class for HLSL code generation. Concrete
10
// subclasses of this implement code generation for specific HLSL
11
// runtime libraries.
12
//
13
//===----------------------------------------------------------------------===//
14
15
#include "CGHLSLRuntime.h"
16
#include "CGDebugInfo.h"
17
#include "CodeGenModule.h"
18
#include "clang/AST/Decl.h"
19
#include "clang/Basic/TargetOptions.h"
20
#include "llvm/IR/Metadata.h"
21
#include "llvm/IR/Module.h"
22
#include "llvm/Support/FormatVariadic.h"
23
24
using namespace clang;
25
using namespace CodeGen;
26
using namespace clang::hlsl;
27
using namespace llvm;
28
29
namespace {
30
31
void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) {
32
// The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs.
33
// Assume ValVersionStr is legal here.
34
VersionTuple Version;
35
if (Version.tryParse(ValVersionStr) || Version.getBuild() ||
36
Version.getSubminor() || !Version.getMinor()) {
37
return;
38
}
39
40
uint64_t Major = Version.getMajor();
41
uint64_t Minor = *Version.getMinor();
42
43
auto &Ctx = M.getContext();
44
IRBuilder<> B(M.getContext());
45
MDNode *Val = MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(Major)),
46
ConstantAsMetadata::get(B.getInt32(Minor))});
47
StringRef DXILValKey = "dx.valver";
48
auto *DXILValMD = M.getOrInsertNamedMetadata(DXILValKey);
49
DXILValMD->addOperand(Val);
50
}
51
void addDisableOptimizations(llvm::Module &M) {
52
StringRef Key = "dx.disable_optimizations";
53
M.addModuleFlag(llvm::Module::ModFlagBehavior::Override, Key, 1);
54
}
55
// cbuffer will be translated into global variable in special address space.
56
// If translate into C,
57
// cbuffer A {
58
// float a;
59
// float b;
60
// }
61
// float foo() { return a + b; }
62
//
63
// will be translated into
64
//
65
// struct A {
66
// float a;
67
// float b;
68
// } cbuffer_A __attribute__((address_space(4)));
69
// float foo() { return cbuffer_A.a + cbuffer_A.b; }
70
//
71
// layoutBuffer will create the struct A type.
72
// replaceBuffer will replace use of global variable a and b with cbuffer_A.a
73
// and cbuffer_A.b.
74
//
75
void layoutBuffer(CGHLSLRuntime::Buffer &Buf, const DataLayout &DL) {
76
if (Buf.Constants.empty())
77
return;
78
79
std::vector<llvm::Type *> EltTys;
80
for (auto &Const : Buf.Constants) {
81
GlobalVariable *GV = Const.first;
82
Const.second = EltTys.size();
83
llvm::Type *Ty = GV->getValueType();
84
EltTys.emplace_back(Ty);
85
}
86
Buf.LayoutStruct = llvm::StructType::get(EltTys[0]->getContext(), EltTys);
87
}
88
89
GlobalVariable *replaceBuffer(CGHLSLRuntime::Buffer &Buf) {
90
// Create global variable for CB.
91
GlobalVariable *CBGV = new GlobalVariable(
92
Buf.LayoutStruct, /*isConstant*/ true,
93
GlobalValue::LinkageTypes::ExternalLinkage, nullptr,
94
llvm::formatv("{0}{1}", Buf.Name, Buf.IsCBuffer ? ".cb." : ".tb."),
95
GlobalValue::NotThreadLocal);
96
97
IRBuilder<> B(CBGV->getContext());
98
Value *ZeroIdx = B.getInt32(0);
99
// Replace Const use with CB use.
100
for (auto &[GV, Offset] : Buf.Constants) {
101
Value *GEP =
102
B.CreateGEP(Buf.LayoutStruct, CBGV, {ZeroIdx, B.getInt32(Offset)});
103
104
assert(Buf.LayoutStruct->getElementType(Offset) == GV->getValueType() &&
105
"constant type mismatch");
106
107
// Replace.
108
GV->replaceAllUsesWith(GEP);
109
// Erase GV.
110
GV->removeDeadConstantUsers();
111
GV->eraseFromParent();
112
}
113
return CBGV;
114
}
115
116
} // namespace
117
118
llvm::Triple::ArchType CGHLSLRuntime::getArch() {
119
return CGM.getTarget().getTriple().getArch();
120
}
121
122
void CGHLSLRuntime::addConstant(VarDecl *D, Buffer &CB) {
123
if (D->getStorageClass() == SC_Static) {
124
// For static inside cbuffer, take as global static.
125
// Don't add to cbuffer.
126
CGM.EmitGlobal(D);
127
return;
128
}
129
130
auto *GV = cast<GlobalVariable>(CGM.GetAddrOfGlobalVar(D));
131
// Add debug info for constVal.
132
if (CGDebugInfo *DI = CGM.getModuleDebugInfo())
133
if (CGM.getCodeGenOpts().getDebugInfo() >=
134
codegenoptions::DebugInfoKind::LimitedDebugInfo)
135
DI->EmitGlobalVariable(cast<GlobalVariable>(GV), D);
136
137
// FIXME: support packoffset.
138
// See https://github.com/llvm/llvm-project/issues/57914.
139
uint32_t Offset = 0;
140
bool HasUserOffset = false;
141
142
unsigned LowerBound = HasUserOffset ? Offset : UINT_MAX;
143
CB.Constants.emplace_back(std::make_pair(GV, LowerBound));
144
}
145
146
void CGHLSLRuntime::addBufferDecls(const DeclContext *DC, Buffer &CB) {
147
for (Decl *it : DC->decls()) {
148
if (auto *ConstDecl = dyn_cast<VarDecl>(it)) {
149
addConstant(ConstDecl, CB);
150
} else if (isa<CXXRecordDecl, EmptyDecl>(it)) {
151
// Nothing to do for this declaration.
152
} else if (isa<FunctionDecl>(it)) {
153
// A function within an cbuffer is effectively a top-level function,
154
// as it only refers to globally scoped declarations.
155
CGM.EmitTopLevelDecl(it);
156
}
157
}
158
}
159
160
void CGHLSLRuntime::addBuffer(const HLSLBufferDecl *D) {
161
Buffers.emplace_back(Buffer(D));
162
addBufferDecls(D, Buffers.back());
163
}
164
165
void CGHLSLRuntime::finishCodeGen() {
166
auto &TargetOpts = CGM.getTarget().getTargetOpts();
167
llvm::Module &M = CGM.getModule();
168
Triple T(M.getTargetTriple());
169
if (T.getArch() == Triple::ArchType::dxil)
170
addDxilValVersion(TargetOpts.DxilValidatorVersion, M);
171
172
generateGlobalCtorDtorCalls();
173
if (CGM.getCodeGenOpts().OptimizationLevel == 0)
174
addDisableOptimizations(M);
175
176
const DataLayout &DL = M.getDataLayout();
177
178
for (auto &Buf : Buffers) {
179
layoutBuffer(Buf, DL);
180
GlobalVariable *GV = replaceBuffer(Buf);
181
M.insertGlobalVariable(GV);
182
llvm::hlsl::ResourceClass RC = Buf.IsCBuffer
183
? llvm::hlsl::ResourceClass::CBuffer
184
: llvm::hlsl::ResourceClass::SRV;
185
llvm::hlsl::ResourceKind RK = Buf.IsCBuffer
186
? llvm::hlsl::ResourceKind::CBuffer
187
: llvm::hlsl::ResourceKind::TBuffer;
188
addBufferResourceAnnotation(GV, RC, RK, /*IsROV=*/false,
189
llvm::hlsl::ElementType::Invalid, Buf.Binding);
190
}
191
}
192
193
CGHLSLRuntime::Buffer::Buffer(const HLSLBufferDecl *D)
194
: Name(D->getName()), IsCBuffer(D->isCBuffer()),
195
Binding(D->getAttr<HLSLResourceBindingAttr>()) {}
196
197
void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV,
198
llvm::hlsl::ResourceClass RC,
199
llvm::hlsl::ResourceKind RK,
200
bool IsROV,
201
llvm::hlsl::ElementType ET,
202
BufferResBinding &Binding) {
203
llvm::Module &M = CGM.getModule();
204
205
NamedMDNode *ResourceMD = nullptr;
206
switch (RC) {
207
case llvm::hlsl::ResourceClass::UAV:
208
ResourceMD = M.getOrInsertNamedMetadata("hlsl.uavs");
209
break;
210
case llvm::hlsl::ResourceClass::SRV:
211
ResourceMD = M.getOrInsertNamedMetadata("hlsl.srvs");
212
break;
213
case llvm::hlsl::ResourceClass::CBuffer:
214
ResourceMD = M.getOrInsertNamedMetadata("hlsl.cbufs");
215
break;
216
default:
217
assert(false && "Unsupported buffer type!");
218
return;
219
}
220
assert(ResourceMD != nullptr &&
221
"ResourceMD must have been set by the switch above.");
222
223
llvm::hlsl::FrontendResource Res(
224
GV, RK, ET, IsROV, Binding.Reg.value_or(UINT_MAX), Binding.Space);
225
ResourceMD->addOperand(Res.getMetadata());
226
}
227
228
static llvm::hlsl::ElementType
229
calculateElementType(const ASTContext &Context, const clang::Type *ResourceTy) {
230
using llvm::hlsl::ElementType;
231
232
// TODO: We may need to update this when we add things like ByteAddressBuffer
233
// that don't have a template parameter (or, indeed, an element type).
234
const auto *TST = ResourceTy->getAs<TemplateSpecializationType>();
235
assert(TST && "Resource types must be template specializations");
236
ArrayRef<TemplateArgument> Args = TST->template_arguments();
237
assert(!Args.empty() && "Resource has no element type");
238
239
// At this point we have a resource with an element type, so we can assume
240
// that it's valid or we would have diagnosed the error earlier.
241
QualType ElTy = Args[0].getAsType();
242
243
// We should either have a basic type or a vector of a basic type.
244
if (const auto *VecTy = ElTy->getAs<clang::VectorType>())
245
ElTy = VecTy->getElementType();
246
247
if (ElTy->isSignedIntegerType()) {
248
switch (Context.getTypeSize(ElTy)) {
249
case 16:
250
return ElementType::I16;
251
case 32:
252
return ElementType::I32;
253
case 64:
254
return ElementType::I64;
255
}
256
} else if (ElTy->isUnsignedIntegerType()) {
257
switch (Context.getTypeSize(ElTy)) {
258
case 16:
259
return ElementType::U16;
260
case 32:
261
return ElementType::U32;
262
case 64:
263
return ElementType::U64;
264
}
265
} else if (ElTy->isSpecificBuiltinType(BuiltinType::Half))
266
return ElementType::F16;
267
else if (ElTy->isSpecificBuiltinType(BuiltinType::Float))
268
return ElementType::F32;
269
else if (ElTy->isSpecificBuiltinType(BuiltinType::Double))
270
return ElementType::F64;
271
272
// TODO: We need to handle unorm/snorm float types here once we support them
273
llvm_unreachable("Invalid element type for resource");
274
}
275
276
void CGHLSLRuntime::annotateHLSLResource(const VarDecl *D, GlobalVariable *GV) {
277
const Type *Ty = D->getType()->getPointeeOrArrayElementType();
278
if (!Ty)
279
return;
280
const auto *RD = Ty->getAsCXXRecordDecl();
281
if (!RD)
282
return;
283
const auto *HLSLResAttr = RD->getAttr<HLSLResourceAttr>();
284
const auto *HLSLResClassAttr = RD->getAttr<HLSLResourceClassAttr>();
285
if (!HLSLResAttr || !HLSLResClassAttr)
286
return;
287
288
llvm::hlsl::ResourceClass RC = HLSLResClassAttr->getResourceClass();
289
llvm::hlsl::ResourceKind RK = HLSLResAttr->getResourceKind();
290
bool IsROV = HLSLResAttr->getIsROV();
291
llvm::hlsl::ElementType ET = calculateElementType(CGM.getContext(), Ty);
292
293
BufferResBinding Binding(D->getAttr<HLSLResourceBindingAttr>());
294
addBufferResourceAnnotation(GV, RC, RK, IsROV, ET, Binding);
295
}
296
297
CGHLSLRuntime::BufferResBinding::BufferResBinding(
298
HLSLResourceBindingAttr *Binding) {
299
if (Binding) {
300
llvm::APInt RegInt(64, 0);
301
Binding->getSlot().substr(1).getAsInteger(10, RegInt);
302
Reg = RegInt.getLimitedValue();
303
llvm::APInt SpaceInt(64, 0);
304
Binding->getSpace().substr(5).getAsInteger(10, SpaceInt);
305
Space = SpaceInt.getLimitedValue();
306
} else {
307
Space = 0;
308
}
309
}
310
311
void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes(
312
const FunctionDecl *FD, llvm::Function *Fn) {
313
const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
314
assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr");
315
const StringRef ShaderAttrKindStr = "hlsl.shader";
316
Fn->addFnAttr(ShaderAttrKindStr,
317
llvm::Triple::getEnvironmentTypeName(ShaderAttr->getType()));
318
if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) {
319
const StringRef NumThreadsKindStr = "hlsl.numthreads";
320
std::string NumThreadsStr =
321
formatv("{0},{1},{2}", NumThreadsAttr->getX(), NumThreadsAttr->getY(),
322
NumThreadsAttr->getZ());
323
Fn->addFnAttr(NumThreadsKindStr, NumThreadsStr);
324
}
325
}
326
327
static Value *buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty) {
328
if (const auto *VT = dyn_cast<FixedVectorType>(Ty)) {
329
Value *Result = PoisonValue::get(Ty);
330
for (unsigned I = 0; I < VT->getNumElements(); ++I) {
331
Value *Elt = B.CreateCall(F, {B.getInt32(I)});
332
Result = B.CreateInsertElement(Result, Elt, I);
333
}
334
return Result;
335
}
336
return B.CreateCall(F, {B.getInt32(0)});
337
}
338
339
llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,
340
const ParmVarDecl &D,
341
llvm::Type *Ty) {
342
assert(D.hasAttrs() && "Entry parameter missing annotation attribute!");
343
if (D.hasAttr<HLSLSV_GroupIndexAttr>()) {
344
llvm::Function *DxGroupIndex =
345
CGM.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group);
346
return B.CreateCall(FunctionCallee(DxGroupIndex));
347
}
348
if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) {
349
llvm::Function *ThreadIDIntrinsic =
350
CGM.getIntrinsic(getThreadIdIntrinsic());
351
return buildVectorInput(B, ThreadIDIntrinsic, Ty);
352
}
353
assert(false && "Unhandled parameter attribute");
354
return nullptr;
355
}
356
357
void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
358
llvm::Function *Fn) {
359
llvm::Module &M = CGM.getModule();
360
llvm::LLVMContext &Ctx = M.getContext();
361
auto *EntryTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false);
362
Function *EntryFn =
363
Function::Create(EntryTy, Function::ExternalLinkage, FD->getName(), &M);
364
365
// Copy function attributes over, we have no argument or return attributes
366
// that can be valid on the real entry.
367
AttributeList NewAttrs = AttributeList::get(Ctx, AttributeList::FunctionIndex,
368
Fn->getAttributes().getFnAttrs());
369
EntryFn->setAttributes(NewAttrs);
370
setHLSLEntryAttributes(FD, EntryFn);
371
372
// Set the called function as internal linkage.
373
Fn->setLinkage(GlobalValue::InternalLinkage);
374
375
BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn);
376
IRBuilder<> B(BB);
377
llvm::SmallVector<Value *> Args;
378
// FIXME: support struct parameters where semantics are on members.
379
// See: https://github.com/llvm/llvm-project/issues/57874
380
unsigned SRetOffset = 0;
381
for (const auto &Param : Fn->args()) {
382
if (Param.hasStructRetAttr()) {
383
// FIXME: support output.
384
// See: https://github.com/llvm/llvm-project/issues/57874
385
SRetOffset = 1;
386
Args.emplace_back(PoisonValue::get(Param.getType()));
387
continue;
388
}
389
const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset);
390
Args.push_back(emitInputSemantic(B, *PD, Param.getType()));
391
}
392
393
CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args);
394
(void)CI;
395
// FIXME: Handle codegen for return type semantics.
396
// See: https://github.com/llvm/llvm-project/issues/57875
397
B.CreateRetVoid();
398
}
399
400
static void gatherFunctions(SmallVectorImpl<Function *> &Fns, llvm::Module &M,
401
bool CtorOrDtor) {
402
const auto *GV =
403
M.getNamedGlobal(CtorOrDtor ? "llvm.global_ctors" : "llvm.global_dtors");
404
if (!GV)
405
return;
406
const auto *CA = dyn_cast<ConstantArray>(GV->getInitializer());
407
if (!CA)
408
return;
409
// The global_ctor array elements are a struct [Priority, Fn *, COMDat].
410
// HLSL neither supports priorities or COMDat values, so we will check those
411
// in an assert but not handle them.
412
413
llvm::SmallVector<Function *> CtorFns;
414
for (const auto &Ctor : CA->operands()) {
415
if (isa<ConstantAggregateZero>(Ctor))
416
continue;
417
ConstantStruct *CS = cast<ConstantStruct>(Ctor);
418
419
assert(cast<ConstantInt>(CS->getOperand(0))->getValue() == 65535 &&
420
"HLSL doesn't support setting priority for global ctors.");
421
assert(isa<ConstantPointerNull>(CS->getOperand(2)) &&
422
"HLSL doesn't support COMDat for global ctors.");
423
Fns.push_back(cast<Function>(CS->getOperand(1)));
424
}
425
}
426
427
void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
428
llvm::Module &M = CGM.getModule();
429
SmallVector<Function *> CtorFns;
430
SmallVector<Function *> DtorFns;
431
gatherFunctions(CtorFns, M, true);
432
gatherFunctions(DtorFns, M, false);
433
434
// Insert a call to the global constructor at the beginning of the entry block
435
// to externally exported functions. This is a bit of a hack, but HLSL allows
436
// global constructors, but doesn't support driver initialization of globals.
437
for (auto &F : M.functions()) {
438
if (!F.hasFnAttribute("hlsl.shader"))
439
continue;
440
IRBuilder<> B(&F.getEntryBlock(), F.getEntryBlock().begin());
441
for (auto *Fn : CtorFns)
442
B.CreateCall(FunctionCallee(Fn));
443
444
// Insert global dtors before the terminator of the last instruction
445
B.SetInsertPoint(F.back().getTerminator());
446
for (auto *Fn : DtorFns)
447
B.CreateCall(FunctionCallee(Fn));
448
}
449
450
// No need to keep global ctors/dtors for non-lib profile after call to
451
// ctors/dtors added for entry.
452
Triple T(M.getTargetTriple());
453
if (T.getEnvironment() != Triple::EnvironmentType::Library) {
454
if (auto *GV = M.getNamedGlobal("llvm.global_ctors"))
455
GV->eraseFromParent();
456
if (auto *GV = M.getNamedGlobal("llvm.global_dtors"))
457
GV->eraseFromParent();
458
}
459
}
460
461