Skip to content

Commit 4bc597a

Browse files
committed
Assert peer_{dis,}connected consistency across test handlers
This adds a `ConnectionTracker` test util which is used across `TestChannelMessageHandler`, `TestRoutingMessageHandler` and `TestCustomMessageHandler`, asserting that `peer_connected` and `peer_disconnected` methods are well-ordered. This expands test coverage from just `TestChannelMessageHandler` to cover all test handlers and adds some useful features which we'll use to test the fix in the next commit. This also adds an additional test which tests `peer_{dis,}connected` consistency when a handler refuses a connection by returning an `Err` from `peer_connected`.
1 parent 07148db commit 4bc597a

File tree

2 files changed

+134
-22
lines changed

2 files changed

+134
-22
lines changed

lightning/src/ln/peer_handler.rs

+88-15
Original file line numberDiff line numberDiff line change
@@ -2867,6 +2867,16 @@ mod tests {
28672867

28682868
struct TestCustomMessageHandler {
28692869
features: InitFeatures,
2870+
conn_tracker: test_utils::ConnectionTracker,
2871+
}
2872+
2873+
impl TestCustomMessageHandler {
2874+
fn new(features: InitFeatures) -> Self {
2875+
Self {
2876+
features,
2877+
conn_tracker: test_utils::ConnectionTracker::new(),
2878+
}
2879+
}
28702880
}
28712881

28722882
impl wire::CustomMessageReader for TestCustomMessageHandler {
@@ -2883,10 +2893,13 @@ mod tests {
28832893

28842894
fn get_and_clear_pending_msg(&self) -> Vec<(PublicKey, Self::CustomMessage)> { Vec::new() }
28852895

2896+
fn peer_disconnected(&self, their_node_id: PublicKey) {
2897+
self.conn_tracker.peer_disconnected(their_node_id);
2898+
}
28862899

2887-
fn peer_disconnected(&self, _their_node_id: PublicKey) {}
2888-
2889-
fn peer_connected(&self, _their_node_id: PublicKey, _msg: &Init, _inbound: bool) -> Result<(), ()> { Ok(()) }
2900+
fn peer_connected(&self, their_node_id: PublicKey, _msg: &Init, _inbound: bool) -> Result<(), ()> {
2901+
self.conn_tracker.peer_connected(their_node_id)
2902+
}
28902903

28912904
fn provided_node_features(&self) -> NodeFeatures { NodeFeatures::empty() }
28922905

@@ -2909,7 +2922,7 @@ mod tests {
29092922
chan_handler: test_utils::TestChannelMessageHandler::new(ChainHash::using_genesis_block(Network::Testnet)),
29102923
logger: test_utils::TestLogger::with_id(i.to_string()),
29112924
routing_handler: test_utils::TestRoutingMessageHandler::new(),
2912-
custom_handler: TestCustomMessageHandler { features },
2925+
custom_handler: TestCustomMessageHandler::new(features),
29132926
node_signer: test_utils::TestNodeSigner::new(node_secret),
29142927
}
29152928
);
@@ -2932,7 +2945,7 @@ mod tests {
29322945
chan_handler: test_utils::TestChannelMessageHandler::new(ChainHash::using_genesis_block(Network::Testnet)),
29332946
logger: test_utils::TestLogger::new(),
29342947
routing_handler: test_utils::TestRoutingMessageHandler::new(),
2935-
custom_handler: TestCustomMessageHandler { features },
2948+
custom_handler: TestCustomMessageHandler::new(features),
29362949
node_signer: test_utils::TestNodeSigner::new(node_secret),
29372950
}
29382951
);
@@ -2952,7 +2965,7 @@ mod tests {
29522965
chan_handler: test_utils::TestChannelMessageHandler::new(network),
29532966
logger: test_utils::TestLogger::new(),
29542967
routing_handler: test_utils::TestRoutingMessageHandler::new(),
2955-
custom_handler: TestCustomMessageHandler { features },
2968+
custom_handler: TestCustomMessageHandler::new(features),
29562969
node_signer: test_utils::TestNodeSigner::new(node_secret),
29572970
}
29582971
);
@@ -2976,19 +2989,16 @@ mod tests {
29762989
peers
29772990
}
29782991

2979-
fn establish_connection<'a>(peer_a: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler, &'a test_utils::TestRoutingMessageHandler, IgnoringMessageHandler, &'a test_utils::TestLogger, &'a TestCustomMessageHandler, &'a test_utils::TestNodeSigner>, peer_b: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler, &'a test_utils::TestRoutingMessageHandler, IgnoringMessageHandler, &'a test_utils::TestLogger, &'a TestCustomMessageHandler, &'a test_utils::TestNodeSigner>) -> (FileDescriptor, FileDescriptor) {
2992+
fn try_establish_connection<'a>(peer_a: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler, &'a test_utils::TestRoutingMessageHandler, IgnoringMessageHandler, &'a test_utils::TestLogger, &'a TestCustomMessageHandler, &'a test_utils::TestNodeSigner>, peer_b: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler, &'a test_utils::TestRoutingMessageHandler, IgnoringMessageHandler, &'a test_utils::TestLogger, &'a TestCustomMessageHandler, &'a test_utils::TestNodeSigner>) -> (FileDescriptor, FileDescriptor, Result<bool, PeerHandleError>, Result<bool, PeerHandleError>) {
2993+
let addr_a = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1000};
2994+
let addr_b = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1001};
2995+
29802996
static FD_COUNTER: AtomicUsize = AtomicUsize::new(0);
29812997
let fd = FD_COUNTER.fetch_add(1, Ordering::Relaxed) as u16;
29822998

29832999
let id_a = peer_a.node_signer.get_node_id(Recipient::Node).unwrap();
29843000
let mut fd_a = FileDescriptor::new(fd);
2985-
let addr_a = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1000};
2986-
2987-
let id_b = peer_b.node_signer.get_node_id(Recipient::Node).unwrap();
2988-
let features_a = peer_a.init_features(id_b);
2989-
let features_b = peer_b.init_features(id_a);
29903001
let mut fd_b = FileDescriptor::new(fd);
2991-
let addr_b = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1001};
29923002

29933003
let initial_data = peer_b.new_outbound_connection(id_a, fd_b.clone(), Some(addr_a.clone())).unwrap();
29943004
peer_a.new_inbound_connection(fd_a.clone(), Some(addr_b.clone())).unwrap();
@@ -3000,11 +3010,30 @@ mod tests {
30003010

30013011
peer_b.process_events();
30023012
let b_data = fd_b.outbound_data.lock().unwrap().split_off(0);
3003-
assert_eq!(peer_a.read_event(&mut fd_a, &b_data).unwrap(), false);
3013+
let a_refused = peer_a.read_event(&mut fd_a, &b_data);
30043014

30053015
peer_a.process_events();
30063016
let a_data = fd_a.outbound_data.lock().unwrap().split_off(0);
3007-
assert_eq!(peer_b.read_event(&mut fd_b, &a_data).unwrap(), false);
3017+
let b_refused = peer_b.read_event(&mut fd_b, &a_data);
3018+
3019+
(fd_a, fd_b, a_refused, b_refused)
3020+
}
3021+
3022+
3023+
fn establish_connection<'a>(peer_a: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler, &'a test_utils::TestRoutingMessageHandler, IgnoringMessageHandler, &'a test_utils::TestLogger, &'a TestCustomMessageHandler, &'a test_utils::TestNodeSigner>, peer_b: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler, &'a test_utils::TestRoutingMessageHandler, IgnoringMessageHandler, &'a test_utils::TestLogger, &'a TestCustomMessageHandler, &'a test_utils::TestNodeSigner>) -> (FileDescriptor, FileDescriptor) {
3024+
let addr_a = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1000};
3025+
let addr_b = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1001};
3026+
3027+
let id_a = peer_a.node_signer.get_node_id(Recipient::Node).unwrap();
3028+
let id_b = peer_b.node_signer.get_node_id(Recipient::Node).unwrap();
3029+
3030+
let features_a = peer_a.init_features(id_b);
3031+
let features_b = peer_b.init_features(id_a);
3032+
3033+
let (fd_a, fd_b, a_refused, b_refused) = try_establish_connection(peer_a, peer_b);
3034+
3035+
assert_eq!(a_refused.unwrap(), false);
3036+
assert_eq!(b_refused.unwrap(), false);
30083037

30093038
assert_eq!(peer_a.peer_by_node_id(&id_b).unwrap().counterparty_node_id, id_b);
30103039
assert_eq!(peer_a.peer_by_node_id(&id_b).unwrap().socket_address, Some(addr_b));
@@ -3257,6 +3286,50 @@ mod tests {
32573286
assert_eq!(peers[0].peers.read().unwrap().len(), 0);
32583287
}
32593288

3289+
fn do_test_peer_connected_error_disconnects(handler: usize) {
3290+
// Test that if a message handler fails a connection in `peer_connected` we reliably
3291+
// produce `peer_disconnected` events for all other message handlers (that saw a
3292+
// corresponding `peer_connected`).
3293+
let cfgs = create_peermgr_cfgs(2);
3294+
let peers = create_network(2, &cfgs);
3295+
3296+
match handler & !1 {
3297+
0 => {
3298+
peers[handler & 1].message_handler.chan_handler.conn_tracker.fail_connections.store(true, Ordering::Release);
3299+
}
3300+
2 => {
3301+
peers[handler & 1].message_handler.route_handler.conn_tracker.fail_connections.store(true, Ordering::Release);
3302+
}
3303+
4 => {
3304+
peers[handler & 1].message_handler.custom_message_handler.conn_tracker.fail_connections.store(true, Ordering::Release);
3305+
}
3306+
_ => panic!(),
3307+
}
3308+
let (_sd1, _sd2, a_refused, b_refused) = try_establish_connection(&peers[0], &peers[1]);
3309+
if handler & 1 == 0 {
3310+
assert!(a_refused.is_err());
3311+
assert!(peers[0].list_peers().is_empty());
3312+
} else {
3313+
assert!(b_refused.is_err());
3314+
assert!(peers[1].list_peers().is_empty());
3315+
}
3316+
// At least one message handler should have seen the connection.
3317+
assert!(peers[handler & 1].message_handler.chan_handler.conn_tracker.had_peers.load(Ordering::Acquire) ||
3318+
peers[handler & 1].message_handler.route_handler.conn_tracker.had_peers.load(Ordering::Acquire) ||
3319+
peers[handler & 1].message_handler.custom_message_handler.conn_tracker.had_peers.load(Ordering::Acquire));
3320+
// And both message handlers doing tracking should see the disconnection
3321+
assert!(peers[handler & 1].message_handler.chan_handler.conn_tracker.connected_peers.lock().unwrap().is_empty());
3322+
assert!(peers[handler & 1].message_handler.route_handler.conn_tracker.connected_peers.lock().unwrap().is_empty());
3323+
assert!(peers[handler & 1].message_handler.custom_message_handler.conn_tracker.connected_peers.lock().unwrap().is_empty());
3324+
}
3325+
3326+
#[test]
3327+
fn test_peer_connected_error_disconnects() {
3328+
for i in 0..6 {
3329+
do_test_peer_connected_error_disconnects(i);
3330+
}
3331+
}
3332+
32603333
#[test]
32613334
fn test_do_attempt_write_data() {
32623335
// Create 2 peers with custom TestRoutingMessageHandlers and connect them.

lightning/src/util/test_utils.rs

+46-7
Original file line numberDiff line numberDiff line change
@@ -889,10 +889,45 @@ impl chaininterface::BroadcasterInterface for TestBroadcaster {
889889
}
890890
}
891891

892+
pub struct ConnectionTracker {
893+
pub had_peers: AtomicBool,
894+
pub connected_peers: Mutex<Vec<PublicKey>>,
895+
pub fail_connections: AtomicBool,
896+
}
897+
898+
impl ConnectionTracker {
899+
pub fn new() -> Self {
900+
Self {
901+
had_peers: AtomicBool::new(false),
902+
connected_peers: Mutex::new(Vec::new()),
903+
fail_connections: AtomicBool::new(false),
904+
}
905+
}
906+
907+
pub fn peer_connected(&self, their_node_id: PublicKey) -> Result<(), ()> {
908+
self.had_peers.store(true, Ordering::Release);
909+
let mut connected_peers = self.connected_peers.lock().unwrap();
910+
assert!(!connected_peers.contains(&their_node_id));
911+
if self.fail_connections.load(Ordering::Acquire) {
912+
Err(())
913+
} else {
914+
connected_peers.push(their_node_id);
915+
Ok(())
916+
}
917+
}
918+
919+
pub fn peer_disconnected(&self, their_node_id: PublicKey) {
920+
assert!(self.had_peers.load(Ordering::Acquire));
921+
let mut connected_peers = self.connected_peers.lock().unwrap();
922+
assert!(connected_peers.contains(&their_node_id));
923+
connected_peers.retain(|id| *id != their_node_id);
924+
}
925+
}
926+
892927
pub struct TestChannelMessageHandler {
893928
pub pending_events: Mutex<Vec<events::MessageSendEvent>>,
894929
expected_recv_msgs: Mutex<Option<Vec<wire::Message<()>>>>,
895-
connected_peers: Mutex<HashSet<PublicKey>>,
930+
pub conn_tracker: ConnectionTracker,
896931
chain_hash: ChainHash,
897932
}
898933

@@ -907,7 +942,7 @@ impl TestChannelMessageHandler {
907942
TestChannelMessageHandler {
908943
pending_events: Mutex::new(Vec::new()),
909944
expected_recv_msgs: Mutex::new(None),
910-
connected_peers: Mutex::new(new_hash_set()),
945+
conn_tracker: ConnectionTracker::new(),
911946
chain_hash,
912947
}
913948
}
@@ -1019,15 +1054,14 @@ impl msgs::ChannelMessageHandler for TestChannelMessageHandler {
10191054
self.received_msg(wire::Message::ChannelReestablish(msg.clone()));
10201055
}
10211056
fn peer_disconnected(&self, their_node_id: PublicKey) {
1022-
assert!(self.connected_peers.lock().unwrap().remove(&their_node_id));
1057+
self.conn_tracker.peer_disconnected(their_node_id)
10231058
}
10241059
fn peer_connected(
10251060
&self, their_node_id: PublicKey, _msg: &msgs::Init, _inbound: bool,
10261061
) -> Result<(), ()> {
1027-
assert!(self.connected_peers.lock().unwrap().insert(their_node_id.clone()));
10281062
// Don't bother with `received_msg` for Init as its auto-generated and we don't want to
10291063
// bother re-generating the expected Init message in all tests.
1030-
Ok(())
1064+
self.conn_tracker.peer_connected(their_node_id)
10311065
}
10321066
fn handle_error(&self, _their_node_id: PublicKey, msg: &msgs::ErrorMessage) {
10331067
self.received_msg(wire::Message::Error(msg.clone()));
@@ -1157,6 +1191,7 @@ pub struct TestRoutingMessageHandler {
11571191
pub pending_events: Mutex<Vec<events::MessageSendEvent>>,
11581192
pub request_full_sync: AtomicBool,
11591193
pub announcement_available_for_sync: AtomicBool,
1194+
pub conn_tracker: ConnectionTracker,
11601195
}
11611196

11621197
impl TestRoutingMessageHandler {
@@ -1168,6 +1203,7 @@ impl TestRoutingMessageHandler {
11681203
pending_events,
11691204
request_full_sync: AtomicBool::new(false),
11701205
announcement_available_for_sync: AtomicBool::new(false),
1206+
conn_tracker: ConnectionTracker::new(),
11711207
}
11721208
}
11731209
}
@@ -1242,10 +1278,13 @@ impl msgs::RoutingMessageHandler for TestRoutingMessageHandler {
12421278
timestamp_range: u32::max_value(),
12431279
},
12441280
});
1245-
Ok(())
1281+
1282+
self.conn_tracker.peer_connected(their_node_id)
12461283
}
12471284

1248-
fn peer_disconnected(&self, _their_node_id: PublicKey) {}
1285+
fn peer_disconnected(&self, their_node_id: PublicKey) {
1286+
self.conn_tracker.peer_disconnected(their_node_id);
1287+
}
12491288

12501289
fn handle_reply_channel_range(
12511290
&self, _their_node_id: PublicKey, _msg: msgs::ReplyChannelRange,

0 commit comments

Comments
 (0)