@@ -2,7 +2,6 @@ use std::cmp::Reverse;
2
2
use std:: collections:: BinaryHeap ;
3
3
use std:: hash:: { Hash , Hasher } ;
4
4
use std:: ops:: Range ;
5
- use std:: time:: Instant ;
6
5
7
6
use daachorse:: { DoubleArrayAhoCorasick , DoubleArrayAhoCorasickBuilder } ;
8
7
use fnv:: { FnvHashMap , FnvHasher } ;
@@ -181,32 +180,32 @@ impl BytePairEncoding {
181
180
& BPE_CL100K
182
181
}
183
182
183
+ /// Construct a BytePairEncoding instance frmo a tiktoken dictionary.
184
184
pub fn from_tiktoken ( tiktoken_bpe : & CoreBPE , num_tokens : usize ) -> Self {
185
- let start = Instant :: now ( ) ;
186
- println ! ( "loaded tiktoken: {:?}" , start. elapsed( ) ) ;
185
+ Self :: from_dictionary ( ( 0 ..num_tokens) . map ( |i| tiktoken_bpe. _decode_native ( & [ i] ) ) )
186
+ }
187
+
188
+ /// Construct a BytePairEncoding instance from an iterator which enumerates all tokens.
189
+ pub fn from_dictionary ( iter : impl Iterator < Item = Vec < u8 > > ) -> Self {
187
190
let mut all_tokens = Vec :: new ( ) ;
188
191
let mut all_tokens_rev = Vec :: new ( ) ;
189
192
let mut token_starts = vec ! [ 0 ] ;
190
193
let mut bytes_hash_to_token = FnvHashMap :: default ( ) ;
191
- for i in 0 ..num_tokens {
192
- let token = tiktoken_bpe. _decode_native ( & [ i] ) ;
194
+ for ( i, token) in iter. enumerate ( ) {
193
195
bytes_hash_to_token. insert ( hash_bytes ( & token) , i as u32 ) ;
194
196
all_tokens_rev. extend ( token. iter ( ) . copied ( ) . rev ( ) ) ;
195
197
all_tokens. extend ( token) ;
196
198
token_starts. push ( all_tokens. len ( ) as u32 ) ;
197
199
}
198
200
assert_eq ! ( bytes_hash_to_token. len( ) + 1 , token_starts. len( ) ) ;
199
- println ! ( "copied tokens: {:?}" , start. elapsed( ) ) ;
200
201
201
202
let longest_searcher = DoubleArrayAhoCorasickBuilder :: new ( )
202
203
. match_kind ( daachorse:: MatchKind :: LeftmostLongest )
203
204
. build ( token_iter ( & all_tokens, & token_starts) )
204
205
. expect ( "failed to build AhoCorasick" ) ;
205
- println ! ( "constructed longest searcher: {:?}" , start. elapsed( ) ) ;
206
206
207
207
let overlapping_searcher =
208
208
DoubleArrayAhoCorasick :: < u32 > :: new ( token_iter ( & all_tokens, & token_starts) ) . expect ( "" ) ;
209
- println ! ( "constructed overlapping searcher: {:?}" , start. elapsed( ) ) ;
210
209
let overlapping_searcher_rev =
211
210
DoubleArrayAhoCorasick :: < u32 > :: new ( token_iter ( & all_tokens_rev, & token_starts) )
212
211
. expect ( "" ) ;
@@ -216,7 +215,6 @@ impl BytePairEncoding {
216
215
next_match ( & longest_searcher, & token[ 0 ..token. len ( ) - 1 ] ) . unwrap_or ( u32:: MAX )
217
216
} )
218
217
. collect ( ) ;
219
- println ! ( "constructed next_prefix_match: {:?}" , start. elapsed( ) ) ;
220
218
221
219
let mut split_table = vec ! [ ] ;
222
220
let mut pair_lookup = FnvHashMap :: default ( ) ;
@@ -243,8 +241,6 @@ impl BytePairEncoding {
243
241
split_table. push ( ( id as u32 , id as u32 ) ) ;
244
242
}
245
243
}
246
- println ! ( "constructed split table: {:?}" , start. elapsed( ) ) ;
247
-
248
244
Self {
249
245
all_tokens,
250
246
token_starts,
@@ -339,12 +335,35 @@ impl BytePairEncoding {
339
335
last_token
340
336
}
341
337
338
+ /// Counts the number tokens produced when encoding the text.
342
339
pub fn count ( & self , text : & [ u8 ] ) -> usize {
343
340
let mut enc = BacktrackEncoder :: new ( self , text) ;
344
341
while enc. step ( ) . is_some ( ) { }
345
342
enc. count ( )
346
343
}
347
344
345
+ /// Returns the token count iff the total token count stays below the specified `token_limit`.
346
+ /// Otherwise, it returns false.
347
+ /// This function can be faster than `count` when the token_limit is much smaller than the provided text.
348
+ pub fn count_till_limit ( & self , text : & [ u8 ] , token_limit : usize ) -> Option < usize > {
349
+ let mut enc = BacktrackEncoder :: new ( self , text) ;
350
+ // When the text has exactly the desired number of tokens, then it could in theory happen that
351
+ // the token_limit is exceeded before the end of the text is reached (and a different encoding is tested).
352
+ // To be on the "safe" side, we add a little buffer for such cases.
353
+ // TODO: Determine exactly how large this buffer must be in the worst case.
354
+ let limit_with_buffer = token_limit. saturating_add ( 10 ) ;
355
+ while enc. step ( ) . is_some ( ) {
356
+ if enc. count ( ) > limit_with_buffer {
357
+ return None ;
358
+ }
359
+ }
360
+ if enc. count ( ) <= token_limit {
361
+ Some ( enc. count ( ) )
362
+ } else {
363
+ None
364
+ }
365
+ }
366
+
348
367
pub fn encode_via_table ( & self , text : & [ u8 ] ) -> Vec < u32 > {
349
368
let last_token = self . encode_all_prefixes ( text) ;
350
369
let mut encoded = Vec :: with_capacity ( text. len ( ) / 3 ) ;
0 commit comments