Skip to content

Commit ca40276

Browse files
authored
Merge pull request #3580 from TheBlueMatt/2025-01-peer-connected-consistency
Ensure `peer_disconnected` is called after a handler refuses a connection
2 parents f5c0433 + 4bc597a commit ca40276

File tree

5 files changed

+157
-21
lines changed

5 files changed

+157
-21
lines changed

lightning-net-tokio/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,7 @@ mod tests {
689689
) -> Result<(), ()> {
690690
Ok(())
691691
}
692+
fn peer_disconnected(&self, _their_node_id: PublicKey) {}
692693
fn handle_reply_channel_range(
693694
&self, _their_node_id: PublicKey, _msg: ReplyChannelRange,
694695
) -> Result<(), LightningError> {

lightning/src/ln/msgs.rs

+8
Original file line numberDiff line numberDiff line change
@@ -1579,6 +1579,8 @@ pub trait ChannelMessageHandler : MessageSendEventsProvider {
15791579
/// May return an `Err(())` if the features the peer supports are not sufficient to communicate
15801580
/// with us. Implementors should be somewhat conservative about doing so, however, as other
15811581
/// message handlers may still wish to communicate with this peer.
1582+
///
1583+
/// [`Self::peer_disconnected`] will not be called if `Err(())` is returned.
15821584
fn peer_connected(&self, their_node_id: PublicKey, msg: &Init, inbound: bool) -> Result<(), ()>;
15831585
/// Handle an incoming `channel_reestablish` message from the given peer.
15841586
fn handle_channel_reestablish(&self, their_node_id: PublicKey, msg: &ChannelReestablish);
@@ -1657,7 +1659,11 @@ pub trait RoutingMessageHandler : MessageSendEventsProvider {
16571659
/// May return an `Err(())` if the features the peer supports are not sufficient to communicate
16581660
/// with us. Implementors should be somewhat conservative about doing so, however, as other
16591661
/// message handlers may still wish to communicate with this peer.
1662+
///
1663+
/// [`Self::peer_disconnected`] will not be called if `Err(())` is returned.
16601664
fn peer_connected(&self, their_node_id: PublicKey, init: &Init, inbound: bool) -> Result<(), ()>;
1665+
/// Indicates a connection to the peer failed/an existing connection was lost.
1666+
fn peer_disconnected(&self, their_node_id: PublicKey);
16611667
/// Handles the reply of a query we initiated to learn about channels
16621668
/// for a given range of blocks. We can expect to receive one or more
16631669
/// replies to a single query.
@@ -1708,6 +1714,8 @@ pub trait OnionMessageHandler {
17081714
/// May return an `Err(())` if the features the peer supports are not sufficient to communicate
17091715
/// with us. Implementors should be somewhat conservative about doing so, however, as other
17101716
/// message handlers may still wish to communicate with this peer.
1717+
///
1718+
/// [`Self::peer_disconnected`] will not be called if `Err(())` is returned.
17111719
fn peer_connected(&self, their_node_id: PublicKey, init: &Init, inbound: bool) -> Result<(), ()>;
17121720

17131721
/// Indicates a connection to the peer failed/an existing connection was lost. Allows handlers to

lightning/src/ln/peer_handler.rs

+99-15
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ pub trait CustomMessageHandler: wire::CustomMessageReader {
8888
/// May return an `Err(())` if the features the peer supports are not sufficient to communicate
8989
/// with us. Implementors should be somewhat conservative about doing so, however, as other
9090
/// message handlers may still wish to communicate with this peer.
91+
///
92+
/// [`Self::peer_disconnected`] will not be called if `Err(())` is returned.
9193
fn peer_connected(&self, their_node_id: PublicKey, msg: &Init, inbound: bool) -> Result<(), ()>;
9294

9395
/// Gets the node feature flags which this handler itself supports. All available handlers are
@@ -119,6 +121,7 @@ impl RoutingMessageHandler for IgnoringMessageHandler {
119121
Option<(msgs::ChannelAnnouncement, Option<msgs::ChannelUpdate>, Option<msgs::ChannelUpdate>)> { None }
120122
fn get_next_node_announcement(&self, _starting_point: Option<&NodeId>) -> Option<msgs::NodeAnnouncement> { None }
121123
fn peer_connected(&self, _their_node_id: PublicKey, _init: &msgs::Init, _inbound: bool) -> Result<(), ()> { Ok(()) }
124+
fn peer_disconnected(&self, _their_node_id: PublicKey) { }
122125
fn handle_reply_channel_range(&self, _their_node_id: PublicKey, _msg: msgs::ReplyChannelRange) -> Result<(), LightningError> { Ok(()) }
123126
fn handle_reply_short_channel_ids_end(&self, _their_node_id: PublicKey, _msg: msgs::ReplyShortChannelIdsEnd) -> Result<(), LightningError> { Ok(()) }
124127
fn handle_query_channel_range(&self, _their_node_id: PublicKey, _msg: msgs::QueryChannelRange) -> Result<(), LightningError> { Ok(()) }
@@ -1714,14 +1717,20 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
17141717
}
17151718
if let Err(()) = self.message_handler.chan_handler.peer_connected(their_node_id, &msg, peer_lock.inbound_connection) {
17161719
log_debug!(logger, "Channel Handler decided we couldn't communicate with peer {}", log_pubkey!(their_node_id));
1720+
self.message_handler.route_handler.peer_disconnected(their_node_id);
17171721
return Err(PeerHandleError { }.into());
17181722
}
17191723
if let Err(()) = self.message_handler.onion_message_handler.peer_connected(their_node_id, &msg, peer_lock.inbound_connection) {
17201724
log_debug!(logger, "Onion Message Handler decided we couldn't communicate with peer {}", log_pubkey!(their_node_id));
1725+
self.message_handler.route_handler.peer_disconnected(their_node_id);
1726+
self.message_handler.chan_handler.peer_disconnected(their_node_id);
17211727
return Err(PeerHandleError { }.into());
17221728
}
17231729
if let Err(()) = self.message_handler.custom_message_handler.peer_connected(their_node_id, &msg, peer_lock.inbound_connection) {
17241730
log_debug!(logger, "Custom Message Handler decided we couldn't communicate with peer {}", log_pubkey!(their_node_id));
1731+
self.message_handler.route_handler.peer_disconnected(their_node_id);
1732+
self.message_handler.chan_handler.peer_disconnected(their_node_id);
1733+
self.message_handler.onion_message_handler.peer_disconnected(their_node_id);
17251734
return Err(PeerHandleError { }.into());
17261735
}
17271736

@@ -2533,6 +2542,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
25332542
debug_assert!(peer.their_node_id.is_some());
25342543
if let Some((node_id, _)) = peer.their_node_id {
25352544
log_trace!(WithContext::from(&self.logger, Some(node_id), None, None), "Disconnecting peer with id {} due to {}", node_id, reason);
2545+
self.message_handler.route_handler.peer_disconnected(node_id);
25362546
self.message_handler.chan_handler.peer_disconnected(node_id);
25372547
self.message_handler.onion_message_handler.peer_disconnected(node_id);
25382548
self.message_handler.custom_message_handler.peer_disconnected(node_id);
@@ -2557,6 +2567,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
25572567
let removed = self.node_id_to_descriptor.lock().unwrap().remove(&node_id);
25582568
debug_assert!(removed.is_some(), "descriptor maps should be consistent");
25592569
if !peer.handshake_complete() { return; }
2570+
self.message_handler.route_handler.peer_disconnected(node_id);
25602571
self.message_handler.chan_handler.peer_disconnected(node_id);
25612572
self.message_handler.onion_message_handler.peer_disconnected(node_id);
25622573
self.message_handler.custom_message_handler.peer_disconnected(node_id);
@@ -2856,6 +2867,16 @@ mod tests {
28562867

28572868
struct TestCustomMessageHandler {
28582869
features: InitFeatures,
2870+
conn_tracker: test_utils::ConnectionTracker,
2871+
}
2872+
2873+
impl TestCustomMessageHandler {
2874+
fn new(features: InitFeatures) -> Self {
2875+
Self {
2876+
features,
2877+
conn_tracker: test_utils::ConnectionTracker::new(),
2878+
}
2879+
}
28592880
}
28602881

28612882
impl wire::CustomMessageReader for TestCustomMessageHandler {
@@ -2872,10 +2893,13 @@ mod tests {
28722893

28732894
fn get_and_clear_pending_msg(&self) -> Vec<(PublicKey, Self::CustomMessage)> { Vec::new() }
28742895

2896+
fn peer_disconnected(&self, their_node_id: PublicKey) {
2897+
self.conn_tracker.peer_disconnected(their_node_id);
2898+
}
28752899

2876-
fn peer_disconnected(&self, _their_node_id: PublicKey) {}
2877-
2878-
fn peer_connected(&self, _their_node_id: PublicKey, _msg: &Init, _inbound: bool) -> Result<(), ()> { Ok(()) }
2900+
fn peer_connected(&self, their_node_id: PublicKey, _msg: &Init, _inbound: bool) -> Result<(), ()> {
2901+
self.conn_tracker.peer_connected(their_node_id)
2902+
}
28792903

28802904
fn provided_node_features(&self) -> NodeFeatures { NodeFeatures::empty() }
28812905

@@ -2898,7 +2922,7 @@ mod tests {
28982922
chan_handler: test_utils::TestChannelMessageHandler::new(ChainHash::using_genesis_block(Network::Testnet)),
28992923
logger: test_utils::TestLogger::with_id(i.to_string()),
29002924
routing_handler: test_utils::TestRoutingMessageHandler::new(),
2901-
custom_handler: TestCustomMessageHandler { features },
2925+
custom_handler: TestCustomMessageHandler::new(features),
29022926
node_signer: test_utils::TestNodeSigner::new(node_secret),
29032927
}
29042928
);
@@ -2921,7 +2945,7 @@ mod tests {
29212945
chan_handler: test_utils::TestChannelMessageHandler::new(ChainHash::using_genesis_block(Network::Testnet)),
29222946
logger: test_utils::TestLogger::new(),
29232947
routing_handler: test_utils::TestRoutingMessageHandler::new(),
2924-
custom_handler: TestCustomMessageHandler { features },
2948+
custom_handler: TestCustomMessageHandler::new(features),
29252949
node_signer: test_utils::TestNodeSigner::new(node_secret),
29262950
}
29272951
);
@@ -2941,7 +2965,7 @@ mod tests {
29412965
chan_handler: test_utils::TestChannelMessageHandler::new(network),
29422966
logger: test_utils::TestLogger::new(),
29432967
routing_handler: test_utils::TestRoutingMessageHandler::new(),
2944-
custom_handler: TestCustomMessageHandler { features },
2968+
custom_handler: TestCustomMessageHandler::new(features),
29452969
node_signer: test_utils::TestNodeSigner::new(node_secret),
29462970
}
29472971
);
@@ -2965,19 +2989,16 @@ mod tests {
29652989
peers
29662990
}
29672991

2968-
fn establish_connection<'a>(peer_a: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler, &'a test_utils::TestRoutingMessageHandler, IgnoringMessageHandler, &'a test_utils::TestLogger, &'a TestCustomMessageHandler, &'a test_utils::TestNodeSigner>, peer_b: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler, &'a test_utils::TestRoutingMessageHandler, IgnoringMessageHandler, &'a test_utils::TestLogger, &'a TestCustomMessageHandler, &'a test_utils::TestNodeSigner>) -> (FileDescriptor, FileDescriptor) {
2992+
fn try_establish_connection<'a>(peer_a: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler, &'a test_utils::TestRoutingMessageHandler, IgnoringMessageHandler, &'a test_utils::TestLogger, &'a TestCustomMessageHandler, &'a test_utils::TestNodeSigner>, peer_b: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler, &'a test_utils::TestRoutingMessageHandler, IgnoringMessageHandler, &'a test_utils::TestLogger, &'a TestCustomMessageHandler, &'a test_utils::TestNodeSigner>) -> (FileDescriptor, FileDescriptor, Result<bool, PeerHandleError>, Result<bool, PeerHandleError>) {
2993+
let addr_a = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1000};
2994+
let addr_b = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1001};
2995+
29692996
static FD_COUNTER: AtomicUsize = AtomicUsize::new(0);
29702997
let fd = FD_COUNTER.fetch_add(1, Ordering::Relaxed) as u16;
29712998

29722999
let id_a = peer_a.node_signer.get_node_id(Recipient::Node).unwrap();
29733000
let mut fd_a = FileDescriptor::new(fd);
2974-
let addr_a = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1000};
2975-
2976-
let id_b = peer_b.node_signer.get_node_id(Recipient::Node).unwrap();
2977-
let features_a = peer_a.init_features(id_b);
2978-
let features_b = peer_b.init_features(id_a);
29793001
let mut fd_b = FileDescriptor::new(fd);
2980-
let addr_b = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1001};
29813002

29823003
let initial_data = peer_b.new_outbound_connection(id_a, fd_b.clone(), Some(addr_a.clone())).unwrap();
29833004
peer_a.new_inbound_connection(fd_a.clone(), Some(addr_b.clone())).unwrap();
@@ -2989,11 +3010,30 @@ mod tests {
29893010

29903011
peer_b.process_events();
29913012
let b_data = fd_b.outbound_data.lock().unwrap().split_off(0);
2992-
assert_eq!(peer_a.read_event(&mut fd_a, &b_data).unwrap(), false);
3013+
let a_refused = peer_a.read_event(&mut fd_a, &b_data);
29933014

29943015
peer_a.process_events();
29953016
let a_data = fd_a.outbound_data.lock().unwrap().split_off(0);
2996-
assert_eq!(peer_b.read_event(&mut fd_b, &a_data).unwrap(), false);
3017+
let b_refused = peer_b.read_event(&mut fd_b, &a_data);
3018+
3019+
(fd_a, fd_b, a_refused, b_refused)
3020+
}
3021+
3022+
3023+
fn establish_connection<'a>(peer_a: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler, &'a test_utils::TestRoutingMessageHandler, IgnoringMessageHandler, &'a test_utils::TestLogger, &'a TestCustomMessageHandler, &'a test_utils::TestNodeSigner>, peer_b: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler, &'a test_utils::TestRoutingMessageHandler, IgnoringMessageHandler, &'a test_utils::TestLogger, &'a TestCustomMessageHandler, &'a test_utils::TestNodeSigner>) -> (FileDescriptor, FileDescriptor) {
3024+
let addr_a = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1000};
3025+
let addr_b = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1001};
3026+
3027+
let id_a = peer_a.node_signer.get_node_id(Recipient::Node).unwrap();
3028+
let id_b = peer_b.node_signer.get_node_id(Recipient::Node).unwrap();
3029+
3030+
let features_a = peer_a.init_features(id_b);
3031+
let features_b = peer_b.init_features(id_a);
3032+
3033+
let (fd_a, fd_b, a_refused, b_refused) = try_establish_connection(peer_a, peer_b);
3034+
3035+
assert_eq!(a_refused.unwrap(), false);
3036+
assert_eq!(b_refused.unwrap(), false);
29973037

29983038
assert_eq!(peer_a.peer_by_node_id(&id_b).unwrap().counterparty_node_id, id_b);
29993039
assert_eq!(peer_a.peer_by_node_id(&id_b).unwrap().socket_address, Some(addr_b));
@@ -3246,6 +3286,50 @@ mod tests {
32463286
assert_eq!(peers[0].peers.read().unwrap().len(), 0);
32473287
}
32483288

3289+
fn do_test_peer_connected_error_disconnects(handler: usize) {
3290+
// Test that if a message handler fails a connection in `peer_connected` we reliably
3291+
// produce `peer_disconnected` events for all other message handlers (that saw a
3292+
// corresponding `peer_connected`).
3293+
let cfgs = create_peermgr_cfgs(2);
3294+
let peers = create_network(2, &cfgs);
3295+
3296+
match handler & !1 {
3297+
0 => {
3298+
peers[handler & 1].message_handler.chan_handler.conn_tracker.fail_connections.store(true, Ordering::Release);
3299+
}
3300+
2 => {
3301+
peers[handler & 1].message_handler.route_handler.conn_tracker.fail_connections.store(true, Ordering::Release);
3302+
}
3303+
4 => {
3304+
peers[handler & 1].message_handler.custom_message_handler.conn_tracker.fail_connections.store(true, Ordering::Release);
3305+
}
3306+
_ => panic!(),
3307+
}
3308+
let (_sd1, _sd2, a_refused, b_refused) = try_establish_connection(&peers[0], &peers[1]);
3309+
if handler & 1 == 0 {
3310+
assert!(a_refused.is_err());
3311+
assert!(peers[0].list_peers().is_empty());
3312+
} else {
3313+
assert!(b_refused.is_err());
3314+
assert!(peers[1].list_peers().is_empty());
3315+
}
3316+
// At least one message handler should have seen the connection.
3317+
assert!(peers[handler & 1].message_handler.chan_handler.conn_tracker.had_peers.load(Ordering::Acquire) ||
3318+
peers[handler & 1].message_handler.route_handler.conn_tracker.had_peers.load(Ordering::Acquire) ||
3319+
peers[handler & 1].message_handler.custom_message_handler.conn_tracker.had_peers.load(Ordering::Acquire));
3320+
// And both message handlers doing tracking should see the disconnection
3321+
assert!(peers[handler & 1].message_handler.chan_handler.conn_tracker.connected_peers.lock().unwrap().is_empty());
3322+
assert!(peers[handler & 1].message_handler.route_handler.conn_tracker.connected_peers.lock().unwrap().is_empty());
3323+
assert!(peers[handler & 1].message_handler.custom_message_handler.conn_tracker.connected_peers.lock().unwrap().is_empty());
3324+
}
3325+
3326+
#[test]
3327+
fn test_peer_connected_error_disconnects() {
3328+
for i in 0..6 {
3329+
do_test_peer_connected_error_disconnects(i);
3330+
}
3331+
}
3332+
32493333
#[test]
32503334
fn test_do_attempt_write_data() {
32513335
// Create 2 peers with custom TestRoutingMessageHandlers and connect them.

lightning/src/routing/gossip.rs

+2
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,8 @@ where
706706
Ok(())
707707
}
708708

709+
fn peer_disconnected(&self, _their_node_id: PublicKey) {}
710+
709711
fn handle_reply_channel_range(
710712
&self, _their_node_id: PublicKey, _msg: ReplyChannelRange,
711713
) -> Result<(), LightningError> {

lightning/src/util/test_utils.rs

+47-6
Original file line numberDiff line numberDiff line change
@@ -887,10 +887,45 @@ impl chaininterface::BroadcasterInterface for TestBroadcaster {
887887
}
888888
}
889889

890+
pub struct ConnectionTracker {
891+
pub had_peers: AtomicBool,
892+
pub connected_peers: Mutex<Vec<PublicKey>>,
893+
pub fail_connections: AtomicBool,
894+
}
895+
896+
impl ConnectionTracker {
897+
pub fn new() -> Self {
898+
Self {
899+
had_peers: AtomicBool::new(false),
900+
connected_peers: Mutex::new(Vec::new()),
901+
fail_connections: AtomicBool::new(false),
902+
}
903+
}
904+
905+
pub fn peer_connected(&self, their_node_id: PublicKey) -> Result<(), ()> {
906+
self.had_peers.store(true, Ordering::Release);
907+
let mut connected_peers = self.connected_peers.lock().unwrap();
908+
assert!(!connected_peers.contains(&their_node_id));
909+
if self.fail_connections.load(Ordering::Acquire) {
910+
Err(())
911+
} else {
912+
connected_peers.push(their_node_id);
913+
Ok(())
914+
}
915+
}
916+
917+
pub fn peer_disconnected(&self, their_node_id: PublicKey) {
918+
assert!(self.had_peers.load(Ordering::Acquire));
919+
let mut connected_peers = self.connected_peers.lock().unwrap();
920+
assert!(connected_peers.contains(&their_node_id));
921+
connected_peers.retain(|id| *id != their_node_id);
922+
}
923+
}
924+
890925
pub struct TestChannelMessageHandler {
891926
pub pending_events: Mutex<Vec<events::MessageSendEvent>>,
892927
expected_recv_msgs: Mutex<Option<Vec<wire::Message<()>>>>,
893-
connected_peers: Mutex<HashSet<PublicKey>>,
928+
pub conn_tracker: ConnectionTracker,
894929
chain_hash: ChainHash,
895930
}
896931

@@ -905,7 +940,7 @@ impl TestChannelMessageHandler {
905940
TestChannelMessageHandler {
906941
pending_events: Mutex::new(Vec::new()),
907942
expected_recv_msgs: Mutex::new(None),
908-
connected_peers: Mutex::new(new_hash_set()),
943+
conn_tracker: ConnectionTracker::new(),
909944
chain_hash,
910945
}
911946
}
@@ -1017,15 +1052,14 @@ impl msgs::ChannelMessageHandler for TestChannelMessageHandler {
10171052
self.received_msg(wire::Message::ChannelReestablish(msg.clone()));
10181053
}
10191054
fn peer_disconnected(&self, their_node_id: PublicKey) {
1020-
assert!(self.connected_peers.lock().unwrap().remove(&their_node_id));
1055+
self.conn_tracker.peer_disconnected(their_node_id)
10211056
}
10221057
fn peer_connected(
10231058
&self, their_node_id: PublicKey, _msg: &msgs::Init, _inbound: bool,
10241059
) -> Result<(), ()> {
1025-
assert!(self.connected_peers.lock().unwrap().insert(their_node_id.clone()));
10261060
// Don't bother with `received_msg` for Init as its auto-generated and we don't want to
10271061
// bother re-generating the expected Init message in all tests.
1028-
Ok(())
1062+
self.conn_tracker.peer_connected(their_node_id)
10291063
}
10301064
fn handle_error(&self, _their_node_id: PublicKey, msg: &msgs::ErrorMessage) {
10311065
self.received_msg(wire::Message::Error(msg.clone()));
@@ -1155,6 +1189,7 @@ pub struct TestRoutingMessageHandler {
11551189
pub pending_events: Mutex<Vec<events::MessageSendEvent>>,
11561190
pub request_full_sync: AtomicBool,
11571191
pub announcement_available_for_sync: AtomicBool,
1192+
pub conn_tracker: ConnectionTracker,
11581193
}
11591194

11601195
impl TestRoutingMessageHandler {
@@ -1166,6 +1201,7 @@ impl TestRoutingMessageHandler {
11661201
pending_events,
11671202
request_full_sync: AtomicBool::new(false),
11681203
announcement_available_for_sync: AtomicBool::new(false),
1204+
conn_tracker: ConnectionTracker::new(),
11691205
}
11701206
}
11711207
}
@@ -1240,7 +1276,12 @@ impl msgs::RoutingMessageHandler for TestRoutingMessageHandler {
12401276
timestamp_range: u32::max_value(),
12411277
},
12421278
});
1243-
Ok(())
1279+
1280+
self.conn_tracker.peer_connected(their_node_id)
1281+
}
1282+
1283+
fn peer_disconnected(&self, their_node_id: PublicKey) {
1284+
self.conn_tracker.peer_disconnected(their_node_id);
12441285
}
12451286

12461287
fn handle_reply_channel_range(

0 commit comments

Comments
 (0)