Skip to content

Commit 170165e

Browse files
authored
Merge pull request #13 from github/aneubeck/dict
Add constructor from dictionary and count_till_limit function
2 parents 4ce06d1 + 89ac128 commit 170165e

File tree

1 file changed

+30
-11
lines changed

1 file changed

+30
-11
lines changed

crates/bpe/src/byte_pair_encoding.rs

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ use std::cmp::Reverse;
22
use std::collections::BinaryHeap;
33
use std::hash::{Hash, Hasher};
44
use std::ops::Range;
5-
use std::time::Instant;
65

76
use daachorse::{DoubleArrayAhoCorasick, DoubleArrayAhoCorasickBuilder};
87
use fnv::{FnvHashMap, FnvHasher};
@@ -181,32 +180,32 @@ impl BytePairEncoding {
181180
&BPE_CL100K
182181
}
183182

183+
/// Construct a BytePairEncoding instance frmo a tiktoken dictionary.
184184
pub fn from_tiktoken(tiktoken_bpe: &CoreBPE, num_tokens: usize) -> Self {
185-
let start = Instant::now();
186-
println!("loaded tiktoken: {:?}", start.elapsed());
185+
Self::from_dictionary((0..num_tokens).map(|i| tiktoken_bpe._decode_native(&[i])))
186+
}
187+
188+
/// Construct a BytePairEncoding instance from an iterator which enumerates all tokens.
189+
pub fn from_dictionary(iter: impl Iterator<Item = Vec<u8>>) -> Self {
187190
let mut all_tokens = Vec::new();
188191
let mut all_tokens_rev = Vec::new();
189192
let mut token_starts = vec![0];
190193
let mut bytes_hash_to_token = FnvHashMap::default();
191-
for i in 0..num_tokens {
192-
let token = tiktoken_bpe._decode_native(&[i]);
194+
for (i, token) in iter.enumerate() {
193195
bytes_hash_to_token.insert(hash_bytes(&token), i as u32);
194196
all_tokens_rev.extend(token.iter().copied().rev());
195197
all_tokens.extend(token);
196198
token_starts.push(all_tokens.len() as u32);
197199
}
198200
assert_eq!(bytes_hash_to_token.len() + 1, token_starts.len());
199-
println!("copied tokens: {:?}", start.elapsed());
200201

201202
let longest_searcher = DoubleArrayAhoCorasickBuilder::new()
202203
.match_kind(daachorse::MatchKind::LeftmostLongest)
203204
.build(token_iter(&all_tokens, &token_starts))
204205
.expect("failed to build AhoCorasick");
205-
println!("constructed longest searcher: {:?}", start.elapsed());
206206

207207
let overlapping_searcher =
208208
DoubleArrayAhoCorasick::<u32>::new(token_iter(&all_tokens, &token_starts)).expect("");
209-
println!("constructed overlapping searcher: {:?}", start.elapsed());
210209
let overlapping_searcher_rev =
211210
DoubleArrayAhoCorasick::<u32>::new(token_iter(&all_tokens_rev, &token_starts))
212211
.expect("");
@@ -216,7 +215,6 @@ impl BytePairEncoding {
216215
next_match(&longest_searcher, &token[0..token.len() - 1]).unwrap_or(u32::MAX)
217216
})
218217
.collect();
219-
println!("constructed next_prefix_match: {:?}", start.elapsed());
220218

221219
let mut split_table = vec![];
222220
let mut pair_lookup = FnvHashMap::default();
@@ -243,8 +241,6 @@ impl BytePairEncoding {
243241
split_table.push((id as u32, id as u32));
244242
}
245243
}
246-
println!("constructed split table: {:?}", start.elapsed());
247-
248244
Self {
249245
all_tokens,
250246
token_starts,
@@ -339,12 +335,35 @@ impl BytePairEncoding {
339335
last_token
340336
}
341337

338+
/// Counts the number tokens produced when encoding the text.
342339
pub fn count(&self, text: &[u8]) -> usize {
343340
let mut enc = BacktrackEncoder::new(self, text);
344341
while enc.step().is_some() {}
345342
enc.count()
346343
}
347344

345+
/// Returns the token count iff the total token count stays below the specified `token_limit`.
346+
/// Otherwise, it returns false.
347+
/// This function can be faster than `count` when the token_limit is much smaller than the provided text.
348+
pub fn count_till_limit(&self, text: &[u8], token_limit: usize) -> Option<usize> {
349+
let mut enc = BacktrackEncoder::new(self, text);
350+
// When the text has exactly the desired number of tokens, then it could in theory happen that
351+
// the token_limit is exceeded before the end of the text is reached (and a different encoding is tested).
352+
// To be on the "safe" side, we add a little buffer for such cases.
353+
// TODO: Determine exactly how large this buffer must be in the worst case.
354+
let limit_with_buffer = token_limit.saturating_add(10);
355+
while enc.step().is_some() {
356+
if enc.count() > limit_with_buffer {
357+
return None;
358+
}
359+
}
360+
if enc.count() <= token_limit {
361+
Some(enc.count())
362+
} else {
363+
None
364+
}
365+
}
366+
348367
pub fn encode_via_table(&self, text: &[u8]) -> Vec<u32> {
349368
let last_token = self.encode_all_prefixes(text);
350369
let mut encoded = Vec::with_capacity(text.len() / 3);

0 commit comments

Comments
 (0)