Skip to content

Commit

Permalink
Replace look-ahead with multiple patterns ==> 3x speedup
Browse files Browse the repository at this point in the history
  • Loading branch information
aneubeck committed Oct 18, 2024
1 parent 5b127c9 commit 25188c8
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 21 deletions.
2 changes: 1 addition & 1 deletion crates/bpe-openai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ bench = false
[dependencies]
bpe = { version = "0.1.0", path = "../bpe" }
either = "1.13"
fancy-regex = "0.13"
regex-automata = "0.4"
rmp-serde = "1"

[dev-dependencies]
Expand Down
88 changes: 68 additions & 20 deletions crates/bpe-openai/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,48 @@ use std::sync::LazyLock;

use bpe::byte_pair_encoding::BytePairEncoding;
use either::Either;
use fancy_regex::Regex;
use regex_automata::{meta::Regex, util::captures::Captures, Anchored, Input};

static BPE_R50K_BASE: LazyLock<Tokenizer> = LazyLock::new(|| {
let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_r50k_base.dict"));
let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data");
let pat = "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+";
let pat = "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+";
Tokenizer::new(bpe, Some(pat)).expect("valid regex")
});

static BPE_P50K_BASE: LazyLock<Tokenizer> = LazyLock::new(|| {
let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_p50k_base.dict"));
let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data");
let pat = "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+";
Tokenizer::new(bpe, Some(pat)).expect("valid regex")
let pat1 = "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+";
let pat2 = "\\s+\\s";
let pat3 = "\\s+";
Tokenizer::with_many(bpe, &[pat1, pat2, pat3]).expect("valid regex")
});

static BPE_CL100K_BASE: LazyLock<Tokenizer> = LazyLock::new(|| {
let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_cl100k_base.dict"));
let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data");
let pat = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
Tokenizer::new(bpe, Some(pat)).expect("valid regex")
let pat1 = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+";
// Note: Rewrite the negative look-ahead with a positive pseudo look-ahead.
// The look-ahead character is dropped from the match by the SpecialRegexp iterator.
let pat2 = "\\s+\\s";
let pat3 = "\\s+";
Tokenizer::with_many(bpe, &[pat1, pat2, pat3]).expect("valid regex")
});

static BPE_O200K_BASE: LazyLock<Tokenizer> = LazyLock::new(|| {
let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_o200k_base.dict"));
let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data");
let pat = [
let pat1 = [
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
"\\p{N}{1,3}",
" ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*",
"\\s*[\\r\\n]+",
"\\s+(?!\\S)",
"\\s+",
].join("|");
Tokenizer::new(bpe, Some(&pat)).expect("valid regex")
let pat2 = "\\s+\\s";
let pat3 = "\\s+";
Tokenizer::with_many(bpe, &[pat1.as_str(), pat2, pat3]).expect("valid regex")
});

pub use bpe::*;
Expand All @@ -57,8 +63,15 @@ pub struct Tokenizer {

impl Tokenizer {
#[allow(clippy::result_large_err)]
pub fn new(bpe: BytePairEncoding, pat: Option<&str>) -> fancy_regex::Result<Self> {
let pat = pat.map(fancy_regex::Regex::new).transpose()?;
pub fn new(bpe: BytePairEncoding, pat: Option<&str>) -> Result<Self, ()> {
let pat = pat.map(Regex::new).transpose().map_err(|_| ())?;
Ok(Self { bpe, pat })
}

/// When using multiple patterns, the second pattern is assumed to be a look-ahead pattern with
/// exactly one look-ahead character!
pub fn with_many(bpe: BytePairEncoding, patterns: &[&str]) -> Result<Self, ()> {
let pat = Some(Regex::new_many(patterns).map_err(|_| ())?);
Ok(Self { bpe, pat })
}

Expand All @@ -78,16 +91,51 @@ impl Tokenizer {
String::from_utf8(self.bpe.decode_tokens(tokens)).ok()
}

pub fn split<'a>(&'a self, text: &'a str) -> impl Iterator<Item = &str> + 'a {
pub fn split<'a>(&'a self, input: &'a str) -> impl Iterator<Item = &str> + 'a {
match &self.pat {
Some(pat) => Either::Left(pat.find_iter(text).scan(0, |start, m| {
let m = m.expect("match succeeded");
assert_eq!(*start, m.start(), "pattern should match all input text");
*start = m.end();
Some(m.as_str())
})),
None => Either::Right(std::iter::once(text)),
Some(pat) => Either::Left(SpecialRegexp {
pat,
input,
last: 0,
caps: Captures::matches(pat.group_info().clone()),
}),
None => Either::Right(std::iter::once(input)),
}
}
}

/// This is a small wrapper around the regex which emulates the behaviour of look-ahead by
/// dropping the look-ahead character from the match. The assumption here is that the
/// second pattern is always a look-ahead pattern, and that just a single character needs
/// to be dropped. With this little hack, we can keep most of the regex patterns as they are,
/// but achieve a >3x speedup.
///
/// Alternatively, this could have been implemented with capture groups, but those were ~30%
/// slower than this approach with multiple patterns.
struct SpecialRegexp<'a> {
pat: &'a Regex,
input: &'a str,
last: usize,
caps: Captures,
}

impl<'a> Iterator for SpecialRegexp<'a> {
type Item = &'a str;

fn next(&mut self) -> Option<Self::Item> {
let input = Input::new(&self.input[self.last..]).anchored(Anchored::Yes);
self.caps.clear();
self.pat.captures(input, &mut self.caps);
let m = self.caps.get_match()?;
let start = self.last;
let mut end = self.last + m.range().end;
if m.pattern() == 1.into() {
let last = self.input[start..end].chars().rev().next().unwrap();
end -= last.len_utf8();
assert_ne!(end, start, "a look-ahead pattern must ALWAYS consume at least one character excluding the look-ahead character!");
}
self.last = end;
Some(&self.input[start..end])
}
}

Expand Down

0 comments on commit 25188c8

Please sign in to comment.