Skip to content

Commit 3a3cbef

Browse files
committed
f Move enqueue to EventQueueNotifierGuard to enforce it's held
1 parent 8cfa0ae commit 3a3cbef

File tree

6 files changed

+56
-117
lines changed

6 files changed

+56
-117
lines changed

lightning-liquidity/src/events/event_queue.rs

+27-90
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ use alloc::collections::VecDeque;
55
use alloc::vec::Vec;
66

77
use core::future::Future;
8-
#[cfg(debug_assertions)]
9-
use core::sync::atomic::{AtomicU8, Ordering};
108
use core::task::{Poll, Waker};
119

1210
/// The maximum queue size we allow before starting to drop events.
@@ -17,8 +15,6 @@ pub(crate) struct EventQueue {
1715
waker: Arc<Mutex<Option<Waker>>>,
1816
#[cfg(feature = "std")]
1917
condvar: Arc<crate::sync::Condvar>,
20-
#[cfg(debug_assertions)]
21-
num_held_notifier_guards: Arc<AtomicU8>,
2218
}
2319

2420
impl EventQueue {
@@ -30,25 +26,6 @@ impl EventQueue {
3026
waker,
3127
#[cfg(feature = "std")]
3228
condvar: Arc::new(crate::sync::Condvar::new()),
33-
#[cfg(debug_assertions)]
34-
num_held_notifier_guards: Arc::new(AtomicU8::new(0)),
35-
}
36-
}
37-
38-
pub fn enqueue<E: Into<LiquidityEvent>>(&self, event: E) {
39-
#[cfg(debug_assertions)]
40-
{
41-
let num_held_notifier_guards = self.num_held_notifier_guards.load(Ordering::Relaxed);
42-
debug_assert!(
43-
num_held_notifier_guards > 0,
44-
"We should be holding at least one notifier guard whenever enqueuing new events"
45-
);
46-
}
47-
let mut queue = self.queue.lock().unwrap();
48-
if queue.len() < MAX_EVENT_QUEUE_SIZE {
49-
queue.push_back(event.into());
50-
} else {
51-
return;
5229
}
5330
}
5431

@@ -91,76 +68,36 @@ impl EventQueue {
9168

9269
// Returns an [`EventQueueNotifierGuard`] that will notify about new event when dropped.
9370
pub fn notifier(&self) -> EventQueueNotifierGuard {
94-
#[cfg(debug_assertions)]
95-
{
96-
self.num_held_notifier_guards.fetch_add(1, Ordering::Relaxed);
97-
}
98-
EventQueueNotifierGuard {
99-
queue: Arc::clone(&self.queue),
100-
waker: Arc::clone(&self.waker),
101-
#[cfg(feature = "std")]
102-
condvar: Arc::clone(&self.condvar),
103-
#[cfg(debug_assertions)]
104-
num_held_notifier_guards: Arc::clone(&self.num_held_notifier_guards),
105-
}
106-
}
107-
}
108-
109-
impl Drop for EventQueue {
110-
fn drop(&mut self) {
111-
#[cfg(debug_assertions)]
112-
{
113-
let num_held_notifier_guards = self.num_held_notifier_guards.load(Ordering::Relaxed);
114-
debug_assert!(
115-
num_held_notifier_guards == 0,
116-
"We should not be holding any notifier guards when the event queue is dropped"
117-
);
118-
}
71+
EventQueueNotifierGuard(self)
11972
}
12073
}
12174

12275
// A guard type that will notify about new events when dropped.
12376
#[must_use]
124-
pub(crate) struct EventQueueNotifierGuard {
125-
queue: Arc<Mutex<VecDeque<LiquidityEvent>>>,
126-
waker: Arc<Mutex<Option<Waker>>>,
127-
#[cfg(feature = "std")]
128-
condvar: Arc<crate::sync::Condvar>,
129-
#[cfg(debug_assertions)]
130-
num_held_notifier_guards: Arc<AtomicU8>,
77+
pub(crate) struct EventQueueNotifierGuard<'a>(&'a EventQueue);
78+
79+
impl<'a> EventQueueNotifierGuard<'a> {
80+
pub fn enqueue<E: Into<LiquidityEvent>>(&self, event: E) {
81+
let mut queue = self.0.queue.lock().unwrap();
82+
if queue.len() < MAX_EVENT_QUEUE_SIZE {
83+
queue.push_back(event.into());
84+
} else {
85+
return;
86+
}
87+
}
13188
}
13289

133-
impl Drop for EventQueueNotifierGuard {
90+
impl<'a> Drop for EventQueueNotifierGuard<'a> {
13491
fn drop(&mut self) {
135-
let should_notify = !self.queue.lock().unwrap().is_empty();
92+
let should_notify = !self.0.queue.lock().unwrap().is_empty();
13693

13794
if should_notify {
138-
if let Some(waker) = self.waker.lock().unwrap().take() {
95+
if let Some(waker) = self.0.waker.lock().unwrap().take() {
13996
waker.wake();
14097
}
14198

14299
#[cfg(feature = "std")]
143-
self.condvar.notify_one();
144-
}
145-
146-
#[cfg(debug_assertions)]
147-
{
148-
let res = self.num_held_notifier_guards.fetch_update(
149-
Ordering::Relaxed,
150-
Ordering::Relaxed,
151-
|x| Some(x.saturating_sub(1)),
152-
);
153-
match res {
154-
Ok(previous_value) if previous_value == 0 => debug_assert!(
155-
false,
156-
"num_held_notifier_guards counter out-of-sync! This should never happen!"
157-
),
158-
Err(_) => debug_assert!(
159-
false,
160-
"num_held_notifier_guards counter out-of-sync! This should never happen!"
161-
),
162-
_ => {},
163-
}
100+
self.0.condvar.notify_one();
164101
}
165102
}
166103
}
@@ -209,8 +146,8 @@ mod tests {
209146
});
210147

211148
for _ in 0..3 {
212-
let _guard = event_queue.notifier();
213-
event_queue.enqueue(expected_event.clone());
149+
let guard = event_queue.notifier();
150+
guard.enqueue(expected_event.clone());
214151
}
215152

216153
assert_eq!(event_queue.wait_next_event(), expected_event);
@@ -235,25 +172,25 @@ mod tests {
235172
let mut delayed_enqueue = false;
236173

237174
for _ in 0..25 {
238-
let _guard = event_queue.notifier();
239-
event_queue.enqueue(expected_event.clone());
175+
let guard = event_queue.notifier();
176+
guard.enqueue(expected_event.clone());
240177
enqueued_events.fetch_add(1, Ordering::SeqCst);
241178
}
242179

243180
loop {
244181
tokio::select! {
245182
_ = tokio::time::sleep(Duration::from_millis(10)), if !delayed_enqueue => {
246-
let _guard = event_queue.notifier();
247-
event_queue.enqueue(expected_event.clone());
183+
let guard = event_queue.notifier();
184+
guard.enqueue(expected_event.clone());
248185
enqueued_events.fetch_add(1, Ordering::SeqCst);
249186
delayed_enqueue = true;
250187
}
251188
e = event_queue.next_event_async() => {
252189
assert_eq!(e, expected_event);
253190
received_events.fetch_add(1, Ordering::SeqCst);
254191

255-
let _guard = event_queue.notifier();
256-
event_queue.enqueue(expected_event.clone());
192+
let guard = event_queue.notifier();
193+
guard.enqueue(expected_event.clone());
257194
enqueued_events.fetch_add(1, Ordering::SeqCst);
258195
}
259196
e = event_queue.next_event_async() => {
@@ -285,9 +222,9 @@ mod tests {
285222
std::thread::spawn(move || {
286223
// Sleep a bit before we enqueue the events everybody is waiting for.
287224
std::thread::sleep(Duration::from_millis(20));
288-
let _guard = thread_queue.notifier();
289-
thread_queue.enqueue(thread_event.clone());
290-
thread_queue.enqueue(thread_event.clone());
225+
let guard = thread_queue.notifier();
226+
guard.enqueue(thread_event.clone());
227+
guard.enqueue(thread_event.clone());
291228
});
292229

293230
let e = event_queue.next_event_async().await;

lightning-liquidity/src/lsps0/client.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,11 @@ where
6161
fn handle_response(
6262
&self, response: LSPS0Response, counterparty_node_id: &PublicKey,
6363
) -> Result<(), LightningError> {
64-
let _event_queue_notifier = self.pending_events.notifier();
64+
let event_queue_notifier = self.pending_events.notifier();
6565

6666
match response {
6767
LSPS0Response::ListProtocols(LSPS0ListProtocolsResponse { protocols }) => {
68-
self.pending_events.enqueue(LSPS0ClientEvent::ListProtocolsResponse {
68+
event_queue_notifier.enqueue(LSPS0ClientEvent::ListProtocolsResponse {
6969
counterparty_node_id: *counterparty_node_id,
7070
protocols,
7171
});

lightning-liquidity/src/lsps1/client.rs

+12-12
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ where
110110
&self, request_id: LSPSRequestId, counterparty_node_id: &PublicKey,
111111
result: LSPS1GetInfoResponse,
112112
) -> Result<(), LightningError> {
113-
let _event_queue_notifier = self.pending_events.notifier();
113+
let event_queue_notifier = self.pending_events.notifier();
114114

115115
let outer_state_lock = self.per_peer_state.write().unwrap();
116116
match outer_state_lock.get(counterparty_node_id) {
@@ -127,7 +127,7 @@ where
127127
});
128128
}
129129

130-
self.pending_events.enqueue(LSPS1ClientEvent::SupportedOptionsReady {
130+
event_queue_notifier.enqueue(LSPS1ClientEvent::SupportedOptionsReady {
131131
counterparty_node_id: *counterparty_node_id,
132132
supported_options: result.options,
133133
request_id,
@@ -148,7 +148,7 @@ where
148148
&self, request_id: LSPSRequestId, counterparty_node_id: &PublicKey,
149149
error: LSPSResponseError,
150150
) -> Result<(), LightningError> {
151-
let _event_queue_notifier = self.pending_events.notifier();
151+
let event_queue_notifier = self.pending_events.notifier();
152152

153153
let outer_state_lock = self.per_peer_state.read().unwrap();
154154
match outer_state_lock.get(counterparty_node_id) {
@@ -165,7 +165,7 @@ where
165165
});
166166
}
167167

168-
self.pending_events.enqueue(LSPS1ClientEvent::SupportedOptionsRequestFailed {
168+
event_queue_notifier.enqueue(LSPS1ClientEvent::SupportedOptionsRequestFailed {
169169
request_id: request_id.clone(),
170170
counterparty_node_id: *counterparty_node_id,
171171
error: error.clone(),
@@ -227,7 +227,7 @@ where
227227
&self, request_id: LSPSRequestId, counterparty_node_id: &PublicKey,
228228
response: LSPS1CreateOrderResponse,
229229
) -> Result<(), LightningError> {
230-
let _event_queue_notifier = self.pending_events.notifier();
230+
let event_queue_notifier = self.pending_events.notifier();
231231

232232
let outer_state_lock = self.per_peer_state.read().unwrap();
233233
match outer_state_lock.get(counterparty_node_id) {
@@ -244,7 +244,7 @@ where
244244
});
245245
}
246246

247-
self.pending_events.enqueue(LSPS1ClientEvent::OrderCreated {
247+
event_queue_notifier.enqueue(LSPS1ClientEvent::OrderCreated {
248248
request_id,
249249
counterparty_node_id: *counterparty_node_id,
250250
order_id: response.order_id,
@@ -271,7 +271,7 @@ where
271271
&self, request_id: LSPSRequestId, counterparty_node_id: &PublicKey,
272272
error: LSPSResponseError,
273273
) -> Result<(), LightningError> {
274-
let _event_queue_notifier = self.pending_events.notifier();
274+
let event_queue_notifier = self.pending_events.notifier();
275275

276276
let outer_state_lock = self.per_peer_state.read().unwrap();
277277
match outer_state_lock.get(counterparty_node_id) {
@@ -288,7 +288,7 @@ where
288288
});
289289
}
290290

291-
self.pending_events.enqueue(LSPS1ClientEvent::OrderRequestFailed {
291+
event_queue_notifier.enqueue(LSPS1ClientEvent::OrderRequestFailed {
292292
request_id: request_id.clone(),
293293
counterparty_node_id: *counterparty_node_id,
294294
error: error.clone(),
@@ -350,7 +350,7 @@ where
350350
&self, request_id: LSPSRequestId, counterparty_node_id: &PublicKey,
351351
response: LSPS1CreateOrderResponse,
352352
) -> Result<(), LightningError> {
353-
let _event_queue_notifier = self.pending_events.notifier();
353+
let event_queue_notifier = self.pending_events.notifier();
354354

355355
let outer_state_lock = self.per_peer_state.read().unwrap();
356356
match outer_state_lock.get(counterparty_node_id) {
@@ -367,7 +367,7 @@ where
367367
});
368368
}
369369

370-
self.pending_events.enqueue(LSPS1ClientEvent::OrderStatus {
370+
event_queue_notifier.enqueue(LSPS1ClientEvent::OrderStatus {
371371
request_id,
372372
counterparty_node_id: *counterparty_node_id,
373373
order_id: response.order_id,
@@ -394,7 +394,7 @@ where
394394
&self, request_id: LSPSRequestId, counterparty_node_id: &PublicKey,
395395
error: LSPSResponseError,
396396
) -> Result<(), LightningError> {
397-
let _event_queue_notifier = self.pending_events.notifier();
397+
let event_queue_notifier = self.pending_events.notifier();
398398

399399
let outer_state_lock = self.per_peer_state.read().unwrap();
400400
match outer_state_lock.get(counterparty_node_id) {
@@ -411,7 +411,7 @@ where
411411
});
412412
}
413413

414-
self.pending_events.enqueue(LSPS1ClientEvent::OrderRequestFailed {
414+
event_queue_notifier.enqueue(LSPS1ClientEvent::OrderRequestFailed {
415415
request_id: request_id.clone(),
416416
counterparty_node_id: *counterparty_node_id,
417417
error: error.clone(),

lightning-liquidity/src/lsps1/service.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ where
198198
&self, request_id: LSPSRequestId, counterparty_node_id: &PublicKey,
199199
params: LSPS1CreateOrderRequest,
200200
) -> Result<(), LightningError> {
201+
let event_queue_notifier = self.pending_events.notifier();
201202
if !is_valid(&params.order, &self.config.supported_options.as_ref().unwrap()) {
202203
let response = LSPS1Response::CreateOrderError(LSPSResponseError {
203204
code: LSPS1_CREATE_ORDER_REQUEST_ORDER_MISMATCH_ERROR_CODE,
@@ -231,7 +232,7 @@ where
231232
.insert(request_id.clone(), LSPS1Request::CreateOrder(params.clone()));
232233
}
233234

234-
self.pending_events.enqueue(LSPS1ServiceEvent::RequestForPaymentDetails {
235+
event_queue_notifier.enqueue(LSPS1ServiceEvent::RequestForPaymentDetails {
235236
request_id,
236237
counterparty_node_id: *counterparty_node_id,
237238
order: params.order,
@@ -315,6 +316,7 @@ where
315316
&self, request_id: LSPSRequestId, counterparty_node_id: &PublicKey,
316317
params: LSPS1GetOrderRequest,
317318
) -> Result<(), LightningError> {
319+
let event_queue_notifier = self.pending_events.notifier();
318320
let outer_state_lock = self.per_peer_state.read().unwrap();
319321
match outer_state_lock.get(counterparty_node_id) {
320322
Some(inner_state_lock) => {
@@ -333,7 +335,7 @@ where
333335

334336
if let Err(e) = outbound_channel.awaiting_payment() {
335337
peer_state_lock.outbound_channels_by_order_id.remove(&params.order_id);
336-
self.pending_events.enqueue(LSPS1ServiceEvent::Refund {
338+
event_queue_notifier.enqueue(LSPS1ServiceEvent::Refund {
337339
request_id,
338340
counterparty_node_id: *counterparty_node_id,
339341
order_id: params.order_id,
@@ -345,7 +347,7 @@ where
345347
.pending_requests
346348
.insert(request_id.clone(), LSPS1Request::GetOrder(params.clone()));
347349

348-
self.pending_events.enqueue(LSPS1ServiceEvent::CheckPaymentConfirmation {
350+
event_queue_notifier.enqueue(LSPS1ServiceEvent::CheckPaymentConfirmation {
349351
request_id,
350352
counterparty_node_id: *counterparty_node_id,
351353
order_id: params.order_id,

lightning-liquidity/src/lsps2/client.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ where
191191
&self, request_id: LSPSRequestId, counterparty_node_id: &PublicKey,
192192
result: LSPS2GetInfoResponse,
193193
) -> Result<(), LightningError> {
194-
let _event_queue_notifier = self.pending_events.notifier();
194+
let event_queue_notifier = self.pending_events.notifier();
195195

196196
let outer_state_lock = self.per_peer_state.read().unwrap();
197197
match outer_state_lock.get(counterparty_node_id) {
@@ -208,7 +208,7 @@ where
208208
});
209209
}
210210

211-
self.pending_events.enqueue(LSPS2ClientEvent::OpeningParametersReady {
211+
event_queue_notifier.enqueue(LSPS2ClientEvent::OpeningParametersReady {
212212
request_id,
213213
counterparty_node_id: *counterparty_node_id,
214214
opening_fee_params_menu: result.opening_fee_params_menu,
@@ -259,7 +259,7 @@ where
259259
&self, request_id: LSPSRequestId, counterparty_node_id: &PublicKey,
260260
result: LSPS2BuyResponse,
261261
) -> Result<(), LightningError> {
262-
let _event_queue_notifier = self.pending_events.notifier();
262+
let event_queue_notifier = self.pending_events.notifier();
263263

264264
let outer_state_lock = self.per_peer_state.read().unwrap();
265265
match outer_state_lock.get(counterparty_node_id) {
@@ -276,7 +276,7 @@ where
276276
})?;
277277

278278
if let Ok(intercept_scid) = result.jit_channel_scid.to_scid() {
279-
self.pending_events.enqueue(LSPS2ClientEvent::InvoiceParametersReady {
279+
event_queue_notifier.enqueue(LSPS2ClientEvent::InvoiceParametersReady {
280280
request_id,
281281
counterparty_node_id: *counterparty_node_id,
282282
intercept_scid,

0 commit comments

Comments
 (0)