CCF
Loading...
Searching...
No Matches
thread_messaging.h
Go to the documentation of this file.
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the Apache 2.0 License.
3#pragma once
4
5#include "ccf/ds/logger.h"
7#include "ds/ccf_assert.h"
8
9#include <atomic>
10#include <chrono>
11#include <cstddef>
12
13namespace threading
14{
15 struct ThreadMsg
16 {
17 void (*cb)(std::unique_ptr<ThreadMsg>);
18 std::atomic<ThreadMsg*> next = nullptr;
19
20 ThreadMsg(void (*_cb)(std::unique_ptr<ThreadMsg>)) : cb(_cb) {}
21
22 virtual ~ThreadMsg() = default;
23 };
24
25 template <typename Payload>
26 struct alignas(16) Tmsg : public ThreadMsg
27 {
28 Payload data;
29
30 template <typename... Args>
31 Tmsg(void (*_cb)(std::unique_ptr<Tmsg<Payload>>), Args&&... args) :
32 ThreadMsg(reinterpret_cast<void (*)(std::unique_ptr<ThreadMsg>)>(_cb)),
33 data(std::forward<Args>(args)...)
34 {}
35
36 void reset_cb(void (*_cb)(std::unique_ptr<Tmsg<Payload>>))
37 {
38 cb = reinterpret_cast<void (*)(std::unique_ptr<ThreadMsg>)>(_cb);
39 }
40
41 virtual ~Tmsg() = default;
42 };
43
44 class ThreadMessaging;
45
47 {
48 std::atomic<ThreadMsg*> item_head = nullptr;
49 ThreadMsg* local_msg = nullptr;
50
51 public:
52 TaskQueue() = default;
53
55 {
56 if (local_msg == nullptr && item_head != nullptr)
57 {
58 local_msg = item_head.exchange(nullptr);
59 reverse_local_messages();
60 }
61
62 if (local_msg == nullptr)
63 {
64 return false;
65 }
66
67 ThreadMsg* current = local_msg;
68 local_msg = local_msg->next;
69
70 current->cb(std::unique_ptr<ThreadMsg>(current));
71 return true;
72 }
73
74 void add_task(ThreadMsg* item)
75 {
76 ThreadMsg* tmp_head;
77 do
78 {
79 tmp_head = item_head.load();
80 item->next = tmp_head;
81 } while (!item_head.compare_exchange_strong(tmp_head, item));
82 }
83
85 {
87 TimerEntry(std::chrono::milliseconds time_offset_, uint64_t counter_) :
88 time_offset(time_offset_),
89 counter(counter_)
90 {}
91
92 std::chrono::milliseconds time_offset;
93 uint64_t counter;
94 };
95
97 {
98 bool operator()(const TimerEntry& lhs, const TimerEntry& rhs) const
99 {
100 if (lhs.time_offset != rhs.time_offset)
101 {
102 return lhs.time_offset < rhs.time_offset;
103 }
104
105 return lhs.counter < rhs.counter;
106 }
107 };
108
110 std::unique_ptr<ThreadMsg> item, std::chrono::milliseconds ms)
111 {
112 TimerEntry entry = {time_offset + ms, time_entry_counter++};
113 if (timer_map.empty() || entry.time_offset <= next_time_offset)
114 {
115 next_time_offset = entry.time_offset;
116 }
117
118 timer_map.emplace(entry, std::move(item));
119 return entry;
120 }
121
123 {
124 auto num_erased = timer_map.erase(timer_entry);
125 CCF_ASSERT(num_erased <= 1, "Too many items erased");
126 if (!timer_map.empty() && timer_entry.time_offset <= next_time_offset)
127 {
128 next_time_offset = timer_map.begin()->first.time_offset;
129 }
130 return num_erased != 0;
131 }
132
133 void tick(std::chrono::milliseconds elapsed)
134 {
135 time_offset += elapsed;
136
137 bool updated = false;
138
139 while (!timer_map.empty() && next_time_offset <= time_offset &&
140 timer_map.begin()->first.time_offset <= time_offset)
141 {
142 updated = true;
143 auto it = timer_map.begin();
144
145 auto& cb = it->second->cb;
146 auto msg = std::move(it->second);
147 timer_map.erase(it);
148 cb(std::move(msg));
149 }
150
151 if (updated && !timer_map.empty())
152 {
153 next_time_offset = timer_map.begin()->first.time_offset;
154 }
155 }
156
157 std::chrono::milliseconds get_current_time_offset()
158 {
159 return time_offset;
160 }
161
162 private:
163 std::chrono::milliseconds time_offset = std::chrono::milliseconds(0);
164 uint64_t time_entry_counter = 0;
165 std::map<TimerEntry, std::unique_ptr<ThreadMsg>, TimerEntryCompare>
166 timer_map;
167 std::chrono::milliseconds next_time_offset;
168
169 void reverse_local_messages()
170 {
171 if (local_msg == nullptr)
172 return;
173
174 ThreadMsg *prev = nullptr, *current = nullptr, *next = nullptr;
175 current = local_msg;
176 while (current != nullptr)
177 {
178 next = current->next;
179 current->next = prev;
180 prev = current;
181 current = next;
182 }
183 // now let the head point at the last node (prev)
184 local_msg = prev;
185 }
186
187 void drop()
188 {
189 while (true)
190 {
191 if (local_msg == nullptr && item_head != nullptr)
192 {
193 local_msg = item_head.exchange(nullptr);
194 reverse_local_messages();
195 }
196
197 if (local_msg == nullptr)
198 {
199 break;
200 }
201
202 ThreadMsg* current = local_msg;
203 local_msg = local_msg->next;
204 delete current;
205 }
206 }
207
208 friend ThreadMessaging;
209 };
210
212 {
213 std::atomic<bool> finished;
214 std::vector<TaskQueue> tasks; // Fixed-size at construction
215
216 // Drop all pending tasks, this is only ever to be used
217 // on shutdown, to avoid leaks, and after all thread but
218 // the main one have been shut down.
219 void drop_tasks()
220 {
221 for (auto& t : tasks)
222 {
223 t.drop();
224 }
225 }
226
227 inline TaskQueue& get_tasks(uint16_t task_id)
228 {
229 if (task_id >= tasks.size())
230 {
231 throw std::runtime_error(fmt::format(
232 "Attempting to access task_id >= task_count, task_id:{}, "
233 "task_count:{}",
234 task_id,
235 tasks.size()));
236 }
237 return tasks[task_id];
238 }
239
240 static std::unique_ptr<ThreadMessaging>& get_singleton()
241 {
242 static std::unique_ptr<ThreadMessaging> singleton = nullptr;
243 return singleton;
244 }
245
246 public:
247 static constexpr uint16_t max_num_threads = 24;
248
249 ThreadMessaging(uint16_t num_task_queues) :
250 finished(false),
251 tasks(num_task_queues)
252 {
253 if (num_task_queues > max_num_threads)
254 {
255 throw std::logic_error(fmt::format(
256 "ThreadMessaging constructed with too many tasks: {} > {}",
257 num_task_queues,
259 }
260 }
261
263 {
264 drop_tasks();
265 }
266
267 static void init(uint16_t num_task_queues)
268 {
269 auto& singleton = get_singleton();
270 if (singleton != nullptr)
271 {
272 throw std::logic_error("Called init() multiple times");
273 }
274
275 singleton = std::make_unique<ThreadMessaging>(num_task_queues);
276 }
277
278 static void shutdown()
279 {
280 get_singleton().reset();
281 }
282
284 {
285 auto& singleton = get_singleton();
286 if (singleton == nullptr)
287 {
288 throw std::logic_error(
289 "Attempted to access global ThreadMessaging instance without first "
290 "calling init()");
291 }
292
293 return *singleton;
294 }
295
296 void set_finished(bool v = true)
297 {
298 finished.store(v);
299 }
300
301 void run()
302 {
304
305 while (!is_finished())
306 {
307 task.run_next_task();
308 }
309 }
310
311 bool run_one()
312 {
314 return task.run_next_task();
315 }
316
317 template <typename Payload>
318 void add_task(uint16_t tid, std::unique_ptr<Tmsg<Payload>> msg)
319 {
320 TaskQueue& task = get_tasks(tid);
321
322 task.add_task(reinterpret_cast<ThreadMsg*>(msg.release()));
323 }
324
325 template <typename Payload>
327 std::unique_ptr<Tmsg<Payload>> msg, std::chrono::milliseconds ms)
328 {
330 return task.add_task_after(std::move(msg), ms);
331 }
332
334 {
336 return task.cancel_timer_task(timer_entry);
337 }
338
339 std::chrono::milliseconds get_current_time_offset()
340 {
342 return task.get_current_time_offset();
343 }
344
345 struct TickMsg
346 {
347 TickMsg(std::chrono::milliseconds elapsed_, TaskQueue& task_) :
348 elapsed(elapsed_),
349 task(task_)
350 {}
351
352 std::chrono::milliseconds elapsed;
354 };
355
356 static void tick_cb(std::unique_ptr<Tmsg<TickMsg>> msg)
357 {
358 msg->data.task.tick(msg->data.elapsed);
359 }
360
361 void tick(std::chrono::milliseconds elapsed)
362 {
363 for (auto i = 0ul; i < tasks.size(); ++i)
364 {
365 auto& task = get_tasks(i);
366 auto msg = std::make_unique<Tmsg<TickMsg>>(&tick_cb, elapsed, task);
367 task.add_task(msg.release());
368 }
369 }
370
371 uint16_t get_execution_thread(uint32_t i)
372 {
373 uint16_t tid = ccf::threading::MAIN_THREAD_ID;
374 if (tasks.size() > 1)
375 {
376 // If we have multiple task queues, then we distinguish the main thread
377 // from the remaining workers; anything asking for an execution thread
378 // does _not_ go to the main thread's queue
379 tid = (i % (tasks.size() - 1));
380 ++tid;
381 }
382
383 return tid;
384 }
385
386 uint16_t thread_count() const
387 {
388 return tasks.size();
389 }
390
391 private:
392 bool is_finished()
393 {
394 return finished.load();
395 }
396 };
397};
#define CCF_ASSERT(expr, msg)
Definition ccf_assert.h:14
Definition thread_messaging.h:47
void tick(std::chrono::milliseconds elapsed)
Definition thread_messaging.h:133
bool run_next_task()
Definition thread_messaging.h:54
void add_task(ThreadMsg *item)
Definition thread_messaging.h:74
std::chrono::milliseconds get_current_time_offset()
Definition thread_messaging.h:157
TimerEntry add_task_after(std::unique_ptr< ThreadMsg > item, std::chrono::milliseconds ms)
Definition thread_messaging.h:109
bool cancel_timer_task(TimerEntry timer_entry)
Definition thread_messaging.h:122
Definition thread_messaging.h:212
ThreadMessaging(uint16_t num_task_queues)
Definition thread_messaging.h:249
void tick(std::chrono::milliseconds elapsed)
Definition thread_messaging.h:361
void run()
Definition thread_messaging.h:301
std::chrono::milliseconds get_current_time_offset()
Definition thread_messaging.h:339
static void shutdown()
Definition thread_messaging.h:278
uint16_t thread_count() const
Definition thread_messaging.h:386
static ThreadMessaging & instance()
Definition thread_messaging.h:283
bool cancel_timer_task(TaskQueue::TimerEntry timer_entry)
Definition thread_messaging.h:333
bool run_one()
Definition thread_messaging.h:311
TaskQueue::TimerEntry add_task_after(std::unique_ptr< Tmsg< Payload > > msg, std::chrono::milliseconds ms)
Definition thread_messaging.h:326
void add_task(uint16_t tid, std::unique_ptr< Tmsg< Payload > > msg)
Definition thread_messaging.h:318
static void init(uint16_t num_task_queues)
Definition thread_messaging.h:267
void set_finished(bool v=true)
Definition thread_messaging.h:296
~ThreadMessaging()
Definition thread_messaging.h:262
static void tick_cb(std::unique_ptr< Tmsg< TickMsg > > msg)
Definition thread_messaging.h:356
uint16_t get_execution_thread(uint32_t i)
Definition thread_messaging.h:371
static constexpr uint16_t max_num_threads
Definition thread_messaging.h:247
uint16_t get_current_thread_id()
Definition thread_local.cpp:15
STL namespace.
Definition thread_messaging.h:14
Definition thread_messaging.h:97
bool operator()(const TimerEntry &lhs, const TimerEntry &rhs) const
Definition thread_messaging.h:98
Definition thread_messaging.h:85
TimerEntry(std::chrono::milliseconds time_offset_, uint64_t counter_)
Definition thread_messaging.h:87
std::chrono::milliseconds time_offset
Definition thread_messaging.h:92
TimerEntry()
Definition thread_messaging.h:86
uint64_t counter
Definition thread_messaging.h:93
Definition thread_messaging.h:346
std::chrono::milliseconds elapsed
Definition thread_messaging.h:352
TaskQueue & task
Definition thread_messaging.h:353
TickMsg(std::chrono::milliseconds elapsed_, TaskQueue &task_)
Definition thread_messaging.h:347
Definition thread_messaging.h:16
void(* cb)(std::unique_ptr< ThreadMsg >)
Definition thread_messaging.h:17
std::atomic< ThreadMsg * > next
Definition thread_messaging.h:18
virtual ~ThreadMsg()=default
ThreadMsg(void(*_cb)(std::unique_ptr< ThreadMsg >))
Definition thread_messaging.h:20
Definition thread_messaging.h:27
Tmsg(void(*_cb)(std::unique_ptr< Tmsg< Payload > >), Args &&... args)
Definition thread_messaging.h:31
virtual ~Tmsg()=default
void reset_cb(void(*_cb)(std::unique_ptr< Tmsg< Payload > >))
Definition thread_messaging.h:36
Payload data
Definition thread_messaging.h:28