@@ -19,11 +19,12 @@ pub(crate) trait DbTxPruneExt: DbTxMut {
1919 mut delete_callback : impl FnMut ( TableRow < T > ) ,
2020 ) -> Result < ( usize , bool ) , DatabaseError > {
2121 let mut cursor = self . cursor_write :: < T > ( ) ?;
22- let mut keys = keys. into_iter ( ) ;
22+ let mut keys = keys. into_iter ( ) . peekable ( ) ;
2323
2424 let mut deleted_entries = 0 ;
2525
26- for key in & mut keys {
26+ let mut done = true ;
27+ while keys. peek ( ) . is_some ( ) {
2728 if limiter. is_limit_reached ( ) {
2829 debug ! (
2930 target: "providers::db" ,
@@ -33,9 +34,11 @@ pub(crate) trait DbTxPruneExt: DbTxMut {
3334 table = %T :: NAME ,
3435 "Pruning limit reached"
3536 ) ;
37+ done = false ;
3638 break
3739 }
3840
41+ let key = keys. next ( ) . expect ( "peek() said Some" ) ;
3942 let row = cursor. seek_exact ( key) ?;
4043 if let Some ( row) = row {
4144 cursor. delete_current ( ) ?;
@@ -45,7 +48,6 @@ pub(crate) trait DbTxPruneExt: DbTxMut {
4548 }
4649 }
4750
48- let done = keys. next ( ) . is_none ( ) ;
4951 Ok ( ( deleted_entries, done) )
5052 }
5153
@@ -124,3 +126,158 @@ pub(crate) trait DbTxPruneExt: DbTxMut {
124126}
125127
126128impl < Tx > DbTxPruneExt for Tx where Tx : DbTxMut { }
129+
130+ #[ cfg( test) ]
131+ mod tests {
132+ use super :: DbTxPruneExt ;
133+ use crate :: PruneLimiter ;
134+ use reth_db_api:: tables;
135+ use reth_primitives_traits:: SignerRecoverable ;
136+ use reth_provider:: { DBProvider , DatabaseProviderFactory } ;
137+ use reth_stages:: test_utils:: { StorageKind , TestStageDB } ;
138+ use reth_testing_utils:: generators:: { self , random_block_range, BlockRangeParams } ;
139+ use std:: sync:: {
140+ atomic:: { AtomicUsize , Ordering } ,
141+ Arc ,
142+ } ;
143+
144+ struct CountingIter {
145+ data : Vec < u64 > ,
146+ calls : Arc < AtomicUsize > ,
147+ }
148+
149+ impl CountingIter {
150+ fn new ( data : Vec < u64 > , calls : Arc < AtomicUsize > ) -> Self {
151+ Self { data, calls }
152+ }
153+ }
154+
155+ struct CountingIntoIter {
156+ inner : std:: vec:: IntoIter < u64 > ,
157+ calls : Arc < AtomicUsize > ,
158+ }
159+
160+ impl Iterator for CountingIntoIter {
161+ type Item = u64 ;
162+ fn next ( & mut self ) -> Option < Self :: Item > {
163+ let res = self . inner . next ( ) ;
164+ self . calls . fetch_add ( 1 , Ordering :: SeqCst ) ;
165+ res
166+ }
167+ }
168+
169+ impl IntoIterator for CountingIter {
170+ type Item = u64 ;
171+ type IntoIter = CountingIntoIter ;
172+ fn into_iter ( self ) -> Self :: IntoIter {
173+ CountingIntoIter { inner : self . data . into_iter ( ) , calls : self . calls }
174+ }
175+ }
176+
177+ #[ test]
178+ fn prune_table_with_iterator_early_exit_does_not_overconsume ( ) {
179+ let db = TestStageDB :: default ( ) ;
180+ let mut rng = generators:: rng ( ) ;
181+
182+ let blocks = random_block_range (
183+ & mut rng,
184+ 1 ..=3 ,
185+ BlockRangeParams {
186+ parent : Some ( alloy_primitives:: B256 :: ZERO ) ,
187+ tx_count : 2 ..3 ,
188+ ..Default :: default ( )
189+ } ,
190+ ) ;
191+ db. insert_blocks ( blocks. iter ( ) , StorageKind :: Database ( None ) ) . expect ( "insert blocks" ) ;
192+
193+ let mut tx_senders = Vec :: new ( ) ;
194+ for block in & blocks {
195+ tx_senders. reserve_exact ( block. transaction_count ( ) ) ;
196+ for transaction in & block. body ( ) . transactions {
197+ tx_senders. push ( (
198+ tx_senders. len ( ) as u64 ,
199+ transaction. recover_signer ( ) . expect ( "recover signer" ) ,
200+ ) ) ;
201+ }
202+ }
203+ let total = tx_senders. len ( ) ;
204+ db. insert_transaction_senders ( tx_senders) . expect ( "insert transaction senders" ) ;
205+
206+ let provider = db. factory . database_provider_rw ( ) . unwrap ( ) ;
207+
208+ let calls = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
209+ let keys: Vec < u64 > = ( 0 ..total as u64 ) . collect ( ) ;
210+ let counting_iter = CountingIter :: new ( keys, calls. clone ( ) ) ;
211+
212+ let mut limiter = PruneLimiter :: default ( ) . set_deleted_entries_limit ( 2 ) ;
213+
214+ let ( pruned, done) = provider
215+ . tx_ref ( )
216+ . prune_table_with_iterator :: < tables:: TransactionSenders > (
217+ counting_iter,
218+ & mut limiter,
219+ |_| { } ,
220+ )
221+ . expect ( "prune" ) ;
222+
223+ assert_eq ! ( pruned, 2 ) ;
224+ assert ! ( !done) ;
225+ assert_eq ! ( calls. load( Ordering :: SeqCst ) , pruned + 1 ) ;
226+
227+ provider. commit ( ) . expect ( "commit" ) ;
228+ assert_eq ! ( db. table:: <tables:: TransactionSenders >( ) . unwrap( ) . len( ) , total - 2 ) ;
229+ }
230+
231+ #[ test]
232+ fn prune_table_with_iterator_consumes_to_end_reports_done ( ) {
233+ let db = TestStageDB :: default ( ) ;
234+ let mut rng = generators:: rng ( ) ;
235+
236+ let blocks = random_block_range (
237+ & mut rng,
238+ 1 ..=2 ,
239+ BlockRangeParams {
240+ parent : Some ( alloy_primitives:: B256 :: ZERO ) ,
241+ tx_count : 1 ..2 ,
242+ ..Default :: default ( )
243+ } ,
244+ ) ;
245+ db. insert_blocks ( blocks. iter ( ) , StorageKind :: Database ( None ) ) . expect ( "insert blocks" ) ;
246+
247+ let mut tx_senders = Vec :: new ( ) ;
248+ for block in & blocks {
249+ for transaction in & block. body ( ) . transactions {
250+ tx_senders. push ( (
251+ tx_senders. len ( ) as u64 ,
252+ transaction. recover_signer ( ) . expect ( "recover signer" ) ,
253+ ) ) ;
254+ }
255+ }
256+ let total = tx_senders. len ( ) ;
257+ db. insert_transaction_senders ( tx_senders) . expect ( "insert transaction senders" ) ;
258+
259+ let provider = db. factory . database_provider_rw ( ) . unwrap ( ) ;
260+
261+ let calls = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
262+ let keys: Vec < u64 > = ( 0 ..total as u64 ) . collect ( ) ;
263+ let counting_iter = CountingIter :: new ( keys, calls. clone ( ) ) ;
264+
265+ let mut limiter = PruneLimiter :: default ( ) . set_deleted_entries_limit ( usize:: MAX ) ;
266+
267+ let ( pruned, done) = provider
268+ . tx_ref ( )
269+ . prune_table_with_iterator :: < tables:: TransactionSenders > (
270+ counting_iter,
271+ & mut limiter,
272+ |_| { } ,
273+ )
274+ . expect ( "prune" ) ;
275+
276+ assert_eq ! ( pruned, total) ;
277+ assert ! ( done) ;
278+ assert_eq ! ( calls. load( Ordering :: SeqCst ) , total + 1 ) ;
279+
280+ provider. commit ( ) . expect ( "commit" ) ;
281+ assert_eq ! ( db. table:: <tables:: TransactionSenders >( ) . unwrap( ) . len( ) , 0 ) ;
282+ }
283+ }
0 commit comments