Path: blob/main/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
35294 views
//===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// This file implements the SPIRVTargetLowering class.9//10//===----------------------------------------------------------------------===//1112#include "SPIRVISelLowering.h"13#include "SPIRV.h"14#include "SPIRVInstrInfo.h"15#include "SPIRVRegisterBankInfo.h"16#include "SPIRVRegisterInfo.h"17#include "SPIRVSubtarget.h"18#include "SPIRVTargetMachine.h"19#include "llvm/CodeGen/MachineInstrBuilder.h"20#include "llvm/CodeGen/MachineRegisterInfo.h"21#include "llvm/IR/IntrinsicsSPIRV.h"2223#define DEBUG_TYPE "spirv-lower"2425using namespace llvm;2627unsigned SPIRVTargetLowering::getNumRegistersForCallingConv(28LLVMContext &Context, CallingConv::ID CC, EVT VT) const {29// This code avoids CallLowering fail inside getVectorTypeBreakdown30// on v3i1 arguments. Maybe we need to return 1 for all types.31// TODO: remove it once this case is supported by the default implementation.32if (VT.isVector() && VT.getVectorNumElements() == 3 &&33(VT.getVectorElementType() == MVT::i1 ||34VT.getVectorElementType() == MVT::i8))35return 1;36if (!VT.isVector() && VT.isInteger() && VT.getSizeInBits() <= 64)37return 1;38return getNumRegisters(Context, VT);39}4041MVT SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,42CallingConv::ID CC,43EVT VT) const {44// This code avoids CallLowering fail inside getVectorTypeBreakdown45// on v3i1 arguments. Maybe we need to return i32 for all types.46// TODO: remove it once this case is supported by the default implementation.47if (VT.isVector() && VT.getVectorNumElements() == 3) {48if (VT.getVectorElementType() == MVT::i1)49return MVT::v4i1;50else if (VT.getVectorElementType() == MVT::i8)51return MVT::v4i8;52}53return getRegisterType(Context, VT);54}5556bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,57const CallInst &I,58MachineFunction &MF,59unsigned Intrinsic) const {60unsigned AlignIdx = 3;61switch (Intrinsic) {62case Intrinsic::spv_load:63AlignIdx = 2;64[[fallthrough]];65case Intrinsic::spv_store: {66if (I.getNumOperands() >= AlignIdx + 1) {67auto *AlignOp = cast<ConstantInt>(I.getOperand(AlignIdx));68Info.align = Align(AlignOp->getZExtValue());69}70Info.flags = static_cast<MachineMemOperand::Flags>(71cast<ConstantInt>(I.getOperand(AlignIdx - 1))->getZExtValue());72Info.memVT = MVT::i64;73// TODO: take into account opaque pointers (don't use getElementType).74// MVT::getVT(PtrTy->getElementType());75return true;76break;77}78default:79break;80}81return false;82}8384std::pair<unsigned, const TargetRegisterClass *>85SPIRVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,86StringRef Constraint,87MVT VT) const {88const TargetRegisterClass *RC = nullptr;89if (Constraint.starts_with("{"))90return std::make_pair(0u, RC);9192if (VT.isFloatingPoint())93RC = VT.isVector() ? &SPIRV::vfIDRegClass94: (VT.getScalarSizeInBits() > 32 ? &SPIRV::fID64RegClass95: &SPIRV::fIDRegClass);96else if (VT.isInteger())97RC = VT.isVector() ? &SPIRV::vIDRegClass98: (VT.getScalarSizeInBits() > 32 ? &SPIRV::ID64RegClass99: &SPIRV::IDRegClass);100else101RC = &SPIRV::IDRegClass;102103return std::make_pair(0u, RC);104}105106inline Register getTypeReg(MachineRegisterInfo *MRI, Register OpReg) {107SPIRVType *TypeInst = MRI->getVRegDef(OpReg);108return TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter109? TypeInst->getOperand(1).getReg()110: OpReg;111}112113static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,114SPIRVGlobalRegistry &GR, MachineInstr &I,115Register OpReg, unsigned OpIdx,116SPIRVType *NewPtrType) {117Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32));118MachineIRBuilder MIB(I);119bool Res = MIB.buildInstr(SPIRV::OpBitcast)120.addDef(NewReg)121.addUse(GR.getSPIRVTypeID(NewPtrType))122.addUse(OpReg)123.constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),124*STI.getRegBankInfo());125if (!Res)126report_fatal_error("insert validation bitcast: cannot constrain all uses");127MRI->setRegClass(NewReg, &SPIRV::IDRegClass);128GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF());129I.getOperand(OpIdx).setReg(NewReg);130}131132static SPIRVType *createNewPtrType(SPIRVGlobalRegistry &GR, MachineInstr &I,133SPIRVType *OpType, bool ReuseType,134bool EmitIR, SPIRVType *ResType,135const Type *ResTy) {136SPIRV::StorageClass::StorageClass SC =137static_cast<SPIRV::StorageClass::StorageClass>(138OpType->getOperand(1).getImm());139MachineIRBuilder MIB(I);140SPIRVType *NewBaseType =141ReuseType ? ResType142: GR.getOrCreateSPIRVType(143ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, EmitIR);144return GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC);145}146147// Insert a bitcast before the instruction to keep SPIR-V code valid148// when there is a type mismatch between results and operand types.149static void validatePtrTypes(const SPIRVSubtarget &STI,150MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR,151MachineInstr &I, unsigned OpIdx,152SPIRVType *ResType, const Type *ResTy = nullptr) {153// Get operand type154MachineFunction *MF = I.getParent()->getParent();155Register OpReg = I.getOperand(OpIdx).getReg();156Register OpTypeReg = getTypeReg(MRI, OpReg);157SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);158if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)159return;160// Get operand's pointee type161Register ElemTypeReg = OpType->getOperand(2).getReg();162SPIRVType *ElemType = GR.getSPIRVTypeForVReg(ElemTypeReg, MF);163if (!ElemType)164return;165// Check if we need a bitcast to make a statement valid166bool IsSameMF = MF == ResType->getParent()->getParent();167bool IsEqualTypes = IsSameMF ? ElemType == ResType168: GR.getTypeForSPIRVType(ElemType) == ResTy;169if (IsEqualTypes)170return;171// There is a type mismatch between results and operand types172// and we insert a bitcast before the instruction to keep SPIR-V code valid173SPIRVType *NewPtrType =174createNewPtrType(GR, I, OpType, IsSameMF, false, ResType, ResTy);175if (!GR.isBitcastCompatible(NewPtrType, OpType))176report_fatal_error(177"insert validation bitcast: incompatible result and operand types");178doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);179}180181// Insert a bitcast before OpGroupWaitEvents if the last argument is a pointer182// that doesn't point to OpTypeEvent.183static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI,184MachineRegisterInfo *MRI,185SPIRVGlobalRegistry &GR,186MachineInstr &I) {187constexpr unsigned OpIdx = 2;188MachineFunction *MF = I.getParent()->getParent();189Register OpReg = I.getOperand(OpIdx).getReg();190Register OpTypeReg = getTypeReg(MRI, OpReg);191SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);192if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)193return;194SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());195if (!ElemType || ElemType->getOpcode() == SPIRV::OpTypeEvent)196return;197// Insert a bitcast before the instruction to keep SPIR-V code valid.198LLVMContext &Context = MF->getFunction().getContext();199SPIRVType *NewPtrType =200createNewPtrType(GR, I, OpType, false, true, nullptr,201TargetExtType::get(Context, "spirv.Event"));202doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);203}204205static void validateGroupAsyncCopyPtr(const SPIRVSubtarget &STI,206MachineRegisterInfo *MRI,207SPIRVGlobalRegistry &GR, MachineInstr &I,208unsigned OpIdx) {209MachineFunction *MF = I.getParent()->getParent();210Register OpReg = I.getOperand(OpIdx).getReg();211Register OpTypeReg = getTypeReg(MRI, OpReg);212SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);213if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)214return;215SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());216if (!ElemType || ElemType->getOpcode() != SPIRV::OpTypeStruct ||217ElemType->getNumOperands() != 2)218return;219// It's a structure-wrapper around another type with a single member field.220SPIRVType *MemberType =221GR.getSPIRVTypeForVReg(ElemType->getOperand(1).getReg());222if (!MemberType)223return;224unsigned MemberTypeOp = MemberType->getOpcode();225if (MemberTypeOp != SPIRV::OpTypeVector && MemberTypeOp != SPIRV::OpTypeInt &&226MemberTypeOp != SPIRV::OpTypeFloat && MemberTypeOp != SPIRV::OpTypeBool)227return;228// It's a structure-wrapper around a valid type. Insert a bitcast before the229// instruction to keep SPIR-V code valid.230SPIRV::StorageClass::StorageClass SC =231static_cast<SPIRV::StorageClass::StorageClass>(232OpType->getOperand(1).getImm());233MachineIRBuilder MIB(I);234SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(MemberType, MIB, SC);235doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);236}237238// Insert a bitcast before the function call instruction to keep SPIR-V code239// valid when there is a type mismatch between actual and expected types of an240// argument:241// %formal = OpFunctionParameter %formal_type242// ...243// %res = OpFunctionCall %ty %fun %actual ...244// implies that %actual is of %formal_type, and in case of opaque pointers.245// We may need to insert a bitcast to ensure this.246void validateFunCallMachineDef(const SPIRVSubtarget &STI,247MachineRegisterInfo *DefMRI,248MachineRegisterInfo *CallMRI,249SPIRVGlobalRegistry &GR, MachineInstr &FunCall,250MachineInstr *FunDef) {251if (FunDef->getOpcode() != SPIRV::OpFunction)252return;253unsigned OpIdx = 3;254for (FunDef = FunDef->getNextNode();255FunDef && FunDef->getOpcode() == SPIRV::OpFunctionParameter &&256OpIdx < FunCall.getNumOperands();257FunDef = FunDef->getNextNode(), OpIdx++) {258SPIRVType *DefPtrType = DefMRI->getVRegDef(FunDef->getOperand(1).getReg());259SPIRVType *DefElemType =260DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer261? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg(),262DefPtrType->getParent()->getParent())263: nullptr;264if (DefElemType) {265const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType);266// validatePtrTypes() works in the context if the call site267// When we process historical records about forward calls268// we need to switch context to the (forward) call site and269// then restore it back to the current machine function.270MachineFunction *CurMF =271GR.setCurrentFunc(*FunCall.getParent()->getParent());272validatePtrTypes(STI, CallMRI, GR, FunCall, OpIdx, DefElemType,273DefElemTy);274GR.setCurrentFunc(*CurMF);275}276}277}278279// Ensure there is no mismatch between actual and expected arg types: calls280// with a processed definition. Return Function pointer if it's a forward281// call (ahead of definition), and nullptr otherwise.282const Function *validateFunCall(const SPIRVSubtarget &STI,283MachineRegisterInfo *CallMRI,284SPIRVGlobalRegistry &GR,285MachineInstr &FunCall) {286const GlobalValue *GV = FunCall.getOperand(2).getGlobal();287const Function *F = dyn_cast<Function>(GV);288MachineInstr *FunDef =289const_cast<MachineInstr *>(GR.getFunctionDefinition(F));290if (!FunDef)291return F;292MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo();293validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef);294return nullptr;295}296297// Ensure there is no mismatch between actual and expected arg types: calls298// ahead of a processed definition.299void validateForwardCalls(const SPIRVSubtarget &STI,300MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR,301MachineInstr &FunDef) {302const Function *F = GR.getFunctionByDefinition(&FunDef);303if (SmallPtrSet<MachineInstr *, 8> *FwdCalls = GR.getForwardCalls(F))304for (MachineInstr *FunCall : *FwdCalls) {305MachineRegisterInfo *CallMRI =306&FunCall->getParent()->getParent()->getRegInfo();307validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, *FunCall, &FunDef);308}309}310311// Validation of an access chain.312void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,313SPIRVGlobalRegistry &GR, MachineInstr &I) {314SPIRVType *BaseTypeInst = GR.getSPIRVTypeForVReg(I.getOperand(0).getReg());315if (BaseTypeInst && BaseTypeInst->getOpcode() == SPIRV::OpTypePointer) {316SPIRVType *BaseElemType =317GR.getSPIRVTypeForVReg(BaseTypeInst->getOperand(2).getReg());318validatePtrTypes(STI, MRI, GR, I, 2, BaseElemType);319}320}321322// TODO: the logic of inserting additional bitcast's is to be moved323// to pre-IRTranslation passes eventually324void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {325// finalizeLowering() is called twice (see GlobalISel/InstructionSelect.cpp)326// We'd like to avoid the needless second processing pass.327if (ProcessedMF.find(&MF) != ProcessedMF.end())328return;329330MachineRegisterInfo *MRI = &MF.getRegInfo();331SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();332GR.setCurrentFunc(MF);333for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) {334MachineBasicBlock *MBB = &*I;335for (MachineBasicBlock::iterator MBBI = MBB->begin(), MBBE = MBB->end();336MBBI != MBBE;) {337MachineInstr &MI = *MBBI++;338switch (MI.getOpcode()) {339case SPIRV::OpAtomicLoad:340case SPIRV::OpAtomicExchange:341case SPIRV::OpAtomicCompareExchange:342case SPIRV::OpAtomicCompareExchangeWeak:343case SPIRV::OpAtomicIIncrement:344case SPIRV::OpAtomicIDecrement:345case SPIRV::OpAtomicIAdd:346case SPIRV::OpAtomicISub:347case SPIRV::OpAtomicSMin:348case SPIRV::OpAtomicUMin:349case SPIRV::OpAtomicSMax:350case SPIRV::OpAtomicUMax:351case SPIRV::OpAtomicAnd:352case SPIRV::OpAtomicOr:353case SPIRV::OpAtomicXor:354// for the above listed instructions355// OpAtomicXXX <ResType>, ptr %Op, ...356// implies that %Op is a pointer to <ResType>357case SPIRV::OpLoad:358// OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType>359validatePtrTypes(STI, MRI, GR, MI, 2,360GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg()));361break;362case SPIRV::OpAtomicStore:363// OpAtomicStore ptr %Op, <Scope>, <Mem>, <Obj>364// implies that %Op points to the <Obj>'s type365validatePtrTypes(STI, MRI, GR, MI, 0,366GR.getSPIRVTypeForVReg(MI.getOperand(3).getReg()));367break;368case SPIRV::OpStore:369// OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type370validatePtrTypes(STI, MRI, GR, MI, 0,371GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()));372break;373case SPIRV::OpPtrCastToGeneric:374case SPIRV::OpGenericCastToPtr:375validateAccessChain(STI, MRI, GR, MI);376break;377case SPIRV::OpInBoundsPtrAccessChain:378if (MI.getNumOperands() == 4)379validateAccessChain(STI, MRI, GR, MI);380break;381382case SPIRV::OpFunctionCall:383// ensure there is no mismatch between actual and expected arg types:384// calls with a processed definition385if (MI.getNumOperands() > 3)386if (const Function *F = validateFunCall(STI, MRI, GR, MI))387GR.addForwardCall(F, &MI);388break;389case SPIRV::OpFunction:390// ensure there is no mismatch between actual and expected arg types:391// calls ahead of a processed definition392validateForwardCalls(STI, MRI, GR, MI);393break;394395// ensure that LLVM IR bitwise instructions result in logical SPIR-V396// instructions when applied to bool type397case SPIRV::OpBitwiseOrS:398case SPIRV::OpBitwiseOrV:399if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),400SPIRV::OpTypeBool))401MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalOr));402break;403case SPIRV::OpBitwiseAndS:404case SPIRV::OpBitwiseAndV:405if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),406SPIRV::OpTypeBool))407MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalAnd));408break;409case SPIRV::OpBitwiseXorS:410case SPIRV::OpBitwiseXorV:411if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),412SPIRV::OpTypeBool))413MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));414break;415case SPIRV::OpGroupAsyncCopy:416validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 3);417validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 4);418break;419case SPIRV::OpGroupWaitEvents:420// OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>421validateGroupWaitEventsPtr(STI, MRI, GR, MI);422break;423case SPIRV::OpConstantI: {424SPIRVType *Type = GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg());425if (Type->getOpcode() != SPIRV::OpTypeInt && MI.getOperand(2).isImm() &&426MI.getOperand(2).getImm() == 0) {427// Validate the null constant of a target extension type428MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpConstantNull));429for (unsigned i = MI.getNumOperands() - 1; i > 1; --i)430MI.removeOperand(i);431}432} break;433}434}435}436ProcessedMF.insert(&MF);437TargetLowering::finalizeLowering(MF);438}439440441