Skip to content

Commit 4ce06d1

Browse files
authored
Merge pull request #12 from github/aneubeck/prepend
Support prependable encoder
2 parents 0a012df + 2797581 commit 4ce06d1

File tree

8 files changed

+177
-34
lines changed

8 files changed

+177
-34
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
A collection of useful algorithms written in Rust. Currently contains:
44

55
- [`geo_filters`](crates/geo_filters): probabilistic data structures that solve the [Distinct Count Problem](https://en.wikipedia.org/wiki/Count-distinct_problem) using geometric filters.
6+
- [`bpe`](crates/bpe): fast, correct, and novel algorithms for the [Byte Pair Encoding Algorithm](https://en.wikipedia.org/wiki/Large_language_model#BPE) which are particularly useful for chunking of documents.
67

78
## Background
89

crates/bpe/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ bench = false
99

1010
[dependencies]
1111
#daachorse = "1"
12-
daachorse = { git = "https://github.com/aneubeck/daachorse.git", rev = "22f471532a25d90a320eae0902c759db2b8fe962" }
12+
daachorse = { git = "https://github.com/aneubeck/daachorse.git", rev = "ac44a471a7be5a139535173073b8f1cd2e33bcbd" }
1313
fnv = "1.0"
1414
itertools = "0.12"
1515
once_cell = "1"

crates/bpe/src/appendable_encoder.rs

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,22 @@
1-
use daachorse::bytewise::iter::OverlappingStepper;
2-
31
use crate::byte_pair_encoding::BytePairEncoding;
42

5-
/// Encoder which keeps track of the encoding length.
3+
struct State {
4+
state: u32,
5+
last_token: u32,
6+
count: u32,
7+
}
8+
9+
/// Encoder which keeps track of the encoding length while appending characters.
610
pub struct AppendableEncoder<'a> {
711
bpe: &'a BytePairEncoding,
8-
stepper: OverlappingStepper<'a, u32>,
9-
// TODO: If we only want to answer the length of the input text, then we could
10-
// replace these vectors with some fixed size arrays. Essentially we can only
11-
// go back up to the length of the longest token. This way can save some memory
12-
// and reallocations.
13-
last_token: Vec<u32>,
14-
counts: Vec<u32>,
12+
states: Vec<State>,
1513
}
1614

1715
impl<'a> AppendableEncoder<'a> {
1816
pub fn new(bpe: &'a BytePairEncoding) -> Self {
1917
Self {
2018
bpe,
21-
stepper: bpe.overlapping_searcher.overlapping_stepper(),
22-
last_token: vec![],
23-
counts: vec![],
19+
states: vec![],
2420
}
2521
}
2622

@@ -31,24 +27,43 @@ impl<'a> AppendableEncoder<'a> {
3127
}
3228
}
3329

30+
pub fn truncate(&mut self, len: usize) {
31+
self.states.truncate(len);
32+
}
33+
3434
/// Appends a byte to the input string which should be tokenized.
3535
/// The operation is amortized O(1) (due to vector resizing).
3636
pub fn push(&mut self, c: u8) {
37-
self.stepper.consume(c);
38-
while let Some(m) = self.stepper.next() {
37+
let (state, iter) = self.bpe.overlapping_searcher.consume(
38+
self.states
39+
.last()
40+
.map(|s| s.state)
41+
.unwrap_or_else(|| self.bpe.overlapping_searcher.start_state()),
42+
self.states.len() + 1,
43+
c,
44+
);
45+
for m in iter {
3946
let new_token = m.value();
4047
let new_range = m.start()..m.end();
41-
assert_eq!(new_range.end, self.last_token.len() + 1);
48+
assert_eq!(new_range.end, self.states.len() + 1);
4249
if new_range.start == 0 {
43-
self.last_token.push(new_token);
44-
self.counts.push(1);
50+
self.states.push(State {
51+
state,
52+
last_token: new_token,
53+
count: 1,
54+
});
4555
break;
4656
} else {
47-
let prev_token = unsafe { *self.last_token.get_unchecked(new_range.start - 1) };
57+
let prev_token =
58+
unsafe { self.states.get_unchecked(new_range.start - 1).last_token };
4859
if self.bpe.is_valid_token_pair(prev_token, new_token) {
49-
self.last_token.push(new_token);
50-
let prev_count = unsafe { *self.counts.get_unchecked(new_range.start - 1) };
51-
self.counts.push(prev_count + 1);
60+
let prev_count =
61+
unsafe { self.states.get_unchecked(new_range.start - 1).count };
62+
self.states.push(State {
63+
state,
64+
last_token: new_token,
65+
count: prev_count + 1,
66+
});
5267
break;
5368
}
5469
}
@@ -57,13 +72,17 @@ impl<'a> AppendableEncoder<'a> {
5772

5873
/// Returns the number of tokens required to tokenize the input text.
5974
/// This operation is O(1) and can be called at any point in time.
75+
pub fn token_count(&self) -> usize {
76+
self.states.last().map(|s| s.count).unwrap_or(0) as usize
77+
}
78+
6079
pub fn len(&self) -> usize {
61-
self.counts.last().copied().unwrap_or(0) as usize
80+
self.states.len()
6281
}
6382

6483
/// Returns true if the structure represents the empty string.
6584
pub fn is_empty(&self) -> bool {
66-
self.counts.is_empty()
85+
self.states.is_empty()
6786
}
6887
}
6988

@@ -79,7 +98,7 @@ mod tests {
7998
let mut enc = AppendableEncoder::new(bpe);
8099
let input_string = create_test_bytes(bpe, 100);
81100
for (i, c) in input_string.iter().enumerate() {
82-
assert_eq!(enc.len(), bpe.count(&input_string[0..i]));
101+
assert_eq!(enc.token_count(), bpe.count(&input_string[0..i]));
83102
enc.push(*c);
84103
}
85104
}

crates/bpe/src/byte_pair_encoding.rs

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ pub struct BytePairEncoding {
4949
deserialize_with = "deserialize_daac"
5050
)]
5151
pub(crate) overlapping_searcher: DoubleArrayAhoCorasick<u32>,
52+
/// An aho corasick automaton to find ALL tokens in a byte sequence which is being processed in reverse order.
53+
#[serde(
54+
serialize_with = "serialize_daac",
55+
deserialize_with = "deserialize_daac"
56+
)]
57+
pub(crate) overlapping_searcher_rev: DoubleArrayAhoCorasick<u32>,
5258
/// Mapping from a token to the next longest prefix token.
5359
/// This is in principle information represented by the AhoCorasick automaton.
5460
/// But we don't have efficient access to it and therefore store it here again.
@@ -179,11 +185,13 @@ impl BytePairEncoding {
179185
let start = Instant::now();
180186
println!("loaded tiktoken: {:?}", start.elapsed());
181187
let mut all_tokens = Vec::new();
188+
let mut all_tokens_rev = Vec::new();
182189
let mut token_starts = vec![0];
183190
let mut bytes_hash_to_token = FnvHashMap::default();
184191
for i in 0..num_tokens {
185192
let token = tiktoken_bpe._decode_native(&[i]);
186193
bytes_hash_to_token.insert(hash_bytes(&token), i as u32);
194+
all_tokens_rev.extend(token.iter().copied().rev());
187195
all_tokens.extend(token);
188196
token_starts.push(all_tokens.len() as u32);
189197
}
@@ -199,6 +207,9 @@ impl BytePairEncoding {
199207
let overlapping_searcher =
200208
DoubleArrayAhoCorasick::<u32>::new(token_iter(&all_tokens, &token_starts)).expect("");
201209
println!("constructed overlapping searcher: {:?}", start.elapsed());
210+
let overlapping_searcher_rev =
211+
DoubleArrayAhoCorasick::<u32>::new(token_iter(&all_tokens_rev, &token_starts))
212+
.expect("");
202213

203214
let next_prefix_match: Vec<_> = token_iter(&all_tokens, &token_starts)
204215
.map(|token| {
@@ -239,6 +250,7 @@ impl BytePairEncoding {
239250
token_starts,
240251
bytes_hash_to_token,
241252
overlapping_searcher,
253+
overlapping_searcher_rev,
242254
longest_searcher,
243255
next_prefix_match,
244256
pair_lookup,
@@ -304,10 +316,11 @@ impl BytePairEncoding {
304316
/// Computes for every prefix of the input text a corresponding last token.
305317
pub(crate) fn encode_all_prefixes(&self, text: &[u8]) -> Vec<u32> {
306318
let mut last_token = Vec::with_capacity(text.len());
307-
let mut stepper = self.overlapping_searcher.overlapping_stepper();
308-
for c in text {
309-
stepper.consume(*c);
310-
while let Some(m) = stepper.next() {
319+
let mut state = self.overlapping_searcher.start_state();
320+
for (pos, c) in text.iter().enumerate() {
321+
let (s, iter) = self.overlapping_searcher.consume(state, pos + 1, *c);
322+
state = s;
323+
for m in iter {
311324
let new_token = m.value();
312325
let new_range = m.start()..m.end();
313326
assert_eq!(new_range.end, last_token.len() + 1);
@@ -420,11 +433,12 @@ impl BytePairEncoding {
420433
/// tokenization produced by the original BPE algorithm.
421434
pub fn encode_minimal(&self, text: &[u8]) -> Vec<u32> {
422435
let mut last_token: Vec<(u32, u32)> = Vec::with_capacity(text.len());
423-
let mut stepper = self.overlapping_searcher.overlapping_stepper();
424-
for c in text {
425-
stepper.consume(*c);
436+
let mut state = self.overlapping_searcher.start_state();
437+
for (pos, c) in text.iter().enumerate() {
438+
let (s, iter) = self.overlapping_searcher.consume(state, pos + 1, *c);
439+
state = s;
426440
let mut best = (0, u32::MAX);
427-
while let Some(m) = stepper.next() {
441+
for m in iter {
428442
if m.start() == 0 {
429443
best = (m.value(), 1);
430444
break;

crates/bpe/src/data/bpe_cl100k.dict

3.3 MB
Binary file not shown.

crates/bpe/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ pub mod backtrack_encoder;
33
mod bitfield;
44
pub mod byte_pair_encoding;
55
pub mod interval_encoding;
6+
pub mod prependable_encoder;

crates/bpe/src/prependable_encoder.rs

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
use crate::byte_pair_encoding::BytePairEncoding;
2+
3+
struct State {
4+
state: u32,
5+
prev_token: u32,
6+
count: u32,
7+
}
8+
9+
/// Encoder which keeps track of the encoding length while prepending characters.
10+
pub struct PrependableEncoder<'a> {
11+
bpe: &'a BytePairEncoding,
12+
states: Vec<State>,
13+
}
14+
15+
impl<'a> PrependableEncoder<'a> {
16+
pub fn new(bpe: &'a BytePairEncoding) -> Self {
17+
Self {
18+
bpe,
19+
states: vec![],
20+
}
21+
}
22+
23+
pub fn truncate(&mut self, len: usize) {
24+
self.states.truncate(len);
25+
}
26+
27+
/// Prepends multiple bytes to the input string.
28+
pub fn extend(&mut self, iter: impl Iterator<Item = u8>) {
29+
for c in iter {
30+
self.push(c);
31+
}
32+
}
33+
34+
/// Prepends a byte to the input string which should be tokenized.
35+
/// The operation is amortized O(1) (due to vector resizing).
36+
pub fn push(&mut self, c: u8) {
37+
let (state, iter) = self.bpe.overlapping_searcher_rev.consume(
38+
self.states
39+
.last()
40+
.map(|s| s.state)
41+
.unwrap_or_else(|| self.bpe.overlapping_searcher_rev.start_state()),
42+
self.states.len() + 1,
43+
c,
44+
);
45+
for m in iter {
46+
let new_token = m.value();
47+
let new_range = m.start()..m.end();
48+
assert_eq!(new_range.end, self.states.len() + 1);
49+
if new_range.start == 0 {
50+
self.states.push(State {
51+
state,
52+
prev_token: new_token,
53+
count: 1,
54+
});
55+
break;
56+
} else {
57+
let next_token =
58+
unsafe { self.states.get_unchecked(new_range.start - 1).prev_token };
59+
if self.bpe.is_valid_token_pair(new_token, next_token) {
60+
let prev_count =
61+
unsafe { self.states.get_unchecked(new_range.start - 1).count };
62+
self.states.push(State {
63+
state,
64+
prev_token: new_token,
65+
count: prev_count + 1,
66+
});
67+
break;
68+
}
69+
}
70+
}
71+
}
72+
73+
/// Returns the number of tokens required to tokenize the input text.
74+
/// This operation is O(1) and can be called at any point in time.
75+
pub fn token_count(&self) -> usize {
76+
self.states.last().map(|s| s.count).unwrap_or(0) as usize
77+
}
78+
79+
pub fn len(&self) -> usize {
80+
self.states.len()
81+
}
82+
83+
/// Returns true if the structure represents the empty string.
84+
pub fn is_empty(&self) -> bool {
85+
self.states.is_empty()
86+
}
87+
}
88+
89+
#[cfg(test)]
90+
mod tests {
91+
use crate::byte_pair_encoding::{create_test_bytes, BytePairEncoding};
92+
93+
use super::PrependableEncoder;
94+
95+
#[test]
96+
fn test_prependable_encoder() {
97+
let bpe = BytePairEncoding::cl100k();
98+
let mut enc = PrependableEncoder::new(bpe);
99+
let input_string = create_test_bytes(bpe, 100);
100+
for (i, c) in input_string.iter().enumerate().rev() {
101+
enc.push(*c);
102+
assert_eq!(enc.token_count(), bpe.count(&input_string[i..]));
103+
}
104+
}
105+
}

crates/geo_filters/src/diff_count.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,9 @@ mod tests {
518518
let masked_a = masked(&a, mask, mask_size);
519519
let masked_b = masked(&b, mask, mask_size);
520520
let masked_expected = masked(&expected, mask, mask_size);
521+
// FIXME: test failed once with:
522+
// left: ~12.37563 (msb: [390, 334, 263, 242, 222, 215, 164, 148, 100, 97, 66, 36], |lsb|: 36)
523+
// right: ~12.37563 (msb: [390, 334, 263, 242, 222, 215, 164, 148, 100, 97, 66, 36], |lsb|: 0)
521524
assert_eq!(masked_expected, xor(&masked_a, &masked_b));
522525
}
523526
}

0 commit comments

Comments
 (0)