Path: blob/main/contrib/llvm-project/llvm/lib/Object/OffloadBundle.cpp
213764 views
//===- OffloadBundle.cpp - Utilities for offload bundles---*- C++ -*-===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------===//78#include "llvm/Object/OffloadBundle.h"9#include "llvm/BinaryFormat/Magic.h"10#include "llvm/IR/Module.h"11#include "llvm/IRReader/IRReader.h"12#include "llvm/MC/StringTableBuilder.h"13#include "llvm/Object/Archive.h"14#include "llvm/Object/Binary.h"15#include "llvm/Object/COFF.h"16#include "llvm/Object/ELFObjectFile.h"17#include "llvm/Object/Error.h"18#include "llvm/Object/IRObjectFile.h"19#include "llvm/Object/ObjectFile.h"20#include "llvm/Support/BinaryStreamReader.h"21#include "llvm/Support/SourceMgr.h"22#include "llvm/Support/Timer.h"2324using namespace llvm;25using namespace llvm::object;2627static llvm::TimerGroup28OffloadBundlerTimerGroup("Offload Bundler Timer Group",29"Timer group for offload bundler");3031// Extract an Offload bundle (usually a Offload Bundle) from a fat_bin32// section33Error extractOffloadBundle(MemoryBufferRef Contents, uint64_t SectionOffset,34StringRef FileName,35SmallVectorImpl<OffloadBundleFatBin> &Bundles) {3637size_t Offset = 0;38size_t NextbundleStart = 0;3940// There could be multiple offloading bundles stored at this section.41while (NextbundleStart != StringRef::npos) {42std::unique_ptr<MemoryBuffer> Buffer =43MemoryBuffer::getMemBuffer(Contents.getBuffer().drop_front(Offset), "",44/*RequiresNullTerminator=*/false);4546// Create the FatBinBindle object. This will also create the Bundle Entry47// list info.48auto FatBundleOrErr =49OffloadBundleFatBin::create(*Buffer, SectionOffset + Offset, FileName);50if (!FatBundleOrErr)51return FatBundleOrErr.takeError();5253// Add current Bundle to list.54Bundles.emplace_back(std::move(**FatBundleOrErr));5556// Find the next bundle by searching for the magic string57StringRef Str = Buffer->getBuffer();58NextbundleStart = Str.find(StringRef("__CLANG_OFFLOAD_BUNDLE__"), 24);5960if (NextbundleStart != StringRef::npos)61Offset += NextbundleStart;62}6364return Error::success();65}6667Error OffloadBundleFatBin::readEntries(StringRef Buffer,68uint64_t SectionOffset) {69uint64_t NumOfEntries = 0;7071BinaryStreamReader Reader(Buffer, llvm::endianness::little);7273// Read the Magic String first.74StringRef Magic;75if (auto EC = Reader.readFixedString(Magic, 24))76return errorCodeToError(object_error::parse_failed);7778// Read the number of Code Objects (Entries) in the current Bundle.79if (auto EC = Reader.readInteger(NumOfEntries))80return errorCodeToError(object_error::parse_failed);8182NumberOfEntries = NumOfEntries;8384// For each Bundle Entry (code object)85for (uint64_t I = 0; I < NumOfEntries; I++) {86uint64_t EntrySize;87uint64_t EntryOffset;88uint64_t EntryIDSize;89StringRef EntryID;9091if (auto EC = Reader.readInteger(EntryOffset))92return errorCodeToError(object_error::parse_failed);9394if (auto EC = Reader.readInteger(EntrySize))95return errorCodeToError(object_error::parse_failed);9697if (auto EC = Reader.readInteger(EntryIDSize))98return errorCodeToError(object_error::parse_failed);99100if (auto EC = Reader.readFixedString(EntryID, EntryIDSize))101return errorCodeToError(object_error::parse_failed);102103auto Entry = std::make_unique<OffloadBundleEntry>(104EntryOffset + SectionOffset, EntrySize, EntryIDSize, EntryID);105106Entries.push_back(*Entry);107}108109return Error::success();110}111112Expected<std::unique_ptr<OffloadBundleFatBin>>113OffloadBundleFatBin::create(MemoryBufferRef Buf, uint64_t SectionOffset,114StringRef FileName) {115if (Buf.getBufferSize() < 24)116return errorCodeToError(object_error::parse_failed);117118// Check for magic bytes.119if (identify_magic(Buf.getBuffer()) != file_magic::offload_bundle)120return errorCodeToError(object_error::parse_failed);121122OffloadBundleFatBin *TheBundle = new OffloadBundleFatBin(Buf, FileName);123124// Read the Bundle Entries125Error Err = TheBundle->readEntries(Buf.getBuffer(), SectionOffset);126if (Err)127return errorCodeToError(object_error::parse_failed);128129return std::unique_ptr<OffloadBundleFatBin>(TheBundle);130}131132Error OffloadBundleFatBin::extractBundle(const ObjectFile &Source) {133// This will extract all entries in the Bundle134for (OffloadBundleEntry &Entry : Entries) {135136if (Entry.Size == 0)137continue;138139// create output file name. Which should be140// <fileName>-offset<Offset>-size<Size>.co"141std::string Str = getFileName().str() + "-offset" + itostr(Entry.Offset) +142"-size" + itostr(Entry.Size) + ".co";143if (Error Err = object::extractCodeObject(Source, Entry.Offset, Entry.Size,144StringRef(Str)))145return Err;146}147148return Error::success();149}150151Error object::extractOffloadBundleFatBinary(152const ObjectFile &Obj, SmallVectorImpl<OffloadBundleFatBin> &Bundles) {153assert((Obj.isELF() || Obj.isCOFF()) && "Invalid file type");154155// Iterate through Sections until we find an offload_bundle section.156for (SectionRef Sec : Obj.sections()) {157Expected<StringRef> Buffer = Sec.getContents();158if (!Buffer)159return Buffer.takeError();160161// If it does not start with the reserved suffix, just skip this section.162if ((llvm::identify_magic(*Buffer) == llvm::file_magic::offload_bundle) ||163(llvm::identify_magic(*Buffer) ==164llvm::file_magic::offload_bundle_compressed)) {165166uint64_t SectionOffset = 0;167if (Obj.isELF()) {168SectionOffset = ELFSectionRef(Sec).getOffset();169} else if (Obj.isCOFF()) // TODO: add COFF Support170return createStringError(object_error::parse_failed,171"COFF object files not supported.\n");172173MemoryBufferRef Contents(*Buffer, Obj.getFileName());174175if (llvm::identify_magic(*Buffer) ==176llvm::file_magic::offload_bundle_compressed) {177// Decompress the input if necessary.178Expected<std::unique_ptr<MemoryBuffer>> DecompressedBufferOrErr =179CompressedOffloadBundle::decompress(Contents, false);180181if (!DecompressedBufferOrErr)182return createStringError(183inconvertibleErrorCode(),184"Failed to decompress input: " +185llvm::toString(DecompressedBufferOrErr.takeError()));186187MemoryBuffer &DecompressedInput = **DecompressedBufferOrErr;188if (Error Err = extractOffloadBundle(DecompressedInput, SectionOffset,189Obj.getFileName(), Bundles))190return Err;191} else {192if (Error Err = extractOffloadBundle(Contents, SectionOffset,193Obj.getFileName(), Bundles))194return Err;195}196}197}198return Error::success();199}200201Error object::extractCodeObject(const ObjectFile &Source, int64_t Offset,202int64_t Size, StringRef OutputFileName) {203Expected<std::unique_ptr<FileOutputBuffer>> BufferOrErr =204FileOutputBuffer::create(OutputFileName, Size);205206if (!BufferOrErr)207return BufferOrErr.takeError();208209Expected<MemoryBufferRef> InputBuffOrErr = Source.getMemoryBufferRef();210if (Error Err = InputBuffOrErr.takeError())211return Err;212213std::unique_ptr<FileOutputBuffer> Buf = std::move(*BufferOrErr);214std::copy(InputBuffOrErr->getBufferStart() + Offset,215InputBuffOrErr->getBufferStart() + Offset + Size,216Buf->getBufferStart());217if (Error E = Buf->commit())218return E;219220return Error::success();221}222223// given a file name, offset, and size, extract data into a code object file,224// into file <SourceFile>-offset<Offset>-size<Size>.co225Error object::extractOffloadBundleByURI(StringRef URIstr) {226// create a URI object227Expected<std::unique_ptr<OffloadBundleURI>> UriOrErr(228OffloadBundleURI::createOffloadBundleURI(URIstr, FILE_URI));229if (!UriOrErr)230return UriOrErr.takeError();231232OffloadBundleURI &Uri = **UriOrErr;233std::string OutputFile = Uri.FileName.str();234OutputFile +=235"-offset" + itostr(Uri.Offset) + "-size" + itostr(Uri.Size) + ".co";236237// Create an ObjectFile object from uri.file_uri238auto ObjOrErr = ObjectFile::createObjectFile(Uri.FileName);239if (!ObjOrErr)240return ObjOrErr.takeError();241242auto Obj = ObjOrErr->getBinary();243if (Error Err =244object::extractCodeObject(*Obj, Uri.Offset, Uri.Size, OutputFile))245return Err;246247return Error::success();248}249250// Utility function to format numbers with commas251static std::string formatWithCommas(unsigned long long Value) {252std::string Num = std::to_string(Value);253int InsertPosition = Num.length() - 3;254while (InsertPosition > 0) {255Num.insert(InsertPosition, ",");256InsertPosition -= 3;257}258return Num;259}260261llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>262CompressedOffloadBundle::decompress(llvm::MemoryBufferRef &Input,263bool Verbose) {264StringRef Blob = Input.getBuffer();265266if (Blob.size() < V1HeaderSize)267return llvm::MemoryBuffer::getMemBufferCopy(Blob);268269if (llvm::identify_magic(Blob) !=270llvm::file_magic::offload_bundle_compressed) {271if (Verbose)272llvm::errs() << "Uncompressed bundle.\n";273return llvm::MemoryBuffer::getMemBufferCopy(Blob);274}275276size_t CurrentOffset = MagicSize;277278uint16_t ThisVersion;279memcpy(&ThisVersion, Blob.data() + CurrentOffset, sizeof(uint16_t));280CurrentOffset += VersionFieldSize;281282uint16_t CompressionMethod;283memcpy(&CompressionMethod, Blob.data() + CurrentOffset, sizeof(uint16_t));284CurrentOffset += MethodFieldSize;285286uint32_t TotalFileSize;287if (ThisVersion >= 2) {288if (Blob.size() < V2HeaderSize)289return createStringError(inconvertibleErrorCode(),290"Compressed bundle header size too small");291memcpy(&TotalFileSize, Blob.data() + CurrentOffset, sizeof(uint32_t));292CurrentOffset += FileSizeFieldSize;293}294295uint32_t UncompressedSize;296memcpy(&UncompressedSize, Blob.data() + CurrentOffset, sizeof(uint32_t));297CurrentOffset += UncompressedSizeFieldSize;298299uint64_t StoredHash;300memcpy(&StoredHash, Blob.data() + CurrentOffset, sizeof(uint64_t));301CurrentOffset += HashFieldSize;302303llvm::compression::Format CompressionFormat;304if (CompressionMethod ==305static_cast<uint16_t>(llvm::compression::Format::Zlib))306CompressionFormat = llvm::compression::Format::Zlib;307else if (CompressionMethod ==308static_cast<uint16_t>(llvm::compression::Format::Zstd))309CompressionFormat = llvm::compression::Format::Zstd;310else311return createStringError(inconvertibleErrorCode(),312"Unknown compressing method");313314llvm::Timer DecompressTimer("Decompression Timer", "Decompression time",315OffloadBundlerTimerGroup);316if (Verbose)317DecompressTimer.startTimer();318319SmallVector<uint8_t, 0> DecompressedData;320StringRef CompressedData = Blob.substr(CurrentOffset);321if (llvm::Error DecompressionError = llvm::compression::decompress(322CompressionFormat, llvm::arrayRefFromStringRef(CompressedData),323DecompressedData, UncompressedSize))324return createStringError(inconvertibleErrorCode(),325"Could not decompress embedded file contents: " +326llvm::toString(std::move(DecompressionError)));327328if (Verbose) {329DecompressTimer.stopTimer();330331double DecompressionTimeSeconds =332DecompressTimer.getTotalTime().getWallTime();333334// Recalculate MD5 hash for integrity check.335llvm::Timer HashRecalcTimer("Hash Recalculation Timer",336"Hash recalculation time",337OffloadBundlerTimerGroup);338HashRecalcTimer.startTimer();339llvm::MD5 Hash;340llvm::MD5::MD5Result Result;341Hash.update(llvm::ArrayRef<uint8_t>(DecompressedData));342Hash.final(Result);343uint64_t RecalculatedHash = Result.low();344HashRecalcTimer.stopTimer();345bool HashMatch = (StoredHash == RecalculatedHash);346347double CompressionRate =348static_cast<double>(UncompressedSize) / CompressedData.size();349double DecompressionSpeedMBs =350(UncompressedSize / (1024.0 * 1024.0)) / DecompressionTimeSeconds;351352llvm::errs() << "Compressed bundle format version: " << ThisVersion << "\n";353if (ThisVersion >= 2)354llvm::errs() << "Total file size (from header): "355<< formatWithCommas(TotalFileSize) << " bytes\n";356llvm::errs() << "Decompression method: "357<< (CompressionFormat == llvm::compression::Format::Zlib358? "zlib"359: "zstd")360<< "\n"361<< "Size before decompression: "362<< formatWithCommas(CompressedData.size()) << " bytes\n"363<< "Size after decompression: "364<< formatWithCommas(UncompressedSize) << " bytes\n"365<< "Compression rate: "366<< llvm::format("%.2lf", CompressionRate) << "\n"367<< "Compression ratio: "368<< llvm::format("%.2lf%%", 100.0 / CompressionRate) << "\n"369<< "Decompression speed: "370<< llvm::format("%.2lf MB/s", DecompressionSpeedMBs) << "\n"371<< "Stored hash: " << llvm::format_hex(StoredHash, 16) << "\n"372<< "Recalculated hash: "373<< llvm::format_hex(RecalculatedHash, 16) << "\n"374<< "Hashes match: " << (HashMatch ? "Yes" : "No") << "\n";375}376377return llvm::MemoryBuffer::getMemBufferCopy(378llvm::toStringRef(DecompressedData));379}380381llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>382CompressedOffloadBundle::compress(llvm::compression::Params P,383const llvm::MemoryBuffer &Input,384bool Verbose) {385if (!llvm::compression::zstd::isAvailable() &&386!llvm::compression::zlib::isAvailable())387return createStringError(llvm::inconvertibleErrorCode(),388"Compression not supported");389390llvm::Timer HashTimer("Hash Calculation Timer", "Hash calculation time",391OffloadBundlerTimerGroup);392if (Verbose)393HashTimer.startTimer();394llvm::MD5 Hash;395llvm::MD5::MD5Result Result;396Hash.update(Input.getBuffer());397Hash.final(Result);398uint64_t TruncatedHash = Result.low();399if (Verbose)400HashTimer.stopTimer();401402SmallVector<uint8_t, 0> CompressedBuffer;403auto BufferUint8 = llvm::ArrayRef<uint8_t>(404reinterpret_cast<const uint8_t *>(Input.getBuffer().data()),405Input.getBuffer().size());406407llvm::Timer CompressTimer("Compression Timer", "Compression time",408OffloadBundlerTimerGroup);409if (Verbose)410CompressTimer.startTimer();411llvm::compression::compress(P, BufferUint8, CompressedBuffer);412if (Verbose)413CompressTimer.stopTimer();414415uint16_t CompressionMethod = static_cast<uint16_t>(P.format);416uint32_t UncompressedSize = Input.getBuffer().size();417uint32_t TotalFileSize = MagicNumber.size() + sizeof(TotalFileSize) +418sizeof(Version) + sizeof(CompressionMethod) +419sizeof(UncompressedSize) + sizeof(TruncatedHash) +420CompressedBuffer.size();421422SmallVector<char, 0> FinalBuffer;423llvm::raw_svector_ostream OS(FinalBuffer);424OS << MagicNumber;425OS.write(reinterpret_cast<const char *>(&Version), sizeof(Version));426OS.write(reinterpret_cast<const char *>(&CompressionMethod),427sizeof(CompressionMethod));428OS.write(reinterpret_cast<const char *>(&TotalFileSize),429sizeof(TotalFileSize));430OS.write(reinterpret_cast<const char *>(&UncompressedSize),431sizeof(UncompressedSize));432OS.write(reinterpret_cast<const char *>(&TruncatedHash),433sizeof(TruncatedHash));434OS.write(reinterpret_cast<const char *>(CompressedBuffer.data()),435CompressedBuffer.size());436437if (Verbose) {438auto MethodUsed =439P.format == llvm::compression::Format::Zstd ? "zstd" : "zlib";440double CompressionRate =441static_cast<double>(UncompressedSize) / CompressedBuffer.size();442double CompressionTimeSeconds = CompressTimer.getTotalTime().getWallTime();443double CompressionSpeedMBs =444(UncompressedSize / (1024.0 * 1024.0)) / CompressionTimeSeconds;445446llvm::errs() << "Compressed bundle format version: " << Version << "\n"447<< "Total file size (including headers): "448<< formatWithCommas(TotalFileSize) << " bytes\n"449<< "Compression method used: " << MethodUsed << "\n"450<< "Compression level: " << P.level << "\n"451<< "Binary size before compression: "452<< formatWithCommas(UncompressedSize) << " bytes\n"453<< "Binary size after compression: "454<< formatWithCommas(CompressedBuffer.size()) << " bytes\n"455<< "Compression rate: "456<< llvm::format("%.2lf", CompressionRate) << "\n"457<< "Compression ratio: "458<< llvm::format("%.2lf%%", 100.0 / CompressionRate) << "\n"459<< "Compression speed: "460<< llvm::format("%.2lf MB/s", CompressionSpeedMBs) << "\n"461<< "Truncated MD5 hash: "462<< llvm::format_hex(TruncatedHash, 16) << "\n";463}464return llvm::MemoryBuffer::getMemBufferCopy(465llvm::StringRef(FinalBuffer.data(), FinalBuffer.size()));466}467468469