From 0067c8289c6e5b1273f0a91bcc591b579db59420 Mon Sep 17 00:00:00 2001 From: Eddy Ashton Date: Fri, 31 Oct 2025 14:42:12 +0000 Subject: [PATCH 1/3] Aggressive rebase --- .ruff.toml | 3 +- CMakeLists.txt | 31 +- cmake/common.cmake | 7 +- include/ccf/ds/logger.h | 4 +- include/ccf/threading/thread_ids.h | 3 + src/consensus/aft/raft.h | 1 + src/consensus/aft/test/driver.cpp | 2 - src/consensus/aft/test/main.cpp | 1 - src/crypto/openssl/symmetric_key.cpp | 1 - src/ds/serializer.h | 5 + src/ds/test/messaging.cpp | 1 - src/ds/test/thread_messaging.cpp | 214 ---------- src/ds/thread_messaging.h | 397 ------------------ src/enclave/enclave.h | 55 +-- src/enclave/main.cpp | 16 +- src/enclave/session.h | 124 ++++-- src/enclave/thread_local.cpp | 16 + src/enclave/tls_session.h | 95 +---- src/http/curl.h | 38 +- src/http/http2_session.h | 8 +- src/http/http_session.h | 14 +- src/http/test/curl_test.cpp | 12 +- src/indexing/test/indexing.cpp | 1 - src/node/channels.h | 3 +- src/node/history.h | 85 ++-- src/node/jwt_key_auto_refresh.h | 86 ++-- src/node/node_state.h | 73 ++-- src/node/node_to_node_channel_manager.h | 1 + src/node/quote_endorsements_client.h | 184 ++++---- src/node/retired_nodes_cleanup.h | 21 +- src/node/rpc/forwarder.h | 87 +--- src/node/rpc/test/frontend_test.cpp | 1 - src/node/snapshotter.h | 45 +- src/node/test/historical_queries.cpp | 1 - src/node/test/history.cpp | 1 - src/node/test/history_bench.cpp | 1 - src/node/test/snapshot.cpp | 1 - src/node/test/snapshotter.cpp | 30 +- src/quic/quic_session.h | 148 +++---- .../test/demo/concurrent_queue_interface.h | 25 ++ .../test/demo/locking_concurrent_queue.h | 58 +++ src/tasks/test/flush_all_jobs.h | 54 +++ src/tasks/test/merge_bench.cpp | 86 ++++ src/tasks/test/merge_sort.h | 73 ++++ src/tasks/test/promises.cpp | 110 +++++ src/tasks/test/sleep_bench.cpp | 61 +++ src/tasks/test/task_system_thread.h | 48 +++ 47 files changed, 1035 insertions(+), 1297 deletions(-) delete mode 100644 src/ds/test/thread_messaging.cpp delete mode 100644 src/ds/thread_messaging.h create mode 100644 src/tasks/test/demo/concurrent_queue_interface.h create mode 100644 src/tasks/test/demo/locking_concurrent_queue.h create mode 100644 src/tasks/test/flush_all_jobs.h create mode 100644 src/tasks/test/merge_bench.cpp create mode 100644 src/tasks/test/merge_sort.h create mode 100644 src/tasks/test/promises.cpp create mode 100644 src/tasks/test/sleep_bench.cpp create mode 100644 src/tasks/test/task_system_thread.h diff --git a/.ruff.toml b/.ruff.toml index ac1a59c604b5..6fa99b180959 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -1,2 +1 @@ -extend-exclude = ["*_pb2*.py"] -line-length = 2000 \ No newline at end of file +line-length = 320 \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 597c4a9f44a5..84adb8c879bc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -389,6 +389,7 @@ add_ccf_static_library( ccf_endpoints ccfcrypto ccf_kv + ccf_tasks nghttp2 ${CMAKE_THREAD_LIBS_INIT} curl @@ -549,7 +550,6 @@ if(BUILD_TESTS) ${CMAKE_CURRENT_SOURCE_DIR}/src/ds/test/serialized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/ds/test/serializer.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/ds/test/hash.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/ds/test/thread_messaging.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/ds/test/lru.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/ds/test/hex.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/ds/test/contiguous_set.cpp @@ -584,7 +584,7 @@ if(BUILD_TESTS) ${CMAKE_CURRENT_SOURCE_DIR}/src/consensus/aft/test/view_history.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/consensus/aft/test/committable_suffix.cpp ) - target_link_libraries(raft_test PRIVATE ccfcrypto) + target_link_libraries(raft_test PRIVATE ccfcrypto ccf_tasks) add_unit_test( raft_enclave_test @@ -615,7 +615,9 @@ if(BUILD_TESTS) add_unit_test( history_test ${CMAKE_CURRENT_SOURCE_DIR}/src/node/test/history.cpp ) - target_link_libraries(history_test PRIVATE ccfcrypto http_parser ccf_kv) + target_link_libraries( + history_test PRIVATE ccfcrypto http_parser ccf_kv ccf_tasks + ) add_unit_test( encryptor_test ${CMAKE_CURRENT_SOURCE_DIR}/src/node/test/encryptor.cpp @@ -648,23 +650,26 @@ if(BUILD_TESTS) ) target_link_libraries( historical_queries_test PRIVATE http_parser ccf_kv ccf_endpoints + ccf_tasks ) add_unit_test( indexing_test ${CMAKE_CURRENT_SOURCE_DIR}/src/indexing/test/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/indexing/test/lfs.cpp ) - target_link_libraries(indexing_test PRIVATE ccf_endpoints ccf_kv) + target_link_libraries(indexing_test PRIVATE ccf_endpoints ccf_kv ccf_tasks) add_unit_test( snapshot_test ${CMAKE_CURRENT_SOURCE_DIR}/src/node/test/snapshot.cpp ) - target_link_libraries(snapshot_test PRIVATE ccf_kv) + target_link_libraries(snapshot_test PRIVATE ccf_kv ccf_tasks) add_unit_test( snapshotter_test ${CMAKE_CURRENT_SOURCE_DIR}/src/node/test/snapshotter.cpp ) - target_link_libraries(snapshotter_test PRIVATE ccf_kv ccf_endpoints) + target_link_libraries( + snapshotter_test PRIVATE ccf_kv ccf_endpoints ccf_tasks + ) add_unit_test( node_info_json_test @@ -715,8 +720,14 @@ if(BUILD_TESTS) ${CCF_DIR}/src/node/quote.cpp ${CCF_DIR}/src/node/uvm_endorsements.cpp ) target_link_libraries( - frontend_test PRIVATE ${CMAKE_THREAD_LIBS_INIT} http_parser ccf_js - ccf_endpoints ccfcrypto ccf_kv + frontend_test + PRIVATE ${CMAKE_THREAD_LIBS_INIT} + http_parser + ccf_js + ccf_endpoints + ccfcrypto + ccf_kv + ccf_tasks ) add_unit_test( @@ -768,7 +779,7 @@ if(BUILD_TESTS) raft_driver ${CMAKE_CURRENT_SOURCE_DIR}/src/consensus/aft/test/driver.cpp src/enclave/thread_local.cpp ) - target_link_libraries(raft_driver PRIVATE ccfcrypto) + target_link_libraries(raft_driver PRIVATE ccfcrypto ccf_tasks) target_include_directories(raft_driver PRIVATE src/aft) add_test( @@ -811,7 +822,7 @@ if(BUILD_TESTS) add_picobench( history_bench SRCS src/node/test/history_bench.cpp src/enclave/thread_local.cpp - LINK_LIBS ccf_kv + LINK_LIBS ccf_kv ccf_tasks ) add_picobench( diff --git a/cmake/common.cmake b/cmake/common.cmake index d22799281b96..e6d77f9e5fdd 100644 --- a/cmake/common.cmake +++ b/cmake/common.cmake @@ -225,8 +225,11 @@ function(add_picobench name) bash -c "$ --samples=10 --out-fmt=csv --output=${name}.csv && cat ${name}.csv" ) - - set_property(TEST ${name} PROPERTY LABELS benchmark) + set_property( + TEST ${name} + APPEND + PROPERTY LABELS benchmark + ) add_san_test_properties(${name}) endfunction() diff --git a/include/ccf/ds/logger.h b/include/ccf/ds/logger.h index 52662270724c..40b12dbba8d8 100644 --- a/include/ccf/ds/logger.h +++ b/include/ccf/ds/logger.h @@ -44,7 +44,7 @@ namespace ccf::logger std::string tag; std::string file_name; size_t line_number; - uint16_t thread_id; + std::string thread_id; std::ostringstream ss; std::string msg; @@ -59,7 +59,7 @@ namespace ccf::logger file_name(file_name_), line_number(line_number_) { - thread_id = ccf::threading::get_current_thread_id(); + thread_id = ccf::threading::get_current_thread_name(); } template diff --git a/include/ccf/threading/thread_ids.h b/include/ccf/threading/thread_ids.h index 2def60951587..4f59b088a0e0 100644 --- a/include/ccf/threading/thread_ids.h +++ b/include/ccf/threading/thread_ids.h @@ -22,4 +22,7 @@ namespace ccf::threading uint16_t get_current_thread_id(); void set_current_thread_id(ThreadID to); void reset_thread_id_generator(ThreadID to = MAIN_THREAD_ID); + + std::string get_current_thread_name(); + void set_current_thread_name(std::string_view sv); } \ No newline at end of file diff --git a/src/consensus/aft/raft.h b/src/consensus/aft/raft.h index 66601f676c62..13156288a9c9 100644 --- a/src/consensus/aft/raft.h +++ b/src/consensus/aft/raft.h @@ -7,6 +7,7 @@ #include "ccf/service/reconfiguration_type.h" #include "ccf/tx_id.h" #include "ccf/tx_status.h" +#include "ds/ccf_assert.h" #include "ds/internal_logger.h" #include "ds/serialized.h" #include "impl/state.h" diff --git a/src/consensus/aft/test/driver.cpp b/src/consensus/aft/test/driver.cpp index 903856cc8f5d..cbead7255edd 100644 --- a/src/consensus/aft/test/driver.cpp +++ b/src/consensus/aft/test/driver.cpp @@ -38,8 +38,6 @@ int main(int argc, char** argv) #endif ccf::logger::config::level() = ccf::LoggerLevel::DEBUG; - threading::ThreadMessaging::init(1); - const std::string filename = argv[1]; std::ifstream fstream; diff --git a/src/consensus/aft/test/main.cpp b/src/consensus/aft/test/main.cpp index d42bfd51d2d0..17f1d6937955 100644 --- a/src/consensus/aft/test/main.cpp +++ b/src/consensus/aft/test/main.cpp @@ -1010,7 +1010,6 @@ DOCTEST_TEST_CASE( int main(int argc, char** argv) { - threading::ThreadMessaging::init(1); doctest::Context context; context.applyCommandLine(argc, argv); int res = context.run(); diff --git a/src/crypto/openssl/symmetric_key.cpp b/src/crypto/openssl/symmetric_key.cpp index bd18b3dcc2b6..f073d8725788 100644 --- a/src/crypto/openssl/symmetric_key.cpp +++ b/src/crypto/openssl/symmetric_key.cpp @@ -6,7 +6,6 @@ #include "ccf/crypto/openssl/openssl_wrappers.h" #include "ccf/crypto/symmetric_key.h" #include "ds/internal_logger.h" -#include "ds/thread_messaging.h" #include #include diff --git a/src/ds/serializer.h b/src/ds/serializer.h index 7e7aed5ae0ce..91ee53f4e712 100644 --- a/src/ds/serializer.h +++ b/src/ds/serializer.h @@ -27,6 +27,11 @@ namespace serializer { const uint8_t* data; const size_t size; + + operator std::span() const + { + return std::span(data, size); + } }; namespace details diff --git a/src/ds/test/messaging.cpp b/src/ds/test/messaging.cpp index f92e5918e7ed..e71b5422e1b7 100644 --- a/src/ds/test/messaging.cpp +++ b/src/ds/test/messaging.cpp @@ -5,7 +5,6 @@ #include "../non_blocking.h" #include "../ring_buffer.h" #include "../serialized.h" -#include "../thread_messaging.h" #include #include diff --git a/src/ds/test/thread_messaging.cpp b/src/ds/test/thread_messaging.cpp deleted file mode 100644 index 26678b3b2c26..000000000000 --- a/src/ds/test/thread_messaging.cpp +++ /dev/null @@ -1,214 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the Apache 2.0 License. -#include "../thread_messaging.h" - -#include - -struct Foo -{ - bool& happened; - - static size_t count; - - Foo(bool& h) : happened(h) - { - count++; - } - - ~Foo() - { - count--; - } -}; - -size_t Foo::count = 0; - -static void always(std::unique_ptr<::threading::Tmsg> msg) -{ - msg->data.happened = true; -} - -static void never(std::unique_ptr<::threading::Tmsg> msg) -{ - CHECK(false); -} - -TEST_CASE("ThreadMessaging API" * doctest::test_suite("threadmessaging")) -{ - { - ::threading::ThreadMessaging tm(1); - - static constexpr auto worker_thread_id = ccf::threading::MAIN_THREAD_ID + 1; - - bool happened_main_thread = false; - bool happened_worker_thread = false; - - tm.add_task( - ccf::threading::MAIN_THREAD_ID, - std::make_unique<::threading::Tmsg>(&always, happened_main_thread)); - - REQUIRE_THROWS(tm.add_task( - worker_thread_id, - std::make_unique<::threading::Tmsg>( - &always, happened_worker_thread))); - - REQUIRE(tm.run_one()); - REQUIRE_FALSE(tm.run_one()); - - REQUIRE(happened_main_thread); - REQUIRE_FALSE(happened_worker_thread); - } - - { - // Create a ThreadMessaging with task queues for main thread + 1 worker - // thread - ::threading::ThreadMessaging tm(2); - - static constexpr auto worker_a_id = 1; - static constexpr auto worker_b_id = 2; - - ccf::threading::reset_thread_id_generator(worker_a_id); - - bool happened_0 = false; - bool happened_1 = false; - bool happened_2 = false; - bool happened_3 = false; - - // Queue single task for main thread: - // - set happened_0 - tm.add_task( - ccf::threading::MAIN_THREAD_ID, - std::make_unique>(&always, happened_0)); - - // Queue 2 tasks for worker a: - // - set happened_1 - // - set happened_2 - tm.add_task( - worker_a_id, - std::make_unique<::threading::Tmsg>(&always, happened_1)); - tm.add_task( - worker_a_id, - std::make_unique<::threading::Tmsg>(&always, happened_2)); - - // Fail to queue task for worker b, tm is too small - REQUIRE_THROWS(tm.add_task( - worker_b_id, - std::make_unique<::threading::Tmsg>(&always, happened_3))); - - // Run single task on main thread - REQUIRE(tm.run_one()); - // Confirm there are no more tasks for main thread - REQUIRE_FALSE(tm.run_one()); - - // Confirm only first task has been executed - REQUIRE(happened_0); - REQUIRE_FALSE(happened_1); - REQUIRE_FALSE(happened_2); - REQUIRE_FALSE(happened_3); - - std::thread t([&]() { - // Run tasks for worker "a" - REQUIRE(ccf::threading::get_current_thread_id() == worker_a_id); - - REQUIRE(tm.run_one()); - REQUIRE(happened_1); - REQUIRE_FALSE(happened_2); - - REQUIRE(tm.run_one()); - REQUIRE(happened_2); - - REQUIRE_FALSE(tm.run_one()); - }); - - t.join(); - - REQUIRE(happened_0); - REQUIRE(happened_1); - REQUIRE(happened_2); - REQUIRE_FALSE(happened_3); - } -} - -// Note: this only works with ASAN turned on, which catches m2 not being -// freed. -TEST_CASE( - "Unpopped messages are freed" * doctest::test_suite("threadmessaging")) -{ - bool happened = false; - - { - ::threading::ThreadMessaging tm(1); - - auto m1 = std::make_unique<::threading::Tmsg>(&always, happened); - tm.add_task(0, std::move(m1)); - - // Task payload (and TMsg) is freed after running - tm.run_one(); - CHECK(Foo::count == 0); - - auto m2 = std::make_unique<::threading::Tmsg>(&never, happened); - tm.add_task(0, std::move(m2)); - // Task is owned by the queue, hasn't run - CHECK(Foo::count == 1); - } - // Task payload (and TMsg) is also freed if it hasn't run - // but the queue was destructed - CHECK(Foo::count == 0); - - CHECK(happened); -} - -TEST_CASE("Unique thread IDs" * doctest::test_suite("threadmessaging")) -{ - std::mutex assigned_ids_lock; - std::vector assigned_ids; - - const auto main_thread_id = ccf::threading::get_current_thread_id(); - REQUIRE(main_thread_id == ccf::threading::MAIN_THREAD_ID); - assigned_ids.push_back(main_thread_id); - - auto fn = [&]() { - { - std::lock_guard guard(assigned_ids_lock); - const auto current_thread_id = ccf::threading::get_current_thread_id(); - assigned_ids.push_back(current_thread_id); - } - }; - - constexpr size_t num_threads = 20; - constexpr size_t expected_ids = num_threads + 1; // Includes MAIN_THREAD_ID - std::vector threads; - for (auto i = 0; i < num_threads; ++i) - { - threads.emplace_back(fn); - } - - size_t attempts = 0; - constexpr size_t max_attempts = 5; - while (true) - { - { - std::lock_guard guard(assigned_ids_lock); - if (assigned_ids.size() == expected_ids) - { - break; - } - } - - REQUIRE(++attempts < max_attempts); - std::this_thread::sleep_for(std::chrono::milliseconds(50)); - } - - REQUIRE(assigned_ids.size() == expected_ids); - - for (auto& thread : threads) - { - thread.join(); - } - - const auto unique = std::unique(assigned_ids.begin(), assigned_ids.end()); - REQUIRE_MESSAGE( - unique == assigned_ids.end(), - fmt::format( - "Thread IDs are not unique: {}", fmt::join(assigned_ids, ", "))); -} diff --git a/src/ds/thread_messaging.h b/src/ds/thread_messaging.h deleted file mode 100644 index 075d4787b196..000000000000 --- a/src/ds/thread_messaging.h +++ /dev/null @@ -1,397 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the Apache 2.0 License. -#pragma once - -#include "ccf/threading/thread_ids.h" -#include "ds/ccf_assert.h" -#include "ds/internal_logger.h" - -#include -#include -#include - -namespace threading -{ - struct ThreadMsg - { - void (*cb)(std::unique_ptr); - std::atomic next = nullptr; - - ThreadMsg(void (*_cb)(std::unique_ptr)) : cb(_cb) {} - - virtual ~ThreadMsg() = default; - }; - - template - struct alignas(16) Tmsg : public ThreadMsg - { - Payload data; - - template - Tmsg(void (*_cb)(std::unique_ptr>), Args&&... args) : - ThreadMsg(reinterpret_cast)>(_cb)), - data(std::forward(args)...) - {} - - void reset_cb(void (*_cb)(std::unique_ptr>)) - { - cb = reinterpret_cast)>(_cb); - } - - virtual ~Tmsg() = default; - }; - - class ThreadMessaging; - - class TaskQueue - { - std::atomic item_head = nullptr; - ThreadMsg* local_msg = nullptr; - - public: - TaskQueue() = default; - - bool run_next_task() - { - if (local_msg == nullptr && item_head != nullptr) - { - local_msg = item_head.exchange(nullptr); - reverse_local_messages(); - } - - if (local_msg == nullptr) - { - return false; - } - - ThreadMsg* current = local_msg; - local_msg = local_msg->next; - - current->cb(std::unique_ptr(current)); - return true; - } - - void add_task(ThreadMsg* item) - { - ThreadMsg* tmp_head; - do - { - tmp_head = item_head.load(); - item->next = tmp_head; - } while (!item_head.compare_exchange_strong(tmp_head, item)); - } - - struct TimerEntry - { - TimerEntry() : time_offset(0), counter(0) {} - TimerEntry(std::chrono::milliseconds time_offset_, uint64_t counter_) : - time_offset(time_offset_), - counter(counter_) - {} - - std::chrono::milliseconds time_offset; - uint64_t counter; - }; - - struct TimerEntryCompare - { - bool operator()(const TimerEntry& lhs, const TimerEntry& rhs) const - { - if (lhs.time_offset != rhs.time_offset) - { - return lhs.time_offset < rhs.time_offset; - } - - return lhs.counter < rhs.counter; - } - }; - - TimerEntry add_task_after( - std::unique_ptr item, std::chrono::milliseconds ms) - { - TimerEntry entry = {time_offset + ms, time_entry_counter++}; - if (timer_map.empty() || entry.time_offset <= next_time_offset) - { - next_time_offset = entry.time_offset; - } - - timer_map.emplace(entry, std::move(item)); - return entry; - } - - bool cancel_timer_task(TimerEntry timer_entry) - { - auto num_erased = timer_map.erase(timer_entry); - CCF_ASSERT(num_erased <= 1, "Too many items erased"); - if (!timer_map.empty() && timer_entry.time_offset <= next_time_offset) - { - next_time_offset = timer_map.begin()->first.time_offset; - } - return num_erased != 0; - } - - void tick(std::chrono::milliseconds elapsed) - { - time_offset += elapsed; - - bool updated = false; - - while (!timer_map.empty() && next_time_offset <= time_offset && - timer_map.begin()->first.time_offset <= time_offset) - { - updated = true; - auto it = timer_map.begin(); - - auto& cb = it->second->cb; - auto msg = std::move(it->second); - timer_map.erase(it); - cb(std::move(msg)); - } - - if (updated && !timer_map.empty()) - { - next_time_offset = timer_map.begin()->first.time_offset; - } - } - - std::chrono::milliseconds get_current_time_offset() - { - return time_offset; - } - - private: - std::chrono::milliseconds time_offset = std::chrono::milliseconds(0); - uint64_t time_entry_counter = 0; - std::map, TimerEntryCompare> - timer_map; - std::chrono::milliseconds next_time_offset; - - void reverse_local_messages() - { - if (local_msg == nullptr) - return; - - ThreadMsg *prev = nullptr, *current = nullptr, *next = nullptr; - current = local_msg; - while (current != nullptr) - { - next = current->next; - current->next = prev; - prev = current; - current = next; - } - // now let the head point at the last node (prev) - local_msg = prev; - } - - void drop() - { - while (true) - { - if (local_msg == nullptr && item_head != nullptr) - { - local_msg = item_head.exchange(nullptr); - reverse_local_messages(); - } - - if (local_msg == nullptr) - { - break; - } - - ThreadMsg* current = local_msg; - local_msg = local_msg->next; - delete current; - } - } - - friend ThreadMessaging; - }; - - class ThreadMessaging - { - std::atomic finished; - std::vector tasks; // Fixed-size at construction - - // Drop all pending tasks, this is only ever to be used - // on shutdown, to avoid leaks, and after all thread but - // the main one have been shut down. - void drop_tasks() - { - for (auto& t : tasks) - { - t.drop(); - } - } - - inline TaskQueue& get_tasks(uint16_t task_id) - { - if (task_id >= tasks.size()) - { - throw std::runtime_error(fmt::format( - "Attempting to access task_id >= task_count, task_id:{}, " - "task_count:{}", - task_id, - tasks.size())); - } - return tasks[task_id]; - } - - static std::unique_ptr& get_singleton() - { - static std::unique_ptr singleton = nullptr; - return singleton; - } - - public: - static constexpr uint16_t max_num_threads = 24; - - ThreadMessaging(uint16_t num_task_queues) : - finished(false), - tasks(num_task_queues) - { - if (num_task_queues > max_num_threads) - { - throw std::logic_error(fmt::format( - "ThreadMessaging constructed with too many tasks: {} > {}", - num_task_queues, - max_num_threads)); - } - } - - ~ThreadMessaging() - { - drop_tasks(); - } - - static void init(uint16_t num_task_queues) - { - auto& singleton = get_singleton(); - if (singleton != nullptr) - { - throw std::logic_error("Called init() multiple times"); - } - - singleton = std::make_unique(num_task_queues); - } - - static void shutdown() - { - get_singleton().reset(); - } - - static ThreadMessaging& instance() - { - auto& singleton = get_singleton(); - if (singleton == nullptr) - { - throw std::logic_error( - "Attempted to access global ThreadMessaging instance without first " - "calling init()"); - } - - return *singleton; - } - - void set_finished(bool v = true) - { - finished.store(v); - } - - void run() - { - TaskQueue& task = get_tasks(ccf::threading::get_current_thread_id()); - - while (!is_finished()) - { - task.run_next_task(); - } - } - - bool run_one() - { - TaskQueue& task = get_tasks(ccf::threading::get_current_thread_id()); - return task.run_next_task(); - } - - template - void add_task(uint16_t tid, std::unique_ptr> msg) - { - TaskQueue& task = get_tasks(tid); - - task.add_task(reinterpret_cast(msg.release())); - } - - template - TaskQueue::TimerEntry add_task_after( - std::unique_ptr> msg, std::chrono::milliseconds ms) - { - TaskQueue& task = get_tasks(ccf::threading::get_current_thread_id()); - return task.add_task_after(std::move(msg), ms); - } - - bool cancel_timer_task(TaskQueue::TimerEntry timer_entry) - { - TaskQueue& task = get_tasks(ccf::threading::get_current_thread_id()); - return task.cancel_timer_task(timer_entry); - } - - std::chrono::milliseconds get_current_time_offset() - { - TaskQueue& task = get_tasks(ccf::threading::get_current_thread_id()); - return task.get_current_time_offset(); - } - - struct TickMsg - { - TickMsg(std::chrono::milliseconds elapsed_, TaskQueue& task_) : - elapsed(elapsed_), - task(task_) - {} - - std::chrono::milliseconds elapsed; - TaskQueue& task; - }; - - static void tick_cb(std::unique_ptr> msg) - { - msg->data.task.tick(msg->data.elapsed); - } - - void tick(std::chrono::milliseconds elapsed) - { - for (auto i = 0ul; i < tasks.size(); ++i) - { - auto& task = get_tasks(i); - auto msg = std::make_unique>(&tick_cb, elapsed, task); - task.add_task(msg.release()); - } - } - - uint16_t get_execution_thread(uint32_t i) - { - uint16_t tid = ccf::threading::MAIN_THREAD_ID; - if (tasks.size() > 1) - { - // If we have multiple task queues, then we distinguish the main thread - // from the remaining workers; anything asking for an execution thread - // does _not_ go to the main thread's queue - tid = (i % (tasks.size() - 1)); - ++tid; - } - - return tid; - } - - uint16_t thread_count() const - { - return tasks.size(); - } - - private: - bool is_finished() - { - return finished.load(); - } - }; -}; diff --git a/src/enclave/enclave.h b/src/enclave/enclave.h index f1de383c71d3..e2cf67e368a3 100644 --- a/src/enclave/enclave.h +++ b/src/enclave/enclave.h @@ -48,6 +48,7 @@ namespace ccf std::unique_ptr node; ringbuffer::WriterPtr to_host = nullptr; std::chrono::high_resolution_clock::time_point last_tick_time; + std::atomic worker_stop_signal = false; StartType start_type; @@ -231,9 +232,9 @@ namespace ccf lfs_access->register_message_handlers(bp.get_dispatcher()); DISPATCHER_SET_MESSAGE_HANDLER( - bp, AdminMessage::stop, [&bp](const uint8_t*, size_t) { + bp, AdminMessage::stop, [this, &bp](const uint8_t*, size_t) { bp.set_finished(); - ::threading::ThreadMessaging::instance().set_finished(); + this->worker_stop_signal.store(true); }); DISPATCHER_SET_MESSAGE_HANDLER( @@ -263,7 +264,7 @@ namespace ccf node->tick(elapsed_ms); historical_state_cache->tick(elapsed_ms); - ::threading::ThreadMessaging::instance().tick(elapsed_ms); + ccf::tasks::tick(elapsed_ms); // When recovering, no signature should be emitted while the // public ledger is being read if (!node->is_reading_public_ledger()) @@ -384,17 +385,24 @@ namespace ccf // First, read some messages from the ringbuffer auto read = bp.read_n(max_messages, circuit->read_from_outside()); - // Then, execute some thread messages - size_t thread_msg = 0; - while (thread_msg < max_messages && - ::threading::ThreadMessaging::instance().run_one()) + // Then, execute some tasks + auto& job_board = ccf::tasks::get_main_job_board(); + ccf::tasks::Task task = job_board.get_task(); + size_t tasks_done = 0; + while (task != nullptr) { - thread_msg++; + task->do_task(); + ++tasks_done; + if (tasks_done >= max_messages) + { + break; + } + task = job_board.get_task(); } - // If no messages were read from the ringbuffer and no thread - // messages were executed, idle - if (read == 0 && thread_msg == 0) + // If no messages were read from the ringbuffer and tasks were + // executed, idle + if (read == 0 && tasks_done == 0) { std::this_thread::yield(); } @@ -407,27 +415,22 @@ namespace ccf } } - struct Msg - { - uint64_t tid; - }; - - static void init_thread_cb(std::unique_ptr<::threading::Tmsg> msg) - { - LOG_DEBUG_FMT("First thread CB:{}", msg->data.tid); - } - bool run_worker() { LOG_DEBUG_FMT("Running worker thread"); { - auto msg = std::make_unique<::threading::Tmsg>(&init_thread_cb); - msg->data.tid = ccf::threading::get_current_thread_id(); - ::threading::ThreadMessaging::instance().add_task( - msg->data.tid, std::move(msg)); + auto& job_board = ccf::tasks::get_main_job_board(); + const auto timeout = std::chrono::milliseconds(100); - ::threading::ThreadMessaging::instance().run(); + while (!worker_stop_signal.load()) + { + auto task = job_board.wait_for_task(timeout); + if (task != nullptr) + { + task->do_task(); + } + } } return true; diff --git a/src/enclave/main.cpp b/src/enclave/main.cpp index 0a5c45be436c..8132a6463f6c 100644 --- a/src/enclave/main.cpp +++ b/src/enclave/main.cpp @@ -60,16 +60,6 @@ extern "C" { num_pending_threads = (uint16_t)num_worker_threads + 1; - - if (num_pending_threads > threading::ThreadMessaging::max_num_threads) - { - LOG_FAIL_FMT("Too many threads: {}", num_pending_threads); - return CreateNodeStatus::TooManyThreads; - } - - // Initialise singleton instance of ThreadMessaging, now that number of - // threads are known - threading::ThreadMessaging::init(num_pending_threads); } // 2-tx reconfiguration is currently experimental, disable it in release @@ -195,11 +185,7 @@ extern "C" if (tid == ccf::threading::MAIN_THREAD_ID) { auto s = e.load()->run_main(); - while (num_complete_threads != - threading::ThreadMessaging::instance().thread_count() - 1) - { - } - threading::ThreadMessaging::shutdown(); + return s; } auto s = e.load()->run_worker(); diff --git a/src/enclave/session.h b/src/enclave/session.h index 9a4790099fd8..867676f16a59 100644 --- a/src/enclave/session.h +++ b/src/enclave/session.h @@ -3,8 +3,10 @@ #pragma once #include "ccf/node/session.h" -#include "ds/thread_messaging.h" #include "enclave/tls_session.h" +#include "tasks/ordered_tasks.h" +#include "tasks/task.h" +#include "tasks/task_system.h" #include "tcp/msg_types.h" #include @@ -15,20 +17,71 @@ namespace ccf public std::enable_shared_from_this { private: - size_t execution_thread; + std::shared_ptr task_scheduler; + std::atomic is_closing = false; - struct SendRecvMsg + struct SessionDataTask : public ccf::tasks::ITaskAction { std::vector data; std::shared_ptr self; + + SessionDataTask( + std::span d, std::shared_ptr s) : + self(s) + { + data.assign(d.begin(), d.end()); + } + }; + + struct HandleIncomingDataTask : public SessionDataTask + { + using SessionDataTask::SessionDataTask; + + void do_action() override + { + if (self->is_closing.load()) + { + return; + } + + self->handle_incoming_data_thread(std::move(data)); + } + + const std::string& get_name() const override + { + static const std::string name = + "ThreadedSession::HandleIncomingDataTask"; + return name; + } + }; + + struct SendDataTask : public SessionDataTask + { + using SessionDataTask::SessionDataTask; + + void do_action() override + { + self->send_data_thread(std::move(data)); + } + + const std::string& get_name() const override + { + static const std::string name = "ThreadedSession::SendDataTask"; + return name; + } }; public: - ThreadedSession(int64_t thread_affinity) + ThreadedSession(int64_t session_id) + { + task_scheduler = ccf::tasks::OrderedTasks::create( + ccf::tasks::get_main_job_board(), + fmt::format("Session {}", session_id)); + } + + ~ThreadedSession() { - execution_thread = - ::threading::ThreadMessaging::instance().get_execution_thread( - thread_affinity); + task_scheduler->cancel_task(); } // Implement Session::handle_incoming_data by dispatching a thread message @@ -37,19 +90,8 @@ namespace ccf { auto [_, body] = ringbuffer::read_message<::tcp::tcp_inbound>(data); - auto msg = std::make_unique<::threading::Tmsg>( - &handle_incoming_data_cb); - msg->data.self = this->shared_from_this(); - msg->data.data.assign(body.data, body.data + body.size); - - ::threading::ThreadMessaging::instance().add_task( - execution_thread, std::move(msg)); - } - - static void handle_incoming_data_cb( - std::unique_ptr<::threading::Tmsg> msg) - { - msg->data.self->handle_incoming_data_thread(std::move(msg->data.data)); + task_scheduler->add_action( + std::make_shared(body, shared_from_this())); } virtual void handle_incoming_data_thread(std::vector&& data) = 0; @@ -58,22 +100,21 @@ namespace ccf // that eventually invokes the virtual send_data_thread() void send_data(std::span data) override { - auto msg = - std::make_unique<::threading::Tmsg>(&send_data_cb); - msg->data.self = this->shared_from_this(); - msg->data.data.assign(data.begin(), data.end()); - - ::threading::ThreadMessaging::instance().add_task( - execution_thread, std::move(msg)); + task_scheduler->add_action( + std::make_shared(data, shared_from_this())); } - static void send_data_cb( - std::unique_ptr<::threading::Tmsg> msg) + virtual void send_data_thread(std::vector&& data) = 0; + + void close_session() override { - msg->data.self->send_data_thread(std::move(msg->data.data)); + is_closing.store(true); + + task_scheduler->add_action(ccf::tasks::make_basic_action( + [self = shared_from_this()]() { self->close_session_thread(); })); } - virtual void send_data_thread(std::vector&& data) = 0; + virtual void close_session_thread() = 0; }; class EncryptedSession : public ThreadedSession @@ -96,21 +137,9 @@ namespace ccf {} public: - void send_data(std::span data) override - { - // Override send_data rather than send_data_thread, as the TLSSession - // handles dispatching for thread affinity - tls_io->send_raw(data.data(), data.size()); - } - void send_data_thread(std::vector&& data) override { - throw std::logic_error("Unimplemented"); - } - - void close_session() override - { - tls_io->close(); + tls_io->send_data(data.data(), data.size()); } void handle_incoming_data_thread(std::vector&& data) override @@ -150,6 +179,11 @@ namespace ccf n_read = tls_io->read(data.data(), data.size(), false); } } + + void close_session_thread() override + { + tls_io->close(); + } }; class UnencryptedSession : public ccf::ThreadedSession @@ -178,7 +212,7 @@ namespace ccf serializer::ByteRange{data.data(), data.size()}); } - void close_session() override + void close_session_thread() override { RINGBUFFER_WRITE_MESSAGE( ::tcp::tcp_stop, to_host, session_id, std::string("Session closed")); diff --git a/src/enclave/thread_local.cpp b/src/enclave/thread_local.cpp index 23c8bdeda858..79b75010d712 100644 --- a/src/enclave/thread_local.cpp +++ b/src/enclave/thread_local.cpp @@ -7,6 +7,7 @@ namespace ccf::threading namespace { std::atomic next_thread_id = MAIN_THREAD_ID; + thread_local std::optional this_thread_name = std::nullopt; } uint16_t& current_thread_id() @@ -29,4 +30,19 @@ namespace ccf::threading { next_thread_id.store(to); } + + std::string get_current_thread_name() + { + if (!this_thread_name.has_value()) + { + this_thread_name = fmt::format("{}", get_current_thread_id()); + } + + return this_thread_name.value(); + } + + void set_current_thread_name(std::string_view sv) + { + this_thread_name = sv; + } } \ No newline at end of file diff --git a/src/enclave/tls_session.h b/src/enclave/tls_session.h index 44d15e5da6a0..0b1c72a1009a 100644 --- a/src/enclave/tls_session.h +++ b/src/enclave/tls_session.h @@ -5,7 +5,6 @@ #include "ds/internal_logger.h" #include "ds/messaging.h" #include "ds/ring_buffer.h" -#include "ds/thread_messaging.h" #include "tcp/msg_types.h" #include "tls/context.h" #include "tls/tls.h" @@ -32,7 +31,6 @@ namespace ccf protected: ringbuffer::WriterPtr to_host; ::tcp::ConnID session_id; - size_t execution_thread; private: std::vector pending_write; @@ -57,17 +55,6 @@ namespace ccf return status == ready || status == handshake; } - struct SendRecvMsg - { - std::vector data; - std::shared_ptr self; - }; - - struct EmptyMsg - { - std::shared_ptr self; - }; - public: TLSSession( int64_t session_id_, @@ -78,9 +65,6 @@ namespace ccf ctx(std::move(ctx_)), status(handshake) { - execution_thread = - ::threading::ThreadMessaging::instance().get_execution_thread( - session_id); ctx->set_bio(this, send_callback_openssl, recv_callback_openssl); } @@ -236,11 +220,6 @@ namespace ccf void recv_buffered(const uint8_t* data, size_t size) { - if (ccf::threading::get_current_thread_id() != execution_thread) - { - throw std::runtime_error("Called recv_buffered from incorrect thread"); - } - if (can_recv()) { pending_read.insert(pending_read.end(), data, data + size); @@ -252,32 +231,6 @@ namespace ccf void close() { status = closing; - if (ccf::threading::get_current_thread_id() != execution_thread) - { - auto msg = std::make_unique<::threading::Tmsg>(&close_cb); - msg->data.self = this->shared_from_this(); - - ::threading::ThreadMessaging::instance().add_task( - execution_thread, std::move(msg)); - } - else - { - // Close inline immediately - close_thread(); - } - } - - static void close_cb(std::unique_ptr<::threading::Tmsg> msg) - { - msg->data.self->close_thread(); - } - - virtual void close_thread() - { - if (ccf::threading::get_current_thread_id() != execution_thread) - { - throw std::runtime_error("Called close_thread from incorrect thread"); - } switch (status) { @@ -327,39 +280,8 @@ namespace ccf } } - void send_raw(const uint8_t* data, size_t size) + void send_data(const uint8_t* data, size_t size) { - if (ccf::threading::get_current_thread_id() != execution_thread) - { - auto msg = - std::make_unique<::threading::Tmsg>(&send_raw_cb); - msg->data.self = this->shared_from_this(); - msg->data.data = std::vector(data, data + size); - - ::threading::ThreadMessaging::instance().add_task( - execution_thread, std::move(msg)); - } - else - { - // Send inline immediately - send_raw_thread(data, size); - } - } - - private: - static void send_raw_cb(std::unique_ptr<::threading::Tmsg> msg) - { - msg->data.self->send_raw_thread( - msg->data.data.data(), msg->data.data.size()); - } - - void send_raw_thread(const uint8_t* data, size_t size) - { - if (ccf::threading::get_current_thread_id() != execution_thread) - { - throw std::runtime_error( - "Called send_raw_thread from incorrect thread"); - } // Writes as much of the data as possible. If the data cannot all // be written now, we store the remainder. We // will try to send pending writes again whenever write() is called. @@ -381,23 +303,14 @@ namespace ccf flush(); } + private: void send_buffered(const std::vector& data) { - if (ccf::threading::get_current_thread_id() != execution_thread) - { - throw std::runtime_error("Called send_buffered from incorrect thread"); - } - pending_write.insert(pending_write.end(), data.begin(), data.end()); } void flush() { - if (ccf::threading::get_current_thread_id() != execution_thread) - { - throw std::runtime_error("Called flush from incorrect thread"); - } - do_handshake(); if (!can_send()) @@ -574,10 +487,6 @@ namespace ccf int handle_recv(uint8_t* buf, size_t len) { - if (ccf::threading::get_current_thread_id() != execution_thread) - { - throw std::runtime_error("Called handle_recv from incorrect thread"); - } if (pending_read.size() > 0) { // Use the pending data vector. This is populated when the host diff --git a/src/http/curl.h b/src/http/curl.h index 1b822c66ce23..0b953372244f 100644 --- a/src/http/curl.h +++ b/src/http/curl.h @@ -4,9 +4,7 @@ #include "ccf/ds/nonstd.h" #include "ccf/rest_verb.h" -#include "ccf/threading/thread_ids.h" #include "ds/internal_logger.h" -#include "ds/thread_messaging.h" #include "host/proxy.h" #include @@ -389,7 +387,6 @@ namespace ccf::curl std::unique_ptr response; ResponseHeaders response_headers; std::optional response_callback; - std::optional response_thread; public: CurlRequest( @@ -399,17 +396,14 @@ namespace ccf::curl UniqueSlist&& headers_, std::unique_ptr&& request_body_, std::unique_ptr&& response_, - std::optional&& response_callback_, - std::optional response_thread_ = - threading::get_current_thread_id()) : + std::optional&& response_callback_) : curl_handle(std::move(curl_handle_)), method(method_), url(std::move(url_)), headers(std::move(headers_)), request_body(std::move(request_body_)), response(std::move(response_)), - response_callback(std::move(response_callback_)), - response_thread(response_thread_) + response_callback(std::move(response_callback_)) { if (url.empty()) { @@ -542,11 +536,6 @@ namespace ccf::curl { return response_headers.data; } - - [[nodiscard]] std::optional get_response_thread() const - { - return response_thread; - } }; class CurlRequestCURLM : public UniqueCURLM @@ -606,26 +595,9 @@ namespace ccf::curl // destructor of CurlRequest curl_multi_remove_handle(p.get(), easy); - // dispatch the response handling to a thread for processing - if (request->get_response_thread().has_value()) - { - using Data = - std::tuple, CURLcode>; - ::threading::ThreadMessaging::instance().add_task( - request->get_response_thread().value(), - std::make_unique<::threading::Tmsg>( - [](std::unique_ptr<::threading::Tmsg> msg) { - auto& [curl_request, curl_code] = msg->data; - CurlRequest::handle_response( - std::move(curl_request), curl_code); - }, - std::make_tuple(std::move(request_data_ptr), result))); - } - else - { - // If the response thread is not set, run on the uv thread - CurlRequest::handle_response(std::move(request_data_ptr), result); - } + // handle response inline. Note that if this is expensive, it should + // defer its work to a task + CurlRequest::handle_response(std::move(request_data_ptr), result); } } while (msgq > 0); return running_handles; diff --git a/src/http/http2_session.h b/src/http/http2_session.h index d74a7f959f1f..138303d56a9b 100644 --- a/src/http/http2_session.h +++ b/src/http/http2_session.h @@ -258,9 +258,7 @@ namespace http responder_lookup(responder_lookup_) { server_parser->set_outgoing_data_handler( - [this](std::span data) { - this->tls_io->send_raw(data.data(), data.size()); - }); + [this](std::span data) { send_data(data); }); } ~HTTP2ServerSession() @@ -460,9 +458,7 @@ namespace http client_parser(*this) { client_parser.set_outgoing_data_handler( - [this](std::span data) { - this->tls_io->send_raw(data.data(), data.size()); - }); + [this](std::span data) { send_data(data); }); } bool parse(std::span data) override diff --git a/src/http/http_session.h b/src/http/http_session.h index d7016ceb2e07..2ba05cee40d6 100644 --- a/src/http/http_session.h +++ b/src/http/http_session.h @@ -67,7 +67,7 @@ namespace http ccf::errors::RequestBodyTooLarge, e.what()}); - tls_io->close(); + close_session(); } catch (RequestHeaderTooLargeException& e) { @@ -83,7 +83,7 @@ namespace http ccf::errors::RequestHeaderTooLarge, e.what()}); - tls_io->close(); + close_session(); } catch (const std::exception& e) { @@ -113,7 +113,7 @@ namespace http {}, std::move(response_body)); - tls_io->close(); + close_session(); } return false; @@ -157,7 +157,7 @@ namespace http HTTP_STATUS_INTERNAL_SERVER_ERROR, ccf::errors::InternalError, fmt::format("Error constructing RpcContext: {}", e.what())}); - tls_io->close(); + close_session(); } std::shared_ptr search = @@ -181,7 +181,7 @@ namespace http if (rpc_ctx->terminate_session) { - tls_io->close(); + close_session(); } } } @@ -195,7 +195,7 @@ namespace http // On any exception, close the connection. LOG_FAIL_FMT("Closing connection"); LOG_DEBUG_FMT("Closing connection due to exception: {}", e.what()); - tls_io->close(); + close_session(); throw; } } @@ -224,7 +224,7 @@ namespace http ); auto data = response.build_response(); - tls_io->send_raw(data.data(), data.size()); + send_data(data); return true; } diff --git a/src/http/test/curl_test.cpp b/src/http/test/curl_test.cpp index f09b19a0930a..52918bddc662 100644 --- a/src/http/test/curl_test.cpp +++ b/src/http/test/curl_test.cpp @@ -126,8 +126,7 @@ TEST_CASE("CurlmLibuvContext") std::move(headers), std::move(body), std::make_unique(SIZE_MAX), - std::move(response_callback), - std::nullopt); + std::move(response_callback)); ccf::curl::CurlmLibuvContextSingleton::get_instance()->attach_request( std::move(request)); @@ -192,8 +191,7 @@ TEST_CASE("CurlmLibuvContext slow") std::move(headers), std::move(body), std::make_unique(SIZE_MAX), - std::move(response_callback), - std::nullopt); + std::move(response_callback)); ccf::curl::CurlmLibuvContextSingleton::get_instance()->attach_request( std::move(request)); @@ -266,8 +264,7 @@ TEST_CASE("CurlmLibuvContext timeouts") std::move(headers), std::move(body), std::make_unique(SIZE_MAX), - std::move(response_callback), - std::nullopt); + std::move(response_callback)); ccf::curl::CurlmLibuvContextSingleton::get_instance()->attach_request( std::move(request)); @@ -346,8 +343,7 @@ TEST_CASE("CurlmLibuvContext multiple init") std::move(headers), std::move(body), std::make_unique(SIZE_MAX), - std::move(response_callback), - std::nullopt); + std::move(response_callback)); ccf::curl::CurlmLibuvContextSingleton::get_instance()->attach_request( std::move(request)); diff --git a/src/indexing/test/indexing.cpp b/src/indexing/test/indexing.cpp index 245cdb845601..b4b8ccbdfab9 100644 --- a/src/indexing/test/indexing.cpp +++ b/src/indexing/test/indexing.cpp @@ -953,7 +953,6 @@ TEST_CASE( int main(int argc, char** argv) { - threading::ThreadMessaging::init(1); doctest::Context context; context.applyCommandLine(argc, argv); int res = context.run(); diff --git a/src/node/channels.h b/src/node/channels.h index f5757b33a228..69accec698e0 100644 --- a/src/node/channels.h +++ b/src/node/channels.h @@ -15,8 +15,7 @@ #include "ds/internal_logger.h" #include "ds/serialized.h" #include "ds/state_machine.h" -#include "ds/thread_messaging.h" -#include "node_types.h" +#include "node/node_types.h" #include #include diff --git a/src/node/history.h b/src/node/history.h index 0c8b2d391431..28e1050169c9 100644 --- a/src/node/history.h +++ b/src/node/history.h @@ -10,12 +10,13 @@ #include "crypto/openssl/hash.h" #include "crypto/openssl/key_pair.h" #include "ds/internal_logger.h" -#include "ds/thread_messaging.h" #include "endian.h" #include "kv/kv_types.h" #include "kv/store.h" #include "node_signature_verify.h" #include "service/tables/signatures.h" +#include "tasks/basic_task.h" +#include "tasks/task_system.h" #include #include @@ -573,8 +574,7 @@ namespace ccf ccf::crypto::COSEVerifierUniquePtr cose_verifier{}; std::vector cose_cert_cached{}; - std::optional<::threading::TaskQueue::TimerEntry> - emit_signature_timer_entry = std::nullopt; + ccf::tasks::Task emit_signature_periodic_task; size_t sig_tx_interval; size_t sig_ms_interval; @@ -645,72 +645,57 @@ namespace ccf void start_signature_emit_timer() override { - struct EmitSigMsg - { - EmitSigMsg(HashedTxHistory* self_) : self(self_) {} - HashedTxHistory* self; - }; - - auto emit_sig_msg = std::make_unique<::threading::Tmsg>( - [](std::unique_ptr<::threading::Tmsg> msg) { - auto self = msg->data.self; + const auto delay = std::chrono::milliseconds(sig_ms_interval); - std::unique_lock mguard( - self->signature_lock, std::defer_lock); + emit_signature_periodic_task = ccf::tasks::make_basic_task([this]() { + std::unique_lock mguard( + this->signature_lock, std::defer_lock); - bool should_emit_signature = false; + bool should_emit_signature = false; - if (mguard.try_lock()) + if (mguard.try_lock()) + { + auto consensus = this->store.get_consensus(); + if (consensus != nullptr) { - auto consensus = self->store.get_consensus(); - if (consensus != nullptr) + auto sig_disp = consensus->get_signature_disposition(); + switch (sig_disp) { - auto sig_disp = consensus->get_signature_disposition(); - switch (sig_disp) + case ccf::kv::Consensus::SignatureDisposition::CANT_REPLICATE: { - case ccf::kv::Consensus::SignatureDisposition::CANT_REPLICATE: - { - break; - } - case ccf::kv::Consensus::SignatureDisposition::CAN_SIGN: - { - if (self->store.committable_gap() > 0) - { - should_emit_signature = true; - } - break; - } - case ccf::kv::Consensus::SignatureDisposition::SHOULD_SIGN: + break; + } + case ccf::kv::Consensus::SignatureDisposition::CAN_SIGN: + { + if (this->store.committable_gap() > 0) { should_emit_signature = true; - break; } + break; + } + case ccf::kv::Consensus::SignatureDisposition::SHOULD_SIGN: + { + should_emit_signature = true; + break; } } } + } - if (should_emit_signature) - { - msg->data.self->emit_signature(); - } - - self->emit_signature_timer_entry = - ::threading::ThreadMessaging::instance().add_task_after( - std::move(msg), std::chrono::milliseconds(self->sig_ms_interval)); - }, - this); + if (should_emit_signature) + { + this->emit_signature(); + } + }); - emit_signature_timer_entry = - ::threading::ThreadMessaging::instance().add_task_after( - std::move(emit_sig_msg), std::chrono::milliseconds(sig_ms_interval)); + ccf::tasks::add_periodic_task(emit_signature_periodic_task, delay, delay); } ~HashedTxHistory() { - if (emit_signature_timer_entry.has_value()) + if (emit_signature_periodic_task != nullptr) { - ::threading::ThreadMessaging::instance().cancel_timer_task( - *emit_signature_timer_entry); + emit_signature_periodic_task->cancel_task(); } } diff --git a/src/node/jwt_key_auto_refresh.h b/src/node/jwt_key_auto_refresh.h index 8f5faf8d641f..1e6f1d5be587 100644 --- a/src/node/jwt_key_auto_refresh.h +++ b/src/node/jwt_key_auto_refresh.h @@ -24,6 +24,8 @@ namespace ccf ccf::crypto::Pem node_cert; std::atomic_size_t attempts; + ccf::tasks::Task periodic_refresh_task; + public: JwtKeyAutoRefresh( size_t refresh_interval_s, @@ -43,62 +45,54 @@ namespace ccf attempts(0) {} - struct RefreshTimeMsg + ~JwtKeyAutoRefresh() { - RefreshTimeMsg(JwtKeyAutoRefresh& self_) : self(self_) {} - - JwtKeyAutoRefresh& self; - }; + stop(); + } void start() { - auto refresh_msg = std::make_unique<::threading::Tmsg>( - [](std::unique_ptr<::threading::Tmsg> msg) { - if (!msg->data.self.consensus->can_replicate()) - { - LOG_DEBUG_FMT( - "JWT key auto-refresh: Node is not primary, skipping"); - } - else - { - msg->data.self.refresh_jwt_keys(); - } - LOG_DEBUG_FMT( - "JWT key auto-refresh: Scheduling in {}s", - msg->data.self.refresh_interval_s); - auto delay = std::chrono::seconds(msg->data.self.refresh_interval_s); - ::threading::ThreadMessaging::instance().add_task_after( - std::move(msg), delay); - }, - *this); + LOG_DEBUG_FMT("JWT key initial auto-refresh"); + periodic_refresh_task = ccf::tasks::make_basic_task([this]() { + if (!this->consensus->can_replicate()) + { + LOG_DEBUG_FMT("JWT key auto-refresh: Node is not primary, skipping"); + } + else + { + this->refresh_jwt_keys(); + } - LOG_DEBUG_FMT( - "JWT key auto-refresh: Scheduling in {}s", refresh_interval_s); - auto delay = std::chrono::seconds(refresh_interval_s); - ::threading::ThreadMessaging::instance().add_task_after( - std::move(refresh_msg), delay); + LOG_DEBUG_FMT( + "JWT key auto-refresh: Scheduling in {}s", this->refresh_interval_s); + }); + + const std::chrono::seconds period(refresh_interval_s); + ccf::tasks::add_periodic_task(periodic_refresh_task, period, period); } - void schedule_once() + void stop() { - auto refresh_msg = std::make_unique<::threading::Tmsg>( - [](std::unique_ptr<::threading::Tmsg> msg) { - if (!msg->data.self.consensus->can_replicate()) - { - LOG_DEBUG_FMT( - "JWT key one-off refresh: Node is not primary, skipping"); - } - else - { - msg->data.self.refresh_jwt_keys(); - } - }, - *this); + if (periodic_refresh_task != nullptr) + { + periodic_refresh_task->cancel_task(); + } + } + void schedule_once() + { LOG_DEBUG_FMT("JWT key one-off refresh: Scheduling without delay"); - auto delay = std::chrono::seconds(0); - ::threading::ThreadMessaging::instance().add_task_after( - std::move(refresh_msg), delay); + ccf::tasks::add_task(ccf::tasks::make_basic_task([this]() { + if (!this->consensus->can_replicate()) + { + LOG_DEBUG_FMT( + "JWT key one-off refresh: Node is not primary, skipping"); + } + else + { + this->refresh_jwt_keys(); + } + })); } template diff --git a/src/node/node_state.h b/src/node/node_state.h index f4b62c3d8f9a..afcd8c4510c1 100644 --- a/src/node/node_state.h +++ b/src/node/node_state.h @@ -111,21 +111,6 @@ namespace ccf std::atomic stop_noticed = false; - struct NodeStateMsg - { - NodeStateMsg( - NodeState& self_, - View create_view_ = 0, - bool create_consortium_ = true) : - self(self_), - create_view(create_view_), - create_consortium(create_consortium_) - {} - NodeState& self; - View create_view; - bool create_consortium; - }; - // // kv store, replication, and I/O // @@ -173,6 +158,8 @@ namespace ccf // the lifetime of the node ccf::kv::Version startup_seqno = 0; + ccf::tasks::Task join_periodic_task; + std::shared_ptr make_encryptor() { #ifdef USE_NULL_ENCRYPTOR @@ -865,6 +852,12 @@ namespace ccf sm.advance(NodeStartupState::partOfNetwork); } + if (join_periodic_task != nullptr) + { + join_periodic_task->cancel_task(); + join_periodic_task = nullptr; + } + LOG_INFO_FMT( "Node has now joined the network as node {}: {}", self, @@ -925,23 +918,18 @@ namespace ccf { initiate_join_unsafe(); - auto timer_msg = std::make_unique<::threading::Tmsg>( - [](std::unique_ptr<::threading::Tmsg> msg) { - std::lock_guard guard(msg->data.self.lock); - if (msg->data.self.sm.check(NodeStartupState::pending)) - { - msg->data.self.initiate_join_unsafe(); - auto delay = std::chrono::milliseconds( - msg->data.self.config.join.retry_timeout); - - ::threading::ThreadMessaging::instance().add_task_after( - std::move(msg), delay); - } - }, - *this); + join_periodic_task = ccf::tasks::make_basic_task([this]() { + std::lock_guard guard(this->lock); + if (this->sm.check(NodeStartupState::pending)) + { + this->initiate_join_unsafe(); + } + }); - ::threading::ThreadMessaging::instance().add_task_after( - std::move(timer_msg), config.join.retry_timeout); + ccf::tasks::add_periodic_task( + join_periodic_task, + config.join.retry_timeout, + config.join.retry_timeout); } void auto_refresh_jwt_keys() @@ -2103,30 +2091,23 @@ namespace ccf { // Service creation transaction is asynchronous to avoid deadlocks // (e.g. https://github.com/microsoft/CCF/issues/3788) - auto msg = std::make_unique<::threading::Tmsg>( - [](std::unique_ptr<::threading::Tmsg> msg) { - if (!msg->data.self.send_create_request( - msg->data.self.serialize_create_request( - msg->data.create_view, msg->data.create_consortium))) + ccf::tasks::add_task( + ccf::tasks::make_basic_task([this, create_view, create_consortium]() { + if (!this->send_create_request( + this->serialize_create_request(create_view, create_consortium))) { throw std::runtime_error( "Service creation request could not be committed"); } - if (msg->data.create_consortium) + if (create_consortium) { - msg->data.self.advance_part_of_network(); + this->advance_part_of_network(); } else { - msg->data.self.advance_part_of_public_network(); + this->advance_part_of_public_network(); } - }, - *this, - create_view, - create_consortium); - - ::threading::ThreadMessaging::instance().add_task( - threading::get_current_thread_id(), std::move(msg)); + })); } void begin_private_recovery() diff --git a/src/node/node_to_node_channel_manager.h b/src/node/node_to_node_channel_manager.h index 19d8610d927e..96bd07f3e037 100644 --- a/src/node/node_to_node_channel_manager.h +++ b/src/node/node_to_node_channel_manager.h @@ -4,6 +4,7 @@ #include "ccf/pal/locking.h" #include "channels.h" +#include "ds/ccf_assert.h" #include "node/node_to_node.h" namespace ccf diff --git a/src/node/quote_endorsements_client.h b/src/node/quote_endorsements_client.h index 66cf344cb543..e5f2bcef945b 100644 --- a/src/node/quote_endorsements_client.h +++ b/src/node/quote_endorsements_client.h @@ -4,7 +4,6 @@ #include "ccf/pal/attestation.h" #include "ccf/pal/attestation_sev_snp_endorsements.h" -#include "ds/thread_messaging.h" #include "enclave/rpc_sessions.h" #include "http/curl.h" @@ -70,22 +69,6 @@ namespace ccf Server server; }; - struct QuoteEndorsementsClientTimeoutMsg - { - QuoteEndorsementsClientTimeoutMsg( - const std::shared_ptr& self_, - const EndpointInfo& endpoint_, - size_t request_id_) : - self(self_), - endpoint(endpoint_), - request_id(request_id_) - {} - - std::shared_ptr self; - EndpointInfo endpoint; - size_t request_id; - }; - void handle_success_response_unsafe(std::vector&& data) { auto& server = servers.front(); @@ -154,54 +137,28 @@ namespace ccf fetch_unsafe(); } - void fetch_unsafe() + struct HandleResponseTask : public ccf::tasks::BaseTask { - const auto& server = servers.front(); - const auto& endpoint = server.front(); - - curl::UniqueCURL curl_handle; - - // Set curl get - curl_handle.set_opt(CURLOPT_HTTPGET, 1L); - // If the server does not respond at all within this time timeout - curl_handle.set_opt(CURLOPT_CONNECTTIMEOUT, server_connection_timeout_s); - // If the server does not completely response within this time timeout - curl_handle.set_opt(CURLOPT_TIMEOUT, server_response_timeout_s); - - auto url = fmt::format( - "{}://{}:{}{}{}", - endpoint.tls ? "https" : "http", - endpoint.host, - endpoint.port, - endpoint.uri, - get_formatted_query(endpoint.params)); - - if (endpoint.tls) - { - // Note: server CA is not checked here as this client is not sending - // private data. If the server was malicious and the certificate chain - // was bogus, the verification of the endorsement of the quote would - // fail anyway. - curl_handle.set_opt(CURLOPT_SSL_VERIFYHOST, 0L); - curl_handle.set_opt(CURLOPT_SSL_VERIFYPEER, 0L); - curl_handle.set_opt(CURLOPT_SSL_VERIFYSTATUS, 0L); - } + std::shared_ptr self; + std::unique_ptr request; + CURLcode curl_response; + long status_code; + + HandleResponseTask( + std::shared_ptr self_, + std::unique_ptr&& request_, + CURLcode curl_response_, + long status_code_) : + self(self_), + request(std::move(request_)), + curl_response(curl_response_), + status_code(status_code_) + {} - auto headers = ccf::curl::UniqueSlist(); - for (auto const& [k, v] : endpoint.headers) + void do_task_implementation() override { - headers.append(k, v); - } - headers.append(http::headers::HOST, endpoint.host); - - auto response_callback = ([this, lifetime = shared_from_this()]( - std::unique_ptr&& request, - CURLcode curl_response, - long status_code) { - std::lock_guard guard(this->lock); + std::lock_guard guard(self->lock); - const auto& server = servers.front(); - const auto& endpoint = server.front(); auto* response_body = request->get_response_body(); const auto& response_headers = request->get_response_headers(); @@ -212,7 +169,8 @@ namespace ccf "{} bytes", response_body->buffer.size()); - handle_success_response_unsafe(std::move(response_body->buffer)); + self->handle_success_response_unsafe( + std::move(response_body->buffer)); return; } @@ -222,15 +180,17 @@ namespace ccf curl_response, status_code); - if (server_retries_count >= max_retries_count(server)) + if ( + self->server_retries_count >= + max_retries_count(self->servers.front())) { - servers.pop_front(); + self->servers.pop_front(); - if (servers.empty()) + if (self->servers.empty()) { auto servers_tried = std::accumulate( - config.servers.begin(), - config.servers.end(), + self->config.servers.begin(), + self->config.servers.end(), std::string{}, [](const std::string& a, const Server& b) { return a + (a.length() > 0 ? ", " : "") + b.front().host; @@ -239,19 +199,21 @@ namespace ccf "Giving up retrying fetching attestation endorsements from [{}] " "after {} attempts", servers_tried, - total_retries_count); + self->total_retries_count); throw ccf::pal::AttestationCollateralFetchingTimeout( "Timed out fetching attestation endorsements from all " "configured servers"); } - server_retries_count = 0; - fetch_unsafe(); + self->server_retries_count = 0; + self->fetch_unsafe(); } else { - ++this->server_retries_count; - ++this->total_retries_count; + ++self->server_retries_count; + ++self->total_retries_count; + + const auto& endpoint = self->servers.front().front(); constexpr size_t default_retry_after_s = 3; size_t retry_after_s = default_retry_after_s; @@ -285,16 +247,77 @@ namespace ccf retry_after_s); } - auto msg = - std::make_unique<::threading::Tmsg>( - [](std::unique_ptr<::threading::Tmsg> - msg) { msg->data.self->fetch(); }, - shared_from_this(), - server); + const std::chrono::seconds retry_after(retry_after_s); - ::threading::ThreadMessaging::instance().add_task_after( - std::move(msg), std::chrono::seconds(retry_after_s)); + LOG_INFO_FMT( + "{} endorsements endpoint had too many requests. Retrying " + "in {}s", + endpoint, + retry_after_s); + + ccf::tasks::add_delayed_task( + ccf::tasks::make_basic_task( + [self = this->self]() { self->fetch(); }), + retry_after); } + } + + const std::string& get_name() const override + { + static const std::string name = + "QuoteEndorsementsClient::HandleResponseTask"; + return name; + } + }; + + void fetch_unsafe() + { + const auto& server = servers.front(); + const auto& endpoint = server.front(); + + curl::UniqueCURL curl_handle; + + // Set curl get + curl_handle.set_opt(CURLOPT_HTTPGET, 1L); + // If the server does not respond at all within this time timeout + curl_handle.set_opt(CURLOPT_CONNECTTIMEOUT, server_connection_timeout_s); + // If the server does not completely response within this time timeout + curl_handle.set_opt(CURLOPT_TIMEOUT, server_response_timeout_s); + + auto url = fmt::format( + "{}://{}:{}{}{}", + endpoint.tls ? "https" : "http", + endpoint.host, + endpoint.port, + endpoint.uri, + get_formatted_query(endpoint.params)); + + if (endpoint.tls) + { + // Note: server CA is not checked here as this client is not sending + // private data. If the server was malicious and the certificate chain + // was bogus, the verification of the endorsement of the quote would + // fail anyway. + curl_handle.set_opt(CURLOPT_SSL_VERIFYHOST, 0L); + curl_handle.set_opt(CURLOPT_SSL_VERIFYPEER, 0L); + curl_handle.set_opt(CURLOPT_SSL_VERIFYSTATUS, 0L); + } + + auto headers = ccf::curl::UniqueSlist(); + for (auto const& [k, v] : endpoint.headers) + { + headers.append(k, v); + } + headers.append(http::headers::HOST, endpoint.host); + + auto response_callback = ([self = shared_from_this()]( + std::unique_ptr&& request, + CURLcode curl_response, + long status_code) { + std::shared_ptr response_task = + std::make_shared( + self, std::move(request), curl_response, status_code); + ccf::tasks::add_task(response_task); }); auto request = std::make_unique( @@ -310,7 +333,6 @@ namespace ccf LOG_INFO_FMT( "Fetching endorsements for attestation report at {}", request->get_url()); - curl::CurlmLibuvContextSingleton::get_instance()->attach_request( std::move(request)); } diff --git a/src/node/retired_nodes_cleanup.h b/src/node/retired_nodes_cleanup.h index dd2de0de0792..799e18bfa38c 100644 --- a/src/node/retired_nodes_cleanup.h +++ b/src/node/retired_nodes_cleanup.h @@ -2,8 +2,9 @@ // Licensed under the Apache 2.0 License. #pragma once -#include "ds/thread_messaging.h" #include "node_client.h" +#include "tasks/basic_task.h" +#include "tasks/task_system.h" namespace ccf { @@ -30,24 +31,10 @@ namespace ccf node_client->make_request(request); } - struct RetiredNodeCleanupMsg - { - RetiredNodeCleanupMsg(RetiredNodeCleanup& self_) : self(self_) {} - - RetiredNodeCleanup& self; - }; - void cleanup() { - auto cleanup_msg = - std::make_unique<::threading::Tmsg>( - [](std::unique_ptr<::threading::Tmsg> msg) { - msg->data.self.send_cleanup_retired_nodes(); - }, - *this); - - ::threading::ThreadMessaging::instance().add_task( - ccf::threading::get_current_thread_id(), std::move(cleanup_msg)); + ccf::tasks::add_task(ccf::tasks::make_basic_task( + [this]() { this->send_cleanup_retired_nodes(); })); } }; } \ No newline at end of file diff --git a/src/node/rpc/forwarder.h b/src/node/rpc/forwarder.h index c1b6109db995..acf69ed3721c 100644 --- a/src/node/rpc/forwarder.h +++ b/src/node/rpc/forwarder.h @@ -8,6 +8,8 @@ #include "http/http_rpc_context.h" #include "kv/kv_types.h" #include "node/node_to_node.h" +#include "tasks/basic_task.h" +#include "tasks/task_system.h" namespace ccf { @@ -34,58 +36,11 @@ namespace ccf using ForwardedCommandId = ForwardedHeader_v2::ForwardedCommandId; ForwardedCommandId next_command_id = 0; - struct TimeoutTask - { - ::threading::TaskQueue::TimerEntry timer_entry; - uint16_t thread_id; - }; - - std::unordered_map timeout_tasks; + std::unordered_map timeout_tasks; ccf::pal::Mutex timeout_tasks_lock; using IsCallerCertForwarded = bool; - struct SendTimeoutErrorMsg - { - SendTimeoutErrorMsg( - Forwarder* forwarder_, - const ccf::NodeId& to_, - size_t client_session_id_, - const std::chrono::milliseconds& timeout_) : - forwarder(forwarder_), - to(to_), - client_session_id(client_session_id_), - timeout(timeout_) - {} - - Forwarder* forwarder; - ccf::NodeId to; - size_t client_session_id; - std::chrono::milliseconds timeout; - }; - - struct CancelTimerMsg - { - ::threading::TaskQueue::TimerEntry timer_entry; - }; - - std::unique_ptr<::threading::Tmsg> - create_timeout_error_task( - const ccf::NodeId& to, - size_t client_session_id, - const std::chrono::milliseconds& timeout) - { - return std::make_unique<::threading::Tmsg>( - [](std::unique_ptr<::threading::Tmsg> msg) { - msg->data.forwarder->send_timeout_error_response( - msg->data.to, msg->data.client_session_id, msg->data.timeout); - }, - this, - to, - client_session_id, - timeout); - } - void send_timeout_error_response( NodeId to, size_t client_session_id, @@ -108,18 +63,6 @@ namespace ccf } } - static void cancel_forwarding_task_cb( - std::unique_ptr<::threading::Tmsg> msg) - { - cancel_forwarding_task(msg->data.timer_entry); - } - - static void cancel_forwarding_task( - ::threading::TaskQueue::TimerEntry timer_entry) - { - ::threading::ThreadMessaging::instance().cancel_timer_task(timer_entry); - } - public: Forwarder( std::weak_ptr rpcresponder, @@ -171,10 +114,11 @@ namespace ccf { std::lock_guard guard(timeout_tasks_lock); command_id = next_command_id++; - timeout_tasks[command_id] = { - ::threading::ThreadMessaging::instance().add_task_after( - create_timeout_error_task(to, client_session_id, timeout), timeout), - ccf::threading::get_current_thread_id()}; + auto task = ccf::tasks::make_basic_task([=, this]() { + this->send_timeout_error_response(to, client_session_id, timeout); + }); + timeout_tasks[command_id] = task; + ccf::tasks::add_delayed_task(task, timeout); } const auto view_opt = session_ctx->active_view; @@ -459,20 +403,7 @@ namespace ccf auto it = timeout_tasks.find(cmd_id); if (it != timeout_tasks.end()) { - if ( - ccf::threading::get_current_thread_id() != it->second.thread_id) - { - auto msg = std::make_unique<::threading::Tmsg>( - &cancel_forwarding_task_cb); - msg->data.timer_entry = it->second.timer_entry; - - ::threading::ThreadMessaging::instance().add_task( - it->second.thread_id, std::move(msg)); - } - else - { - cancel_forwarding_task(it->second.timer_entry); - } + it->second->cancel_task(); it = timeout_tasks.erase(it); } else diff --git a/src/node/rpc/test/frontend_test.cpp b/src/node/rpc/test/frontend_test.cpp index 00f907671c28..dafa668eb0fe 100644 --- a/src/node/rpc/test/frontend_test.cpp +++ b/src/node/rpc/test/frontend_test.cpp @@ -1775,7 +1775,6 @@ TEST_CASE("Manual conflicts") int main(int argc, char** argv) { - ::threading::ThreadMessaging::init(1); doctest::Context context; context.applyCommandLine(argc, argv); int res = context.run(); diff --git a/src/node/snapshotter.h b/src/node/snapshotter.h index bddfbfa2189c..7e462a30978c 100644 --- a/src/node/snapshotter.h +++ b/src/node/snapshotter.h @@ -6,12 +6,12 @@ #include "consensus/ledger_enclave_types.h" #include "ds/ccf_assert.h" #include "ds/internal_logger.h" -#include "ds/thread_messaging.h" #include "kv/kv_types.h" #include "kv/store.h" #include "node/network_state.h" #include "node/snapshot_serdes.h" #include "service/tables/snapshot_evidence.h" +#include "tasks/task_system.h" #include #include @@ -96,18 +96,35 @@ namespace ccf serialised_receipt); } - struct SnapshotMsg + struct SnapshotTask : public ccf::tasks::BaseTask { std::shared_ptr self; std::unique_ptr snapshot; uint32_t generation_count; - }; - static void snapshot_cb(std::unique_ptr<::threading::Tmsg> msg) - { - msg->data.self->snapshot_( - std::move(msg->data.snapshot), msg->data.generation_count); - } + const std::string name; + + SnapshotTask( + std::shared_ptr _self, + std::unique_ptr&& _snapshot, + uint32_t _generation_count) : + self(_self), + snapshot(std::move(_snapshot)), + generation_count(_generation_count), + name(fmt::format( + "snapshot@{}[{}]", snapshot->get_version(), generation_count)) + {} + + void do_task_implementation() override + { + self->snapshot_(std::move(snapshot), generation_count); + } + + const std::string& get_name() const override + { + return name; + } + }; void snapshot_( std::unique_ptr snapshot, @@ -438,13 +455,13 @@ namespace ccf void schedule_snapshot(::consensus::Index idx) { static uint32_t generation_count = 0; - auto msg = std::make_unique<::threading::Tmsg>(&snapshot_cb); - msg->data.self = shared_from_this(); - msg->data.snapshot = store->snapshot_unsafe_maps(idx); - msg->data.generation_count = generation_count++; - auto& tm = ::threading::ThreadMessaging::instance(); - tm.add_task(tm.get_execution_thread(generation_count), std::move(msg)); + auto task = std::make_shared( + shared_from_this(), + store->snapshot_unsafe_maps(idx), + generation_count++); + + ccf::tasks::add_task(task); } void commit(::consensus::Index idx, bool generate_snapshot) override diff --git a/src/node/test/historical_queries.cpp b/src/node/test/historical_queries.cpp index 43b073ea9206..3b517939e980 100644 --- a/src/node/test/historical_queries.cpp +++ b/src/node/test/historical_queries.cpp @@ -1993,7 +1993,6 @@ TEST_CASE("Valid merkle proof from receipts") int main(int argc, char** argv) { - threading::ThreadMessaging::init(1); doctest::Context context; context.applyCommandLine(argc, argv); int res = context.run(); diff --git a/src/node/test/history.cpp b/src/node/test/history.cpp index 774bc48ab183..7d91a754f8dc 100644 --- a/src/node/test/history.cpp +++ b/src/node/test/history.cpp @@ -499,7 +499,6 @@ TEST_CASE( int main(int argc, char** argv) { - threading::ThreadMessaging::init(1); doctest::Context context; context.applyCommandLine(argc, argv); int res = context.run(); diff --git a/src/node/test/history_bench.cpp b/src/node/test/history_bench.cpp index 12055778ed2c..738985873689 100644 --- a/src/node/test/history_bench.cpp +++ b/src/node/test/history_bench.cpp @@ -158,7 +158,6 @@ PICOBENCH(append_compact<1000>).iterations(sizes); int main(int argc, char* argv[]) { ccf::logger::config::level() = ccf::LoggerLevel::FATAL; - ::threading::ThreadMessaging::init(1); picobench::runner runner; runner.parse_cmd_line(argc, argv); diff --git a/src/node/test/snapshot.cpp b/src/node/test/snapshot.cpp index 57c9b651d998..c0b85ad0a651 100644 --- a/src/node/test/snapshot.cpp +++ b/src/node/test/snapshot.cpp @@ -178,7 +178,6 @@ TEST_CASE("Snapshot with merkle tree" * doctest::test_suite("snapshot")) int main(int argc, char** argv) { - threading::ThreadMessaging::init(1); doctest::Context context; context.applyCommandLine(argc, argv); int res = context.run(); diff --git a/src/node/test/snapshotter.cpp b/src/node/test/snapshotter.cpp index f3e9c236d264..2cfc94b82adb 100644 --- a/src/node/test/snapshotter.cpp +++ b/src/node/test/snapshotter.cpp @@ -21,6 +21,15 @@ auto node_kp = ccf::crypto::make_key_pair(); using StringString = ccf::kv::Map; using rb_msg = std::pair; +void run_one_task() +{ + auto task = ccf::tasks::get_main_job_board().get_task(); + if (task != nullptr) + { + task->do_task(); + } +} + auto read_ringbuffer_out(ringbuffer::Circuit& circuit) { std::optional idx = std::nullopt; @@ -167,7 +176,7 @@ TEST_CASE("Regular snapshotting") REQUIRE_FALSE(record_signature(history, snapshotter, snapshot_idx - 1)); commit_idx = snapshot_idx - 1; snapshotter->commit(commit_idx, true); - threading::ThreadMessaging::instance().run_one(); + run_one_task(); REQUIRE_THROWS_AS( read_latest_snapshot_evidence(network.tables), std::logic_error); @@ -183,7 +192,7 @@ TEST_CASE("Regular snapshotting") commit_idx = snapshot_idx + 1; snapshotter->commit(commit_idx, true); - threading::ThreadMessaging::instance().run_one(); + run_one_task(); REQUIRE(read_latest_snapshot_evidence(network.tables) == snapshot_idx); auto snapshot_allocate_msg = read_snapshot_allocate_out(eio); REQUIRE(snapshot_allocate_msg.has_value()); @@ -221,7 +230,7 @@ TEST_CASE("Regular snapshotting") commit_idx = snapshot_idx + 1; snapshotter->commit(commit_idx, true); - threading::ThreadMessaging::instance().run_one(); + run_one_task(); REQUIRE(read_latest_snapshot_evidence(network.tables) == snapshot_idx); auto snapshot_allocate_msg = read_snapshot_allocate_out(eio); REQUIRE(snapshot_allocate_msg.has_value()); @@ -255,7 +264,7 @@ TEST_CASE("Regular snapshotting") { commit_idx = snapshot_idx + 2; snapshotter->commit(commit_idx, true); - threading::ThreadMessaging::instance().run_one(); + run_one_task(); REQUIRE(read_ringbuffer_out(eio) == std::nullopt); } @@ -270,7 +279,7 @@ TEST_CASE("Regular snapshotting") commit_idx = snapshot_idx; snapshotter->commit(commit_idx, true); - threading::ThreadMessaging::instance().run_one(); + run_one_task(); REQUIRE(read_latest_snapshot_evidence(network.tables) == snapshot_idx); auto snapshot_allocate_msg = read_snapshot_allocate_out(eio); REQUIRE(snapshot_allocate_msg.has_value()); @@ -329,7 +338,7 @@ TEST_CASE("Rollback before snapshot is committed") REQUIRE(record_signature(history, snapshotter, snapshot_idx)); snapshotter->commit(snapshot_idx, true); - threading::ThreadMessaging::instance().run_one(); + run_one_task(); REQUIRE(read_latest_snapshot_evidence(network.tables) == snapshot_idx); auto snapshot_allocate_msg = read_snapshot_allocate_out(eio); @@ -364,7 +373,7 @@ TEST_CASE("Rollback before snapshot is committed") REQUIRE(record_signature(history, snapshotter, snapshot_idx)); snapshotter->commit(snapshot_idx, true); - threading::ThreadMessaging::instance().run_one(); + run_one_task(); REQUIRE(read_latest_snapshot_evidence(network.tables) == snapshot_idx); auto snapshot_allocate_msg = read_snapshot_allocate_out(eio); REQUIRE(snapshot_allocate_msg.has_value()); @@ -395,7 +404,7 @@ TEST_CASE("Rollback before snapshot is committed") REQUIRE_FALSE(record_signature(history, snapshotter, snapshot_idx)); snapshotter->commit(snapshot_idx, true); - threading::ThreadMessaging::instance().run_one(); + run_one_task(); REQUIRE(read_latest_snapshot_evidence(network.tables) == snapshot_idx); auto snapshot_allocate_msg = read_snapshot_allocate_out(eio); REQUIRE(snapshot_allocate_msg.has_value()); @@ -418,7 +427,7 @@ TEST_CASE("Rollback before snapshot is committed") read_ringbuffer_out(eio) == rb_msg({::consensus::snapshot_commit, snapshot_idx})); - threading::ThreadMessaging::instance().run_one(); + run_one_task(); } } @@ -484,7 +493,7 @@ TEST_CASE("Rekey ledger while snapshot is in progress") INFO("Finally, schedule snapshot creation"); { - threading::ThreadMessaging::instance().run_one(); + run_one_task(); REQUIRE(read_latest_snapshot_evidence(network.tables) == snapshot_idx); auto snapshot_allocate_msg = read_snapshot_allocate_out(eio); REQUIRE(snapshot_allocate_msg.has_value()); @@ -518,7 +527,6 @@ TEST_CASE("Rekey ledger while snapshot is in progress") int main(int argc, char** argv) { - threading::ThreadMessaging::init(1); doctest::Context context; context.applyCommandLine(argc, argv); int res = context.run(); diff --git a/src/quic/quic_session.h b/src/quic/quic_session.h index ca956764b552..688068dce801 100644 --- a/src/quic/quic_session.h +++ b/src/quic/quic_session.h @@ -6,7 +6,6 @@ #include "ds/messaging.h" #include "ds/pending_io.h" #include "ds/ring_buffer.h" -#include "ds/thread_messaging.h" #include "enclave/session.h" #include "udp/msg_types.h" @@ -20,7 +19,8 @@ namespace quic protected: ringbuffer::WriterPtr to_host; ccf::tls::ConnID session_id; - size_t execution_thread; + + std::shared_ptr task_scheduler; enum Status { @@ -55,12 +55,14 @@ namespace quic session_id(session_id_), status(handshake) { - execution_thread = - threading::ThreadMessaging::instance().get_execution_thread(session_id); + task_scheduler = ccf::tasks::OrderedTasks::create( + ccf::tasks::get_main_job_board(), + fmt::format("Session {}", session_id)); } ~QUICSession() { + task_scheduler->cancel_task(); // RINGBUFFER_WRITE_MESSAGE(quic::quic_closed, to_host, session_id); } @@ -144,45 +146,68 @@ namespace quic void recv_buffered(const uint8_t* data, size_t size, sockaddr addr) { - if (ccf::threading::get_current_thread_id() != execution_thread) - { - throw std::runtime_error("Called recv_buffered from incorrect thread"); - } LOG_TRACE_FMT("QUIC Session recv_buffered with {} bytes", size); pending_reads.emplace_back(const_cast(data), size, addr); do_handshake(); } - struct SendRecvMsg + struct SessionDataTask : public ccf::tasks::ITaskAction { - std::vector data; std::shared_ptr self; + std::vector data; sockaddr addr; + + SessionDataTask( + std::shared_ptr s, + std::span d, + sockaddr sa) : + self(s), + addr(sa) + { + data.assign(d.begin(), d.end()); + } }; - static void send_raw_cb(std::unique_ptr> msg) + struct SendDataTask : public SessionDataTask { - msg->data.self->send_raw_thread(msg->data.data, msg->data.addr); - } + using SessionDataTask::SessionDataTask; - void send_raw(const uint8_t* data, size_t size, sockaddr addr) + void do_action() override + { + self->send_raw_thread(data, addr); + } + + const std::string& get_name() const override + { + static const std::string name = "quic::SendDataTask"; + return name; + } + }; + + struct RecvDataTask : public SessionDataTask { - auto msg = std::make_unique>(&send_raw_cb); - msg->data.self = this->shared_from_this(); - msg->data.data = std::vector(data, data + size); - msg->data.addr = addr; + using SessionDataTask::SessionDataTask; - threading::ThreadMessaging::instance().add_task( - execution_thread, std::move(msg)); + void do_action() override + { + self->recv(data.data(), data.size(), addr); + } + + const std::string& get_name() const override + { + static const std::string name = "quic::RecvDataTask"; + return name; + } + }; + + void send_raw(const uint8_t* data, size_t size, sockaddr addr) + { + task_scheduler->add_action(std::make_shared( + shared_from_this(), std::span{data, size}, addr)); } void send_raw_thread(const std::vector& data, sockaddr addr) { - if (ccf::threading::get_current_thread_id() != execution_thread) - { - throw std::runtime_error( - "Called send_raw_thread from incorrect thread"); - } // Writes as much of the data as possible. If the data cannot all // be written now, we store the remainder. We // will try to send pending writes again whenever write() is called. @@ -206,22 +231,25 @@ namespace quic void send_buffered(const std::vector& data, sockaddr addr) { - if (ccf::threading::get_current_thread_id() != execution_thread) - { - throw std::runtime_error("Called send_buffered from incorrect thread"); - } - pending_writes.emplace_back( const_cast(data.data()), data.size(), addr); } - void flush() + void handle_incoming_data(std::span data) override { - if (ccf::threading::get_current_thread_id() != execution_thread) - { - throw std::runtime_error("Called flush from incorrect thread"); - } + auto [_, addr_family, addr_data, body] = + ringbuffer::read_message(data); + + task_scheduler->add_action(std::make_shared( + shared_from_this(), + body, + udp::sockaddr_decode(addr_family, addr_data))); + } + virtual void recv(const uint8_t* data_, size_t size_, sockaddr addr_) = 0; + + void flush() + { do_handshake(); if (status != ready) @@ -250,32 +278,15 @@ namespace quic PendingBuffer::clear_empty(pending_writes); } - struct EmptyMsg - { - std::shared_ptr self; - }; - - static void close_cb(std::unique_ptr> msg) - { - msg->data.self->close_thread(); - } - void close_session() override { - auto msg = std::make_unique>(&close_cb); - msg->data.self = this->shared_from_this(); - - threading::ThreadMessaging::instance().add_task( - execution_thread, std::move(msg)); + auto self = shared_from_this(); + task_scheduler->add_action( + ccf::tasks::make_basic_action([self]() { self->close_thread(); })); } void close_thread() { - if (ccf::threading::get_current_thread_id() != execution_thread) - { - throw std::runtime_error("Called close_thread from incorrect thread"); - } - switch (status) { case handshake: @@ -349,11 +360,6 @@ namespace quic int handle_recv(uint8_t* buf, size_t len, sockaddr addr) { - if (ccf::threading::get_current_thread_id() != execution_thread) - { - throw std::runtime_error("Called handle_recv from incorrect thread"); - } - size_t len_read = 0; for (auto& read : pending_reads) { @@ -417,27 +423,7 @@ namespace quic send_raw(data.data(), data.size(), addr); } - static void recv_cb(std::unique_ptr> msg) - { - reinterpret_cast(msg->data.self.get()) - ->recv_(msg->data.data.data(), msg->data.data.size(), msg->data.addr); - } - - void handle_incoming_data(std::span data) override - { - auto [_, addr_family, addr_data, body] = - ringbuffer::read_message(data); - - auto msg = std::make_unique>(&recv_cb); - msg->data.self = this->shared_from_this(); - msg->data.data.assign(body.data, body.data + body.size); - msg->data.addr = udp::sockaddr_decode(addr_family, addr_data); - - threading::ThreadMessaging::instance().add_task( - execution_thread, std::move(msg)); - } - - void recv_(const uint8_t* data_, size_t size_, sockaddr addr_) + void recv(const uint8_t* data_, size_t size_, sockaddr addr_) override { recv_buffered(data_, size_, addr_); addr = addr_; diff --git a/src/tasks/test/demo/concurrent_queue_interface.h b/src/tasks/test/demo/concurrent_queue_interface.h new file mode 100644 index 000000000000..7500de0e4788 --- /dev/null +++ b/src/tasks/test/demo/concurrent_queue_interface.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. +#pragma once + +#include +#include + +namespace ccf::tasks +{ + template + class IConcurrentQueue + { + public: + using ValueType = T; + + virtual ~IConcurrentQueue() = default; + + virtual bool empty() = 0; + + virtual void push_back(const T& t) = 0; + virtual void emplace_back(T&& t) = 0; + + virtual std::optional try_pop() = 0; + }; +} \ No newline at end of file diff --git a/src/tasks/test/demo/locking_concurrent_queue.h b/src/tasks/test/demo/locking_concurrent_queue.h new file mode 100644 index 000000000000..2599274ce8ef --- /dev/null +++ b/src/tasks/test/demo/locking_concurrent_queue.h @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. +#pragma once + +#include "./concurrent_queue_interface.h" + +#include +#include + +namespace ccf::tasks +{ + template + class LockingConcurrentQueue : public IConcurrentQueue + { + protected: + std::mutex mutex; + std::deque deque; + + public: + bool empty() override + { + std::lock_guard lock(mutex); + return deque.empty(); + } + + size_t size() + { + std::lock_guard lock(mutex); + return deque.size(); + } + + void push_back(const T& t) override + { + std::lock_guard lock(mutex); + deque.push_back(t); + } + + void emplace_back(T&& t) override + { + std::lock_guard lock(mutex); + deque.emplace_back(std::move(t)); + } + + std::optional try_pop() override + { + std::lock_guard lock(mutex); + + if (deque.empty()) + { + return std::nullopt; + } + + std::optional val = deque.front(); + deque.pop_front(); + return val; + } + }; +} diff --git a/src/tasks/test/flush_all_jobs.h b/src/tasks/test/flush_all_jobs.h new file mode 100644 index 000000000000..66ecc1164fa6 --- /dev/null +++ b/src/tasks/test/flush_all_jobs.h @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. +#pragma once + +#include "tasks/task_system.h" + +static inline void flush_all_jobs( + std::atomic& stop_signal, + size_t worker_count, + std::chrono::seconds kill_after = std::chrono::seconds(5)) +{ + std::vector workers; + for (size_t i = 0; i < worker_count; ++i) + { + workers.emplace_back([&stop_signal]() { + while (!stop_signal.load()) + { + auto task = ccf::tasks::get_main_job_board().get_task(); + if (task != nullptr) + { + task->do_task(); + } + std::this_thread::yield(); + } + }); + } + + using TClock = std::chrono::steady_clock; + auto now = TClock::now(); + + const auto hard_end = now + kill_after; + + while (true) + { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + now = TClock::now(); + if (now > hard_end) + { + break; + } + + if (stop_signal.load()) + { + break; + } + } + + stop_signal.store(true); + + for (auto& worker : workers) + { + worker.join(); + } +} \ No newline at end of file diff --git a/src/tasks/test/merge_bench.cpp b/src/tasks/test/merge_bench.cpp new file mode 100644 index 000000000000..236e2cd5bcf1 --- /dev/null +++ b/src/tasks/test/merge_bench.cpp @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. + +#include "flush_all_jobs.h" +#include "merge_sort.h" + +#include + +#define PICOBENCH_DONT_BIND_TO_ONE_CORE +#define PICOBENCH_IMPLEMENT_WITH_MAIN +#include + +#define FMT_HEADER_ONLY +#include +#include + +static inline std::span get_merge_sort_data(size_t n) +{ + static std::vector data; + static std::random_device rd; + static std::mt19937 g(rd()); + + while (data.size() < n) + { + data.emplace_back(rand()); + } + + auto begin = data.begin(); + auto end = begin + n; + std::shuffle(begin, end, g); + + return {begin, end}; +} + +void do_merge_sort(picobench::state& s, size_t worker_count, size_t data_size) +{ + auto ns = get_merge_sort_data(data_size); + if (std::is_sorted(ns.begin(), ns.end())) + { + throw std::logic_error("Initial data already sorted"); + } + + std::atomic stop_signal{false}; + + ccf::tasks::add_task( + std::make_shared(ns.begin(), ns.end(), stop_signal)); + + s.start_timer(); + flush_all_jobs(stop_signal, worker_count); + s.stop_timer(); + + if (!std::is_sorted(ns.begin(), ns.end())) + { + throw std::logic_error("Final data not sorted"); + } +} + +template +static void benchmark_mergesort(picobench::state& s) +{ + do_merge_sort(s, num_threads, s.iterations()); +} + +namespace +{ + const std::vector data_sizes{1'000, 1'000'000}; + + auto threads_1 = benchmark_mergesort<1>; + auto threads_2 = benchmark_mergesort<2>; + auto threads_3 = benchmark_mergesort<3>; + auto threads_4 = benchmark_mergesort<4>; + auto threads_5 = benchmark_mergesort<5>; + auto threads_6 = benchmark_mergesort<6>; + auto threads_7 = benchmark_mergesort<7>; + auto threads_8 = benchmark_mergesort<8>; + + PICOBENCH_SUITE("merge sort"); + PICOBENCH(threads_1).iterations(data_sizes).baseline(); + PICOBENCH(threads_2).iterations(data_sizes); + PICOBENCH(threads_3).iterations(data_sizes); + PICOBENCH(threads_4).iterations(data_sizes); + PICOBENCH(threads_5).iterations(data_sizes); + PICOBENCH(threads_6).iterations(data_sizes); + PICOBENCH(threads_7).iterations(data_sizes); + PICOBENCH(threads_8).iterations(data_sizes); +} diff --git a/src/tasks/test/merge_sort.h b/src/tasks/test/merge_sort.h new file mode 100644 index 000000000000..410ba6d7de46 --- /dev/null +++ b/src/tasks/test/merge_sort.h @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. +#pragma once + +#include "tasks/task_system.h" + +#include + +struct MergeSortTask : public ccf::tasks::BaseTask, + public std::enable_shared_from_this +{ + // How many items will we actually directly sort, vs forking 2 new tasks to + // sub-sort + static constexpr size_t sort_threshold = 50; + + using Iterator = std::span::iterator; + + Iterator begin; + Iterator end; + std::atomic& stop_signal; + std::shared_ptr parent; + std::atomic sub_tasks; + + MergeSortTask( + Iterator b, + Iterator e, + std::atomic& ss, + std::shared_ptr p = nullptr) : + begin(b), + end(e), + parent(p), + stop_signal(ss) + {} + + void merge() + { + std::sort(begin, end); + + if (parent != nullptr) + { + if (--parent->sub_tasks == 0) + { + parent->merge(); + } + } + else + { + stop_signal.store(true); + } + } + + void do_task_implementation() override + { + const auto dist = std::distance(begin, end); + if (dist >= sort_threshold) + { + sub_tasks.store(2); + + auto self = shared_from_this(); + + auto mid_point = begin + (dist / 2); + + ccf::tasks::add_task( + std::make_shared(begin, mid_point, stop_signal, self)); + ccf::tasks::add_task( + std::make_shared(mid_point, end, stop_signal, self)); + } + else + { + merge(); + } + } +}; diff --git a/src/tasks/test/promises.cpp b/src/tasks/test/promises.cpp new file mode 100644 index 000000000000..fa6da879001b --- /dev/null +++ b/src/tasks/test/promises.cpp @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// NB: Not currently built, just a sketch of using variants + promise types to +// report a (potentially cancelled or erroring) async result + +struct Cancelled +{ + std::string reason; +}; + +struct TimedOut +{}; + +struct Actual +{ + size_t x; + std::string s; +}; + +template +using TResult = std::variant; + +using Result = TResult; + +void do_it(std::promise& result) +{ + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + auto choice = rand() % 4; + if (choice == 0) + { + std::cout << "do_it producing a real value" << std::endl; + result.set_value(Actual{.x = 42, .s = "hello world"}); + } + else if (choice == 1) + { + std::cout << "do_it simulating a cancellation" << std::endl; + result.set_value(Cancelled{.reason = "Dumb luck"}); + } + else if (choice == 2) + { + std::cout << "do_it simulating a timeout" << std::endl; + result.set_value(TimedOut{}); + } + else if (choice == 3) + { + std::cout << "do_it simulating an exception" << std::endl; + result.set_exception( + std::make_exception_ptr(std::logic_error("I blew up"))); + } +} + +int main() +{ + for (auto i = 0; i < 10; ++i) + { + std::cout << "Iteration " << i << std::endl; + std::promise result; + + std::future future = result.get_future(); + + std::thread t(do_it, std::ref(result)); + + try + { + std::cout << "About to call future.get()" << std::endl; + auto r = future.get(); + std::visit( + [](auto&& arg) { + using T = std::decay_t; + if constexpr (std::is_same_v) + { + std::cout << " Result is an actual value, with x = " << arg.x + << " and s = " << arg.s << std::endl; + } + else if constexpr (std::is_same_v) + { + std::cout << " Operation was cancelled, because: " << arg.reason + << std::endl; + } + else if constexpr (std::is_same_v) + { + std::cout << " Operation timed out" << std::endl; + } + else + { + static_assert(false, "Non-exhaustive visitor!"); + } + }, + r); + } + catch (const std::exception& e) + { + std::cout << " Exception thrown: " << e.what() << std::endl; + } + + t.join(); + } +} diff --git a/src/tasks/test/sleep_bench.cpp b/src/tasks/test/sleep_bench.cpp new file mode 100644 index 000000000000..7c9ce854c1fc --- /dev/null +++ b/src/tasks/test/sleep_bench.cpp @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. + +#include "flush_all_jobs.h" +#include "tasks/basic_task.h" + +#define PICOBENCH_DONT_BIND_TO_ONE_CORE +#include + +#define FMT_HEADER_ONLY +#include +#include + +void sleep_with_many_workers( + picobench::state& s, size_t worker_count, size_t num_sleeps) +{ + std::atomic stop_signal{false}; + + for (auto i = 0; i < num_sleeps; ++i) + { + ccf::tasks::add_task(ccf::tasks::make_basic_task( + []() { std::this_thread::sleep_for(std::chrono::milliseconds(1)); })); + } + + ccf::tasks::add_task( + ccf::tasks::make_basic_task([&]() { stop_signal.store(true); })); + + s.start_timer(); + flush_all_jobs(stop_signal, worker_count); + s.stop_timer(); +} + +template +static void benchmark_sleeps(picobench::state& s) +{ + sleep_with_many_workers(s, num_threads, s.iterations()); +} + +namespace +{ + const std::vector num_sleeps{100, 1000}; + + auto threads_1 = benchmark_sleeps<1>; + auto threads_2 = benchmark_sleeps<2>; + auto threads_3 = benchmark_sleeps<3>; + auto threads_4 = benchmark_sleeps<4>; + auto threads_5 = benchmark_sleeps<5>; + auto threads_6 = benchmark_sleeps<6>; + auto threads_7 = benchmark_sleeps<7>; + auto threads_8 = benchmark_sleeps<8>; + + PICOBENCH_SUITE("sleeps"); + PICOBENCH(threads_1).iterations(num_sleeps).baseline(); + PICOBENCH(threads_2).iterations(num_sleeps); + PICOBENCH(threads_3).iterations(num_sleeps); + PICOBENCH(threads_4).iterations(num_sleeps); + PICOBENCH(threads_5).iterations(num_sleeps); + PICOBENCH(threads_6).iterations(num_sleeps); + PICOBENCH(threads_7).iterations(num_sleeps); + PICOBENCH(threads_8).iterations(num_sleeps); +} diff --git a/src/tasks/test/task_system_thread.h b/src/tasks/test/task_system_thread.h new file mode 100644 index 000000000000..cc74a2c5d3ec --- /dev/null +++ b/src/tasks/test/task_system_thread.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the Apache 2.0 License. +#pragma once + +#include "tasks/task_system.h" + +#include +#include +#include + +namespace ccf::tasks::test +{ + struct TaskSystemThread + { + std::chrono::milliseconds polling_period; + std::thread thread; + std::atomic terminate = false; + + TaskSystemThread( + std::chrono::milliseconds _polling_period = + std::chrono::milliseconds(10)) : + polling_period(_polling_period) + { + thread = std::thread([this]() { + while (!this->terminate.load()) + { + ccf::tasks::tick(this->polling_period); + + auto& job_board = ccf::tasks::get_main_job_board(); + auto task = job_board.get_task(); + while (task != nullptr) + { + task->do_task(); + task = job_board.get_task(); + } + + std::this_thread::sleep_for(this->polling_period); + } + }); + } + + ~TaskSystemThread() + { + terminate.store(true); + thread.join(); + } + }; +} From fe1f2fe06b41ee97e188029da0dc25b9f1af63c6 Mon Sep 17 00:00:00 2001 From: Eddy Ashton Date: Fri, 31 Oct 2025 14:49:06 +0000 Subject: [PATCH 2/3] Delete some accidentally re-introduced files --- .../test/demo/concurrent_queue_interface.h | 25 ---- .../test/demo/locking_concurrent_queue.h | 58 --------- src/tasks/test/flush_all_jobs.h | 54 --------- src/tasks/test/merge_bench.cpp | 86 -------------- src/tasks/test/merge_sort.h | 73 ------------ src/tasks/test/promises.cpp | 110 ------------------ src/tasks/test/sleep_bench.cpp | 61 ---------- src/tasks/test/task_system_thread.h | 48 -------- 8 files changed, 515 deletions(-) delete mode 100644 src/tasks/test/demo/concurrent_queue_interface.h delete mode 100644 src/tasks/test/demo/locking_concurrent_queue.h delete mode 100644 src/tasks/test/flush_all_jobs.h delete mode 100644 src/tasks/test/merge_bench.cpp delete mode 100644 src/tasks/test/merge_sort.h delete mode 100644 src/tasks/test/promises.cpp delete mode 100644 src/tasks/test/sleep_bench.cpp delete mode 100644 src/tasks/test/task_system_thread.h diff --git a/src/tasks/test/demo/concurrent_queue_interface.h b/src/tasks/test/demo/concurrent_queue_interface.h deleted file mode 100644 index 7500de0e4788..000000000000 --- a/src/tasks/test/demo/concurrent_queue_interface.h +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the Apache 2.0 License. -#pragma once - -#include -#include - -namespace ccf::tasks -{ - template - class IConcurrentQueue - { - public: - using ValueType = T; - - virtual ~IConcurrentQueue() = default; - - virtual bool empty() = 0; - - virtual void push_back(const T& t) = 0; - virtual void emplace_back(T&& t) = 0; - - virtual std::optional try_pop() = 0; - }; -} \ No newline at end of file diff --git a/src/tasks/test/demo/locking_concurrent_queue.h b/src/tasks/test/demo/locking_concurrent_queue.h deleted file mode 100644 index 2599274ce8ef..000000000000 --- a/src/tasks/test/demo/locking_concurrent_queue.h +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the Apache 2.0 License. -#pragma once - -#include "./concurrent_queue_interface.h" - -#include -#include - -namespace ccf::tasks -{ - template - class LockingConcurrentQueue : public IConcurrentQueue - { - protected: - std::mutex mutex; - std::deque deque; - - public: - bool empty() override - { - std::lock_guard lock(mutex); - return deque.empty(); - } - - size_t size() - { - std::lock_guard lock(mutex); - return deque.size(); - } - - void push_back(const T& t) override - { - std::lock_guard lock(mutex); - deque.push_back(t); - } - - void emplace_back(T&& t) override - { - std::lock_guard lock(mutex); - deque.emplace_back(std::move(t)); - } - - std::optional try_pop() override - { - std::lock_guard lock(mutex); - - if (deque.empty()) - { - return std::nullopt; - } - - std::optional val = deque.front(); - deque.pop_front(); - return val; - } - }; -} diff --git a/src/tasks/test/flush_all_jobs.h b/src/tasks/test/flush_all_jobs.h deleted file mode 100644 index 66ecc1164fa6..000000000000 --- a/src/tasks/test/flush_all_jobs.h +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the Apache 2.0 License. -#pragma once - -#include "tasks/task_system.h" - -static inline void flush_all_jobs( - std::atomic& stop_signal, - size_t worker_count, - std::chrono::seconds kill_after = std::chrono::seconds(5)) -{ - std::vector workers; - for (size_t i = 0; i < worker_count; ++i) - { - workers.emplace_back([&stop_signal]() { - while (!stop_signal.load()) - { - auto task = ccf::tasks::get_main_job_board().get_task(); - if (task != nullptr) - { - task->do_task(); - } - std::this_thread::yield(); - } - }); - } - - using TClock = std::chrono::steady_clock; - auto now = TClock::now(); - - const auto hard_end = now + kill_after; - - while (true) - { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - now = TClock::now(); - if (now > hard_end) - { - break; - } - - if (stop_signal.load()) - { - break; - } - } - - stop_signal.store(true); - - for (auto& worker : workers) - { - worker.join(); - } -} \ No newline at end of file diff --git a/src/tasks/test/merge_bench.cpp b/src/tasks/test/merge_bench.cpp deleted file mode 100644 index 236e2cd5bcf1..000000000000 --- a/src/tasks/test/merge_bench.cpp +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the Apache 2.0 License. - -#include "flush_all_jobs.h" -#include "merge_sort.h" - -#include - -#define PICOBENCH_DONT_BIND_TO_ONE_CORE -#define PICOBENCH_IMPLEMENT_WITH_MAIN -#include - -#define FMT_HEADER_ONLY -#include -#include - -static inline std::span get_merge_sort_data(size_t n) -{ - static std::vector data; - static std::random_device rd; - static std::mt19937 g(rd()); - - while (data.size() < n) - { - data.emplace_back(rand()); - } - - auto begin = data.begin(); - auto end = begin + n; - std::shuffle(begin, end, g); - - return {begin, end}; -} - -void do_merge_sort(picobench::state& s, size_t worker_count, size_t data_size) -{ - auto ns = get_merge_sort_data(data_size); - if (std::is_sorted(ns.begin(), ns.end())) - { - throw std::logic_error("Initial data already sorted"); - } - - std::atomic stop_signal{false}; - - ccf::tasks::add_task( - std::make_shared(ns.begin(), ns.end(), stop_signal)); - - s.start_timer(); - flush_all_jobs(stop_signal, worker_count); - s.stop_timer(); - - if (!std::is_sorted(ns.begin(), ns.end())) - { - throw std::logic_error("Final data not sorted"); - } -} - -template -static void benchmark_mergesort(picobench::state& s) -{ - do_merge_sort(s, num_threads, s.iterations()); -} - -namespace -{ - const std::vector data_sizes{1'000, 1'000'000}; - - auto threads_1 = benchmark_mergesort<1>; - auto threads_2 = benchmark_mergesort<2>; - auto threads_3 = benchmark_mergesort<3>; - auto threads_4 = benchmark_mergesort<4>; - auto threads_5 = benchmark_mergesort<5>; - auto threads_6 = benchmark_mergesort<6>; - auto threads_7 = benchmark_mergesort<7>; - auto threads_8 = benchmark_mergesort<8>; - - PICOBENCH_SUITE("merge sort"); - PICOBENCH(threads_1).iterations(data_sizes).baseline(); - PICOBENCH(threads_2).iterations(data_sizes); - PICOBENCH(threads_3).iterations(data_sizes); - PICOBENCH(threads_4).iterations(data_sizes); - PICOBENCH(threads_5).iterations(data_sizes); - PICOBENCH(threads_6).iterations(data_sizes); - PICOBENCH(threads_7).iterations(data_sizes); - PICOBENCH(threads_8).iterations(data_sizes); -} diff --git a/src/tasks/test/merge_sort.h b/src/tasks/test/merge_sort.h deleted file mode 100644 index 410ba6d7de46..000000000000 --- a/src/tasks/test/merge_sort.h +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the Apache 2.0 License. -#pragma once - -#include "tasks/task_system.h" - -#include - -struct MergeSortTask : public ccf::tasks::BaseTask, - public std::enable_shared_from_this -{ - // How many items will we actually directly sort, vs forking 2 new tasks to - // sub-sort - static constexpr size_t sort_threshold = 50; - - using Iterator = std::span::iterator; - - Iterator begin; - Iterator end; - std::atomic& stop_signal; - std::shared_ptr parent; - std::atomic sub_tasks; - - MergeSortTask( - Iterator b, - Iterator e, - std::atomic& ss, - std::shared_ptr p = nullptr) : - begin(b), - end(e), - parent(p), - stop_signal(ss) - {} - - void merge() - { - std::sort(begin, end); - - if (parent != nullptr) - { - if (--parent->sub_tasks == 0) - { - parent->merge(); - } - } - else - { - stop_signal.store(true); - } - } - - void do_task_implementation() override - { - const auto dist = std::distance(begin, end); - if (dist >= sort_threshold) - { - sub_tasks.store(2); - - auto self = shared_from_this(); - - auto mid_point = begin + (dist / 2); - - ccf::tasks::add_task( - std::make_shared(begin, mid_point, stop_signal, self)); - ccf::tasks::add_task( - std::make_shared(mid_point, end, stop_signal, self)); - } - else - { - merge(); - } - } -}; diff --git a/src/tasks/test/promises.cpp b/src/tasks/test/promises.cpp deleted file mode 100644 index fa6da879001b..000000000000 --- a/src/tasks/test/promises.cpp +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the Apache 2.0 License. - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -// NB: Not currently built, just a sketch of using variants + promise types to -// report a (potentially cancelled or erroring) async result - -struct Cancelled -{ - std::string reason; -}; - -struct TimedOut -{}; - -struct Actual -{ - size_t x; - std::string s; -}; - -template -using TResult = std::variant; - -using Result = TResult; - -void do_it(std::promise& result) -{ - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - auto choice = rand() % 4; - if (choice == 0) - { - std::cout << "do_it producing a real value" << std::endl; - result.set_value(Actual{.x = 42, .s = "hello world"}); - } - else if (choice == 1) - { - std::cout << "do_it simulating a cancellation" << std::endl; - result.set_value(Cancelled{.reason = "Dumb luck"}); - } - else if (choice == 2) - { - std::cout << "do_it simulating a timeout" << std::endl; - result.set_value(TimedOut{}); - } - else if (choice == 3) - { - std::cout << "do_it simulating an exception" << std::endl; - result.set_exception( - std::make_exception_ptr(std::logic_error("I blew up"))); - } -} - -int main() -{ - for (auto i = 0; i < 10; ++i) - { - std::cout << "Iteration " << i << std::endl; - std::promise result; - - std::future future = result.get_future(); - - std::thread t(do_it, std::ref(result)); - - try - { - std::cout << "About to call future.get()" << std::endl; - auto r = future.get(); - std::visit( - [](auto&& arg) { - using T = std::decay_t; - if constexpr (std::is_same_v) - { - std::cout << " Result is an actual value, with x = " << arg.x - << " and s = " << arg.s << std::endl; - } - else if constexpr (std::is_same_v) - { - std::cout << " Operation was cancelled, because: " << arg.reason - << std::endl; - } - else if constexpr (std::is_same_v) - { - std::cout << " Operation timed out" << std::endl; - } - else - { - static_assert(false, "Non-exhaustive visitor!"); - } - }, - r); - } - catch (const std::exception& e) - { - std::cout << " Exception thrown: " << e.what() << std::endl; - } - - t.join(); - } -} diff --git a/src/tasks/test/sleep_bench.cpp b/src/tasks/test/sleep_bench.cpp deleted file mode 100644 index 7c9ce854c1fc..000000000000 --- a/src/tasks/test/sleep_bench.cpp +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the Apache 2.0 License. - -#include "flush_all_jobs.h" -#include "tasks/basic_task.h" - -#define PICOBENCH_DONT_BIND_TO_ONE_CORE -#include - -#define FMT_HEADER_ONLY -#include -#include - -void sleep_with_many_workers( - picobench::state& s, size_t worker_count, size_t num_sleeps) -{ - std::atomic stop_signal{false}; - - for (auto i = 0; i < num_sleeps; ++i) - { - ccf::tasks::add_task(ccf::tasks::make_basic_task( - []() { std::this_thread::sleep_for(std::chrono::milliseconds(1)); })); - } - - ccf::tasks::add_task( - ccf::tasks::make_basic_task([&]() { stop_signal.store(true); })); - - s.start_timer(); - flush_all_jobs(stop_signal, worker_count); - s.stop_timer(); -} - -template -static void benchmark_sleeps(picobench::state& s) -{ - sleep_with_many_workers(s, num_threads, s.iterations()); -} - -namespace -{ - const std::vector num_sleeps{100, 1000}; - - auto threads_1 = benchmark_sleeps<1>; - auto threads_2 = benchmark_sleeps<2>; - auto threads_3 = benchmark_sleeps<3>; - auto threads_4 = benchmark_sleeps<4>; - auto threads_5 = benchmark_sleeps<5>; - auto threads_6 = benchmark_sleeps<6>; - auto threads_7 = benchmark_sleeps<7>; - auto threads_8 = benchmark_sleeps<8>; - - PICOBENCH_SUITE("sleeps"); - PICOBENCH(threads_1).iterations(num_sleeps).baseline(); - PICOBENCH(threads_2).iterations(num_sleeps); - PICOBENCH(threads_3).iterations(num_sleeps); - PICOBENCH(threads_4).iterations(num_sleeps); - PICOBENCH(threads_5).iterations(num_sleeps); - PICOBENCH(threads_6).iterations(num_sleeps); - PICOBENCH(threads_7).iterations(num_sleeps); - PICOBENCH(threads_8).iterations(num_sleeps); -} diff --git a/src/tasks/test/task_system_thread.h b/src/tasks/test/task_system_thread.h deleted file mode 100644 index cc74a2c5d3ec..000000000000 --- a/src/tasks/test/task_system_thread.h +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the Apache 2.0 License. -#pragma once - -#include "tasks/task_system.h" - -#include -#include -#include - -namespace ccf::tasks::test -{ - struct TaskSystemThread - { - std::chrono::milliseconds polling_period; - std::thread thread; - std::atomic terminate = false; - - TaskSystemThread( - std::chrono::milliseconds _polling_period = - std::chrono::milliseconds(10)) : - polling_period(_polling_period) - { - thread = std::thread([this]() { - while (!this->terminate.load()) - { - ccf::tasks::tick(this->polling_period); - - auto& job_board = ccf::tasks::get_main_job_board(); - auto task = job_board.get_task(); - while (task != nullptr) - { - task->do_task(); - task = job_board.get_task(); - } - - std::this_thread::sleep_for(this->polling_period); - } - }); - } - - ~TaskSystemThread() - { - terminate.store(true); - thread.join(); - } - }; -} From 711d7b8c20731cfdeca07fc4e686649741698552 Mon Sep 17 00:00:00 2001 From: Eddy Ashton Date: Fri, 31 Oct 2025 15:08:47 +0000 Subject: [PATCH 3/3] Revert unused thread naming --- include/ccf/ds/logger.h | 4 ++-- include/ccf/threading/thread_ids.h | 3 --- src/enclave/thread_local.cpp | 16 ---------------- 3 files changed, 2 insertions(+), 21 deletions(-) diff --git a/include/ccf/ds/logger.h b/include/ccf/ds/logger.h index 40b12dbba8d8..52662270724c 100644 --- a/include/ccf/ds/logger.h +++ b/include/ccf/ds/logger.h @@ -44,7 +44,7 @@ namespace ccf::logger std::string tag; std::string file_name; size_t line_number; - std::string thread_id; + uint16_t thread_id; std::ostringstream ss; std::string msg; @@ -59,7 +59,7 @@ namespace ccf::logger file_name(file_name_), line_number(line_number_) { - thread_id = ccf::threading::get_current_thread_name(); + thread_id = ccf::threading::get_current_thread_id(); } template diff --git a/include/ccf/threading/thread_ids.h b/include/ccf/threading/thread_ids.h index 4f59b088a0e0..2def60951587 100644 --- a/include/ccf/threading/thread_ids.h +++ b/include/ccf/threading/thread_ids.h @@ -22,7 +22,4 @@ namespace ccf::threading uint16_t get_current_thread_id(); void set_current_thread_id(ThreadID to); void reset_thread_id_generator(ThreadID to = MAIN_THREAD_ID); - - std::string get_current_thread_name(); - void set_current_thread_name(std::string_view sv); } \ No newline at end of file diff --git a/src/enclave/thread_local.cpp b/src/enclave/thread_local.cpp index 79b75010d712..23c8bdeda858 100644 --- a/src/enclave/thread_local.cpp +++ b/src/enclave/thread_local.cpp @@ -7,7 +7,6 @@ namespace ccf::threading namespace { std::atomic next_thread_id = MAIN_THREAD_ID; - thread_local std::optional this_thread_name = std::nullopt; } uint16_t& current_thread_id() @@ -30,19 +29,4 @@ namespace ccf::threading { next_thread_id.store(to); } - - std::string get_current_thread_name() - { - if (!this_thread_name.has_value()) - { - this_thread_name = fmt::format("{}", get_current_thread_id()); - } - - return this_thread_name.value(); - } - - void set_current_thread_name(std::string_view sv) - { - this_thread_name = sv; - } } \ No newline at end of file