Skip to content

Commit

Permalink
Generate test strings with multi-byte characters
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikvanantwerpen committed Oct 18, 2024
1 parent 0907c88 commit d430615
Showing 1 changed file with 37 additions and 10 deletions.
47 changes: 37 additions & 10 deletions crates/bpe/src/byte_pair_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -567,20 +567,47 @@ fn is_char_boundary(b: u8) -> bool {
#[cfg(feature = "rand")]
pub fn create_test_string(bpe: &BytePairEncoding, min_bytes: usize) -> String {
use rand::{thread_rng, Rng};
let mut text = String::new();
while text.len() < min_bytes {
loop {
// the bytes we accumulated thus far
let mut bytes = Vec::new();
// the tokens we added so we can backtrack
let mut tokens = Vec::new();
// the number of valid UTF-8 bytes
let mut valid_bytes = 0;
'keep: while valid_bytes < min_bytes {
// try a few times to find a suitable token
for _ in 0..8 {
// pick a random token and provisionally add it
let i = thread_rng().gen_range(0..bpe.num_tokens());
let s = bpe.token_bytes(i as u32);
if s.iter().all(|b| is_char_boundary(*b)) {
if let Ok(s) = std::str::from_utf8(s) {
text.push_str(s);
break;
}
bytes.extend(bpe.token_bytes(i as u32));
// test if the additional bytes are valid utf-8
// the last character is not included, because it may be incomplete
let last = bytes
.iter()
.rev()
.find_position(|b| is_char_boundary(**b))
.map_or(0, |(offset, _)| bytes.len() - (offset + 1));
assert!(last >= valid_bytes);
if std::str::from_utf8(&bytes[valid_bytes..last]).is_ok() {
tokens.push(i);
valid_bytes = last;
continue 'keep;
} else {
bytes.truncate(bytes.len() - bpe.token_len(i as u32));
}
}
// we didn't find anything after a few tries, backtrack
if let Some(i) = tokens.pop() {
bytes.truncate(bytes.len() - bpe.token_len(i as u32));
valid_bytes = bytes
.iter()
.rev()
.find_position(|b| is_char_boundary(**b))
.map_or(0, |(offset, _)| bytes.len() - (offset + 1));
}
}
text
// truncate to the know valid bytes
bytes.truncate(valid_bytes);
String::from_utf8(bytes).expect("should be valid here")
}

#[cfg(feature = "rand")]
Expand Down

0 comments on commit d430615

Please sign in to comment.