thread-pool
Loading...
Searching...
No Matches
thread_pool.h
Go to the documentation of this file.
1#pragma once
2
3#include <atomic>
4#include <barrier>
5#include <concepts>
6#include <deque>
7#include <functional>
8#include <future>
9#include <memory>
10#include <semaphore>
11#include <thread>
12#include <type_traits>
13#ifdef __has_include
14# if __has_include(<version>)
15# include <version>
16# endif
17#endif
18
20
21namespace dp {
22 namespace details {
23
24#ifdef __cpp_lib_move_only_function
25 using default_function_type = std::move_only_function<void()>;
26#else
27 using default_function_type = std::function<void()>;
28#endif
29 } // namespace details
30
31 template <typename FunctionType = details::default_function_type,
32 typename ThreadType = std::jthread>
33 requires std::invocable<FunctionType> &&
34 std::is_same_v<void, std::invoke_result_t<FunctionType>>
36 public:
37 explicit thread_pool(
38 const unsigned int &number_of_threads = std::thread::hardware_concurrency())
39 : tasks_(number_of_threads) {
40 std::size_t current_id = 0;
41 for (std::size_t i = 0; i < number_of_threads; ++i) {
42 priority_queue_.push_back(size_t(current_id));
43 try {
44 threads_.emplace_back([&, id = current_id](const std::stop_token &stop_tok) {
45 do {
46 // wait until signaled
47 tasks_[id].signal.acquire();
48
49 do {
50 // invoke the task
51 while (auto task = tasks_[id].tasks.pop_front()) {
52 try {
53 pending_tasks_.fetch_sub(1, std::memory_order_release);
54 std::invoke(std::move(task.value()));
55 } catch (...) {
56 }
57 }
58
59 // try to steal a task
60 for (std::size_t j = 1; j < tasks_.size(); ++j) {
61 const std::size_t index = (id + j) % tasks_.size();
62 if (auto task = tasks_[index].tasks.steal()) {
63 // steal a task
64 pending_tasks_.fetch_sub(1, std::memory_order_release);
65 std::invoke(std::move(task.value()));
66 // stop stealing once we have invoked a stolen task
67 break;
68 }
69 }
70
71 } while (pending_tasks_.load(std::memory_order_acquire) > 0);
72
73 priority_queue_.rotate_to_front(id);
74
75 } while (!stop_tok.stop_requested());
76 });
77 // increment the thread id
78 ++current_id;
79
80 } catch (...) {
81 // catch all
82
83 // remove one item from the tasks
84 tasks_.pop_back();
85
86 // remove our thread from the priority queue
87 std::ignore = priority_queue_.pop_back();
88 }
89 }
90 }
91
93 // stop all threads
94 for (std::size_t i = 0; i < threads_.size(); ++i) {
95 threads_[i].request_stop();
96 tasks_[i].signal.release();
97 threads_[i].join();
98 }
99 }
100
102 thread_pool(const thread_pool &) = delete;
104
115 template <typename Function, typename... Args,
116 typename ReturnType = std::invoke_result_t<Function &&, Args &&...>>
117 requires std::invocable<Function, Args...>
118 [[nodiscard]] std::future<ReturnType> enqueue(Function f, Args... args) {
119#if __cpp_lib_move_only_function
120 // we can do this in C++23 because we now have support for move only functions
121 std::promise<ReturnType> promise;
122 auto future = promise.get_future();
123 auto task = [func = std::move(f), ... largs = std::move(args),
124 promise = std::move(promise)]() mutable {
125 try {
126 if constexpr (std::is_same_v<ReturnType, void>) {
127 func(largs...);
128 promise.set_value();
129 } else {
130 promise.set_value(func(largs...));
131 }
132 } catch (...) {
133 promise.set_exception(std::current_exception());
134 }
135 };
136 enqueue_task(std::move(task));
137 return future;
138#else
139 /*
140 * use shared promise here so that we don't break the promise later (until C++23)
141 *
142 * with C++23 we can do the following:
143 *
144 * std::promise<ReturnType> promise;
145 * auto future = promise.get_future();
146 * auto task = [func = std::move(f), ...largs = std::move(args),
147 promise = std::move(promise)]() mutable {...};
148 */
149 auto shared_promise = std::make_shared<std::promise<ReturnType>>();
150 auto task = [func = std::move(f), ... largs = std::move(args),
151 promise = shared_promise]() {
152 try {
153 if constexpr (std::is_same_v<ReturnType, void>) {
154 func(largs...);
155 promise->set_value();
156 } else {
157 promise->set_value(func(largs...));
158 }
159
160 } catch (...) {
161 promise->set_exception(std::current_exception());
162 }
163 };
164
165 // get the future before enqueuing the task
166 auto future = shared_promise->get_future();
167 // enqueue the task
168 enqueue_task(std::move(task));
169 return future;
170#endif
171 }
172
180 template <typename Function, typename... Args>
181 requires std::invocable<Function, Args...> &&
182 std::is_same_v<void, std::invoke_result_t<Function &&, Args &&...>>
183 void enqueue_detach(Function &&func, Args &&...args) {
184 enqueue_task(
185 std::move([f = std::forward<Function>(func),
186 ... largs = std::forward<Args>(args)]() mutable -> decltype(auto) {
187 // suppress exceptions
188 try {
189 std::invoke(f, largs...);
190 } catch (...) {
191 }
192 }));
193 }
194
195 [[nodiscard]] auto size() const { return threads_.size(); }
196
197 private:
198 template <typename Function>
199 void enqueue_task(Function &&f) {
200 auto i_opt = priority_queue_.copy_front_and_rotate_to_back();
201 if (!i_opt.has_value()) {
202 // would only be a problem if there are zero threads
203 return;
204 }
205 auto i = *(i_opt);
206 pending_tasks_.fetch_add(1, std::memory_order_relaxed);
207 tasks_[i].tasks.push_back(std::forward<Function>(f));
208 tasks_[i].signal.release();
209 }
210
211 struct task_item {
213 std::binary_semaphore signal{0};
214 };
215
216 std::vector<ThreadType> threads_;
217 std::deque<task_item> tasks_;
219 std::atomic_int_fast64_t pending_tasks_{};
220 };
221
227} // namespace dp
Definition thread_pool.h:35
~thread_pool()
Definition thread_pool.h:92
auto size() const
Definition thread_pool.h:195
std::future< ReturnType > enqueue(Function f, Args... args)
Enqueue a task into the thread pool that returns a result.
Definition thread_pool.h:118
void enqueue_detach(Function &&func, Args &&...args)
Enqueue a task to be executed in the thread pool that returns void.
Definition thread_pool.h:183
thread_pool & operator=(const thread_pool &)=delete
thread_pool(const thread_pool &)=delete
thread pool is non-copyable
thread_pool(const unsigned int &number_of_threads=std::thread::hardware_concurrency())
Definition thread_pool.h:37
Definition thread_safe_queue.h:25
void push_back(T &&value)
Definition thread_safe_queue.h:32
std::optional< T > pop_back()
Definition thread_safe_queue.h:56
void rotate_to_front(const T &item)
Definition thread_safe_queue.h:74
std::function< void()> default_function_type
Definition thread_pool.h:27
Definition thread_pool.h:21