Path: blob/main/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
35268 views
//===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// This file contains the implementation of the SPIRVGlobalRegistry class,9// which is used to maintain rich type information required for SPIR-V even10// after lowering from LLVM IR to GMIR. It can convert an llvm::Type into11// an OpTypeXXX instruction, and map it to a virtual register. Also it builds12// and supports consistency of constants and global variables.13//14//===----------------------------------------------------------------------===//1516#include "SPIRVGlobalRegistry.h"17#include "SPIRV.h"18#include "SPIRVBuiltins.h"19#include "SPIRVSubtarget.h"20#include "SPIRVTargetMachine.h"21#include "SPIRVUtils.h"22#include "llvm/ADT/APInt.h"23#include "llvm/IR/Constants.h"24#include "llvm/IR/Type.h"25#include "llvm/Support/Casting.h"26#include <cassert>2728using namespace llvm;29SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)30: PointerSize(PointerSize), Bound(0) {}3132SPIRVType *SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth,33Register VReg,34MachineInstr &I,35const SPIRVInstrInfo &TII) {36SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);37assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);38return SpirvType;39}4041SPIRVType *42SPIRVGlobalRegistry::assignFloatTypeToVReg(unsigned BitWidth, Register VReg,43MachineInstr &I,44const SPIRVInstrInfo &TII) {45SPIRVType *SpirvType = getOrCreateSPIRVFloatType(BitWidth, I, TII);46assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);47return SpirvType;48}4950SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg(51SPIRVType *BaseType, unsigned NumElements, Register VReg, MachineInstr &I,52const SPIRVInstrInfo &TII) {53SPIRVType *SpirvType =54getOrCreateSPIRVVectorType(BaseType, NumElements, I, TII);55assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);56return SpirvType;57}5859SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(60const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder,61SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {62SPIRVType *SpirvType =63getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR);64assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF());65return SpirvType;66}6768void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType,69Register VReg,70MachineFunction &MF) {71VRegToTypeMap[&MF][VReg] = SpirvType;72}7374static Register createTypeVReg(MachineIRBuilder &MIRBuilder) {75auto &MRI = MIRBuilder.getMF().getRegInfo();76auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));77MRI.setRegClass(Res, &SPIRV::TYPERegClass);78return Res;79}8081static Register createTypeVReg(MachineRegisterInfo &MRI) {82auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));83MRI.setRegClass(Res, &SPIRV::TYPERegClass);84return Res;85}8687SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {88return MIRBuilder.buildInstr(SPIRV::OpTypeBool)89.addDef(createTypeVReg(MIRBuilder));90}9192unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {93if (Width > 64)94report_fatal_error("Unsupported integer width!");95const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());96if (ST.canUseExtension(97SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers))98return Width;99if (Width <= 8)100Width = 8;101else if (Width <= 16)102Width = 16;103else if (Width <= 32)104Width = 32;105else106Width = 64;107return Width;108}109110SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,111MachineIRBuilder &MIRBuilder,112bool IsSigned) {113Width = adjustOpTypeIntWidth(Width);114const SPIRVSubtarget &ST =115cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());116if (ST.canUseExtension(117SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {118MIRBuilder.buildInstr(SPIRV::OpExtension)119.addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);120MIRBuilder.buildInstr(SPIRV::OpCapability)121.addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);122}123auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt)124.addDef(createTypeVReg(MIRBuilder))125.addImm(Width)126.addImm(IsSigned ? 1 : 0);127return MIB;128}129130SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,131MachineIRBuilder &MIRBuilder) {132auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat)133.addDef(createTypeVReg(MIRBuilder))134.addImm(Width);135return MIB;136}137138SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {139return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)140.addDef(createTypeVReg(MIRBuilder));141}142143SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,144SPIRVType *ElemType,145MachineIRBuilder &MIRBuilder) {146auto EleOpc = ElemType->getOpcode();147(void)EleOpc;148assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||149EleOpc == SPIRV::OpTypeBool) &&150"Invalid vector element type");151152auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeVector)153.addDef(createTypeVReg(MIRBuilder))154.addUse(getSPIRVTypeID(ElemType))155.addImm(NumElems);156return MIB;157}158159std::tuple<Register, ConstantInt *, bool>160SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,161MachineIRBuilder *MIRBuilder,162MachineInstr *I,163const SPIRVInstrInfo *TII) {164const IntegerType *LLVMIntTy;165if (SpvType)166LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));167else168LLVMIntTy = IntegerType::getInt32Ty(CurMF->getFunction().getContext());169bool NewInstr = false;170// Find a constant in DT or build a new one.171ConstantInt *CI = ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);172Register Res = DT.find(CI, CurMF);173if (!Res.isValid()) {174unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;175// TODO: handle cases where the type is not 32bit wide176// TODO: https://github.com/llvm/llvm-project/issues/88129177LLT LLTy = LLT::scalar(32);178Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);179CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);180if (MIRBuilder)181assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder);182else183assignIntTypeToVReg(BitWidth, Res, *I, *TII);184DT.add(CI, CurMF, Res);185NewInstr = true;186}187return std::make_tuple(Res, CI, NewInstr);188}189190std::tuple<Register, ConstantFP *, bool, unsigned>191SPIRVGlobalRegistry::getOrCreateConstFloatReg(APFloat Val, SPIRVType *SpvType,192MachineIRBuilder *MIRBuilder,193MachineInstr *I,194const SPIRVInstrInfo *TII) {195const Type *LLVMFloatTy;196LLVMContext &Ctx = CurMF->getFunction().getContext();197unsigned BitWidth = 32;198if (SpvType)199LLVMFloatTy = getTypeForSPIRVType(SpvType);200else {201LLVMFloatTy = Type::getFloatTy(Ctx);202if (MIRBuilder)203SpvType = getOrCreateSPIRVType(LLVMFloatTy, *MIRBuilder);204}205bool NewInstr = false;206// Find a constant in DT or build a new one.207auto *const CI = ConstantFP::get(Ctx, Val);208Register Res = DT.find(CI, CurMF);209if (!Res.isValid()) {210if (SpvType)211BitWidth = getScalarOrVectorBitWidth(SpvType);212// TODO: handle cases where the type is not 32bit wide213// TODO: https://github.com/llvm/llvm-project/issues/88129214LLT LLTy = LLT::scalar(32);215Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);216CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);217if (MIRBuilder)218assignTypeToVReg(LLVMFloatTy, Res, *MIRBuilder);219else220assignFloatTypeToVReg(BitWidth, Res, *I, *TII);221DT.add(CI, CurMF, Res);222NewInstr = true;223}224return std::make_tuple(Res, CI, NewInstr, BitWidth);225}226227Register SPIRVGlobalRegistry::getOrCreateConstFP(APFloat Val, MachineInstr &I,228SPIRVType *SpvType,229const SPIRVInstrInfo &TII,230bool ZeroAsNull) {231assert(SpvType);232ConstantFP *CI;233Register Res;234bool New;235unsigned BitWidth;236std::tie(Res, CI, New, BitWidth) =237getOrCreateConstFloatReg(Val, SpvType, nullptr, &I, &TII);238// If we have found Res register which is defined by the passed G_CONSTANT239// machine instruction, a new constant instruction should be created.240if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg()))241return Res;242MachineInstrBuilder MIB;243MachineBasicBlock &BB = *I.getParent();244// In OpenCL OpConstantNull - Scalar floating point: +0.0 (all bits 0)245if (Val.isPosZero() && ZeroAsNull) {246MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))247.addDef(Res)248.addUse(getSPIRVTypeID(SpvType));249} else {250MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantF))251.addDef(Res)252.addUse(getSPIRVTypeID(SpvType));253addNumImm(254APInt(BitWidth, CI->getValueAPF().bitcastToAPInt().getZExtValue()),255MIB);256}257const auto &ST = CurMF->getSubtarget();258constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),259*ST.getRegisterInfo(), *ST.getRegBankInfo());260return Res;261}262263Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,264SPIRVType *SpvType,265const SPIRVInstrInfo &TII,266bool ZeroAsNull) {267assert(SpvType);268ConstantInt *CI;269Register Res;270bool New;271std::tie(Res, CI, New) =272getOrCreateConstIntReg(Val, SpvType, nullptr, &I, &TII);273// If we have found Res register which is defined by the passed G_CONSTANT274// machine instruction, a new constant instruction should be created.275if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg()))276return Res;277MachineInstrBuilder MIB;278MachineBasicBlock &BB = *I.getParent();279if (Val || !ZeroAsNull) {280MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI))281.addDef(Res)282.addUse(getSPIRVTypeID(SpvType));283addNumImm(APInt(getScalarOrVectorBitWidth(SpvType), Val), MIB);284} else {285MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))286.addDef(Res)287.addUse(getSPIRVTypeID(SpvType));288}289const auto &ST = CurMF->getSubtarget();290constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),291*ST.getRegisterInfo(), *ST.getRegBankInfo());292return Res;293}294295Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,296MachineIRBuilder &MIRBuilder,297SPIRVType *SpvType,298bool EmitIR) {299auto &MF = MIRBuilder.getMF();300const IntegerType *LLVMIntTy;301if (SpvType)302LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));303else304LLVMIntTy = IntegerType::getInt32Ty(MF.getFunction().getContext());305// Find a constant in DT or build a new one.306const auto ConstInt =307ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);308Register Res = DT.find(ConstInt, &MF);309if (!Res.isValid()) {310unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;311LLT LLTy = LLT::scalar(EmitIR ? BitWidth : 32);312Res = MF.getRegInfo().createGenericVirtualRegister(LLTy);313MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);314assignTypeToVReg(LLVMIntTy, Res, MIRBuilder,315SPIRV::AccessQualifier::ReadWrite, EmitIR);316DT.add(ConstInt, &MIRBuilder.getMF(), Res);317if (EmitIR) {318MIRBuilder.buildConstant(Res, *ConstInt);319} else {320if (!SpvType)321SpvType = getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder);322MachineInstrBuilder MIB;323if (Val) {324MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)325.addDef(Res)326.addUse(getSPIRVTypeID(SpvType));327addNumImm(APInt(BitWidth, Val), MIB);328} else {329MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)330.addDef(Res)331.addUse(getSPIRVTypeID(SpvType));332}333const auto &Subtarget = CurMF->getSubtarget();334constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),335*Subtarget.getRegisterInfo(),336*Subtarget.getRegBankInfo());337}338}339return Res;340}341342Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,343MachineIRBuilder &MIRBuilder,344SPIRVType *SpvType) {345auto &MF = MIRBuilder.getMF();346auto &Ctx = MF.getFunction().getContext();347if (!SpvType) {348const Type *LLVMFPTy = Type::getFloatTy(Ctx);349SpvType = getOrCreateSPIRVType(LLVMFPTy, MIRBuilder);350}351// Find a constant in DT or build a new one.352const auto ConstFP = ConstantFP::get(Ctx, Val);353Register Res = DT.find(ConstFP, &MF);354if (!Res.isValid()) {355Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(32));356MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);357assignSPIRVTypeToVReg(SpvType, Res, MF);358DT.add(ConstFP, &MF, Res);359360MachineInstrBuilder MIB;361MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF)362.addDef(Res)363.addUse(getSPIRVTypeID(SpvType));364addNumImm(ConstFP->getValueAPF().bitcastToAPInt(), MIB);365}366367return Res;368}369370Register SPIRVGlobalRegistry::getOrCreateBaseRegister(Constant *Val,371MachineInstr &I,372SPIRVType *SpvType,373const SPIRVInstrInfo &TII,374unsigned BitWidth) {375SPIRVType *Type = SpvType;376if (SpvType->getOpcode() == SPIRV::OpTypeVector ||377SpvType->getOpcode() == SPIRV::OpTypeArray) {378auto EleTypeReg = SpvType->getOperand(1).getReg();379Type = getSPIRVTypeForVReg(EleTypeReg);380}381if (Type->getOpcode() == SPIRV::OpTypeFloat) {382SPIRVType *SpvBaseType = getOrCreateSPIRVFloatType(BitWidth, I, TII);383return getOrCreateConstFP(dyn_cast<ConstantFP>(Val)->getValue(), I,384SpvBaseType, TII);385}386assert(Type->getOpcode() == SPIRV::OpTypeInt);387SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);388return getOrCreateConstInt(Val->getUniqueInteger().getSExtValue(), I,389SpvBaseType, TII);390}391392Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(393Constant *Val, MachineInstr &I, SPIRVType *SpvType,394const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth,395unsigned ElemCnt, bool ZeroAsNull) {396// Find a constant vector or array in DT or build a new one.397Register Res = DT.find(CA, CurMF);398// If no values are attached, the composite is null constant.399bool IsNull = Val->isNullValue() && ZeroAsNull;400if (!Res.isValid()) {401// SpvScalConst should be created before SpvVecConst to avoid undefined ID402// error on validation.403// TODO: can moved below once sorting of types/consts/defs is implemented.404Register SpvScalConst;405if (!IsNull)406SpvScalConst = getOrCreateBaseRegister(Val, I, SpvType, TII, BitWidth);407408// TODO: handle cases where the type is not 32bit wide409// TODO: https://github.com/llvm/llvm-project/issues/88129410LLT LLTy = LLT::scalar(32);411Register SpvVecConst =412CurMF->getRegInfo().createGenericVirtualRegister(LLTy);413CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass);414assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);415DT.add(CA, CurMF, SpvVecConst);416MachineInstrBuilder MIB;417MachineBasicBlock &BB = *I.getParent();418if (!IsNull) {419MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantComposite))420.addDef(SpvVecConst)421.addUse(getSPIRVTypeID(SpvType));422for (unsigned i = 0; i < ElemCnt; ++i)423MIB.addUse(SpvScalConst);424} else {425MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))426.addDef(SpvVecConst)427.addUse(getSPIRVTypeID(SpvType));428}429const auto &Subtarget = CurMF->getSubtarget();430constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),431*Subtarget.getRegisterInfo(),432*Subtarget.getRegBankInfo());433return SpvVecConst;434}435return Res;436}437438Register SPIRVGlobalRegistry::getOrCreateConstVector(uint64_t Val,439MachineInstr &I,440SPIRVType *SpvType,441const SPIRVInstrInfo &TII,442bool ZeroAsNull) {443const Type *LLVMTy = getTypeForSPIRVType(SpvType);444assert(LLVMTy->isVectorTy());445const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);446Type *LLVMBaseTy = LLVMVecTy->getElementType();447assert(LLVMBaseTy->isIntegerTy());448auto *ConstVal = ConstantInt::get(LLVMBaseTy, Val);449auto *ConstVec =450ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal);451unsigned BW = getScalarOrVectorBitWidth(SpvType);452return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW,453SpvType->getOperand(2).getImm(),454ZeroAsNull);455}456457Register SPIRVGlobalRegistry::getOrCreateConstVector(APFloat Val,458MachineInstr &I,459SPIRVType *SpvType,460const SPIRVInstrInfo &TII,461bool ZeroAsNull) {462const Type *LLVMTy = getTypeForSPIRVType(SpvType);463assert(LLVMTy->isVectorTy());464const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);465Type *LLVMBaseTy = LLVMVecTy->getElementType();466assert(LLVMBaseTy->isFloatingPointTy());467auto *ConstVal = ConstantFP::get(LLVMBaseTy, Val);468auto *ConstVec =469ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal);470unsigned BW = getScalarOrVectorBitWidth(SpvType);471return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW,472SpvType->getOperand(2).getImm(),473ZeroAsNull);474}475476Register SPIRVGlobalRegistry::getOrCreateConstIntArray(477uint64_t Val, size_t Num, MachineInstr &I, SPIRVType *SpvType,478const SPIRVInstrInfo &TII) {479const Type *LLVMTy = getTypeForSPIRVType(SpvType);480assert(LLVMTy->isArrayTy());481const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);482Type *LLVMBaseTy = LLVMArrTy->getElementType();483Constant *CI = ConstantInt::get(LLVMBaseTy, Val);484SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());485unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);486// The following is reasonably unique key that is better that [Val]. The naive487// alternative would be something along the lines of:488// SmallVector<Constant *> NumCI(Num, CI);489// Constant *UniqueKey =490// ConstantArray::get(const_cast<ArrayType*>(LLVMArrTy), NumCI);491// that would be a truly unique but dangerous key, because it could lead to492// the creation of constants of arbitrary length (that is, the parameter of493// memset) which were missing in the original module.494Constant *UniqueKey = ConstantStruct::getAnon(495{PoisonValue::get(const_cast<ArrayType *>(LLVMArrTy)),496ConstantInt::get(LLVMBaseTy, Val), ConstantInt::get(LLVMBaseTy, Num)});497return getOrCreateCompositeOrNull(CI, I, SpvType, TII, UniqueKey, BW,498LLVMArrTy->getNumElements());499}500501Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(502uint64_t Val, MachineIRBuilder &MIRBuilder, SPIRVType *SpvType, bool EmitIR,503Constant *CA, unsigned BitWidth, unsigned ElemCnt) {504Register Res = DT.find(CA, CurMF);505if (!Res.isValid()) {506Register SpvScalConst;507if (Val || EmitIR) {508SPIRVType *SpvBaseType =509getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder);510SpvScalConst = buildConstantInt(Val, MIRBuilder, SpvBaseType, EmitIR);511}512LLT LLTy = EmitIR ? LLT::fixed_vector(ElemCnt, BitWidth) : LLT::scalar(32);513Register SpvVecConst =514CurMF->getRegInfo().createGenericVirtualRegister(LLTy);515CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass);516assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);517DT.add(CA, CurMF, SpvVecConst);518if (EmitIR) {519MIRBuilder.buildSplatVector(SpvVecConst, SpvScalConst);520} else {521if (Val) {522auto MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite)523.addDef(SpvVecConst)524.addUse(getSPIRVTypeID(SpvType));525for (unsigned i = 0; i < ElemCnt; ++i)526MIB.addUse(SpvScalConst);527} else {528MIRBuilder.buildInstr(SPIRV::OpConstantNull)529.addDef(SpvVecConst)530.addUse(getSPIRVTypeID(SpvType));531}532}533return SpvVecConst;534}535return Res;536}537538Register539SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val,540MachineIRBuilder &MIRBuilder,541SPIRVType *SpvType, bool EmitIR) {542const Type *LLVMTy = getTypeForSPIRVType(SpvType);543assert(LLVMTy->isVectorTy());544const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);545Type *LLVMBaseTy = LLVMVecTy->getElementType();546const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);547auto ConstVec =548ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt);549unsigned BW = getScalarOrVectorBitWidth(SpvType);550return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR,551ConstVec, BW,552SpvType->getOperand(2).getImm());553}554555Register556SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,557SPIRVType *SpvType) {558const Type *LLVMTy = getTypeForSPIRVType(SpvType);559const TypedPointerType *LLVMPtrTy = cast<TypedPointerType>(LLVMTy);560// Find a constant in DT or build a new one.561Constant *CP = ConstantPointerNull::get(PointerType::get(562LLVMPtrTy->getElementType(), LLVMPtrTy->getAddressSpace()));563Register Res = DT.find(CP, CurMF);564if (!Res.isValid()) {565LLT LLTy = LLT::pointer(LLVMPtrTy->getAddressSpace(), PointerSize);566Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);567CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);568assignSPIRVTypeToVReg(SpvType, Res, *CurMF);569MIRBuilder.buildInstr(SPIRV::OpConstantNull)570.addDef(Res)571.addUse(getSPIRVTypeID(SpvType));572DT.add(CP, CurMF, Res);573}574return Res;575}576577Register SPIRVGlobalRegistry::buildConstantSampler(578Register ResReg, unsigned AddrMode, unsigned Param, unsigned FilerMode,579MachineIRBuilder &MIRBuilder, SPIRVType *SpvType) {580SPIRVType *SampTy;581if (SpvType)582SampTy = getOrCreateSPIRVType(getTypeForSPIRVType(SpvType), MIRBuilder);583else if ((SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t",584MIRBuilder)) == nullptr)585report_fatal_error("Unable to recognize SPIRV type name: opencl.sampler_t");586587auto Sampler =588ResReg.isValid()589? ResReg590: MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);591auto Res = MIRBuilder.buildInstr(SPIRV::OpConstantSampler)592.addDef(Sampler)593.addUse(getSPIRVTypeID(SampTy))594.addImm(AddrMode)595.addImm(Param)596.addImm(FilerMode);597assert(Res->getOperand(0).isReg());598return Res->getOperand(0).getReg();599}600601Register SPIRVGlobalRegistry::buildGlobalVariable(602Register ResVReg, SPIRVType *BaseType, StringRef Name,603const GlobalValue *GV, SPIRV::StorageClass::StorageClass Storage,604const MachineInstr *Init, bool IsConst, bool HasLinkageTy,605SPIRV::LinkageType::LinkageType LinkageType, MachineIRBuilder &MIRBuilder,606bool IsInstSelector) {607const GlobalVariable *GVar = nullptr;608if (GV)609GVar = cast<const GlobalVariable>(GV);610else {611// If GV is not passed explicitly, use the name to find or construct612// the global variable.613Module *M = MIRBuilder.getMF().getFunction().getParent();614GVar = M->getGlobalVariable(Name);615if (GVar == nullptr) {616const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type.617// Module takes ownership of the global var.618GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false,619GlobalValue::ExternalLinkage, nullptr,620Twine(Name));621}622GV = GVar;623}624Register Reg = DT.find(GVar, &MIRBuilder.getMF());625if (Reg.isValid()) {626if (Reg != ResVReg)627MIRBuilder.buildCopy(ResVReg, Reg);628return ResVReg;629}630631auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable)632.addDef(ResVReg)633.addUse(getSPIRVTypeID(BaseType))634.addImm(static_cast<uint32_t>(Storage));635636if (Init != 0) {637MIB.addUse(Init->getOperand(0).getReg());638}639640// ISel may introduce a new register on this step, so we need to add it to641// DT and correct its type avoiding fails on the next stage.642if (IsInstSelector) {643const auto &Subtarget = CurMF->getSubtarget();644constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),645*Subtarget.getRegisterInfo(),646*Subtarget.getRegBankInfo());647}648Reg = MIB->getOperand(0).getReg();649DT.add(GVar, &MIRBuilder.getMF(), Reg);650651// Set to Reg the same type as ResVReg has.652auto MRI = MIRBuilder.getMRI();653assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected");654if (Reg != ResVReg) {655LLT RegLLTy =656LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), getPointerSize());657MRI->setType(Reg, RegLLTy);658assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());659} else {660// Our knowledge about the type may be updated.661// If that's the case, we need to update a type662// associated with the register.663SPIRVType *DefType = getSPIRVTypeForVReg(ResVReg);664if (!DefType || DefType != BaseType)665assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());666}667668// If it's a global variable with name, output OpName for it.669if (GVar && GVar->hasName())670buildOpName(Reg, GVar->getName(), MIRBuilder);671672// Output decorations for the GV.673// TODO: maybe move to GenerateDecorations pass.674const SPIRVSubtarget &ST =675cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());676if (IsConst && ST.isOpenCLEnv())677buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {});678679if (GVar && GVar->getAlign().valueOrOne().value() != 1) {680unsigned Alignment = (unsigned)GVar->getAlign().valueOrOne().value();681buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Alignment, {Alignment});682}683684if (HasLinkageTy)685buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,686{static_cast<uint32_t>(LinkageType)}, Name);687688SPIRV::BuiltIn::BuiltIn BuiltInId;689if (getSpirvBuiltInIdByName(Name, BuiltInId))690buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::BuiltIn,691{static_cast<uint32_t>(BuiltInId)});692693// If it's a global variable with "spirv.Decorations" metadata node694// recognize it as a SPIR-V friendly LLVM IR and parse "spirv.Decorations"695// arguments.696MDNode *GVarMD = nullptr;697if (GVar && (GVarMD = GVar->getMetadata("spirv.Decorations")) != nullptr)698buildOpSpirvDecorations(Reg, MIRBuilder, GVarMD);699700return Reg;701}702703SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,704SPIRVType *ElemType,705MachineIRBuilder &MIRBuilder,706bool EmitIR) {707assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) &&708"Invalid array element type");709Register NumElementsVReg =710buildConstantInt(NumElems, MIRBuilder, nullptr, EmitIR);711auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray)712.addDef(createTypeVReg(MIRBuilder))713.addUse(getSPIRVTypeID(ElemType))714.addUse(NumElementsVReg);715return MIB;716}717718SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty,719MachineIRBuilder &MIRBuilder) {720assert(Ty->hasName());721const StringRef Name = Ty->hasName() ? Ty->getName() : "";722Register ResVReg = createTypeVReg(MIRBuilder);723auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeOpaque).addDef(ResVReg);724addStringImm(Name, MIB);725buildOpName(ResVReg, Name, MIRBuilder);726return MIB;727}728729SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,730MachineIRBuilder &MIRBuilder,731bool EmitIR) {732SmallVector<Register, 4> FieldTypes;733for (const auto &Elem : Ty->elements()) {734SPIRVType *ElemTy = findSPIRVType(toTypedPointer(Elem), MIRBuilder);735assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&736"Invalid struct element type");737FieldTypes.push_back(getSPIRVTypeID(ElemTy));738}739Register ResVReg = createTypeVReg(MIRBuilder);740auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg);741for (const auto &Ty : FieldTypes)742MIB.addUse(Ty);743if (Ty->hasName())744buildOpName(ResVReg, Ty->getName(), MIRBuilder);745if (Ty->isPacked())746buildOpDecorate(ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {});747return MIB;748}749750SPIRVType *SPIRVGlobalRegistry::getOrCreateSpecialType(751const Type *Ty, MachineIRBuilder &MIRBuilder,752SPIRV::AccessQualifier::AccessQualifier AccQual) {753assert(isSpecialOpaqueType(Ty) && "Not a special opaque builtin type");754return SPIRV::lowerBuiltinType(Ty, AccQual, MIRBuilder, this);755}756757SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(758SPIRV::StorageClass::StorageClass SC, SPIRVType *ElemType,759MachineIRBuilder &MIRBuilder, Register Reg) {760if (!Reg.isValid())761Reg = createTypeVReg(MIRBuilder);762return MIRBuilder.buildInstr(SPIRV::OpTypePointer)763.addDef(Reg)764.addImm(static_cast<uint32_t>(SC))765.addUse(getSPIRVTypeID(ElemType));766}767768SPIRVType *SPIRVGlobalRegistry::getOpTypeForwardPointer(769SPIRV::StorageClass::StorageClass SC, MachineIRBuilder &MIRBuilder) {770return MIRBuilder.buildInstr(SPIRV::OpTypeForwardPointer)771.addUse(createTypeVReg(MIRBuilder))772.addImm(static_cast<uint32_t>(SC));773}774775SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction(776SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes,777MachineIRBuilder &MIRBuilder) {778auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction)779.addDef(createTypeVReg(MIRBuilder))780.addUse(getSPIRVTypeID(RetType));781for (const SPIRVType *ArgType : ArgTypes)782MIB.addUse(getSPIRVTypeID(ArgType));783return MIB;784}785786SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(787const Type *Ty, SPIRVType *RetType,788const SmallVectorImpl<SPIRVType *> &ArgTypes,789MachineIRBuilder &MIRBuilder) {790Register Reg = DT.find(Ty, &MIRBuilder.getMF());791if (Reg.isValid())792return getSPIRVTypeForVReg(Reg);793SPIRVType *SpirvType = getOpTypeFunction(RetType, ArgTypes, MIRBuilder);794DT.add(Ty, CurMF, getSPIRVTypeID(SpirvType));795return finishCreatingSPIRVType(Ty, SpirvType);796}797798SPIRVType *SPIRVGlobalRegistry::findSPIRVType(799const Type *Ty, MachineIRBuilder &MIRBuilder,800SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {801Ty = adjustIntTypeByWidth(Ty);802Register Reg = DT.find(Ty, &MIRBuilder.getMF());803if (Reg.isValid())804return getSPIRVTypeForVReg(Reg);805if (ForwardPointerTypes.contains(Ty))806return ForwardPointerTypes[Ty];807return restOfCreateSPIRVType(Ty, MIRBuilder, AccQual, EmitIR);808}809810Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const {811assert(SpirvType && "Attempting to get type id for nullptr type.");812if (SpirvType->getOpcode() == SPIRV::OpTypeForwardPointer)813return SpirvType->uses().begin()->getReg();814return SpirvType->defs().begin()->getReg();815}816817// We need to use a new LLVM integer type if there is a mismatch between818// number of bits in LLVM and SPIRV integer types to let DuplicateTracker819// ensure uniqueness of a SPIRV type by the corresponding LLVM type. Without820// such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create the821// same "OpTypeInt 8" type for a series of LLVM integer types with number of822// bits less than 8. This would lead to duplicate type definitions823// eventually due to the method that DuplicateTracker utilizes to reason824// about uniqueness of type records.825const Type *SPIRVGlobalRegistry::adjustIntTypeByWidth(const Type *Ty) const {826if (auto IType = dyn_cast<IntegerType>(Ty)) {827unsigned SrcBitWidth = IType->getBitWidth();828if (SrcBitWidth > 1) {829unsigned BitWidth = adjustOpTypeIntWidth(SrcBitWidth);830// Maybe change source LLVM type to keep DuplicateTracker consistent.831if (SrcBitWidth != BitWidth)832Ty = IntegerType::get(Ty->getContext(), BitWidth);833}834}835return Ty;836}837838SPIRVType *SPIRVGlobalRegistry::createSPIRVType(839const Type *Ty, MachineIRBuilder &MIRBuilder,840SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {841if (isSpecialOpaqueType(Ty))842return getOrCreateSpecialType(Ty, MIRBuilder, AccQual);843auto &TypeToSPIRVTypeMap = DT.getTypes()->getAllUses();844auto t = TypeToSPIRVTypeMap.find(Ty);845if (t != TypeToSPIRVTypeMap.end()) {846auto tt = t->second.find(&MIRBuilder.getMF());847if (tt != t->second.end())848return getSPIRVTypeForVReg(tt->second);849}850851if (auto IType = dyn_cast<IntegerType>(Ty)) {852const unsigned Width = IType->getBitWidth();853return Width == 1 ? getOpTypeBool(MIRBuilder)854: getOpTypeInt(Width, MIRBuilder, false);855}856if (Ty->isFloatingPointTy())857return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);858if (Ty->isVoidTy())859return getOpTypeVoid(MIRBuilder);860if (Ty->isVectorTy()) {861SPIRVType *El =862findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), MIRBuilder);863return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El,864MIRBuilder);865}866if (Ty->isArrayTy()) {867SPIRVType *El = findSPIRVType(Ty->getArrayElementType(), MIRBuilder);868return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR);869}870if (auto SType = dyn_cast<StructType>(Ty)) {871if (SType->isOpaque())872return getOpTypeOpaque(SType, MIRBuilder);873return getOpTypeStruct(SType, MIRBuilder, EmitIR);874}875if (auto FType = dyn_cast<FunctionType>(Ty)) {876SPIRVType *RetTy = findSPIRVType(FType->getReturnType(), MIRBuilder);877SmallVector<SPIRVType *, 4> ParamTypes;878for (const auto &t : FType->params()) {879ParamTypes.push_back(findSPIRVType(t, MIRBuilder));880}881return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder);882}883unsigned AddrSpace = 0xFFFF;884if (auto PType = dyn_cast<TypedPointerType>(Ty))885AddrSpace = PType->getAddressSpace();886else if (auto PType = dyn_cast<PointerType>(Ty))887AddrSpace = PType->getAddressSpace();888else889report_fatal_error("Unable to convert LLVM type to SPIRVType", true);890891SPIRVType *SpvElementType = nullptr;892if (auto PType = dyn_cast<TypedPointerType>(Ty))893SpvElementType = getOrCreateSPIRVType(PType->getElementType(), MIRBuilder,894AccQual, EmitIR);895else896SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);897898// Get access to information about available extensions899const SPIRVSubtarget *ST =900static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());901auto SC = addressSpaceToStorageClass(AddrSpace, *ST);902// Null pointer means we have a loop in type definitions, make and903// return corresponding OpTypeForwardPointer.904if (SpvElementType == nullptr) {905if (!ForwardPointerTypes.contains(Ty))906ForwardPointerTypes[Ty] = getOpTypeForwardPointer(SC, MIRBuilder);907return ForwardPointerTypes[Ty];908}909// If we have forward pointer associated with this type, use its register910// operand to create OpTypePointer.911if (ForwardPointerTypes.contains(Ty)) {912Register Reg = getSPIRVTypeID(ForwardPointerTypes[Ty]);913return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg);914}915916return getOrCreateSPIRVPointerType(SpvElementType, MIRBuilder, SC);917}918919SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(920const Type *Ty, MachineIRBuilder &MIRBuilder,921SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {922if (TypesInProcessing.count(Ty) && !isPointerTy(Ty))923return nullptr;924TypesInProcessing.insert(Ty);925SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);926TypesInProcessing.erase(Ty);927VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;928SPIRVToLLVMType[SpirvType] = unifyPtrType(Ty);929Register Reg = DT.find(Ty, &MIRBuilder.getMF());930// Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type931// will be added later. For special types it is already added to DT.932if (SpirvType->getOpcode() != SPIRV::OpTypeForwardPointer && !Reg.isValid() &&933!isSpecialOpaqueType(Ty)) {934if (!isPointerTy(Ty))935DT.add(Ty, &MIRBuilder.getMF(), getSPIRVTypeID(SpirvType));936else if (isTypedPointerTy(Ty))937DT.add(cast<TypedPointerType>(Ty)->getElementType(),938getPointerAddressSpace(Ty), &MIRBuilder.getMF(),939getSPIRVTypeID(SpirvType));940else941DT.add(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),942getPointerAddressSpace(Ty), &MIRBuilder.getMF(),943getSPIRVTypeID(SpirvType));944}945946return SpirvType;947}948949SPIRVType *950SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg,951const MachineFunction *MF) const {952auto t = VRegToTypeMap.find(MF ? MF : CurMF);953if (t != VRegToTypeMap.end()) {954auto tt = t->second.find(VReg);955if (tt != t->second.end())956return tt->second;957}958return nullptr;959}960961SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(962const Type *Ty, MachineIRBuilder &MIRBuilder,963SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {964Register Reg;965if (!isPointerTy(Ty)) {966Ty = adjustIntTypeByWidth(Ty);967Reg = DT.find(Ty, &MIRBuilder.getMF());968} else if (isTypedPointerTy(Ty)) {969Reg = DT.find(cast<TypedPointerType>(Ty)->getElementType(),970getPointerAddressSpace(Ty), &MIRBuilder.getMF());971} else {972Reg =973DT.find(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),974getPointerAddressSpace(Ty), &MIRBuilder.getMF());975}976977if (Reg.isValid() && !isSpecialOpaqueType(Ty))978return getSPIRVTypeForVReg(Reg);979TypesInProcessing.clear();980SPIRVType *STy = restOfCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);981// Create normal pointer types for the corresponding OpTypeForwardPointers.982for (auto &CU : ForwardPointerTypes) {983const Type *Ty2 = CU.first;984SPIRVType *STy2 = CU.second;985if ((Reg = DT.find(Ty2, &MIRBuilder.getMF())).isValid())986STy2 = getSPIRVTypeForVReg(Reg);987else988STy2 = restOfCreateSPIRVType(Ty2, MIRBuilder, AccessQual, EmitIR);989if (Ty == Ty2)990STy = STy2;991}992ForwardPointerTypes.clear();993return STy;994}995996bool SPIRVGlobalRegistry::isScalarOfType(Register VReg,997unsigned TypeOpcode) const {998SPIRVType *Type = getSPIRVTypeForVReg(VReg);999assert(Type && "isScalarOfType VReg has no type assigned");1000return Type->getOpcode() == TypeOpcode;1001}10021003bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg,1004unsigned TypeOpcode) const {1005SPIRVType *Type = getSPIRVTypeForVReg(VReg);1006assert(Type && "isScalarOrVectorOfType VReg has no type assigned");1007if (Type->getOpcode() == TypeOpcode)1008return true;1009if (Type->getOpcode() == SPIRV::OpTypeVector) {1010Register ScalarTypeVReg = Type->getOperand(1).getReg();1011SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg);1012return ScalarType->getOpcode() == TypeOpcode;1013}1014return false;1015}10161017unsigned1018SPIRVGlobalRegistry::getScalarOrVectorComponentCount(Register VReg) const {1019return getScalarOrVectorComponentCount(getSPIRVTypeForVReg(VReg));1020}10211022unsigned1023SPIRVGlobalRegistry::getScalarOrVectorComponentCount(SPIRVType *Type) const {1024if (!Type)1025return 0;1026return Type->getOpcode() == SPIRV::OpTypeVector1027? static_cast<unsigned>(Type->getOperand(2).getImm())1028: 1;1029}10301031unsigned1032SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const {1033assert(Type && "Invalid Type pointer");1034if (Type->getOpcode() == SPIRV::OpTypeVector) {1035auto EleTypeReg = Type->getOperand(1).getReg();1036Type = getSPIRVTypeForVReg(EleTypeReg);1037}1038if (Type->getOpcode() == SPIRV::OpTypeInt ||1039Type->getOpcode() == SPIRV::OpTypeFloat)1040return Type->getOperand(1).getImm();1041if (Type->getOpcode() == SPIRV::OpTypeBool)1042return 1;1043llvm_unreachable("Attempting to get bit width of non-integer/float type.");1044}10451046unsigned SPIRVGlobalRegistry::getNumScalarOrVectorTotalBitWidth(1047const SPIRVType *Type) const {1048assert(Type && "Invalid Type pointer");1049unsigned NumElements = 1;1050if (Type->getOpcode() == SPIRV::OpTypeVector) {1051NumElements = static_cast<unsigned>(Type->getOperand(2).getImm());1052Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());1053}1054return Type->getOpcode() == SPIRV::OpTypeInt ||1055Type->getOpcode() == SPIRV::OpTypeFloat1056? NumElements * Type->getOperand(1).getImm()1057: 0;1058}10591060const SPIRVType *SPIRVGlobalRegistry::retrieveScalarOrVectorIntType(1061const SPIRVType *Type) const {1062if (Type && Type->getOpcode() == SPIRV::OpTypeVector)1063Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());1064return Type && Type->getOpcode() == SPIRV::OpTypeInt ? Type : nullptr;1065}10661067bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {1068const SPIRVType *IntType = retrieveScalarOrVectorIntType(Type);1069return IntType && IntType->getOperand(2).getImm() != 0;1070}10711072SPIRVType *SPIRVGlobalRegistry::getPointeeType(SPIRVType *PtrType) {1073return PtrType && PtrType->getOpcode() == SPIRV::OpTypePointer1074? getSPIRVTypeForVReg(PtrType->getOperand(2).getReg())1075: nullptr;1076}10771078unsigned SPIRVGlobalRegistry::getPointeeTypeOp(Register PtrReg) {1079SPIRVType *ElemType = getPointeeType(getSPIRVTypeForVReg(PtrReg));1080return ElemType ? ElemType->getOpcode() : 0;1081}10821083bool SPIRVGlobalRegistry::isBitcastCompatible(const SPIRVType *Type1,1084const SPIRVType *Type2) const {1085if (!Type1 || !Type2)1086return false;1087auto Op1 = Type1->getOpcode(), Op2 = Type2->getOpcode();1088// Ignore difference between <1.5 and >=1.5 protocol versions:1089// it's valid if either Result Type or Operand is a pointer, and the other1090// is a pointer, an integer scalar, or an integer vector.1091if (Op1 == SPIRV::OpTypePointer &&1092(Op2 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type2)))1093return true;1094if (Op2 == SPIRV::OpTypePointer &&1095(Op1 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type1)))1096return true;1097unsigned Bits1 = getNumScalarOrVectorTotalBitWidth(Type1),1098Bits2 = getNumScalarOrVectorTotalBitWidth(Type2);1099return Bits1 > 0 && Bits1 == Bits2;1100}11011102SPIRV::StorageClass::StorageClass1103SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const {1104SPIRVType *Type = getSPIRVTypeForVReg(VReg);1105assert(Type && Type->getOpcode() == SPIRV::OpTypePointer &&1106Type->getOperand(1).isImm() && "Pointer type is expected");1107return static_cast<SPIRV::StorageClass::StorageClass>(1108Type->getOperand(1).getImm());1109}11101111SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage(1112MachineIRBuilder &MIRBuilder, SPIRVType *SampledType, SPIRV::Dim::Dim Dim,1113uint32_t Depth, uint32_t Arrayed, uint32_t Multisampled, uint32_t Sampled,1114SPIRV::ImageFormat::ImageFormat ImageFormat,1115SPIRV::AccessQualifier::AccessQualifier AccessQual) {1116auto TD = SPIRV::make_descr_image(SPIRVToLLVMType.lookup(SampledType), Dim,1117Depth, Arrayed, Multisampled, Sampled,1118ImageFormat, AccessQual);1119if (auto *Res = checkSpecialInstr(TD, MIRBuilder))1120return Res;1121Register ResVReg = createTypeVReg(MIRBuilder);1122DT.add(TD, &MIRBuilder.getMF(), ResVReg);1123return MIRBuilder.buildInstr(SPIRV::OpTypeImage)1124.addDef(ResVReg)1125.addUse(getSPIRVTypeID(SampledType))1126.addImm(Dim)1127.addImm(Depth) // Depth (whether or not it is a Depth image).1128.addImm(Arrayed) // Arrayed.1129.addImm(Multisampled) // Multisampled (0 = only single-sample).1130.addImm(Sampled) // Sampled (0 = usage known at runtime).1131.addImm(ImageFormat)1132.addImm(AccessQual);1133}11341135SPIRVType *1136SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) {1137auto TD = SPIRV::make_descr_sampler();1138if (auto *Res = checkSpecialInstr(TD, MIRBuilder))1139return Res;1140Register ResVReg = createTypeVReg(MIRBuilder);1141DT.add(TD, &MIRBuilder.getMF(), ResVReg);1142return MIRBuilder.buildInstr(SPIRV::OpTypeSampler).addDef(ResVReg);1143}11441145SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe(1146MachineIRBuilder &MIRBuilder,1147SPIRV::AccessQualifier::AccessQualifier AccessQual) {1148auto TD = SPIRV::make_descr_pipe(AccessQual);1149if (auto *Res = checkSpecialInstr(TD, MIRBuilder))1150return Res;1151Register ResVReg = createTypeVReg(MIRBuilder);1152DT.add(TD, &MIRBuilder.getMF(), ResVReg);1153return MIRBuilder.buildInstr(SPIRV::OpTypePipe)1154.addDef(ResVReg)1155.addImm(AccessQual);1156}11571158SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent(1159MachineIRBuilder &MIRBuilder) {1160auto TD = SPIRV::make_descr_event();1161if (auto *Res = checkSpecialInstr(TD, MIRBuilder))1162return Res;1163Register ResVReg = createTypeVReg(MIRBuilder);1164DT.add(TD, &MIRBuilder.getMF(), ResVReg);1165return MIRBuilder.buildInstr(SPIRV::OpTypeDeviceEvent).addDef(ResVReg);1166}11671168SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage(1169SPIRVType *ImageType, MachineIRBuilder &MIRBuilder) {1170auto TD = SPIRV::make_descr_sampled_image(1171SPIRVToLLVMType.lookup(MIRBuilder.getMF().getRegInfo().getVRegDef(1172ImageType->getOperand(1).getReg())),1173ImageType);1174if (auto *Res = checkSpecialInstr(TD, MIRBuilder))1175return Res;1176Register ResVReg = createTypeVReg(MIRBuilder);1177DT.add(TD, &MIRBuilder.getMF(), ResVReg);1178return MIRBuilder.buildInstr(SPIRV::OpTypeSampledImage)1179.addDef(ResVReg)1180.addUse(getSPIRVTypeID(ImageType));1181}11821183SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(1184MachineIRBuilder &MIRBuilder, const TargetExtType *ExtensionType,1185const SPIRVType *ElemType, uint32_t Scope, uint32_t Rows, uint32_t Columns,1186uint32_t Use) {1187Register ResVReg = DT.find(ExtensionType, &MIRBuilder.getMF());1188if (ResVReg.isValid())1189return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);1190ResVReg = createTypeVReg(MIRBuilder);1191SPIRVType *SpirvTy =1192MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR)1193.addDef(ResVReg)1194.addUse(getSPIRVTypeID(ElemType))1195.addUse(buildConstantInt(Scope, MIRBuilder, nullptr, true))1196.addUse(buildConstantInt(Rows, MIRBuilder, nullptr, true))1197.addUse(buildConstantInt(Columns, MIRBuilder, nullptr, true))1198.addUse(buildConstantInt(Use, MIRBuilder, nullptr, true));1199DT.add(ExtensionType, &MIRBuilder.getMF(), ResVReg);1200return SpirvTy;1201}12021203SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode(1204const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode) {1205Register ResVReg = DT.find(Ty, &MIRBuilder.getMF());1206if (ResVReg.isValid())1207return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);1208ResVReg = createTypeVReg(MIRBuilder);1209SPIRVType *SpirvTy = MIRBuilder.buildInstr(Opcode).addDef(ResVReg);1210DT.add(Ty, &MIRBuilder.getMF(), ResVReg);1211return SpirvTy;1212}12131214const MachineInstr *1215SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,1216MachineIRBuilder &MIRBuilder) {1217Register Reg = DT.find(TD, &MIRBuilder.getMF());1218if (Reg.isValid())1219return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(Reg);1220return nullptr;1221}12221223// Returns nullptr if unable to recognize SPIRV type name1224SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(1225StringRef TypeStr, MachineIRBuilder &MIRBuilder,1226SPIRV::StorageClass::StorageClass SC,1227SPIRV::AccessQualifier::AccessQualifier AQ) {1228unsigned VecElts = 0;1229auto &Ctx = MIRBuilder.getMF().getFunction().getContext();12301231// Parse strings representing either a SPIR-V or OpenCL builtin type.1232if (hasBuiltinTypePrefix(TypeStr))1233return getOrCreateSPIRVType(SPIRV::parseBuiltinTypeNameToTargetExtType(1234TypeStr.str(), MIRBuilder.getContext()),1235MIRBuilder, AQ);12361237// Parse type name in either "typeN" or "type vector[N]" format, where1238// N is the number of elements of the vector.1239Type *Ty;12401241Ty = parseBasicTypeName(TypeStr, Ctx);1242if (!Ty)1243// Unable to recognize SPIRV type name1244return nullptr;12451246auto SpirvTy = getOrCreateSPIRVType(Ty, MIRBuilder, AQ);12471248// Handle "type*" or "type* vector[N]".1249if (TypeStr.starts_with("*")) {1250SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);1251TypeStr = TypeStr.substr(strlen("*"));1252}12531254// Handle "typeN*" or "type vector[N]*".1255bool IsPtrToVec = TypeStr.consume_back("*");12561257if (TypeStr.consume_front(" vector[")) {1258TypeStr = TypeStr.substr(0, TypeStr.find(']'));1259}1260TypeStr.getAsInteger(10, VecElts);1261if (VecElts > 0)1262SpirvTy = getOrCreateSPIRVVectorType(SpirvTy, VecElts, MIRBuilder);12631264if (IsPtrToVec)1265SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);12661267return SpirvTy;1268}12691270SPIRVType *1271SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth,1272MachineIRBuilder &MIRBuilder) {1273return getOrCreateSPIRVType(1274IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth),1275MIRBuilder);1276}12771278SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,1279SPIRVType *SpirvType) {1280assert(CurMF == SpirvType->getMF());1281VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType;1282SPIRVToLLVMType[SpirvType] = unifyPtrType(LLVMTy);1283return SpirvType;1284}12851286SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,1287MachineInstr &I,1288const SPIRVInstrInfo &TII,1289unsigned SPIRVOPcode,1290Type *LLVMTy) {1291Register Reg = DT.find(LLVMTy, CurMF);1292if (Reg.isValid())1293return getSPIRVTypeForVReg(Reg);1294MachineBasicBlock &BB = *I.getParent();1295auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRVOPcode))1296.addDef(createTypeVReg(CurMF->getRegInfo()))1297.addImm(BitWidth)1298.addImm(0);1299DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));1300return finishCreatingSPIRVType(LLVMTy, MIB);1301}13021303SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(1304unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {1305// Maybe adjust bit width to keep DuplicateTracker consistent. Without1306// such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create, for1307// example, the same "OpTypeInt 8" type for a series of LLVM integer types1308// with number of bits less than 8, causing duplicate type definitions.1309BitWidth = adjustOpTypeIntWidth(BitWidth);1310Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);1311return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeInt, LLVMTy);1312}13131314SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(1315unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {1316LLVMContext &Ctx = CurMF->getFunction().getContext();1317Type *LLVMTy;1318switch (BitWidth) {1319case 16:1320LLVMTy = Type::getHalfTy(Ctx);1321break;1322case 32:1323LLVMTy = Type::getFloatTy(Ctx);1324break;1325case 64:1326LLVMTy = Type::getDoubleTy(Ctx);1327break;1328default:1329llvm_unreachable("Bit width is of unexpected size.");1330}1331return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeFloat, LLVMTy);1332}13331334SPIRVType *1335SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) {1336return getOrCreateSPIRVType(1337IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1),1338MIRBuilder);1339}13401341SPIRVType *1342SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I,1343const SPIRVInstrInfo &TII) {1344Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), 1);1345Register Reg = DT.find(LLVMTy, CurMF);1346if (Reg.isValid())1347return getSPIRVTypeForVReg(Reg);1348MachineBasicBlock &BB = *I.getParent();1349auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeBool))1350.addDef(createTypeVReg(CurMF->getRegInfo()));1351DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));1352return finishCreatingSPIRVType(LLVMTy, MIB);1353}13541355SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(1356SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) {1357return getOrCreateSPIRVType(1358FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),1359NumElements),1360MIRBuilder);1361}13621363SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(1364SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,1365const SPIRVInstrInfo &TII) {1366Type *LLVMTy = FixedVectorType::get(1367const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);1368Register Reg = DT.find(LLVMTy, CurMF);1369if (Reg.isValid())1370return getSPIRVTypeForVReg(Reg);1371MachineBasicBlock &BB = *I.getParent();1372auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector))1373.addDef(createTypeVReg(CurMF->getRegInfo()))1374.addUse(getSPIRVTypeID(BaseType))1375.addImm(NumElements);1376DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));1377return finishCreatingSPIRVType(LLVMTy, MIB);1378}13791380SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVArrayType(1381SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,1382const SPIRVInstrInfo &TII) {1383Type *LLVMTy = ArrayType::get(1384const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);1385Register Reg = DT.find(LLVMTy, CurMF);1386if (Reg.isValid())1387return getSPIRVTypeForVReg(Reg);1388MachineBasicBlock &BB = *I.getParent();1389SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(32, I, TII);1390Register Len = getOrCreateConstInt(NumElements, I, SpirvType, TII);1391auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeArray))1392.addDef(createTypeVReg(CurMF->getRegInfo()))1393.addUse(getSPIRVTypeID(BaseType))1394.addUse(Len);1395DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));1396return finishCreatingSPIRVType(LLVMTy, MIB);1397}13981399SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(1400SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,1401SPIRV::StorageClass::StorageClass SC) {1402const Type *PointerElementType = getTypeForSPIRVType(BaseType);1403unsigned AddressSpace = storageClassToAddressSpace(SC);1404Type *LLVMTy = TypedPointerType::get(const_cast<Type *>(PointerElementType),1405AddressSpace);1406// check if this type is already available1407Register Reg = DT.find(PointerElementType, AddressSpace, CurMF);1408if (Reg.isValid())1409return getSPIRVTypeForVReg(Reg);1410// create a new type1411auto MIB = BuildMI(MIRBuilder.getMBB(), MIRBuilder.getInsertPt(),1412MIRBuilder.getDebugLoc(),1413MIRBuilder.getTII().get(SPIRV::OpTypePointer))1414.addDef(createTypeVReg(CurMF->getRegInfo()))1415.addImm(static_cast<uint32_t>(SC))1416.addUse(getSPIRVTypeID(BaseType));1417DT.add(PointerElementType, AddressSpace, CurMF, getSPIRVTypeID(MIB));1418return finishCreatingSPIRVType(LLVMTy, MIB);1419}14201421SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(1422SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &,1423SPIRV::StorageClass::StorageClass SC) {1424MachineIRBuilder MIRBuilder(I);1425return getOrCreateSPIRVPointerType(BaseType, MIRBuilder, SC);1426}14271428Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,1429SPIRVType *SpvType,1430const SPIRVInstrInfo &TII) {1431assert(SpvType);1432const Type *LLVMTy = getTypeForSPIRVType(SpvType);1433assert(LLVMTy);1434// Find a constant in DT or build a new one.1435UndefValue *UV = UndefValue::get(const_cast<Type *>(LLVMTy));1436Register Res = DT.find(UV, CurMF);1437if (Res.isValid())1438return Res;1439LLT LLTy = LLT::scalar(32);1440Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);1441CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);1442assignSPIRVTypeToVReg(SpvType, Res, *CurMF);1443DT.add(UV, CurMF, Res);14441445MachineInstrBuilder MIB;1446MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpUndef))1447.addDef(Res)1448.addUse(getSPIRVTypeID(SpvType));1449const auto &ST = CurMF->getSubtarget();1450constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),1451*ST.getRegisterInfo(), *ST.getRegBankInfo());1452return Res;1453}145414551456