Skip to content

Commit fdad047

Browse files
committed
add combined scorer
1 parent b3e7ab0 commit fdad047

File tree

1 file changed

+238
-6
lines changed

1 file changed

+238
-6
lines changed

lightning/src/routing/scoring.rs

+238-6
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,7 @@ where L::Target: Logger {
478478
channel_liquidities: ChannelLiquidities,
479479
}
480480
/// Container for live and historical liquidity bounds for each channel.
481+
#[derive(Clone)]
481482
pub struct ChannelLiquidities(HashMap<u64, ChannelLiquidity>);
482483

483484
impl ChannelLiquidities {
@@ -516,10 +517,6 @@ impl ChannelLiquidities {
516517
self.0.get(short_channel_id)
517518
}
518519

519-
fn get_mut(&mut self, short_channel_id: &u64) -> Option<&mut ChannelLiquidity> {
520-
self.0.get_mut(short_channel_id)
521-
}
522-
523520
fn insert(&mut self, short_channel_id: u64, liquidity: ChannelLiquidity) -> Option<ChannelLiquidity> {
524521
self.0.insert(short_channel_id, liquidity)
525522
}
@@ -532,6 +529,12 @@ impl ChannelLiquidities {
532529
self.0.entry(short_channel_id)
533530
}
534531

532+
#[cfg(test)]
533+
fn get_mut(&mut self, short_channel_id: &u64) -> Option<&mut ChannelLiquidity> {
534+
self.0.get_mut(short_channel_id)
535+
}
536+
537+
#[cfg(test)]
535538
fn serialized_length(&self) -> usize {
536539
self.0.serialized_length()
537540
}
@@ -886,6 +889,7 @@ impl ProbabilisticScoringDecayParameters {
886889
/// first node in the ordering of the channel's counterparties. Thus, swapping the two liquidity
887890
/// offset fields gives the opposite direction.
888891
#[repr(C)] // Force the fields in memory to be in the order we specify
892+
#[derive(Clone)]
889893
pub struct ChannelLiquidity {
890894
/// Lower channel liquidity bound in terms of an offset from zero.
891895
min_liquidity_offset_msat: u64,
@@ -1156,6 +1160,15 @@ impl ChannelLiquidity {
11561160
}
11571161
}
11581162

1163+
fn merge(&mut self, other: &Self) {
1164+
// Take average for min/max liquidity offsets.
1165+
self.min_liquidity_offset_msat = (self.min_liquidity_offset_msat + other.min_liquidity_offset_msat) / 2;
1166+
self.max_liquidity_offset_msat = (self.max_liquidity_offset_msat + other.max_liquidity_offset_msat) / 2;
1167+
1168+
// Merge historical liquidity data.
1169+
self.liquidity_history.merge(&other.liquidity_history);
1170+
}
1171+
11591172
/// Returns a view of the channel liquidity directed from `source` to `target` assuming
11601173
/// `capacity_msat`.
11611174
fn as_directed(
@@ -1689,6 +1702,91 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref> ScoreUpdate for Probabilistic
16891702
}
16901703
}
16911704

1705+
/// A probabilistic scorer that combines local and external information to score channels. This scorer is
1706+
/// shadow-tracking local only scores, so that it becomes possible to cleanly merge external scores when they become
1707+
/// available.
1708+
pub struct CombinedScorer<G: Deref<Target = NetworkGraph<L>>, L: Deref> where L::Target: Logger {
1709+
local_only_scorer: ProbabilisticScorer<G, L>,
1710+
scorer: ProbabilisticScorer<G, L>,
1711+
}
1712+
1713+
impl<G: Deref<Target = NetworkGraph<L>> + Clone, L: Deref + Clone> CombinedScorer<G, L> where L::Target: Logger {
1714+
/// Create a new combined scorer with the given local scorer.
1715+
pub fn new(local_scorer: ProbabilisticScorer<G, L>) -> Self {
1716+
let decay_params = local_scorer.decay_params;
1717+
let network_graph = local_scorer.network_graph.clone();
1718+
let logger = local_scorer.logger.clone();
1719+
let mut scorer = ProbabilisticScorer::new(decay_params, network_graph, logger);
1720+
1721+
scorer.channel_liquidities = local_scorer.channel_liquidities.clone();
1722+
1723+
Self {
1724+
local_only_scorer: local_scorer,
1725+
scorer: scorer,
1726+
}
1727+
}
1728+
1729+
/// Merge external channel liquidity information into the scorer.
1730+
pub fn merge(&mut self, mut external_scores: ChannelLiquidities, duration_since_epoch: Duration) {
1731+
// Decay both sets of scores to make them comparable and mergeable.
1732+
self.local_only_scorer.time_passed(duration_since_epoch);
1733+
external_scores.time_passed(duration_since_epoch, self.local_only_scorer.decay_params);
1734+
1735+
let local_scores = &self.local_only_scorer.channel_liquidities;
1736+
1737+
// For each channel, merge the external liquidity information with the isolated local liquidity information.
1738+
for (scid, mut liquidity) in external_scores.0 {
1739+
if let Some(local_liquidity) = local_scores.get(&scid) {
1740+
liquidity.merge(local_liquidity);
1741+
}
1742+
self.scorer.channel_liquidities.insert(scid, liquidity);
1743+
}
1744+
}
1745+
}
1746+
1747+
impl<G: Deref<Target = NetworkGraph<L>>, L: Deref> ScoreLookUp for CombinedScorer<G, L> where L::Target: Logger {
1748+
type ScoreParams = ProbabilisticScoringFeeParameters;
1749+
1750+
fn channel_penalty_msat(
1751+
&self, candidate: &CandidateRouteHop, usage: ChannelUsage, score_params: &ProbabilisticScoringFeeParameters
1752+
) -> u64 {
1753+
self.scorer.channel_penalty_msat(candidate, usage, score_params)
1754+
}
1755+
}
1756+
1757+
impl<G: Deref<Target = NetworkGraph<L>>, L: Deref> ScoreUpdate for CombinedScorer<G, L> where L::Target: Logger {
1758+
fn payment_path_failed(&mut self,path: &Path,short_channel_id:u64,duration_since_epoch:Duration) {
1759+
self.local_only_scorer.payment_path_failed(path, short_channel_id, duration_since_epoch);
1760+
self.scorer.payment_path_failed(path, short_channel_id, duration_since_epoch);
1761+
}
1762+
1763+
fn payment_path_successful(&mut self,path: &Path,duration_since_epoch:Duration) {
1764+
self.local_only_scorer.payment_path_successful(path, duration_since_epoch);
1765+
self.scorer.payment_path_successful(path, duration_since_epoch);
1766+
}
1767+
1768+
fn probe_failed(&mut self,path: &Path,short_channel_id:u64,duration_since_epoch:Duration) {
1769+
self.local_only_scorer.probe_failed(path, short_channel_id, duration_since_epoch);
1770+
self.scorer.probe_failed(path, short_channel_id, duration_since_epoch);
1771+
}
1772+
1773+
fn probe_successful(&mut self,path: &Path,duration_since_epoch:Duration) {
1774+
self.local_only_scorer.probe_successful(path, duration_since_epoch);
1775+
self.scorer.probe_successful(path, duration_since_epoch);
1776+
}
1777+
1778+
fn time_passed(&mut self,duration_since_epoch:Duration) {
1779+
self.local_only_scorer.time_passed(duration_since_epoch);
1780+
self.scorer.time_passed(duration_since_epoch);
1781+
}
1782+
}
1783+
1784+
impl<G: Deref<Target = NetworkGraph<L>>, L: Deref> Writeable for CombinedScorer<G, L> where L::Target: Logger {
1785+
fn write<W: crate::util::ser::Writer>(&self, writer: &mut W) -> Result<(), crate::io::Error> {
1786+
self.local_only_scorer.write(writer)
1787+
}
1788+
}
1789+
16921790
#[cfg(c_bindings)]
16931791
impl<G: Deref<Target = NetworkGraph<L>>, L: Deref> Score for ProbabilisticScorer<G, L>
16941792
where L::Target: Logger {}
@@ -1868,6 +1966,13 @@ mod bucketed_history {
18681966
self.buckets[bucket] = self.buckets[bucket].saturating_add(BUCKET_FIXED_POINT_ONE);
18691967
}
18701968
}
1969+
1970+
/// Returns the average of the buckets between the two trackers.
1971+
pub(crate) fn merge(&mut self, other: &Self) -> () {
1972+
for (index, bucket) in self.buckets.iter_mut().enumerate() {
1973+
*bucket = (*bucket + other.buckets[index]) / 2;
1974+
}
1975+
}
18711976
}
18721977

18731978
impl_writeable_tlv_based!(HistoricalBucketRangeTracker, { (0, buckets, required) });
@@ -1964,6 +2069,13 @@ mod bucketed_history {
19642069
-> DirectedHistoricalLiquidityTracker<&'a mut HistoricalLiquidityTracker> {
19652070
DirectedHistoricalLiquidityTracker { source_less_than_target, tracker: self }
19662071
}
2072+
2073+
/// Merges the historical liquidity data from another tracker into this one.
2074+
pub fn merge(&mut self, other: &Self) {
2075+
self.min_liquidity_offset_history.merge(&other.min_liquidity_offset_history);
2076+
self.max_liquidity_offset_history.merge(&other.max_liquidity_offset_history);
2077+
self.recalculate_valid_point_count();
2078+
}
19672079
}
19682080

19692081
/// A set of buckets representing the history of where we've seen the minimum- and maximum-
@@ -2122,6 +2234,72 @@ mod bucketed_history {
21222234
Some((cumulative_success_prob * (1024.0 * 1024.0 * 1024.0)) as u64)
21232235
}
21242236
}
2237+
2238+
#[cfg(test)]
2239+
mod tests {
2240+
use crate::routing::scoring::ProbabilisticScoringFeeParameters;
2241+
2242+
use super::{HistoricalBucketRangeTracker, HistoricalLiquidityTracker};
2243+
#[test]
2244+
fn historical_liquidity_bucket_merge() {
2245+
let mut bucket1 = HistoricalBucketRangeTracker::new();
2246+
bucket1.track_datapoint(100, 1000);
2247+
assert_eq!(
2248+
bucket1.buckets,
2249+
[
2250+
0u16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2251+
0, 0, 0, 0, 0, 0, 0
2252+
]
2253+
);
2254+
2255+
let mut bucket2 = HistoricalBucketRangeTracker::new();
2256+
bucket2.track_datapoint(0, 1000);
2257+
assert_eq!(
2258+
bucket2.buckets,
2259+
[
2260+
32u16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2261+
0, 0, 0, 0, 0, 0, 0
2262+
]
2263+
);
2264+
2265+
bucket1.merge(&bucket2);
2266+
assert_eq!(
2267+
bucket1.buckets,
2268+
[
2269+
16u16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2270+
0, 0, 0, 0, 0, 0, 0
2271+
]
2272+
);
2273+
}
2274+
2275+
#[test]
2276+
fn historical_liquidity_tracker_merge() {
2277+
let params = ProbabilisticScoringFeeParameters::default();
2278+
2279+
let probability1: Option<u64>;
2280+
let mut tracker1 = HistoricalLiquidityTracker::new();
2281+
{
2282+
let mut directed_tracker1 = tracker1.as_directed_mut(true);
2283+
directed_tracker1.track_datapoint(100, 200, 1000);
2284+
probability1 = directed_tracker1
2285+
.calculate_success_probability_times_billion(&params, 500, 1000);
2286+
}
2287+
2288+
let mut tracker2 = HistoricalLiquidityTracker::new();
2289+
{
2290+
let mut directed_tracker2 = tracker2.as_directed_mut(true);
2291+
directed_tracker2.track_datapoint(200, 300, 1000);
2292+
}
2293+
2294+
tracker1.merge(&tracker2);
2295+
2296+
let directed_tracker1 = tracker1.as_directed(true);
2297+
let probability =
2298+
directed_tracker1.calculate_success_probability_times_billion(&params, 500, 1000);
2299+
2300+
assert_ne!(probability1, probability);
2301+
}
2302+
}
21252303
}
21262304

21272305
impl<G: Deref<Target = NetworkGraph<L>>, L: Deref> Writeable for ProbabilisticScorer<G, L> where L::Target: Logger {
@@ -2215,15 +2393,15 @@ impl Readable for ChannelLiquidity {
22152393

22162394
#[cfg(test)]
22172395
mod tests {
2218-
use super::{ChannelLiquidity, HistoricalLiquidityTracker, ProbabilisticScoringFeeParameters, ProbabilisticScoringDecayParameters, ProbabilisticScorer};
2396+
use super::{ChannelLiquidity, HistoricalLiquidityTracker, ProbabilisticScorer, ProbabilisticScoringDecayParameters, ProbabilisticScoringFeeParameters};
22192397
use crate::blinded_path::BlindedHop;
22202398
use crate::util::config::UserConfig;
22212399

22222400
use crate::ln::channelmanager;
22232401
use crate::ln::msgs::{ChannelAnnouncement, ChannelUpdate, UnsignedChannelAnnouncement, UnsignedChannelUpdate};
22242402
use crate::routing::gossip::{EffectiveCapacity, NetworkGraph, NodeId};
22252403
use crate::routing::router::{BlindedTail, Path, RouteHop, CandidateRouteHop, PublicHopCandidate};
2226-
use crate::routing::scoring::{ChannelUsage, ScoreLookUp, ScoreUpdate};
2404+
use crate::routing::scoring::{ChannelLiquidities, ChannelUsage, CombinedScorer, ScoreLookUp, ScoreUpdate};
22272405
use crate::util::ser::{ReadableArgs, Writeable};
22282406
use crate::util::test_utils::{self, TestLogger};
22292407

@@ -2233,6 +2411,7 @@ mod tests {
22332411
use bitcoin::network::Network;
22342412
use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey};
22352413
use core::time::Duration;
2414+
use std::rc::Rc;
22362415
use crate::io;
22372416

22382417
fn source_privkey() -> SecretKey {
@@ -3724,6 +3903,59 @@ mod tests {
37243903
assert_eq!(scorer.historical_estimated_payment_success_probability(42, &target, amount_msat, &params, false),
37253904
Some(0.0));
37263905
}
3906+
3907+
#[test]
3908+
fn combined_scorer() {
3909+
let logger = TestLogger::new();
3910+
let network_graph = network_graph(&logger);
3911+
let params = ProbabilisticScoringFeeParameters::default();
3912+
let mut scorer = ProbabilisticScorer::new(ProbabilisticScoringDecayParameters::default(), &network_graph, &logger);
3913+
scorer.payment_path_failed(&payment_path_for_amount(600), 42, Duration::ZERO);
3914+
3915+
let mut combined_scorer = CombinedScorer::new(scorer);
3916+
3917+
// Verify that the combined_scorer has the correct liquidity range after a failed 600 msat payment.
3918+
let liquidity_range = combined_scorer.scorer.estimated_channel_liquidity_range(42, &target_node_id());
3919+
assert_eq!(liquidity_range.unwrap(), (0, 600));
3920+
3921+
let source = source_node_id();
3922+
let usage = ChannelUsage {
3923+
amount_msat: 750,
3924+
inflight_htlc_msat: 0,
3925+
effective_capacity: EffectiveCapacity::Total { capacity_msat: 1_000, htlc_maximum_msat: 1_000 },
3926+
};
3927+
3928+
{
3929+
let network_graph = network_graph.read_only();
3930+
let channel = network_graph.channel(42).unwrap();
3931+
let (info, _) = channel.as_directed_from(&source).unwrap();
3932+
let candidate = CandidateRouteHop::PublicHop(PublicHopCandidate {
3933+
info,
3934+
short_channel_id: 42,
3935+
});
3936+
3937+
let penalty = combined_scorer.channel_penalty_msat(&candidate, usage, &params);
3938+
3939+
let mut external_liquidity = ChannelLiquidity::new(Duration::ZERO);
3940+
let logger_rc = Rc::new(&logger); // Why necessary and not above for the network graph?
3941+
external_liquidity.as_directed_mut(&source_node_id(), &target_node_id(), 1_000).
3942+
successful(1000, Duration::ZERO, format_args!("test channel"), logger_rc.as_ref());
3943+
3944+
let mut external_scores = ChannelLiquidities::new();
3945+
3946+
external_scores.insert(42, external_liquidity);
3947+
combined_scorer.merge(external_scores, Duration::ZERO);
3948+
3949+
let penalty_after_merge = combined_scorer.channel_penalty_msat(&candidate, usage, &params);
3950+
3951+
// Since the external source observed a successful payment, the penalty should be lower after the merge.
3952+
assert!(penalty_after_merge < penalty);
3953+
}
3954+
3955+
// Verify that after the merge with a successful payment, the liquidity range is increased.
3956+
let liquidity_range = combined_scorer.scorer.estimated_channel_liquidity_range(42, &target_node_id());
3957+
assert_eq!(liquidity_range.unwrap(), (0, 300));
3958+
}
37273959
}
37283960

37293961
#[cfg(ldk_bench)]

0 commit comments

Comments
 (0)