Skip to content

Test generated-route valididty in fuzzing #3728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 91 additions & 5 deletions fuzz/src/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ use lightning::ln::channel_state::{ChannelCounterparty, ChannelDetails, ChannelS
use lightning::ln::channelmanager;
use lightning::ln::msgs;
use lightning::ln::types::ChannelId;
use lightning::routing::gossip::{NetworkGraph, RoutingFees};
use lightning::routing::gossip::{NetworkGraph, NodeId, RoutingFees};
use lightning::routing::router::{
find_route, PaymentParameters, RouteHint, RouteHintHop, RouteParameters,
find_route, Payee, PaymentParameters, RouteHint, RouteHintHop, RouteParameters,
};
use lightning::routing::scoring::{
ProbabilisticScorer, ProbabilisticScoringDecayParameters, ProbabilisticScoringFeeParameters,
Expand Down Expand Up @@ -296,7 +296,7 @@ pub fn do_test<Out: test_logger::Output>(data: &[u8], out: Out) {
let final_value_msat = slice_to_be64(get_slice!(8));
let final_cltv_expiry_delta = slice_to_be32(get_slice!(4));
let route_params = $route_params(final_value_msat, final_cltv_expiry_delta, target);
let _ = find_route(
let route = find_route(
&our_pubkey,
&route_params,
&net_graph,
Expand All @@ -309,6 +309,91 @@ pub fn do_test<Out: test_logger::Output>(data: &[u8], out: Out) {
&ProbabilisticScoringFeeParameters::default(),
&random_seed_bytes,
);
if let Ok(route) = route {
// If we generated a route, check that it is valid
// TODO: Check CLTV deltas
assert_eq!(route.route_params.as_ref(), Some(&route_params));
let graph = net_graph.read_only();
let mut blinded_path_payment_amts = new_hash_map();
let mut total_fee = 0;
let mut total_sent = 0;
for path in &route.paths {
total_fee += path.fee_msat();
total_sent += path.final_value_msat();
let unblinded_recipient = path.hops.last().expect("No hops").pubkey;
let mut hops = path.hops.iter().peekable();
let payee = &route_params.payment_params.payee;
'path_check: while let Some(hop) = hops.next() {
if let Some(next) = hops.peek().cloned() {
let amt_sent: u64 = hops.clone().map(|hop| hop.fee_msat).sum();
if let Payee::Clear { route_hints, .. } = payee {
// If we paid to an invoice with clear route hints, check
// whether we pulled from a route hint first, and if not fall
// back to searching through the public network graph.
for hint_path in route_hints.iter() {
let mut hint_hops = hint_path.0.iter().peekable();
while let Some(hint) = hint_hops.next() {
let next_hint_hop_key = hint_hops
.peek()
.map(|hint_hop| hint_hop.src_node_id)
.unwrap_or(unblinded_recipient);

let matches_hint = hint.src_node_id == hop.pubkey
&& hint.short_channel_id == next.short_channel_id
&& next_hint_hop_key == next.pubkey;
let prop = hint.fees.proportional_millionths as u128;
let base = hint.fees.base_msat as u128;
let min_fee = amt_sent as u128 * prop / 1000000 + base;
if matches_hint {
assert!(min_fee <= hop.fee_msat as u128);
continue 'path_check;
}
}
}
}
let chan = graph.channel(hop.short_channel_id).expect("No chan");
assert!(chan.one_to_two.is_some() && chan.two_to_one.is_some());
let fees = if chan.node_one == NodeId::from_pubkey(&hop.pubkey) {
chan.one_to_two.as_ref().unwrap().fees
} else {
chan.two_to_one.as_ref().unwrap().fees
};
let prop_fee = fees.proportional_millionths as u128;
let base_fee = fees.base_msat as u128;
let min_fee = amt_sent as u128 * prop_fee / 1_000_000 + base_fee;
assert!(min_fee <= hop.fee_msat as u128);
} else {
if let Payee::Blinded { route_hints, .. } = payee {
let tail = path.blinded_tail.as_ref().expect("No blinded path");
if tail.hops.len() == 1 {
// We don't consider the payinfo for one-hop blinded paths
// since they're not "real" blinded paths.
continue;
}
// TODO: We should add some kind of coverage of trampoline hops
assert!(tail.trampoline_hops.is_empty());
let hint_filter = |hint: &&BlindedPaymentPath| {
// We store a unique counter in each encrypted_payload.
let hint_id = &hint.blinded_hops()[0].encrypted_payload;
*hint_id == tail.hops[0].encrypted_payload
};
let mut matching_hints = route_hints.iter().filter(hint_filter);
let used_hint = matching_hints.next().unwrap();
assert!(matching_hints.next().is_none());
let key = &tail.hops[0].encrypted_payload;
let used = blinded_path_payment_amts.entry(key).or_insert(0u64);
let blind_intro_amt = tail.final_value_msat + hop.fee_msat;
*used += blind_intro_amt;
assert!(*used <= used_hint.payinfo.htlc_maximum_msat);
assert!(blind_intro_amt >= used_hint.payinfo.htlc_minimum_msat);
}
break;
}
}
}
assert!(total_sent >= final_value_msat);
assert!(total_fee <= route_params.max_total_routing_fee_msat.unwrap());
}
}
};
}
Expand Down Expand Up @@ -383,7 +468,8 @@ pub fn do_test<Out: test_logger::Output>(data: &[u8], out: Out) {
let dummy_pk = PublicKey::from_slice(&[2; 33]).unwrap();
let last_hops: Vec<BlindedPaymentPath> = last_hops_unblinded
.into_iter()
.map(|hint| {
.enumerate()
.map(|(hint_idx, hint)| {
let hop = &hint.0[0];
let payinfo = BlindedPayInfo {
fee_base_msat: hop.fees.base_msat,
Expand All @@ -398,7 +484,7 @@ pub fn do_test<Out: test_logger::Output>(data: &[u8], out: Out) {
for _ in 0..num_blinded_hops {
blinded_hops.push(BlindedHop {
blinded_node_id: dummy_pk,
encrypted_payload: Vec::new(),
encrypted_payload: hint_idx.to_ne_bytes().to_vec(),
});
}
BlindedPaymentPath::from_raw(
Expand Down
19 changes: 16 additions & 3 deletions lightning/src/routing/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1771,6 +1771,7 @@ struct NodeCounters<'a> {
network_graph: &'a ReadOnlyNetworkGraph<'a>,
private_node_id_to_node_counter: HashMap<NodeId, u32>,
private_hop_key_cache: HashMap<PublicKey, (NodeId, u32)>,
have_blinded_payee_counter: bool,
}

struct NodeCountersBuilder<'a>(NodeCounters<'a>);
Expand All @@ -1781,9 +1782,17 @@ impl<'a> NodeCountersBuilder<'a> {
network_graph,
private_node_id_to_node_counter: new_hash_map(),
private_hop_key_cache: new_hash_map(),
have_blinded_payee_counter: false,
})
}

fn select_node_counter_for_blinded_payee_and_build(mut self) -> (u32, NodeCounters<'a>) {
let next_node_counter = self.0.network_graph.max_node_counter() + 1 +
self.0.private_node_id_to_node_counter.len() as u32;
self.0.have_blinded_payee_counter = true;
(next_node_counter, self.0)
}

fn select_node_counter_for_pubkey(&mut self, pubkey: PublicKey) -> u32 {
let id = NodeId::from_pubkey(&pubkey);
let counter = self.select_node_counter_for_id(id);
Expand All @@ -1809,7 +1818,8 @@ impl<'a> NodeCountersBuilder<'a> {
impl<'a> NodeCounters<'a> {
fn max_counter(&self) -> u32 {
self.network_graph.max_node_counter() +
self.private_node_id_to_node_counter.len() as u32
self.private_node_id_to_node_counter.len() as u32 +
if self.have_blinded_payee_counter { 1 } else { 0 }
}

fn private_node_counter_from_pubkey(&self, pubkey: &PublicKey) -> Option<&(NodeId, u32)> {
Expand Down Expand Up @@ -2416,7 +2426,6 @@ where L::Target: Logger {
let mut node_counter_builder = NodeCountersBuilder::new(&network_graph);

let payer_node_counter = node_counter_builder.select_node_counter_for_pubkey(*our_node_pubkey);
let payee_node_counter = node_counter_builder.select_node_counter_for_pubkey(maybe_dummy_payee_pk);

for route in payment_params.payee.unblinded_route_hints().iter() {
for hop in route.0.iter() {
Expand Down Expand Up @@ -2455,7 +2464,11 @@ where L::Target: Logger {
}
}

let node_counters = node_counter_builder.build();
let (payee_node_counter, node_counters) = if let Some(pubkey) = payment_params.payee.node_id() {
(node_counter_builder.select_node_counter_for_pubkey(pubkey), node_counter_builder.build())
} else {
node_counter_builder.select_node_counter_for_blinded_payee_and_build()
};

let introduction_node_id_cache = calculate_blinded_path_intro_points(
&payment_params, &node_counters, network_graph, &logger, our_node_id, &first_hop_targets,
Expand Down
Loading