Path: blob/main/contrib/llvm-project/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp
35294 views
//===- ConvergenceRegionAnalysis.h -----------------------------*- C++ -*--===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// The analysis determines the convergence region for each basic block of9// the module, and provides a tree-like structure describing the region10// hierarchy.11//12//===----------------------------------------------------------------------===//1314#include "SPIRVConvergenceRegionAnalysis.h"15#include "llvm/Analysis/LoopInfo.h"16#include "llvm/IR/Dominators.h"17#include "llvm/IR/IntrinsicInst.h"18#include "llvm/InitializePasses.h"19#include "llvm/Transforms/Utils/LoopSimplify.h"20#include <optional>21#include <queue>2223#define DEBUG_TYPE "spirv-convergence-region-analysis"2425using namespace llvm;2627namespace llvm {28void initializeSPIRVConvergenceRegionAnalysisWrapperPassPass(PassRegistry &);29} // namespace llvm3031INITIALIZE_PASS_BEGIN(SPIRVConvergenceRegionAnalysisWrapperPass,32"convergence-region",33"SPIRV convergence regions analysis", true, true)34INITIALIZE_PASS_DEPENDENCY(LoopSimplify)35INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)36INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)37INITIALIZE_PASS_END(SPIRVConvergenceRegionAnalysisWrapperPass,38"convergence-region", "SPIRV convergence regions analysis",39true, true)4041namespace llvm {42namespace SPIRV {43namespace {4445template <typename BasicBlockType, typename IntrinsicInstType>46std::optional<IntrinsicInstType *>47getConvergenceTokenInternal(BasicBlockType *BB) {48static_assert(std::is_const_v<IntrinsicInstType> ==49std::is_const_v<BasicBlockType>,50"Constness must match between input and output.");51static_assert(std::is_same_v<BasicBlock, std::remove_const_t<BasicBlockType>>,52"Input must be a basic block.");53static_assert(54std::is_same_v<IntrinsicInst, std::remove_const_t<IntrinsicInstType>>,55"Output type must be an intrinsic instruction.");5657for (auto &I : *BB) {58if (auto *II = dyn_cast<IntrinsicInst>(&I)) {59switch (II->getIntrinsicID()) {60case Intrinsic::experimental_convergence_entry:61case Intrinsic::experimental_convergence_loop:62return II;63case Intrinsic::experimental_convergence_anchor: {64auto Bundle = II->getOperandBundle(LLVMContext::OB_convergencectrl);65assert(Bundle->Inputs.size() == 1 &&66Bundle->Inputs[0]->getType()->isTokenTy());67auto TII = dyn_cast<IntrinsicInst>(Bundle->Inputs[0].get());68assert(TII != nullptr);69return TII;70}71}72}7374if (auto *CI = dyn_cast<CallInst>(&I)) {75auto OB = CI->getOperandBundle(LLVMContext::OB_convergencectrl);76if (!OB.has_value())77continue;78return dyn_cast<IntrinsicInst>(OB.value().Inputs[0]);79}80}8182return std::nullopt;83}8485// Given a ConvergenceRegion tree with |Start| as its root, finds the smallest86// region |Entry| belongs to. If |Entry| does not belong to the region defined87// by |Start|, this function returns |nullptr|.88ConvergenceRegion *findParentRegion(ConvergenceRegion *Start,89BasicBlock *Entry) {90ConvergenceRegion *Candidate = nullptr;91ConvergenceRegion *NextCandidate = Start;9293while (Candidate != NextCandidate && NextCandidate != nullptr) {94Candidate = NextCandidate;95NextCandidate = nullptr;9697// End of the search, we can return.98if (Candidate->Children.size() == 0)99return Candidate;100101for (auto *Child : Candidate->Children) {102if (Child->Blocks.count(Entry) != 0) {103NextCandidate = Child;104break;105}106}107}108109return Candidate;110}111112} // anonymous namespace113114std::optional<IntrinsicInst *> getConvergenceToken(BasicBlock *BB) {115return getConvergenceTokenInternal<BasicBlock, IntrinsicInst>(BB);116}117118std::optional<const IntrinsicInst *> getConvergenceToken(const BasicBlock *BB) {119return getConvergenceTokenInternal<const BasicBlock, const IntrinsicInst>(BB);120}121122ConvergenceRegion::ConvergenceRegion(DominatorTree &DT, LoopInfo &LI,123Function &F)124: DT(DT), LI(LI), Parent(nullptr) {125Entry = &F.getEntryBlock();126ConvergenceToken = getConvergenceToken(Entry);127for (auto &B : F) {128Blocks.insert(&B);129if (isa<ReturnInst>(B.getTerminator()))130Exits.insert(&B);131}132}133134ConvergenceRegion::ConvergenceRegion(135DominatorTree &DT, LoopInfo &LI,136std::optional<IntrinsicInst *> ConvergenceToken, BasicBlock *Entry,137SmallPtrSet<BasicBlock *, 8> &&Blocks, SmallPtrSet<BasicBlock *, 2> &&Exits)138: DT(DT), LI(LI), ConvergenceToken(ConvergenceToken), Entry(Entry),139Exits(std::move(Exits)), Blocks(std::move(Blocks)) {140for ([[maybe_unused]] auto *BB : this->Exits)141assert(this->Blocks.count(BB) != 0);142assert(this->Blocks.count(this->Entry) != 0);143}144145void ConvergenceRegion::releaseMemory() {146// Parent memory is owned by the parent.147Parent = nullptr;148for (auto *Child : Children) {149Child->releaseMemory();150delete Child;151}152Children.resize(0);153}154155void ConvergenceRegion::dump(const unsigned IndentSize) const {156const std::string Indent(IndentSize, '\t');157dbgs() << Indent << this << ": {\n";158dbgs() << Indent << " Parent: " << Parent << "\n";159160if (ConvergenceToken.value_or(nullptr)) {161dbgs() << Indent162<< " ConvergenceToken: " << ConvergenceToken.value()->getName()163<< "\n";164}165166if (Entry->getName() != "")167dbgs() << Indent << " Entry: " << Entry->getName() << "\n";168else169dbgs() << Indent << " Entry: " << Entry << "\n";170171dbgs() << Indent << " Exits: { ";172for (const auto &Exit : Exits) {173if (Exit->getName() != "")174dbgs() << Exit->getName() << ", ";175else176dbgs() << Exit << ", ";177}178dbgs() << " }\n";179180dbgs() << Indent << " Blocks: { ";181for (const auto &Block : Blocks) {182if (Block->getName() != "")183dbgs() << Block->getName() << ", ";184else185dbgs() << Block << ", ";186}187dbgs() << " }\n";188189dbgs() << Indent << " Children: {\n";190for (const auto Child : Children)191Child->dump(IndentSize + 2);192dbgs() << Indent << " }\n";193194dbgs() << Indent << "}\n";195}196197class ConvergenceRegionAnalyzer {198199public:200ConvergenceRegionAnalyzer(Function &F, DominatorTree &DT, LoopInfo &LI)201: DT(DT), LI(LI), F(F) {}202203private:204bool isBackEdge(const BasicBlock *From, const BasicBlock *To) const {205assert(From != To && "From == To. This is awkward.");206207// We only handle loop in the simplified form. This means:208// - a single back-edge, a single latch.209// - meaning the back-edge target can only be the loop header.210// - meaning the From can only be the loop latch.211if (!LI.isLoopHeader(To))212return false;213214auto *L = LI.getLoopFor(To);215if (L->contains(From) && L->isLoopLatch(From))216return true;217218return false;219}220221std::unordered_set<BasicBlock *>222findPathsToMatch(LoopInfo &LI, BasicBlock *From,223std::function<bool(const BasicBlock *)> isMatch) const {224std::unordered_set<BasicBlock *> Output;225226if (isMatch(From))227Output.insert(From);228229auto *Terminator = From->getTerminator();230for (unsigned i = 0; i < Terminator->getNumSuccessors(); ++i) {231auto *To = Terminator->getSuccessor(i);232if (isBackEdge(From, To))233continue;234235auto ChildSet = findPathsToMatch(LI, To, isMatch);236if (ChildSet.size() == 0)237continue;238239Output.insert(ChildSet.begin(), ChildSet.end());240Output.insert(From);241if (LI.isLoopHeader(From)) {242auto *L = LI.getLoopFor(From);243for (auto *BB : L->getBlocks()) {244Output.insert(BB);245}246}247}248249return Output;250}251252SmallPtrSet<BasicBlock *, 2>253findExitNodes(const SmallPtrSetImpl<BasicBlock *> &RegionBlocks) {254SmallPtrSet<BasicBlock *, 2> Exits;255256for (auto *B : RegionBlocks) {257auto *Terminator = B->getTerminator();258for (unsigned i = 0; i < Terminator->getNumSuccessors(); ++i) {259auto *Child = Terminator->getSuccessor(i);260if (RegionBlocks.count(Child) == 0)261Exits.insert(B);262}263}264265return Exits;266}267268public:269ConvergenceRegionInfo analyze() {270ConvergenceRegion *TopLevelRegion = new ConvergenceRegion(DT, LI, F);271std::queue<Loop *> ToProcess;272for (auto *L : LI.getLoopsInPreorder())273ToProcess.push(L);274275while (ToProcess.size() != 0) {276auto *L = ToProcess.front();277ToProcess.pop();278assert(L->isLoopSimplifyForm());279280auto CT = getConvergenceToken(L->getHeader());281SmallPtrSet<BasicBlock *, 8> RegionBlocks(L->block_begin(),282L->block_end());283SmallVector<BasicBlock *> LoopExits;284L->getExitingBlocks(LoopExits);285if (CT.has_value()) {286for (auto *Exit : LoopExits) {287auto N = findPathsToMatch(LI, Exit, [&CT](const BasicBlock *block) {288auto Token = getConvergenceToken(block);289if (Token == std::nullopt)290return false;291return Token.value() == CT.value();292});293RegionBlocks.insert(N.begin(), N.end());294}295}296297auto RegionExits = findExitNodes(RegionBlocks);298ConvergenceRegion *Region = new ConvergenceRegion(299DT, LI, CT, L->getHeader(), std::move(RegionBlocks),300std::move(RegionExits));301Region->Parent = findParentRegion(TopLevelRegion, Region->Entry);302assert(Region->Parent != nullptr && "This is impossible.");303Region->Parent->Children.push_back(Region);304}305306return ConvergenceRegionInfo(TopLevelRegion);307}308309private:310DominatorTree &DT;311LoopInfo &LI;312Function &F;313};314315ConvergenceRegionInfo getConvergenceRegions(Function &F, DominatorTree &DT,316LoopInfo &LI) {317ConvergenceRegionAnalyzer Analyzer(F, DT, LI);318return Analyzer.analyze();319}320321} // namespace SPIRV322323char SPIRVConvergenceRegionAnalysisWrapperPass::ID = 0;324325SPIRVConvergenceRegionAnalysisWrapperPass::326SPIRVConvergenceRegionAnalysisWrapperPass()327: FunctionPass(ID) {}328329bool SPIRVConvergenceRegionAnalysisWrapperPass::runOnFunction(Function &F) {330DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();331LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();332333CRI = SPIRV::getConvergenceRegions(F, DT, LI);334// Nothing was modified.335return false;336}337338SPIRVConvergenceRegionAnalysis::Result339SPIRVConvergenceRegionAnalysis::run(Function &F, FunctionAnalysisManager &AM) {340Result CRI;341auto &DT = AM.getResult<DominatorTreeAnalysis>(F);342auto &LI = AM.getResult<LoopAnalysis>(F);343CRI = SPIRV::getConvergenceRegions(F, DT, LI);344return CRI;345}346347AnalysisKey SPIRVConvergenceRegionAnalysis::Key;348349} // namespace llvm350351352