diff --git a/CMakeLists.txt b/CMakeLists.txt index 597c4a9f44a..84adb8c879b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -389,6 +389,7 @@ add_ccf_static_library( ccf_endpoints ccfcrypto ccf_kv + ccf_tasks nghttp2 ${CMAKE_THREAD_LIBS_INIT} curl @@ -549,7 +550,6 @@ if(BUILD_TESTS) ${CMAKE_CURRENT_SOURCE_DIR}/src/ds/test/serialized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/ds/test/serializer.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/ds/test/hash.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/src/ds/test/thread_messaging.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/ds/test/lru.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/ds/test/hex.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/ds/test/contiguous_set.cpp @@ -584,7 +584,7 @@ if(BUILD_TESTS) ${CMAKE_CURRENT_SOURCE_DIR}/src/consensus/aft/test/view_history.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/consensus/aft/test/committable_suffix.cpp ) - target_link_libraries(raft_test PRIVATE ccfcrypto) + target_link_libraries(raft_test PRIVATE ccfcrypto ccf_tasks) add_unit_test( raft_enclave_test @@ -615,7 +615,9 @@ if(BUILD_TESTS) add_unit_test( history_test ${CMAKE_CURRENT_SOURCE_DIR}/src/node/test/history.cpp ) - target_link_libraries(history_test PRIVATE ccfcrypto http_parser ccf_kv) + target_link_libraries( + history_test PRIVATE ccfcrypto http_parser ccf_kv ccf_tasks + ) add_unit_test( encryptor_test ${CMAKE_CURRENT_SOURCE_DIR}/src/node/test/encryptor.cpp @@ -648,23 +650,26 @@ if(BUILD_TESTS) ) target_link_libraries( historical_queries_test PRIVATE http_parser ccf_kv ccf_endpoints + ccf_tasks ) add_unit_test( indexing_test ${CMAKE_CURRENT_SOURCE_DIR}/src/indexing/test/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/indexing/test/lfs.cpp ) - target_link_libraries(indexing_test PRIVATE ccf_endpoints ccf_kv) + target_link_libraries(indexing_test PRIVATE ccf_endpoints ccf_kv ccf_tasks) add_unit_test( snapshot_test ${CMAKE_CURRENT_SOURCE_DIR}/src/node/test/snapshot.cpp ) - target_link_libraries(snapshot_test PRIVATE ccf_kv) + target_link_libraries(snapshot_test PRIVATE ccf_kv ccf_tasks) add_unit_test( snapshotter_test ${CMAKE_CURRENT_SOURCE_DIR}/src/node/test/snapshotter.cpp ) - target_link_libraries(snapshotter_test PRIVATE ccf_kv ccf_endpoints) + target_link_libraries( + snapshotter_test PRIVATE ccf_kv ccf_endpoints ccf_tasks + ) add_unit_test( node_info_json_test @@ -715,8 +720,14 @@ if(BUILD_TESTS) ${CCF_DIR}/src/node/quote.cpp ${CCF_DIR}/src/node/uvm_endorsements.cpp ) target_link_libraries( - frontend_test PRIVATE ${CMAKE_THREAD_LIBS_INIT} http_parser ccf_js - ccf_endpoints ccfcrypto ccf_kv + frontend_test + PRIVATE ${CMAKE_THREAD_LIBS_INIT} + http_parser + ccf_js + ccf_endpoints + ccfcrypto + ccf_kv + ccf_tasks ) add_unit_test( @@ -768,7 +779,7 @@ if(BUILD_TESTS) raft_driver ${CMAKE_CURRENT_SOURCE_DIR}/src/consensus/aft/test/driver.cpp src/enclave/thread_local.cpp ) - target_link_libraries(raft_driver PRIVATE ccfcrypto) + target_link_libraries(raft_driver PRIVATE ccfcrypto ccf_tasks) target_include_directories(raft_driver PRIVATE src/aft) add_test( @@ -811,7 +822,7 @@ if(BUILD_TESTS) add_picobench( history_bench SRCS src/node/test/history_bench.cpp src/enclave/thread_local.cpp - LINK_LIBS ccf_kv + LINK_LIBS ccf_kv ccf_tasks ) add_picobench( diff --git a/cmake/common.cmake b/cmake/common.cmake index d22799281b9..e6d77f9e5fd 100644 --- a/cmake/common.cmake +++ b/cmake/common.cmake @@ -225,8 +225,11 @@ function(add_picobench name) bash -c "$ --samples=10 --out-fmt=csv --output=${name}.csv && cat ${name}.csv" ) - - set_property(TEST ${name} PROPERTY LABELS benchmark) + set_property( + TEST ${name} + APPEND + PROPERTY LABELS benchmark + ) add_san_test_properties(${name}) endfunction() diff --git a/src/consensus/aft/raft.h b/src/consensus/aft/raft.h index 66601f676c6..13156288a9c 100644 --- a/src/consensus/aft/raft.h +++ b/src/consensus/aft/raft.h @@ -7,6 +7,7 @@ #include "ccf/service/reconfiguration_type.h" #include "ccf/tx_id.h" #include "ccf/tx_status.h" +#include "ds/ccf_assert.h" #include "ds/internal_logger.h" #include "ds/serialized.h" #include "impl/state.h" diff --git a/src/consensus/aft/test/driver.cpp b/src/consensus/aft/test/driver.cpp index 903856cc8f5..cbead7255ed 100644 --- a/src/consensus/aft/test/driver.cpp +++ b/src/consensus/aft/test/driver.cpp @@ -38,8 +38,6 @@ int main(int argc, char** argv) #endif ccf::logger::config::level() = ccf::LoggerLevel::DEBUG; - threading::ThreadMessaging::init(1); - const std::string filename = argv[1]; std::ifstream fstream; diff --git a/src/consensus/aft/test/main.cpp b/src/consensus/aft/test/main.cpp index d42bfd51d2d..17f1d693795 100644 --- a/src/consensus/aft/test/main.cpp +++ b/src/consensus/aft/test/main.cpp @@ -1010,7 +1010,6 @@ DOCTEST_TEST_CASE( int main(int argc, char** argv) { - threading::ThreadMessaging::init(1); doctest::Context context; context.applyCommandLine(argc, argv); int res = context.run(); diff --git a/src/crypto/openssl/symmetric_key.cpp b/src/crypto/openssl/symmetric_key.cpp index bd18b3dcc2b..f073d872578 100644 --- a/src/crypto/openssl/symmetric_key.cpp +++ b/src/crypto/openssl/symmetric_key.cpp @@ -6,7 +6,6 @@ #include "ccf/crypto/openssl/openssl_wrappers.h" #include "ccf/crypto/symmetric_key.h" #include "ds/internal_logger.h" -#include "ds/thread_messaging.h" #include #include diff --git a/src/ds/serializer.h b/src/ds/serializer.h index 7e7aed5ae0c..91ee53f4e71 100644 --- a/src/ds/serializer.h +++ b/src/ds/serializer.h @@ -27,6 +27,11 @@ namespace serializer { const uint8_t* data; const size_t size; + + operator std::span() const + { + return std::span(data, size); + } }; namespace details diff --git a/src/ds/test/messaging.cpp b/src/ds/test/messaging.cpp index f92e5918e7e..e71b5422e1b 100644 --- a/src/ds/test/messaging.cpp +++ b/src/ds/test/messaging.cpp @@ -5,7 +5,6 @@ #include "../non_blocking.h" #include "../ring_buffer.h" #include "../serialized.h" -#include "../thread_messaging.h" #include #include diff --git a/src/ds/test/thread_messaging.cpp b/src/ds/test/thread_messaging.cpp deleted file mode 100644 index 26678b3b2c2..00000000000 --- a/src/ds/test/thread_messaging.cpp +++ /dev/null @@ -1,214 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the Apache 2.0 License. -#include "../thread_messaging.h" - -#include - -struct Foo -{ - bool& happened; - - static size_t count; - - Foo(bool& h) : happened(h) - { - count++; - } - - ~Foo() - { - count--; - } -}; - -size_t Foo::count = 0; - -static void always(std::unique_ptr<::threading::Tmsg> msg) -{ - msg->data.happened = true; -} - -static void never(std::unique_ptr<::threading::Tmsg> msg) -{ - CHECK(false); -} - -TEST_CASE("ThreadMessaging API" * doctest::test_suite("threadmessaging")) -{ - { - ::threading::ThreadMessaging tm(1); - - static constexpr auto worker_thread_id = ccf::threading::MAIN_THREAD_ID + 1; - - bool happened_main_thread = false; - bool happened_worker_thread = false; - - tm.add_task( - ccf::threading::MAIN_THREAD_ID, - std::make_unique<::threading::Tmsg>(&always, happened_main_thread)); - - REQUIRE_THROWS(tm.add_task( - worker_thread_id, - std::make_unique<::threading::Tmsg>( - &always, happened_worker_thread))); - - REQUIRE(tm.run_one()); - REQUIRE_FALSE(tm.run_one()); - - REQUIRE(happened_main_thread); - REQUIRE_FALSE(happened_worker_thread); - } - - { - // Create a ThreadMessaging with task queues for main thread + 1 worker - // thread - ::threading::ThreadMessaging tm(2); - - static constexpr auto worker_a_id = 1; - static constexpr auto worker_b_id = 2; - - ccf::threading::reset_thread_id_generator(worker_a_id); - - bool happened_0 = false; - bool happened_1 = false; - bool happened_2 = false; - bool happened_3 = false; - - // Queue single task for main thread: - // - set happened_0 - tm.add_task( - ccf::threading::MAIN_THREAD_ID, - std::make_unique>(&always, happened_0)); - - // Queue 2 tasks for worker a: - // - set happened_1 - // - set happened_2 - tm.add_task( - worker_a_id, - std::make_unique<::threading::Tmsg>(&always, happened_1)); - tm.add_task( - worker_a_id, - std::make_unique<::threading::Tmsg>(&always, happened_2)); - - // Fail to queue task for worker b, tm is too small - REQUIRE_THROWS(tm.add_task( - worker_b_id, - std::make_unique<::threading::Tmsg>(&always, happened_3))); - - // Run single task on main thread - REQUIRE(tm.run_one()); - // Confirm there are no more tasks for main thread - REQUIRE_FALSE(tm.run_one()); - - // Confirm only first task has been executed - REQUIRE(happened_0); - REQUIRE_FALSE(happened_1); - REQUIRE_FALSE(happened_2); - REQUIRE_FALSE(happened_3); - - std::thread t([&]() { - // Run tasks for worker "a" - REQUIRE(ccf::threading::get_current_thread_id() == worker_a_id); - - REQUIRE(tm.run_one()); - REQUIRE(happened_1); - REQUIRE_FALSE(happened_2); - - REQUIRE(tm.run_one()); - REQUIRE(happened_2); - - REQUIRE_FALSE(tm.run_one()); - }); - - t.join(); - - REQUIRE(happened_0); - REQUIRE(happened_1); - REQUIRE(happened_2); - REQUIRE_FALSE(happened_3); - } -} - -// Note: this only works with ASAN turned on, which catches m2 not being -// freed. -TEST_CASE( - "Unpopped messages are freed" * doctest::test_suite("threadmessaging")) -{ - bool happened = false; - - { - ::threading::ThreadMessaging tm(1); - - auto m1 = std::make_unique<::threading::Tmsg>(&always, happened); - tm.add_task(0, std::move(m1)); - - // Task payload (and TMsg) is freed after running - tm.run_one(); - CHECK(Foo::count == 0); - - auto m2 = std::make_unique<::threading::Tmsg>(&never, happened); - tm.add_task(0, std::move(m2)); - // Task is owned by the queue, hasn't run - CHECK(Foo::count == 1); - } - // Task payload (and TMsg) is also freed if it hasn't run - // but the queue was destructed - CHECK(Foo::count == 0); - - CHECK(happened); -} - -TEST_CASE("Unique thread IDs" * doctest::test_suite("threadmessaging")) -{ - std::mutex assigned_ids_lock; - std::vector assigned_ids; - - const auto main_thread_id = ccf::threading::get_current_thread_id(); - REQUIRE(main_thread_id == ccf::threading::MAIN_THREAD_ID); - assigned_ids.push_back(main_thread_id); - - auto fn = [&]() { - { - std::lock_guard guard(assigned_ids_lock); - const auto current_thread_id = ccf::threading::get_current_thread_id(); - assigned_ids.push_back(current_thread_id); - } - }; - - constexpr size_t num_threads = 20; - constexpr size_t expected_ids = num_threads + 1; // Includes MAIN_THREAD_ID - std::vector threads; - for (auto i = 0; i < num_threads; ++i) - { - threads.emplace_back(fn); - } - - size_t attempts = 0; - constexpr size_t max_attempts = 5; - while (true) - { - { - std::lock_guard guard(assigned_ids_lock); - if (assigned_ids.size() == expected_ids) - { - break; - } - } - - REQUIRE(++attempts < max_attempts); - std::this_thread::sleep_for(std::chrono::milliseconds(50)); - } - - REQUIRE(assigned_ids.size() == expected_ids); - - for (auto& thread : threads) - { - thread.join(); - } - - const auto unique = std::unique(assigned_ids.begin(), assigned_ids.end()); - REQUIRE_MESSAGE( - unique == assigned_ids.end(), - fmt::format( - "Thread IDs are not unique: {}", fmt::join(assigned_ids, ", "))); -} diff --git a/src/ds/thread_messaging.h b/src/ds/thread_messaging.h deleted file mode 100644 index 075d4787b19..00000000000 --- a/src/ds/thread_messaging.h +++ /dev/null @@ -1,397 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the Apache 2.0 License. -#pragma once - -#include "ccf/threading/thread_ids.h" -#include "ds/ccf_assert.h" -#include "ds/internal_logger.h" - -#include -#include -#include - -namespace threading -{ - struct ThreadMsg - { - void (*cb)(std::unique_ptr); - std::atomic next = nullptr; - - ThreadMsg(void (*_cb)(std::unique_ptr)) : cb(_cb) {} - - virtual ~ThreadMsg() = default; - }; - - template - struct alignas(16) Tmsg : public ThreadMsg - { - Payload data; - - template - Tmsg(void (*_cb)(std::unique_ptr>), Args&&... args) : - ThreadMsg(reinterpret_cast)>(_cb)), - data(std::forward(args)...) - {} - - void reset_cb(void (*_cb)(std::unique_ptr>)) - { - cb = reinterpret_cast)>(_cb); - } - - virtual ~Tmsg() = default; - }; - - class ThreadMessaging; - - class TaskQueue - { - std::atomic item_head = nullptr; - ThreadMsg* local_msg = nullptr; - - public: - TaskQueue() = default; - - bool run_next_task() - { - if (local_msg == nullptr && item_head != nullptr) - { - local_msg = item_head.exchange(nullptr); - reverse_local_messages(); - } - - if (local_msg == nullptr) - { - return false; - } - - ThreadMsg* current = local_msg; - local_msg = local_msg->next; - - current->cb(std::unique_ptr(current)); - return true; - } - - void add_task(ThreadMsg* item) - { - ThreadMsg* tmp_head; - do - { - tmp_head = item_head.load(); - item->next = tmp_head; - } while (!item_head.compare_exchange_strong(tmp_head, item)); - } - - struct TimerEntry - { - TimerEntry() : time_offset(0), counter(0) {} - TimerEntry(std::chrono::milliseconds time_offset_, uint64_t counter_) : - time_offset(time_offset_), - counter(counter_) - {} - - std::chrono::milliseconds time_offset; - uint64_t counter; - }; - - struct TimerEntryCompare - { - bool operator()(const TimerEntry& lhs, const TimerEntry& rhs) const - { - if (lhs.time_offset != rhs.time_offset) - { - return lhs.time_offset < rhs.time_offset; - } - - return lhs.counter < rhs.counter; - } - }; - - TimerEntry add_task_after( - std::unique_ptr item, std::chrono::milliseconds ms) - { - TimerEntry entry = {time_offset + ms, time_entry_counter++}; - if (timer_map.empty() || entry.time_offset <= next_time_offset) - { - next_time_offset = entry.time_offset; - } - - timer_map.emplace(entry, std::move(item)); - return entry; - } - - bool cancel_timer_task(TimerEntry timer_entry) - { - auto num_erased = timer_map.erase(timer_entry); - CCF_ASSERT(num_erased <= 1, "Too many items erased"); - if (!timer_map.empty() && timer_entry.time_offset <= next_time_offset) - { - next_time_offset = timer_map.begin()->first.time_offset; - } - return num_erased != 0; - } - - void tick(std::chrono::milliseconds elapsed) - { - time_offset += elapsed; - - bool updated = false; - - while (!timer_map.empty() && next_time_offset <= time_offset && - timer_map.begin()->first.time_offset <= time_offset) - { - updated = true; - auto it = timer_map.begin(); - - auto& cb = it->second->cb; - auto msg = std::move(it->second); - timer_map.erase(it); - cb(std::move(msg)); - } - - if (updated && !timer_map.empty()) - { - next_time_offset = timer_map.begin()->first.time_offset; - } - } - - std::chrono::milliseconds get_current_time_offset() - { - return time_offset; - } - - private: - std::chrono::milliseconds time_offset = std::chrono::milliseconds(0); - uint64_t time_entry_counter = 0; - std::map, TimerEntryCompare> - timer_map; - std::chrono::milliseconds next_time_offset; - - void reverse_local_messages() - { - if (local_msg == nullptr) - return; - - ThreadMsg *prev = nullptr, *current = nullptr, *next = nullptr; - current = local_msg; - while (current != nullptr) - { - next = current->next; - current->next = prev; - prev = current; - current = next; - } - // now let the head point at the last node (prev) - local_msg = prev; - } - - void drop() - { - while (true) - { - if (local_msg == nullptr && item_head != nullptr) - { - local_msg = item_head.exchange(nullptr); - reverse_local_messages(); - } - - if (local_msg == nullptr) - { - break; - } - - ThreadMsg* current = local_msg; - local_msg = local_msg->next; - delete current; - } - } - - friend ThreadMessaging; - }; - - class ThreadMessaging - { - std::atomic finished; - std::vector tasks; // Fixed-size at construction - - // Drop all pending tasks, this is only ever to be used - // on shutdown, to avoid leaks, and after all thread but - // the main one have been shut down. - void drop_tasks() - { - for (auto& t : tasks) - { - t.drop(); - } - } - - inline TaskQueue& get_tasks(uint16_t task_id) - { - if (task_id >= tasks.size()) - { - throw std::runtime_error(fmt::format( - "Attempting to access task_id >= task_count, task_id:{}, " - "task_count:{}", - task_id, - tasks.size())); - } - return tasks[task_id]; - } - - static std::unique_ptr& get_singleton() - { - static std::unique_ptr singleton = nullptr; - return singleton; - } - - public: - static constexpr uint16_t max_num_threads = 24; - - ThreadMessaging(uint16_t num_task_queues) : - finished(false), - tasks(num_task_queues) - { - if (num_task_queues > max_num_threads) - { - throw std::logic_error(fmt::format( - "ThreadMessaging constructed with too many tasks: {} > {}", - num_task_queues, - max_num_threads)); - } - } - - ~ThreadMessaging() - { - drop_tasks(); - } - - static void init(uint16_t num_task_queues) - { - auto& singleton = get_singleton(); - if (singleton != nullptr) - { - throw std::logic_error("Called init() multiple times"); - } - - singleton = std::make_unique(num_task_queues); - } - - static void shutdown() - { - get_singleton().reset(); - } - - static ThreadMessaging& instance() - { - auto& singleton = get_singleton(); - if (singleton == nullptr) - { - throw std::logic_error( - "Attempted to access global ThreadMessaging instance without first " - "calling init()"); - } - - return *singleton; - } - - void set_finished(bool v = true) - { - finished.store(v); - } - - void run() - { - TaskQueue& task = get_tasks(ccf::threading::get_current_thread_id()); - - while (!is_finished()) - { - task.run_next_task(); - } - } - - bool run_one() - { - TaskQueue& task = get_tasks(ccf::threading::get_current_thread_id()); - return task.run_next_task(); - } - - template - void add_task(uint16_t tid, std::unique_ptr> msg) - { - TaskQueue& task = get_tasks(tid); - - task.add_task(reinterpret_cast(msg.release())); - } - - template - TaskQueue::TimerEntry add_task_after( - std::unique_ptr> msg, std::chrono::milliseconds ms) - { - TaskQueue& task = get_tasks(ccf::threading::get_current_thread_id()); - return task.add_task_after(std::move(msg), ms); - } - - bool cancel_timer_task(TaskQueue::TimerEntry timer_entry) - { - TaskQueue& task = get_tasks(ccf::threading::get_current_thread_id()); - return task.cancel_timer_task(timer_entry); - } - - std::chrono::milliseconds get_current_time_offset() - { - TaskQueue& task = get_tasks(ccf::threading::get_current_thread_id()); - return task.get_current_time_offset(); - } - - struct TickMsg - { - TickMsg(std::chrono::milliseconds elapsed_, TaskQueue& task_) : - elapsed(elapsed_), - task(task_) - {} - - std::chrono::milliseconds elapsed; - TaskQueue& task; - }; - - static void tick_cb(std::unique_ptr> msg) - { - msg->data.task.tick(msg->data.elapsed); - } - - void tick(std::chrono::milliseconds elapsed) - { - for (auto i = 0ul; i < tasks.size(); ++i) - { - auto& task = get_tasks(i); - auto msg = std::make_unique>(&tick_cb, elapsed, task); - task.add_task(msg.release()); - } - } - - uint16_t get_execution_thread(uint32_t i) - { - uint16_t tid = ccf::threading::MAIN_THREAD_ID; - if (tasks.size() > 1) - { - // If we have multiple task queues, then we distinguish the main thread - // from the remaining workers; anything asking for an execution thread - // does _not_ go to the main thread's queue - tid = (i % (tasks.size() - 1)); - ++tid; - } - - return tid; - } - - uint16_t thread_count() const - { - return tasks.size(); - } - - private: - bool is_finished() - { - return finished.load(); - } - }; -}; diff --git a/src/enclave/enclave.h b/src/enclave/enclave.h index f1de383c71d..e2cf67e368a 100644 --- a/src/enclave/enclave.h +++ b/src/enclave/enclave.h @@ -48,6 +48,7 @@ namespace ccf std::unique_ptr node; ringbuffer::WriterPtr to_host = nullptr; std::chrono::high_resolution_clock::time_point last_tick_time; + std::atomic worker_stop_signal = false; StartType start_type; @@ -231,9 +232,9 @@ namespace ccf lfs_access->register_message_handlers(bp.get_dispatcher()); DISPATCHER_SET_MESSAGE_HANDLER( - bp, AdminMessage::stop, [&bp](const uint8_t*, size_t) { + bp, AdminMessage::stop, [this, &bp](const uint8_t*, size_t) { bp.set_finished(); - ::threading::ThreadMessaging::instance().set_finished(); + this->worker_stop_signal.store(true); }); DISPATCHER_SET_MESSAGE_HANDLER( @@ -263,7 +264,7 @@ namespace ccf node->tick(elapsed_ms); historical_state_cache->tick(elapsed_ms); - ::threading::ThreadMessaging::instance().tick(elapsed_ms); + ccf::tasks::tick(elapsed_ms); // When recovering, no signature should be emitted while the // public ledger is being read if (!node->is_reading_public_ledger()) @@ -384,17 +385,24 @@ namespace ccf // First, read some messages from the ringbuffer auto read = bp.read_n(max_messages, circuit->read_from_outside()); - // Then, execute some thread messages - size_t thread_msg = 0; - while (thread_msg < max_messages && - ::threading::ThreadMessaging::instance().run_one()) + // Then, execute some tasks + auto& job_board = ccf::tasks::get_main_job_board(); + ccf::tasks::Task task = job_board.get_task(); + size_t tasks_done = 0; + while (task != nullptr) { - thread_msg++; + task->do_task(); + ++tasks_done; + if (tasks_done >= max_messages) + { + break; + } + task = job_board.get_task(); } - // If no messages were read from the ringbuffer and no thread - // messages were executed, idle - if (read == 0 && thread_msg == 0) + // If no messages were read from the ringbuffer and tasks were + // executed, idle + if (read == 0 && tasks_done == 0) { std::this_thread::yield(); } @@ -407,27 +415,22 @@ namespace ccf } } - struct Msg - { - uint64_t tid; - }; - - static void init_thread_cb(std::unique_ptr<::threading::Tmsg> msg) - { - LOG_DEBUG_FMT("First thread CB:{}", msg->data.tid); - } - bool run_worker() { LOG_DEBUG_FMT("Running worker thread"); { - auto msg = std::make_unique<::threading::Tmsg>(&init_thread_cb); - msg->data.tid = ccf::threading::get_current_thread_id(); - ::threading::ThreadMessaging::instance().add_task( - msg->data.tid, std::move(msg)); + auto& job_board = ccf::tasks::get_main_job_board(); + const auto timeout = std::chrono::milliseconds(100); - ::threading::ThreadMessaging::instance().run(); + while (!worker_stop_signal.load()) + { + auto task = job_board.wait_for_task(timeout); + if (task != nullptr) + { + task->do_task(); + } + } } return true; diff --git a/src/enclave/main.cpp b/src/enclave/main.cpp index 0a5c45be436..8132a6463f6 100644 --- a/src/enclave/main.cpp +++ b/src/enclave/main.cpp @@ -60,16 +60,6 @@ extern "C" { num_pending_threads = (uint16_t)num_worker_threads + 1; - - if (num_pending_threads > threading::ThreadMessaging::max_num_threads) - { - LOG_FAIL_FMT("Too many threads: {}", num_pending_threads); - return CreateNodeStatus::TooManyThreads; - } - - // Initialise singleton instance of ThreadMessaging, now that number of - // threads are known - threading::ThreadMessaging::init(num_pending_threads); } // 2-tx reconfiguration is currently experimental, disable it in release @@ -195,11 +185,7 @@ extern "C" if (tid == ccf::threading::MAIN_THREAD_ID) { auto s = e.load()->run_main(); - while (num_complete_threads != - threading::ThreadMessaging::instance().thread_count() - 1) - { - } - threading::ThreadMessaging::shutdown(); + return s; } auto s = e.load()->run_worker(); diff --git a/src/enclave/session.h b/src/enclave/session.h index 9a4790099fd..867676f16a5 100644 --- a/src/enclave/session.h +++ b/src/enclave/session.h @@ -3,8 +3,10 @@ #pragma once #include "ccf/node/session.h" -#include "ds/thread_messaging.h" #include "enclave/tls_session.h" +#include "tasks/ordered_tasks.h" +#include "tasks/task.h" +#include "tasks/task_system.h" #include "tcp/msg_types.h" #include @@ -15,20 +17,71 @@ namespace ccf public std::enable_shared_from_this { private: - size_t execution_thread; + std::shared_ptr task_scheduler; + std::atomic is_closing = false; - struct SendRecvMsg + struct SessionDataTask : public ccf::tasks::ITaskAction { std::vector data; std::shared_ptr self; + + SessionDataTask( + std::span d, std::shared_ptr s) : + self(s) + { + data.assign(d.begin(), d.end()); + } + }; + + struct HandleIncomingDataTask : public SessionDataTask + { + using SessionDataTask::SessionDataTask; + + void do_action() override + { + if (self->is_closing.load()) + { + return; + } + + self->handle_incoming_data_thread(std::move(data)); + } + + const std::string& get_name() const override + { + static const std::string name = + "ThreadedSession::HandleIncomingDataTask"; + return name; + } + }; + + struct SendDataTask : public SessionDataTask + { + using SessionDataTask::SessionDataTask; + + void do_action() override + { + self->send_data_thread(std::move(data)); + } + + const std::string& get_name() const override + { + static const std::string name = "ThreadedSession::SendDataTask"; + return name; + } }; public: - ThreadedSession(int64_t thread_affinity) + ThreadedSession(int64_t session_id) + { + task_scheduler = ccf::tasks::OrderedTasks::create( + ccf::tasks::get_main_job_board(), + fmt::format("Session {}", session_id)); + } + + ~ThreadedSession() { - execution_thread = - ::threading::ThreadMessaging::instance().get_execution_thread( - thread_affinity); + task_scheduler->cancel_task(); } // Implement Session::handle_incoming_data by dispatching a thread message @@ -37,19 +90,8 @@ namespace ccf { auto [_, body] = ringbuffer::read_message<::tcp::tcp_inbound>(data); - auto msg = std::make_unique<::threading::Tmsg>( - &handle_incoming_data_cb); - msg->data.self = this->shared_from_this(); - msg->data.data.assign(body.data, body.data + body.size); - - ::threading::ThreadMessaging::instance().add_task( - execution_thread, std::move(msg)); - } - - static void handle_incoming_data_cb( - std::unique_ptr<::threading::Tmsg> msg) - { - msg->data.self->handle_incoming_data_thread(std::move(msg->data.data)); + task_scheduler->add_action( + std::make_shared(body, shared_from_this())); } virtual void handle_incoming_data_thread(std::vector&& data) = 0; @@ -58,22 +100,21 @@ namespace ccf // that eventually invokes the virtual send_data_thread() void send_data(std::span data) override { - auto msg = - std::make_unique<::threading::Tmsg>(&send_data_cb); - msg->data.self = this->shared_from_this(); - msg->data.data.assign(data.begin(), data.end()); - - ::threading::ThreadMessaging::instance().add_task( - execution_thread, std::move(msg)); + task_scheduler->add_action( + std::make_shared(data, shared_from_this())); } - static void send_data_cb( - std::unique_ptr<::threading::Tmsg> msg) + virtual void send_data_thread(std::vector&& data) = 0; + + void close_session() override { - msg->data.self->send_data_thread(std::move(msg->data.data)); + is_closing.store(true); + + task_scheduler->add_action(ccf::tasks::make_basic_action( + [self = shared_from_this()]() { self->close_session_thread(); })); } - virtual void send_data_thread(std::vector&& data) = 0; + virtual void close_session_thread() = 0; }; class EncryptedSession : public ThreadedSession @@ -96,21 +137,9 @@ namespace ccf {} public: - void send_data(std::span data) override - { - // Override send_data rather than send_data_thread, as the TLSSession - // handles dispatching for thread affinity - tls_io->send_raw(data.data(), data.size()); - } - void send_data_thread(std::vector&& data) override { - throw std::logic_error("Unimplemented"); - } - - void close_session() override - { - tls_io->close(); + tls_io->send_data(data.data(), data.size()); } void handle_incoming_data_thread(std::vector&& data) override @@ -150,6 +179,11 @@ namespace ccf n_read = tls_io->read(data.data(), data.size(), false); } } + + void close_session_thread() override + { + tls_io->close(); + } }; class UnencryptedSession : public ccf::ThreadedSession @@ -178,7 +212,7 @@ namespace ccf serializer::ByteRange{data.data(), data.size()}); } - void close_session() override + void close_session_thread() override { RINGBUFFER_WRITE_MESSAGE( ::tcp::tcp_stop, to_host, session_id, std::string("Session closed")); diff --git a/src/enclave/tls_session.h b/src/enclave/tls_session.h index 44d15e5da6a..0b1c72a1009 100644 --- a/src/enclave/tls_session.h +++ b/src/enclave/tls_session.h @@ -5,7 +5,6 @@ #include "ds/internal_logger.h" #include "ds/messaging.h" #include "ds/ring_buffer.h" -#include "ds/thread_messaging.h" #include "tcp/msg_types.h" #include "tls/context.h" #include "tls/tls.h" @@ -32,7 +31,6 @@ namespace ccf protected: ringbuffer::WriterPtr to_host; ::tcp::ConnID session_id; - size_t execution_thread; private: std::vector pending_write; @@ -57,17 +55,6 @@ namespace ccf return status == ready || status == handshake; } - struct SendRecvMsg - { - std::vector data; - std::shared_ptr self; - }; - - struct EmptyMsg - { - std::shared_ptr self; - }; - public: TLSSession( int64_t session_id_, @@ -78,9 +65,6 @@ namespace ccf ctx(std::move(ctx_)), status(handshake) { - execution_thread = - ::threading::ThreadMessaging::instance().get_execution_thread( - session_id); ctx->set_bio(this, send_callback_openssl, recv_callback_openssl); } @@ -236,11 +220,6 @@ namespace ccf void recv_buffered(const uint8_t* data, size_t size) { - if (ccf::threading::get_current_thread_id() != execution_thread) - { - throw std::runtime_error("Called recv_buffered from incorrect thread"); - } - if (can_recv()) { pending_read.insert(pending_read.end(), data, data + size); @@ -252,32 +231,6 @@ namespace ccf void close() { status = closing; - if (ccf::threading::get_current_thread_id() != execution_thread) - { - auto msg = std::make_unique<::threading::Tmsg>(&close_cb); - msg->data.self = this->shared_from_this(); - - ::threading::ThreadMessaging::instance().add_task( - execution_thread, std::move(msg)); - } - else - { - // Close inline immediately - close_thread(); - } - } - - static void close_cb(std::unique_ptr<::threading::Tmsg> msg) - { - msg->data.self->close_thread(); - } - - virtual void close_thread() - { - if (ccf::threading::get_current_thread_id() != execution_thread) - { - throw std::runtime_error("Called close_thread from incorrect thread"); - } switch (status) { @@ -327,39 +280,8 @@ namespace ccf } } - void send_raw(const uint8_t* data, size_t size) + void send_data(const uint8_t* data, size_t size) { - if (ccf::threading::get_current_thread_id() != execution_thread) - { - auto msg = - std::make_unique<::threading::Tmsg>(&send_raw_cb); - msg->data.self = this->shared_from_this(); - msg->data.data = std::vector(data, data + size); - - ::threading::ThreadMessaging::instance().add_task( - execution_thread, std::move(msg)); - } - else - { - // Send inline immediately - send_raw_thread(data, size); - } - } - - private: - static void send_raw_cb(std::unique_ptr<::threading::Tmsg> msg) - { - msg->data.self->send_raw_thread( - msg->data.data.data(), msg->data.data.size()); - } - - void send_raw_thread(const uint8_t* data, size_t size) - { - if (ccf::threading::get_current_thread_id() != execution_thread) - { - throw std::runtime_error( - "Called send_raw_thread from incorrect thread"); - } // Writes as much of the data as possible. If the data cannot all // be written now, we store the remainder. We // will try to send pending writes again whenever write() is called. @@ -381,23 +303,14 @@ namespace ccf flush(); } + private: void send_buffered(const std::vector& data) { - if (ccf::threading::get_current_thread_id() != execution_thread) - { - throw std::runtime_error("Called send_buffered from incorrect thread"); - } - pending_write.insert(pending_write.end(), data.begin(), data.end()); } void flush() { - if (ccf::threading::get_current_thread_id() != execution_thread) - { - throw std::runtime_error("Called flush from incorrect thread"); - } - do_handshake(); if (!can_send()) @@ -574,10 +487,6 @@ namespace ccf int handle_recv(uint8_t* buf, size_t len) { - if (ccf::threading::get_current_thread_id() != execution_thread) - { - throw std::runtime_error("Called handle_recv from incorrect thread"); - } if (pending_read.size() > 0) { // Use the pending data vector. This is populated when the host diff --git a/src/http/curl.h b/src/http/curl.h index 1b822c66ce2..0b953372244 100644 --- a/src/http/curl.h +++ b/src/http/curl.h @@ -4,9 +4,7 @@ #include "ccf/ds/nonstd.h" #include "ccf/rest_verb.h" -#include "ccf/threading/thread_ids.h" #include "ds/internal_logger.h" -#include "ds/thread_messaging.h" #include "host/proxy.h" #include @@ -389,7 +387,6 @@ namespace ccf::curl std::unique_ptr response; ResponseHeaders response_headers; std::optional response_callback; - std::optional response_thread; public: CurlRequest( @@ -399,17 +396,14 @@ namespace ccf::curl UniqueSlist&& headers_, std::unique_ptr&& request_body_, std::unique_ptr&& response_, - std::optional&& response_callback_, - std::optional response_thread_ = - threading::get_current_thread_id()) : + std::optional&& response_callback_) : curl_handle(std::move(curl_handle_)), method(method_), url(std::move(url_)), headers(std::move(headers_)), request_body(std::move(request_body_)), response(std::move(response_)), - response_callback(std::move(response_callback_)), - response_thread(response_thread_) + response_callback(std::move(response_callback_)) { if (url.empty()) { @@ -542,11 +536,6 @@ namespace ccf::curl { return response_headers.data; } - - [[nodiscard]] std::optional get_response_thread() const - { - return response_thread; - } }; class CurlRequestCURLM : public UniqueCURLM @@ -606,26 +595,9 @@ namespace ccf::curl // destructor of CurlRequest curl_multi_remove_handle(p.get(), easy); - // dispatch the response handling to a thread for processing - if (request->get_response_thread().has_value()) - { - using Data = - std::tuple, CURLcode>; - ::threading::ThreadMessaging::instance().add_task( - request->get_response_thread().value(), - std::make_unique<::threading::Tmsg>( - [](std::unique_ptr<::threading::Tmsg> msg) { - auto& [curl_request, curl_code] = msg->data; - CurlRequest::handle_response( - std::move(curl_request), curl_code); - }, - std::make_tuple(std::move(request_data_ptr), result))); - } - else - { - // If the response thread is not set, run on the uv thread - CurlRequest::handle_response(std::move(request_data_ptr), result); - } + // handle response inline. Note that if this is expensive, it should + // defer its work to a task + CurlRequest::handle_response(std::move(request_data_ptr), result); } } while (msgq > 0); return running_handles; diff --git a/src/http/http2_session.h b/src/http/http2_session.h index d74a7f959f1..138303d56a9 100644 --- a/src/http/http2_session.h +++ b/src/http/http2_session.h @@ -258,9 +258,7 @@ namespace http responder_lookup(responder_lookup_) { server_parser->set_outgoing_data_handler( - [this](std::span data) { - this->tls_io->send_raw(data.data(), data.size()); - }); + [this](std::span data) { send_data(data); }); } ~HTTP2ServerSession() @@ -460,9 +458,7 @@ namespace http client_parser(*this) { client_parser.set_outgoing_data_handler( - [this](std::span data) { - this->tls_io->send_raw(data.data(), data.size()); - }); + [this](std::span data) { send_data(data); }); } bool parse(std::span data) override diff --git a/src/http/http_session.h b/src/http/http_session.h index d7016ceb2e0..2ba05cee40d 100644 --- a/src/http/http_session.h +++ b/src/http/http_session.h @@ -67,7 +67,7 @@ namespace http ccf::errors::RequestBodyTooLarge, e.what()}); - tls_io->close(); + close_session(); } catch (RequestHeaderTooLargeException& e) { @@ -83,7 +83,7 @@ namespace http ccf::errors::RequestHeaderTooLarge, e.what()}); - tls_io->close(); + close_session(); } catch (const std::exception& e) { @@ -113,7 +113,7 @@ namespace http {}, std::move(response_body)); - tls_io->close(); + close_session(); } return false; @@ -157,7 +157,7 @@ namespace http HTTP_STATUS_INTERNAL_SERVER_ERROR, ccf::errors::InternalError, fmt::format("Error constructing RpcContext: {}", e.what())}); - tls_io->close(); + close_session(); } std::shared_ptr search = @@ -181,7 +181,7 @@ namespace http if (rpc_ctx->terminate_session) { - tls_io->close(); + close_session(); } } } @@ -195,7 +195,7 @@ namespace http // On any exception, close the connection. LOG_FAIL_FMT("Closing connection"); LOG_DEBUG_FMT("Closing connection due to exception: {}", e.what()); - tls_io->close(); + close_session(); throw; } } @@ -224,7 +224,7 @@ namespace http ); auto data = response.build_response(); - tls_io->send_raw(data.data(), data.size()); + send_data(data); return true; } diff --git a/src/http/test/curl_test.cpp b/src/http/test/curl_test.cpp index f09b19a0930..52918bddc66 100644 --- a/src/http/test/curl_test.cpp +++ b/src/http/test/curl_test.cpp @@ -126,8 +126,7 @@ TEST_CASE("CurlmLibuvContext") std::move(headers), std::move(body), std::make_unique(SIZE_MAX), - std::move(response_callback), - std::nullopt); + std::move(response_callback)); ccf::curl::CurlmLibuvContextSingleton::get_instance()->attach_request( std::move(request)); @@ -192,8 +191,7 @@ TEST_CASE("CurlmLibuvContext slow") std::move(headers), std::move(body), std::make_unique(SIZE_MAX), - std::move(response_callback), - std::nullopt); + std::move(response_callback)); ccf::curl::CurlmLibuvContextSingleton::get_instance()->attach_request( std::move(request)); @@ -266,8 +264,7 @@ TEST_CASE("CurlmLibuvContext timeouts") std::move(headers), std::move(body), std::make_unique(SIZE_MAX), - std::move(response_callback), - std::nullopt); + std::move(response_callback)); ccf::curl::CurlmLibuvContextSingleton::get_instance()->attach_request( std::move(request)); @@ -346,8 +343,7 @@ TEST_CASE("CurlmLibuvContext multiple init") std::move(headers), std::move(body), std::make_unique(SIZE_MAX), - std::move(response_callback), - std::nullopt); + std::move(response_callback)); ccf::curl::CurlmLibuvContextSingleton::get_instance()->attach_request( std::move(request)); diff --git a/src/indexing/test/indexing.cpp b/src/indexing/test/indexing.cpp index 245cdb84560..b4b8ccbdfab 100644 --- a/src/indexing/test/indexing.cpp +++ b/src/indexing/test/indexing.cpp @@ -953,7 +953,6 @@ TEST_CASE( int main(int argc, char** argv) { - threading::ThreadMessaging::init(1); doctest::Context context; context.applyCommandLine(argc, argv); int res = context.run(); diff --git a/src/node/channels.h b/src/node/channels.h index f5757b33a22..69accec698e 100644 --- a/src/node/channels.h +++ b/src/node/channels.h @@ -15,8 +15,7 @@ #include "ds/internal_logger.h" #include "ds/serialized.h" #include "ds/state_machine.h" -#include "ds/thread_messaging.h" -#include "node_types.h" +#include "node/node_types.h" #include #include diff --git a/src/node/history.h b/src/node/history.h index 0c8b2d39143..28e1050169c 100644 --- a/src/node/history.h +++ b/src/node/history.h @@ -10,12 +10,13 @@ #include "crypto/openssl/hash.h" #include "crypto/openssl/key_pair.h" #include "ds/internal_logger.h" -#include "ds/thread_messaging.h" #include "endian.h" #include "kv/kv_types.h" #include "kv/store.h" #include "node_signature_verify.h" #include "service/tables/signatures.h" +#include "tasks/basic_task.h" +#include "tasks/task_system.h" #include #include @@ -573,8 +574,7 @@ namespace ccf ccf::crypto::COSEVerifierUniquePtr cose_verifier{}; std::vector cose_cert_cached{}; - std::optional<::threading::TaskQueue::TimerEntry> - emit_signature_timer_entry = std::nullopt; + ccf::tasks::Task emit_signature_periodic_task; size_t sig_tx_interval; size_t sig_ms_interval; @@ -645,72 +645,57 @@ namespace ccf void start_signature_emit_timer() override { - struct EmitSigMsg - { - EmitSigMsg(HashedTxHistory* self_) : self(self_) {} - HashedTxHistory* self; - }; - - auto emit_sig_msg = std::make_unique<::threading::Tmsg>( - [](std::unique_ptr<::threading::Tmsg> msg) { - auto self = msg->data.self; + const auto delay = std::chrono::milliseconds(sig_ms_interval); - std::unique_lock mguard( - self->signature_lock, std::defer_lock); + emit_signature_periodic_task = ccf::tasks::make_basic_task([this]() { + std::unique_lock mguard( + this->signature_lock, std::defer_lock); - bool should_emit_signature = false; + bool should_emit_signature = false; - if (mguard.try_lock()) + if (mguard.try_lock()) + { + auto consensus = this->store.get_consensus(); + if (consensus != nullptr) { - auto consensus = self->store.get_consensus(); - if (consensus != nullptr) + auto sig_disp = consensus->get_signature_disposition(); + switch (sig_disp) { - auto sig_disp = consensus->get_signature_disposition(); - switch (sig_disp) + case ccf::kv::Consensus::SignatureDisposition::CANT_REPLICATE: { - case ccf::kv::Consensus::SignatureDisposition::CANT_REPLICATE: - { - break; - } - case ccf::kv::Consensus::SignatureDisposition::CAN_SIGN: - { - if (self->store.committable_gap() > 0) - { - should_emit_signature = true; - } - break; - } - case ccf::kv::Consensus::SignatureDisposition::SHOULD_SIGN: + break; + } + case ccf::kv::Consensus::SignatureDisposition::CAN_SIGN: + { + if (this->store.committable_gap() > 0) { should_emit_signature = true; - break; } + break; + } + case ccf::kv::Consensus::SignatureDisposition::SHOULD_SIGN: + { + should_emit_signature = true; + break; } } } + } - if (should_emit_signature) - { - msg->data.self->emit_signature(); - } - - self->emit_signature_timer_entry = - ::threading::ThreadMessaging::instance().add_task_after( - std::move(msg), std::chrono::milliseconds(self->sig_ms_interval)); - }, - this); + if (should_emit_signature) + { + this->emit_signature(); + } + }); - emit_signature_timer_entry = - ::threading::ThreadMessaging::instance().add_task_after( - std::move(emit_sig_msg), std::chrono::milliseconds(sig_ms_interval)); + ccf::tasks::add_periodic_task(emit_signature_periodic_task, delay, delay); } ~HashedTxHistory() { - if (emit_signature_timer_entry.has_value()) + if (emit_signature_periodic_task != nullptr) { - ::threading::ThreadMessaging::instance().cancel_timer_task( - *emit_signature_timer_entry); + emit_signature_periodic_task->cancel_task(); } } diff --git a/src/node/jwt_key_auto_refresh.h b/src/node/jwt_key_auto_refresh.h index 8f5faf8d641..1e6f1d5be58 100644 --- a/src/node/jwt_key_auto_refresh.h +++ b/src/node/jwt_key_auto_refresh.h @@ -24,6 +24,8 @@ namespace ccf ccf::crypto::Pem node_cert; std::atomic_size_t attempts; + ccf::tasks::Task periodic_refresh_task; + public: JwtKeyAutoRefresh( size_t refresh_interval_s, @@ -43,62 +45,54 @@ namespace ccf attempts(0) {} - struct RefreshTimeMsg + ~JwtKeyAutoRefresh() { - RefreshTimeMsg(JwtKeyAutoRefresh& self_) : self(self_) {} - - JwtKeyAutoRefresh& self; - }; + stop(); + } void start() { - auto refresh_msg = std::make_unique<::threading::Tmsg>( - [](std::unique_ptr<::threading::Tmsg> msg) { - if (!msg->data.self.consensus->can_replicate()) - { - LOG_DEBUG_FMT( - "JWT key auto-refresh: Node is not primary, skipping"); - } - else - { - msg->data.self.refresh_jwt_keys(); - } - LOG_DEBUG_FMT( - "JWT key auto-refresh: Scheduling in {}s", - msg->data.self.refresh_interval_s); - auto delay = std::chrono::seconds(msg->data.self.refresh_interval_s); - ::threading::ThreadMessaging::instance().add_task_after( - std::move(msg), delay); - }, - *this); + LOG_DEBUG_FMT("JWT key initial auto-refresh"); + periodic_refresh_task = ccf::tasks::make_basic_task([this]() { + if (!this->consensus->can_replicate()) + { + LOG_DEBUG_FMT("JWT key auto-refresh: Node is not primary, skipping"); + } + else + { + this->refresh_jwt_keys(); + } - LOG_DEBUG_FMT( - "JWT key auto-refresh: Scheduling in {}s", refresh_interval_s); - auto delay = std::chrono::seconds(refresh_interval_s); - ::threading::ThreadMessaging::instance().add_task_after( - std::move(refresh_msg), delay); + LOG_DEBUG_FMT( + "JWT key auto-refresh: Scheduling in {}s", this->refresh_interval_s); + }); + + const std::chrono::seconds period(refresh_interval_s); + ccf::tasks::add_periodic_task(periodic_refresh_task, period, period); } - void schedule_once() + void stop() { - auto refresh_msg = std::make_unique<::threading::Tmsg>( - [](std::unique_ptr<::threading::Tmsg> msg) { - if (!msg->data.self.consensus->can_replicate()) - { - LOG_DEBUG_FMT( - "JWT key one-off refresh: Node is not primary, skipping"); - } - else - { - msg->data.self.refresh_jwt_keys(); - } - }, - *this); + if (periodic_refresh_task != nullptr) + { + periodic_refresh_task->cancel_task(); + } + } + void schedule_once() + { LOG_DEBUG_FMT("JWT key one-off refresh: Scheduling without delay"); - auto delay = std::chrono::seconds(0); - ::threading::ThreadMessaging::instance().add_task_after( - std::move(refresh_msg), delay); + ccf::tasks::add_task(ccf::tasks::make_basic_task([this]() { + if (!this->consensus->can_replicate()) + { + LOG_DEBUG_FMT( + "JWT key one-off refresh: Node is not primary, skipping"); + } + else + { + this->refresh_jwt_keys(); + } + })); } template diff --git a/src/node/node_state.h b/src/node/node_state.h index f4b62c3d8f9..afcd8c4510c 100644 --- a/src/node/node_state.h +++ b/src/node/node_state.h @@ -111,21 +111,6 @@ namespace ccf std::atomic stop_noticed = false; - struct NodeStateMsg - { - NodeStateMsg( - NodeState& self_, - View create_view_ = 0, - bool create_consortium_ = true) : - self(self_), - create_view(create_view_), - create_consortium(create_consortium_) - {} - NodeState& self; - View create_view; - bool create_consortium; - }; - // // kv store, replication, and I/O // @@ -173,6 +158,8 @@ namespace ccf // the lifetime of the node ccf::kv::Version startup_seqno = 0; + ccf::tasks::Task join_periodic_task; + std::shared_ptr make_encryptor() { #ifdef USE_NULL_ENCRYPTOR @@ -865,6 +852,12 @@ namespace ccf sm.advance(NodeStartupState::partOfNetwork); } + if (join_periodic_task != nullptr) + { + join_periodic_task->cancel_task(); + join_periodic_task = nullptr; + } + LOG_INFO_FMT( "Node has now joined the network as node {}: {}", self, @@ -925,23 +918,18 @@ namespace ccf { initiate_join_unsafe(); - auto timer_msg = std::make_unique<::threading::Tmsg>( - [](std::unique_ptr<::threading::Tmsg> msg) { - std::lock_guard guard(msg->data.self.lock); - if (msg->data.self.sm.check(NodeStartupState::pending)) - { - msg->data.self.initiate_join_unsafe(); - auto delay = std::chrono::milliseconds( - msg->data.self.config.join.retry_timeout); - - ::threading::ThreadMessaging::instance().add_task_after( - std::move(msg), delay); - } - }, - *this); + join_periodic_task = ccf::tasks::make_basic_task([this]() { + std::lock_guard guard(this->lock); + if (this->sm.check(NodeStartupState::pending)) + { + this->initiate_join_unsafe(); + } + }); - ::threading::ThreadMessaging::instance().add_task_after( - std::move(timer_msg), config.join.retry_timeout); + ccf::tasks::add_periodic_task( + join_periodic_task, + config.join.retry_timeout, + config.join.retry_timeout); } void auto_refresh_jwt_keys() @@ -2103,30 +2091,23 @@ namespace ccf { // Service creation transaction is asynchronous to avoid deadlocks // (e.g. https://github.com/microsoft/CCF/issues/3788) - auto msg = std::make_unique<::threading::Tmsg>( - [](std::unique_ptr<::threading::Tmsg> msg) { - if (!msg->data.self.send_create_request( - msg->data.self.serialize_create_request( - msg->data.create_view, msg->data.create_consortium))) + ccf::tasks::add_task( + ccf::tasks::make_basic_task([this, create_view, create_consortium]() { + if (!this->send_create_request( + this->serialize_create_request(create_view, create_consortium))) { throw std::runtime_error( "Service creation request could not be committed"); } - if (msg->data.create_consortium) + if (create_consortium) { - msg->data.self.advance_part_of_network(); + this->advance_part_of_network(); } else { - msg->data.self.advance_part_of_public_network(); + this->advance_part_of_public_network(); } - }, - *this, - create_view, - create_consortium); - - ::threading::ThreadMessaging::instance().add_task( - threading::get_current_thread_id(), std::move(msg)); + })); } void begin_private_recovery() diff --git a/src/node/node_to_node_channel_manager.h b/src/node/node_to_node_channel_manager.h index 19d8610d927..96bd07f3e03 100644 --- a/src/node/node_to_node_channel_manager.h +++ b/src/node/node_to_node_channel_manager.h @@ -4,6 +4,7 @@ #include "ccf/pal/locking.h" #include "channels.h" +#include "ds/ccf_assert.h" #include "node/node_to_node.h" namespace ccf diff --git a/src/node/quote_endorsements_client.h b/src/node/quote_endorsements_client.h index a131c42ae46..e5f2bcef945 100644 --- a/src/node/quote_endorsements_client.h +++ b/src/node/quote_endorsements_client.h @@ -4,7 +4,6 @@ #include "ccf/pal/attestation.h" #include "ccf/pal/attestation_sev_snp_endorsements.h" -#include "ds/thread_messaging.h" #include "enclave/rpc_sessions.h" #include "http/curl.h" @@ -138,54 +137,28 @@ namespace ccf fetch_unsafe(); } - void fetch_unsafe() + struct HandleResponseTask : public ccf::tasks::BaseTask { - const auto& server = servers.front(); - const auto& endpoint = server.front(); - - curl::UniqueCURL curl_handle; - - // Set curl get - curl_handle.set_opt(CURLOPT_HTTPGET, 1L); - // If the server does not respond at all within this time timeout - curl_handle.set_opt(CURLOPT_CONNECTTIMEOUT, server_connection_timeout_s); - // If the server does not completely response within this time timeout - curl_handle.set_opt(CURLOPT_TIMEOUT, server_response_timeout_s); - - auto url = fmt::format( - "{}://{}:{}{}{}", - endpoint.tls ? "https" : "http", - endpoint.host, - endpoint.port, - endpoint.uri, - get_formatted_query(endpoint.params)); - - if (endpoint.tls) - { - // Note: server CA is not checked here as this client is not sending - // private data. If the server was malicious and the certificate chain - // was bogus, the verification of the endorsement of the quote would - // fail anyway. - curl_handle.set_opt(CURLOPT_SSL_VERIFYHOST, 0L); - curl_handle.set_opt(CURLOPT_SSL_VERIFYPEER, 0L); - curl_handle.set_opt(CURLOPT_SSL_VERIFYSTATUS, 0L); - } + std::shared_ptr self; + std::unique_ptr request; + CURLcode curl_response; + long status_code; + + HandleResponseTask( + std::shared_ptr self_, + std::unique_ptr&& request_, + CURLcode curl_response_, + long status_code_) : + self(self_), + request(std::move(request_)), + curl_response(curl_response_), + status_code(status_code_) + {} - auto headers = ccf::curl::UniqueSlist(); - for (auto const& [k, v] : endpoint.headers) + void do_task_implementation() override { - headers.append(k, v); - } - headers.append(http::headers::HOST, endpoint.host); - - auto response_callback = ([this, lifetime = shared_from_this()]( - std::unique_ptr&& request, - CURLcode curl_response, - long status_code) { - std::lock_guard guard(this->lock); + std::lock_guard guard(self->lock); - const auto& server = servers.front(); - const auto& endpoint = server.front(); auto* response_body = request->get_response_body(); const auto& response_headers = request->get_response_headers(); @@ -196,7 +169,8 @@ namespace ccf "{} bytes", response_body->buffer.size()); - handle_success_response_unsafe(std::move(response_body->buffer)); + self->handle_success_response_unsafe( + std::move(response_body->buffer)); return; } @@ -206,15 +180,17 @@ namespace ccf curl_response, status_code); - if (server_retries_count >= max_retries_count(server)) + if ( + self->server_retries_count >= + max_retries_count(self->servers.front())) { - servers.pop_front(); + self->servers.pop_front(); - if (servers.empty()) + if (self->servers.empty()) { auto servers_tried = std::accumulate( - config.servers.begin(), - config.servers.end(), + self->config.servers.begin(), + self->config.servers.end(), std::string{}, [](const std::string& a, const Server& b) { return a + (a.length() > 0 ? ", " : "") + b.front().host; @@ -223,19 +199,21 @@ namespace ccf "Giving up retrying fetching attestation endorsements from [{}] " "after {} attempts", servers_tried, - total_retries_count); + self->total_retries_count); throw ccf::pal::AttestationCollateralFetchingTimeout( "Timed out fetching attestation endorsements from all " "configured servers"); } - server_retries_count = 0; - fetch_unsafe(); + self->server_retries_count = 0; + self->fetch_unsafe(); } else { - ++this->server_retries_count; - ++this->total_retries_count; + ++self->server_retries_count; + ++self->total_retries_count; + + const auto& endpoint = self->servers.front().front(); constexpr size_t default_retry_after_s = 3; size_t retry_after_s = default_retry_after_s; @@ -269,16 +247,77 @@ namespace ccf retry_after_s); } - auto msg = - std::make_unique<::threading::Tmsg>( - [](std::unique_ptr<::threading::Tmsg> - msg) { msg->data.self->fetch(); }, - shared_from_this(), - server); + const std::chrono::seconds retry_after(retry_after_s); - ::threading::ThreadMessaging::instance().add_task_after( - std::move(msg), std::chrono::seconds(retry_after_s)); + LOG_INFO_FMT( + "{} endorsements endpoint had too many requests. Retrying " + "in {}s", + endpoint, + retry_after_s); + + ccf::tasks::add_delayed_task( + ccf::tasks::make_basic_task( + [self = this->self]() { self->fetch(); }), + retry_after); } + } + + const std::string& get_name() const override + { + static const std::string name = + "QuoteEndorsementsClient::HandleResponseTask"; + return name; + } + }; + + void fetch_unsafe() + { + const auto& server = servers.front(); + const auto& endpoint = server.front(); + + curl::UniqueCURL curl_handle; + + // Set curl get + curl_handle.set_opt(CURLOPT_HTTPGET, 1L); + // If the server does not respond at all within this time timeout + curl_handle.set_opt(CURLOPT_CONNECTTIMEOUT, server_connection_timeout_s); + // If the server does not completely response within this time timeout + curl_handle.set_opt(CURLOPT_TIMEOUT, server_response_timeout_s); + + auto url = fmt::format( + "{}://{}:{}{}{}", + endpoint.tls ? "https" : "http", + endpoint.host, + endpoint.port, + endpoint.uri, + get_formatted_query(endpoint.params)); + + if (endpoint.tls) + { + // Note: server CA is not checked here as this client is not sending + // private data. If the server was malicious and the certificate chain + // was bogus, the verification of the endorsement of the quote would + // fail anyway. + curl_handle.set_opt(CURLOPT_SSL_VERIFYHOST, 0L); + curl_handle.set_opt(CURLOPT_SSL_VERIFYPEER, 0L); + curl_handle.set_opt(CURLOPT_SSL_VERIFYSTATUS, 0L); + } + + auto headers = ccf::curl::UniqueSlist(); + for (auto const& [k, v] : endpoint.headers) + { + headers.append(k, v); + } + headers.append(http::headers::HOST, endpoint.host); + + auto response_callback = ([self = shared_from_this()]( + std::unique_ptr&& request, + CURLcode curl_response, + long status_code) { + std::shared_ptr response_task = + std::make_shared( + self, std::move(request), curl_response, status_code); + ccf::tasks::add_task(response_task); }); auto request = std::make_unique( @@ -294,7 +333,6 @@ namespace ccf LOG_INFO_FMT( "Fetching endorsements for attestation report at {}", request->get_url()); - curl::CurlmLibuvContextSingleton::get_instance()->attach_request( std::move(request)); } diff --git a/src/node/retired_nodes_cleanup.h b/src/node/retired_nodes_cleanup.h index dd2de0de079..799e18bfa38 100644 --- a/src/node/retired_nodes_cleanup.h +++ b/src/node/retired_nodes_cleanup.h @@ -2,8 +2,9 @@ // Licensed under the Apache 2.0 License. #pragma once -#include "ds/thread_messaging.h" #include "node_client.h" +#include "tasks/basic_task.h" +#include "tasks/task_system.h" namespace ccf { @@ -30,24 +31,10 @@ namespace ccf node_client->make_request(request); } - struct RetiredNodeCleanupMsg - { - RetiredNodeCleanupMsg(RetiredNodeCleanup& self_) : self(self_) {} - - RetiredNodeCleanup& self; - }; - void cleanup() { - auto cleanup_msg = - std::make_unique<::threading::Tmsg>( - [](std::unique_ptr<::threading::Tmsg> msg) { - msg->data.self.send_cleanup_retired_nodes(); - }, - *this); - - ::threading::ThreadMessaging::instance().add_task( - ccf::threading::get_current_thread_id(), std::move(cleanup_msg)); + ccf::tasks::add_task(ccf::tasks::make_basic_task( + [this]() { this->send_cleanup_retired_nodes(); })); } }; } \ No newline at end of file diff --git a/src/node/rpc/forwarder.h b/src/node/rpc/forwarder.h index c1b6109db99..acf69ed3721 100644 --- a/src/node/rpc/forwarder.h +++ b/src/node/rpc/forwarder.h @@ -8,6 +8,8 @@ #include "http/http_rpc_context.h" #include "kv/kv_types.h" #include "node/node_to_node.h" +#include "tasks/basic_task.h" +#include "tasks/task_system.h" namespace ccf { @@ -34,58 +36,11 @@ namespace ccf using ForwardedCommandId = ForwardedHeader_v2::ForwardedCommandId; ForwardedCommandId next_command_id = 0; - struct TimeoutTask - { - ::threading::TaskQueue::TimerEntry timer_entry; - uint16_t thread_id; - }; - - std::unordered_map timeout_tasks; + std::unordered_map timeout_tasks; ccf::pal::Mutex timeout_tasks_lock; using IsCallerCertForwarded = bool; - struct SendTimeoutErrorMsg - { - SendTimeoutErrorMsg( - Forwarder* forwarder_, - const ccf::NodeId& to_, - size_t client_session_id_, - const std::chrono::milliseconds& timeout_) : - forwarder(forwarder_), - to(to_), - client_session_id(client_session_id_), - timeout(timeout_) - {} - - Forwarder* forwarder; - ccf::NodeId to; - size_t client_session_id; - std::chrono::milliseconds timeout; - }; - - struct CancelTimerMsg - { - ::threading::TaskQueue::TimerEntry timer_entry; - }; - - std::unique_ptr<::threading::Tmsg> - create_timeout_error_task( - const ccf::NodeId& to, - size_t client_session_id, - const std::chrono::milliseconds& timeout) - { - return std::make_unique<::threading::Tmsg>( - [](std::unique_ptr<::threading::Tmsg> msg) { - msg->data.forwarder->send_timeout_error_response( - msg->data.to, msg->data.client_session_id, msg->data.timeout); - }, - this, - to, - client_session_id, - timeout); - } - void send_timeout_error_response( NodeId to, size_t client_session_id, @@ -108,18 +63,6 @@ namespace ccf } } - static void cancel_forwarding_task_cb( - std::unique_ptr<::threading::Tmsg> msg) - { - cancel_forwarding_task(msg->data.timer_entry); - } - - static void cancel_forwarding_task( - ::threading::TaskQueue::TimerEntry timer_entry) - { - ::threading::ThreadMessaging::instance().cancel_timer_task(timer_entry); - } - public: Forwarder( std::weak_ptr rpcresponder, @@ -171,10 +114,11 @@ namespace ccf { std::lock_guard guard(timeout_tasks_lock); command_id = next_command_id++; - timeout_tasks[command_id] = { - ::threading::ThreadMessaging::instance().add_task_after( - create_timeout_error_task(to, client_session_id, timeout), timeout), - ccf::threading::get_current_thread_id()}; + auto task = ccf::tasks::make_basic_task([=, this]() { + this->send_timeout_error_response(to, client_session_id, timeout); + }); + timeout_tasks[command_id] = task; + ccf::tasks::add_delayed_task(task, timeout); } const auto view_opt = session_ctx->active_view; @@ -459,20 +403,7 @@ namespace ccf auto it = timeout_tasks.find(cmd_id); if (it != timeout_tasks.end()) { - if ( - ccf::threading::get_current_thread_id() != it->second.thread_id) - { - auto msg = std::make_unique<::threading::Tmsg>( - &cancel_forwarding_task_cb); - msg->data.timer_entry = it->second.timer_entry; - - ::threading::ThreadMessaging::instance().add_task( - it->second.thread_id, std::move(msg)); - } - else - { - cancel_forwarding_task(it->second.timer_entry); - } + it->second->cancel_task(); it = timeout_tasks.erase(it); } else diff --git a/src/node/rpc/test/frontend_test.cpp b/src/node/rpc/test/frontend_test.cpp index 00f907671c2..dafa668eb0f 100644 --- a/src/node/rpc/test/frontend_test.cpp +++ b/src/node/rpc/test/frontend_test.cpp @@ -1775,7 +1775,6 @@ TEST_CASE("Manual conflicts") int main(int argc, char** argv) { - ::threading::ThreadMessaging::init(1); doctest::Context context; context.applyCommandLine(argc, argv); int res = context.run(); diff --git a/src/node/snapshotter.h b/src/node/snapshotter.h index bddfbfa2189..7e462a30978 100644 --- a/src/node/snapshotter.h +++ b/src/node/snapshotter.h @@ -6,12 +6,12 @@ #include "consensus/ledger_enclave_types.h" #include "ds/ccf_assert.h" #include "ds/internal_logger.h" -#include "ds/thread_messaging.h" #include "kv/kv_types.h" #include "kv/store.h" #include "node/network_state.h" #include "node/snapshot_serdes.h" #include "service/tables/snapshot_evidence.h" +#include "tasks/task_system.h" #include #include @@ -96,18 +96,35 @@ namespace ccf serialised_receipt); } - struct SnapshotMsg + struct SnapshotTask : public ccf::tasks::BaseTask { std::shared_ptr self; std::unique_ptr snapshot; uint32_t generation_count; - }; - static void snapshot_cb(std::unique_ptr<::threading::Tmsg> msg) - { - msg->data.self->snapshot_( - std::move(msg->data.snapshot), msg->data.generation_count); - } + const std::string name; + + SnapshotTask( + std::shared_ptr _self, + std::unique_ptr&& _snapshot, + uint32_t _generation_count) : + self(_self), + snapshot(std::move(_snapshot)), + generation_count(_generation_count), + name(fmt::format( + "snapshot@{}[{}]", snapshot->get_version(), generation_count)) + {} + + void do_task_implementation() override + { + self->snapshot_(std::move(snapshot), generation_count); + } + + const std::string& get_name() const override + { + return name; + } + }; void snapshot_( std::unique_ptr snapshot, @@ -438,13 +455,13 @@ namespace ccf void schedule_snapshot(::consensus::Index idx) { static uint32_t generation_count = 0; - auto msg = std::make_unique<::threading::Tmsg>(&snapshot_cb); - msg->data.self = shared_from_this(); - msg->data.snapshot = store->snapshot_unsafe_maps(idx); - msg->data.generation_count = generation_count++; - auto& tm = ::threading::ThreadMessaging::instance(); - tm.add_task(tm.get_execution_thread(generation_count), std::move(msg)); + auto task = std::make_shared( + shared_from_this(), + store->snapshot_unsafe_maps(idx), + generation_count++); + + ccf::tasks::add_task(task); } void commit(::consensus::Index idx, bool generate_snapshot) override diff --git a/src/node/test/historical_queries.cpp b/src/node/test/historical_queries.cpp index 43b073ea920..3b517939e98 100644 --- a/src/node/test/historical_queries.cpp +++ b/src/node/test/historical_queries.cpp @@ -1993,7 +1993,6 @@ TEST_CASE("Valid merkle proof from receipts") int main(int argc, char** argv) { - threading::ThreadMessaging::init(1); doctest::Context context; context.applyCommandLine(argc, argv); int res = context.run(); diff --git a/src/node/test/history.cpp b/src/node/test/history.cpp index 774bc48ab18..7d91a754f8d 100644 --- a/src/node/test/history.cpp +++ b/src/node/test/history.cpp @@ -499,7 +499,6 @@ TEST_CASE( int main(int argc, char** argv) { - threading::ThreadMessaging::init(1); doctest::Context context; context.applyCommandLine(argc, argv); int res = context.run(); diff --git a/src/node/test/history_bench.cpp b/src/node/test/history_bench.cpp index 12055778ed2..73898587368 100644 --- a/src/node/test/history_bench.cpp +++ b/src/node/test/history_bench.cpp @@ -158,7 +158,6 @@ PICOBENCH(append_compact<1000>).iterations(sizes); int main(int argc, char* argv[]) { ccf::logger::config::level() = ccf::LoggerLevel::FATAL; - ::threading::ThreadMessaging::init(1); picobench::runner runner; runner.parse_cmd_line(argc, argv); diff --git a/src/node/test/snapshot.cpp b/src/node/test/snapshot.cpp index 57c9b651d99..c0b85ad0a65 100644 --- a/src/node/test/snapshot.cpp +++ b/src/node/test/snapshot.cpp @@ -178,7 +178,6 @@ TEST_CASE("Snapshot with merkle tree" * doctest::test_suite("snapshot")) int main(int argc, char** argv) { - threading::ThreadMessaging::init(1); doctest::Context context; context.applyCommandLine(argc, argv); int res = context.run(); diff --git a/src/node/test/snapshotter.cpp b/src/node/test/snapshotter.cpp index f3e9c236d26..2cfc94b82ad 100644 --- a/src/node/test/snapshotter.cpp +++ b/src/node/test/snapshotter.cpp @@ -21,6 +21,15 @@ auto node_kp = ccf::crypto::make_key_pair(); using StringString = ccf::kv::Map; using rb_msg = std::pair; +void run_one_task() +{ + auto task = ccf::tasks::get_main_job_board().get_task(); + if (task != nullptr) + { + task->do_task(); + } +} + auto read_ringbuffer_out(ringbuffer::Circuit& circuit) { std::optional idx = std::nullopt; @@ -167,7 +176,7 @@ TEST_CASE("Regular snapshotting") REQUIRE_FALSE(record_signature(history, snapshotter, snapshot_idx - 1)); commit_idx = snapshot_idx - 1; snapshotter->commit(commit_idx, true); - threading::ThreadMessaging::instance().run_one(); + run_one_task(); REQUIRE_THROWS_AS( read_latest_snapshot_evidence(network.tables), std::logic_error); @@ -183,7 +192,7 @@ TEST_CASE("Regular snapshotting") commit_idx = snapshot_idx + 1; snapshotter->commit(commit_idx, true); - threading::ThreadMessaging::instance().run_one(); + run_one_task(); REQUIRE(read_latest_snapshot_evidence(network.tables) == snapshot_idx); auto snapshot_allocate_msg = read_snapshot_allocate_out(eio); REQUIRE(snapshot_allocate_msg.has_value()); @@ -221,7 +230,7 @@ TEST_CASE("Regular snapshotting") commit_idx = snapshot_idx + 1; snapshotter->commit(commit_idx, true); - threading::ThreadMessaging::instance().run_one(); + run_one_task(); REQUIRE(read_latest_snapshot_evidence(network.tables) == snapshot_idx); auto snapshot_allocate_msg = read_snapshot_allocate_out(eio); REQUIRE(snapshot_allocate_msg.has_value()); @@ -255,7 +264,7 @@ TEST_CASE("Regular snapshotting") { commit_idx = snapshot_idx + 2; snapshotter->commit(commit_idx, true); - threading::ThreadMessaging::instance().run_one(); + run_one_task(); REQUIRE(read_ringbuffer_out(eio) == std::nullopt); } @@ -270,7 +279,7 @@ TEST_CASE("Regular snapshotting") commit_idx = snapshot_idx; snapshotter->commit(commit_idx, true); - threading::ThreadMessaging::instance().run_one(); + run_one_task(); REQUIRE(read_latest_snapshot_evidence(network.tables) == snapshot_idx); auto snapshot_allocate_msg = read_snapshot_allocate_out(eio); REQUIRE(snapshot_allocate_msg.has_value()); @@ -329,7 +338,7 @@ TEST_CASE("Rollback before snapshot is committed") REQUIRE(record_signature(history, snapshotter, snapshot_idx)); snapshotter->commit(snapshot_idx, true); - threading::ThreadMessaging::instance().run_one(); + run_one_task(); REQUIRE(read_latest_snapshot_evidence(network.tables) == snapshot_idx); auto snapshot_allocate_msg = read_snapshot_allocate_out(eio); @@ -364,7 +373,7 @@ TEST_CASE("Rollback before snapshot is committed") REQUIRE(record_signature(history, snapshotter, snapshot_idx)); snapshotter->commit(snapshot_idx, true); - threading::ThreadMessaging::instance().run_one(); + run_one_task(); REQUIRE(read_latest_snapshot_evidence(network.tables) == snapshot_idx); auto snapshot_allocate_msg = read_snapshot_allocate_out(eio); REQUIRE(snapshot_allocate_msg.has_value()); @@ -395,7 +404,7 @@ TEST_CASE("Rollback before snapshot is committed") REQUIRE_FALSE(record_signature(history, snapshotter, snapshot_idx)); snapshotter->commit(snapshot_idx, true); - threading::ThreadMessaging::instance().run_one(); + run_one_task(); REQUIRE(read_latest_snapshot_evidence(network.tables) == snapshot_idx); auto snapshot_allocate_msg = read_snapshot_allocate_out(eio); REQUIRE(snapshot_allocate_msg.has_value()); @@ -418,7 +427,7 @@ TEST_CASE("Rollback before snapshot is committed") read_ringbuffer_out(eio) == rb_msg({::consensus::snapshot_commit, snapshot_idx})); - threading::ThreadMessaging::instance().run_one(); + run_one_task(); } } @@ -484,7 +493,7 @@ TEST_CASE("Rekey ledger while snapshot is in progress") INFO("Finally, schedule snapshot creation"); { - threading::ThreadMessaging::instance().run_one(); + run_one_task(); REQUIRE(read_latest_snapshot_evidence(network.tables) == snapshot_idx); auto snapshot_allocate_msg = read_snapshot_allocate_out(eio); REQUIRE(snapshot_allocate_msg.has_value()); @@ -518,7 +527,6 @@ TEST_CASE("Rekey ledger while snapshot is in progress") int main(int argc, char** argv) { - threading::ThreadMessaging::init(1); doctest::Context context; context.applyCommandLine(argc, argv); int res = context.run(); diff --git a/src/quic/quic_session.h b/src/quic/quic_session.h index ca956764b55..688068dce80 100644 --- a/src/quic/quic_session.h +++ b/src/quic/quic_session.h @@ -6,7 +6,6 @@ #include "ds/messaging.h" #include "ds/pending_io.h" #include "ds/ring_buffer.h" -#include "ds/thread_messaging.h" #include "enclave/session.h" #include "udp/msg_types.h" @@ -20,7 +19,8 @@ namespace quic protected: ringbuffer::WriterPtr to_host; ccf::tls::ConnID session_id; - size_t execution_thread; + + std::shared_ptr task_scheduler; enum Status { @@ -55,12 +55,14 @@ namespace quic session_id(session_id_), status(handshake) { - execution_thread = - threading::ThreadMessaging::instance().get_execution_thread(session_id); + task_scheduler = ccf::tasks::OrderedTasks::create( + ccf::tasks::get_main_job_board(), + fmt::format("Session {}", session_id)); } ~QUICSession() { + task_scheduler->cancel_task(); // RINGBUFFER_WRITE_MESSAGE(quic::quic_closed, to_host, session_id); } @@ -144,45 +146,68 @@ namespace quic void recv_buffered(const uint8_t* data, size_t size, sockaddr addr) { - if (ccf::threading::get_current_thread_id() != execution_thread) - { - throw std::runtime_error("Called recv_buffered from incorrect thread"); - } LOG_TRACE_FMT("QUIC Session recv_buffered with {} bytes", size); pending_reads.emplace_back(const_cast(data), size, addr); do_handshake(); } - struct SendRecvMsg + struct SessionDataTask : public ccf::tasks::ITaskAction { - std::vector data; std::shared_ptr self; + std::vector data; sockaddr addr; + + SessionDataTask( + std::shared_ptr s, + std::span d, + sockaddr sa) : + self(s), + addr(sa) + { + data.assign(d.begin(), d.end()); + } }; - static void send_raw_cb(std::unique_ptr> msg) + struct SendDataTask : public SessionDataTask { - msg->data.self->send_raw_thread(msg->data.data, msg->data.addr); - } + using SessionDataTask::SessionDataTask; - void send_raw(const uint8_t* data, size_t size, sockaddr addr) + void do_action() override + { + self->send_raw_thread(data, addr); + } + + const std::string& get_name() const override + { + static const std::string name = "quic::SendDataTask"; + return name; + } + }; + + struct RecvDataTask : public SessionDataTask { - auto msg = std::make_unique>(&send_raw_cb); - msg->data.self = this->shared_from_this(); - msg->data.data = std::vector(data, data + size); - msg->data.addr = addr; + using SessionDataTask::SessionDataTask; - threading::ThreadMessaging::instance().add_task( - execution_thread, std::move(msg)); + void do_action() override + { + self->recv(data.data(), data.size(), addr); + } + + const std::string& get_name() const override + { + static const std::string name = "quic::RecvDataTask"; + return name; + } + }; + + void send_raw(const uint8_t* data, size_t size, sockaddr addr) + { + task_scheduler->add_action(std::make_shared( + shared_from_this(), std::span{data, size}, addr)); } void send_raw_thread(const std::vector& data, sockaddr addr) { - if (ccf::threading::get_current_thread_id() != execution_thread) - { - throw std::runtime_error( - "Called send_raw_thread from incorrect thread"); - } // Writes as much of the data as possible. If the data cannot all // be written now, we store the remainder. We // will try to send pending writes again whenever write() is called. @@ -206,22 +231,25 @@ namespace quic void send_buffered(const std::vector& data, sockaddr addr) { - if (ccf::threading::get_current_thread_id() != execution_thread) - { - throw std::runtime_error("Called send_buffered from incorrect thread"); - } - pending_writes.emplace_back( const_cast(data.data()), data.size(), addr); } - void flush() + void handle_incoming_data(std::span data) override { - if (ccf::threading::get_current_thread_id() != execution_thread) - { - throw std::runtime_error("Called flush from incorrect thread"); - } + auto [_, addr_family, addr_data, body] = + ringbuffer::read_message(data); + + task_scheduler->add_action(std::make_shared( + shared_from_this(), + body, + udp::sockaddr_decode(addr_family, addr_data))); + } + virtual void recv(const uint8_t* data_, size_t size_, sockaddr addr_) = 0; + + void flush() + { do_handshake(); if (status != ready) @@ -250,32 +278,15 @@ namespace quic PendingBuffer::clear_empty(pending_writes); } - struct EmptyMsg - { - std::shared_ptr self; - }; - - static void close_cb(std::unique_ptr> msg) - { - msg->data.self->close_thread(); - } - void close_session() override { - auto msg = std::make_unique>(&close_cb); - msg->data.self = this->shared_from_this(); - - threading::ThreadMessaging::instance().add_task( - execution_thread, std::move(msg)); + auto self = shared_from_this(); + task_scheduler->add_action( + ccf::tasks::make_basic_action([self]() { self->close_thread(); })); } void close_thread() { - if (ccf::threading::get_current_thread_id() != execution_thread) - { - throw std::runtime_error("Called close_thread from incorrect thread"); - } - switch (status) { case handshake: @@ -349,11 +360,6 @@ namespace quic int handle_recv(uint8_t* buf, size_t len, sockaddr addr) { - if (ccf::threading::get_current_thread_id() != execution_thread) - { - throw std::runtime_error("Called handle_recv from incorrect thread"); - } - size_t len_read = 0; for (auto& read : pending_reads) { @@ -417,27 +423,7 @@ namespace quic send_raw(data.data(), data.size(), addr); } - static void recv_cb(std::unique_ptr> msg) - { - reinterpret_cast(msg->data.self.get()) - ->recv_(msg->data.data.data(), msg->data.data.size(), msg->data.addr); - } - - void handle_incoming_data(std::span data) override - { - auto [_, addr_family, addr_data, body] = - ringbuffer::read_message(data); - - auto msg = std::make_unique>(&recv_cb); - msg->data.self = this->shared_from_this(); - msg->data.data.assign(body.data, body.data + body.size); - msg->data.addr = udp::sockaddr_decode(addr_family, addr_data); - - threading::ThreadMessaging::instance().add_task( - execution_thread, std::move(msg)); - } - - void recv_(const uint8_t* data_, size_t size_, sockaddr addr_) + void recv(const uint8_t* data_, size_t size_, sockaddr addr_) override { recv_buffered(data_, size_, addr_); addr = addr_;