Skip to content

Commit c11f0a2

Browse files
committed
add a version of traverse_path that's optimized for small int
1 parent 7a59879 commit c11f0a2

File tree

2 files changed

+105
-3
lines changed

2 files changed

+105
-3
lines changed

src/run_program.rs

+11-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
use super::traverse_path::traverse_path;
2-
use crate::allocator::{Allocator, Checkpoint, NodePtr, SExp};
1+
use super::traverse_path::{traverse_path, traverse_path_fast};
2+
use crate::allocator::{Allocator, Checkpoint, NodePtr, NodeVisitor, SExp};
33
use crate::cost::Cost;
44
use crate::dialect::{Dialect, OperatorSet};
55
use crate::err_utils::err;
@@ -279,7 +279,15 @@ impl<'a, D: Dialect> RunProgramContext<'a, D> {
279279
// put a bunch of ops on op_stack
280280
let SExp::Pair(op_node, op_list) = self.allocator.sexp(program) else {
281281
// the program is just a bitfield path through the env tree
282-
let r: Reduction = traverse_path(self.allocator, self.allocator.atom(program), env)?;
282+
let r: Reduction = self.allocator.visit_node(program, |node| -> Response {
283+
match node {
284+
NodeVisitor::Buffer(buf) => traverse_path(self.allocator, buf, env),
285+
NodeVisitor::U32(val) => traverse_path_fast(self.allocator, *val, env),
286+
NodeVisitor::Pair(_, _) => {
287+
panic!("expected atom, got pair");
288+
}
289+
}
290+
})?;
283291
self.push(r.1)?;
284292
return Ok(r.0);
285293
};

src/traverse_path.rs

+94
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,42 @@ pub fn traverse_path(allocator: &Allocator, node_index: &[u8], args: NodePtr) ->
7272
Ok(Reduction(cost, arg_list))
7373
}
7474

75+
// The cost calculation for this version of traverse_path assumes the node_index has the canonical
76+
// integer representation (which is true for SmallAtom in the allocator). If there are any
77+
// redundant leading zeros, the slow path must be used
78+
pub fn traverse_path_fast(allocator: &Allocator, mut node_index: u32, args: NodePtr) -> Response {
79+
if node_index == 0 {
80+
return Ok(Reduction(
81+
TRAVERSE_BASE_COST + TRAVERSE_COST_PER_BIT,
82+
allocator.nil(),
83+
));
84+
}
85+
86+
let mut arg_list: NodePtr = args;
87+
88+
let mut cost: Cost = TRAVERSE_BASE_COST + TRAVERSE_COST_PER_BIT;
89+
let mut num_bits = 0;
90+
while node_index != 1 {
91+
let SExp::Pair(left, right) = allocator.sexp(arg_list) else {
92+
return Err(EvalErr(arg_list, "path into atom".into()));
93+
};
94+
95+
let is_bit_set: bool = (node_index & 0x01) != 0;
96+
arg_list = if is_bit_set { right } else { left };
97+
node_index >>= 1;
98+
num_bits += 1
99+
}
100+
101+
cost += num_bits * TRAVERSE_COST_PER_BIT;
102+
// since positive numbers sometimes need a leading zero, e.g. 0x80, 0x8000 etc. We also
103+
// need to add the cost of that leading zero byte
104+
if num_bits == 7 || num_bits == 15 || num_bits == 23 || num_bits == 31 {
105+
cost += TRAVERSE_COST_PER_ZERO_BYTE;
106+
}
107+
108+
Ok(Reduction(cost, arg_list))
109+
}
110+
75111
#[test]
76112
fn test_msb_mask() {
77113
assert_eq!(msb_mask(0x0), 0x0);
@@ -166,3 +202,61 @@ fn test_traverse_path() {
166202
EvalErr(n2, "path into atom".to_string())
167203
);
168204
}
205+
206+
#[test]
207+
fn test_traverse_path_fast_fast() {
208+
use crate::allocator::Allocator;
209+
210+
let mut a = Allocator::new();
211+
let nul = a.nil();
212+
let n1 = a.new_atom(&[0, 1, 2]).unwrap();
213+
let n2 = a.new_atom(&[4, 5, 6]).unwrap();
214+
215+
assert_eq!(traverse_path_fast(&a, 0, n1).unwrap(), Reduction(44, nul));
216+
assert_eq!(traverse_path_fast(&a, 0b1, n1).unwrap(), Reduction(44, n1));
217+
assert_eq!(traverse_path_fast(&a, 0b1, n2).unwrap(), Reduction(44, n2));
218+
219+
let n3 = a.new_pair(n1, n2).unwrap();
220+
assert_eq!(traverse_path_fast(&a, 0b1, n3).unwrap(), Reduction(44, n3));
221+
assert_eq!(traverse_path_fast(&a, 0b10, n3).unwrap(), Reduction(48, n1));
222+
assert_eq!(traverse_path_fast(&a, 0b11, n3).unwrap(), Reduction(48, n2));
223+
assert_eq!(traverse_path_fast(&a, 0b11, n3).unwrap(), Reduction(48, n2));
224+
225+
let list = a.new_pair(n1, nul).unwrap();
226+
let list = a.new_pair(n2, list).unwrap();
227+
228+
assert_eq!(
229+
traverse_path_fast(&a, 0b10, list).unwrap(),
230+
Reduction(48, n2)
231+
);
232+
assert_eq!(
233+
traverse_path_fast(&a, 0b101, list).unwrap(),
234+
Reduction(52, n1)
235+
);
236+
assert_eq!(
237+
traverse_path_fast(&a, 0b111, list).unwrap(),
238+
Reduction(52, nul)
239+
);
240+
241+
// errors
242+
assert_eq!(
243+
traverse_path_fast(&a, 0b1011, list).unwrap_err(),
244+
EvalErr(nul, "path into atom".to_string())
245+
);
246+
assert_eq!(
247+
traverse_path_fast(&a, 0b1101, list).unwrap_err(),
248+
EvalErr(n1, "path into atom".to_string())
249+
);
250+
assert_eq!(
251+
traverse_path_fast(&a, 0b1001, list).unwrap_err(),
252+
EvalErr(n1, "path into atom".to_string())
253+
);
254+
assert_eq!(
255+
traverse_path_fast(&a, 0b1010, list).unwrap_err(),
256+
EvalErr(n2, "path into atom".to_string())
257+
);
258+
assert_eq!(
259+
traverse_path_fast(&a, 0b1110, list).unwrap_err(),
260+
EvalErr(n2, "path into atom".to_string())
261+
);
262+
}

0 commit comments

Comments
 (0)