Path: blob/main/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
213799 views
//===-- SPIRVLegalizePointerCast.cpp ----------------------*- C++ -*-===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// The LLVM IR has multiple legal patterns we cannot lower to Logical SPIR-V.9// This pass modifies such loads to have an IR we can directly lower to valid10// logical SPIR-V.11// OpenCL can avoid this because they rely on ptrcast, which is not supported12// by logical SPIR-V.13//14// This pass relies on the assign_ptr_type intrinsic to deduce the type of the15// pointed values, must replace all occurences of `ptrcast`. This is why16// unhandled cases are reported as unreachable: we MUST cover all cases.17//18// 1. Loading the first element of an array19//20// %array = [10 x i32]21// %value = load i32, ptr %array22//23// LLVM can skip the GEP instruction, and only request loading the first 424// bytes. In logical SPIR-V, we need an OpAccessChain to access the first25// element. This pass will add a getelementptr instruction before the load.26//27//28// 2. Implicit downcast from load29//30// %1 = getelementptr <4 x i32>, ptr %vec4, i64 031// %2 = load <3 x i32>, ptr %132//33// The pointer in the GEP instruction is only used for offset computations,34// but it doesn't NEED to match the pointed type. OpAccessChain however35// requires this. Also, LLVM loads define the bitwidth of the load, not the36// pointer. In this example, we can guess %vec4 is a vec4 thanks to the GEP37// instruction basetype, but we only want to load the first 3 elements, hence38// do a partial load. In logical SPIR-V, this is not legal. What we must do39// is load the full vector (basetype), extract 3 elements, and recombine them40// to form a 3-element vector.41//42//===----------------------------------------------------------------------===//4344#include "SPIRV.h"45#include "SPIRVSubtarget.h"46#include "SPIRVTargetMachine.h"47#include "SPIRVUtils.h"48#include "llvm/CodeGen/IntrinsicLowering.h"49#include "llvm/IR/IRBuilder.h"50#include "llvm/IR/IntrinsicInst.h"51#include "llvm/IR/Intrinsics.h"52#include "llvm/IR/IntrinsicsSPIRV.h"53#include "llvm/Transforms/Utils/Cloning.h"54#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"5556using namespace llvm;5758namespace {59class SPIRVLegalizePointerCast : public FunctionPass {6061// Builds the `spv_assign_type` assigning |Ty| to |Value| at the current62// builder position.63void buildAssignType(IRBuilder<> &B, Type *Ty, Value *Arg) {64Value *OfType = PoisonValue::get(Ty);65CallInst *AssignCI = buildIntrWithMD(Intrinsic::spv_assign_type,66{Arg->getType()}, OfType, Arg, {}, B);67GR->addAssignPtrTypeInstr(Arg, AssignCI);68}6970// Loads parts of the vector of type |SourceType| from the pointer |Source|71// and create a new vector of type |TargetType|. |TargetType| must be a vector72// type, and element types of |TargetType| and |SourceType| must match.73// Returns the loaded value.74Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *SourceType,75FixedVectorType *TargetType, Value *Source) {76// We expect the codegen to avoid doing implicit bitcast from a load.77assert(TargetType->getElementType() == SourceType->getElementType());78assert(TargetType->getNumElements() < SourceType->getNumElements());7980LoadInst *NewLoad = B.CreateLoad(SourceType, Source);81buildAssignType(B, SourceType, NewLoad);8283SmallVector<int> Mask(/* Size= */ TargetType->getNumElements());84for (unsigned I = 0; I < TargetType->getNumElements(); ++I)85Mask[I] = I;86Value *Output = B.CreateShuffleVector(NewLoad, NewLoad, Mask);87buildAssignType(B, TargetType, Output);88return Output;89}9091// Loads the first value in an aggregate pointed by |Source| of containing92// elements of type |ElementType|. Load flags will be copied from |BadLoad|,93// which should be the load being legalized. Returns the loaded value.94Value *loadFirstValueFromAggregate(IRBuilder<> &B, Type *ElementType,95Value *Source, LoadInst *BadLoad) {96SmallVector<Type *, 2> Types = {BadLoad->getPointerOperandType(),97BadLoad->getPointerOperandType()};98SmallVector<Value *, 3> Args{/* isInBounds= */ B.getInt1(false), Source,99B.getInt32(0), B.getInt32(0)};100auto *GEP = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});101GR->buildAssignPtr(B, ElementType, GEP);102103LoadInst *LI = B.CreateLoad(ElementType, GEP);104LI->setAlignment(BadLoad->getAlign());105buildAssignType(B, ElementType, LI);106return LI;107}108109// Replaces the load instruction to get rid of the ptrcast used as source110// operand.111void transformLoad(IRBuilder<> &B, LoadInst *LI, Value *CastedOperand,112Value *OriginalOperand) {113Type *FromTy = GR->findDeducedElementType(OriginalOperand);114Type *ToTy = GR->findDeducedElementType(CastedOperand);115Value *Output = nullptr;116117auto *SAT = dyn_cast<ArrayType>(FromTy);118auto *SVT = dyn_cast<FixedVectorType>(FromTy);119auto *SST = dyn_cast<StructType>(FromTy);120auto *DVT = dyn_cast<FixedVectorType>(ToTy);121122B.SetInsertPoint(LI);123124// Destination is the element type of Source, and source is an array ->125// Loading 1st element.126// - float a = array[0];127if (SAT && SAT->getElementType() == ToTy)128Output = loadFirstValueFromAggregate(B, SAT->getElementType(),129OriginalOperand, LI);130// Destination is the element type of Source, and source is a vector ->131// Vector to scalar.132// - float a = vector.x;133else if (!DVT && SVT && SVT->getElementType() == ToTy) {134Output = loadFirstValueFromAggregate(B, SVT->getElementType(),135OriginalOperand, LI);136}137// Destination is a smaller vector than source.138// - float3 v3 = vector4;139else if (SVT && DVT)140Output = loadVectorFromVector(B, SVT, DVT, OriginalOperand);141// Destination is the scalar type stored at the start of an aggregate.142// - struct S { float m };143// - float v = s.m;144else if (SST && SST->getTypeAtIndex(0u) == ToTy)145Output = loadFirstValueFromAggregate(B, ToTy, OriginalOperand, LI);146else147llvm_unreachable("Unimplemented implicit down-cast from load.");148149GR->replaceAllUsesWith(LI, Output, /* DeleteOld= */ true);150DeadInstructions.push_back(LI);151}152153// Creates an spv_insertelt instruction (equivalent to llvm's insertelement).154Value *makeInsertElement(IRBuilder<> &B, Value *Vector, Value *Element,155unsigned Index) {156Type *Int32Ty = Type::getInt32Ty(B.getContext());157SmallVector<Type *, 4> Types = {Vector->getType(), Vector->getType(),158Element->getType(), Int32Ty};159SmallVector<Value *> Args = {Vector, Element, B.getInt32(Index)};160Instruction *NewI =161B.CreateIntrinsic(Intrinsic::spv_insertelt, {Types}, {Args});162buildAssignType(B, Vector->getType(), NewI);163return NewI;164}165166// Creates an spv_extractelt instruction (equivalent to llvm's167// extractelement).168Value *makeExtractElement(IRBuilder<> &B, Type *ElementType, Value *Vector,169unsigned Index) {170Type *Int32Ty = Type::getInt32Ty(B.getContext());171SmallVector<Type *, 3> Types = {ElementType, Vector->getType(), Int32Ty};172SmallVector<Value *> Args = {Vector, B.getInt32(Index)};173Instruction *NewI =174B.CreateIntrinsic(Intrinsic::spv_extractelt, {Types}, {Args});175buildAssignType(B, ElementType, NewI);176return NewI;177}178179// Stores the given Src vector operand into the Dst vector, adjusting the size180// if required.181Value *storeVectorFromVector(IRBuilder<> &B, Value *Src, Value *Dst,182Align Alignment) {183FixedVectorType *SrcType = cast<FixedVectorType>(Src->getType());184FixedVectorType *DstType =185cast<FixedVectorType>(GR->findDeducedElementType(Dst));186assert(DstType->getNumElements() >= SrcType->getNumElements());187188LoadInst *LI = B.CreateLoad(DstType, Dst);189LI->setAlignment(Alignment);190Value *OldValues = LI;191buildAssignType(B, OldValues->getType(), OldValues);192Value *NewValues = Src;193194for (unsigned I = 0; I < SrcType->getNumElements(); ++I) {195Value *Element =196makeExtractElement(B, SrcType->getElementType(), NewValues, I);197OldValues = makeInsertElement(B, OldValues, Element, I);198}199200StoreInst *SI = B.CreateStore(OldValues, Dst);201SI->setAlignment(Alignment);202return SI;203}204205void buildGEPIndexChain(IRBuilder<> &B, Type *Search, Type *Aggregate,206SmallVectorImpl<Value *> &Indices) {207Indices.push_back(B.getInt32(0));208209if (Search == Aggregate)210return;211212if (auto *ST = dyn_cast<StructType>(Aggregate))213buildGEPIndexChain(B, Search, ST->getTypeAtIndex(0u), Indices);214else if (auto *AT = dyn_cast<ArrayType>(Aggregate))215buildGEPIndexChain(B, Search, AT->getElementType(), Indices);216else if (auto *VT = dyn_cast<FixedVectorType>(Aggregate))217buildGEPIndexChain(B, Search, VT->getElementType(), Indices);218else219llvm_unreachable("Bad access chain?");220}221222// Stores the given Src value into the first entry of the Dst aggregate.223Value *storeToFirstValueAggregate(IRBuilder<> &B, Value *Src, Value *Dst,224Type *DstPointeeType, Align Alignment) {225SmallVector<Type *, 2> Types = {Dst->getType(), Dst->getType()};226SmallVector<Value *, 3> Args{/* isInBounds= */ B.getInt1(true), Dst};227buildGEPIndexChain(B, Src->getType(), DstPointeeType, Args);228auto *GEP = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});229GR->buildAssignPtr(B, Src->getType(), GEP);230StoreInst *SI = B.CreateStore(Src, GEP);231SI->setAlignment(Alignment);232return SI;233}234235bool isTypeFirstElementAggregate(Type *Search, Type *Aggregate) {236if (Search == Aggregate)237return true;238if (auto *ST = dyn_cast<StructType>(Aggregate))239return isTypeFirstElementAggregate(Search, ST->getTypeAtIndex(0u));240if (auto *VT = dyn_cast<FixedVectorType>(Aggregate))241return isTypeFirstElementAggregate(Search, VT->getElementType());242if (auto *AT = dyn_cast<ArrayType>(Aggregate))243return isTypeFirstElementAggregate(Search, AT->getElementType());244return false;245}246247// Transforms a store instruction (or SPV intrinsic) using a ptrcast as248// operand into a valid logical SPIR-V store with no ptrcast.249void transformStore(IRBuilder<> &B, Instruction *BadStore, Value *Src,250Value *Dst, Align Alignment) {251Type *ToTy = GR->findDeducedElementType(Dst);252Type *FromTy = Src->getType();253254auto *S_VT = dyn_cast<FixedVectorType>(FromTy);255auto *D_ST = dyn_cast<StructType>(ToTy);256auto *D_VT = dyn_cast<FixedVectorType>(ToTy);257258B.SetInsertPoint(BadStore);259if (D_ST && isTypeFirstElementAggregate(FromTy, D_ST))260storeToFirstValueAggregate(B, Src, Dst, D_ST, Alignment);261else if (D_VT && S_VT)262storeVectorFromVector(B, Src, Dst, Alignment);263else if (D_VT && !S_VT && FromTy == D_VT->getElementType())264storeToFirstValueAggregate(B, Src, Dst, D_VT, Alignment);265else266llvm_unreachable("Unsupported ptrcast use in store. Please fix.");267268DeadInstructions.push_back(BadStore);269}270271void legalizePointerCast(IntrinsicInst *II) {272Value *CastedOperand = II;273Value *OriginalOperand = II->getOperand(0);274275IRBuilder<> B(II->getContext());276std::vector<Value *> Users;277for (Use &U : II->uses())278Users.push_back(U.getUser());279280for (Value *User : Users) {281if (LoadInst *LI = dyn_cast<LoadInst>(User)) {282transformLoad(B, LI, CastedOperand, OriginalOperand);283continue;284}285286if (StoreInst *SI = dyn_cast<StoreInst>(User)) {287transformStore(B, SI, SI->getValueOperand(), OriginalOperand,288SI->getAlign());289continue;290}291292if (IntrinsicInst *Intrin = dyn_cast<IntrinsicInst>(User)) {293if (Intrin->getIntrinsicID() == Intrinsic::spv_assign_ptr_type) {294DeadInstructions.push_back(Intrin);295continue;296}297298if (Intrin->getIntrinsicID() == Intrinsic::spv_gep) {299GR->replaceAllUsesWith(CastedOperand, OriginalOperand,300/* DeleteOld= */ false);301continue;302}303304if (Intrin->getIntrinsicID() == Intrinsic::spv_store) {305Align Alignment;306if (ConstantInt *C = dyn_cast<ConstantInt>(Intrin->getOperand(3)))307Alignment = Align(C->getZExtValue());308transformStore(B, Intrin, Intrin->getArgOperand(0), OriginalOperand,309Alignment);310continue;311}312}313314llvm_unreachable("Unsupported ptrcast user. Please fix.");315}316317DeadInstructions.push_back(II);318}319320public:321SPIRVLegalizePointerCast(SPIRVTargetMachine *TM) : FunctionPass(ID), TM(TM) {}322323virtual bool runOnFunction(Function &F) override {324const SPIRVSubtarget &ST = TM->getSubtarget<SPIRVSubtarget>(F);325GR = ST.getSPIRVGlobalRegistry();326DeadInstructions.clear();327328std::vector<IntrinsicInst *> WorkList;329for (auto &BB : F) {330for (auto &I : BB) {331auto *II = dyn_cast<IntrinsicInst>(&I);332if (II && II->getIntrinsicID() == Intrinsic::spv_ptrcast)333WorkList.push_back(II);334}335}336337for (IntrinsicInst *II : WorkList)338legalizePointerCast(II);339340for (Instruction *I : DeadInstructions)341I->eraseFromParent();342343return DeadInstructions.size() != 0;344}345346private:347SPIRVTargetMachine *TM = nullptr;348SPIRVGlobalRegistry *GR = nullptr;349std::vector<Instruction *> DeadInstructions;350351public:352static char ID;353};354} // namespace355356char SPIRVLegalizePointerCast::ID = 0;357INITIALIZE_PASS(SPIRVLegalizePointerCast, "spirv-legalize-bitcast",358"SPIRV legalize bitcast pass", false, false)359360FunctionPass *llvm::createSPIRVLegalizePointerCastPass(SPIRVTargetMachine *TM) {361return new SPIRVLegalizePointerCast(TM);362}363364365