Skip to content

Commit ec14d05

Browse files
Be explicit about lookahead patterns
1 parent b543057 commit ec14d05

File tree

1 file changed

+66
-29
lines changed

1 file changed

+66
-29
lines changed

crates/bpe-openai/src/lib.rs

Lines changed: 66 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,10 @@ static BPE_CL100K_BASE: LazyLock<Tokenizer> = LazyLock::new(|| {
1616
let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_cl100k_base.dict"));
1717
let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data");
1818
let pat1 = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+$";
19-
// Note: Rewrite the negative look-ahead with a positive pseudo look-ahead.
20-
// The look-ahead character is dropped from the match by the SpecialRegexp iterator.
21-
// Note: The negative look-ahead requires also the pattern `\\s+$` to handle end of file without dropping a character!
2219
let pat2 = "\\s+\\s";
2320
let pat3 = "\\s+";
24-
Tokenizer::with_many(bpe, &[pat1, pat2, pat3]).expect("valid regex")
21+
Tokenizer::new_lookahead(bpe, &[(pat1, false), (pat2, true), (pat3, false)])
22+
.expect("valid regex")
2523
});
2624

2725
static BPE_O200K_BASE: LazyLock<Tokenizer> = LazyLock::new(|| {
@@ -37,7 +35,8 @@ static BPE_O200K_BASE: LazyLock<Tokenizer> = LazyLock::new(|| {
3735
].join("|");
3836
let pat2 = "\\s+\\s";
3937
let pat3 = "\\s+";
40-
Tokenizer::with_many(bpe, &[pat1.as_str(), pat2, pat3]).expect("valid regex")
38+
Tokenizer::new_lookahead(bpe, &[(&pat1, false), (pat2, true), (pat3, false)])
39+
.expect("valid regex")
4140
});
4241

4342
pub use bpe::*;
@@ -52,22 +51,33 @@ pub struct Tokenizer {
5251
/// The byte-pair encoding for this tokenizer.
5352
pub bpe: BytePairEncoding,
5453
/// The pattern regex used to split the input.
55-
pub pat: Option<Regex>,
54+
pub pre: Option<Pretokenizer>,
55+
}
56+
57+
pub struct Pretokenizer {
58+
/// The pattern regex used to split the input.
59+
pat: Regex,
60+
/// For each pattern in the regex a boolean whether the last character is a look-ahead.
61+
lookahead: Vec<bool>,
5662
}
5763

5864
impl Tokenizer {
5965
/// Build a tokenizer with an optional pretokenization regex pattern.
6066
#[allow(clippy::result_large_err)]
6167
pub fn new(bpe: BytePairEncoding, pat: Option<&str>) -> Result<Self, BuildError> {
62-
let pat = pat.map(Regex::new).transpose()?;
63-
Ok(Self { bpe, pat })
68+
let pre = pat.map(Pretokenizer::new).transpose()?;
69+
Ok(Self { bpe, pre })
6470
}
6571

66-
/// When using multiple patterns, the second pattern is assumed to be a look-ahead pattern with
67-
/// exactly one look-ahead character!
68-
pub fn with_many(bpe: BytePairEncoding, patterns: &[&str]) -> Result<Self, BuildError> {
69-
let pat = Some(Regex::new_many(patterns)?);
70-
Ok(Self { bpe, pat })
72+
/// Build a tokenizer with pretokenization regex patterns. If the boolean for a pattern is true,
73+
/// the pattern is assumed to be a look-ahead pattern with exactly one look-ahead character!
74+
#[allow(clippy::result_large_err)]
75+
pub fn new_lookahead(
76+
bpe: BytePairEncoding,
77+
patterns: &[(&str, bool)],
78+
) -> Result<Self, BuildError> {
79+
let pre = Some(Pretokenizer::new_lookahead(patterns)?);
80+
Ok(Self { bpe, pre })
7181
}
7282

7383
pub fn count(&self, text: &str) -> usize {
@@ -86,15 +96,41 @@ impl Tokenizer {
8696
String::from_utf8(self.bpe.decode_tokens(tokens)).ok()
8797
}
8898

89-
pub fn split<'a>(&'a self, input: &'a str) -> impl Iterator<Item = &str> + 'a {
90-
match &self.pat {
91-
Some(pat) => Either::Left(SpecialRegexp {
92-
pat,
93-
input,
94-
last: 0,
95-
caps: Captures::matches(pat.group_info().clone()),
96-
}),
97-
None => Either::Right(std::iter::once(input)),
99+
pub fn split<'a>(&'a self, text: &'a str) -> impl Iterator<Item = &str> + 'a {
100+
match &self.pre {
101+
Some(pre) => Either::Left(pre.split(text)),
102+
None => Either::Right(std::iter::once(text)),
103+
}
104+
}
105+
}
106+
107+
impl Pretokenizer {
108+
/// Build a pretokenizer from the given regex pattern.
109+
#[allow(clippy::result_large_err)]
110+
fn new(pat: &str) -> Result<Self, BuildError> {
111+
let pat = Regex::new(pat)?;
112+
Ok(Self {
113+
pat,
114+
lookahead: vec![false],
115+
})
116+
}
117+
118+
/// Build a pretokenizer from the given regex patterns. If the boolean for a pattern is true,
119+
/// the pattern is assumed to be a look-ahead pattern with exactly one look-ahead character!
120+
#[allow(clippy::result_large_err)]
121+
fn new_lookahead(pats: &[(&str, bool)]) -> Result<Self, BuildError> {
122+
let (pats, lookahead): (Vec<_>, _) = pats.iter().copied().unzip();
123+
let pat = Regex::new_many(&pats)?;
124+
Ok(Self { pat, lookahead })
125+
}
126+
127+
pub fn split<'a>(&'a self, text: &'a str) -> impl Iterator<Item = &str> + 'a {
128+
Splits {
129+
pat: &self.pat,
130+
lookahead: &self.lookahead,
131+
text,
132+
last: 0,
133+
caps: Captures::matches(self.pat.group_info().clone()),
98134
}
99135
}
100136
}
@@ -107,33 +143,34 @@ impl Tokenizer {
107143
///
108144
/// Alternatively, this could have been implemented with capture groups, but those were ~30%
109145
/// slower than this approach with multiple patterns.
110-
struct SpecialRegexp<'a> {
146+
struct Splits<'a> {
111147
pat: &'a Regex,
112-
input: &'a str,
148+
lookahead: &'a [bool],
149+
text: &'a str,
113150
last: usize,
114151
caps: Captures,
115152
}
116153

117-
impl<'a> Iterator for SpecialRegexp<'a> {
154+
impl<'a> Iterator for Splits<'a> {
118155
type Item = &'a str;
119156

120157
fn next(&mut self) -> Option<Self::Item> {
121-
let input = Input::new(&self.input[self.last..]).anchored(Anchored::Yes);
158+
let input = Input::new(&self.text[self.last..]).anchored(Anchored::Yes);
122159
self.caps.clear();
123160
self.pat.captures(input, &mut self.caps);
124161
let m = self.caps.get_match()?;
125162
let start = self.last;
126163
let mut end = self.last + m.range().end;
127-
if m.pattern() == 1.into() {
128-
let last = self.input[start..end]
164+
if self.lookahead[m.pattern().as_usize()] {
165+
let last = self.text[start..end]
129166
.chars()
130167
.next_back()
131168
.expect("Expected at least a look-ahead character!");
132169
end -= last.len_utf8();
133170
assert_ne!(end, start, "a look-ahead pattern must ALWAYS consume at least one character excluding the look-ahead character!");
134171
}
135172
self.last = end;
136-
Some(&self.input[start..end])
173+
Some(&self.text[start..end])
137174
}
138175
}
139176

0 commit comments

Comments
 (0)