Path: blob/main/src/utils/threadpool/WorkStealingThreadPool.h
169678 views
/****************************************************************************/1// Eclipse SUMO, Simulation of Urban MObility; see https://eclipse.dev/sumo2// Copyright (C) 2020-2025 German Aerospace Center (DLR) and others.3// This program and the accompanying materials are made available under the4// terms of the Eclipse Public License 2.0 which is available at5// https://www.eclipse.org/legal/epl-2.0/6// This Source Code may also be made available under the following Secondary7// Licenses when the conditions for such availability set forth in the Eclipse8// Public License 2.0 are satisfied: GNU General Public License, version 29// or later which is available at10// https://www.gnu.org/licenses/old-licenses/gpl-2.0-standalone.html11// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later12/****************************************************************************/13/// @file WorkStealingThreadPool.h14/// @author Michael Behrisch15/// @date 2020-09-0916///17// Threadpool implementation,18// based on https://github.com/vukis/Cpp-Utilities/tree/master/ThreadPool19/****************************************************************************/20#pragma once21#include <config.h>2223#include <algorithm>24#include <thread>25#include "TaskQueue.h"262728template<typename CONTEXT = int>29class WorkStealingThreadPool {30public:3132explicit WorkStealingThreadPool(const bool workSteal, const std::vector<CONTEXT>& context)33: myQueues{ context.size() }, myTryoutCount(workSteal ? 1 : 0) {34size_t index = 0;35for (const CONTEXT& c : context) {36if (workSteal) {37myThreads.emplace_back([this, index, c] { workStealRun(index, c); });38} else {39myThreads.emplace_back([this, index, c] { run(index, c); });40}41index++;42}43}4445~WorkStealingThreadPool() {46for (auto& queue : myQueues) {47queue.setEnabled(false);48}49for (auto& thread : myThreads) {50thread.join();51}52}5354template<typename TaskT>55auto executeAsync(TaskT&& task, int idx = -1) -> std::future<decltype(task(std::declval<CONTEXT>()))> {56const auto index = idx == -1 ? myQueueIndex++ : idx;57if (myTryoutCount > 0) {58for (size_t n = 0; n != myQueues.size() * myTryoutCount; ++n) {59// Here we need not to std::forward just copy task.60// Because if the universal reference of task has bound to an r-value reference61// then std::forward will have the same effect as std::move and thus task is not required to contain a valid task.62// Universal reference must only be std::forward'ed a exactly zero or one times.63bool success = false;64auto result = myQueues[(index + n) % myQueues.size()].tryPush(task, success);6566if (success) {67return result;68}69}70}71return myQueues[index % myQueues.size()].push(std::forward<TaskT>(task));72}7374void waitAll() {75std::vector<std::future<void>> results;76for (int n = 0; n != (int)myQueues.size(); ++n) {77results.push_back(executeAsync([](CONTEXT) {}, n));78}79for (auto& r : results) {80r.wait();81}82}8384private:85void run(size_t queueIndex, const CONTEXT& context) {86while (myQueues[queueIndex].isEnabled()) {87typename TaskQueue<CONTEXT>::TaskPtrType task;88if (myQueues[queueIndex].waitAndPop(task)) {89task->exec(context);90}91}92}9394void workStealRun(size_t queueIndex, const CONTEXT& context) {95while (myQueues[queueIndex].isEnabled()) {96typename TaskQueue<CONTEXT>::TaskPtrType task;97for (size_t n = 0; n != myQueues.size()*myTryoutCount; ++n) {98if (myQueues[(queueIndex + n) % myQueues.size()].tryPop(task)) {99break;100}101}102if (!task && !myQueues[queueIndex].waitAndPop(task)) {103return;104}105task->exec(context);106}107}108109private:110std::vector<TaskQueue<CONTEXT> > myQueues;111std::atomic<size_t> myQueueIndex{ 0 };112const size_t myTryoutCount;113std::vector<std::thread> myThreads;114};115116117