Path: blob/main/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
35267 views
//===-- SPIRVMergeRegionExitTargets.cpp ----------------------*- C++ -*-===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// Merge the multiple exit targets of a convergence region into a single block.9// Each exit target will be assigned a constant value, and a phi node + switch10// will allow the new exit target to re-route to the correct basic block.11//12//===----------------------------------------------------------------------===//1314#include "Analysis/SPIRVConvergenceRegionAnalysis.h"15#include "SPIRV.h"16#include "SPIRVSubtarget.h"17#include "SPIRVTargetMachine.h"18#include "SPIRVUtils.h"19#include "llvm/ADT/DenseMap.h"20#include "llvm/ADT/SmallPtrSet.h"21#include "llvm/Analysis/LoopInfo.h"22#include "llvm/CodeGen/IntrinsicLowering.h"23#include "llvm/IR/CFG.h"24#include "llvm/IR/Dominators.h"25#include "llvm/IR/IRBuilder.h"26#include "llvm/IR/IntrinsicInst.h"27#include "llvm/IR/Intrinsics.h"28#include "llvm/IR/IntrinsicsSPIRV.h"29#include "llvm/InitializePasses.h"30#include "llvm/Transforms/Utils/Cloning.h"31#include "llvm/Transforms/Utils/LoopSimplify.h"32#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"3334using namespace llvm;3536namespace llvm {37void initializeSPIRVMergeRegionExitTargetsPass(PassRegistry &);3839class SPIRVMergeRegionExitTargets : public FunctionPass {40public:41static char ID;4243SPIRVMergeRegionExitTargets() : FunctionPass(ID) {44initializeSPIRVMergeRegionExitTargetsPass(*PassRegistry::getPassRegistry());45};4647// Gather all the successors of |BB|.48// This function asserts if the terminator neither a branch, switch or return.49std::unordered_set<BasicBlock *> gatherSuccessors(BasicBlock *BB) {50std::unordered_set<BasicBlock *> output;51auto *T = BB->getTerminator();5253if (auto *BI = dyn_cast<BranchInst>(T)) {54output.insert(BI->getSuccessor(0));55if (BI->isConditional())56output.insert(BI->getSuccessor(1));57return output;58}5960if (auto *SI = dyn_cast<SwitchInst>(T)) {61output.insert(SI->getDefaultDest());62for (auto &Case : SI->cases())63output.insert(Case.getCaseSuccessor());64return output;65}6667assert(isa<ReturnInst>(T) && "Unhandled terminator type.");68return output;69}7071/// Create a value in BB set to the value associated with the branch the block72/// terminator will take.73llvm::Value *createExitVariable(74BasicBlock *BB,75const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) {76auto *T = BB->getTerminator();77if (isa<ReturnInst>(T))78return nullptr;7980IRBuilder<> Builder(BB);81Builder.SetInsertPoint(T);8283if (auto *BI = dyn_cast<BranchInst>(T)) {8485BasicBlock *LHSTarget = BI->getSuccessor(0);86BasicBlock *RHSTarget =87BI->isConditional() ? BI->getSuccessor(1) : nullptr;8889Value *LHS = TargetToValue.count(LHSTarget) != 090? TargetToValue.at(LHSTarget)91: nullptr;92Value *RHS = TargetToValue.count(RHSTarget) != 093? TargetToValue.at(RHSTarget)94: nullptr;9596if (LHS == nullptr || RHS == nullptr)97return LHS == nullptr ? RHS : LHS;98return Builder.CreateSelect(BI->getCondition(), LHS, RHS);99}100101// TODO: add support for switch cases.102llvm_unreachable("Unhandled terminator type.");103}104105/// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|.106void replaceBranchTargets(BasicBlock *BB,107const SmallPtrSet<BasicBlock *, 4> &ToReplace,108BasicBlock *NewTarget) {109auto *T = BB->getTerminator();110if (isa<ReturnInst>(T))111return;112113if (auto *BI = dyn_cast<BranchInst>(T)) {114for (size_t i = 0; i < BI->getNumSuccessors(); i++) {115if (ToReplace.count(BI->getSuccessor(i)) != 0)116BI->setSuccessor(i, NewTarget);117}118return;119}120121if (auto *SI = dyn_cast<SwitchInst>(T)) {122for (size_t i = 0; i < SI->getNumSuccessors(); i++) {123if (ToReplace.count(SI->getSuccessor(i)) != 0)124SI->setSuccessor(i, NewTarget);125}126return;127}128129assert(false && "Unhandled terminator type.");130}131132// Run the pass on the given convergence region, ignoring the sub-regions.133// Returns true if the CFG changed, false otherwise.134bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,135const SPIRV::ConvergenceRegion *CR) {136// Gather all the exit targets for this region.137SmallPtrSet<BasicBlock *, 4> ExitTargets;138for (BasicBlock *Exit : CR->Exits) {139for (BasicBlock *Target : gatherSuccessors(Exit)) {140if (CR->Blocks.count(Target) == 0)141ExitTargets.insert(Target);142}143}144145// If we have zero or one exit target, nothing do to.146if (ExitTargets.size() <= 1)147return false;148149// Create the new single exit target.150auto F = CR->Entry->getParent();151auto NewExitTarget = BasicBlock::Create(F->getContext(), "new.exit", F);152IRBuilder<> Builder(NewExitTarget);153154// CodeGen output needs to be stable. Using the set as-is would order155// the targets differently depending on the allocation pattern.156// Sorting per basic-block ordering in the function.157std::vector<BasicBlock *> SortedExitTargets;158std::vector<BasicBlock *> SortedExits;159for (BasicBlock &BB : *F) {160if (ExitTargets.count(&BB) != 0)161SortedExitTargets.push_back(&BB);162if (CR->Exits.count(&BB) != 0)163SortedExits.push_back(&BB);164}165166// Creating one constant per distinct exit target. This will be route to the167// correct target.168DenseMap<BasicBlock *, ConstantInt *> TargetToValue;169for (BasicBlock *Target : SortedExitTargets)170TargetToValue.insert(171std::make_pair(Target, Builder.getInt32(TargetToValue.size())));172173// Creating one variable per exit node, set to the constant matching the174// targeted external block.175std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable;176for (auto Exit : SortedExits) {177llvm::Value *Value = createExitVariable(Exit, TargetToValue);178ExitToVariable.emplace_back(std::make_pair(Exit, Value));179}180181// Gather the correct value depending on the exit we came from.182llvm::PHINode *node =183Builder.CreatePHI(Builder.getInt32Ty(), ExitToVariable.size());184for (auto [BB, Value] : ExitToVariable) {185node->addIncoming(Value, BB);186}187188// Creating the switch to jump to the correct exit target.189llvm::SwitchInst *Sw = Builder.CreateSwitch(node, SortedExitTargets[0],190SortedExitTargets.size() - 1);191for (size_t i = 1; i < SortedExitTargets.size(); i++) {192BasicBlock *BB = SortedExitTargets[i];193Sw->addCase(TargetToValue[BB], BB);194}195196// Fix exit branches to redirect to the new exit.197for (auto Exit : CR->Exits)198replaceBranchTargets(Exit, ExitTargets, NewExitTarget);199200return true;201}202203/// Run the pass on the given convergence region and sub-regions (DFS).204/// Returns true if a region/sub-region was modified, false otherwise.205/// This returns as soon as one region/sub-region has been modified.206bool runOnConvergenceRegion(LoopInfo &LI,207const SPIRV::ConvergenceRegion *CR) {208for (auto *Child : CR->Children)209if (runOnConvergenceRegion(LI, Child))210return true;211212return runOnConvergenceRegionNoRecurse(LI, CR);213}214215#if !NDEBUG216/// Validates each edge exiting the region has the same destination basic217/// block.218void validateRegionExits(const SPIRV::ConvergenceRegion *CR) {219for (auto *Child : CR->Children)220validateRegionExits(Child);221222std::unordered_set<BasicBlock *> ExitTargets;223for (auto *Exit : CR->Exits) {224auto Set = gatherSuccessors(Exit);225for (auto *BB : Set) {226if (CR->Blocks.count(BB) == 0)227ExitTargets.insert(BB);228}229}230231assert(ExitTargets.size() <= 1);232}233#endif234235virtual bool runOnFunction(Function &F) override {236LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();237const auto *TopLevelRegion =238getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()239.getRegionInfo()240.getTopLevelRegion();241242// FIXME: very inefficient method: each time a region is modified, we bubble243// back up, and recompute the whole convergence region tree. Once the244// algorithm is completed and test coverage good enough, rewrite this pass245// to be efficient instead of simple.246bool modified = false;247while (runOnConvergenceRegion(LI, TopLevelRegion)) {248TopLevelRegion = getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()249.getRegionInfo()250.getTopLevelRegion();251modified = true;252}253254#if !defined(NDEBUG) || defined(EXPENSIVE_CHECKS)255validateRegionExits(TopLevelRegion);256#endif257return modified;258}259260void getAnalysisUsage(AnalysisUsage &AU) const override {261AU.addRequired<DominatorTreeWrapperPass>();262AU.addRequired<LoopInfoWrapperPass>();263AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();264FunctionPass::getAnalysisUsage(AU);265}266};267} // namespace llvm268269char SPIRVMergeRegionExitTargets::ID = 0;270271INITIALIZE_PASS_BEGIN(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",272"SPIRV split region exit blocks", false, false)273INITIALIZE_PASS_DEPENDENCY(LoopSimplify)274INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)275INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)276INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass)277278INITIALIZE_PASS_END(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",279"SPIRV split region exit blocks", false, false)280281FunctionPass *llvm::createSPIRVMergeRegionExitTargetsPass() {282return new SPIRVMergeRegionExitTargets();283}284285286