Path: blob/main/contrib/llvm-project/llvm/lib/Object/DXContainer.cpp
35232 views
//===- DXContainer.cpp - DXContainer object file implementation -----------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//78#include "llvm/Object/DXContainer.h"9#include "llvm/BinaryFormat/DXContainer.h"10#include "llvm/Object/Error.h"11#include "llvm/Support/Alignment.h"12#include "llvm/Support/FormatVariadic.h"1314using namespace llvm;15using namespace llvm::object;1617static Error parseFailed(const Twine &Msg) {18return make_error<GenericBinaryError>(Msg.str(), object_error::parse_failed);19}2021template <typename T>22static Error readStruct(StringRef Buffer, const char *Src, T &Struct) {23// Don't read before the beginning or past the end of the file24if (Src < Buffer.begin() || Src + sizeof(T) > Buffer.end())25return parseFailed("Reading structure out of file bounds");2627memcpy(&Struct, Src, sizeof(T));28// DXContainer is always little endian29if (sys::IsBigEndianHost)30Struct.swapBytes();31return Error::success();32}3334template <typename T>35static Error readInteger(StringRef Buffer, const char *Src, T &Val,36Twine Str = "structure") {37static_assert(std::is_integral_v<T>,38"Cannot call readInteger on non-integral type.");39// Don't read before the beginning or past the end of the file40if (Src < Buffer.begin() || Src + sizeof(T) > Buffer.end())41return parseFailed(Twine("Reading ") + Str + " out of file bounds");4243// The DXContainer offset table is comprised of uint32_t values but not padded44// to a 64-bit boundary. So Parts may start unaligned if there is an odd45// number of parts and part data itself is not required to be padded.46if (reinterpret_cast<uintptr_t>(Src) % alignof(T) != 0)47memcpy(reinterpret_cast<char *>(&Val), Src, sizeof(T));48else49Val = *reinterpret_cast<const T *>(Src);50// DXContainer is always little endian51if (sys::IsBigEndianHost)52sys::swapByteOrder(Val);53return Error::success();54}5556DXContainer::DXContainer(MemoryBufferRef O) : Data(O) {}5758Error DXContainer::parseHeader() {59return readStruct(Data.getBuffer(), Data.getBuffer().data(), Header);60}6162Error DXContainer::parseDXILHeader(StringRef Part) {63if (DXIL)64return parseFailed("More than one DXIL part is present in the file");65const char *Current = Part.begin();66dxbc::ProgramHeader Header;67if (Error Err = readStruct(Part, Current, Header))68return Err;69Current += offsetof(dxbc::ProgramHeader, Bitcode) + Header.Bitcode.Offset;70DXIL.emplace(std::make_pair(Header, Current));71return Error::success();72}7374Error DXContainer::parseShaderFeatureFlags(StringRef Part) {75if (ShaderFeatureFlags)76return parseFailed("More than one SFI0 part is present in the file");77uint64_t FlagValue = 0;78if (Error Err = readInteger(Part, Part.begin(), FlagValue))79return Err;80ShaderFeatureFlags = FlagValue;81return Error::success();82}8384Error DXContainer::parseHash(StringRef Part) {85if (Hash)86return parseFailed("More than one HASH part is present in the file");87dxbc::ShaderHash ReadHash;88if (Error Err = readStruct(Part, Part.begin(), ReadHash))89return Err;90Hash = ReadHash;91return Error::success();92}9394Error DXContainer::parsePSVInfo(StringRef Part) {95if (PSVInfo)96return parseFailed("More than one PSV0 part is present in the file");97PSVInfo = DirectX::PSVRuntimeInfo(Part);98// Parsing the PSVRuntime info occurs late because we need to read data from99// other parts first.100return Error::success();101}102103Error DirectX::Signature::initialize(StringRef Part) {104dxbc::ProgramSignatureHeader SigHeader;105if (Error Err = readStruct(Part, Part.begin(), SigHeader))106return Err;107size_t Size = sizeof(dxbc::ProgramSignatureElement) * SigHeader.ParamCount;108109if (Part.size() < Size + SigHeader.FirstParamOffset)110return parseFailed("Signature parameters extend beyond the part boundary");111112Parameters.Data = Part.substr(SigHeader.FirstParamOffset, Size);113114StringTableOffset = SigHeader.FirstParamOffset + static_cast<uint32_t>(Size);115StringTable = Part.substr(SigHeader.FirstParamOffset + Size);116117for (const auto &Param : Parameters) {118if (Param.NameOffset < StringTableOffset)119return parseFailed("Invalid parameter name offset: name starts before "120"the first name offset");121if (Param.NameOffset - StringTableOffset > StringTable.size())122return parseFailed("Invalid parameter name offset: name starts after the "123"end of the part data");124}125return Error::success();126}127128Error DXContainer::parsePartOffsets() {129uint32_t LastOffset =130sizeof(dxbc::Header) + (Header.PartCount * sizeof(uint32_t));131const char *Current = Data.getBuffer().data() + sizeof(dxbc::Header);132for (uint32_t Part = 0; Part < Header.PartCount; ++Part) {133uint32_t PartOffset;134if (Error Err = readInteger(Data.getBuffer(), Current, PartOffset))135return Err;136if (PartOffset < LastOffset)137return parseFailed(138formatv(139"Part offset for part {0} begins before the previous part ends",140Part)141.str());142Current += sizeof(uint32_t);143if (PartOffset >= Data.getBufferSize())144return parseFailed("Part offset points beyond boundary of the file");145// To prevent overflow when reading the part name, we subtract the part name146// size from the buffer size, rather than adding to the offset. Since the147// file header is larger than the part header we can't reach this code148// unless the buffer is at least as large as a part header, so this149// subtraction can't underflow.150if (PartOffset >= Data.getBufferSize() - sizeof(dxbc::PartHeader::Name))151return parseFailed("File not large enough to read part name");152PartOffsets.push_back(PartOffset);153154dxbc::PartType PT =155dxbc::parsePartType(Data.getBuffer().substr(PartOffset, 4));156uint32_t PartDataStart = PartOffset + sizeof(dxbc::PartHeader);157uint32_t PartSize;158if (Error Err = readInteger(Data.getBuffer(),159Data.getBufferStart() + PartOffset + 4,160PartSize, "part size"))161return Err;162StringRef PartData = Data.getBuffer().substr(PartDataStart, PartSize);163LastOffset = PartOffset + PartSize;164switch (PT) {165case dxbc::PartType::DXIL:166if (Error Err = parseDXILHeader(PartData))167return Err;168break;169case dxbc::PartType::SFI0:170if (Error Err = parseShaderFeatureFlags(PartData))171return Err;172break;173case dxbc::PartType::HASH:174if (Error Err = parseHash(PartData))175return Err;176break;177case dxbc::PartType::PSV0:178if (Error Err = parsePSVInfo(PartData))179return Err;180break;181case dxbc::PartType::ISG1:182if (Error Err = InputSignature.initialize(PartData))183return Err;184break;185case dxbc::PartType::OSG1:186if (Error Err = OutputSignature.initialize(PartData))187return Err;188break;189case dxbc::PartType::PSG1:190if (Error Err = PatchConstantSignature.initialize(PartData))191return Err;192break;193case dxbc::PartType::Unknown:194break;195}196}197198// Fully parsing the PSVInfo requires knowing the shader kind which we read199// out of the program header in the DXIL part.200if (PSVInfo) {201if (!DXIL)202return parseFailed("Cannot fully parse pipeline state validation "203"information without DXIL part.");204if (Error Err = PSVInfo->parse(DXIL->first.ShaderKind))205return Err;206}207return Error::success();208}209210Expected<DXContainer> DXContainer::create(MemoryBufferRef Object) {211DXContainer Container(Object);212if (Error Err = Container.parseHeader())213return std::move(Err);214if (Error Err = Container.parsePartOffsets())215return std::move(Err);216return Container;217}218219void DXContainer::PartIterator::updateIteratorImpl(const uint32_t Offset) {220StringRef Buffer = Container.Data.getBuffer();221const char *Current = Buffer.data() + Offset;222// Offsets are validated during parsing, so all offsets in the container are223// valid and contain enough readable data to read a header.224cantFail(readStruct(Buffer, Current, IteratorState.Part));225IteratorState.Data =226StringRef(Current + sizeof(dxbc::PartHeader), IteratorState.Part.Size);227IteratorState.Offset = Offset;228}229230Error DirectX::PSVRuntimeInfo::parse(uint16_t ShaderKind) {231Triple::EnvironmentType ShaderStage = dxbc::getShaderStage(ShaderKind);232233const char *Current = Data.begin();234if (Error Err = readInteger(Data, Current, Size))235return Err;236Current += sizeof(uint32_t);237238StringRef PSVInfoData = Data.substr(sizeof(uint32_t), Size);239240if (PSVInfoData.size() < Size)241return parseFailed(242"Pipeline state data extends beyond the bounds of the part");243244using namespace dxbc::PSV;245246const uint32_t PSVVersion = getVersion();247248// Detect the PSVVersion by looking at the size field.249if (PSVVersion == 3) {250v3::RuntimeInfo Info;251if (Error Err = readStruct(PSVInfoData, Current, Info))252return Err;253if (sys::IsBigEndianHost)254Info.swapBytes(ShaderStage);255BasicInfo = Info;256} else if (PSVVersion == 2) {257v2::RuntimeInfo Info;258if (Error Err = readStruct(PSVInfoData, Current, Info))259return Err;260if (sys::IsBigEndianHost)261Info.swapBytes(ShaderStage);262BasicInfo = Info;263} else if (PSVVersion == 1) {264v1::RuntimeInfo Info;265if (Error Err = readStruct(PSVInfoData, Current, Info))266return Err;267if (sys::IsBigEndianHost)268Info.swapBytes(ShaderStage);269BasicInfo = Info;270} else if (PSVVersion == 0) {271v0::RuntimeInfo Info;272if (Error Err = readStruct(PSVInfoData, Current, Info))273return Err;274if (sys::IsBigEndianHost)275Info.swapBytes(ShaderStage);276BasicInfo = Info;277} else278return parseFailed(279"Cannot read PSV Runtime Info, unsupported PSV version.");280281Current += Size;282283uint32_t ResourceCount = 0;284if (Error Err = readInteger(Data, Current, ResourceCount))285return Err;286Current += sizeof(uint32_t);287288if (ResourceCount > 0) {289if (Error Err = readInteger(Data, Current, Resources.Stride))290return Err;291Current += sizeof(uint32_t);292293size_t BindingDataSize = Resources.Stride * ResourceCount;294Resources.Data = Data.substr(Current - Data.begin(), BindingDataSize);295296if (Resources.Data.size() < BindingDataSize)297return parseFailed(298"Resource binding data extends beyond the bounds of the part");299300Current += BindingDataSize;301} else302Resources.Stride = sizeof(v2::ResourceBindInfo);303304// PSV version 0 ends after the resource bindings.305if (PSVVersion == 0)306return Error::success();307308// String table starts at a 4-byte offset.309Current = reinterpret_cast<const char *>(310alignTo<4>(reinterpret_cast<uintptr_t>(Current)));311312uint32_t StringTableSize = 0;313if (Error Err = readInteger(Data, Current, StringTableSize))314return Err;315if (StringTableSize % 4 != 0)316return parseFailed("String table misaligned");317Current += sizeof(uint32_t);318StringTable = StringRef(Current, StringTableSize);319320Current += StringTableSize;321322uint32_t SemanticIndexTableSize = 0;323if (Error Err = readInteger(Data, Current, SemanticIndexTableSize))324return Err;325Current += sizeof(uint32_t);326327SemanticIndexTable.reserve(SemanticIndexTableSize);328for (uint32_t I = 0; I < SemanticIndexTableSize; ++I) {329uint32_t Index = 0;330if (Error Err = readInteger(Data, Current, Index))331return Err;332Current += sizeof(uint32_t);333SemanticIndexTable.push_back(Index);334}335336uint8_t InputCount = getSigInputCount();337uint8_t OutputCount = getSigOutputCount();338uint8_t PatchOrPrimCount = getSigPatchOrPrimCount();339340uint32_t ElementCount = InputCount + OutputCount + PatchOrPrimCount;341342if (ElementCount > 0) {343if (Error Err = readInteger(Data, Current, SigInputElements.Stride))344return Err;345Current += sizeof(uint32_t);346// Assign the stride to all the arrays.347SigOutputElements.Stride = SigPatchOrPrimElements.Stride =348SigInputElements.Stride;349350if (Data.end() - Current <351(ptrdiff_t)(ElementCount * SigInputElements.Stride))352return parseFailed(353"Signature elements extend beyond the size of the part");354355size_t InputSize = SigInputElements.Stride * InputCount;356SigInputElements.Data = Data.substr(Current - Data.begin(), InputSize);357Current += InputSize;358359size_t OutputSize = SigOutputElements.Stride * OutputCount;360SigOutputElements.Data = Data.substr(Current - Data.begin(), OutputSize);361Current += OutputSize;362363size_t PSize = SigPatchOrPrimElements.Stride * PatchOrPrimCount;364SigPatchOrPrimElements.Data = Data.substr(Current - Data.begin(), PSize);365Current += PSize;366}367368ArrayRef<uint8_t> OutputVectorCounts = getOutputVectorCounts();369uint8_t PatchConstOrPrimVectorCount = getPatchConstOrPrimVectorCount();370uint8_t InputVectorCount = getInputVectorCount();371372auto maskDwordSize = [](uint8_t Vector) {373return (static_cast<uint32_t>(Vector) + 7) >> 3;374};375376auto mapTableSize = [maskDwordSize](uint8_t X, uint8_t Y) {377return maskDwordSize(Y) * X * 4;378};379380if (usesViewID()) {381for (uint32_t I = 0; I < OutputVectorCounts.size(); ++I) {382// The vector mask is one bit per component and 4 components per vector.383// We can compute the number of dwords required by rounding up to the next384// multiple of 8.385uint32_t NumDwords =386maskDwordSize(static_cast<uint32_t>(OutputVectorCounts[I]));387size_t NumBytes = NumDwords * sizeof(uint32_t);388OutputVectorMasks[I].Data = Data.substr(Current - Data.begin(), NumBytes);389Current += NumBytes;390}391392if (ShaderStage == Triple::Hull && PatchConstOrPrimVectorCount > 0) {393uint32_t NumDwords = maskDwordSize(PatchConstOrPrimVectorCount);394size_t NumBytes = NumDwords * sizeof(uint32_t);395PatchOrPrimMasks.Data = Data.substr(Current - Data.begin(), NumBytes);396Current += NumBytes;397}398}399400// Input/Output mapping table401for (uint32_t I = 0; I < OutputVectorCounts.size(); ++I) {402if (InputVectorCount == 0 || OutputVectorCounts[I] == 0)403continue;404uint32_t NumDwords = mapTableSize(InputVectorCount, OutputVectorCounts[I]);405size_t NumBytes = NumDwords * sizeof(uint32_t);406InputOutputMap[I].Data = Data.substr(Current - Data.begin(), NumBytes);407Current += NumBytes;408}409410// Hull shader: Input/Patch mapping table411if (ShaderStage == Triple::Hull && PatchConstOrPrimVectorCount > 0 &&412InputVectorCount > 0) {413uint32_t NumDwords =414mapTableSize(InputVectorCount, PatchConstOrPrimVectorCount);415size_t NumBytes = NumDwords * sizeof(uint32_t);416InputPatchMap.Data = Data.substr(Current - Data.begin(), NumBytes);417Current += NumBytes;418}419420// Domain Shader: Patch/Output mapping table421if (ShaderStage == Triple::Domain && PatchConstOrPrimVectorCount > 0 &&422OutputVectorCounts[0] > 0) {423uint32_t NumDwords =424mapTableSize(PatchConstOrPrimVectorCount, OutputVectorCounts[0]);425size_t NumBytes = NumDwords * sizeof(uint32_t);426PatchOutputMap.Data = Data.substr(Current - Data.begin(), NumBytes);427Current += NumBytes;428}429430return Error::success();431}432433uint8_t DirectX::PSVRuntimeInfo::getSigInputCount() const {434if (const auto *P = std::get_if<dxbc::PSV::v3::RuntimeInfo>(&BasicInfo))435return P->SigInputElements;436if (const auto *P = std::get_if<dxbc::PSV::v2::RuntimeInfo>(&BasicInfo))437return P->SigInputElements;438if (const auto *P = std::get_if<dxbc::PSV::v1::RuntimeInfo>(&BasicInfo))439return P->SigInputElements;440return 0;441}442443uint8_t DirectX::PSVRuntimeInfo::getSigOutputCount() const {444if (const auto *P = std::get_if<dxbc::PSV::v3::RuntimeInfo>(&BasicInfo))445return P->SigOutputElements;446if (const auto *P = std::get_if<dxbc::PSV::v2::RuntimeInfo>(&BasicInfo))447return P->SigOutputElements;448if (const auto *P = std::get_if<dxbc::PSV::v1::RuntimeInfo>(&BasicInfo))449return P->SigOutputElements;450return 0;451}452453uint8_t DirectX::PSVRuntimeInfo::getSigPatchOrPrimCount() const {454if (const auto *P = std::get_if<dxbc::PSV::v3::RuntimeInfo>(&BasicInfo))455return P->SigPatchOrPrimElements;456if (const auto *P = std::get_if<dxbc::PSV::v2::RuntimeInfo>(&BasicInfo))457return P->SigPatchOrPrimElements;458if (const auto *P = std::get_if<dxbc::PSV::v1::RuntimeInfo>(&BasicInfo))459return P->SigPatchOrPrimElements;460return 0;461}462463464