Path: blob/main/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
35266 views
//===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// The analysis collects instructions that should be output at the module level9// and performs the global register numbering.10//11// The results of this analysis are used in AsmPrinter to rename registers12// globally and to output required instructions at the module level.13//14//===----------------------------------------------------------------------===//1516#include "SPIRVModuleAnalysis.h"17#include "MCTargetDesc/SPIRVBaseInfo.h"18#include "MCTargetDesc/SPIRVMCTargetDesc.h"19#include "SPIRV.h"20#include "SPIRVSubtarget.h"21#include "SPIRVTargetMachine.h"22#include "SPIRVUtils.h"23#include "TargetInfo/SPIRVTargetInfo.h"24#include "llvm/ADT/STLExtras.h"25#include "llvm/CodeGen/MachineModuleInfo.h"26#include "llvm/CodeGen/TargetPassConfig.h"2728using namespace llvm;2930#define DEBUG_TYPE "spirv-module-analysis"3132static cl::opt<bool>33SPVDumpDeps("spv-dump-deps",34cl::desc("Dump MIR with SPIR-V dependencies info"),35cl::Optional, cl::init(false));3637static cl::list<SPIRV::Capability::Capability>38AvoidCapabilities("avoid-spirv-capabilities",39cl::desc("SPIR-V capabilities to avoid if there are "40"other options enabling a feature"),41cl::ZeroOrMore, cl::Hidden,42cl::values(clEnumValN(SPIRV::Capability::Shader, "Shader",43"SPIR-V Shader capability")));44// Use sets instead of cl::list to check "if contains" condition45struct AvoidCapabilitiesSet {46SmallSet<SPIRV::Capability::Capability, 4> S;47AvoidCapabilitiesSet() {48for (auto Cap : AvoidCapabilities)49S.insert(Cap);50}51};5253char llvm::SPIRVModuleAnalysis::ID = 0;5455namespace llvm {56void initializeSPIRVModuleAnalysisPass(PassRegistry &);57} // namespace llvm5859INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true,60true)6162// Retrieve an unsigned from an MDNode with a list of them as operands.63static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex,64unsigned DefaultVal = 0) {65if (MdNode && OpIndex < MdNode->getNumOperands()) {66const auto &Op = MdNode->getOperand(OpIndex);67return mdconst::extract<ConstantInt>(Op)->getZExtValue();68}69return DefaultVal;70}7172static SPIRV::Requirements73getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category,74unsigned i, const SPIRVSubtarget &ST,75SPIRV::RequirementHandler &Reqs) {76static AvoidCapabilitiesSet77AvoidCaps; // contains capabilities to avoid if there is another option7879VersionTuple ReqMinVer = getSymbolicOperandMinVersion(Category, i);80VersionTuple ReqMaxVer = getSymbolicOperandMaxVersion(Category, i);81VersionTuple SPIRVVersion = ST.getSPIRVVersion();82bool MinVerOK = SPIRVVersion.empty() || SPIRVVersion >= ReqMinVer;83bool MaxVerOK =84ReqMaxVer.empty() || SPIRVVersion.empty() || SPIRVVersion <= ReqMaxVer;85CapabilityList ReqCaps = getSymbolicOperandCapabilities(Category, i);86ExtensionList ReqExts = getSymbolicOperandExtensions(Category, i);87if (ReqCaps.empty()) {88if (ReqExts.empty()) {89if (MinVerOK && MaxVerOK)90return {true, {}, {}, ReqMinVer, ReqMaxVer};91return {false, {}, {}, VersionTuple(), VersionTuple()};92}93} else if (MinVerOK && MaxVerOK) {94if (ReqCaps.size() == 1) {95auto Cap = ReqCaps[0];96if (Reqs.isCapabilityAvailable(Cap))97return {true, {Cap}, ReqExts, ReqMinVer, ReqMaxVer};98} else {99// By SPIR-V specification: "If an instruction, enumerant, or other100// feature specifies multiple enabling capabilities, only one such101// capability needs to be declared to use the feature." However, one102// capability may be preferred over another. We use command line103// argument(s) and AvoidCapabilities to avoid selection of certain104// capabilities if there are other options.105CapabilityList UseCaps;106for (auto Cap : ReqCaps)107if (Reqs.isCapabilityAvailable(Cap))108UseCaps.push_back(Cap);109for (size_t i = 0, Sz = UseCaps.size(); i < Sz; ++i) {110auto Cap = UseCaps[i];111if (i == Sz - 1 || !AvoidCaps.S.contains(Cap))112return {true, {Cap}, ReqExts, ReqMinVer, ReqMaxVer};113}114}115}116// If there are no capabilities, or we can't satisfy the version or117// capability requirements, use the list of extensions (if the subtarget118// can handle them all).119if (llvm::all_of(ReqExts, [&ST](const SPIRV::Extension::Extension &Ext) {120return ST.canUseExtension(Ext);121})) {122return {true,123{},124ReqExts,125VersionTuple(),126VersionTuple()}; // TODO: add versions to extensions.127}128return {false, {}, {}, VersionTuple(), VersionTuple()};129}130131void SPIRVModuleAnalysis::setBaseInfo(const Module &M) {132MAI.MaxID = 0;133for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++)134MAI.MS[i].clear();135MAI.RegisterAliasTable.clear();136MAI.InstrsToDelete.clear();137MAI.FuncMap.clear();138MAI.GlobalVarList.clear();139MAI.ExtInstSetMap.clear();140MAI.Reqs.clear();141MAI.Reqs.initAvailableCapabilities(*ST);142143// TODO: determine memory model and source language from the configuratoin.144if (auto MemModel = M.getNamedMetadata("spirv.MemoryModel")) {145auto MemMD = MemModel->getOperand(0);146MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>(147getMetadataUInt(MemMD, 0));148MAI.Mem =149static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MemMD, 1));150} else {151// TODO: Add support for VulkanMemoryModel.152MAI.Mem = ST->isOpenCLEnv() ? SPIRV::MemoryModel::OpenCL153: SPIRV::MemoryModel::GLSL450;154if (MAI.Mem == SPIRV::MemoryModel::OpenCL) {155unsigned PtrSize = ST->getPointerSize();156MAI.Addr = PtrSize == 32 ? SPIRV::AddressingModel::Physical32157: PtrSize == 64 ? SPIRV::AddressingModel::Physical64158: SPIRV::AddressingModel::Logical;159} else {160// TODO: Add support for PhysicalStorageBufferAddress.161MAI.Addr = SPIRV::AddressingModel::Logical;162}163}164// Get the OpenCL version number from metadata.165// TODO: support other source languages.166if (auto VerNode = M.getNamedMetadata("opencl.ocl.version")) {167MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C;168// Construct version literal in accordance with SPIRV-LLVM-Translator.169// TODO: support multiple OCL version metadata.170assert(VerNode->getNumOperands() > 0 && "Invalid SPIR");171auto VersionMD = VerNode->getOperand(0);172unsigned MajorNum = getMetadataUInt(VersionMD, 0, 2);173unsigned MinorNum = getMetadataUInt(VersionMD, 1);174unsigned RevNum = getMetadataUInt(VersionMD, 2);175// Prevent Major part of OpenCL version to be 0176MAI.SrcLangVersion =177(std::max(1U, MajorNum) * 100 + MinorNum) * 1000 + RevNum;178} else {179// If there is no information about OpenCL version we are forced to generate180// OpenCL 1.0 by default for the OpenCL environment to avoid puzzling181// run-times with Unknown/0.0 version output. For a reference, LLVM-SPIRV182// Translator avoids potential issues with run-times in a similar manner.183if (ST->isOpenCLEnv()) {184MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_CPP;185MAI.SrcLangVersion = 100000;186} else {187MAI.SrcLang = SPIRV::SourceLanguage::Unknown;188MAI.SrcLangVersion = 0;189}190}191192if (auto ExtNode = M.getNamedMetadata("opencl.used.extensions")) {193for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) {194MDNode *MD = ExtNode->getOperand(I);195if (!MD || MD->getNumOperands() == 0)196continue;197for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J)198MAI.SrcExt.insert(cast<MDString>(MD->getOperand(J))->getString());199}200}201202// Update required capabilities for this memory model, addressing model and203// source language.204MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand,205MAI.Mem, *ST);206MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::SourceLanguageOperand,207MAI.SrcLang, *ST);208MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,209MAI.Addr, *ST);210211if (ST->isOpenCLEnv()) {212// TODO: check if it's required by default.213MAI.ExtInstSetMap[static_cast<unsigned>(214SPIRV::InstructionSet::OpenCL_std)] =215Register::index2VirtReg(MAI.getNextID());216}217}218219// Collect MI which defines the register in the given machine function.220static void collectDefInstr(Register Reg, const MachineFunction *MF,221SPIRV::ModuleAnalysisInfo *MAI,222SPIRV::ModuleSectionType MSType,223bool DoInsert = true) {224assert(MAI->hasRegisterAlias(MF, Reg) && "Cannot find register alias");225MachineInstr *MI = MF->getRegInfo().getUniqueVRegDef(Reg);226assert(MI && "There should be an instruction that defines the register");227MAI->setSkipEmission(MI);228if (DoInsert)229MAI->MS[MSType].push_back(MI);230}231232void SPIRVModuleAnalysis::collectGlobalEntities(233const std::vector<SPIRV::DTSortableEntry *> &DepsGraph,234SPIRV::ModuleSectionType MSType,235std::function<bool(const SPIRV::DTSortableEntry *)> Pred,236bool UsePreOrder = false) {237DenseSet<const SPIRV::DTSortableEntry *> Visited;238for (const auto *E : DepsGraph) {239std::function<void(const SPIRV::DTSortableEntry *)> RecHoistUtil;240// NOTE: here we prefer recursive approach over iterative because241// we don't expect depchains long enough to cause SO.242RecHoistUtil = [MSType, UsePreOrder, &Visited, &Pred,243&RecHoistUtil](const SPIRV::DTSortableEntry *E) {244if (Visited.count(E) || !Pred(E))245return;246Visited.insert(E);247248// Traversing deps graph in post-order allows us to get rid of249// register aliases preprocessing.250// But pre-order is required for correct processing of function251// declaration and arguments processing.252if (!UsePreOrder)253for (auto *S : E->getDeps())254RecHoistUtil(S);255256Register GlobalReg = Register::index2VirtReg(MAI.getNextID());257bool IsFirst = true;258for (auto &U : *E) {259const MachineFunction *MF = U.first;260Register Reg = U.second;261MAI.setRegisterAlias(MF, Reg, GlobalReg);262if (!MF->getRegInfo().getUniqueVRegDef(Reg))263continue;264collectDefInstr(Reg, MF, &MAI, MSType, IsFirst);265IsFirst = false;266if (E->getIsGV())267MAI.GlobalVarList.push_back(MF->getRegInfo().getUniqueVRegDef(Reg));268}269270if (UsePreOrder)271for (auto *S : E->getDeps())272RecHoistUtil(S);273};274RecHoistUtil(E);275}276}277278// The function initializes global register alias table for types, consts,279// global vars and func decls and collects these instruction for output280// at module level. Also it collects explicit OpExtension/OpCapability281// instructions.282void SPIRVModuleAnalysis::processDefInstrs(const Module &M) {283std::vector<SPIRV::DTSortableEntry *> DepsGraph;284285GR->buildDepsGraph(DepsGraph, SPVDumpDeps ? MMI : nullptr);286287collectGlobalEntities(288DepsGraph, SPIRV::MB_TypeConstVars,289[](const SPIRV::DTSortableEntry *E) { return !E->getIsFunc(); });290291for (auto F = M.begin(), E = M.end(); F != E; ++F) {292MachineFunction *MF = MMI->getMachineFunction(*F);293if (!MF)294continue;295// Iterate through and collect OpExtension/OpCapability instructions.296for (MachineBasicBlock &MBB : *MF) {297for (MachineInstr &MI : MBB) {298if (MI.getOpcode() == SPIRV::OpExtension) {299// Here, OpExtension just has a single enum operand, not a string.300auto Ext = SPIRV::Extension::Extension(MI.getOperand(0).getImm());301MAI.Reqs.addExtension(Ext);302MAI.setSkipEmission(&MI);303} else if (MI.getOpcode() == SPIRV::OpCapability) {304auto Cap = SPIRV::Capability::Capability(MI.getOperand(0).getImm());305MAI.Reqs.addCapability(Cap);306MAI.setSkipEmission(&MI);307}308}309}310}311312collectGlobalEntities(313DepsGraph, SPIRV::MB_ExtFuncDecls,314[](const SPIRV::DTSortableEntry *E) { return E->getIsFunc(); }, true);315}316317// Look for IDs declared with Import linkage, and map the corresponding function318// to the register defining that variable (which will usually be the result of319// an OpFunction). This lets us call externally imported functions using320// the correct ID registers.321void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI,322const Function *F) {323if (MI.getOpcode() == SPIRV::OpDecorate) {324// If it's got Import linkage.325auto Dec = MI.getOperand(1).getImm();326if (Dec == static_cast<unsigned>(SPIRV::Decoration::LinkageAttributes)) {327auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm();328if (Lnk == static_cast<unsigned>(SPIRV::LinkageType::Import)) {329// Map imported function name to function ID register.330const Function *ImportedFunc =331F->getParent()->getFunction(getStringImm(MI, 2));332Register Target = MI.getOperand(0).getReg();333MAI.FuncMap[ImportedFunc] = MAI.getRegisterAlias(MI.getMF(), Target);334}335}336} else if (MI.getOpcode() == SPIRV::OpFunction) {337// Record all internal OpFunction declarations.338Register Reg = MI.defs().begin()->getReg();339Register GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg);340assert(GlobalReg.isValid());341MAI.FuncMap[F] = GlobalReg;342}343}344345// References to a function via function pointers generate virtual346// registers without a definition. We are able to resolve this347// reference using Globar Register info into an OpFunction instruction348// and replace dummy operands by the corresponding global register references.349void SPIRVModuleAnalysis::collectFuncPtrs() {350for (auto &MI : MAI.MS[SPIRV::MB_TypeConstVars])351if (MI->getOpcode() == SPIRV::OpConstantFunctionPointerINTEL)352collectFuncPtrs(MI);353}354355void SPIRVModuleAnalysis::collectFuncPtrs(MachineInstr *MI) {356const MachineOperand *FunUse = &MI->getOperand(2);357if (const MachineOperand *FunDef = GR->getFunctionDefinitionByUse(FunUse)) {358const MachineInstr *FunDefMI = FunDef->getParent();359assert(FunDefMI->getOpcode() == SPIRV::OpFunction &&360"Constant function pointer must refer to function definition");361Register FunDefReg = FunDef->getReg();362Register GlobalFunDefReg =363MAI.getRegisterAlias(FunDefMI->getMF(), FunDefReg);364assert(GlobalFunDefReg.isValid() &&365"Function definition must refer to a global register");366Register FunPtrReg = FunUse->getReg();367MAI.setRegisterAlias(MI->getMF(), FunPtrReg, GlobalFunDefReg);368}369}370371using InstrSignature = SmallVector<size_t>;372using InstrTraces = std::set<InstrSignature>;373374// Returns a representation of an instruction as a vector of MachineOperand375// hash values, see llvm::hash_value(const MachineOperand &MO) for details.376// This creates a signature of the instruction with the same content377// that MachineOperand::isIdenticalTo uses for comparison.378static InstrSignature instrToSignature(MachineInstr &MI,379SPIRV::ModuleAnalysisInfo &MAI) {380InstrSignature Signature;381for (unsigned i = 0; i < MI.getNumOperands(); ++i) {382const MachineOperand &MO = MI.getOperand(i);383size_t h;384if (MO.isReg()) {385Register RegAlias = MAI.getRegisterAlias(MI.getMF(), MO.getReg());386// mimic llvm::hash_value(const MachineOperand &MO)387h = hash_combine(MO.getType(), (unsigned)RegAlias, MO.getSubReg(),388MO.isDef());389} else {390h = hash_value(MO);391}392Signature.push_back(h);393}394return Signature;395}396397// Collect the given instruction in the specified MS. We assume global register398// numbering has already occurred by this point. We can directly compare reg399// arguments when detecting duplicates.400static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI,401SPIRV::ModuleSectionType MSType, InstrTraces &IS,402bool Append = true) {403MAI.setSkipEmission(&MI);404InstrSignature MISign = instrToSignature(MI, MAI);405auto FoundMI = IS.insert(MISign);406if (!FoundMI.second)407return; // insert failed, so we found a duplicate; don't add it to MAI.MS408// No duplicates, so add it.409if (Append)410MAI.MS[MSType].push_back(&MI);411else412MAI.MS[MSType].insert(MAI.MS[MSType].begin(), &MI);413}414415// Some global instructions make reference to function-local ID regs, so cannot416// be correctly collected until these registers are globally numbered.417void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) {418InstrTraces IS;419for (auto F = M.begin(), E = M.end(); F != E; ++F) {420if ((*F).isDeclaration())421continue;422MachineFunction *MF = MMI->getMachineFunction(*F);423assert(MF);424for (MachineBasicBlock &MBB : *MF)425for (MachineInstr &MI : MBB) {426if (MAI.getSkipEmission(&MI))427continue;428const unsigned OpCode = MI.getOpcode();429if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) {430collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames, IS);431} else if (OpCode == SPIRV::OpEntryPoint) {432collectOtherInstr(MI, MAI, SPIRV::MB_EntryPoints, IS);433} else if (TII->isDecorationInstr(MI)) {434collectOtherInstr(MI, MAI, SPIRV::MB_Annotations, IS);435collectFuncNames(MI, &*F);436} else if (TII->isConstantInstr(MI)) {437// Now OpSpecConstant*s are not in DT,438// but they need to be collected anyway.439collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS);440} else if (OpCode == SPIRV::OpFunction) {441collectFuncNames(MI, &*F);442} else if (OpCode == SPIRV::OpTypeForwardPointer) {443collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS, false);444}445}446}447}448449// Number registers in all functions globally from 0 onwards and store450// the result in global register alias table. Some registers are already451// numbered in collectGlobalEntities.452void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) {453for (auto F = M.begin(), E = M.end(); F != E; ++F) {454if ((*F).isDeclaration())455continue;456MachineFunction *MF = MMI->getMachineFunction(*F);457assert(MF);458for (MachineBasicBlock &MBB : *MF) {459for (MachineInstr &MI : MBB) {460for (MachineOperand &Op : MI.operands()) {461if (!Op.isReg())462continue;463Register Reg = Op.getReg();464if (MAI.hasRegisterAlias(MF, Reg))465continue;466Register NewReg = Register::index2VirtReg(MAI.getNextID());467MAI.setRegisterAlias(MF, Reg, NewReg);468}469if (MI.getOpcode() != SPIRV::OpExtInst)470continue;471auto Set = MI.getOperand(2).getImm();472if (!MAI.ExtInstSetMap.contains(Set))473MAI.ExtInstSetMap[Set] = Register::index2VirtReg(MAI.getNextID());474}475}476}477}478479// RequirementHandler implementations.480void SPIRV::RequirementHandler::getAndAddRequirements(481SPIRV::OperandCategory::OperandCategory Category, uint32_t i,482const SPIRVSubtarget &ST) {483addRequirements(getSymbolicOperandRequirements(Category, i, ST, *this));484}485486void SPIRV::RequirementHandler::recursiveAddCapabilities(487const CapabilityList &ToPrune) {488for (const auto &Cap : ToPrune) {489AllCaps.insert(Cap);490CapabilityList ImplicitDecls =491getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);492recursiveAddCapabilities(ImplicitDecls);493}494}495496void SPIRV::RequirementHandler::addCapabilities(const CapabilityList &ToAdd) {497for (const auto &Cap : ToAdd) {498bool IsNewlyInserted = AllCaps.insert(Cap).second;499if (!IsNewlyInserted) // Don't re-add if it's already been declared.500continue;501CapabilityList ImplicitDecls =502getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);503recursiveAddCapabilities(ImplicitDecls);504MinimalCaps.push_back(Cap);505}506}507508void SPIRV::RequirementHandler::addRequirements(509const SPIRV::Requirements &Req) {510if (!Req.IsSatisfiable)511report_fatal_error("Adding SPIR-V requirements this target can't satisfy.");512513if (Req.Cap.has_value())514addCapabilities({Req.Cap.value()});515516addExtensions(Req.Exts);517518if (!Req.MinVer.empty()) {519if (!MaxVersion.empty() && Req.MinVer > MaxVersion) {520LLVM_DEBUG(dbgs() << "Conflicting version requirements: >= " << Req.MinVer521<< " and <= " << MaxVersion << "\n");522report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");523}524525if (MinVersion.empty() || Req.MinVer > MinVersion)526MinVersion = Req.MinVer;527}528529if (!Req.MaxVer.empty()) {530if (!MinVersion.empty() && Req.MaxVer < MinVersion) {531LLVM_DEBUG(dbgs() << "Conflicting version requirements: <= " << Req.MaxVer532<< " and >= " << MinVersion << "\n");533report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");534}535536if (MaxVersion.empty() || Req.MaxVer < MaxVersion)537MaxVersion = Req.MaxVer;538}539}540541void SPIRV::RequirementHandler::checkSatisfiable(542const SPIRVSubtarget &ST) const {543// Report as many errors as possible before aborting the compilation.544bool IsSatisfiable = true;545auto TargetVer = ST.getSPIRVVersion();546547if (!MaxVersion.empty() && !TargetVer.empty() && MaxVersion < TargetVer) {548LLVM_DEBUG(549dbgs() << "Target SPIR-V version too high for required features\n"550<< "Required max version: " << MaxVersion << " target version "551<< TargetVer << "\n");552IsSatisfiable = false;553}554555if (!MinVersion.empty() && !TargetVer.empty() && MinVersion > TargetVer) {556LLVM_DEBUG(dbgs() << "Target SPIR-V version too low for required features\n"557<< "Required min version: " << MinVersion558<< " target version " << TargetVer << "\n");559IsSatisfiable = false;560}561562if (!MinVersion.empty() && !MaxVersion.empty() && MinVersion > MaxVersion) {563LLVM_DEBUG(564dbgs()565<< "Version is too low for some features and too high for others.\n"566<< "Required SPIR-V min version: " << MinVersion567<< " required SPIR-V max version " << MaxVersion << "\n");568IsSatisfiable = false;569}570571for (auto Cap : MinimalCaps) {572if (AvailableCaps.contains(Cap))573continue;574LLVM_DEBUG(dbgs() << "Capability not supported: "575<< getSymbolicOperandMnemonic(576OperandCategory::CapabilityOperand, Cap)577<< "\n");578IsSatisfiable = false;579}580581for (auto Ext : AllExtensions) {582if (ST.canUseExtension(Ext))583continue;584LLVM_DEBUG(dbgs() << "Extension not supported: "585<< getSymbolicOperandMnemonic(586OperandCategory::ExtensionOperand, Ext)587<< "\n");588IsSatisfiable = false;589}590591if (!IsSatisfiable)592report_fatal_error("Unable to meet SPIR-V requirements for this target.");593}594595// Add the given capabilities and all their implicitly defined capabilities too.596void SPIRV::RequirementHandler::addAvailableCaps(const CapabilityList &ToAdd) {597for (const auto Cap : ToAdd)598if (AvailableCaps.insert(Cap).second)599addAvailableCaps(getSymbolicOperandCapabilities(600SPIRV::OperandCategory::CapabilityOperand, Cap));601}602603void SPIRV::RequirementHandler::removeCapabilityIf(604const Capability::Capability ToRemove,605const Capability::Capability IfPresent) {606if (AllCaps.contains(IfPresent))607AllCaps.erase(ToRemove);608}609610namespace llvm {611namespace SPIRV {612void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) {613if (ST.isOpenCLEnv()) {614initAvailableCapabilitiesForOpenCL(ST);615return;616}617618if (ST.isVulkanEnv()) {619initAvailableCapabilitiesForVulkan(ST);620return;621}622623report_fatal_error("Unimplemented environment for SPIR-V generation.");624}625626void RequirementHandler::initAvailableCapabilitiesForOpenCL(627const SPIRVSubtarget &ST) {628// Add the min requirements for different OpenCL and SPIR-V versions.629addAvailableCaps({Capability::Addresses, Capability::Float16Buffer,630Capability::Int16, Capability::Int8, Capability::Kernel,631Capability::Linkage, Capability::Vector16,632Capability::Groups, Capability::GenericPointer,633Capability::Shader});634if (ST.hasOpenCLFullProfile())635addAvailableCaps({Capability::Int64, Capability::Int64Atomics});636if (ST.hasOpenCLImageSupport()) {637addAvailableCaps({Capability::ImageBasic, Capability::LiteralSampler,638Capability::Image1D, Capability::SampledBuffer,639Capability::ImageBuffer});640if (ST.isAtLeastOpenCLVer(VersionTuple(2, 0)))641addAvailableCaps({Capability::ImageReadWrite});642}643if (ST.isAtLeastSPIRVVer(VersionTuple(1, 1)) &&644ST.isAtLeastOpenCLVer(VersionTuple(2, 2)))645addAvailableCaps({Capability::SubgroupDispatch, Capability::PipeStorage});646if (ST.isAtLeastSPIRVVer(VersionTuple(1, 3)))647addAvailableCaps({Capability::GroupNonUniform,648Capability::GroupNonUniformVote,649Capability::GroupNonUniformArithmetic,650Capability::GroupNonUniformBallot,651Capability::GroupNonUniformClustered,652Capability::GroupNonUniformShuffle,653Capability::GroupNonUniformShuffleRelative});654if (ST.isAtLeastSPIRVVer(VersionTuple(1, 4)))655addAvailableCaps({Capability::DenormPreserve, Capability::DenormFlushToZero,656Capability::SignedZeroInfNanPreserve,657Capability::RoundingModeRTE,658Capability::RoundingModeRTZ});659// TODO: verify if this needs some checks.660addAvailableCaps({Capability::Float16, Capability::Float64});661662// Add capabilities enabled by extensions.663for (auto Extension : ST.getAllAvailableExtensions()) {664CapabilityList EnabledCapabilities =665getCapabilitiesEnabledByExtension(Extension);666addAvailableCaps(EnabledCapabilities);667}668669// TODO: add OpenCL extensions.670}671672void RequirementHandler::initAvailableCapabilitiesForVulkan(673const SPIRVSubtarget &ST) {674addAvailableCaps({Capability::Shader, Capability::Linkage});675676// Provided by all supported Vulkan versions.677addAvailableCaps({Capability::Int16, Capability::Int64, Capability::Float16,678Capability::Float64, Capability::GroupNonUniform});679}680681} // namespace SPIRV682} // namespace llvm683684// Add the required capabilities from a decoration instruction (including685// BuiltIns).686static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex,687SPIRV::RequirementHandler &Reqs,688const SPIRVSubtarget &ST) {689int64_t DecOp = MI.getOperand(DecIndex).getImm();690auto Dec = static_cast<SPIRV::Decoration::Decoration>(DecOp);691Reqs.addRequirements(getSymbolicOperandRequirements(692SPIRV::OperandCategory::DecorationOperand, Dec, ST, Reqs));693694if (Dec == SPIRV::Decoration::BuiltIn) {695int64_t BuiltInOp = MI.getOperand(DecIndex + 1).getImm();696auto BuiltIn = static_cast<SPIRV::BuiltIn::BuiltIn>(BuiltInOp);697Reqs.addRequirements(getSymbolicOperandRequirements(698SPIRV::OperandCategory::BuiltInOperand, BuiltIn, ST, Reqs));699} else if (Dec == SPIRV::Decoration::LinkageAttributes) {700int64_t LinkageOp = MI.getOperand(MI.getNumOperands() - 1).getImm();701SPIRV::LinkageType::LinkageType LnkType =702static_cast<SPIRV::LinkageType::LinkageType>(LinkageOp);703if (LnkType == SPIRV::LinkageType::LinkOnceODR)704Reqs.addExtension(SPIRV::Extension::SPV_KHR_linkonce_odr);705} else if (Dec == SPIRV::Decoration::CacheControlLoadINTEL ||706Dec == SPIRV::Decoration::CacheControlStoreINTEL) {707Reqs.addExtension(SPIRV::Extension::SPV_INTEL_cache_controls);708} else if (Dec == SPIRV::Decoration::HostAccessINTEL) {709Reqs.addExtension(SPIRV::Extension::SPV_INTEL_global_variable_host_access);710} else if (Dec == SPIRV::Decoration::InitModeINTEL ||711Dec == SPIRV::Decoration::ImplementInRegisterMapINTEL) {712Reqs.addExtension(713SPIRV::Extension::SPV_INTEL_global_variable_fpga_decorations);714}715}716717// Add requirements for image handling.718static void addOpTypeImageReqs(const MachineInstr &MI,719SPIRV::RequirementHandler &Reqs,720const SPIRVSubtarget &ST) {721assert(MI.getNumOperands() >= 8 && "Insufficient operands for OpTypeImage");722// The operand indices used here are based on the OpTypeImage layout, which723// the MachineInstr follows as well.724int64_t ImgFormatOp = MI.getOperand(7).getImm();725auto ImgFormat = static_cast<SPIRV::ImageFormat::ImageFormat>(ImgFormatOp);726Reqs.getAndAddRequirements(SPIRV::OperandCategory::ImageFormatOperand,727ImgFormat, ST);728729bool IsArrayed = MI.getOperand(4).getImm() == 1;730bool IsMultisampled = MI.getOperand(5).getImm() == 1;731bool NoSampler = MI.getOperand(6).getImm() == 2;732// Add dimension requirements.733assert(MI.getOperand(2).isImm());734switch (MI.getOperand(2).getImm()) {735case SPIRV::Dim::DIM_1D:736Reqs.addRequirements(NoSampler ? SPIRV::Capability::Image1D737: SPIRV::Capability::Sampled1D);738break;739case SPIRV::Dim::DIM_2D:740if (IsMultisampled && NoSampler)741Reqs.addRequirements(SPIRV::Capability::ImageMSArray);742break;743case SPIRV::Dim::DIM_Cube:744Reqs.addRequirements(SPIRV::Capability::Shader);745if (IsArrayed)746Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageCubeArray747: SPIRV::Capability::SampledCubeArray);748break;749case SPIRV::Dim::DIM_Rect:750Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageRect751: SPIRV::Capability::SampledRect);752break;753case SPIRV::Dim::DIM_Buffer:754Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageBuffer755: SPIRV::Capability::SampledBuffer);756break;757case SPIRV::Dim::DIM_SubpassData:758Reqs.addRequirements(SPIRV::Capability::InputAttachment);759break;760}761762// Has optional access qualifier.763// TODO: check if it's OpenCL's kernel.764if (MI.getNumOperands() > 8 &&765MI.getOperand(8).getImm() == SPIRV::AccessQualifier::ReadWrite)766Reqs.addRequirements(SPIRV::Capability::ImageReadWrite);767else768Reqs.addRequirements(SPIRV::Capability::ImageBasic);769}770771// Add requirements for handling atomic float instructions772#define ATOM_FLT_REQ_EXT_MSG(ExtName) \773"The atomic float instruction requires the following SPIR-V " \774"extension: SPV_EXT_shader_atomic_float" ExtName775static void AddAtomicFloatRequirements(const MachineInstr &MI,776SPIRV::RequirementHandler &Reqs,777const SPIRVSubtarget &ST) {778assert(MI.getOperand(1).isReg() &&779"Expect register operand in atomic float instruction");780Register TypeReg = MI.getOperand(1).getReg();781SPIRVType *TypeDef = MI.getMF()->getRegInfo().getVRegDef(TypeReg);782if (TypeDef->getOpcode() != SPIRV::OpTypeFloat)783report_fatal_error("Result type of an atomic float instruction must be a "784"floating-point type scalar");785786unsigned BitWidth = TypeDef->getOperand(1).getImm();787unsigned Op = MI.getOpcode();788if (Op == SPIRV::OpAtomicFAddEXT) {789if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add))790report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_add"), false);791Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add);792switch (BitWidth) {793case 16:794if (!ST.canUseExtension(795SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))796report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false);797Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);798Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT);799break;800case 32:801Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT);802break;803case 64:804Reqs.addCapability(SPIRV::Capability::AtomicFloat64AddEXT);805break;806default:807report_fatal_error(808"Unexpected floating-point type width in atomic float instruction");809}810} else {811if (!ST.canUseExtension(812SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max))813report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_min_max"), false);814Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max);815switch (BitWidth) {816case 16:817Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT);818break;819case 32:820Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT);821break;822case 64:823Reqs.addCapability(SPIRV::Capability::AtomicFloat64MinMaxEXT);824break;825default:826report_fatal_error(827"Unexpected floating-point type width in atomic float instruction");828}829}830}831832void addInstrRequirements(const MachineInstr &MI,833SPIRV::RequirementHandler &Reqs,834const SPIRVSubtarget &ST) {835switch (MI.getOpcode()) {836case SPIRV::OpMemoryModel: {837int64_t Addr = MI.getOperand(0).getImm();838Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,839Addr, ST);840int64_t Mem = MI.getOperand(1).getImm();841Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, Mem,842ST);843break;844}845case SPIRV::OpEntryPoint: {846int64_t Exe = MI.getOperand(0).getImm();847Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModelOperand,848Exe, ST);849break;850}851case SPIRV::OpExecutionMode:852case SPIRV::OpExecutionModeId: {853int64_t Exe = MI.getOperand(1).getImm();854Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModeOperand,855Exe, ST);856break;857}858case SPIRV::OpTypeMatrix:859Reqs.addCapability(SPIRV::Capability::Matrix);860break;861case SPIRV::OpTypeInt: {862unsigned BitWidth = MI.getOperand(1).getImm();863if (BitWidth == 64)864Reqs.addCapability(SPIRV::Capability::Int64);865else if (BitWidth == 16)866Reqs.addCapability(SPIRV::Capability::Int16);867else if (BitWidth == 8)868Reqs.addCapability(SPIRV::Capability::Int8);869break;870}871case SPIRV::OpTypeFloat: {872unsigned BitWidth = MI.getOperand(1).getImm();873if (BitWidth == 64)874Reqs.addCapability(SPIRV::Capability::Float64);875else if (BitWidth == 16)876Reqs.addCapability(SPIRV::Capability::Float16);877break;878}879case SPIRV::OpTypeVector: {880unsigned NumComponents = MI.getOperand(2).getImm();881if (NumComponents == 8 || NumComponents == 16)882Reqs.addCapability(SPIRV::Capability::Vector16);883break;884}885case SPIRV::OpTypePointer: {886auto SC = MI.getOperand(1).getImm();887Reqs.getAndAddRequirements(SPIRV::OperandCategory::StorageClassOperand, SC,888ST);889// If it's a type of pointer to float16 targeting OpenCL, add Float16Buffer890// capability.891if (!ST.isOpenCLEnv())892break;893assert(MI.getOperand(2).isReg());894const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();895SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg());896if (TypeDef->getOpcode() == SPIRV::OpTypeFloat &&897TypeDef->getOperand(1).getImm() == 16)898Reqs.addCapability(SPIRV::Capability::Float16Buffer);899break;900}901case SPIRV::OpBitReverse:902case SPIRV::OpBitFieldInsert:903case SPIRV::OpBitFieldSExtract:904case SPIRV::OpBitFieldUExtract:905if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions)) {906Reqs.addCapability(SPIRV::Capability::Shader);907break;908}909Reqs.addExtension(SPIRV::Extension::SPV_KHR_bit_instructions);910Reqs.addCapability(SPIRV::Capability::BitInstructions);911break;912case SPIRV::OpTypeRuntimeArray:913Reqs.addCapability(SPIRV::Capability::Shader);914break;915case SPIRV::OpTypeOpaque:916case SPIRV::OpTypeEvent:917Reqs.addCapability(SPIRV::Capability::Kernel);918break;919case SPIRV::OpTypePipe:920case SPIRV::OpTypeReserveId:921Reqs.addCapability(SPIRV::Capability::Pipes);922break;923case SPIRV::OpTypeDeviceEvent:924case SPIRV::OpTypeQueue:925case SPIRV::OpBuildNDRange:926Reqs.addCapability(SPIRV::Capability::DeviceEnqueue);927break;928case SPIRV::OpDecorate:929case SPIRV::OpDecorateId:930case SPIRV::OpDecorateString:931addOpDecorateReqs(MI, 1, Reqs, ST);932break;933case SPIRV::OpMemberDecorate:934case SPIRV::OpMemberDecorateString:935addOpDecorateReqs(MI, 2, Reqs, ST);936break;937case SPIRV::OpInBoundsPtrAccessChain:938Reqs.addCapability(SPIRV::Capability::Addresses);939break;940case SPIRV::OpConstantSampler:941Reqs.addCapability(SPIRV::Capability::LiteralSampler);942break;943case SPIRV::OpTypeImage:944addOpTypeImageReqs(MI, Reqs, ST);945break;946case SPIRV::OpTypeSampler:947Reqs.addCapability(SPIRV::Capability::ImageBasic);948break;949case SPIRV::OpTypeForwardPointer:950// TODO: check if it's OpenCL's kernel.951Reqs.addCapability(SPIRV::Capability::Addresses);952break;953case SPIRV::OpAtomicFlagTestAndSet:954case SPIRV::OpAtomicLoad:955case SPIRV::OpAtomicStore:956case SPIRV::OpAtomicExchange:957case SPIRV::OpAtomicCompareExchange:958case SPIRV::OpAtomicIIncrement:959case SPIRV::OpAtomicIDecrement:960case SPIRV::OpAtomicIAdd:961case SPIRV::OpAtomicISub:962case SPIRV::OpAtomicUMin:963case SPIRV::OpAtomicUMax:964case SPIRV::OpAtomicSMin:965case SPIRV::OpAtomicSMax:966case SPIRV::OpAtomicAnd:967case SPIRV::OpAtomicOr:968case SPIRV::OpAtomicXor: {969const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();970const MachineInstr *InstrPtr = &MI;971if (MI.getOpcode() == SPIRV::OpAtomicStore) {972assert(MI.getOperand(3).isReg());973InstrPtr = MRI.getVRegDef(MI.getOperand(3).getReg());974assert(InstrPtr && "Unexpected type instruction for OpAtomicStore");975}976assert(InstrPtr->getOperand(1).isReg() && "Unexpected operand in atomic");977Register TypeReg = InstrPtr->getOperand(1).getReg();978SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);979if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {980unsigned BitWidth = TypeDef->getOperand(1).getImm();981if (BitWidth == 64)982Reqs.addCapability(SPIRV::Capability::Int64Atomics);983}984break;985}986case SPIRV::OpGroupNonUniformIAdd:987case SPIRV::OpGroupNonUniformFAdd:988case SPIRV::OpGroupNonUniformIMul:989case SPIRV::OpGroupNonUniformFMul:990case SPIRV::OpGroupNonUniformSMin:991case SPIRV::OpGroupNonUniformUMin:992case SPIRV::OpGroupNonUniformFMin:993case SPIRV::OpGroupNonUniformSMax:994case SPIRV::OpGroupNonUniformUMax:995case SPIRV::OpGroupNonUniformFMax:996case SPIRV::OpGroupNonUniformBitwiseAnd:997case SPIRV::OpGroupNonUniformBitwiseOr:998case SPIRV::OpGroupNonUniformBitwiseXor:999case SPIRV::OpGroupNonUniformLogicalAnd:1000case SPIRV::OpGroupNonUniformLogicalOr:1001case SPIRV::OpGroupNonUniformLogicalXor: {1002assert(MI.getOperand(3).isImm());1003int64_t GroupOp = MI.getOperand(3).getImm();1004switch (GroupOp) {1005case SPIRV::GroupOperation::Reduce:1006case SPIRV::GroupOperation::InclusiveScan:1007case SPIRV::GroupOperation::ExclusiveScan:1008Reqs.addCapability(SPIRV::Capability::Kernel);1009Reqs.addCapability(SPIRV::Capability::GroupNonUniformArithmetic);1010Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);1011break;1012case SPIRV::GroupOperation::ClusteredReduce:1013Reqs.addCapability(SPIRV::Capability::GroupNonUniformClustered);1014break;1015case SPIRV::GroupOperation::PartitionedReduceNV:1016case SPIRV::GroupOperation::PartitionedInclusiveScanNV:1017case SPIRV::GroupOperation::PartitionedExclusiveScanNV:1018Reqs.addCapability(SPIRV::Capability::GroupNonUniformPartitionedNV);1019break;1020}1021break;1022}1023case SPIRV::OpGroupNonUniformShuffle:1024case SPIRV::OpGroupNonUniformShuffleXor:1025Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffle);1026break;1027case SPIRV::OpGroupNonUniformShuffleUp:1028case SPIRV::OpGroupNonUniformShuffleDown:1029Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffleRelative);1030break;1031case SPIRV::OpGroupAll:1032case SPIRV::OpGroupAny:1033case SPIRV::OpGroupBroadcast:1034case SPIRV::OpGroupIAdd:1035case SPIRV::OpGroupFAdd:1036case SPIRV::OpGroupFMin:1037case SPIRV::OpGroupUMin:1038case SPIRV::OpGroupSMin:1039case SPIRV::OpGroupFMax:1040case SPIRV::OpGroupUMax:1041case SPIRV::OpGroupSMax:1042Reqs.addCapability(SPIRV::Capability::Groups);1043break;1044case SPIRV::OpGroupNonUniformElect:1045Reqs.addCapability(SPIRV::Capability::GroupNonUniform);1046break;1047case SPIRV::OpGroupNonUniformAll:1048case SPIRV::OpGroupNonUniformAny:1049case SPIRV::OpGroupNonUniformAllEqual:1050Reqs.addCapability(SPIRV::Capability::GroupNonUniformVote);1051break;1052case SPIRV::OpGroupNonUniformBroadcast:1053case SPIRV::OpGroupNonUniformBroadcastFirst:1054case SPIRV::OpGroupNonUniformBallot:1055case SPIRV::OpGroupNonUniformInverseBallot:1056case SPIRV::OpGroupNonUniformBallotBitExtract:1057case SPIRV::OpGroupNonUniformBallotBitCount:1058case SPIRV::OpGroupNonUniformBallotFindLSB:1059case SPIRV::OpGroupNonUniformBallotFindMSB:1060Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);1061break;1062case SPIRV::OpSubgroupShuffleINTEL:1063case SPIRV::OpSubgroupShuffleDownINTEL:1064case SPIRV::OpSubgroupShuffleUpINTEL:1065case SPIRV::OpSubgroupShuffleXorINTEL:1066if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {1067Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);1068Reqs.addCapability(SPIRV::Capability::SubgroupShuffleINTEL);1069}1070break;1071case SPIRV::OpSubgroupBlockReadINTEL:1072case SPIRV::OpSubgroupBlockWriteINTEL:1073if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {1074Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);1075Reqs.addCapability(SPIRV::Capability::SubgroupBufferBlockIOINTEL);1076}1077break;1078case SPIRV::OpSubgroupImageBlockReadINTEL:1079case SPIRV::OpSubgroupImageBlockWriteINTEL:1080if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {1081Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);1082Reqs.addCapability(SPIRV::Capability::SubgroupImageBlockIOINTEL);1083}1084break;1085case SPIRV::OpAssumeTrueKHR:1086case SPIRV::OpExpectKHR:1087if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) {1088Reqs.addExtension(SPIRV::Extension::SPV_KHR_expect_assume);1089Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR);1090}1091break;1092case SPIRV::OpPtrCastToCrossWorkgroupINTEL:1093case SPIRV::OpCrossWorkgroupCastToPtrINTEL:1094if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)) {1095Reqs.addExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes);1096Reqs.addCapability(SPIRV::Capability::USMStorageClassesINTEL);1097}1098break;1099case SPIRV::OpConstantFunctionPointerINTEL:1100if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {1101Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);1102Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);1103}1104break;1105case SPIRV::OpGroupNonUniformRotateKHR:1106if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate))1107report_fatal_error("OpGroupNonUniformRotateKHR instruction requires the "1108"following SPIR-V extension: SPV_KHR_subgroup_rotate",1109false);1110Reqs.addExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate);1111Reqs.addCapability(SPIRV::Capability::GroupNonUniformRotateKHR);1112Reqs.addCapability(SPIRV::Capability::GroupNonUniform);1113break;1114case SPIRV::OpGroupIMulKHR:1115case SPIRV::OpGroupFMulKHR:1116case SPIRV::OpGroupBitwiseAndKHR:1117case SPIRV::OpGroupBitwiseOrKHR:1118case SPIRV::OpGroupBitwiseXorKHR:1119case SPIRV::OpGroupLogicalAndKHR:1120case SPIRV::OpGroupLogicalOrKHR:1121case SPIRV::OpGroupLogicalXorKHR:1122if (ST.canUseExtension(1123SPIRV::Extension::SPV_KHR_uniform_group_instructions)) {1124Reqs.addExtension(SPIRV::Extension::SPV_KHR_uniform_group_instructions);1125Reqs.addCapability(SPIRV::Capability::GroupUniformArithmeticKHR);1126}1127break;1128case SPIRV::OpReadClockKHR:1129if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_shader_clock))1130report_fatal_error("OpReadClockKHR instruction requires the "1131"following SPIR-V extension: SPV_KHR_shader_clock",1132false);1133Reqs.addExtension(SPIRV::Extension::SPV_KHR_shader_clock);1134Reqs.addCapability(SPIRV::Capability::ShaderClockKHR);1135break;1136case SPIRV::OpFunctionPointerCallINTEL:1137if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {1138Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);1139Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);1140}1141break;1142case SPIRV::OpAtomicFAddEXT:1143case SPIRV::OpAtomicFMinEXT:1144case SPIRV::OpAtomicFMaxEXT:1145AddAtomicFloatRequirements(MI, Reqs, ST);1146break;1147case SPIRV::OpConvertBF16ToFINTEL:1148case SPIRV::OpConvertFToBF16INTEL:1149if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) {1150Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion);1151Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL);1152}1153break;1154case SPIRV::OpVariableLengthArrayINTEL:1155case SPIRV::OpSaveMemoryINTEL:1156case SPIRV::OpRestoreMemoryINTEL:1157if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_variable_length_array)) {1158Reqs.addExtension(SPIRV::Extension::SPV_INTEL_variable_length_array);1159Reqs.addCapability(SPIRV::Capability::VariableLengthArrayINTEL);1160}1161break;1162case SPIRV::OpAsmTargetINTEL:1163case SPIRV::OpAsmINTEL:1164case SPIRV::OpAsmCallINTEL:1165if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_inline_assembly)) {1166Reqs.addExtension(SPIRV::Extension::SPV_INTEL_inline_assembly);1167Reqs.addCapability(SPIRV::Capability::AsmINTEL);1168}1169break;1170case SPIRV::OpTypeCooperativeMatrixKHR:1171if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))1172report_fatal_error(1173"OpTypeCooperativeMatrixKHR type requires the "1174"following SPIR-V extension: SPV_KHR_cooperative_matrix",1175false);1176Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);1177Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);1178break;1179default:1180break;1181}11821183// If we require capability Shader, then we can remove the requirement for1184// the BitInstructions capability, since Shader is a superset capability1185// of BitInstructions.1186Reqs.removeCapabilityIf(SPIRV::Capability::BitInstructions,1187SPIRV::Capability::Shader);1188}11891190static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,1191MachineModuleInfo *MMI, const SPIRVSubtarget &ST) {1192// Collect requirements for existing instructions.1193for (auto F = M.begin(), E = M.end(); F != E; ++F) {1194MachineFunction *MF = MMI->getMachineFunction(*F);1195if (!MF)1196continue;1197for (const MachineBasicBlock &MBB : *MF)1198for (const MachineInstr &MI : MBB)1199addInstrRequirements(MI, MAI.Reqs, ST);1200}1201// Collect requirements for OpExecutionMode instructions.1202auto Node = M.getNamedMetadata("spirv.ExecutionMode");1203if (Node) {1204// SPV_KHR_float_controls is not available until v1.41205bool RequireFloatControls = false,1206VerLower14 = !ST.isAtLeastSPIRVVer(VersionTuple(1, 4));1207for (unsigned i = 0; i < Node->getNumOperands(); i++) {1208MDNode *MDN = cast<MDNode>(Node->getOperand(i));1209const MDOperand &MDOp = MDN->getOperand(1);1210if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) {1211Constant *C = CMeta->getValue();1212if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) {1213auto EM = Const->getZExtValue();1214MAI.Reqs.getAndAddRequirements(1215SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);1216// add SPV_KHR_float_controls if the version is too low1217switch (EM) {1218case SPIRV::ExecutionMode::DenormPreserve:1219case SPIRV::ExecutionMode::DenormFlushToZero:1220case SPIRV::ExecutionMode::SignedZeroInfNanPreserve:1221case SPIRV::ExecutionMode::RoundingModeRTE:1222case SPIRV::ExecutionMode::RoundingModeRTZ:1223RequireFloatControls = VerLower14;1224break;1225}1226}1227}1228}1229if (RequireFloatControls &&1230ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls))1231MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls);1232}1233for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {1234const Function &F = *FI;1235if (F.isDeclaration())1236continue;1237if (F.getMetadata("reqd_work_group_size"))1238MAI.Reqs.getAndAddRequirements(1239SPIRV::OperandCategory::ExecutionModeOperand,1240SPIRV::ExecutionMode::LocalSize, ST);1241if (F.getFnAttribute("hlsl.numthreads").isValid()) {1242MAI.Reqs.getAndAddRequirements(1243SPIRV::OperandCategory::ExecutionModeOperand,1244SPIRV::ExecutionMode::LocalSize, ST);1245}1246if (F.getMetadata("work_group_size_hint"))1247MAI.Reqs.getAndAddRequirements(1248SPIRV::OperandCategory::ExecutionModeOperand,1249SPIRV::ExecutionMode::LocalSizeHint, ST);1250if (F.getMetadata("intel_reqd_sub_group_size"))1251MAI.Reqs.getAndAddRequirements(1252SPIRV::OperandCategory::ExecutionModeOperand,1253SPIRV::ExecutionMode::SubgroupSize, ST);1254if (F.getMetadata("vec_type_hint"))1255MAI.Reqs.getAndAddRequirements(1256SPIRV::OperandCategory::ExecutionModeOperand,1257SPIRV::ExecutionMode::VecTypeHint, ST);12581259if (F.hasOptNone() &&1260ST.canUseExtension(SPIRV::Extension::SPV_INTEL_optnone)) {1261// Output OpCapability OptNoneINTEL.1262MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_optnone);1263MAI.Reqs.addCapability(SPIRV::Capability::OptNoneINTEL);1264}1265}1266}12671268static unsigned getFastMathFlags(const MachineInstr &I) {1269unsigned Flags = SPIRV::FPFastMathMode::None;1270if (I.getFlag(MachineInstr::MIFlag::FmNoNans))1271Flags |= SPIRV::FPFastMathMode::NotNaN;1272if (I.getFlag(MachineInstr::MIFlag::FmNoInfs))1273Flags |= SPIRV::FPFastMathMode::NotInf;1274if (I.getFlag(MachineInstr::MIFlag::FmNsz))1275Flags |= SPIRV::FPFastMathMode::NSZ;1276if (I.getFlag(MachineInstr::MIFlag::FmArcp))1277Flags |= SPIRV::FPFastMathMode::AllowRecip;1278if (I.getFlag(MachineInstr::MIFlag::FmReassoc))1279Flags |= SPIRV::FPFastMathMode::Fast;1280return Flags;1281}12821283static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST,1284const SPIRVInstrInfo &TII,1285SPIRV::RequirementHandler &Reqs) {1286if (I.getFlag(MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(I) &&1287getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,1288SPIRV::Decoration::NoSignedWrap, ST, Reqs)1289.IsSatisfiable) {1290buildOpDecorate(I.getOperand(0).getReg(), I, TII,1291SPIRV::Decoration::NoSignedWrap, {});1292}1293if (I.getFlag(MachineInstr::MIFlag::NoUWrap) && TII.canUseNUW(I) &&1294getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,1295SPIRV::Decoration::NoUnsignedWrap, ST,1296Reqs)1297.IsSatisfiable) {1298buildOpDecorate(I.getOperand(0).getReg(), I, TII,1299SPIRV::Decoration::NoUnsignedWrap, {});1300}1301if (!TII.canUseFastMathFlags(I))1302return;1303unsigned FMFlags = getFastMathFlags(I);1304if (FMFlags == SPIRV::FPFastMathMode::None)1305return;1306Register DstReg = I.getOperand(0).getReg();1307buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode, {FMFlags});1308}13091310// Walk all functions and add decorations related to MI flags.1311static void addDecorations(const Module &M, const SPIRVInstrInfo &TII,1312MachineModuleInfo *MMI, const SPIRVSubtarget &ST,1313SPIRV::ModuleAnalysisInfo &MAI) {1314for (auto F = M.begin(), E = M.end(); F != E; ++F) {1315MachineFunction *MF = MMI->getMachineFunction(*F);1316if (!MF)1317continue;1318for (auto &MBB : *MF)1319for (auto &MI : MBB)1320handleMIFlagDecoration(MI, ST, TII, MAI.Reqs);1321}1322}13231324struct SPIRV::ModuleAnalysisInfo SPIRVModuleAnalysis::MAI;13251326void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {1327AU.addRequired<TargetPassConfig>();1328AU.addRequired<MachineModuleInfoWrapperPass>();1329}13301331bool SPIRVModuleAnalysis::runOnModule(Module &M) {1332SPIRVTargetMachine &TM =1333getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>();1334ST = TM.getSubtargetImpl();1335GR = ST->getSPIRVGlobalRegistry();1336TII = ST->getInstrInfo();13371338MMI = &getAnalysis<MachineModuleInfoWrapperPass>().getMMI();13391340setBaseInfo(M);13411342addDecorations(M, *TII, MMI, *ST, MAI);13431344collectReqs(M, MAI, MMI, *ST);13451346// Process type/const/global var/func decl instructions, number their1347// destination registers from 0 to N, collect Extensions and Capabilities.1348processDefInstrs(M);13491350// Number rest of registers from N+1 onwards.1351numberRegistersGlobally(M);13521353// Update references to OpFunction instructions to use Global Registers1354if (GR->hasConstFunPtr())1355collectFuncPtrs();13561357// Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions.1358processOtherInstrs(M);13591360// If there are no entry points, we need the Linkage capability.1361if (MAI.MS[SPIRV::MB_EntryPoints].empty())1362MAI.Reqs.addCapability(SPIRV::Capability::Linkage);13631364// Set maximum ID used.1365GR->setBound(MAI.MaxID);13661367return false;1368}136913701371