Path: blob/main/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp
35269 views
//===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// This pass modifies function signatures containing aggregate arguments9// and/or return value before IRTranslator. Information about the original10// signatures is stored in metadata. It is used during call lowering to11// restore correct SPIR-V types of function arguments and return values.12// This pass also substitutes some llvm intrinsic calls with calls to newly13// generated functions (as the Khronos LLVM/SPIR-V Translator does).14//15// NOTE: this pass is a module-level one due to the necessity to modify16// GVs/functions.17//18//===----------------------------------------------------------------------===//1920#include "SPIRV.h"21#include "SPIRVSubtarget.h"22#include "SPIRVTargetMachine.h"23#include "SPIRVUtils.h"24#include "llvm/Analysis/ValueTracking.h"25#include "llvm/CodeGen/IntrinsicLowering.h"26#include "llvm/IR/IRBuilder.h"27#include "llvm/IR/IntrinsicInst.h"28#include "llvm/IR/Intrinsics.h"29#include "llvm/IR/IntrinsicsSPIRV.h"30#include "llvm/Transforms/Utils/Cloning.h"31#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"32#include <charconv>33#include <regex>3435using namespace llvm;3637namespace llvm {38void initializeSPIRVPrepareFunctionsPass(PassRegistry &);39}4041namespace {4243class SPIRVPrepareFunctions : public ModulePass {44const SPIRVTargetMachine &TM;45bool substituteIntrinsicCalls(Function *F);46Function *removeAggregateTypesFromSignature(Function *F);4748public:49static char ID;50SPIRVPrepareFunctions(const SPIRVTargetMachine &TM) : ModulePass(ID), TM(TM) {51initializeSPIRVPrepareFunctionsPass(*PassRegistry::getPassRegistry());52}5354bool runOnModule(Module &M) override;5556StringRef getPassName() const override { return "SPIRV prepare functions"; }5758void getAnalysisUsage(AnalysisUsage &AU) const override {59ModulePass::getAnalysisUsage(AU);60}61};6263} // namespace6465char SPIRVPrepareFunctions::ID = 0;6667INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions",68"SPIRV prepare functions", false, false)6970std::string lowerLLVMIntrinsicName(IntrinsicInst *II) {71Function *IntrinsicFunc = II->getCalledFunction();72assert(IntrinsicFunc && "Missing function");73std::string FuncName = IntrinsicFunc->getName().str();74std::replace(FuncName.begin(), FuncName.end(), '.', '_');75FuncName = "spirv." + FuncName;76return FuncName;77}7879static Function *getOrCreateFunction(Module *M, Type *RetTy,80ArrayRef<Type *> ArgTypes,81StringRef Name) {82FunctionType *FT = FunctionType::get(RetTy, ArgTypes, false);83Function *F = M->getFunction(Name);84if (F && F->getFunctionType() == FT)85return F;86Function *NewF = Function::Create(FT, GlobalValue::ExternalLinkage, Name, M);87if (F)88NewF->setDSOLocal(F->isDSOLocal());89NewF->setCallingConv(CallingConv::SPIR_FUNC);90return NewF;91}9293static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic) {94// For @llvm.memset.* intrinsic cases with constant value and length arguments95// are emulated via "storing" a constant array to the destination. For other96// cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the97// intrinsic to a loop via expandMemSetAsLoop().98if (auto *MSI = dyn_cast<MemSetInst>(Intrinsic))99if (isa<Constant>(MSI->getValue()) && isa<ConstantInt>(MSI->getLength()))100return false; // It is handled later using OpCopyMemorySized.101102Module *M = Intrinsic->getModule();103std::string FuncName = lowerLLVMIntrinsicName(Intrinsic);104if (Intrinsic->isVolatile())105FuncName += ".volatile";106// Redirect @llvm.intrinsic.* call to @spirv.llvm_intrinsic_*107Function *F = M->getFunction(FuncName);108if (F) {109Intrinsic->setCalledFunction(F);110return true;111}112// TODO copy arguments attributes: nocapture writeonly.113FunctionCallee FC =114M->getOrInsertFunction(FuncName, Intrinsic->getFunctionType());115auto IntrinsicID = Intrinsic->getIntrinsicID();116Intrinsic->setCalledFunction(FC);117118F = dyn_cast<Function>(FC.getCallee());119assert(F && "Callee must be a function");120121switch (IntrinsicID) {122case Intrinsic::memset: {123auto *MSI = static_cast<MemSetInst *>(Intrinsic);124Argument *Dest = F->getArg(0);125Argument *Val = F->getArg(1);126Argument *Len = F->getArg(2);127Argument *IsVolatile = F->getArg(3);128Dest->setName("dest");129Val->setName("val");130Len->setName("len");131IsVolatile->setName("isvolatile");132BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F);133IRBuilder<> IRB(EntryBB);134auto *MemSet = IRB.CreateMemSet(Dest, Val, Len, MSI->getDestAlign(),135MSI->isVolatile());136IRB.CreateRetVoid();137expandMemSetAsLoop(cast<MemSetInst>(MemSet));138MemSet->eraseFromParent();139break;140}141case Intrinsic::bswap: {142BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F);143IRBuilder<> IRB(EntryBB);144auto *BSwap = IRB.CreateIntrinsic(Intrinsic::bswap, Intrinsic->getType(),145F->getArg(0));146IRB.CreateRet(BSwap);147IntrinsicLowering IL(M->getDataLayout());148IL.LowerIntrinsicCall(BSwap);149break;150}151default:152break;153}154return true;155}156157static std::string getAnnotation(Value *AnnoVal, Value *OptAnnoVal) {158if (auto *Ref = dyn_cast_or_null<GetElementPtrInst>(AnnoVal))159AnnoVal = Ref->getOperand(0);160if (auto *Ref = dyn_cast_or_null<BitCastInst>(OptAnnoVal))161OptAnnoVal = Ref->getOperand(0);162163std::string Anno;164if (auto *C = dyn_cast_or_null<Constant>(AnnoVal)) {165StringRef Str;166if (getConstantStringInfo(C, Str))167Anno = Str;168}169// handle optional annotation parameter in a way that Khronos Translator do170// (collect integers wrapped in a struct)171if (auto *C = dyn_cast_or_null<Constant>(OptAnnoVal);172C && C->getNumOperands()) {173Value *MaybeStruct = C->getOperand(0);174if (auto *Struct = dyn_cast<ConstantStruct>(MaybeStruct)) {175for (unsigned I = 0, E = Struct->getNumOperands(); I != E; ++I) {176if (auto *CInt = dyn_cast<ConstantInt>(Struct->getOperand(I)))177Anno += (I == 0 ? ": " : ", ") +178std::to_string(CInt->getType()->getIntegerBitWidth() == 1179? CInt->getZExtValue()180: CInt->getSExtValue());181}182} else if (auto *Struct = dyn_cast<ConstantAggregateZero>(MaybeStruct)) {183// { i32 i32 ... } zeroinitializer184for (unsigned I = 0, E = Struct->getType()->getStructNumElements();185I != E; ++I)186Anno += I == 0 ? ": 0" : ", 0";187}188}189return Anno;190}191192static SmallVector<Metadata *> parseAnnotation(Value *I,193const std::string &Anno,194LLVMContext &Ctx,195Type *Int32Ty) {196// Try to parse the annotation string according to the following rules:197// annotation := ({kind} | {kind:value,value,...})+198// kind := number199// value := number | string200static const std::regex R(201"\\{(\\d+)(?:[:,](\\d+|\"[^\"]*\")(?:,(\\d+|\"[^\"]*\"))*)?\\}");202SmallVector<Metadata *> MDs;203int Pos = 0;204for (std::sregex_iterator205It = std::sregex_iterator(Anno.begin(), Anno.end(), R),206ItEnd = std::sregex_iterator();207It != ItEnd; ++It) {208if (It->position() != Pos)209return SmallVector<Metadata *>{};210Pos = It->position() + It->length();211std::smatch Match = *It;212SmallVector<Metadata *> MDsItem;213for (std::size_t i = 1; i < Match.size(); ++i) {214std::ssub_match SMatch = Match[i];215std::string Item = SMatch.str();216if (Item.length() == 0)217break;218if (Item[0] == '"') {219Item = Item.substr(1, Item.length() - 2);220// Acceptable format of the string snippet is:221static const std::regex RStr("^(\\d+)(?:,(\\d+))*$");222if (std::smatch MatchStr; std::regex_match(Item, MatchStr, RStr)) {223for (std::size_t SubIdx = 1; SubIdx < MatchStr.size(); ++SubIdx)224if (std::string SubStr = MatchStr[SubIdx].str(); SubStr.length())225MDsItem.push_back(ConstantAsMetadata::get(226ConstantInt::get(Int32Ty, std::stoi(SubStr))));227} else {228MDsItem.push_back(MDString::get(Ctx, Item));229}230} else if (int32_t Num;231std::from_chars(Item.data(), Item.data() + Item.size(), Num)232.ec == std::errc{}) {233MDsItem.push_back(234ConstantAsMetadata::get(ConstantInt::get(Int32Ty, Num)));235} else {236MDsItem.push_back(MDString::get(Ctx, Item));237}238}239if (MDsItem.size() == 0)240return SmallVector<Metadata *>{};241MDs.push_back(MDNode::get(Ctx, MDsItem));242}243return Pos == static_cast<int>(Anno.length()) ? MDs244: SmallVector<Metadata *>{};245}246247static void lowerPtrAnnotation(IntrinsicInst *II) {248LLVMContext &Ctx = II->getContext();249Type *Int32Ty = Type::getInt32Ty(Ctx);250251// Retrieve an annotation string from arguments.252Value *PtrArg = nullptr;253if (auto *BI = dyn_cast<BitCastInst>(II->getArgOperand(0)))254PtrArg = BI->getOperand(0);255else256PtrArg = II->getOperand(0);257std::string Anno =258getAnnotation(II->getArgOperand(1),2594 < II->arg_size() ? II->getArgOperand(4) : nullptr);260261// Parse the annotation.262SmallVector<Metadata *> MDs = parseAnnotation(II, Anno, Ctx, Int32Ty);263264// If the annotation string is not parsed successfully we don't know the265// format used and output it as a general UserSemantic decoration.266// Otherwise MDs is a Metadata tuple (a decoration list) in the format267// expected by `spirv.Decorations`.268if (MDs.size() == 0) {269auto UserSemantic = ConstantAsMetadata::get(ConstantInt::get(270Int32Ty, static_cast<uint32_t>(SPIRV::Decoration::UserSemantic)));271MDs.push_back(MDNode::get(Ctx, {UserSemantic, MDString::get(Ctx, Anno)}));272}273274// Build the internal intrinsic function.275IRBuilder<> IRB(II->getParent());276IRB.SetInsertPoint(II);277IRB.CreateIntrinsic(278Intrinsic::spv_assign_decoration, {PtrArg->getType()},279{PtrArg, MetadataAsValue::get(Ctx, MDNode::get(Ctx, MDs))});280II->replaceAllUsesWith(II->getOperand(0));281}282283static void lowerFunnelShifts(IntrinsicInst *FSHIntrinsic) {284// Get a separate function - otherwise, we'd have to rework the CFG of the285// current one. Then simply replace the intrinsic uses with a call to the new286// function.287// Generate LLVM IR for i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c)288Module *M = FSHIntrinsic->getModule();289FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType();290Type *FSHRetTy = FSHFuncTy->getReturnType();291const std::string FuncName = lowerLLVMIntrinsicName(FSHIntrinsic);292Function *FSHFunc =293getOrCreateFunction(M, FSHRetTy, FSHFuncTy->params(), FuncName);294295if (!FSHFunc->empty()) {296FSHIntrinsic->setCalledFunction(FSHFunc);297return;298}299BasicBlock *RotateBB = BasicBlock::Create(M->getContext(), "rotate", FSHFunc);300IRBuilder<> IRB(RotateBB);301Type *Ty = FSHFunc->getReturnType();302// Build the actual funnel shift rotate logic.303// In the comments, "int" is used interchangeably with "vector of int304// elements".305FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Ty);306Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty;307unsigned BitWidth = IntTy->getIntegerBitWidth();308ConstantInt *BitWidthConstant = IRB.getInt({BitWidth, BitWidth});309Value *BitWidthForInsts =310VectorTy311? IRB.CreateVectorSplat(VectorTy->getNumElements(), BitWidthConstant)312: BitWidthConstant;313Value *RotateModVal =314IRB.CreateURem(/*Rotate*/ FSHFunc->getArg(2), BitWidthForInsts);315Value *FirstShift = nullptr, *SecShift = nullptr;316if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {317// Shift the less significant number right, the "rotate" number of bits318// will be 0-filled on the left as a result of this regular shift.319FirstShift = IRB.CreateLShr(FSHFunc->getArg(1), RotateModVal);320} else {321// Shift the more significant number left, the "rotate" number of bits322// will be 0-filled on the right as a result of this regular shift.323FirstShift = IRB.CreateShl(FSHFunc->getArg(0), RotateModVal);324}325// We want the "rotate" number of the more significant int's LSBs (MSBs) to326// occupy the leftmost (rightmost) "0 space" left by the previous operation.327// Therefore, subtract the "rotate" number from the integer bitsize...328Value *SubRotateVal = IRB.CreateSub(BitWidthForInsts, RotateModVal);329if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {330// ...and left-shift the more significant int by this number, zero-filling331// the LSBs.332SecShift = IRB.CreateShl(FSHFunc->getArg(0), SubRotateVal);333} else {334// ...and right-shift the less significant int by this number, zero-filling335// the MSBs.336SecShift = IRB.CreateLShr(FSHFunc->getArg(1), SubRotateVal);337}338// A simple binary addition of the shifted ints yields the final result.339IRB.CreateRet(IRB.CreateOr(FirstShift, SecShift));340341FSHIntrinsic->setCalledFunction(FSHFunc);342}343344static void buildUMulWithOverflowFunc(Function *UMulFunc) {345// The function body is already created.346if (!UMulFunc->empty())347return;348349BasicBlock *EntryBB = BasicBlock::Create(UMulFunc->getParent()->getContext(),350"entry", UMulFunc);351IRBuilder<> IRB(EntryBB);352// Build the actual unsigned multiplication logic with the overflow353// indication. Do unsigned multiplication Mul = A * B. Then check354// if unsigned division Div = Mul / A is not equal to B. If so,355// then overflow has happened.356Value *Mul = IRB.CreateNUWMul(UMulFunc->getArg(0), UMulFunc->getArg(1));357Value *Div = IRB.CreateUDiv(Mul, UMulFunc->getArg(0));358Value *Overflow = IRB.CreateICmpNE(UMulFunc->getArg(0), Div);359360// umul.with.overflow intrinsic return a structure, where the first element361// is the multiplication result, and the second is an overflow bit.362Type *StructTy = UMulFunc->getReturnType();363Value *Agg = IRB.CreateInsertValue(PoisonValue::get(StructTy), Mul, {0});364Value *Res = IRB.CreateInsertValue(Agg, Overflow, {1});365IRB.CreateRet(Res);366}367368static void lowerExpectAssume(IntrinsicInst *II) {369// If we cannot use the SPV_KHR_expect_assume extension, then we need to370// ignore the intrinsic and move on. It should be removed later on by LLVM.371// Otherwise we should lower the intrinsic to the corresponding SPIR-V372// instruction.373// For @llvm.assume we have OpAssumeTrueKHR.374// For @llvm.expect we have OpExpectKHR.375//376// We need to lower this into a builtin and then the builtin into a SPIR-V377// instruction.378if (II->getIntrinsicID() == Intrinsic::assume) {379Function *F = Intrinsic::getDeclaration(380II->getModule(), Intrinsic::SPVIntrinsics::spv_assume);381II->setCalledFunction(F);382} else if (II->getIntrinsicID() == Intrinsic::expect) {383Function *F = Intrinsic::getDeclaration(384II->getModule(), Intrinsic::SPVIntrinsics::spv_expect,385{II->getOperand(0)->getType()});386II->setCalledFunction(F);387} else {388llvm_unreachable("Unknown intrinsic");389}390391return;392}393394static bool toSpvOverloadedIntrinsic(IntrinsicInst *II, Intrinsic::ID NewID,395ArrayRef<unsigned> OpNos) {396Function *F = nullptr;397if (OpNos.empty()) {398F = Intrinsic::getDeclaration(II->getModule(), NewID);399} else {400SmallVector<Type *, 4> Tys;401for (unsigned OpNo : OpNos)402Tys.push_back(II->getOperand(OpNo)->getType());403F = Intrinsic::getDeclaration(II->getModule(), NewID, Tys);404}405II->setCalledFunction(F);406return true;407}408409static void lowerUMulWithOverflow(IntrinsicInst *UMulIntrinsic) {410// Get a separate function - otherwise, we'd have to rework the CFG of the411// current one. Then simply replace the intrinsic uses with a call to the new412// function.413Module *M = UMulIntrinsic->getModule();414FunctionType *UMulFuncTy = UMulIntrinsic->getFunctionType();415Type *FSHLRetTy = UMulFuncTy->getReturnType();416const std::string FuncName = lowerLLVMIntrinsicName(UMulIntrinsic);417Function *UMulFunc =418getOrCreateFunction(M, FSHLRetTy, UMulFuncTy->params(), FuncName);419buildUMulWithOverflowFunc(UMulFunc);420UMulIntrinsic->setCalledFunction(UMulFunc);421}422423// Substitutes calls to LLVM intrinsics with either calls to SPIR-V intrinsics424// or calls to proper generated functions. Returns True if F was modified.425bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {426bool Changed = false;427for (BasicBlock &BB : *F) {428for (Instruction &I : BB) {429auto Call = dyn_cast<CallInst>(&I);430if (!Call)431continue;432Function *CF = Call->getCalledFunction();433if (!CF || !CF->isIntrinsic())434continue;435auto *II = cast<IntrinsicInst>(Call);436switch (II->getIntrinsicID()) {437case Intrinsic::memset:438case Intrinsic::bswap:439Changed |= lowerIntrinsicToFunction(II);440break;441case Intrinsic::fshl:442case Intrinsic::fshr:443lowerFunnelShifts(II);444Changed = true;445break;446case Intrinsic::umul_with_overflow:447lowerUMulWithOverflow(II);448Changed = true;449break;450case Intrinsic::assume:451case Intrinsic::expect: {452const SPIRVSubtarget &STI = TM.getSubtarget<SPIRVSubtarget>(*F);453if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume))454lowerExpectAssume(II);455Changed = true;456} break;457case Intrinsic::lifetime_start:458Changed |= toSpvOverloadedIntrinsic(459II, Intrinsic::SPVIntrinsics::spv_lifetime_start, {1});460break;461case Intrinsic::lifetime_end:462Changed |= toSpvOverloadedIntrinsic(463II, Intrinsic::SPVIntrinsics::spv_lifetime_end, {1});464break;465case Intrinsic::ptr_annotation:466lowerPtrAnnotation(II);467Changed = true;468break;469}470}471}472return Changed;473}474475// Returns F if aggregate argument/return types are not present or cloned F476// function with the types replaced by i32 types. The change in types is477// noted in 'spv.cloned_funcs' metadata for later restoration.478Function *479SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {480IRBuilder<> B(F->getContext());481482bool IsRetAggr = F->getReturnType()->isAggregateType();483bool HasAggrArg =484std::any_of(F->arg_begin(), F->arg_end(), [](Argument &Arg) {485return Arg.getType()->isAggregateType();486});487bool DoClone = IsRetAggr || HasAggrArg;488if (!DoClone)489return F;490SmallVector<std::pair<int, Type *>, 4> ChangedTypes;491Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType();492if (IsRetAggr)493ChangedTypes.push_back(std::pair<int, Type *>(-1, F->getReturnType()));494SmallVector<Type *, 4> ArgTypes;495for (const auto &Arg : F->args()) {496if (Arg.getType()->isAggregateType()) {497ArgTypes.push_back(B.getInt32Ty());498ChangedTypes.push_back(499std::pair<int, Type *>(Arg.getArgNo(), Arg.getType()));500} else501ArgTypes.push_back(Arg.getType());502}503FunctionType *NewFTy =504FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg());505Function *NewF =506Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent());507508ValueToValueMapTy VMap;509auto NewFArgIt = NewF->arg_begin();510for (auto &Arg : F->args()) {511StringRef ArgName = Arg.getName();512NewFArgIt->setName(ArgName);513VMap[&Arg] = &(*NewFArgIt++);514}515SmallVector<ReturnInst *, 8> Returns;516517CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,518Returns);519NewF->takeName(F);520521NamedMDNode *FuncMD =522F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs");523SmallVector<Metadata *, 2> MDArgs;524MDArgs.push_back(MDString::get(B.getContext(), NewF->getName()));525for (auto &ChangedTyP : ChangedTypes)526MDArgs.push_back(MDNode::get(527B.getContext(),528{ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)),529ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))}));530MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs);531FuncMD->addOperand(ThisFuncMD);532533for (auto *U : make_early_inc_range(F->users())) {534if (auto *CI = dyn_cast<CallInst>(U))535CI->mutateFunctionType(NewF->getFunctionType());536U->replaceUsesOfWith(F, NewF);537}538539// register the mutation540if (RetType != F->getReturnType())541TM.getSubtarget<SPIRVSubtarget>(*F).getSPIRVGlobalRegistry()->addMutated(542NewF, F->getReturnType());543return NewF;544}545546bool SPIRVPrepareFunctions::runOnModule(Module &M) {547bool Changed = false;548for (Function &F : M)549Changed |= substituteIntrinsicCalls(&F);550551std::vector<Function *> FuncsWorklist;552for (auto &F : M)553FuncsWorklist.push_back(&F);554555for (auto *F : FuncsWorklist) {556Function *NewF = removeAggregateTypesFromSignature(F);557558if (NewF != F) {559F->eraseFromParent();560Changed = true;561}562}563return Changed;564}565566ModulePass *567llvm::createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM) {568return new SPIRVPrepareFunctions(TM);569}570571572