Skip to content

Commit 07a1902

Browse files
110CodingPthunderbiscuit
authored andcommitted
feat: add balance method to Wallet
which in turn calls `Wallet::balance_with_params_conf_threshold` under the hood. Also added some test utilities and a test to check the implementation of the new methods.
1 parent 366dbbe commit 07a1902

File tree

3 files changed

+403
-42
lines changed

3 files changed

+403
-42
lines changed

src/test_utils.rs

Lines changed: 50 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#![allow(unused)]
33
use alloc::string::ToString;
44
use alloc::sync::Arc;
5+
use core::fmt;
56
use core::str::FromStr;
67

78
use bdk_chain::{BlockId, ConfirmationBlockTime, TxUpdate};
@@ -303,50 +304,57 @@ impl From<ConfirmationBlockTime> for ReceiveTo {
303304
// OutPoint { txid, vout: 0 }
304305
// }
305306

306-
// /// Insert a checkpoint into the wallet. This can be used to extend the wallet's local chain
307-
// /// or to insert a block that did not exist previously. Note that if replacing a block with
308-
// /// a different one at the same height, then all later blocks are evicted as well.
309-
// pub fn insert_checkpoint(wallet: &mut Wallet, block: BlockId) {
310-
// let mut cp = wallet.latest_checkpoint();
311-
// cp = cp.insert(block);
312-
// wallet
313-
// .apply_update(Update {
314-
// chain: Some(cp),
315-
// ..Default::default()
316-
// })
317-
// .unwrap();
318-
// }
307+
/// Insert a checkpoint into the wallet. This can be used to extend the wallet's local chain
308+
/// or to insert a block that did not exist previously. Note that if replacing a block with
309+
/// a different one at the same height, then all later blocks are evicted as well.
310+
pub fn insert_checkpoint<K: Ord + Clone + fmt::Debug>(wallet: &mut Wallet<K>, block: BlockId) {
311+
let mut cp = wallet.latest_checkpoint();
312+
cp = cp.insert(block);
313+
wallet
314+
.apply_update(Update {
315+
chain: Some(cp),
316+
..Default::default()
317+
})
318+
.unwrap();
319+
}
319320

320-
// /// Inserts a transaction into the local view, assuming it is currently present in the mempool.
321-
// ///
322-
// /// This can be used, for example, to track a transaction immediately after it is broadcast.
323-
// pub fn insert_tx(wallet: &mut Wallet, tx: Transaction) {
324-
// let txid = tx.compute_txid();
325-
// let seen_at = std::time::UNIX_EPOCH.elapsed().unwrap().as_secs();
326-
// let mut tx_update = TxUpdate::default();
327-
// tx_update.txs = vec![Arc::new(tx)];
328-
// tx_update.seen_ats = [(txid, seen_at)].into();
329-
// wallet
330-
// .apply_update(Update {
331-
// tx_update,
332-
// ..Default::default()
333-
// })
334-
// .expect("failed to apply update");
335-
// }
321+
/// Inserts a transaction into the local view, assuming it is currently present in the mempool.
322+
///
323+
/// This can be used, for example, to track a transaction immediately after it is broadcast.
324+
pub fn insert_tx<K>(wallet: &mut Wallet<K>, tx: Transaction)
325+
where
326+
K: Ord + fmt::Debug + Clone,
327+
{
328+
let txid = tx.compute_txid();
329+
let seen_at = std::time::UNIX_EPOCH.elapsed().unwrap().as_secs();
330+
let mut tx_update = TxUpdate::default();
331+
tx_update.txs = vec![Arc::new(tx)];
332+
tx_update.seen_ats = [(txid, seen_at)].into();
333+
wallet
334+
.apply_update(Update {
335+
tx_update,
336+
..Default::default()
337+
})
338+
.expect("failed to apply update");
339+
}
336340

337-
// /// Simulates confirming a tx with `txid` by applying an update to the wallet containing
338-
// /// the given `anchor`. Note: to be considered confirmed the anchor block must exist in
339-
// /// the current active chain.
340-
// pub fn insert_anchor(wallet: &mut Wallet, txid: Txid, anchor: ConfirmationBlockTime) {
341-
// let mut tx_update = TxUpdate::default();
342-
// tx_update.anchors = [(anchor, txid)].into();
343-
// wallet
344-
// .apply_update(Update {
345-
// tx_update,
346-
// ..Default::default()
347-
// })
348-
// .expect("failed to apply update");
349-
// }
341+
/// Simulates confirming a tx with `txid` by applying an update to the wallet containing
342+
/// the given `anchor`. Note: to be considered confirmed the anchor block must exist in
343+
/// the current active chain.
344+
pub fn insert_anchor<K: Ord + fmt::Debug + Clone>(
345+
wallet: &mut Wallet<K>,
346+
txid: Txid,
347+
anchor: ConfirmationBlockTime,
348+
) {
349+
let mut tx_update = TxUpdate::default();
350+
tx_update.anchors = [(anchor, txid)].into();
351+
wallet
352+
.apply_update(Update {
353+
tx_update,
354+
..Default::default()
355+
})
356+
.expect("failed to apply update");
357+
}
350358

351359
// /// Marks the given `txid` seen as unconfirmed at `seen_at`
352360
// pub fn insert_seen_at(wallet: &mut Wallet, txid: Txid, seen_at: u64) {

src/wallet/mod.rs

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,157 @@ where
571571
}
572572
}
573573

574+
impl<K> Wallet<K>
575+
where
576+
K: Ord + Clone + Debug,
577+
{
578+
/// Compute the wallet balance with canonical params, confirmation threshold, and trust
579+
/// predicate.
580+
///
581+
/// Panics if `conf_threshold` is equal to 0.
582+
pub fn balance_with_params_conf_threshold(
583+
&self,
584+
params: CanonicalizationParams,
585+
conf_threshold: u32,
586+
trust_predicate: impl Fn(
587+
&OutPoint,
588+
&HashMap<Txid, CanonicalTx<Arc<Transaction>, ConfirmationBlockTime>>,
589+
) -> bool,
590+
) -> Balance {
591+
use crate::chain::ChainOracle;
592+
let mut confirmed = Amount::ZERO;
593+
let mut trusted_pending = Amount::ZERO;
594+
let mut untrusted_pending = Amount::ZERO;
595+
let mut immature = Amount::ZERO;
596+
597+
let mut canon_txs = HashMap::new();
598+
let mut canon_spends = HashMap::new();
599+
let outpoints = self.tx_graph.index.outpoints().iter().cloned();
600+
601+
for res in self.tx_graph.graph().try_list_canonical_txs(
602+
&self.chain,
603+
self.chain.tip().block_id(),
604+
params,
605+
) {
606+
let canonical_tx = res.expect("oracle is infallible");
607+
let txid = canonical_tx.tx_node.txid;
608+
609+
if !canonical_tx.tx_node.is_coinbase() {
610+
for txin in &canonical_tx.tx_node.tx.input {
611+
let _res = canon_spends.insert(txin.previous_output, txid);
612+
assert!(_res.is_none(), "tried to replace {_res:?} with {txid:?} ")
613+
}
614+
}
615+
616+
canon_txs.insert(txid, canonical_tx);
617+
}
618+
619+
let unspent_txouts = outpoints.into_iter().filter_map(|(_, outpoint)| {
620+
if canon_spends.contains_key(&outpoint) {
621+
return None;
622+
}
623+
let canon_tx = canon_txs.get(&outpoint.txid)?;
624+
let txout = canon_tx
625+
.tx_node
626+
.tx
627+
.output
628+
.get(outpoint.vout as usize)
629+
.cloned()
630+
.expect("oracle is infallible");
631+
let chain_position = canon_tx.chain_position;
632+
let is_on_coinbase = canon_tx.tx_node.is_coinbase();
633+
Some(FullTxOut {
634+
chain_position,
635+
outpoint,
636+
is_on_coinbase,
637+
spent_by: None,
638+
txout,
639+
})
640+
});
641+
642+
let target_height = self.chain.tip().height().checked_sub(
643+
conf_threshold
644+
.checked_sub(1)
645+
.expect("conf threshold should be positive integer"),
646+
);
647+
let curr_height = self.chain.tip().height();
648+
649+
for full_txout in unspent_txouts {
650+
match full_txout.chain_position {
651+
ChainPosition::Confirmed { .. } => match target_height {
652+
Some(ht) => {
653+
if full_txout.is_confirmed_and_spendable(ht) {
654+
confirmed += full_txout.txout.value;
655+
} else if full_txout.is_confirmed_and_spendable(curr_height) {
656+
if full_txout.is_on_coinbase {
657+
confirmed += full_txout.txout.value;
658+
} else if trust_predicate(&full_txout.outpoint, &canon_txs) {
659+
trusted_pending += full_txout.txout.value;
660+
} else {
661+
untrusted_pending += full_txout.txout.value;
662+
}
663+
} else if !full_txout.is_mature(curr_height) {
664+
immature += full_txout.txout.value;
665+
}
666+
}
667+
None => {
668+
if full_txout.is_confirmed_and_spendable(curr_height) {
669+
if full_txout.is_on_coinbase {
670+
confirmed += full_txout.txout.value;
671+
} else if trust_predicate(&full_txout.outpoint, &canon_txs) {
672+
trusted_pending += full_txout.txout.value;
673+
} else {
674+
untrusted_pending += full_txout.txout.value;
675+
}
676+
} else if !full_txout.is_mature(curr_height) {
677+
immature += full_txout.txout.value;
678+
}
679+
}
680+
},
681+
ChainPosition::Unconfirmed { .. } => {
682+
if trust_predicate(&full_txout.outpoint, &canon_txs) {
683+
trusted_pending += full_txout.txout.value;
684+
} else {
685+
untrusted_pending += full_txout.txout.value;
686+
}
687+
}
688+
}
689+
}
690+
691+
Balance {
692+
confirmed,
693+
trusted_pending,
694+
untrusted_pending,
695+
immature,
696+
}
697+
}
698+
699+
/// Compute the wallet balance with the default parameters.
700+
pub fn balance(&self) -> Balance {
701+
self.balance_with_params_conf_threshold(
702+
CanonicalizationParams::default(),
703+
1,
704+
|outpoint, canon_txs| {
705+
let mut trusted = true;
706+
let canon_tx = canon_txs.get(&outpoint.txid).expect("oracle is infallible");
707+
for txin in &canon_tx.tx_node.tx.input {
708+
trusted = trusted
709+
&& self
710+
.tx_graph
711+
.index
712+
.outpoints()
713+
.iter()
714+
.any(|(_, item)| *item == txin.previous_output);
715+
}
716+
if canon_tx.tx_node.tx.input.is_empty() {
717+
trusted = false;
718+
}
719+
trusted
720+
},
721+
)
722+
}
723+
}
724+
574725
// TODO: replace with `PersistedWallet`
575726
#[cfg(feature = "rusqlite")]
576727
impl<K> Wallet<K>

0 commit comments

Comments
 (0)