From ec14d05fe5659620287b6ed38a6d7c2f1ac88ca3 Mon Sep 17 00:00:00 2001 From: Hendrik van Antwerpen Date: Fri, 18 Oct 2024 13:44:38 +0200 Subject: [PATCH] Be explicit about lookahead patterns --- crates/bpe-openai/src/lib.rs | 95 +++++++++++++++++++++++++----------- 1 file changed, 66 insertions(+), 29 deletions(-) diff --git a/crates/bpe-openai/src/lib.rs b/crates/bpe-openai/src/lib.rs index 2bae78f..be8dfb2 100644 --- a/crates/bpe-openai/src/lib.rs +++ b/crates/bpe-openai/src/lib.rs @@ -16,12 +16,10 @@ static BPE_CL100K_BASE: LazyLock = LazyLock::new(|| { let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_cl100k_base.dict")); let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data"); let pat1 = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+$"; - // Note: Rewrite the negative look-ahead with a positive pseudo look-ahead. - // The look-ahead character is dropped from the match by the SpecialRegexp iterator. - // Note: The negative look-ahead requires also the pattern `\\s+$` to handle end of file without dropping a character! let pat2 = "\\s+\\s"; let pat3 = "\\s+"; - Tokenizer::with_many(bpe, &[pat1, pat2, pat3]).expect("valid regex") + Tokenizer::new_lookahead(bpe, &[(pat1, false), (pat2, true), (pat3, false)]) + .expect("valid regex") }); static BPE_O200K_BASE: LazyLock = LazyLock::new(|| { @@ -37,7 +35,8 @@ static BPE_O200K_BASE: LazyLock = LazyLock::new(|| { ].join("|"); let pat2 = "\\s+\\s"; let pat3 = "\\s+"; - Tokenizer::with_many(bpe, &[pat1.as_str(), pat2, pat3]).expect("valid regex") + Tokenizer::new_lookahead(bpe, &[(&pat1, false), (pat2, true), (pat3, false)]) + .expect("valid regex") }); pub use bpe::*; @@ -52,22 +51,33 @@ pub struct Tokenizer { /// The byte-pair encoding for this tokenizer. pub bpe: BytePairEncoding, /// The pattern regex used to split the input. - pub pat: Option, + pub pre: Option, +} + +pub struct Pretokenizer { + /// The pattern regex used to split the input. + pat: Regex, + /// For each pattern in the regex a boolean whether the last character is a look-ahead. + lookahead: Vec, } impl Tokenizer { /// Build a tokenizer with an optional pretokenization regex pattern. #[allow(clippy::result_large_err)] pub fn new(bpe: BytePairEncoding, pat: Option<&str>) -> Result { - let pat = pat.map(Regex::new).transpose()?; - Ok(Self { bpe, pat }) + let pre = pat.map(Pretokenizer::new).transpose()?; + Ok(Self { bpe, pre }) } - /// When using multiple patterns, the second pattern is assumed to be a look-ahead pattern with - /// exactly one look-ahead character! - pub fn with_many(bpe: BytePairEncoding, patterns: &[&str]) -> Result { - let pat = Some(Regex::new_many(patterns)?); - Ok(Self { bpe, pat }) + /// Build a tokenizer with pretokenization regex patterns. If the boolean for a pattern is true, + /// the pattern is assumed to be a look-ahead pattern with exactly one look-ahead character! + #[allow(clippy::result_large_err)] + pub fn new_lookahead( + bpe: BytePairEncoding, + patterns: &[(&str, bool)], + ) -> Result { + let pre = Some(Pretokenizer::new_lookahead(patterns)?); + Ok(Self { bpe, pre }) } pub fn count(&self, text: &str) -> usize { @@ -86,15 +96,41 @@ impl Tokenizer { String::from_utf8(self.bpe.decode_tokens(tokens)).ok() } - pub fn split<'a>(&'a self, input: &'a str) -> impl Iterator + 'a { - match &self.pat { - Some(pat) => Either::Left(SpecialRegexp { - pat, - input, - last: 0, - caps: Captures::matches(pat.group_info().clone()), - }), - None => Either::Right(std::iter::once(input)), + pub fn split<'a>(&'a self, text: &'a str) -> impl Iterator + 'a { + match &self.pre { + Some(pre) => Either::Left(pre.split(text)), + None => Either::Right(std::iter::once(text)), + } + } +} + +impl Pretokenizer { + /// Build a pretokenizer from the given regex pattern. + #[allow(clippy::result_large_err)] + fn new(pat: &str) -> Result { + let pat = Regex::new(pat)?; + Ok(Self { + pat, + lookahead: vec![false], + }) + } + + /// Build a pretokenizer from the given regex patterns. If the boolean for a pattern is true, + /// the pattern is assumed to be a look-ahead pattern with exactly one look-ahead character! + #[allow(clippy::result_large_err)] + fn new_lookahead(pats: &[(&str, bool)]) -> Result { + let (pats, lookahead): (Vec<_>, _) = pats.iter().copied().unzip(); + let pat = Regex::new_many(&pats)?; + Ok(Self { pat, lookahead }) + } + + pub fn split<'a>(&'a self, text: &'a str) -> impl Iterator + 'a { + Splits { + pat: &self.pat, + lookahead: &self.lookahead, + text, + last: 0, + caps: Captures::matches(self.pat.group_info().clone()), } } } @@ -107,25 +143,26 @@ impl Tokenizer { /// /// Alternatively, this could have been implemented with capture groups, but those were ~30% /// slower than this approach with multiple patterns. -struct SpecialRegexp<'a> { +struct Splits<'a> { pat: &'a Regex, - input: &'a str, + lookahead: &'a [bool], + text: &'a str, last: usize, caps: Captures, } -impl<'a> Iterator for SpecialRegexp<'a> { +impl<'a> Iterator for Splits<'a> { type Item = &'a str; fn next(&mut self) -> Option { - let input = Input::new(&self.input[self.last..]).anchored(Anchored::Yes); + let input = Input::new(&self.text[self.last..]).anchored(Anchored::Yes); self.caps.clear(); self.pat.captures(input, &mut self.caps); let m = self.caps.get_match()?; let start = self.last; let mut end = self.last + m.range().end; - if m.pattern() == 1.into() { - let last = self.input[start..end] + if self.lookahead[m.pattern().as_usize()] { + let last = self.text[start..end] .chars() .next_back() .expect("Expected at least a look-ahead character!"); @@ -133,7 +170,7 @@ impl<'a> Iterator for SpecialRegexp<'a> { assert_ne!(end, start, "a look-ahead pattern must ALWAYS consume at least one character excluding the look-ahead character!"); } self.last = end; - Some(&self.input[start..end]) + Some(&self.text[start..end]) } }