@@ -7,7 +7,7 @@ use std::future::Future;
77use std:: ops:: Deref ;
88use std:: pin:: Pin ;
99use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
10- use std:: sync:: { Arc , Mutex } ;
10+ use std:: sync:: { Arc , Mutex , RwLock } ;
1111use std:: task:: { Context , Poll , Waker } ;
1212use std:: vec:: Vec ;
1313
@@ -82,10 +82,13 @@ fn channel() -> (ChannelSend, ChannelRecv) {
8282/// A [`HrnResolver`] which uses lightning onion messages and DNSSEC proofs to request DNS
8383/// resolution directly from untrusted lightning nodes, providing privacy through onion routing.
8484///
85- /// This implements LDK's [`DNSResolverMessageHandler`], which it uses to send onion messages (you
86- /// should make sure to call LDK's [`PeerManager::process_events`] after a query begins) and
85+ /// This implements LDK's [`DNSResolverMessageHandler`], which it uses to send onion messages and
8786/// process response messages.
8887///
88+ /// Note that after a query begines, [`PeerManager::process_events`] should be called to ensure the
89+ /// query message goes out in a timely manner. You can call [`Self::register_post_queue_action`] to
90+ /// have this happen automatically.
91+ ///
8992/// [`PeerManager::process_events`]: lightning::ln::peer_handler::PeerManager::process_events
9093pub struct LDKOnionMessageDNSSECHrnResolver < N : Deref < Target = NetworkGraph < L > > , L : Deref >
9194where
9699 next_id : AtomicUsize ,
97100 pending_resolutions : Mutex < HashMap < HumanReadableName , Vec < ( PaymentId , ChannelSend ) > > > ,
98101 message_queue : Mutex < Vec < ( DNSResolverMessage , MessageSendInstructions ) > > ,
102+ pm_event_poker : RwLock < Option < Box < dyn Fn ( ) + Send + Sync > > > ,
99103}
100104
101105impl < N : Deref < Target = NetworkGraph < L > > , L : Deref > LDKOnionMessageDNSSECHrnResolver < N , L >
@@ -113,9 +117,18 @@ where
113117 resolver : OMNameResolver :: new ( 0 , 0 ) ,
114118 pending_resolutions : Mutex :: new ( HashMap :: new ( ) ) ,
115119 message_queue : Mutex :: new ( Vec :: new ( ) ) ,
120+ pm_event_poker : RwLock :: new ( None ) ,
116121 }
117122 }
118123
124+ /// Sets a callback which is called any time a new resolution begins and a message is available
125+ /// to be sent. This should generally call [`PeerManager::process_events`].
126+ ///
127+ /// [`PeerManager::process_events`]: lightning::ln::peer_handler::PeerManager::process_events
128+ pub fn register_post_queue_action ( & self , callback : Box < dyn Fn ( ) + Send + Sync > ) {
129+ * self . pm_event_poker . write ( ) . unwrap ( ) = Some ( callback) ;
130+ }
131+
119132 fn init_resolve_hrn < ' a > (
120133 & ' a self , hrn : & HumanReadableName ,
121134 ) -> Result < ChannelRecv , & ' static str > {
@@ -168,31 +181,40 @@ where
168181 self . resolver . resolve_name ( payment_id, hrn. clone ( ) , & OsRng ) . map_err ( |_| err) ?;
169182 let context = MessageContext :: DNSResolver ( dns_context) ;
170183
171- let mut queue = self . message_queue . lock ( ) . unwrap ( ) ;
172- for destination in dns_resolvers {
173- let instructions =
174- MessageSendInstructions :: WithReplyPath { destination, context : context. clone ( ) } ;
175- queue. push ( ( DNSResolverMessage :: DNSSECQuery ( query. clone ( ) ) , instructions) ) ;
176- }
177-
178184 let ( send, recv) = channel ( ) ;
179- let mut pending_resolutions = self . pending_resolutions . lock ( ) . unwrap ( ) ;
180- let senders = pending_resolutions. entry ( hrn. clone ( ) ) . or_insert_with ( Vec :: new) ;
181- senders. push ( ( payment_id, send) ) ;
182-
183- // If we're running in no-std, we won't expire lookups with the time updates above, so walk
184- // the pending resolution list and expire them here.
185- pending_resolutions. retain ( |_name, resolutions| {
186- resolutions. retain ( |( _payment_id, resolution) | {
187- let has_receiver = resolution. receiver_alive ( ) ;
188- if !has_receiver {
189- // TODO: Once LDK 0.2 ships, expire the pending resolution in the resolver:
190- // self.resolver.expire_pending_resolution(name, payment_id);
191- }
192- has_receiver
185+ {
186+ let mut pending_resolutions = self . pending_resolutions . lock ( ) . unwrap ( ) ;
187+ let senders = pending_resolutions. entry ( hrn. clone ( ) ) . or_insert_with ( Vec :: new) ;
188+ senders. push ( ( payment_id, send) ) ;
189+
190+ // If we're running in no-std, we won't expire lookups with the time updates above, so walk
191+ // the pending resolution list and expire them here.
192+ pending_resolutions. retain ( |_name, resolutions| {
193+ resolutions. retain ( |( _payment_id, resolution) | {
194+ let has_receiver = resolution. receiver_alive ( ) ;
195+ if !has_receiver {
196+ // TODO: Once LDK 0.2 ships, expire the pending resolution in the resolver:
197+ // self.resolver.expire_pending_resolution(name, payment_id);
198+ }
199+ has_receiver
200+ } ) ;
201+ !resolutions. is_empty ( )
193202 } ) ;
194- !resolutions. is_empty ( )
195- } ) ;
203+ }
204+
205+ {
206+ let mut queue = self . message_queue . lock ( ) . unwrap ( ) ;
207+ for destination in dns_resolvers {
208+ let instructions =
209+ MessageSendInstructions :: WithReplyPath { destination, context : context. clone ( ) } ;
210+ queue. push ( ( DNSResolverMessage :: DNSSECQuery ( query. clone ( ) ) , instructions) ) ;
211+ }
212+ }
213+
214+ let callback = self . pm_event_poker . read ( ) . unwrap ( ) ;
215+ if let Some ( callback) = & * callback {
216+ callback ( ) ;
217+ }
196218
197219 Ok ( recv)
198220 }
0 commit comments