Path: blob/main/contrib/llvm-project/llvm/lib/Support/BalancedPartitioning.cpp
35232 views
//===- BalancedPartitioning.cpp -------------------------------------------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// This file implements BalancedPartitioning, a recursive balanced graph9// partitioning algorithm.10//11//===----------------------------------------------------------------------===//1213#include "llvm/Support/BalancedPartitioning.h"14#include "llvm/Support/Debug.h"15#include "llvm/Support/Format.h"16#include "llvm/Support/FormatVariadic.h"17#include "llvm/Support/ThreadPool.h"1819using namespace llvm;20#define DEBUG_TYPE "balanced-partitioning"2122void BPFunctionNode::dump(raw_ostream &OS) const {23OS << formatv("{{ID={0} Utilities={{{1:$[,]}} Bucket={2}}", Id,24make_range(UtilityNodes.begin(), UtilityNodes.end()), Bucket);25}2627template <typename Func>28void BalancedPartitioning::BPThreadPool::async(Func &&F) {29#if LLVM_ENABLE_THREADS30// This new thread could spawn more threads, so mark it as active31++NumActiveThreads;32TheThreadPool.async([=]() {33// Run the task34F();3536// This thread will no longer spawn new threads, so mark it as inactive37if (--NumActiveThreads == 0) {38// There are no more active threads, so mark as finished and notify39{40std::unique_lock<std::mutex> lock(mtx);41assert(!IsFinishedSpawning);42IsFinishedSpawning = true;43}44cv.notify_one();45}46});47#else48llvm_unreachable("threads are disabled");49#endif50}5152void BalancedPartitioning::BPThreadPool::wait() {53#if LLVM_ENABLE_THREADS54// TODO: We could remove the mutex and condition variable and use55// std::atomic::wait() instead, but that isn't available until C++2056{57std::unique_lock<std::mutex> lock(mtx);58cv.wait(lock, [&]() { return IsFinishedSpawning; });59assert(IsFinishedSpawning && NumActiveThreads == 0);60}61// Now we can call ThreadPool::wait() since all tasks have been submitted62TheThreadPool.wait();63#else64llvm_unreachable("threads are disabled");65#endif66}6768BalancedPartitioning::BalancedPartitioning(69const BalancedPartitioningConfig &Config)70: Config(Config) {71// Pre-computing log2 values72Log2Cache[0] = 0.0;73for (unsigned I = 1; I < LOG_CACHE_SIZE; I++)74Log2Cache[I] = std::log2(I);75}7677void BalancedPartitioning::run(std::vector<BPFunctionNode> &Nodes) const {78LLVM_DEBUG(79dbgs() << format(80"Partitioning %d nodes using depth %d and %d iterations per split\n",81Nodes.size(), Config.SplitDepth, Config.IterationsPerSplit));82std::optional<BPThreadPool> TP;83#if LLVM_ENABLE_THREADS84DefaultThreadPool TheThreadPool;85if (Config.TaskSplitDepth > 1)86TP.emplace(TheThreadPool);87#endif8889// Record the input order90for (unsigned I = 0; I < Nodes.size(); I++)91Nodes[I].InputOrderIndex = I;9293auto NodesRange = llvm::make_range(Nodes.begin(), Nodes.end());94auto BisectTask = [=, &TP]() {95bisect(NodesRange, /*RecDepth=*/0, /*RootBucket=*/1, /*Offset=*/0, TP);96};97if (TP) {98TP->async(std::move(BisectTask));99TP->wait();100} else {101BisectTask();102}103104llvm::stable_sort(NodesRange, [](const auto &L, const auto &R) {105return L.Bucket < R.Bucket;106});107108LLVM_DEBUG(dbgs() << "Balanced partitioning completed\n");109}110111void BalancedPartitioning::bisect(const FunctionNodeRange Nodes,112unsigned RecDepth, unsigned RootBucket,113unsigned Offset,114std::optional<BPThreadPool> &TP) const {115unsigned NumNodes = std::distance(Nodes.begin(), Nodes.end());116if (NumNodes <= 1 || RecDepth >= Config.SplitDepth) {117// We've reach the lowest level of the recursion tree. Fall back to the118// original order and assign to buckets.119llvm::sort(Nodes, [](const auto &L, const auto &R) {120return L.InputOrderIndex < R.InputOrderIndex;121});122for (auto &N : Nodes)123N.Bucket = Offset++;124return;125}126127LLVM_DEBUG(dbgs() << format("Bisect with %d nodes and root bucket %d\n",128NumNodes, RootBucket));129130std::mt19937 RNG(RootBucket);131132unsigned LeftBucket = 2 * RootBucket;133unsigned RightBucket = 2 * RootBucket + 1;134135// Split into two and assign to the left and right buckets136split(Nodes, LeftBucket);137138runIterations(Nodes, LeftBucket, RightBucket, RNG);139140// Split nodes wrt the resulting buckets141auto NodesMid =142llvm::partition(Nodes, [&](auto &N) { return N.Bucket == LeftBucket; });143unsigned MidOffset = Offset + std::distance(Nodes.begin(), NodesMid);144145auto LeftNodes = llvm::make_range(Nodes.begin(), NodesMid);146auto RightNodes = llvm::make_range(NodesMid, Nodes.end());147148auto LeftRecTask = [=, &TP]() {149bisect(LeftNodes, RecDepth + 1, LeftBucket, Offset, TP);150};151auto RightRecTask = [=, &TP]() {152bisect(RightNodes, RecDepth + 1, RightBucket, MidOffset, TP);153};154155if (TP && RecDepth < Config.TaskSplitDepth && NumNodes >= 4) {156TP->async(std::move(LeftRecTask));157TP->async(std::move(RightRecTask));158} else {159LeftRecTask();160RightRecTask();161}162}163164void BalancedPartitioning::runIterations(const FunctionNodeRange Nodes,165unsigned LeftBucket,166unsigned RightBucket,167std::mt19937 &RNG) const {168unsigned NumNodes = std::distance(Nodes.begin(), Nodes.end());169DenseMap<BPFunctionNode::UtilityNodeT, unsigned> UtilityNodeIndex;170for (auto &N : Nodes)171for (auto &UN : N.UtilityNodes)172++UtilityNodeIndex[UN];173// Remove utility nodes if they have just one edge or are connected to all174// functions175for (auto &N : Nodes)176llvm::erase_if(N.UtilityNodes, [&](auto &UN) {177return UtilityNodeIndex[UN] == 1 || UtilityNodeIndex[UN] == NumNodes;178});179180// Renumber utility nodes so they can be used to index into Signatures181UtilityNodeIndex.clear();182for (auto &N : Nodes)183for (auto &UN : N.UtilityNodes)184UN = UtilityNodeIndex.insert({UN, UtilityNodeIndex.size()}).first->second;185186// Initialize signatures187SignaturesT Signatures(/*Size=*/UtilityNodeIndex.size());188for (auto &N : Nodes) {189for (auto &UN : N.UtilityNodes) {190assert(UN < Signatures.size());191if (N.Bucket == LeftBucket) {192Signatures[UN].LeftCount++;193} else {194Signatures[UN].RightCount++;195}196}197}198199for (unsigned I = 0; I < Config.IterationsPerSplit; I++) {200unsigned NumMovedNodes =201runIteration(Nodes, LeftBucket, RightBucket, Signatures, RNG);202if (NumMovedNodes == 0)203break;204}205}206207unsigned BalancedPartitioning::runIteration(const FunctionNodeRange Nodes,208unsigned LeftBucket,209unsigned RightBucket,210SignaturesT &Signatures,211std::mt19937 &RNG) const {212// Init signature cost caches213for (auto &Signature : Signatures) {214if (Signature.CachedGainIsValid)215continue;216unsigned L = Signature.LeftCount;217unsigned R = Signature.RightCount;218assert((L > 0 || R > 0) && "incorrect signature");219float Cost = logCost(L, R);220Signature.CachedGainLR = 0.f;221Signature.CachedGainRL = 0.f;222if (L > 0)223Signature.CachedGainLR = Cost - logCost(L - 1, R + 1);224if (R > 0)225Signature.CachedGainRL = Cost - logCost(L + 1, R - 1);226Signature.CachedGainIsValid = true;227}228229// Compute move gains230typedef std::pair<float, BPFunctionNode *> GainPair;231std::vector<GainPair> Gains;232for (auto &N : Nodes) {233bool FromLeftToRight = (N.Bucket == LeftBucket);234float Gain = moveGain(N, FromLeftToRight, Signatures);235Gains.push_back(std::make_pair(Gain, &N));236}237238// Collect left and right gains239auto LeftEnd = llvm::partition(240Gains, [&](const auto &GP) { return GP.second->Bucket == LeftBucket; });241auto LeftRange = llvm::make_range(Gains.begin(), LeftEnd);242auto RightRange = llvm::make_range(LeftEnd, Gains.end());243244// Sort gains in descending order245auto LargerGain = [](const auto &L, const auto &R) {246return L.first > R.first;247};248llvm::stable_sort(LeftRange, LargerGain);249llvm::stable_sort(RightRange, LargerGain);250251unsigned NumMovedDataVertices = 0;252for (auto [LeftPair, RightPair] : llvm::zip(LeftRange, RightRange)) {253auto &[LeftGain, LeftNode] = LeftPair;254auto &[RightGain, RightNode] = RightPair;255// Stop when the gain is no longer beneficial256if (LeftGain + RightGain <= 0.f)257break;258// Try to exchange the nodes between buckets259if (moveFunctionNode(*LeftNode, LeftBucket, RightBucket, Signatures, RNG))260++NumMovedDataVertices;261if (moveFunctionNode(*RightNode, LeftBucket, RightBucket, Signatures, RNG))262++NumMovedDataVertices;263}264return NumMovedDataVertices;265}266267bool BalancedPartitioning::moveFunctionNode(BPFunctionNode &N,268unsigned LeftBucket,269unsigned RightBucket,270SignaturesT &Signatures,271std::mt19937 &RNG) const {272// Sometimes we skip the move. This helps to escape local optima273if (std::uniform_real_distribution<float>(0.f, 1.f)(RNG) <=274Config.SkipProbability)275return false;276277bool FromLeftToRight = (N.Bucket == LeftBucket);278// Update the current bucket279N.Bucket = (FromLeftToRight ? RightBucket : LeftBucket);280281// Update signatures and invalidate gain cache282if (FromLeftToRight) {283for (auto &UN : N.UtilityNodes) {284auto &Signature = Signatures[UN];285Signature.LeftCount--;286Signature.RightCount++;287Signature.CachedGainIsValid = false;288}289} else {290for (auto &UN : N.UtilityNodes) {291auto &Signature = Signatures[UN];292Signature.LeftCount++;293Signature.RightCount--;294Signature.CachedGainIsValid = false;295}296}297return true;298}299300void BalancedPartitioning::split(const FunctionNodeRange Nodes,301unsigned StartBucket) const {302unsigned NumNodes = std::distance(Nodes.begin(), Nodes.end());303auto NodesMid = Nodes.begin() + (NumNodes + 1) / 2;304305std::nth_element(Nodes.begin(), NodesMid, Nodes.end(), [](auto &L, auto &R) {306return L.InputOrderIndex < R.InputOrderIndex;307});308309for (auto &N : llvm::make_range(Nodes.begin(), NodesMid))310N.Bucket = StartBucket;311for (auto &N : llvm::make_range(NodesMid, Nodes.end()))312N.Bucket = StartBucket + 1;313}314315float BalancedPartitioning::moveGain(const BPFunctionNode &N,316bool FromLeftToRight,317const SignaturesT &Signatures) {318float Gain = 0.f;319for (auto &UN : N.UtilityNodes)320Gain += (FromLeftToRight ? Signatures[UN].CachedGainLR321: Signatures[UN].CachedGainRL);322return Gain;323}324325float BalancedPartitioning::logCost(unsigned X, unsigned Y) const {326return -(X * log2Cached(X + 1) + Y * log2Cached(Y + 1));327}328329float BalancedPartitioning::log2Cached(unsigned i) const {330return (i < LOG_CACHE_SIZE) ? Log2Cache[i] : std::log2(i);331}332333334