Skip to content

Commit

Permalink
Merge pull request #13 from github/aneubeck/dict
Browse files Browse the repository at this point in the history
Add constructor from dictionary and count_till_limit function
  • Loading branch information
aneubeck authored Aug 16, 2024
2 parents 4ce06d1 + 89ac128 commit 170165e
Showing 1 changed file with 30 additions and 11 deletions.
41 changes: 30 additions & 11 deletions crates/bpe/src/byte_pair_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::hash::{Hash, Hasher};
use std::ops::Range;
use std::time::Instant;

use daachorse::{DoubleArrayAhoCorasick, DoubleArrayAhoCorasickBuilder};
use fnv::{FnvHashMap, FnvHasher};
Expand Down Expand Up @@ -181,32 +180,32 @@ impl BytePairEncoding {
&BPE_CL100K
}

/// Construct a BytePairEncoding instance frmo a tiktoken dictionary.
pub fn from_tiktoken(tiktoken_bpe: &CoreBPE, num_tokens: usize) -> Self {
let start = Instant::now();
println!("loaded tiktoken: {:?}", start.elapsed());
Self::from_dictionary((0..num_tokens).map(|i| tiktoken_bpe._decode_native(&[i])))
}

/// Construct a BytePairEncoding instance from an iterator which enumerates all tokens.
pub fn from_dictionary(iter: impl Iterator<Item = Vec<u8>>) -> Self {
let mut all_tokens = Vec::new();
let mut all_tokens_rev = Vec::new();
let mut token_starts = vec![0];
let mut bytes_hash_to_token = FnvHashMap::default();
for i in 0..num_tokens {
let token = tiktoken_bpe._decode_native(&[i]);
for (i, token) in iter.enumerate() {
bytes_hash_to_token.insert(hash_bytes(&token), i as u32);
all_tokens_rev.extend(token.iter().copied().rev());
all_tokens.extend(token);
token_starts.push(all_tokens.len() as u32);
}
assert_eq!(bytes_hash_to_token.len() + 1, token_starts.len());
println!("copied tokens: {:?}", start.elapsed());

let longest_searcher = DoubleArrayAhoCorasickBuilder::new()
.match_kind(daachorse::MatchKind::LeftmostLongest)
.build(token_iter(&all_tokens, &token_starts))
.expect("failed to build AhoCorasick");
println!("constructed longest searcher: {:?}", start.elapsed());

let overlapping_searcher =
DoubleArrayAhoCorasick::<u32>::new(token_iter(&all_tokens, &token_starts)).expect("");
println!("constructed overlapping searcher: {:?}", start.elapsed());
let overlapping_searcher_rev =
DoubleArrayAhoCorasick::<u32>::new(token_iter(&all_tokens_rev, &token_starts))
.expect("");
Expand All @@ -216,7 +215,6 @@ impl BytePairEncoding {
next_match(&longest_searcher, &token[0..token.len() - 1]).unwrap_or(u32::MAX)
})
.collect();
println!("constructed next_prefix_match: {:?}", start.elapsed());

let mut split_table = vec![];
let mut pair_lookup = FnvHashMap::default();
Expand All @@ -243,8 +241,6 @@ impl BytePairEncoding {
split_table.push((id as u32, id as u32));
}
}
println!("constructed split table: {:?}", start.elapsed());

Self {
all_tokens,
token_starts,
Expand Down Expand Up @@ -339,12 +335,35 @@ impl BytePairEncoding {
last_token
}

/// Counts the number tokens produced when encoding the text.
pub fn count(&self, text: &[u8]) -> usize {
let mut enc = BacktrackEncoder::new(self, text);
while enc.step().is_some() {}
enc.count()
}

/// Returns the token count iff the total token count stays below the specified `token_limit`.
/// Otherwise, it returns false.
/// This function can be faster than `count` when the token_limit is much smaller than the provided text.
pub fn count_till_limit(&self, text: &[u8], token_limit: usize) -> Option<usize> {
let mut enc = BacktrackEncoder::new(self, text);
// When the text has exactly the desired number of tokens, then it could in theory happen that
// the token_limit is exceeded before the end of the text is reached (and a different encoding is tested).
// To be on the "safe" side, we add a little buffer for such cases.
// TODO: Determine exactly how large this buffer must be in the worst case.
let limit_with_buffer = token_limit.saturating_add(10);
while enc.step().is_some() {
if enc.count() > limit_with_buffer {
return None;
}
}
if enc.count() <= token_limit {
Some(enc.count())
} else {
None
}
}

pub fn encode_via_table(&self, text: &[u8]) -> Vec<u32> {
let last_token = self.encode_all_prefixes(text);
let mut encoded = Vec::with_capacity(text.len() / 3);
Expand Down

0 comments on commit 170165e

Please sign in to comment.