Path: blob/main/contrib/llvm-project/llvm/lib/Transforms/Scalar/JumpTableToSwitch.cpp
35266 views
//===- JumpTableToSwitch.cpp ----------------------------------------------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//78#include "llvm/Transforms/Scalar/JumpTableToSwitch.h"9#include "llvm/ADT/SmallVector.h"10#include "llvm/Analysis/ConstantFolding.h"11#include "llvm/Analysis/DomTreeUpdater.h"12#include "llvm/Analysis/OptimizationRemarkEmitter.h"13#include "llvm/Analysis/PostDominators.h"14#include "llvm/IR/IRBuilder.h"15#include "llvm/Support/CommandLine.h"16#include "llvm/Transforms/Utils/BasicBlockUtils.h"1718using namespace llvm;1920static cl::opt<unsigned>21JumpTableSizeThreshold("jump-table-to-switch-size-threshold", cl::Hidden,22cl::desc("Only split jump tables with size less or "23"equal than JumpTableSizeThreshold."),24cl::init(10));2526// TODO: Consider adding a cost model for profitability analysis of this27// transformation. Currently we replace a jump table with a switch if all the28// functions in the jump table are smaller than the provided threshold.29static cl::opt<unsigned> FunctionSizeThreshold(30"jump-table-to-switch-function-size-threshold", cl::Hidden,31cl::desc("Only split jump tables containing functions whose sizes are less "32"or equal than this threshold."),33cl::init(50));3435#define DEBUG_TYPE "jump-table-to-switch"3637namespace {38struct JumpTableTy {39Value *Index;40SmallVector<Function *, 10> Funcs;41};42} // anonymous namespace4344static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP,45PointerType *PtrTy) {46Constant *Ptr = dyn_cast<Constant>(GEP->getPointerOperand());47if (!Ptr)48return std::nullopt;4950GlobalVariable *GV = dyn_cast<GlobalVariable>(Ptr);51if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer())52return std::nullopt;5354Function &F = *GEP->getParent()->getParent();55const DataLayout &DL = F.getDataLayout();56const unsigned BitWidth =57DL.getIndexSizeInBits(GEP->getPointerAddressSpace());58MapVector<Value *, APInt> VariableOffsets;59APInt ConstantOffset(BitWidth, 0);60if (!GEP->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset))61return std::nullopt;62if (VariableOffsets.size() != 1)63return std::nullopt;64// TODO: consider supporting more general patterns65if (!ConstantOffset.isZero())66return std::nullopt;67APInt StrideBytes = VariableOffsets.front().second;68const uint64_t JumpTableSizeBytes = DL.getTypeAllocSize(GV->getValueType());69if (JumpTableSizeBytes % StrideBytes.getZExtValue() != 0)70return std::nullopt;71const uint64_t N = JumpTableSizeBytes / StrideBytes.getZExtValue();72if (N > JumpTableSizeThreshold)73return std::nullopt;7475JumpTableTy JumpTable;76JumpTable.Index = VariableOffsets.front().first;77JumpTable.Funcs.reserve(N);78for (uint64_t Index = 0; Index < N; ++Index) {79// ConstantOffset is zero.80APInt Offset = Index * StrideBytes;81Constant *C =82ConstantFoldLoadFromConst(GV->getInitializer(), PtrTy, Offset, DL);83auto *Func = dyn_cast_or_null<Function>(C);84if (!Func || Func->isDeclaration() ||85Func->getInstructionCount() > FunctionSizeThreshold)86return std::nullopt;87JumpTable.Funcs.push_back(Func);88}89return JumpTable;90}9192static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,93DomTreeUpdater &DTU,94OptimizationRemarkEmitter &ORE) {95const bool IsVoid = CB->getType() == Type::getVoidTy(CB->getContext());9697SmallVector<DominatorTree::UpdateType, 8> DTUpdates;98BasicBlock *BB = CB->getParent();99BasicBlock *Tail = SplitBlock(BB, CB, &DTU, nullptr, nullptr,100BB->getName() + Twine(".tail"));101DTUpdates.push_back({DominatorTree::Delete, BB, Tail});102BB->getTerminator()->eraseFromParent();103104Function &F = *BB->getParent();105BasicBlock *BBUnreachable = BasicBlock::Create(106F.getContext(), "default.switch.case.unreachable", &F, Tail);107IRBuilder<> BuilderUnreachable(BBUnreachable);108BuilderUnreachable.CreateUnreachable();109110IRBuilder<> Builder(BB);111SwitchInst *Switch = Builder.CreateSwitch(JT.Index, BBUnreachable);112DTUpdates.push_back({DominatorTree::Insert, BB, BBUnreachable});113114IRBuilder<> BuilderTail(CB);115PHINode *PHI =116IsVoid ? nullptr : BuilderTail.CreatePHI(CB->getType(), JT.Funcs.size());117118for (auto [Index, Func] : llvm::enumerate(JT.Funcs)) {119BasicBlock *B = BasicBlock::Create(Func->getContext(),120"call." + Twine(Index), &F, Tail);121DTUpdates.push_back({DominatorTree::Insert, BB, B});122DTUpdates.push_back({DominatorTree::Insert, B, Tail});123124CallBase *Call = cast<CallBase>(CB->clone());125Call->setCalledFunction(Func);126Call->insertInto(B, B->end());127Switch->addCase(128cast<ConstantInt>(ConstantInt::get(JT.Index->getType(), Index)), B);129BranchInst::Create(Tail, B);130if (PHI)131PHI->addIncoming(Call, B);132}133DTU.applyUpdates(DTUpdates);134ORE.emit([&]() {135return OptimizationRemark(DEBUG_TYPE, "ReplacedJumpTableWithSwitch", CB)136<< "expanded indirect call into switch";137});138if (PHI)139CB->replaceAllUsesWith(PHI);140CB->eraseFromParent();141return Tail;142}143144PreservedAnalyses JumpTableToSwitchPass::run(Function &F,145FunctionAnalysisManager &AM) {146OptimizationRemarkEmitter &ORE =147AM.getResult<OptimizationRemarkEmitterAnalysis>(F);148DominatorTree *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);149PostDominatorTree *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F);150DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy);151bool Changed = false;152for (BasicBlock &BB : make_early_inc_range(F)) {153BasicBlock *CurrentBB = &BB;154while (CurrentBB) {155BasicBlock *SplittedOutTail = nullptr;156for (Instruction &I : make_early_inc_range(*CurrentBB)) {157auto *Call = dyn_cast<CallInst>(&I);158if (!Call || Call->getCalledFunction() || Call->isMustTailCall())159continue;160auto *L = dyn_cast<LoadInst>(Call->getCalledOperand());161// Skip atomic or volatile loads.162if (!L || !L->isSimple())163continue;164auto *GEP = dyn_cast<GetElementPtrInst>(L->getPointerOperand());165if (!GEP)166continue;167auto *PtrTy = dyn_cast<PointerType>(L->getType());168assert(PtrTy && "call operand must be a pointer");169std::optional<JumpTableTy> JumpTable = parseJumpTable(GEP, PtrTy);170if (!JumpTable)171continue;172SplittedOutTail = expandToSwitch(Call, *JumpTable, DTU, ORE);173Changed = true;174break;175}176CurrentBB = SplittedOutTail ? SplittedOutTail : nullptr;177}178}179180if (!Changed)181return PreservedAnalyses::all();182183PreservedAnalyses PA;184if (DT)185PA.preserve<DominatorTreeAnalysis>();186if (PDT)187PA.preserve<PostDominatorTreeAnalysis>();188return PA;189}190191192