Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Frontend/Offloading/OffloadWrapper.cpp
35271 views
1
//===- OffloadWrapper.cpp ---------------------------------------*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
9
#include "llvm/Frontend/Offloading/OffloadWrapper.h"
10
#include "llvm/ADT/ArrayRef.h"
11
#include "llvm/BinaryFormat/Magic.h"
12
#include "llvm/Frontend/Offloading/Utility.h"
13
#include "llvm/IR/Constants.h"
14
#include "llvm/IR/GlobalVariable.h"
15
#include "llvm/IR/IRBuilder.h"
16
#include "llvm/IR/LLVMContext.h"
17
#include "llvm/IR/Module.h"
18
#include "llvm/Object/OffloadBinary.h"
19
#include "llvm/Support/Error.h"
20
#include "llvm/TargetParser/Triple.h"
21
#include "llvm/Transforms/Utils/ModuleUtils.h"
22
23
using namespace llvm;
24
using namespace llvm::offloading;
25
26
namespace {
27
/// Magic number that begins the section containing the CUDA fatbinary.
28
constexpr unsigned CudaFatMagic = 0x466243b1;
29
constexpr unsigned HIPFatMagic = 0x48495046;
30
31
IntegerType *getSizeTTy(Module &M) {
32
return M.getDataLayout().getIntPtrType(M.getContext());
33
}
34
35
// struct __tgt_device_image {
36
// void *ImageStart;
37
// void *ImageEnd;
38
// __tgt_offload_entry *EntriesBegin;
39
// __tgt_offload_entry *EntriesEnd;
40
// };
41
StructType *getDeviceImageTy(Module &M) {
42
LLVMContext &C = M.getContext();
43
StructType *ImageTy = StructType::getTypeByName(C, "__tgt_device_image");
44
if (!ImageTy)
45
ImageTy =
46
StructType::create("__tgt_device_image", PointerType::getUnqual(C),
47
PointerType::getUnqual(C), PointerType::getUnqual(C),
48
PointerType::getUnqual(C));
49
return ImageTy;
50
}
51
52
PointerType *getDeviceImagePtrTy(Module &M) {
53
return PointerType::getUnqual(getDeviceImageTy(M));
54
}
55
56
// struct __tgt_bin_desc {
57
// int32_t NumDeviceImages;
58
// __tgt_device_image *DeviceImages;
59
// __tgt_offload_entry *HostEntriesBegin;
60
// __tgt_offload_entry *HostEntriesEnd;
61
// };
62
StructType *getBinDescTy(Module &M) {
63
LLVMContext &C = M.getContext();
64
StructType *DescTy = StructType::getTypeByName(C, "__tgt_bin_desc");
65
if (!DescTy)
66
DescTy = StructType::create(
67
"__tgt_bin_desc", Type::getInt32Ty(C), getDeviceImagePtrTy(M),
68
PointerType::getUnqual(C), PointerType::getUnqual(C));
69
return DescTy;
70
}
71
72
PointerType *getBinDescPtrTy(Module &M) {
73
return PointerType::getUnqual(getBinDescTy(M));
74
}
75
76
/// Creates binary descriptor for the given device images. Binary descriptor
77
/// is an object that is passed to the offloading runtime at program startup
78
/// and it describes all device images available in the executable or shared
79
/// library. It is defined as follows
80
///
81
/// __attribute__((visibility("hidden")))
82
/// extern __tgt_offload_entry *__start_omp_offloading_entries;
83
/// __attribute__((visibility("hidden")))
84
/// extern __tgt_offload_entry *__stop_omp_offloading_entries;
85
///
86
/// static const char Image0[] = { <Bufs.front() contents> };
87
/// ...
88
/// static const char ImageN[] = { <Bufs.back() contents> };
89
///
90
/// static const __tgt_device_image Images[] = {
91
/// {
92
/// Image0, /*ImageStart*/
93
/// Image0 + sizeof(Image0), /*ImageEnd*/
94
/// __start_omp_offloading_entries, /*EntriesBegin*/
95
/// __stop_omp_offloading_entries /*EntriesEnd*/
96
/// },
97
/// ...
98
/// {
99
/// ImageN, /*ImageStart*/
100
/// ImageN + sizeof(ImageN), /*ImageEnd*/
101
/// __start_omp_offloading_entries, /*EntriesBegin*/
102
/// __stop_omp_offloading_entries /*EntriesEnd*/
103
/// }
104
/// };
105
///
106
/// static const __tgt_bin_desc BinDesc = {
107
/// sizeof(Images) / sizeof(Images[0]), /*NumDeviceImages*/
108
/// Images, /*DeviceImages*/
109
/// __start_omp_offloading_entries, /*HostEntriesBegin*/
110
/// __stop_omp_offloading_entries /*HostEntriesEnd*/
111
/// };
112
///
113
/// Global variable that represents BinDesc is returned.
114
GlobalVariable *createBinDesc(Module &M, ArrayRef<ArrayRef<char>> Bufs,
115
EntryArrayTy EntryArray, StringRef Suffix,
116
bool Relocatable) {
117
LLVMContext &C = M.getContext();
118
auto [EntriesB, EntriesE] = EntryArray;
119
120
auto *Zero = ConstantInt::get(getSizeTTy(M), 0u);
121
Constant *ZeroZero[] = {Zero, Zero};
122
123
// Create initializer for the images array.
124
SmallVector<Constant *, 4u> ImagesInits;
125
ImagesInits.reserve(Bufs.size());
126
for (ArrayRef<char> Buf : Bufs) {
127
// We embed the full offloading entry so the binary utilities can parse it.
128
auto *Data = ConstantDataArray::get(C, Buf);
129
auto *Image = new GlobalVariable(M, Data->getType(), /*isConstant=*/true,
130
GlobalVariable::InternalLinkage, Data,
131
".omp_offloading.device_image" + Suffix);
132
Image->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
133
Image->setSection(Relocatable ? ".llvm.offloading.relocatable"
134
: ".llvm.offloading");
135
Image->setAlignment(Align(object::OffloadBinary::getAlignment()));
136
137
StringRef Binary(Buf.data(), Buf.size());
138
assert(identify_magic(Binary) == file_magic::offload_binary &&
139
"Invalid binary format");
140
141
// The device image struct contains the pointer to the beginning and end of
142
// the image stored inside of the offload binary. There should only be one
143
// of these for each buffer so we parse it out manually.
144
const auto *Header =
145
reinterpret_cast<const object::OffloadBinary::Header *>(
146
Binary.bytes_begin());
147
const auto *Entry = reinterpret_cast<const object::OffloadBinary::Entry *>(
148
Binary.bytes_begin() + Header->EntryOffset);
149
150
auto *Begin = ConstantInt::get(getSizeTTy(M), Entry->ImageOffset);
151
auto *Size =
152
ConstantInt::get(getSizeTTy(M), Entry->ImageOffset + Entry->ImageSize);
153
Constant *ZeroBegin[] = {Zero, Begin};
154
Constant *ZeroSize[] = {Zero, Size};
155
156
auto *ImageB =
157
ConstantExpr::getGetElementPtr(Image->getValueType(), Image, ZeroBegin);
158
auto *ImageE =
159
ConstantExpr::getGetElementPtr(Image->getValueType(), Image, ZeroSize);
160
161
ImagesInits.push_back(ConstantStruct::get(getDeviceImageTy(M), ImageB,
162
ImageE, EntriesB, EntriesE));
163
}
164
165
// Then create images array.
166
auto *ImagesData = ConstantArray::get(
167
ArrayType::get(getDeviceImageTy(M), ImagesInits.size()), ImagesInits);
168
169
auto *Images =
170
new GlobalVariable(M, ImagesData->getType(), /*isConstant*/ true,
171
GlobalValue::InternalLinkage, ImagesData,
172
".omp_offloading.device_images" + Suffix);
173
Images->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
174
175
auto *ImagesB =
176
ConstantExpr::getGetElementPtr(Images->getValueType(), Images, ZeroZero);
177
178
// And finally create the binary descriptor object.
179
auto *DescInit = ConstantStruct::get(
180
getBinDescTy(M),
181
ConstantInt::get(Type::getInt32Ty(C), ImagesInits.size()), ImagesB,
182
EntriesB, EntriesE);
183
184
return new GlobalVariable(M, DescInit->getType(), /*isConstant*/ true,
185
GlobalValue::InternalLinkage, DescInit,
186
".omp_offloading.descriptor" + Suffix);
187
}
188
189
Function *createUnregisterFunction(Module &M, GlobalVariable *BinDesc,
190
StringRef Suffix) {
191
LLVMContext &C = M.getContext();
192
auto *FuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
193
auto *Func =
194
Function::Create(FuncTy, GlobalValue::InternalLinkage,
195
".omp_offloading.descriptor_unreg" + Suffix, &M);
196
Func->setSection(".text.startup");
197
198
// Get __tgt_unregister_lib function declaration.
199
auto *UnRegFuncTy = FunctionType::get(Type::getVoidTy(C), getBinDescPtrTy(M),
200
/*isVarArg*/ false);
201
FunctionCallee UnRegFuncC =
202
M.getOrInsertFunction("__tgt_unregister_lib", UnRegFuncTy);
203
204
// Construct function body
205
IRBuilder<> Builder(BasicBlock::Create(C, "entry", Func));
206
Builder.CreateCall(UnRegFuncC, BinDesc);
207
Builder.CreateRetVoid();
208
209
return Func;
210
}
211
212
void createRegisterFunction(Module &M, GlobalVariable *BinDesc,
213
StringRef Suffix) {
214
LLVMContext &C = M.getContext();
215
auto *FuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
216
auto *Func = Function::Create(FuncTy, GlobalValue::InternalLinkage,
217
".omp_offloading.descriptor_reg" + Suffix, &M);
218
Func->setSection(".text.startup");
219
220
// Get __tgt_register_lib function declaration.
221
auto *RegFuncTy = FunctionType::get(Type::getVoidTy(C), getBinDescPtrTy(M),
222
/*isVarArg*/ false);
223
FunctionCallee RegFuncC =
224
M.getOrInsertFunction("__tgt_register_lib", RegFuncTy);
225
226
auto *AtExitTy = FunctionType::get(
227
Type::getInt32Ty(C), PointerType::getUnqual(C), /*isVarArg=*/false);
228
FunctionCallee AtExit = M.getOrInsertFunction("atexit", AtExitTy);
229
230
Function *UnregFunc = createUnregisterFunction(M, BinDesc, Suffix);
231
232
// Construct function body
233
IRBuilder<> Builder(BasicBlock::Create(C, "entry", Func));
234
235
Builder.CreateCall(RegFuncC, BinDesc);
236
237
// Register the destructors with 'atexit'. This is expected by the CUDA
238
// runtime and ensures that we clean up before dynamic objects are destroyed.
239
// This needs to be done after plugin initialization to ensure that it is
240
// called before the plugin runtime is destroyed.
241
Builder.CreateCall(AtExit, UnregFunc);
242
Builder.CreateRetVoid();
243
244
// Add this function to constructors.
245
appendToGlobalCtors(M, Func, /*Priority=*/101);
246
}
247
248
// struct fatbin_wrapper {
249
// int32_t magic;
250
// int32_t version;
251
// void *image;
252
// void *reserved;
253
//};
254
StructType *getFatbinWrapperTy(Module &M) {
255
LLVMContext &C = M.getContext();
256
StructType *FatbinTy = StructType::getTypeByName(C, "fatbin_wrapper");
257
if (!FatbinTy)
258
FatbinTy = StructType::create(
259
"fatbin_wrapper", Type::getInt32Ty(C), Type::getInt32Ty(C),
260
PointerType::getUnqual(C), PointerType::getUnqual(C));
261
return FatbinTy;
262
}
263
264
/// Embed the image \p Image into the module \p M so it can be found by the
265
/// runtime.
266
GlobalVariable *createFatbinDesc(Module &M, ArrayRef<char> Image, bool IsHIP,
267
StringRef Suffix) {
268
LLVMContext &C = M.getContext();
269
llvm::Type *Int8PtrTy = PointerType::getUnqual(C);
270
llvm::Triple Triple = llvm::Triple(M.getTargetTriple());
271
272
// Create the global string containing the fatbinary.
273
StringRef FatbinConstantSection =
274
IsHIP ? ".hip_fatbin"
275
: (Triple.isMacOSX() ? "__NV_CUDA,__nv_fatbin" : ".nv_fatbin");
276
auto *Data = ConstantDataArray::get(C, Image);
277
auto *Fatbin = new GlobalVariable(M, Data->getType(), /*isConstant*/ true,
278
GlobalVariable::InternalLinkage, Data,
279
".fatbin_image" + Suffix);
280
Fatbin->setSection(FatbinConstantSection);
281
282
// Create the fatbinary wrapper
283
StringRef FatbinWrapperSection = IsHIP ? ".hipFatBinSegment"
284
: Triple.isMacOSX() ? "__NV_CUDA,__fatbin"
285
: ".nvFatBinSegment";
286
Constant *FatbinWrapper[] = {
287
ConstantInt::get(Type::getInt32Ty(C), IsHIP ? HIPFatMagic : CudaFatMagic),
288
ConstantInt::get(Type::getInt32Ty(C), 1),
289
ConstantExpr::getPointerBitCastOrAddrSpaceCast(Fatbin, Int8PtrTy),
290
ConstantPointerNull::get(PointerType::getUnqual(C))};
291
292
Constant *FatbinInitializer =
293
ConstantStruct::get(getFatbinWrapperTy(M), FatbinWrapper);
294
295
auto *FatbinDesc =
296
new GlobalVariable(M, getFatbinWrapperTy(M),
297
/*isConstant*/ true, GlobalValue::InternalLinkage,
298
FatbinInitializer, ".fatbin_wrapper" + Suffix);
299
FatbinDesc->setSection(FatbinWrapperSection);
300
FatbinDesc->setAlignment(Align(8));
301
302
return FatbinDesc;
303
}
304
305
/// Create the register globals function. We will iterate all of the offloading
306
/// entries stored at the begin / end symbols and register them according to
307
/// their type. This creates the following function in IR:
308
///
309
/// extern struct __tgt_offload_entry __start_cuda_offloading_entries;
310
/// extern struct __tgt_offload_entry __stop_cuda_offloading_entries;
311
///
312
/// extern void __cudaRegisterFunction(void **, void *, void *, void *, int,
313
/// void *, void *, void *, void *, int *);
314
/// extern void __cudaRegisterVar(void **, void *, void *, void *, int32_t,
315
/// int64_t, int32_t, int32_t);
316
///
317
/// void __cudaRegisterTest(void **fatbinHandle) {
318
/// for (struct __tgt_offload_entry *entry = &__start_cuda_offloading_entries;
319
/// entry != &__stop_cuda_offloading_entries; ++entry) {
320
/// if (!entry->size)
321
/// __cudaRegisterFunction(fatbinHandle, entry->addr, entry->name,
322
/// entry->name, -1, 0, 0, 0, 0, 0);
323
/// else
324
/// __cudaRegisterVar(fatbinHandle, entry->addr, entry->name, entry->name,
325
/// 0, entry->size, 0, 0);
326
/// }
327
/// }
328
Function *createRegisterGlobalsFunction(Module &M, bool IsHIP,
329
EntryArrayTy EntryArray,
330
StringRef Suffix,
331
bool EmitSurfacesAndTextures) {
332
LLVMContext &C = M.getContext();
333
auto [EntriesB, EntriesE] = EntryArray;
334
335
// Get the __cudaRegisterFunction function declaration.
336
PointerType *Int8PtrTy = PointerType::get(C, 0);
337
PointerType *Int8PtrPtrTy = PointerType::get(C, 0);
338
PointerType *Int32PtrTy = PointerType::get(C, 0);
339
auto *RegFuncTy = FunctionType::get(
340
Type::getInt32Ty(C),
341
{Int8PtrPtrTy, Int8PtrTy, Int8PtrTy, Int8PtrTy, Type::getInt32Ty(C),
342
Int8PtrTy, Int8PtrTy, Int8PtrTy, Int8PtrTy, Int32PtrTy},
343
/*isVarArg*/ false);
344
FunctionCallee RegFunc = M.getOrInsertFunction(
345
IsHIP ? "__hipRegisterFunction" : "__cudaRegisterFunction", RegFuncTy);
346
347
// Get the __cudaRegisterVar function declaration.
348
auto *RegVarTy = FunctionType::get(
349
Type::getVoidTy(C),
350
{Int8PtrPtrTy, Int8PtrTy, Int8PtrTy, Int8PtrTy, Type::getInt32Ty(C),
351
getSizeTTy(M), Type::getInt32Ty(C), Type::getInt32Ty(C)},
352
/*isVarArg*/ false);
353
FunctionCallee RegVar = M.getOrInsertFunction(
354
IsHIP ? "__hipRegisterVar" : "__cudaRegisterVar", RegVarTy);
355
356
// Get the __cudaRegisterSurface function declaration.
357
FunctionType *RegSurfaceTy =
358
FunctionType::get(Type::getVoidTy(C),
359
{Int8PtrPtrTy, Int8PtrTy, Int8PtrTy, Int8PtrTy,
360
Type::getInt32Ty(C), Type::getInt32Ty(C)},
361
/*isVarArg=*/false);
362
FunctionCallee RegSurface = M.getOrInsertFunction(
363
IsHIP ? "__hipRegisterSurface" : "__cudaRegisterSurface", RegSurfaceTy);
364
365
// Get the __cudaRegisterTexture function declaration.
366
FunctionType *RegTextureTy = FunctionType::get(
367
Type::getVoidTy(C),
368
{Int8PtrPtrTy, Int8PtrTy, Int8PtrTy, Int8PtrTy, Type::getInt32Ty(C),
369
Type::getInt32Ty(C), Type::getInt32Ty(C)},
370
/*isVarArg=*/false);
371
FunctionCallee RegTexture = M.getOrInsertFunction(
372
IsHIP ? "__hipRegisterTexture" : "__cudaRegisterTexture", RegTextureTy);
373
374
auto *RegGlobalsTy = FunctionType::get(Type::getVoidTy(C), Int8PtrPtrTy,
375
/*isVarArg*/ false);
376
auto *RegGlobalsFn =
377
Function::Create(RegGlobalsTy, GlobalValue::InternalLinkage,
378
IsHIP ? ".hip.globals_reg" : ".cuda.globals_reg", &M);
379
RegGlobalsFn->setSection(".text.startup");
380
381
// Create the loop to register all the entries.
382
IRBuilder<> Builder(BasicBlock::Create(C, "entry", RegGlobalsFn));
383
auto *EntryBB = BasicBlock::Create(C, "while.entry", RegGlobalsFn);
384
auto *IfThenBB = BasicBlock::Create(C, "if.then", RegGlobalsFn);
385
auto *IfElseBB = BasicBlock::Create(C, "if.else", RegGlobalsFn);
386
auto *SwGlobalBB = BasicBlock::Create(C, "sw.global", RegGlobalsFn);
387
auto *SwManagedBB = BasicBlock::Create(C, "sw.managed", RegGlobalsFn);
388
auto *SwSurfaceBB = BasicBlock::Create(C, "sw.surface", RegGlobalsFn);
389
auto *SwTextureBB = BasicBlock::Create(C, "sw.texture", RegGlobalsFn);
390
auto *IfEndBB = BasicBlock::Create(C, "if.end", RegGlobalsFn);
391
auto *ExitBB = BasicBlock::Create(C, "while.end", RegGlobalsFn);
392
393
auto *EntryCmp = Builder.CreateICmpNE(EntriesB, EntriesE);
394
Builder.CreateCondBr(EntryCmp, EntryBB, ExitBB);
395
Builder.SetInsertPoint(EntryBB);
396
auto *Entry = Builder.CreatePHI(PointerType::getUnqual(C), 2, "entry");
397
auto *AddrPtr =
398
Builder.CreateInBoundsGEP(offloading::getEntryTy(M), Entry,
399
{ConstantInt::get(getSizeTTy(M), 0),
400
ConstantInt::get(Type::getInt32Ty(C), 0)});
401
auto *Addr = Builder.CreateLoad(Int8PtrTy, AddrPtr, "addr");
402
auto *NamePtr =
403
Builder.CreateInBoundsGEP(offloading::getEntryTy(M), Entry,
404
{ConstantInt::get(getSizeTTy(M), 0),
405
ConstantInt::get(Type::getInt32Ty(C), 1)});
406
auto *Name = Builder.CreateLoad(Int8PtrTy, NamePtr, "name");
407
auto *SizePtr =
408
Builder.CreateInBoundsGEP(offloading::getEntryTy(M), Entry,
409
{ConstantInt::get(getSizeTTy(M), 0),
410
ConstantInt::get(Type::getInt32Ty(C), 2)});
411
auto *Size = Builder.CreateLoad(getSizeTTy(M), SizePtr, "size");
412
auto *FlagsPtr =
413
Builder.CreateInBoundsGEP(offloading::getEntryTy(M), Entry,
414
{ConstantInt::get(getSizeTTy(M), 0),
415
ConstantInt::get(Type::getInt32Ty(C), 3)});
416
auto *Flags = Builder.CreateLoad(Type::getInt32Ty(C), FlagsPtr, "flags");
417
auto *DataPtr =
418
Builder.CreateInBoundsGEP(offloading::getEntryTy(M), Entry,
419
{ConstantInt::get(getSizeTTy(M), 0),
420
ConstantInt::get(Type::getInt32Ty(C), 4)});
421
auto *Data = Builder.CreateLoad(Type::getInt32Ty(C), DataPtr, "textype");
422
auto *Kind = Builder.CreateAnd(
423
Flags, ConstantInt::get(Type::getInt32Ty(C), 0x7), "type");
424
425
// Extract the flags stored in the bit-field and convert them to C booleans.
426
auto *ExternBit = Builder.CreateAnd(
427
Flags, ConstantInt::get(Type::getInt32Ty(C),
428
llvm::offloading::OffloadGlobalExtern));
429
auto *Extern = Builder.CreateLShr(
430
ExternBit, ConstantInt::get(Type::getInt32Ty(C), 3), "extern");
431
auto *ConstantBit = Builder.CreateAnd(
432
Flags, ConstantInt::get(Type::getInt32Ty(C),
433
llvm::offloading::OffloadGlobalConstant));
434
auto *Const = Builder.CreateLShr(
435
ConstantBit, ConstantInt::get(Type::getInt32Ty(C), 4), "constant");
436
auto *NormalizedBit = Builder.CreateAnd(
437
Flags, ConstantInt::get(Type::getInt32Ty(C),
438
llvm::offloading::OffloadGlobalNormalized));
439
auto *Normalized = Builder.CreateLShr(
440
NormalizedBit, ConstantInt::get(Type::getInt32Ty(C), 5), "normalized");
441
auto *FnCond =
442
Builder.CreateICmpEQ(Size, ConstantInt::getNullValue(getSizeTTy(M)));
443
Builder.CreateCondBr(FnCond, IfThenBB, IfElseBB);
444
445
// Create kernel registration code.
446
Builder.SetInsertPoint(IfThenBB);
447
Builder.CreateCall(RegFunc, {RegGlobalsFn->arg_begin(), Addr, Name, Name,
448
ConstantInt::get(Type::getInt32Ty(C), -1),
449
ConstantPointerNull::get(Int8PtrTy),
450
ConstantPointerNull::get(Int8PtrTy),
451
ConstantPointerNull::get(Int8PtrTy),
452
ConstantPointerNull::get(Int8PtrTy),
453
ConstantPointerNull::get(Int32PtrTy)});
454
Builder.CreateBr(IfEndBB);
455
Builder.SetInsertPoint(IfElseBB);
456
457
auto *Switch = Builder.CreateSwitch(Kind, IfEndBB);
458
// Create global variable registration code.
459
Builder.SetInsertPoint(SwGlobalBB);
460
Builder.CreateCall(RegVar,
461
{RegGlobalsFn->arg_begin(), Addr, Name, Name, Extern, Size,
462
Const, ConstantInt::get(Type::getInt32Ty(C), 0)});
463
Builder.CreateBr(IfEndBB);
464
Switch->addCase(Builder.getInt32(llvm::offloading::OffloadGlobalEntry),
465
SwGlobalBB);
466
467
// Create managed variable registration code.
468
Builder.SetInsertPoint(SwManagedBB);
469
Builder.CreateBr(IfEndBB);
470
Switch->addCase(Builder.getInt32(llvm::offloading::OffloadGlobalManagedEntry),
471
SwManagedBB);
472
// Create surface variable registration code.
473
Builder.SetInsertPoint(SwSurfaceBB);
474
if (EmitSurfacesAndTextures)
475
Builder.CreateCall(RegSurface, {RegGlobalsFn->arg_begin(), Addr, Name, Name,
476
Data, Extern});
477
Builder.CreateBr(IfEndBB);
478
Switch->addCase(Builder.getInt32(llvm::offloading::OffloadGlobalSurfaceEntry),
479
SwSurfaceBB);
480
481
// Create texture variable registration code.
482
Builder.SetInsertPoint(SwTextureBB);
483
if (EmitSurfacesAndTextures)
484
Builder.CreateCall(RegTexture, {RegGlobalsFn->arg_begin(), Addr, Name, Name,
485
Data, Normalized, Extern});
486
Builder.CreateBr(IfEndBB);
487
Switch->addCase(Builder.getInt32(llvm::offloading::OffloadGlobalTextureEntry),
488
SwTextureBB);
489
490
Builder.SetInsertPoint(IfEndBB);
491
auto *NewEntry = Builder.CreateInBoundsGEP(
492
offloading::getEntryTy(M), Entry, ConstantInt::get(getSizeTTy(M), 1));
493
auto *Cmp = Builder.CreateICmpEQ(
494
NewEntry,
495
ConstantExpr::getInBoundsGetElementPtr(
496
ArrayType::get(offloading::getEntryTy(M), 0), EntriesE,
497
ArrayRef<Constant *>({ConstantInt::get(getSizeTTy(M), 0),
498
ConstantInt::get(getSizeTTy(M), 0)})));
499
Entry->addIncoming(
500
ConstantExpr::getInBoundsGetElementPtr(
501
ArrayType::get(offloading::getEntryTy(M), 0), EntriesB,
502
ArrayRef<Constant *>({ConstantInt::get(getSizeTTy(M), 0),
503
ConstantInt::get(getSizeTTy(M), 0)})),
504
&RegGlobalsFn->getEntryBlock());
505
Entry->addIncoming(NewEntry, IfEndBB);
506
Builder.CreateCondBr(Cmp, ExitBB, EntryBB);
507
Builder.SetInsertPoint(ExitBB);
508
Builder.CreateRetVoid();
509
510
return RegGlobalsFn;
511
}
512
513
// Create the constructor and destructor to register the fatbinary with the CUDA
514
// runtime.
515
void createRegisterFatbinFunction(Module &M, GlobalVariable *FatbinDesc,
516
bool IsHIP, EntryArrayTy EntryArray,
517
StringRef Suffix,
518
bool EmitSurfacesAndTextures) {
519
LLVMContext &C = M.getContext();
520
auto *CtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
521
auto *CtorFunc = Function::Create(
522
CtorFuncTy, GlobalValue::InternalLinkage,
523
(IsHIP ? ".hip.fatbin_reg" : ".cuda.fatbin_reg") + Suffix, &M);
524
CtorFunc->setSection(".text.startup");
525
526
auto *DtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
527
auto *DtorFunc = Function::Create(
528
DtorFuncTy, GlobalValue::InternalLinkage,
529
(IsHIP ? ".hip.fatbin_unreg" : ".cuda.fatbin_unreg") + Suffix, &M);
530
DtorFunc->setSection(".text.startup");
531
532
auto *PtrTy = PointerType::getUnqual(C);
533
534
// Get the __cudaRegisterFatBinary function declaration.
535
auto *RegFatTy = FunctionType::get(PtrTy, PtrTy, /*isVarArg=*/false);
536
FunctionCallee RegFatbin = M.getOrInsertFunction(
537
IsHIP ? "__hipRegisterFatBinary" : "__cudaRegisterFatBinary", RegFatTy);
538
// Get the __cudaRegisterFatBinaryEnd function declaration.
539
auto *RegFatEndTy =
540
FunctionType::get(Type::getVoidTy(C), PtrTy, /*isVarArg=*/false);
541
FunctionCallee RegFatbinEnd =
542
M.getOrInsertFunction("__cudaRegisterFatBinaryEnd", RegFatEndTy);
543
// Get the __cudaUnregisterFatBinary function declaration.
544
auto *UnregFatTy =
545
FunctionType::get(Type::getVoidTy(C), PtrTy, /*isVarArg=*/false);
546
FunctionCallee UnregFatbin = M.getOrInsertFunction(
547
IsHIP ? "__hipUnregisterFatBinary" : "__cudaUnregisterFatBinary",
548
UnregFatTy);
549
550
auto *AtExitTy =
551
FunctionType::get(Type::getInt32Ty(C), PtrTy, /*isVarArg=*/false);
552
FunctionCallee AtExit = M.getOrInsertFunction("atexit", AtExitTy);
553
554
auto *BinaryHandleGlobal = new llvm::GlobalVariable(
555
M, PtrTy, false, llvm::GlobalValue::InternalLinkage,
556
llvm::ConstantPointerNull::get(PtrTy),
557
(IsHIP ? ".hip.binary_handle" : ".cuda.binary_handle") + Suffix);
558
559
// Create the constructor to register this image with the runtime.
560
IRBuilder<> CtorBuilder(BasicBlock::Create(C, "entry", CtorFunc));
561
CallInst *Handle = CtorBuilder.CreateCall(
562
RegFatbin,
563
ConstantExpr::getPointerBitCastOrAddrSpaceCast(FatbinDesc, PtrTy));
564
CtorBuilder.CreateAlignedStore(
565
Handle, BinaryHandleGlobal,
566
Align(M.getDataLayout().getPointerTypeSize(PtrTy)));
567
CtorBuilder.CreateCall(createRegisterGlobalsFunction(M, IsHIP, EntryArray,
568
Suffix,
569
EmitSurfacesAndTextures),
570
Handle);
571
if (!IsHIP)
572
CtorBuilder.CreateCall(RegFatbinEnd, Handle);
573
CtorBuilder.CreateCall(AtExit, DtorFunc);
574
CtorBuilder.CreateRetVoid();
575
576
// Create the destructor to unregister the image with the runtime. We cannot
577
// use a standard global destructor after CUDA 9.2 so this must be called by
578
// `atexit()` intead.
579
IRBuilder<> DtorBuilder(BasicBlock::Create(C, "entry", DtorFunc));
580
LoadInst *BinaryHandle = DtorBuilder.CreateAlignedLoad(
581
PtrTy, BinaryHandleGlobal,
582
Align(M.getDataLayout().getPointerTypeSize(PtrTy)));
583
DtorBuilder.CreateCall(UnregFatbin, BinaryHandle);
584
DtorBuilder.CreateRetVoid();
585
586
// Add this function to constructors.
587
appendToGlobalCtors(M, CtorFunc, /*Priority=*/101);
588
}
589
} // namespace
590
591
Error offloading::wrapOpenMPBinaries(Module &M, ArrayRef<ArrayRef<char>> Images,
592
EntryArrayTy EntryArray,
593
llvm::StringRef Suffix, bool Relocatable) {
594
GlobalVariable *Desc =
595
createBinDesc(M, Images, EntryArray, Suffix, Relocatable);
596
if (!Desc)
597
return createStringError(inconvertibleErrorCode(),
598
"No binary descriptors created.");
599
createRegisterFunction(M, Desc, Suffix);
600
return Error::success();
601
}
602
603
Error offloading::wrapCudaBinary(Module &M, ArrayRef<char> Image,
604
EntryArrayTy EntryArray,
605
llvm::StringRef Suffix,
606
bool EmitSurfacesAndTextures) {
607
GlobalVariable *Desc = createFatbinDesc(M, Image, /*IsHip=*/false, Suffix);
608
if (!Desc)
609
return createStringError(inconvertibleErrorCode(),
610
"No fatbin section created.");
611
612
createRegisterFatbinFunction(M, Desc, /*IsHip=*/false, EntryArray, Suffix,
613
EmitSurfacesAndTextures);
614
return Error::success();
615
}
616
617
Error offloading::wrapHIPBinary(Module &M, ArrayRef<char> Image,
618
EntryArrayTy EntryArray, llvm::StringRef Suffix,
619
bool EmitSurfacesAndTextures) {
620
GlobalVariable *Desc = createFatbinDesc(M, Image, /*IsHip=*/true, Suffix);
621
if (!Desc)
622
return createStringError(inconvertibleErrorCode(),
623
"No fatbin section created.");
624
625
createRegisterFatbinFunction(M, Desc, /*IsHip=*/true, EntryArray, Suffix,
626
EmitSurfacesAndTextures);
627
return Error::success();
628
}
629
630