diff --git a/include/mp/proxy-io.h b/include/mp/proxy-io.h index 29cd3a4..83b725d 100644 --- a/include/mp/proxy-io.h +++ b/include/mp/proxy-io.h @@ -165,6 +165,8 @@ class EventLoop //! Add/remove remote client reference counts. void addClient(std::unique_lock<std::mutex>& lock); void removeClient(std::unique_lock<std::mutex>& lock); + //! Check if loop should exit. + bool done(std::unique_lock<std::mutex>& lock); Logger log() { diff --git a/src/mp/proxy.cpp b/src/mp/proxy.cpp index bcfbde8..c7516b5 100644 --- a/src/mp/proxy.cpp +++ b/src/mp/proxy.cpp @@ -188,29 +188,31 @@ void EventLoop::loop() kj::Own<kj::AsyncIoStream> wait_stream{ m_io_context.lowLevelProvider->wrapSocketFd(m_wait_fd, kj::LowLevelAsyncIoProvider::TAKE_OWNERSHIP)}; + int post_fd{m_post_fd}; char buffer = 0; for (;;) { size_t read_bytes = wait_stream->read(&buffer, 0, 1).wait(m_io_context.waitScope); - if (read_bytes == 1) { - std::unique_lock<std::mutex> lock(m_mutex); - if (m_post_fn) { - Unlock(lock, *m_post_fn); - m_post_fn = nullptr; - } - } else { - throw std::logic_error("EventLoop wait_stream closed unexpectedly"); - } - m_cv.notify_all(); - if (m_num_clients == 0 && m_async_fns.empty()) { - log() << "EventLoop::loop done, cancelling event listeners."; - m_task_set.reset(); - log() << "EventLoop::loop bye."; + if (read_bytes != 1) throw std::logic_error("EventLoop wait_stream closed unexpectedly"); + std::unique_lock<std::mutex> lock(m_mutex); + if (m_post_fn) { + Unlock(lock, *m_post_fn); + m_post_fn = nullptr; + m_cv.notify_all(); + } else if (done(lock)) { + // Intentionally do not break if m_post_fn was set, even if done() + // would return true, to ensure that the removeClient write(post_fd) + // call always succeeds and the loop does not exit between the time + // that the done condition is set and the write call is made. break; } } + log() << "EventLoop::loop done, cancelling event listeners."; + m_task_set.reset(); + log() << "EventLoop::loop bye."; wait_stream = nullptr; + KJ_SYSCALL(::close(post_fd)); + std::unique_lock<std::mutex> lock(m_mutex); m_wait_fd = -1; - KJ_SYSCALL(::close(m_post_fd)); m_post_fd = -1; } @@ -222,9 +224,10 @@ void EventLoop::post(const std::function<void()>& fn) std::unique_lock<std::mutex> lock(m_mutex); m_cv.wait(lock, [this] { return m_post_fn == nullptr; }); m_post_fn = &fn; + int post_fd{m_post_fd}; Unlock(lock, [&] { char buffer = 0; - KJ_SYSCALL(write(m_post_fd, &buffer, 1)); + KJ_SYSCALL(write(post_fd, &buffer, 1)); }); m_cv.wait(lock, [this, &fn] { return m_post_fn != &fn; }); } @@ -233,13 +236,13 @@ void EventLoop::addClient(std::unique_lock<std::mutex>& lock) { m_num_clients += void EventLoop::removeClient(std::unique_lock<std::mutex>& lock) { - assert(m_num_clients > 0); m_num_clients -= 1; - if (m_num_clients == 0) { + if (done(lock)) { m_cv.notify_all(); + int post_fd{m_post_fd}; Unlock(lock, [&] { char buffer = 0; - KJ_SYSCALL(write(m_post_fd, &buffer, 1)); // NOLINT(bugprone-suspicious-semicolon) + KJ_SYSCALL(write(post_fd, &buffer, 1)); // NOLINT(bugprone-suspicious-semicolon) }); } } @@ -268,6 +271,14 @@ void EventLoop::startAsyncThread(std::unique_lock<std::mutex>& lock) } } +bool EventLoop::done(std::unique_lock<std::mutex>& lock) +{ + assert(m_num_clients >= 0); + assert(lock.owns_lock()); + assert(lock.mutex() == &m_mutex); + return m_num_clients == 0 && m_async_fns.empty(); +} + std::tuple<ConnThread, bool> SetThread(ConnThreads& threads, std::mutex& mutex, Connection* connection, std::function<Thread::Client()> make_thread) { std::unique_lock<std::mutex> lock(mutex);