Skip to content

Commit 5f07fc2

Browse files
author
Lőrinc
committed
Lower backtrack_limit to fail earlier for invalid input
1 parent 21c5688 commit 5f07fc2

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

src/lib.rs

+15-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::num::NonZeroU64;
66
use std::thread;
77

88
use fancy_regex::Regex;
9+
use fancy_regex::RegexBuilder;
910
use pyo3::exceptions;
1011
use pyo3::prelude::*;
1112
use pyo3::pyclass;
@@ -417,7 +418,7 @@ impl CoreBPE {
417418
special_tokens_encoder: HashMap<String, Rank>,
418419
pattern: &str,
419420
) -> PyResult<Self> {
420-
let regex = Regex::new(pattern)
421+
let regex = RegexBuilder::new(pattern).backtrack_limit(10_000).build()
421422
.map_err(|e| PyErr::new::<exceptions::PyValueError, _>(e.to_string()))?;
422423

423424
let special_regex = {
@@ -572,6 +573,7 @@ fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> {
572573

573574
#[cfg(test)]
574575
mod tests {
576+
use fancy_regex::RegexBuilder;
575577
use rustc_hash::FxHashMap as HashMap;
576578

577579
use crate::{byte_pair_split, Rank};
@@ -596,4 +598,16 @@ mod tests {
596598
let res = byte_pair_split(b"abab", &ranks);
597599
assert_eq!(res, vec![b"ab", b"ab"]);
598600
}
601+
602+
#[test]
603+
fn test_effect_of_backtrack_limit() {
604+
let regex = RegexBuilder::new(r"(a|b|ab)*(?=c)")
605+
.backtrack_limit(10)
606+
.build()
607+
.expect("Failed to build regex")
608+
.clone();
609+
610+
let input = "ab".repeat(100) + "c";
611+
assert!(regex.is_match(&input).is_err(), "Should throw");
612+
}
599613
}

tests/test_encoding.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,13 @@
1111
from .test_helpers import ENCODING_FACTORIES, MAX_EXAMPLES
1212

1313

14-
@pytest.mark.skip(reason="Takes a really long time to finish, but was added to reproduce a crash.")
1514
@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
1615
def test_extremely_big_encoding(make_enc: Callable[[], tiktoken.Encoding]):
1716
enc = make_enc()
1817
for c in ["^", "0", "a", "'s"]: # TODO " ", "\n" are still failing
1918
print(f"Validating `{c}`")
2019

21-
big_value = c * 1_000_000
20+
big_value = c * 10_000
2221
assert big_value == enc.decode(enc.encode(big_value))
2322

2423
big_value = " " + big_value

0 commit comments

Comments
 (0)