Skip to content

Commit 0f04c63

Browse files
committed
use allocator_api and bumpalo to speed up find_path()
1 parent 316cc83 commit 0f04c63

File tree

5 files changed

+30
-21
lines changed

5 files changed

+30
-21
lines changed

Cargo.lock

+3-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+2
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ clap = "4.5.29"
6363
rand_chacha = "0.3.1"
6464
bitvec = "1.0.1"
6565
arbitrary = { version = "1.4.1", features = ["derive"] }
66+
bumpalo = { version = "3.17.0", features = ["allocator_api"] }
6667

6768
[dependencies]
6869
lazy_static = { workspace = true }
@@ -81,6 +82,7 @@ sha3 = "0.10.8"
8182
rand = { workspace = true }
8283
hex = { workspace = true }
8384
sha1 = { workspace = true }
85+
bumpalo = { workspace = true }
8486

8587
[dev-dependencies]
8688
rstest = { workspace = true }

src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#![feature(allocator_api)]
2+
13
pub mod allocator;
24
pub mod bls_ops;
35
pub mod chia_dialect;

src/serde/path_builder.rs

+15-14
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::alloc::Allocator;
2+
13
#[repr(u8)]
24
#[derive(PartialEq, Eq, Clone, Debug, Copy, Hash)]
35
pub enum ChildPos {
@@ -10,25 +12,23 @@ pub enum ChildPos {
1012
/// path is built from left to right, since it's parsed right to left when
1113
/// followed).
1214
#[derive(Clone, Debug, PartialEq)]
13-
pub struct PathBuilder {
15+
pub struct PathBuilder<A: Allocator> {
1416
// TODO: It might make sense to implement small object optimization here.
1517
// The vast majority of paths are just a single byte, statically allocate 8
1618
// would seem reasonable
17-
store: Vec<u8>,
19+
store: Vec<u8, A>,
1820
/// the bit the next write will happen to (counts down)
1921
bit_pos: u8,
2022
}
2123

22-
impl Default for PathBuilder {
23-
fn default() -> Self {
24+
impl<A: Allocator> PathBuilder<A> {
25+
pub fn new(allocator: A) -> Self {
2426
Self {
25-
store: Vec::with_capacity(16),
27+
store: Vec::with_capacity_in(16, allocator),
2628
bit_pos: 7,
2729
}
2830
}
29-
}
3031

31-
impl PathBuilder {
3232
pub fn clear(&mut self) {
3333
self.bit_pos = 7;
3434
self.store.clear();
@@ -49,7 +49,7 @@ impl PathBuilder {
4949
}
5050
}
5151

52-
pub fn done(mut self) -> Vec<u8> {
52+
pub fn done(mut self) -> Vec<u8, A> {
5353
if self.bit_pos < 7 {
5454
let right_shift = self.bit_pos + 1;
5555
let left_shift = 7 - self.bit_pos;
@@ -138,9 +138,10 @@ mod tests {
138138
use crate::serde::serialized_length_atom;
139139
use hex;
140140
use rstest::rstest;
141+
use std::alloc::System;
141142

142-
fn build_path(input: &[u8]) -> PathBuilder {
143-
let mut path = PathBuilder::default();
143+
fn build_path(input: &[u8]) -> PathBuilder<System> {
144+
let mut path = PathBuilder::new(System);
144145
// keep in mind that paths are built in reverse order (starting from the
145146
// target).
146147
for (idx, b) in input.iter().enumerate() {
@@ -215,7 +216,7 @@ mod tests {
215216
#[case(80, 80, "ffffffffffffffffffff")]
216217
#[case(80, 79, "7fffffffffffffffffff")]
217218
fn test_truncate(#[case] num_bits: usize, #[case] truncate: u32, #[case] expect: &str) {
218-
let mut path = PathBuilder::default();
219+
let mut path = PathBuilder::new(System);
219220
for _i in 0..num_bits {
220221
path.push(ChildPos::Right);
221222
}
@@ -260,7 +261,7 @@ mod tests {
260261
#[case(80, 15, "ffff")]
261262
#[case(80, 79, "ffffffffffffffffffff")]
262263
fn test_truncate_add(#[case] num_bits: usize, #[case] truncate: u32, #[case] expect: &str) {
263-
let mut path = PathBuilder::default();
264+
let mut path = PathBuilder::new(System);
264265
for _i in 0..num_bits {
265266
path.push(ChildPos::Right);
266267
}
@@ -274,7 +275,7 @@ mod tests {
274275
fn test_clear(
275276
#[values(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)] num_bits: usize,
276277
) {
277-
let mut path = PathBuilder::default();
278+
let mut path = PathBuilder::new(System);
278279
for _i in 0..num_bits {
279280
path.push(ChildPos::Right);
280281
}
@@ -331,7 +332,7 @@ mod tests {
331332
#[case(513)]
332333
#[case(0xfff9)]
333334
fn test_serialized_length(#[case] num_bits: u32) {
334-
let mut path = PathBuilder::default();
335+
let mut path = PathBuilder::new(System);
335336
for _ in 0..num_bits {
336337
path.push(ChildPos::Right);
337338
}

src/serde/tree_cache.rs

+8-5
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::allocator::{Allocator, NodePtr, SExp};
33
use crate::serde::serialized_length_atom;
44
use crate::serde::RandomState;
55
use crate::serde::VisitedNodes;
6+
use bumpalo::Bump;
67
use rand::prelude::*;
78
use sha1::{Digest, Sha1};
89
use std::collections::hash_map::Entry;
@@ -39,9 +40,9 @@ struct NodeEntry {
3940
pub on_stack: u32,
4041
}
4142

42-
struct PartialPath {
43+
struct PartialPath<'alloc> {
4344
// the path we've built so far
44-
path: PathBuilder,
45+
path: PathBuilder<&'alloc Bump>,
4546
// if we're traversing the stack, this is the stack position. Note that this
4647
// is not an index into the stack array, it's a counter of how far away from
4748
// the top of the stack we are. 0 means we're at the top, and we've found
@@ -387,6 +388,8 @@ impl TreeCache {
387388

388389
let mut seen = VisitedNodes::new(self.node_entry.len() as u32);
389390

391+
let arena = Bump::new();
392+
390393
// We perform a breadth-first search from the node we're finding a path
391394
// to, up through its parents until we find the top of the stack. Note
392395
// since nodes are deduplicated, they may have multiple parents.
@@ -400,7 +403,7 @@ impl TreeCache {
400403

401404
// this child pos represents the path terminator bit
402405
partial_paths.push(PartialPath {
403-
path: PathBuilder::default(),
406+
path: PathBuilder::new(&arena),
404407
stack_pos: -1,
405408
idx,
406409
child: ChildPos::Right,
@@ -410,7 +413,7 @@ impl TreeCache {
410413
// the ones whose length is "path_length", which is incremented for every pass
411414
let mut pass_length = 0;
412415

413-
let ret: PathBuilder = loop {
416+
let ret: PathBuilder<&Bump> = loop {
414417
if partial_paths.is_empty() {
415418
return None;
416419
}
@@ -533,7 +536,7 @@ impl TreeCache {
533536
if u64::from(backref_len) + 1 > entry.serialized_length {
534537
None
535538
} else {
536-
Some(ret.done())
539+
Some(ret.done().to_vec())
537540
}
538541
}
539542
}

0 commit comments

Comments
 (0)