Skip to content

Commit

Permalink
Merge pull request #30 from github/unsplittable-test-strings
Browse files Browse the repository at this point in the history
Generate non-splittable test strings for worstcase benchmark
  • Loading branch information
hendrikvanantwerpen authored Oct 22, 2024
2 parents e20fc1a + 0cb520e commit 17d5c3e
Show file tree
Hide file tree
Showing 9 changed files with 400 additions and 302 deletions.
5 changes: 2 additions & 3 deletions crates/bpe/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,8 @@ This suggests that pre-tokenization is not necessary from a performance perspect

![encoding runtime comparison](./images/performance-comparison.svg)

The graph below shows encoding results for input that is particularly challenging for tiktoken.
The input consists of random ranges taken from the continuous list of all Unicode code points excluding whitespace.
The performance of tiktoken shows a quadratic growth with the input size.
The graph below shows encoding results when the input cannot be split in pre-tokenization and allows a better comparison of pure BPE performance.
This case is particularly challenging for tiktoken, which shows a quadratic growth with the input size.
The Huggingface encoder scales better, but becomes slower and slower compared to our implementation as input size increases.

![worst-case encoding runtime comparison](./images/performance-worstcase.svg)
16 changes: 10 additions & 6 deletions crates/bpe/benchmarks/performance.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::time::Duration;

use bpe::appendable_encoder::AppendableEncoder;
use bpe::byte_pair_encoding::{create_test_string, select_test_string};
use bpe::byte_pair_encoding::{
create_test_string, create_test_string_with_predicate, select_test_string,
};
use bpe::interval_encoding::IntervalEncoding;
use bpe_benchmarks::*;
use criterion::{
Expand All @@ -11,7 +13,7 @@ use rand::{thread_rng, Rng};

fn counting_benchmark(c: &mut Criterion) {
for (name, bpe, _, _) in TOKENIZERS.iter() {
let input = create_test_string(&bpe.bpe, 80000);
let input = create_test_string(&bpe.bpe, 80_000);
let fast = IntervalEncoding::new(&bpe.bpe, input.as_bytes());

let mut group = c.benchmark_group(format!("counting-{name}"));
Expand Down Expand Up @@ -185,19 +187,21 @@ fn comparison_benchmark(c: &mut Criterion) {
}

fn worstcase_comparison_benchmark(c: &mut Criterion) {
for (name, bpe, tiktoken, huggingface) in TOKENIZERS.iter() {
let text: String = ('\0'..char::MAX).filter(|c| !c.is_whitespace()).collect();
for (name, tok, tiktoken, huggingface) in TOKENIZERS.iter() {
let text = create_test_string_with_predicate(&tok.bpe, 100000, |text| {
tok.split(text).nth(1).is_none()
});

let mut group = c.benchmark_group(format!("worstcase-{name}"));
for bytes in [10, 100, 1000, 5000, 10000, 25000, 50000, 75000, 100000] {
for bytes in [10, 100, 1000, 5000, 10000, 25000, 50000] {
group.throughput(criterion::Throughput::Bytes(bytes as u64));
group.bench_with_input(
BenchmarkId::new("backtracking", bytes),
&bytes,
|b, bytes| {
b.iter_batched(
|| select_test_string(&text, *bytes),
|text| bpe.encode(text),
|text| tok.encode(text),
criterion::BatchSize::SmallInput,
)
},
Expand Down
100 changes: 56 additions & 44 deletions crates/bpe/images/performance-appending.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
116 changes: 67 additions & 49 deletions crates/bpe/images/performance-comparison.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
96 changes: 56 additions & 40 deletions crates/bpe/images/performance-counting.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
140 changes: 76 additions & 64 deletions crates/bpe/images/performance-encoding.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
164 changes: 82 additions & 82 deletions crates/bpe/images/performance-worstcase.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
47 changes: 39 additions & 8 deletions crates/bpe/src/byte_pair_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -553,18 +553,49 @@ impl BytePairEncoding {
}
}

/// Generate a test string by concatenating random tokens.
/// Create a random test string for the given [`BytePairEncoding`]. The string will be at least [`min_bytes`] long.
#[cfg(feature = "rand")]
pub fn create_test_string(bpe: &BytePairEncoding, min_bytes: usize) -> String {
create_test_string_with_predicate(bpe, min_bytes, |_| true)
}

/// Create a random test string for the given [`BytePairEncoding`]. The string will be at least [`min_bytes`] long.
/// The given predicate enforces other properties on the generated string. Note that this can hurt performance or
/// even cause non-termination!
#[cfg(feature = "rand")]
pub fn create_test_string_with_predicate(
bpe: &BytePairEncoding,
min_bytes: usize,
predicate: impl Fn(&str) -> bool,
) -> String {
use rand::{thread_rng, Rng};
// the string we accumulated thus far
let mut result = String::new();
while result.len() < min_bytes {
let i = thread_rng().gen_range(0..bpe.num_tokens());
// We only use tokens that are valid UTF-8. This is true for ~99% of tokens in OpenAI's
// token set. The chance of constructing a valid UTF-8 character across a token boundary
// by picking random tokens is so small that it is unlikely to happen anyway.
if let Ok(token) = std::str::from_utf8(bpe.token_bytes(i as u32)) {
result.push_str(token);
// the tokens we added so we can backtrack
let mut tokens = Vec::new();
'keep: while result.len() < min_bytes {
// try a few times to find a suitable token
'next: for _ in 0..8 {
// pick a random token and provisionally add it
let i = thread_rng().gen_range(0..bpe.num_tokens()) as u32;
// We only use tokens that are valid UTF-8. This is true for ~99% of tokens in OpenAI's
// token set. The chance of constructing a valid UTF-8 character across a token boundary
// by picking random tokens is so small that it is unlikely to happen anyway.
if let Ok(token) = std::str::from_utf8(bpe.token_bytes(i)) {
result.push_str(token);
} else {
continue 'next;
}
if predicate(&result) {
tokens.push(i);
continue 'keep;
} else {
result.truncate(result.len() - bpe.token_len(i));
}
}
// we didn't find anything after a few tries, backtrack
if let Some(i) = tokens.pop() {
result.truncate(result.len() - bpe.token_len(i));
}
}
result
Expand Down
18 changes: 12 additions & 6 deletions crates/geo_filters/evaluation/accuracy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ impl Accuracy {
.config
.iter()
.map(|c| {
simulation_config_from_str(c).expect(&format!("not a valid configuration: {}", c))
simulation_config_from_str(c)
.unwrap_or_else(|_| panic!("not a valid configuration: {}", c))
})
.collect_vec();
let set_sizes = if self.set_size.is_empty() {
Expand All @@ -118,9 +119,10 @@ impl Accuracy {

let mut output = self.output;
output.set_extension("csv");
let f = File::create(&output).expect(&format!("cannot create file: {}", output.display()));
let f = File::create(&output)
.unwrap_or_else(|_| panic!("cannot create file: {}", output.display()));
write_simulation_results(&configs, &set_sizes, results, f)
.expect(&format!("cannot write file: {}", output.display()));
.unwrap_or_else(|_| panic!("cannot write file: {}", output.display()));
println!(" csv file = {}", output.display());
println!();
}
Expand All @@ -139,9 +141,9 @@ impl SimulationConfigParser {
Self(Regex::new(re).expect(""), Arc::new(f))
}

fn parse<'a>(&self, name: &str) -> Option<SimulationConfig> {
fn parse(&self, name: &str) -> Option<SimulationConfig> {
self.0
.captures(&name)
.captures(name)
.map(self.1.as_ref())
.map(|p| (name.to_string(), p))
}
Expand Down Expand Up @@ -225,7 +227,11 @@ fn simulation_config_from_str(name: &str) -> Result<SimulationConfig, String> {
fn capture_usizes<const N: usize>(c: &Captures, is: [usize; N]) -> [usize; N] {
let mut values = [0; N];
for i in 0..is.len() {
values[i] = usize::from_str_radix(c.get(is[i]).expect("capture to exist").as_str(), 10)
values[i] = c
.get(is[i])
.expect("capture to exist")
.as_str()
.parse::<usize>()
.expect("number string");
}
values
Expand Down

0 comments on commit 17d5c3e

Please sign in to comment.