Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Object/OffloadBundle.cpp
213764 views
1
//===- OffloadBundle.cpp - Utilities for offload bundles---*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------===//
8
9
#include "llvm/Object/OffloadBundle.h"
10
#include "llvm/BinaryFormat/Magic.h"
11
#include "llvm/IR/Module.h"
12
#include "llvm/IRReader/IRReader.h"
13
#include "llvm/MC/StringTableBuilder.h"
14
#include "llvm/Object/Archive.h"
15
#include "llvm/Object/Binary.h"
16
#include "llvm/Object/COFF.h"
17
#include "llvm/Object/ELFObjectFile.h"
18
#include "llvm/Object/Error.h"
19
#include "llvm/Object/IRObjectFile.h"
20
#include "llvm/Object/ObjectFile.h"
21
#include "llvm/Support/BinaryStreamReader.h"
22
#include "llvm/Support/SourceMgr.h"
23
#include "llvm/Support/Timer.h"
24
25
using namespace llvm;
26
using namespace llvm::object;
27
28
static llvm::TimerGroup
29
OffloadBundlerTimerGroup("Offload Bundler Timer Group",
30
"Timer group for offload bundler");
31
32
// Extract an Offload bundle (usually a Offload Bundle) from a fat_bin
33
// section
34
Error extractOffloadBundle(MemoryBufferRef Contents, uint64_t SectionOffset,
35
StringRef FileName,
36
SmallVectorImpl<OffloadBundleFatBin> &Bundles) {
37
38
size_t Offset = 0;
39
size_t NextbundleStart = 0;
40
41
// There could be multiple offloading bundles stored at this section.
42
while (NextbundleStart != StringRef::npos) {
43
std::unique_ptr<MemoryBuffer> Buffer =
44
MemoryBuffer::getMemBuffer(Contents.getBuffer().drop_front(Offset), "",
45
/*RequiresNullTerminator=*/false);
46
47
// Create the FatBinBindle object. This will also create the Bundle Entry
48
// list info.
49
auto FatBundleOrErr =
50
OffloadBundleFatBin::create(*Buffer, SectionOffset + Offset, FileName);
51
if (!FatBundleOrErr)
52
return FatBundleOrErr.takeError();
53
54
// Add current Bundle to list.
55
Bundles.emplace_back(std::move(**FatBundleOrErr));
56
57
// Find the next bundle by searching for the magic string
58
StringRef Str = Buffer->getBuffer();
59
NextbundleStart = Str.find(StringRef("__CLANG_OFFLOAD_BUNDLE__"), 24);
60
61
if (NextbundleStart != StringRef::npos)
62
Offset += NextbundleStart;
63
}
64
65
return Error::success();
66
}
67
68
Error OffloadBundleFatBin::readEntries(StringRef Buffer,
69
uint64_t SectionOffset) {
70
uint64_t NumOfEntries = 0;
71
72
BinaryStreamReader Reader(Buffer, llvm::endianness::little);
73
74
// Read the Magic String first.
75
StringRef Magic;
76
if (auto EC = Reader.readFixedString(Magic, 24))
77
return errorCodeToError(object_error::parse_failed);
78
79
// Read the number of Code Objects (Entries) in the current Bundle.
80
if (auto EC = Reader.readInteger(NumOfEntries))
81
return errorCodeToError(object_error::parse_failed);
82
83
NumberOfEntries = NumOfEntries;
84
85
// For each Bundle Entry (code object)
86
for (uint64_t I = 0; I < NumOfEntries; I++) {
87
uint64_t EntrySize;
88
uint64_t EntryOffset;
89
uint64_t EntryIDSize;
90
StringRef EntryID;
91
92
if (auto EC = Reader.readInteger(EntryOffset))
93
return errorCodeToError(object_error::parse_failed);
94
95
if (auto EC = Reader.readInteger(EntrySize))
96
return errorCodeToError(object_error::parse_failed);
97
98
if (auto EC = Reader.readInteger(EntryIDSize))
99
return errorCodeToError(object_error::parse_failed);
100
101
if (auto EC = Reader.readFixedString(EntryID, EntryIDSize))
102
return errorCodeToError(object_error::parse_failed);
103
104
auto Entry = std::make_unique<OffloadBundleEntry>(
105
EntryOffset + SectionOffset, EntrySize, EntryIDSize, EntryID);
106
107
Entries.push_back(*Entry);
108
}
109
110
return Error::success();
111
}
112
113
Expected<std::unique_ptr<OffloadBundleFatBin>>
114
OffloadBundleFatBin::create(MemoryBufferRef Buf, uint64_t SectionOffset,
115
StringRef FileName) {
116
if (Buf.getBufferSize() < 24)
117
return errorCodeToError(object_error::parse_failed);
118
119
// Check for magic bytes.
120
if (identify_magic(Buf.getBuffer()) != file_magic::offload_bundle)
121
return errorCodeToError(object_error::parse_failed);
122
123
OffloadBundleFatBin *TheBundle = new OffloadBundleFatBin(Buf, FileName);
124
125
// Read the Bundle Entries
126
Error Err = TheBundle->readEntries(Buf.getBuffer(), SectionOffset);
127
if (Err)
128
return errorCodeToError(object_error::parse_failed);
129
130
return std::unique_ptr<OffloadBundleFatBin>(TheBundle);
131
}
132
133
Error OffloadBundleFatBin::extractBundle(const ObjectFile &Source) {
134
// This will extract all entries in the Bundle
135
for (OffloadBundleEntry &Entry : Entries) {
136
137
if (Entry.Size == 0)
138
continue;
139
140
// create output file name. Which should be
141
// <fileName>-offset<Offset>-size<Size>.co"
142
std::string Str = getFileName().str() + "-offset" + itostr(Entry.Offset) +
143
"-size" + itostr(Entry.Size) + ".co";
144
if (Error Err = object::extractCodeObject(Source, Entry.Offset, Entry.Size,
145
StringRef(Str)))
146
return Err;
147
}
148
149
return Error::success();
150
}
151
152
Error object::extractOffloadBundleFatBinary(
153
const ObjectFile &Obj, SmallVectorImpl<OffloadBundleFatBin> &Bundles) {
154
assert((Obj.isELF() || Obj.isCOFF()) && "Invalid file type");
155
156
// Iterate through Sections until we find an offload_bundle section.
157
for (SectionRef Sec : Obj.sections()) {
158
Expected<StringRef> Buffer = Sec.getContents();
159
if (!Buffer)
160
return Buffer.takeError();
161
162
// If it does not start with the reserved suffix, just skip this section.
163
if ((llvm::identify_magic(*Buffer) == llvm::file_magic::offload_bundle) ||
164
(llvm::identify_magic(*Buffer) ==
165
llvm::file_magic::offload_bundle_compressed)) {
166
167
uint64_t SectionOffset = 0;
168
if (Obj.isELF()) {
169
SectionOffset = ELFSectionRef(Sec).getOffset();
170
} else if (Obj.isCOFF()) // TODO: add COFF Support
171
return createStringError(object_error::parse_failed,
172
"COFF object files not supported.\n");
173
174
MemoryBufferRef Contents(*Buffer, Obj.getFileName());
175
176
if (llvm::identify_magic(*Buffer) ==
177
llvm::file_magic::offload_bundle_compressed) {
178
// Decompress the input if necessary.
179
Expected<std::unique_ptr<MemoryBuffer>> DecompressedBufferOrErr =
180
CompressedOffloadBundle::decompress(Contents, false);
181
182
if (!DecompressedBufferOrErr)
183
return createStringError(
184
inconvertibleErrorCode(),
185
"Failed to decompress input: " +
186
llvm::toString(DecompressedBufferOrErr.takeError()));
187
188
MemoryBuffer &DecompressedInput = **DecompressedBufferOrErr;
189
if (Error Err = extractOffloadBundle(DecompressedInput, SectionOffset,
190
Obj.getFileName(), Bundles))
191
return Err;
192
} else {
193
if (Error Err = extractOffloadBundle(Contents, SectionOffset,
194
Obj.getFileName(), Bundles))
195
return Err;
196
}
197
}
198
}
199
return Error::success();
200
}
201
202
Error object::extractCodeObject(const ObjectFile &Source, int64_t Offset,
203
int64_t Size, StringRef OutputFileName) {
204
Expected<std::unique_ptr<FileOutputBuffer>> BufferOrErr =
205
FileOutputBuffer::create(OutputFileName, Size);
206
207
if (!BufferOrErr)
208
return BufferOrErr.takeError();
209
210
Expected<MemoryBufferRef> InputBuffOrErr = Source.getMemoryBufferRef();
211
if (Error Err = InputBuffOrErr.takeError())
212
return Err;
213
214
std::unique_ptr<FileOutputBuffer> Buf = std::move(*BufferOrErr);
215
std::copy(InputBuffOrErr->getBufferStart() + Offset,
216
InputBuffOrErr->getBufferStart() + Offset + Size,
217
Buf->getBufferStart());
218
if (Error E = Buf->commit())
219
return E;
220
221
return Error::success();
222
}
223
224
// given a file name, offset, and size, extract data into a code object file,
225
// into file <SourceFile>-offset<Offset>-size<Size>.co
226
Error object::extractOffloadBundleByURI(StringRef URIstr) {
227
// create a URI object
228
Expected<std::unique_ptr<OffloadBundleURI>> UriOrErr(
229
OffloadBundleURI::createOffloadBundleURI(URIstr, FILE_URI));
230
if (!UriOrErr)
231
return UriOrErr.takeError();
232
233
OffloadBundleURI &Uri = **UriOrErr;
234
std::string OutputFile = Uri.FileName.str();
235
OutputFile +=
236
"-offset" + itostr(Uri.Offset) + "-size" + itostr(Uri.Size) + ".co";
237
238
// Create an ObjectFile object from uri.file_uri
239
auto ObjOrErr = ObjectFile::createObjectFile(Uri.FileName);
240
if (!ObjOrErr)
241
return ObjOrErr.takeError();
242
243
auto Obj = ObjOrErr->getBinary();
244
if (Error Err =
245
object::extractCodeObject(*Obj, Uri.Offset, Uri.Size, OutputFile))
246
return Err;
247
248
return Error::success();
249
}
250
251
// Utility function to format numbers with commas
252
static std::string formatWithCommas(unsigned long long Value) {
253
std::string Num = std::to_string(Value);
254
int InsertPosition = Num.length() - 3;
255
while (InsertPosition > 0) {
256
Num.insert(InsertPosition, ",");
257
InsertPosition -= 3;
258
}
259
return Num;
260
}
261
262
llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
263
CompressedOffloadBundle::decompress(llvm::MemoryBufferRef &Input,
264
bool Verbose) {
265
StringRef Blob = Input.getBuffer();
266
267
if (Blob.size() < V1HeaderSize)
268
return llvm::MemoryBuffer::getMemBufferCopy(Blob);
269
270
if (llvm::identify_magic(Blob) !=
271
llvm::file_magic::offload_bundle_compressed) {
272
if (Verbose)
273
llvm::errs() << "Uncompressed bundle.\n";
274
return llvm::MemoryBuffer::getMemBufferCopy(Blob);
275
}
276
277
size_t CurrentOffset = MagicSize;
278
279
uint16_t ThisVersion;
280
memcpy(&ThisVersion, Blob.data() + CurrentOffset, sizeof(uint16_t));
281
CurrentOffset += VersionFieldSize;
282
283
uint16_t CompressionMethod;
284
memcpy(&CompressionMethod, Blob.data() + CurrentOffset, sizeof(uint16_t));
285
CurrentOffset += MethodFieldSize;
286
287
uint32_t TotalFileSize;
288
if (ThisVersion >= 2) {
289
if (Blob.size() < V2HeaderSize)
290
return createStringError(inconvertibleErrorCode(),
291
"Compressed bundle header size too small");
292
memcpy(&TotalFileSize, Blob.data() + CurrentOffset, sizeof(uint32_t));
293
CurrentOffset += FileSizeFieldSize;
294
}
295
296
uint32_t UncompressedSize;
297
memcpy(&UncompressedSize, Blob.data() + CurrentOffset, sizeof(uint32_t));
298
CurrentOffset += UncompressedSizeFieldSize;
299
300
uint64_t StoredHash;
301
memcpy(&StoredHash, Blob.data() + CurrentOffset, sizeof(uint64_t));
302
CurrentOffset += HashFieldSize;
303
304
llvm::compression::Format CompressionFormat;
305
if (CompressionMethod ==
306
static_cast<uint16_t>(llvm::compression::Format::Zlib))
307
CompressionFormat = llvm::compression::Format::Zlib;
308
else if (CompressionMethod ==
309
static_cast<uint16_t>(llvm::compression::Format::Zstd))
310
CompressionFormat = llvm::compression::Format::Zstd;
311
else
312
return createStringError(inconvertibleErrorCode(),
313
"Unknown compressing method");
314
315
llvm::Timer DecompressTimer("Decompression Timer", "Decompression time",
316
OffloadBundlerTimerGroup);
317
if (Verbose)
318
DecompressTimer.startTimer();
319
320
SmallVector<uint8_t, 0> DecompressedData;
321
StringRef CompressedData = Blob.substr(CurrentOffset);
322
if (llvm::Error DecompressionError = llvm::compression::decompress(
323
CompressionFormat, llvm::arrayRefFromStringRef(CompressedData),
324
DecompressedData, UncompressedSize))
325
return createStringError(inconvertibleErrorCode(),
326
"Could not decompress embedded file contents: " +
327
llvm::toString(std::move(DecompressionError)));
328
329
if (Verbose) {
330
DecompressTimer.stopTimer();
331
332
double DecompressionTimeSeconds =
333
DecompressTimer.getTotalTime().getWallTime();
334
335
// Recalculate MD5 hash for integrity check.
336
llvm::Timer HashRecalcTimer("Hash Recalculation Timer",
337
"Hash recalculation time",
338
OffloadBundlerTimerGroup);
339
HashRecalcTimer.startTimer();
340
llvm::MD5 Hash;
341
llvm::MD5::MD5Result Result;
342
Hash.update(llvm::ArrayRef<uint8_t>(DecompressedData));
343
Hash.final(Result);
344
uint64_t RecalculatedHash = Result.low();
345
HashRecalcTimer.stopTimer();
346
bool HashMatch = (StoredHash == RecalculatedHash);
347
348
double CompressionRate =
349
static_cast<double>(UncompressedSize) / CompressedData.size();
350
double DecompressionSpeedMBs =
351
(UncompressedSize / (1024.0 * 1024.0)) / DecompressionTimeSeconds;
352
353
llvm::errs() << "Compressed bundle format version: " << ThisVersion << "\n";
354
if (ThisVersion >= 2)
355
llvm::errs() << "Total file size (from header): "
356
<< formatWithCommas(TotalFileSize) << " bytes\n";
357
llvm::errs() << "Decompression method: "
358
<< (CompressionFormat == llvm::compression::Format::Zlib
359
? "zlib"
360
: "zstd")
361
<< "\n"
362
<< "Size before decompression: "
363
<< formatWithCommas(CompressedData.size()) << " bytes\n"
364
<< "Size after decompression: "
365
<< formatWithCommas(UncompressedSize) << " bytes\n"
366
<< "Compression rate: "
367
<< llvm::format("%.2lf", CompressionRate) << "\n"
368
<< "Compression ratio: "
369
<< llvm::format("%.2lf%%", 100.0 / CompressionRate) << "\n"
370
<< "Decompression speed: "
371
<< llvm::format("%.2lf MB/s", DecompressionSpeedMBs) << "\n"
372
<< "Stored hash: " << llvm::format_hex(StoredHash, 16) << "\n"
373
<< "Recalculated hash: "
374
<< llvm::format_hex(RecalculatedHash, 16) << "\n"
375
<< "Hashes match: " << (HashMatch ? "Yes" : "No") << "\n";
376
}
377
378
return llvm::MemoryBuffer::getMemBufferCopy(
379
llvm::toStringRef(DecompressedData));
380
}
381
382
llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
383
CompressedOffloadBundle::compress(llvm::compression::Params P,
384
const llvm::MemoryBuffer &Input,
385
bool Verbose) {
386
if (!llvm::compression::zstd::isAvailable() &&
387
!llvm::compression::zlib::isAvailable())
388
return createStringError(llvm::inconvertibleErrorCode(),
389
"Compression not supported");
390
391
llvm::Timer HashTimer("Hash Calculation Timer", "Hash calculation time",
392
OffloadBundlerTimerGroup);
393
if (Verbose)
394
HashTimer.startTimer();
395
llvm::MD5 Hash;
396
llvm::MD5::MD5Result Result;
397
Hash.update(Input.getBuffer());
398
Hash.final(Result);
399
uint64_t TruncatedHash = Result.low();
400
if (Verbose)
401
HashTimer.stopTimer();
402
403
SmallVector<uint8_t, 0> CompressedBuffer;
404
auto BufferUint8 = llvm::ArrayRef<uint8_t>(
405
reinterpret_cast<const uint8_t *>(Input.getBuffer().data()),
406
Input.getBuffer().size());
407
408
llvm::Timer CompressTimer("Compression Timer", "Compression time",
409
OffloadBundlerTimerGroup);
410
if (Verbose)
411
CompressTimer.startTimer();
412
llvm::compression::compress(P, BufferUint8, CompressedBuffer);
413
if (Verbose)
414
CompressTimer.stopTimer();
415
416
uint16_t CompressionMethod = static_cast<uint16_t>(P.format);
417
uint32_t UncompressedSize = Input.getBuffer().size();
418
uint32_t TotalFileSize = MagicNumber.size() + sizeof(TotalFileSize) +
419
sizeof(Version) + sizeof(CompressionMethod) +
420
sizeof(UncompressedSize) + sizeof(TruncatedHash) +
421
CompressedBuffer.size();
422
423
SmallVector<char, 0> FinalBuffer;
424
llvm::raw_svector_ostream OS(FinalBuffer);
425
OS << MagicNumber;
426
OS.write(reinterpret_cast<const char *>(&Version), sizeof(Version));
427
OS.write(reinterpret_cast<const char *>(&CompressionMethod),
428
sizeof(CompressionMethod));
429
OS.write(reinterpret_cast<const char *>(&TotalFileSize),
430
sizeof(TotalFileSize));
431
OS.write(reinterpret_cast<const char *>(&UncompressedSize),
432
sizeof(UncompressedSize));
433
OS.write(reinterpret_cast<const char *>(&TruncatedHash),
434
sizeof(TruncatedHash));
435
OS.write(reinterpret_cast<const char *>(CompressedBuffer.data()),
436
CompressedBuffer.size());
437
438
if (Verbose) {
439
auto MethodUsed =
440
P.format == llvm::compression::Format::Zstd ? "zstd" : "zlib";
441
double CompressionRate =
442
static_cast<double>(UncompressedSize) / CompressedBuffer.size();
443
double CompressionTimeSeconds = CompressTimer.getTotalTime().getWallTime();
444
double CompressionSpeedMBs =
445
(UncompressedSize / (1024.0 * 1024.0)) / CompressionTimeSeconds;
446
447
llvm::errs() << "Compressed bundle format version: " << Version << "\n"
448
<< "Total file size (including headers): "
449
<< formatWithCommas(TotalFileSize) << " bytes\n"
450
<< "Compression method used: " << MethodUsed << "\n"
451
<< "Compression level: " << P.level << "\n"
452
<< "Binary size before compression: "
453
<< formatWithCommas(UncompressedSize) << " bytes\n"
454
<< "Binary size after compression: "
455
<< formatWithCommas(CompressedBuffer.size()) << " bytes\n"
456
<< "Compression rate: "
457
<< llvm::format("%.2lf", CompressionRate) << "\n"
458
<< "Compression ratio: "
459
<< llvm::format("%.2lf%%", 100.0 / CompressionRate) << "\n"
460
<< "Compression speed: "
461
<< llvm::format("%.2lf MB/s", CompressionSpeedMBs) << "\n"
462
<< "Truncated MD5 hash: "
463
<< llvm::format_hex(TruncatedHash, 16) << "\n";
464
}
465
return llvm::MemoryBuffer::getMemBufferCopy(
466
llvm::StringRef(FinalBuffer.data(), FinalBuffer.size()));
467
}
468
469