diff --git a/crates/bpe-openai/Cargo.toml b/crates/bpe-openai/Cargo.toml index b4379c3..c0164a0 100644 --- a/crates/bpe-openai/Cargo.toml +++ b/crates/bpe-openai/Cargo.toml @@ -15,7 +15,7 @@ bench = false [dependencies] bpe = { version = "0.1.0", path = "../bpe" } either = "1.13" -fancy-regex = "0.13" +regex-automata = "0.4" rmp-serde = "1" [dev-dependencies] diff --git a/crates/bpe-openai/src/lib.rs b/crates/bpe-openai/src/lib.rs index fd2c7c8..2ce2113 100644 --- a/crates/bpe-openai/src/lib.rs +++ b/crates/bpe-openai/src/lib.rs @@ -2,42 +2,48 @@ use std::sync::LazyLock; use bpe::byte_pair_encoding::BytePairEncoding; use either::Either; -use fancy_regex::Regex; +use regex_automata::{meta::Regex, util::captures::Captures, Anchored, Input}; static BPE_R50K_BASE: LazyLock = LazyLock::new(|| { let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_r50k_base.dict")); let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data"); - let pat = "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+"; + let pat = "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+"; Tokenizer::new(bpe, Some(pat)).expect("valid regex") }); static BPE_P50K_BASE: LazyLock = LazyLock::new(|| { let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_p50k_base.dict")); let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data"); - let pat = "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+"; - Tokenizer::new(bpe, Some(pat)).expect("valid regex") + let pat1 = "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+"; + let pat2 = "\\s+\\s"; + let pat3 = "\\s+"; + Tokenizer::with_many(bpe, &[pat1, pat2, pat3]).expect("valid regex") }); static BPE_CL100K_BASE: LazyLock = LazyLock::new(|| { let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_cl100k_base.dict")); let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data"); - let pat = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"; - Tokenizer::new(bpe, Some(pat)).expect("valid regex") + let pat1 = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+"; + // Note: Rewrite the negative look-ahead with a positive pseudo look-ahead. + // The look-ahead character is dropped from the match by the SpecialRegexp iterator. + let pat2 = "\\s+\\s"; + let pat3 = "\\s+"; + Tokenizer::with_many(bpe, &[pat1, pat2, pat3]).expect("valid regex") }); static BPE_O200K_BASE: LazyLock = LazyLock::new(|| { let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_o200k_base.dict")); let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data"); - let pat = [ + let pat1 = [ "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?", "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?", "\\p{N}{1,3}", " ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*", "\\s*[\\r\\n]+", - "\\s+(?!\\S)", - "\\s+", ].join("|"); - Tokenizer::new(bpe, Some(&pat)).expect("valid regex") + let pat2 = "\\s+\\s"; + let pat3 = "\\s+"; + Tokenizer::with_many(bpe, &[pat1.as_str(), pat2, pat3]).expect("valid regex") }); pub use bpe::*; @@ -57,8 +63,15 @@ pub struct Tokenizer { impl Tokenizer { #[allow(clippy::result_large_err)] - pub fn new(bpe: BytePairEncoding, pat: Option<&str>) -> fancy_regex::Result { - let pat = pat.map(fancy_regex::Regex::new).transpose()?; + pub fn new(bpe: BytePairEncoding, pat: Option<&str>) -> Result { + let pat = pat.map(Regex::new).transpose().map_err(|_| ())?; + Ok(Self { bpe, pat }) + } + + /// When using multiple patterns, the second pattern is assumed to be a look-ahead pattern with + /// exactly one look-ahead character! + pub fn with_many(bpe: BytePairEncoding, patterns: &[&str]) -> Result { + let pat = Some(Regex::new_many(patterns).map_err(|_| ())?); Ok(Self { bpe, pat }) } @@ -78,16 +91,51 @@ impl Tokenizer { String::from_utf8(self.bpe.decode_tokens(tokens)).ok() } - pub fn split<'a>(&'a self, text: &'a str) -> impl Iterator + 'a { + pub fn split<'a>(&'a self, input: &'a str) -> impl Iterator + 'a { match &self.pat { - Some(pat) => Either::Left(pat.find_iter(text).scan(0, |start, m| { - let m = m.expect("match succeeded"); - assert_eq!(*start, m.start(), "pattern should match all input text"); - *start = m.end(); - Some(m.as_str()) - })), - None => Either::Right(std::iter::once(text)), + Some(pat) => Either::Left(SpecialRegexp { + pat, + input, + last: 0, + caps: Captures::matches(pat.group_info().clone()), + }), + None => Either::Right(std::iter::once(input)), + } + } +} + +/// This is a small wrapper around the regex which emulates the behaviour of look-ahead by +/// dropping the look-ahead character from the match. The assumption here is that the +/// second pattern is always a look-ahead pattern, and that just a single character needs +/// to be dropped. With this little hack, we can keep most of the regex patterns as they are, +/// but achieve a >3x speedup. +/// +/// Alternatively, this could have been implemented with capture groups, but those were ~30% +/// slower than this approach with multiple patterns. +struct SpecialRegexp<'a> { + pat: &'a Regex, + input: &'a str, + last: usize, + caps: Captures, +} + +impl<'a> Iterator for SpecialRegexp<'a> { + type Item = &'a str; + + fn next(&mut self) -> Option { + let input = Input::new(&self.input[self.last..]).anchored(Anchored::Yes); + self.caps.clear(); + self.pat.captures(input, &mut self.caps); + let m = self.caps.get_match()?; + let start = self.last; + let mut end = self.last + m.range().end; + if m.pattern() == 1.into() { + let last = self.input[start..end].chars().rev().next().unwrap(); + end -= last.len_utf8(); + assert_ne!(end, start, "a look-ahead pattern must ALWAYS consume at least one character excluding the look-ahead character!"); } + self.last = end; + Some(&self.input[start..end]) } }