Skip to content

Commit

Permalink
add count_till_limit function
Browse files Browse the repository at this point in the history
  • Loading branch information
aneubeck committed Aug 16, 2024
1 parent f98cbd4 commit 89ac128
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions crates/bpe/src/byte_pair_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,12 +335,35 @@ impl BytePairEncoding {
last_token
}

/// Counts the number tokens produced when encoding the text.
pub fn count(&self, text: &[u8]) -> usize {
let mut enc = BacktrackEncoder::new(self, text);
while enc.step().is_some() {}
enc.count()
}

/// Returns the token count iff the total token count stays below the specified `token_limit`.
/// Otherwise, it returns false.
/// This function can be faster than `count` when the token_limit is much smaller than the provided text.
pub fn count_till_limit(&self, text: &[u8], token_limit: usize) -> Option<usize> {
let mut enc = BacktrackEncoder::new(self, text);
// When the text has exactly the desired number of tokens, then it could in theory happen that
// the token_limit is exceeded before the end of the text is reached (and a different encoding is tested).
// To be on the "safe" side, we add a little buffer for such cases.
// TODO: Determine exactly how large this buffer must be in the worst case.
let limit_with_buffer = token_limit.saturating_add(10);
while enc.step().is_some() {
if enc.count() > limit_with_buffer {
return None;
}
}
if enc.count() <= token_limit {
Some(enc.count())
} else {
None
}
}

pub fn encode_via_table(&self, text: &[u8]) -> Vec<u32> {
let last_token = self.encode_all_prefixes(text);
let mut encoded = Vec::with_capacity(text.len() / 3);
Expand Down

0 comments on commit 89ac128

Please sign in to comment.