@@ -88,6 +88,8 @@ pub trait CustomMessageHandler: wire::CustomMessageReader {
88
88
/// May return an `Err(())` if the features the peer supports are not sufficient to communicate
89
89
/// with us. Implementors should be somewhat conservative about doing so, however, as other
90
90
/// message handlers may still wish to communicate with this peer.
91
+ ///
92
+ /// [`Self::peer_disconnected`] will not be called if `Err(())` is returned.
91
93
fn peer_connected ( & self , their_node_id : PublicKey , msg : & Init , inbound : bool ) -> Result < ( ) , ( ) > ;
92
94
93
95
/// Gets the node feature flags which this handler itself supports. All available handlers are
@@ -119,6 +121,7 @@ impl RoutingMessageHandler for IgnoringMessageHandler {
119
121
Option < ( msgs:: ChannelAnnouncement , Option < msgs:: ChannelUpdate > , Option < msgs:: ChannelUpdate > ) > { None }
120
122
fn get_next_node_announcement ( & self , _starting_point : Option < & NodeId > ) -> Option < msgs:: NodeAnnouncement > { None }
121
123
fn peer_connected ( & self , _their_node_id : PublicKey , _init : & msgs:: Init , _inbound : bool ) -> Result < ( ) , ( ) > { Ok ( ( ) ) }
124
+ fn peer_disconnected ( & self , _their_node_id : PublicKey ) { }
122
125
fn handle_reply_channel_range ( & self , _their_node_id : PublicKey , _msg : msgs:: ReplyChannelRange ) -> Result < ( ) , LightningError > { Ok ( ( ) ) }
123
126
fn handle_reply_short_channel_ids_end ( & self , _their_node_id : PublicKey , _msg : msgs:: ReplyShortChannelIdsEnd ) -> Result < ( ) , LightningError > { Ok ( ( ) ) }
124
127
fn handle_query_channel_range ( & self , _their_node_id : PublicKey , _msg : msgs:: QueryChannelRange ) -> Result < ( ) , LightningError > { Ok ( ( ) ) }
@@ -1714,14 +1717,20 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
1714
1717
}
1715
1718
if let Err ( ( ) ) = self . message_handler . chan_handler . peer_connected ( their_node_id, & msg, peer_lock. inbound_connection ) {
1716
1719
log_debug ! ( logger, "Channel Handler decided we couldn't communicate with peer {}" , log_pubkey!( their_node_id) ) ;
1720
+ self . message_handler . route_handler . peer_disconnected ( their_node_id) ;
1717
1721
return Err ( PeerHandleError { } . into ( ) ) ;
1718
1722
}
1719
1723
if let Err ( ( ) ) = self . message_handler . onion_message_handler . peer_connected ( their_node_id, & msg, peer_lock. inbound_connection ) {
1720
1724
log_debug ! ( logger, "Onion Message Handler decided we couldn't communicate with peer {}" , log_pubkey!( their_node_id) ) ;
1725
+ self . message_handler . route_handler . peer_disconnected ( their_node_id) ;
1726
+ self . message_handler . chan_handler . peer_disconnected ( their_node_id) ;
1721
1727
return Err ( PeerHandleError { } . into ( ) ) ;
1722
1728
}
1723
1729
if let Err ( ( ) ) = self . message_handler . custom_message_handler . peer_connected ( their_node_id, & msg, peer_lock. inbound_connection ) {
1724
1730
log_debug ! ( logger, "Custom Message Handler decided we couldn't communicate with peer {}" , log_pubkey!( their_node_id) ) ;
1731
+ self . message_handler . route_handler . peer_disconnected ( their_node_id) ;
1732
+ self . message_handler . chan_handler . peer_disconnected ( their_node_id) ;
1733
+ self . message_handler . onion_message_handler . peer_disconnected ( their_node_id) ;
1725
1734
return Err ( PeerHandleError { } . into ( ) ) ;
1726
1735
}
1727
1736
@@ -2533,6 +2542,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
2533
2542
debug_assert ! ( peer. their_node_id. is_some( ) ) ;
2534
2543
if let Some ( ( node_id, _) ) = peer. their_node_id {
2535
2544
log_trace ! ( WithContext :: from( & self . logger, Some ( node_id) , None , None ) , "Disconnecting peer with id {} due to {}" , node_id, reason) ;
2545
+ self . message_handler . route_handler . peer_disconnected ( node_id) ;
2536
2546
self . message_handler . chan_handler . peer_disconnected ( node_id) ;
2537
2547
self . message_handler . onion_message_handler . peer_disconnected ( node_id) ;
2538
2548
self . message_handler . custom_message_handler . peer_disconnected ( node_id) ;
@@ -2557,6 +2567,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
2557
2567
let removed = self . node_id_to_descriptor . lock ( ) . unwrap ( ) . remove ( & node_id) ;
2558
2568
debug_assert ! ( removed. is_some( ) , "descriptor maps should be consistent" ) ;
2559
2569
if !peer. handshake_complete ( ) { return ; }
2570
+ self . message_handler . route_handler . peer_disconnected ( node_id) ;
2560
2571
self . message_handler . chan_handler . peer_disconnected ( node_id) ;
2561
2572
self . message_handler . onion_message_handler . peer_disconnected ( node_id) ;
2562
2573
self . message_handler . custom_message_handler . peer_disconnected ( node_id) ;
@@ -2856,6 +2867,16 @@ mod tests {
2856
2867
2857
2868
struct TestCustomMessageHandler {
2858
2869
features : InitFeatures ,
2870
+ conn_tracker : test_utils:: ConnectionTracker ,
2871
+ }
2872
+
2873
+ impl TestCustomMessageHandler {
2874
+ fn new ( features : InitFeatures ) -> Self {
2875
+ Self {
2876
+ features,
2877
+ conn_tracker : test_utils:: ConnectionTracker :: new ( ) ,
2878
+ }
2879
+ }
2859
2880
}
2860
2881
2861
2882
impl wire:: CustomMessageReader for TestCustomMessageHandler {
@@ -2872,10 +2893,13 @@ mod tests {
2872
2893
2873
2894
fn get_and_clear_pending_msg ( & self ) -> Vec < ( PublicKey , Self :: CustomMessage ) > { Vec :: new ( ) }
2874
2895
2896
+ fn peer_disconnected ( & self , their_node_id : PublicKey ) {
2897
+ self . conn_tracker . peer_disconnected ( their_node_id) ;
2898
+ }
2875
2899
2876
- fn peer_disconnected ( & self , _their_node_id : PublicKey ) { }
2877
-
2878
- fn peer_connected ( & self , _their_node_id : PublicKey , _msg : & Init , _inbound : bool ) -> Result < ( ) , ( ) > { Ok ( ( ) ) }
2900
+ fn peer_connected ( & self , their_node_id : PublicKey , _msg : & Init , _inbound : bool ) -> Result < ( ) , ( ) > {
2901
+ self . conn_tracker . peer_connected ( their_node_id )
2902
+ }
2879
2903
2880
2904
fn provided_node_features ( & self ) -> NodeFeatures { NodeFeatures :: empty ( ) }
2881
2905
@@ -2898,7 +2922,7 @@ mod tests {
2898
2922
chan_handler : test_utils:: TestChannelMessageHandler :: new ( ChainHash :: using_genesis_block ( Network :: Testnet ) ) ,
2899
2923
logger : test_utils:: TestLogger :: with_id ( i. to_string ( ) ) ,
2900
2924
routing_handler : test_utils:: TestRoutingMessageHandler :: new ( ) ,
2901
- custom_handler : TestCustomMessageHandler { features } ,
2925
+ custom_handler : TestCustomMessageHandler :: new ( features) ,
2902
2926
node_signer : test_utils:: TestNodeSigner :: new ( node_secret) ,
2903
2927
}
2904
2928
) ;
@@ -2921,7 +2945,7 @@ mod tests {
2921
2945
chan_handler : test_utils:: TestChannelMessageHandler :: new ( ChainHash :: using_genesis_block ( Network :: Testnet ) ) ,
2922
2946
logger : test_utils:: TestLogger :: new ( ) ,
2923
2947
routing_handler : test_utils:: TestRoutingMessageHandler :: new ( ) ,
2924
- custom_handler : TestCustomMessageHandler { features } ,
2948
+ custom_handler : TestCustomMessageHandler :: new ( features) ,
2925
2949
node_signer : test_utils:: TestNodeSigner :: new ( node_secret) ,
2926
2950
}
2927
2951
) ;
@@ -2941,7 +2965,7 @@ mod tests {
2941
2965
chan_handler : test_utils:: TestChannelMessageHandler :: new ( network) ,
2942
2966
logger : test_utils:: TestLogger :: new ( ) ,
2943
2967
routing_handler : test_utils:: TestRoutingMessageHandler :: new ( ) ,
2944
- custom_handler : TestCustomMessageHandler { features } ,
2968
+ custom_handler : TestCustomMessageHandler :: new ( features) ,
2945
2969
node_signer : test_utils:: TestNodeSigner :: new ( node_secret) ,
2946
2970
}
2947
2971
) ;
@@ -2965,19 +2989,16 @@ mod tests {
2965
2989
peers
2966
2990
}
2967
2991
2968
- fn establish_connection < ' a > ( peer_a : & PeerManager < FileDescriptor , & ' a test_utils:: TestChannelMessageHandler , & ' a test_utils:: TestRoutingMessageHandler , IgnoringMessageHandler , & ' a test_utils:: TestLogger , & ' a TestCustomMessageHandler , & ' a test_utils:: TestNodeSigner > , peer_b : & PeerManager < FileDescriptor , & ' a test_utils:: TestChannelMessageHandler , & ' a test_utils:: TestRoutingMessageHandler , IgnoringMessageHandler , & ' a test_utils:: TestLogger , & ' a TestCustomMessageHandler , & ' a test_utils:: TestNodeSigner > ) -> ( FileDescriptor , FileDescriptor ) {
2992
+ fn try_establish_connection < ' a > ( peer_a : & PeerManager < FileDescriptor , & ' a test_utils:: TestChannelMessageHandler , & ' a test_utils:: TestRoutingMessageHandler , IgnoringMessageHandler , & ' a test_utils:: TestLogger , & ' a TestCustomMessageHandler , & ' a test_utils:: TestNodeSigner > , peer_b : & PeerManager < FileDescriptor , & ' a test_utils:: TestChannelMessageHandler , & ' a test_utils:: TestRoutingMessageHandler , IgnoringMessageHandler , & ' a test_utils:: TestLogger , & ' a TestCustomMessageHandler , & ' a test_utils:: TestNodeSigner > ) -> ( FileDescriptor , FileDescriptor , Result < bool , PeerHandleError > , Result < bool , PeerHandleError > ) {
2993
+ let addr_a = SocketAddress :: TcpIpV4 { addr : [ 127 , 0 , 0 , 1 ] , port : 1000 } ;
2994
+ let addr_b = SocketAddress :: TcpIpV4 { addr : [ 127 , 0 , 0 , 1 ] , port : 1001 } ;
2995
+
2969
2996
static FD_COUNTER : AtomicUsize = AtomicUsize :: new ( 0 ) ;
2970
2997
let fd = FD_COUNTER . fetch_add ( 1 , Ordering :: Relaxed ) as u16 ;
2971
2998
2972
2999
let id_a = peer_a. node_signer . get_node_id ( Recipient :: Node ) . unwrap ( ) ;
2973
3000
let mut fd_a = FileDescriptor :: new ( fd) ;
2974
- let addr_a = SocketAddress :: TcpIpV4 { addr : [ 127 , 0 , 0 , 1 ] , port : 1000 } ;
2975
-
2976
- let id_b = peer_b. node_signer . get_node_id ( Recipient :: Node ) . unwrap ( ) ;
2977
- let features_a = peer_a. init_features ( id_b) ;
2978
- let features_b = peer_b. init_features ( id_a) ;
2979
3001
let mut fd_b = FileDescriptor :: new ( fd) ;
2980
- let addr_b = SocketAddress :: TcpIpV4 { addr : [ 127 , 0 , 0 , 1 ] , port : 1001 } ;
2981
3002
2982
3003
let initial_data = peer_b. new_outbound_connection ( id_a, fd_b. clone ( ) , Some ( addr_a. clone ( ) ) ) . unwrap ( ) ;
2983
3004
peer_a. new_inbound_connection ( fd_a. clone ( ) , Some ( addr_b. clone ( ) ) ) . unwrap ( ) ;
@@ -2989,11 +3010,30 @@ mod tests {
2989
3010
2990
3011
peer_b. process_events ( ) ;
2991
3012
let b_data = fd_b. outbound_data . lock ( ) . unwrap ( ) . split_off ( 0 ) ;
2992
- assert_eq ! ( peer_a. read_event( & mut fd_a, & b_data) . unwrap ( ) , false ) ;
3013
+ let a_refused = peer_a. read_event ( & mut fd_a, & b_data) ;
2993
3014
2994
3015
peer_a. process_events ( ) ;
2995
3016
let a_data = fd_a. outbound_data . lock ( ) . unwrap ( ) . split_off ( 0 ) ;
2996
- assert_eq ! ( peer_b. read_event( & mut fd_b, & a_data) . unwrap( ) , false ) ;
3017
+ let b_refused = peer_b. read_event ( & mut fd_b, & a_data) ;
3018
+
3019
+ ( fd_a, fd_b, a_refused, b_refused)
3020
+ }
3021
+
3022
+
3023
+ fn establish_connection < ' a > ( peer_a : & PeerManager < FileDescriptor , & ' a test_utils:: TestChannelMessageHandler , & ' a test_utils:: TestRoutingMessageHandler , IgnoringMessageHandler , & ' a test_utils:: TestLogger , & ' a TestCustomMessageHandler , & ' a test_utils:: TestNodeSigner > , peer_b : & PeerManager < FileDescriptor , & ' a test_utils:: TestChannelMessageHandler , & ' a test_utils:: TestRoutingMessageHandler , IgnoringMessageHandler , & ' a test_utils:: TestLogger , & ' a TestCustomMessageHandler , & ' a test_utils:: TestNodeSigner > ) -> ( FileDescriptor , FileDescriptor ) {
3024
+ let addr_a = SocketAddress :: TcpIpV4 { addr : [ 127 , 0 , 0 , 1 ] , port : 1000 } ;
3025
+ let addr_b = SocketAddress :: TcpIpV4 { addr : [ 127 , 0 , 0 , 1 ] , port : 1001 } ;
3026
+
3027
+ let id_a = peer_a. node_signer . get_node_id ( Recipient :: Node ) . unwrap ( ) ;
3028
+ let id_b = peer_b. node_signer . get_node_id ( Recipient :: Node ) . unwrap ( ) ;
3029
+
3030
+ let features_a = peer_a. init_features ( id_b) ;
3031
+ let features_b = peer_b. init_features ( id_a) ;
3032
+
3033
+ let ( fd_a, fd_b, a_refused, b_refused) = try_establish_connection ( peer_a, peer_b) ;
3034
+
3035
+ assert_eq ! ( a_refused. unwrap( ) , false ) ;
3036
+ assert_eq ! ( b_refused. unwrap( ) , false ) ;
2997
3037
2998
3038
assert_eq ! ( peer_a. peer_by_node_id( & id_b) . unwrap( ) . counterparty_node_id, id_b) ;
2999
3039
assert_eq ! ( peer_a. peer_by_node_id( & id_b) . unwrap( ) . socket_address, Some ( addr_b) ) ;
@@ -3246,6 +3286,50 @@ mod tests {
3246
3286
assert_eq ! ( peers[ 0 ] . peers. read( ) . unwrap( ) . len( ) , 0 ) ;
3247
3287
}
3248
3288
3289
+ fn do_test_peer_connected_error_disconnects ( handler : usize ) {
3290
+ // Test that if a message handler fails a connection in `peer_connected` we reliably
3291
+ // produce `peer_disconnected` events for all other message handlers (that saw a
3292
+ // corresponding `peer_connected`).
3293
+ let cfgs = create_peermgr_cfgs ( 2 ) ;
3294
+ let peers = create_network ( 2 , & cfgs) ;
3295
+
3296
+ match handler & !1 {
3297
+ 0 => {
3298
+ peers[ handler & 1 ] . message_handler . chan_handler . conn_tracker . fail_connections . store ( true , Ordering :: Release ) ;
3299
+ }
3300
+ 2 => {
3301
+ peers[ handler & 1 ] . message_handler . route_handler . conn_tracker . fail_connections . store ( true , Ordering :: Release ) ;
3302
+ }
3303
+ 4 => {
3304
+ peers[ handler & 1 ] . message_handler . custom_message_handler . conn_tracker . fail_connections . store ( true , Ordering :: Release ) ;
3305
+ }
3306
+ _ => panic ! ( ) ,
3307
+ }
3308
+ let ( _sd1, _sd2, a_refused, b_refused) = try_establish_connection ( & peers[ 0 ] , & peers[ 1 ] ) ;
3309
+ if handler & 1 == 0 {
3310
+ assert ! ( a_refused. is_err( ) ) ;
3311
+ assert ! ( peers[ 0 ] . list_peers( ) . is_empty( ) ) ;
3312
+ } else {
3313
+ assert ! ( b_refused. is_err( ) ) ;
3314
+ assert ! ( peers[ 1 ] . list_peers( ) . is_empty( ) ) ;
3315
+ }
3316
+ // At least one message handler should have seen the connection.
3317
+ assert ! ( peers[ handler & 1 ] . message_handler. chan_handler. conn_tracker. had_peers. load( Ordering :: Acquire ) ||
3318
+ peers[ handler & 1 ] . message_handler. route_handler. conn_tracker. had_peers. load( Ordering :: Acquire ) ||
3319
+ peers[ handler & 1 ] . message_handler. custom_message_handler. conn_tracker. had_peers. load( Ordering :: Acquire ) ) ;
3320
+ // And both message handlers doing tracking should see the disconnection
3321
+ assert ! ( peers[ handler & 1 ] . message_handler. chan_handler. conn_tracker. connected_peers. lock( ) . unwrap( ) . is_empty( ) ) ;
3322
+ assert ! ( peers[ handler & 1 ] . message_handler. route_handler. conn_tracker. connected_peers. lock( ) . unwrap( ) . is_empty( ) ) ;
3323
+ assert ! ( peers[ handler & 1 ] . message_handler. custom_message_handler. conn_tracker. connected_peers. lock( ) . unwrap( ) . is_empty( ) ) ;
3324
+ }
3325
+
3326
+ #[ test]
3327
+ fn test_peer_connected_error_disconnects ( ) {
3328
+ for i in 0 ..6 {
3329
+ do_test_peer_connected_error_disconnects ( i) ;
3330
+ }
3331
+ }
3332
+
3249
3333
#[ test]
3250
3334
fn test_do_attempt_write_data ( ) {
3251
3335
// Create 2 peers with custom TestRoutingMessageHandlers and connect them.
0 commit comments