Skip to content

Commit

Permalink
Merge pull request #12 from github/aneubeck/prepend
Browse files Browse the repository at this point in the history
Support prependable encoder
  • Loading branch information
aneubeck authored Aug 15, 2024
2 parents 0a012df + 2797581 commit 4ce06d1
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 34 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
A collection of useful algorithms written in Rust. Currently contains:

- [`geo_filters`](crates/geo_filters): probabilistic data structures that solve the [Distinct Count Problem](https://en.wikipedia.org/wiki/Count-distinct_problem) using geometric filters.
- [`bpe`](crates/bpe): fast, correct, and novel algorithms for the [Byte Pair Encoding Algorithm](https://en.wikipedia.org/wiki/Large_language_model#BPE) which are particularly useful for chunking of documents.

## Background

Expand Down
2 changes: 1 addition & 1 deletion crates/bpe/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ bench = false

[dependencies]
#daachorse = "1"
daachorse = { git = "https://github.com/aneubeck/daachorse.git", rev = "22f471532a25d90a320eae0902c759db2b8fe962" }
daachorse = { git = "https://github.com/aneubeck/daachorse.git", rev = "ac44a471a7be5a139535173073b8f1cd2e33bcbd" }
fnv = "1.0"
itertools = "0.12"
once_cell = "1"
Expand Down
69 changes: 44 additions & 25 deletions crates/bpe/src/appendable_encoder.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,22 @@
use daachorse::bytewise::iter::OverlappingStepper;

use crate::byte_pair_encoding::BytePairEncoding;

/// Encoder which keeps track of the encoding length.
struct State {
state: u32,
last_token: u32,
count: u32,
}

/// Encoder which keeps track of the encoding length while appending characters.
pub struct AppendableEncoder<'a> {
bpe: &'a BytePairEncoding,
stepper: OverlappingStepper<'a, u32>,
// TODO: If we only want to answer the length of the input text, then we could
// replace these vectors with some fixed size arrays. Essentially we can only
// go back up to the length of the longest token. This way can save some memory
// and reallocations.
last_token: Vec<u32>,
counts: Vec<u32>,
states: Vec<State>,
}

impl<'a> AppendableEncoder<'a> {
pub fn new(bpe: &'a BytePairEncoding) -> Self {
Self {
bpe,
stepper: bpe.overlapping_searcher.overlapping_stepper(),
last_token: vec![],
counts: vec![],
states: vec![],
}
}

Expand All @@ -31,24 +27,43 @@ impl<'a> AppendableEncoder<'a> {
}
}

pub fn truncate(&mut self, len: usize) {
self.states.truncate(len);
}

/// Appends a byte to the input string which should be tokenized.
/// The operation is amortized O(1) (due to vector resizing).
pub fn push(&mut self, c: u8) {
self.stepper.consume(c);
while let Some(m) = self.stepper.next() {
let (state, iter) = self.bpe.overlapping_searcher.consume(
self.states
.last()
.map(|s| s.state)
.unwrap_or_else(|| self.bpe.overlapping_searcher.start_state()),
self.states.len() + 1,
c,
);
for m in iter {
let new_token = m.value();
let new_range = m.start()..m.end();
assert_eq!(new_range.end, self.last_token.len() + 1);
assert_eq!(new_range.end, self.states.len() + 1);
if new_range.start == 0 {
self.last_token.push(new_token);
self.counts.push(1);
self.states.push(State {
state,
last_token: new_token,
count: 1,
});
break;
} else {
let prev_token = unsafe { *self.last_token.get_unchecked(new_range.start - 1) };
let prev_token =
unsafe { self.states.get_unchecked(new_range.start - 1).last_token };
if self.bpe.is_valid_token_pair(prev_token, new_token) {
self.last_token.push(new_token);
let prev_count = unsafe { *self.counts.get_unchecked(new_range.start - 1) };
self.counts.push(prev_count + 1);
let prev_count =
unsafe { self.states.get_unchecked(new_range.start - 1).count };
self.states.push(State {
state,
last_token: new_token,
count: prev_count + 1,
});
break;
}
}
Expand All @@ -57,13 +72,17 @@ impl<'a> AppendableEncoder<'a> {

/// Returns the number of tokens required to tokenize the input text.
/// This operation is O(1) and can be called at any point in time.
pub fn token_count(&self) -> usize {
self.states.last().map(|s| s.count).unwrap_or(0) as usize
}

pub fn len(&self) -> usize {
self.counts.last().copied().unwrap_or(0) as usize
self.states.len()
}

/// Returns true if the structure represents the empty string.
pub fn is_empty(&self) -> bool {
self.counts.is_empty()
self.states.is_empty()
}
}

Expand All @@ -79,7 +98,7 @@ mod tests {
let mut enc = AppendableEncoder::new(bpe);
let input_string = create_test_bytes(bpe, 100);
for (i, c) in input_string.iter().enumerate() {
assert_eq!(enc.len(), bpe.count(&input_string[0..i]));
assert_eq!(enc.token_count(), bpe.count(&input_string[0..i]));
enc.push(*c);
}
}
Expand Down
30 changes: 22 additions & 8 deletions crates/bpe/src/byte_pair_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ pub struct BytePairEncoding {
deserialize_with = "deserialize_daac"
)]
pub(crate) overlapping_searcher: DoubleArrayAhoCorasick<u32>,
/// An aho corasick automaton to find ALL tokens in a byte sequence which is being processed in reverse order.
#[serde(
serialize_with = "serialize_daac",
deserialize_with = "deserialize_daac"
)]
pub(crate) overlapping_searcher_rev: DoubleArrayAhoCorasick<u32>,
/// Mapping from a token to the next longest prefix token.
/// This is in principle information represented by the AhoCorasick automaton.
/// But we don't have efficient access to it and therefore store it here again.
Expand Down Expand Up @@ -179,11 +185,13 @@ impl BytePairEncoding {
let start = Instant::now();
println!("loaded tiktoken: {:?}", start.elapsed());
let mut all_tokens = Vec::new();
let mut all_tokens_rev = Vec::new();
let mut token_starts = vec![0];
let mut bytes_hash_to_token = FnvHashMap::default();
for i in 0..num_tokens {
let token = tiktoken_bpe._decode_native(&[i]);
bytes_hash_to_token.insert(hash_bytes(&token), i as u32);
all_tokens_rev.extend(token.iter().copied().rev());
all_tokens.extend(token);
token_starts.push(all_tokens.len() as u32);
}
Expand All @@ -199,6 +207,9 @@ impl BytePairEncoding {
let overlapping_searcher =
DoubleArrayAhoCorasick::<u32>::new(token_iter(&all_tokens, &token_starts)).expect("");
println!("constructed overlapping searcher: {:?}", start.elapsed());
let overlapping_searcher_rev =
DoubleArrayAhoCorasick::<u32>::new(token_iter(&all_tokens_rev, &token_starts))
.expect("");

let next_prefix_match: Vec<_> = token_iter(&all_tokens, &token_starts)
.map(|token| {
Expand Down Expand Up @@ -239,6 +250,7 @@ impl BytePairEncoding {
token_starts,
bytes_hash_to_token,
overlapping_searcher,
overlapping_searcher_rev,
longest_searcher,
next_prefix_match,
pair_lookup,
Expand Down Expand Up @@ -304,10 +316,11 @@ impl BytePairEncoding {
/// Computes for every prefix of the input text a corresponding last token.
pub(crate) fn encode_all_prefixes(&self, text: &[u8]) -> Vec<u32> {
let mut last_token = Vec::with_capacity(text.len());
let mut stepper = self.overlapping_searcher.overlapping_stepper();
for c in text {
stepper.consume(*c);
while let Some(m) = stepper.next() {
let mut state = self.overlapping_searcher.start_state();
for (pos, c) in text.iter().enumerate() {
let (s, iter) = self.overlapping_searcher.consume(state, pos + 1, *c);
state = s;
for m in iter {
let new_token = m.value();
let new_range = m.start()..m.end();
assert_eq!(new_range.end, last_token.len() + 1);
Expand Down Expand Up @@ -420,11 +433,12 @@ impl BytePairEncoding {
/// tokenization produced by the original BPE algorithm.
pub fn encode_minimal(&self, text: &[u8]) -> Vec<u32> {
let mut last_token: Vec<(u32, u32)> = Vec::with_capacity(text.len());
let mut stepper = self.overlapping_searcher.overlapping_stepper();
for c in text {
stepper.consume(*c);
let mut state = self.overlapping_searcher.start_state();
for (pos, c) in text.iter().enumerate() {
let (s, iter) = self.overlapping_searcher.consume(state, pos + 1, *c);
state = s;
let mut best = (0, u32::MAX);
while let Some(m) = stepper.next() {
for m in iter {
if m.start() == 0 {
best = (m.value(), 1);
break;
Expand Down
Binary file modified crates/bpe/src/data/bpe_cl100k.dict
Binary file not shown.
1 change: 1 addition & 0 deletions crates/bpe/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ pub mod backtrack_encoder;
mod bitfield;
pub mod byte_pair_encoding;
pub mod interval_encoding;
pub mod prependable_encoder;
105 changes: 105 additions & 0 deletions crates/bpe/src/prependable_encoder.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
use crate::byte_pair_encoding::BytePairEncoding;

struct State {
state: u32,
prev_token: u32,
count: u32,
}

/// Encoder which keeps track of the encoding length while prepending characters.
pub struct PrependableEncoder<'a> {
bpe: &'a BytePairEncoding,
states: Vec<State>,
}

impl<'a> PrependableEncoder<'a> {
pub fn new(bpe: &'a BytePairEncoding) -> Self {
Self {
bpe,
states: vec![],
}
}

pub fn truncate(&mut self, len: usize) {
self.states.truncate(len);
}

/// Prepends multiple bytes to the input string.
pub fn extend(&mut self, iter: impl Iterator<Item = u8>) {
for c in iter {
self.push(c);
}
}

/// Prepends a byte to the input string which should be tokenized.
/// The operation is amortized O(1) (due to vector resizing).
pub fn push(&mut self, c: u8) {
let (state, iter) = self.bpe.overlapping_searcher_rev.consume(
self.states
.last()
.map(|s| s.state)
.unwrap_or_else(|| self.bpe.overlapping_searcher_rev.start_state()),
self.states.len() + 1,
c,
);
for m in iter {
let new_token = m.value();
let new_range = m.start()..m.end();
assert_eq!(new_range.end, self.states.len() + 1);
if new_range.start == 0 {
self.states.push(State {
state,
prev_token: new_token,
count: 1,
});
break;
} else {
let next_token =
unsafe { self.states.get_unchecked(new_range.start - 1).prev_token };
if self.bpe.is_valid_token_pair(new_token, next_token) {
let prev_count =
unsafe { self.states.get_unchecked(new_range.start - 1).count };
self.states.push(State {
state,
prev_token: new_token,
count: prev_count + 1,
});
break;
}
}
}
}

/// Returns the number of tokens required to tokenize the input text.
/// This operation is O(1) and can be called at any point in time.
pub fn token_count(&self) -> usize {
self.states.last().map(|s| s.count).unwrap_or(0) as usize
}

pub fn len(&self) -> usize {
self.states.len()
}

/// Returns true if the structure represents the empty string.
pub fn is_empty(&self) -> bool {
self.states.is_empty()
}
}

#[cfg(test)]
mod tests {
use crate::byte_pair_encoding::{create_test_bytes, BytePairEncoding};

use super::PrependableEncoder;

#[test]
fn test_prependable_encoder() {
let bpe = BytePairEncoding::cl100k();
let mut enc = PrependableEncoder::new(bpe);
let input_string = create_test_bytes(bpe, 100);
for (i, c) in input_string.iter().enumerate().rev() {
enc.push(*c);
assert_eq!(enc.token_count(), bpe.count(&input_string[i..]));
}
}
}
3 changes: 3 additions & 0 deletions crates/geo_filters/src/diff_count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,9 @@ mod tests {
let masked_a = masked(&a, mask, mask_size);
let masked_b = masked(&b, mask, mask_size);
let masked_expected = masked(&expected, mask, mask_size);
// FIXME: test failed once with:
// left: ~12.37563 (msb: [390, 334, 263, 242, 222, 215, 164, 148, 100, 97, 66, 36], |lsb|: 36)
// right: ~12.37563 (msb: [390, 334, 263, 242, 222, 215, 164, 148, 100, 97, 66, 36], |lsb|: 0)
assert_eq!(masked_expected, xor(&masked_a, &masked_b));
}
}
Expand Down

0 comments on commit 4ce06d1

Please sign in to comment.