@@ -36,6 +36,7 @@ use bitcoin::{BlockHash, ScriptBuf, Transaction, Txid};
36
36
37
37
use core:: future:: Future ;
38
38
use core:: ops:: Deref ;
39
+ use core:: pin:: Pin ;
39
40
use core:: sync:: atomic:: { AtomicBool , Ordering } ;
40
41
use core:: task;
41
42
@@ -414,7 +415,7 @@ where
414
415
/// Returns `Err` on persistence failure, in which case the call may be safely retried.
415
416
///
416
417
/// [`Event::SpendableOutputs`]: crate::events::Event::SpendableOutputs
417
- pub fn track_spendable_outputs (
418
+ pub async fn track_spendable_outputs (
418
419
& self , output_descriptors : Vec < SpendableOutputDescriptor > , channel_id : Option < ChannelId > ,
419
420
exclude_static_outputs : bool , delay_until_height : Option < u32 > ,
420
421
) -> Result < ( ) , ( ) > {
@@ -430,29 +431,34 @@ where
430
431
return Ok ( ( ) ) ;
431
432
}
432
433
433
- let mut state_lock = self . sweeper_state . lock ( ) . unwrap ( ) ;
434
- for descriptor in relevant_descriptors {
435
- let output_info = TrackedSpendableOutput {
436
- descriptor,
437
- channel_id,
438
- status : OutputSpendStatus :: PendingInitialBroadcast {
439
- delayed_until_height : delay_until_height,
440
- } ,
441
- } ;
442
-
443
- let mut outputs = state_lock. persistent . outputs . iter ( ) ;
444
- if outputs. find ( |o| o. descriptor == output_info. descriptor ) . is_some ( ) {
445
- continue ;
446
- }
434
+ let persist_fut;
435
+ {
436
+ let mut state_lock = self . sweeper_state . lock ( ) . unwrap ( ) ;
437
+ for descriptor in relevant_descriptors {
438
+ let output_info = TrackedSpendableOutput {
439
+ descriptor,
440
+ channel_id,
441
+ status : OutputSpendStatus :: PendingInitialBroadcast {
442
+ delayed_until_height : delay_until_height,
443
+ } ,
444
+ } ;
447
445
448
- state_lock. persistent . outputs . push ( output_info) ;
446
+ let mut outputs = state_lock. persistent . outputs . iter ( ) ;
447
+ if outputs. find ( |o| o. descriptor == output_info. descriptor ) . is_some ( ) {
448
+ continue ;
449
+ }
450
+
451
+ state_lock. persistent . outputs . push ( output_info) ;
452
+ }
453
+ persist_fut = self . persist_state ( & state_lock. persistent ) ;
454
+ state_lock. dirty = false ;
449
455
}
450
- self . persist_state ( & state_lock. persistent ) . map_err ( |e| {
451
- log_error ! ( self . logger, "Error persisting OutputSweeper: {:?}" , e) ;
452
- } ) ?;
453
- state_lock. dirty = false ;
454
456
455
- Ok ( ( ) )
457
+ persist_fut. await . map_err ( |e| {
458
+ self . sweeper_state . lock ( ) . unwrap ( ) . dirty = true ;
459
+
460
+ log_error ! ( self . logger, "Error persisting OutputSweeper: {:?}" , e) ;
461
+ } )
456
462
}
457
463
458
464
/// Returns a list of the currently tracked spendable outputs.
@@ -508,30 +514,42 @@ where
508
514
} ;
509
515
510
516
// See if there is anything to sweep before requesting a change address.
517
+ let persist_fut;
518
+ let has_respends;
511
519
{
512
520
let mut sweeper_state = self . sweeper_state . lock ( ) . unwrap ( ) ;
513
521
514
522
let cur_height = sweeper_state. persistent . best_block . height ;
515
- let has_respends =
523
+ has_respends =
516
524
sweeper_state. persistent . outputs . iter ( ) . any ( |o| filter_fn ( o, cur_height) ) ;
517
- if !has_respends {
525
+ if !has_respends && sweeper_state . dirty {
518
526
// If there is nothing to sweep, we still persist the state if it is dirty.
519
- if sweeper_state. dirty {
520
- self . persist_state ( & sweeper_state. persistent ) . map_err ( |e| {
521
- log_error ! ( self . logger, "Error persisting OutputSweeper: {:?}" , e) ;
522
- } ) ?;
523
- sweeper_state. dirty = false ;
524
- }
525
-
526
- return Ok ( ( ) ) ;
527
+ persist_fut = Some ( self . persist_state ( & sweeper_state. persistent ) ) ;
528
+ sweeper_state. dirty = false ;
529
+ } else {
530
+ persist_fut = None ;
527
531
}
528
532
}
529
533
534
+ if let Some ( persist_fut) = persist_fut {
535
+ persist_fut. await . map_err ( |e| {
536
+ self . sweeper_state . lock ( ) . unwrap ( ) . dirty = true ;
537
+
538
+ log_error ! ( self . logger, "Error persisting OutputSweeper: {:?}" , e) ;
539
+ } ) ?;
540
+ } ;
541
+
542
+ if !has_respends {
543
+ // If there is nothing to sweep, we return early.
544
+ return Ok ( ( ) ) ;
545
+ }
546
+
530
547
// Request a new change address outside of the mutex to avoid the mutex crossing await.
531
548
let change_destination_script =
532
549
self . change_destination_source . get_change_destination_script ( ) . await ?;
533
550
534
551
// Sweep the outputs.
552
+ let persist_fut;
535
553
{
536
554
let mut sweeper_state = self . sweeper_state . lock ( ) . unwrap ( ) ;
537
555
@@ -581,14 +599,17 @@ where
581
599
output_info. status . broadcast ( cur_hash, cur_height, spending_tx. clone ( ) ) ;
582
600
}
583
601
584
- self . persist_state ( & sweeper_state. persistent ) . map_err ( |e| {
585
- log_error ! ( self . logger, "Error persisting OutputSweeper: {:?}" , e) ;
586
- } ) ?;
602
+ persist_fut = self . persist_state ( & sweeper_state. persistent ) ;
587
603
sweeper_state. dirty = false ;
588
-
589
604
self . broadcaster . broadcast_transactions ( & [ & spending_tx] ) ;
590
605
}
591
606
607
+ persist_fut. await . map_err ( |e| {
608
+ self . sweeper_state . lock ( ) . unwrap ( ) . dirty = true ;
609
+
610
+ log_error ! ( self . logger, "Error persisting OutputSweeper: {:?}" , e) ;
611
+ } ) ?;
612
+
592
613
Ok ( ( ) )
593
614
}
594
615
@@ -614,25 +635,19 @@ where
614
635
sweeper_state. dirty = true ;
615
636
}
616
637
617
- fn persist_state ( & self , sweeper_state : & PersistentSweeperState ) -> Result < ( ) , io:: Error > {
618
- self . kv_store
619
- . write (
620
- OUTPUT_SWEEPER_PERSISTENCE_PRIMARY_NAMESPACE ,
621
- OUTPUT_SWEEPER_PERSISTENCE_SECONDARY_NAMESPACE ,
622
- OUTPUT_SWEEPER_PERSISTENCE_KEY ,
623
- & sweeper_state. encode ( ) ,
624
- )
625
- . map_err ( |e| {
626
- log_error ! (
627
- self . logger,
628
- "Write for key {}/{}/{} failed due to: {}" ,
629
- OUTPUT_SWEEPER_PERSISTENCE_PRIMARY_NAMESPACE ,
630
- OUTPUT_SWEEPER_PERSISTENCE_SECONDARY_NAMESPACE ,
631
- OUTPUT_SWEEPER_PERSISTENCE_KEY ,
632
- e
633
- ) ;
634
- e
635
- } )
638
+ fn persist_state < ' a > (
639
+ & self , sweeper_state : & PersistentSweeperState ,
640
+ ) -> Pin < Box < dyn Future < Output = Result < ( ) , io:: Error > > + ' a + Send > > {
641
+ let encoded = & sweeper_state. encode ( ) ;
642
+
643
+ let result = self . kv_store . write (
644
+ OUTPUT_SWEEPER_PERSISTENCE_PRIMARY_NAMESPACE ,
645
+ OUTPUT_SWEEPER_PERSISTENCE_SECONDARY_NAMESPACE ,
646
+ OUTPUT_SWEEPER_PERSISTENCE_KEY ,
647
+ encoded,
648
+ ) ;
649
+
650
+ Box :: pin ( async move { result } )
636
651
}
637
652
638
653
fn spend_outputs (
@@ -1005,16 +1020,18 @@ where
1005
1020
}
1006
1021
1007
1022
/// Tells the sweeper to track the given outputs descriptors. Wraps [`OutputSweeper::track_spendable_outputs`].
1008
- pub fn track_spendable_outputs (
1023
+ pub async fn track_spendable_outputs (
1009
1024
& self , output_descriptors : Vec < SpendableOutputDescriptor > , channel_id : Option < ChannelId > ,
1010
1025
exclude_static_outputs : bool , delay_until_height : Option < u32 > ,
1011
1026
) -> Result < ( ) , ( ) > {
1012
- self . sweeper . track_spendable_outputs (
1013
- output_descriptors,
1014
- channel_id,
1015
- exclude_static_outputs,
1016
- delay_until_height,
1017
- )
1027
+ self . sweeper
1028
+ . track_spendable_outputs (
1029
+ output_descriptors,
1030
+ channel_id,
1031
+ exclude_static_outputs,
1032
+ delay_until_height,
1033
+ )
1034
+ . await
1018
1035
}
1019
1036
1020
1037
/// Returns a list of the currently tracked spendable outputs. Wraps [`OutputSweeper::tracked_spendable_outputs`].
0 commit comments