Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
eclipse
GitHub Repository: eclipse/sumo
Path: blob/main/src/utils/threadpool/WorkStealingThreadPool.h
169678 views
1
/****************************************************************************/
2
// Eclipse SUMO, Simulation of Urban MObility; see https://eclipse.dev/sumo
3
// Copyright (C) 2020-2025 German Aerospace Center (DLR) and others.
4
// This program and the accompanying materials are made available under the
5
// terms of the Eclipse Public License 2.0 which is available at
6
// https://www.eclipse.org/legal/epl-2.0/
7
// This Source Code may also be made available under the following Secondary
8
// Licenses when the conditions for such availability set forth in the Eclipse
9
// Public License 2.0 are satisfied: GNU General Public License, version 2
10
// or later which is available at
11
// https://www.gnu.org/licenses/old-licenses/gpl-2.0-standalone.html
12
// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
13
/****************************************************************************/
14
/// @file WorkStealingThreadPool.h
15
/// @author Michael Behrisch
16
/// @date 2020-09-09
17
///
18
// Threadpool implementation,
19
// based on https://github.com/vukis/Cpp-Utilities/tree/master/ThreadPool
20
/****************************************************************************/
21
#pragma once
22
#include <config.h>
23
24
#include <algorithm>
25
#include <thread>
26
#include "TaskQueue.h"
27
28
29
template<typename CONTEXT = int>
30
class WorkStealingThreadPool {
31
public:
32
33
explicit WorkStealingThreadPool(const bool workSteal, const std::vector<CONTEXT>& context)
34
: myQueues{ context.size() }, myTryoutCount(workSteal ? 1 : 0) {
35
size_t index = 0;
36
for (const CONTEXT& c : context) {
37
if (workSteal) {
38
myThreads.emplace_back([this, index, c] { workStealRun(index, c); });
39
} else {
40
myThreads.emplace_back([this, index, c] { run(index, c); });
41
}
42
index++;
43
}
44
}
45
46
~WorkStealingThreadPool() {
47
for (auto& queue : myQueues) {
48
queue.setEnabled(false);
49
}
50
for (auto& thread : myThreads) {
51
thread.join();
52
}
53
}
54
55
template<typename TaskT>
56
auto executeAsync(TaskT&& task, int idx = -1) -> std::future<decltype(task(std::declval<CONTEXT>()))> {
57
const auto index = idx == -1 ? myQueueIndex++ : idx;
58
if (myTryoutCount > 0) {
59
for (size_t n = 0; n != myQueues.size() * myTryoutCount; ++n) {
60
// Here we need not to std::forward just copy task.
61
// Because if the universal reference of task has bound to an r-value reference
62
// then std::forward will have the same effect as std::move and thus task is not required to contain a valid task.
63
// Universal reference must only be std::forward'ed a exactly zero or one times.
64
bool success = false;
65
auto result = myQueues[(index + n) % myQueues.size()].tryPush(task, success);
66
67
if (success) {
68
return result;
69
}
70
}
71
}
72
return myQueues[index % myQueues.size()].push(std::forward<TaskT>(task));
73
}
74
75
void waitAll() {
76
std::vector<std::future<void>> results;
77
for (int n = 0; n != (int)myQueues.size(); ++n) {
78
results.push_back(executeAsync([](CONTEXT) {}, n));
79
}
80
for (auto& r : results) {
81
r.wait();
82
}
83
}
84
85
private:
86
void run(size_t queueIndex, const CONTEXT& context) {
87
while (myQueues[queueIndex].isEnabled()) {
88
typename TaskQueue<CONTEXT>::TaskPtrType task;
89
if (myQueues[queueIndex].waitAndPop(task)) {
90
task->exec(context);
91
}
92
}
93
}
94
95
void workStealRun(size_t queueIndex, const CONTEXT& context) {
96
while (myQueues[queueIndex].isEnabled()) {
97
typename TaskQueue<CONTEXT>::TaskPtrType task;
98
for (size_t n = 0; n != myQueues.size()*myTryoutCount; ++n) {
99
if (myQueues[(queueIndex + n) % myQueues.size()].tryPop(task)) {
100
break;
101
}
102
}
103
if (!task && !myQueues[queueIndex].waitAndPop(task)) {
104
return;
105
}
106
task->exec(context);
107
}
108
}
109
110
private:
111
std::vector<TaskQueue<CONTEXT> > myQueues;
112
std::atomic<size_t> myQueueIndex{ 0 };
113
const size_t myTryoutCount;
114
std::vector<std::thread> myThreads;
115
};
116
117