Path: blob/main/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVRegularizer.cpp
35266 views
//===-- SPIRVRegularizer.cpp - regularize IR for SPIR-V ---------*- C++ -*-===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// This pass implements regularization of LLVM IR for SPIR-V. The prototype of9// the pass was taken from SPIRV-LLVM translator.10//11//===----------------------------------------------------------------------===//1213#include "SPIRV.h"14#include "SPIRVTargetMachine.h"15#include "llvm/Demangle/Demangle.h"16#include "llvm/IR/InstIterator.h"17#include "llvm/IR/InstVisitor.h"18#include "llvm/IR/PassManager.h"19#include "llvm/Transforms/Utils/Cloning.h"2021#include <list>2223#define DEBUG_TYPE "spirv-regularizer"2425using namespace llvm;2627namespace llvm {28void initializeSPIRVRegularizerPass(PassRegistry &);29}3031namespace {32struct SPIRVRegularizer : public FunctionPass, InstVisitor<SPIRVRegularizer> {33DenseMap<Function *, Function *> Old2NewFuncs;3435public:36static char ID;37SPIRVRegularizer() : FunctionPass(ID) {38initializeSPIRVRegularizerPass(*PassRegistry::getPassRegistry());39}40bool runOnFunction(Function &F) override;41StringRef getPassName() const override { return "SPIR-V Regularizer"; }4243void getAnalysisUsage(AnalysisUsage &AU) const override {44FunctionPass::getAnalysisUsage(AU);45}46void visitCallInst(CallInst &CI);4748private:49void visitCallScalToVec(CallInst *CI, StringRef MangledName,50StringRef DemangledName);51void runLowerConstExpr(Function &F);52};53} // namespace5455char SPIRVRegularizer::ID = 0;5657INITIALIZE_PASS(SPIRVRegularizer, DEBUG_TYPE, "SPIR-V Regularizer", false,58false)5960// Since SPIR-V cannot represent constant expression, constant expressions61// in LLVM IR need to be lowered to instructions. For each function,62// the constant expressions used by instructions of the function are replaced63// by instructions placed in the entry block since it dominates all other BBs.64// Each constant expression only needs to be lowered once in each function65// and all uses of it by instructions in that function are replaced by66// one instruction.67// TODO: remove redundant instructions for common subexpression.68void SPIRVRegularizer::runLowerConstExpr(Function &F) {69LLVMContext &Ctx = F.getContext();70std::list<Instruction *> WorkList;71for (auto &II : instructions(F))72WorkList.push_back(&II);7374auto FBegin = F.begin();75while (!WorkList.empty()) {76Instruction *II = WorkList.front();7778auto LowerOp = [&II, &FBegin, &F](Value *V) -> Value * {79if (isa<Function>(V))80return V;81auto *CE = cast<ConstantExpr>(V);82LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] " << *CE);83auto ReplInst = CE->getAsInstruction();84auto InsPoint = II->getParent() == &*FBegin ? II : &FBegin->back();85ReplInst->insertBefore(InsPoint);86LLVM_DEBUG(dbgs() << " -> " << *ReplInst << '\n');87std::vector<Instruction *> Users;88// Do not replace use during iteration of use. Do it in another loop.89for (auto U : CE->users()) {90LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] Use: " << *U << '\n');91auto InstUser = dyn_cast<Instruction>(U);92// Only replace users in scope of current function.93if (InstUser && InstUser->getParent()->getParent() == &F)94Users.push_back(InstUser);95}96for (auto &User : Users) {97if (ReplInst->getParent() == User->getParent() &&98User->comesBefore(ReplInst))99ReplInst->moveBefore(User);100User->replaceUsesOfWith(CE, ReplInst);101}102return ReplInst;103};104105WorkList.pop_front();106auto LowerConstantVec = [&II, &LowerOp, &WorkList,107&Ctx](ConstantVector *Vec,108unsigned NumOfOp) -> Value * {109if (std::all_of(Vec->op_begin(), Vec->op_end(), [](Value *V) {110return isa<ConstantExpr>(V) || isa<Function>(V);111})) {112// Expand a vector of constexprs and construct it back with113// series of insertelement instructions.114std::list<Value *> OpList;115std::transform(Vec->op_begin(), Vec->op_end(),116std::back_inserter(OpList),117[LowerOp](Value *V) { return LowerOp(V); });118Value *Repl = nullptr;119unsigned Idx = 0;120auto *PhiII = dyn_cast<PHINode>(II);121Instruction *InsPoint =122PhiII ? &PhiII->getIncomingBlock(NumOfOp)->back() : II;123std::list<Instruction *> ReplList;124for (auto V : OpList) {125if (auto *Inst = dyn_cast<Instruction>(V))126ReplList.push_back(Inst);127Repl = InsertElementInst::Create(128(Repl ? Repl : PoisonValue::get(Vec->getType())), V,129ConstantInt::get(Type::getInt32Ty(Ctx), Idx++), "", InsPoint);130}131WorkList.splice(WorkList.begin(), ReplList);132return Repl;133}134return nullptr;135};136for (unsigned OI = 0, OE = II->getNumOperands(); OI != OE; ++OI) {137auto *Op = II->getOperand(OI);138if (auto *Vec = dyn_cast<ConstantVector>(Op)) {139Value *ReplInst = LowerConstantVec(Vec, OI);140if (ReplInst)141II->replaceUsesOfWith(Op, ReplInst);142} else if (auto CE = dyn_cast<ConstantExpr>(Op)) {143WorkList.push_front(cast<Instruction>(LowerOp(CE)));144} else if (auto MDAsVal = dyn_cast<MetadataAsValue>(Op)) {145auto ConstMD = dyn_cast<ConstantAsMetadata>(MDAsVal->getMetadata());146if (!ConstMD)147continue;148Constant *C = ConstMD->getValue();149Value *ReplInst = nullptr;150if (auto *Vec = dyn_cast<ConstantVector>(C))151ReplInst = LowerConstantVec(Vec, OI);152if (auto *CE = dyn_cast<ConstantExpr>(C))153ReplInst = LowerOp(CE);154if (!ReplInst)155continue;156Metadata *RepMD = ValueAsMetadata::get(ReplInst);157Value *RepMDVal = MetadataAsValue::get(Ctx, RepMD);158II->setOperand(OI, RepMDVal);159WorkList.push_front(cast<Instruction>(ReplInst));160}161}162}163}164165// It fixes calls to OCL builtins that accept vector arguments and one of them166// is actually a scalar splat.167void SPIRVRegularizer::visitCallInst(CallInst &CI) {168auto F = CI.getCalledFunction();169if (!F)170return;171172auto MangledName = F->getName();173char *NameStr = itaniumDemangle(F->getName().data());174if (!NameStr)175return;176StringRef DemangledName(NameStr);177178// TODO: add support for other builtins.179if (DemangledName.starts_with("fmin") || DemangledName.starts_with("fmax") ||180DemangledName.starts_with("min") || DemangledName.starts_with("max"))181visitCallScalToVec(&CI, MangledName, DemangledName);182free(NameStr);183}184185void SPIRVRegularizer::visitCallScalToVec(CallInst *CI, StringRef MangledName,186StringRef DemangledName) {187// Check if all arguments have the same type - it's simple case.188auto Uniform = true;189Type *Arg0Ty = CI->getOperand(0)->getType();190auto IsArg0Vector = isa<VectorType>(Arg0Ty);191for (unsigned I = 1, E = CI->arg_size(); Uniform && (I != E); ++I)192Uniform = isa<VectorType>(CI->getOperand(I)->getType()) == IsArg0Vector;193if (Uniform)194return;195196auto *OldF = CI->getCalledFunction();197Function *NewF = nullptr;198if (!Old2NewFuncs.count(OldF)) {199AttributeList Attrs = CI->getCalledFunction()->getAttributes();200SmallVector<Type *, 2> ArgTypes = {OldF->getArg(0)->getType(), Arg0Ty};201auto *NewFTy =202FunctionType::get(OldF->getReturnType(), ArgTypes, OldF->isVarArg());203NewF = Function::Create(NewFTy, OldF->getLinkage(), OldF->getName(),204*OldF->getParent());205ValueToValueMapTy VMap;206auto NewFArgIt = NewF->arg_begin();207for (auto &Arg : OldF->args()) {208auto ArgName = Arg.getName();209NewFArgIt->setName(ArgName);210VMap[&Arg] = &(*NewFArgIt++);211}212SmallVector<ReturnInst *, 8> Returns;213CloneFunctionInto(NewF, OldF, VMap,214CloneFunctionChangeType::LocalChangesOnly, Returns);215NewF->setAttributes(Attrs);216Old2NewFuncs[OldF] = NewF;217} else {218NewF = Old2NewFuncs[OldF];219}220assert(NewF);221222// This produces an instruction sequence that implements a splat of223// CI->getOperand(1) to a vector Arg0Ty. However, we use InsertElementInst224// and ShuffleVectorInst to generate the same code as the SPIR-V translator.225// For instance (transcoding/OpMin.ll), this call226// call spir_func <2 x i32> @_Z3minDv2_ii(<2 x i32> <i32 1, i32 10>, i32 5)227// is translated to228// %8 = OpUndef %v2uint229// %14 = OpConstantComposite %v2uint %uint_1 %uint_10230// ...231// %10 = OpCompositeInsert %v2uint %uint_5 %8 0232// %11 = OpVectorShuffle %v2uint %10 %8 0 0233// %call = OpExtInst %v2uint %1 s_min %14 %11234auto ConstInt = ConstantInt::get(IntegerType::get(CI->getContext(), 32), 0);235PoisonValue *PVal = PoisonValue::get(Arg0Ty);236Instruction *Inst =237InsertElementInst::Create(PVal, CI->getOperand(1), ConstInt, "", CI);238ElementCount VecElemCount = cast<VectorType>(Arg0Ty)->getElementCount();239Constant *ConstVec = ConstantVector::getSplat(VecElemCount, ConstInt);240Value *NewVec = new ShuffleVectorInst(Inst, PVal, ConstVec, "", CI);241CI->setOperand(1, NewVec);242CI->replaceUsesOfWith(OldF, NewF);243CI->mutateFunctionType(NewF->getFunctionType());244}245246bool SPIRVRegularizer::runOnFunction(Function &F) {247runLowerConstExpr(F);248visit(F);249for (auto &OldNew : Old2NewFuncs) {250Function *OldF = OldNew.first;251Function *NewF = OldNew.second;252NewF->takeName(OldF);253OldF->eraseFromParent();254}255return true;256}257258FunctionPass *llvm::createSPIRVRegularizerPass() {259return new SPIRVRegularizer();260}261262263