Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bindings/python/src/trainers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ impl PyBpeTrainer {
}
"limit_alphabet" => builder = builder.limit_alphabet(val.extract()?),
"max_token_length" => builder = builder.max_token_length(val.extract()?),
"enforce_utf8_boundaries" => builder = builder.enforce_utf8_boundaries(val.extract()?),
"initial_alphabet" => {
let alphabet: Vec<String> = val.extract()?;
builder = builder.initial_alphabet(
Expand Down
43 changes: 43 additions & 0 deletions bindings/python/tests/bindings/test_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,49 @@ def test_can_pickle(self):
)


def test_enforce_utf8_boundaries(self):
# This input is designed to have a very frequent but invalid merge candidate:
# a space (0x20) followed by the first byte of different 4-byte encodings (0xF0).
# A less frequent but valid candidate is the first two bytes of an emoji (0xF0, 0x9F).
data = [" 🤗"] * 10 + [" 𝟑"] * 9

# Setup a tokenizer with a ByteLevel pre-tokenizer
tokenizer = Tokenizer(models.BPE())
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)

# 1. Train with `enforce_utf8_boundaries=False` (unconstrained)
unconstrained_trainer = trainers.BpeTrainer(
vocab_size=260,
special_tokens=["<unk>"],
enforce_utf8_boundaries=False,
show_progress=False,
)
tokenizer.train_from_iterator(data, trainer=unconstrained_trainer)
vocab = tokenizer.get_vocab()

# The pre-tokenizer maps byte 0x20 to `Ġ` and 0xF0 to `ð`.
# The invalid merge of these two should be present.
invalid_token = "Ġð" # Bytes: [20, F0]
assert invalid_token in vocab, "Unconstrained trainer should learn the invalid merge"

# 2. Train with `enforce_utf8_boundaries=True` (constrained)
# We must re-initialize the tokenizer to start with a fresh model
tokenizer = Tokenizer(models.BPE())
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)

# Train with enforce_utf8_boundaries=True
constrained_trainer = trainers.BpeTrainer(
vocab_size=260,
special_tokens=["<unk>"],
enforce_utf8_boundaries=True,
show_progress=False,
)
tokenizer.train_from_iterator(data, trainer=constrained_trainer)
vocab = tokenizer.get_vocab()

# The invalid merge should not be present when enforcing UTF-8 boundaries
assert invalid_token not in vocab, "Constrained trainer should not learn invalid merges"

class TestWordPieceTrainer:
def test_can_modify(self):
trainer = trainers.WordPieceTrainer(
Expand Down
161 changes: 158 additions & 3 deletions tokenizers/src/models/bpe/trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use super::{Pair, WithFirstLastIterator, Word, BPE};
use crate::parallelism::*;
use crate::tokenizer::{AddedToken, Result, Trainer};
use crate::utils::progress::{ProgressBar, ProgressStyle};
use crate::pre_tokenizers::byte_level::CHAR_BYTES;
use ahash::{AHashMap, AHashSet};
use compact_str::CompactString;
use dary_heap::OctonaryHeap;
Expand Down Expand Up @@ -48,6 +49,7 @@ struct Config {
continuing_subword_prefix: Option<String>,
end_of_word_suffix: Option<String>,
max_token_length: Option<usize>,
enforce_utf8_boundaries: bool,
}

/// A `BpeTrainerBuilder` can be used to create a `BpeTrainer` with a custom
Expand All @@ -69,6 +71,7 @@ impl Default for BpeTrainerBuilder {
continuing_subword_prefix: None,
end_of_word_suffix: None,
max_token_length: None,
enforce_utf8_boundaries: false,
},
}
}
Expand Down Expand Up @@ -144,6 +147,13 @@ impl BpeTrainerBuilder {
self
}

/// Whether to enforce UTF-8 character boundaries during merges
#[must_use]
pub fn enforce_utf8_boundaries(mut self, enforce: bool) -> Self {
self.config.enforce_utf8_boundaries = enforce;
self
}

/// Constructs the final BpeTrainer
pub fn build(self) -> BpeTrainer {
BpeTrainer {
Expand All @@ -156,6 +166,7 @@ impl BpeTrainerBuilder {
continuing_subword_prefix: self.config.continuing_subword_prefix,
end_of_word_suffix: self.config.end_of_word_suffix,
max_token_length: self.config.max_token_length,
enforce_utf8_boundaries: self.config.enforce_utf8_boundaries,
words: AHashMap::new(),
}
}
Expand Down Expand Up @@ -199,6 +210,11 @@ pub struct BpeTrainer {
pub end_of_word_suffix: Option<String>,
/// An optional parameter to limit the max length of any single token
pub max_token_length: Option<usize>,
/// Whether to enforce UTF-8 character boundaries during merges. When true, only allows merging:
/// 1. Complete UTF-8 characters with each other
/// 2. Single bytes that are part of the same UTF-8 character, from left to right
/// This is useful to avoid creating tokens that are not valid UTF-8 sequences, at no cost to compression.
pub enforce_utf8_boundaries: bool,

words: AHashMap<CompactString, u64>,
}
Expand All @@ -210,6 +226,7 @@ impl Default for BpeTrainer {
}

impl BpeTrainer {

pub fn new(min_frequency: u64, vocab_size: usize) -> Self {
Self {
min_frequency,
Expand Down Expand Up @@ -270,6 +287,67 @@ impl BpeTrainer {
}
}

/// helper for is_merge_allowed, to get the original bytes of a part
fn get_original_bytes(&self, part: &str) -> Option<Vec<u8>> {
part.chars().map(|c| CHAR_BYTES.get(&c).copied()).collect()
}
/// Determines if a merge is allowed under UTF-8 boundary constraints.
///
/// This check is only performed if `enforce_utf8_boundaries` is true.
/// A merge is allowed if it meets one of the following criteria:
/// 1. Both tokens consist of complete characters.
/// 2. Both tokens are part of the same single character, and the second is a single byte.
/// This allows building multi-byte characters from their individual bytes left-to-right.
/// All other combinations, such as merging a complete character with a partial byte, are disallowed.
/// This function is designed to work on the character-mapped output of a `ByteLevel`
/// pre-tokenizer by reversing the mapping to check the original bytes.
/// Determines if a merge is allowed under UTF-8 boundary constraints.
/// This function is designed to work on the character-mapped output of a `ByteLevel`
/// pre-tokenizer by reversing the mapping to check the original bytes.
fn is_merge_allowed(&self, pair: &Pair, id_to_word: &[CompactString]) -> bool {
if !self.enforce_utf8_boundaries {
return true;
}

let part_a = &id_to_word[pair.0 as usize];
let part_b = &id_to_word[pair.1 as usize];

// Get the original bytes by reversing the ByteLevel character mapping.
let bytes_a = self.get_original_bytes(part_a.as_ref()).unwrap_or_default();
let bytes_b = self.get_original_bytes(part_b.as_ref()).unwrap_or_default();

// A "complete" token is one whose underlying bytes form a valid UTF-8 string.
// For ByteLevel, this means single-byte ASCII chars (like a space) are complete,
// but single bytes from a multi-byte sequence (like 0xF0) are not.
let is_a_complete = std::str::from_utf8(&bytes_a).is_ok();
let is_b_complete = std::str::from_utf8(&bytes_b).is_ok();

// - Allow merging two complete tokens.
// - Any mix of complete and incomplete is disallowed.
if is_a_complete && is_b_complete {
return true;
}
if is_a_complete ^ is_b_complete {
return false;
}

// Here we know both tokens are incomplete.
// Allow merge only if building a valid UTF-8 prefix by appending a single byte.
if bytes_b.len() == 1 {
let mut merged = bytes_a;
merged.extend_from_slice(&bytes_b);
match std::str::from_utf8(&merged) {
// The merged bytes form one or more complete characters. Valid.
Ok(_) => true,
// The merged bytes are an incomplete but valid prefix. Valid.
Err(e) => e.error_len().is_none(),
}
} else {
// If part_b is not a single byte, it's not a valid continuation merge.
false
}
}

/// Compute the initial alphabet and limit it if relevant
fn compute_alphabet(
&self,
Expand Down Expand Up @@ -455,7 +533,7 @@ impl BpeTrainer {
let mut queue = OctonaryHeap::with_capacity(pair_counts.len());
where_to_update.drain().for_each(|(pair, pos)| {
let count = pair_counts[&pair];
if count > 0 {
if count > 0 && self.is_merge_allowed(&pair, &id_to_word) {
queue.push(Merge {
pair,
count: count as u64,
Expand Down Expand Up @@ -550,13 +628,13 @@ impl BpeTrainer {
for ((pair, change), iw) in changes {
let count = change * counts[iw] as i32;
*pair_counts.entry(pair).or_default() += count;
if change > 0 {
if change > 0 && self.is_merge_allowed(&pair, &id_to_word) {
where_to_update.entry(pair).or_default().insert(iw);
}
}
where_to_update.drain().for_each(|(pair, pos)| {
let count = pair_counts[&pair];
if count > 0 {
if count > 0 && self.is_merge_allowed(&pair, &id_to_word) {
queue.push(Merge {
pair,
count: count as u64,
Expand Down Expand Up @@ -644,8 +722,14 @@ impl Trainer for BpeTrainer {
#[cfg(test)]
mod tests {
use super::{BpeTrainer, Pair, BPE};
use crate::pre_tokenizers::byte_level::{bytes_char, ByteLevel};
use crate::tokenizer::{
OffsetReferential, OffsetType, PreTokenizedString, PreTokenizer, Result, Trainer,
};
use ahash::AHashMap;
use compact_str::CompactString;
use std::collections::HashMap;
use std::sync::LazyLock;

#[test]
fn test_train() {
Expand Down Expand Up @@ -762,6 +846,7 @@ mod tests {
)
}
}

#[test]
fn bpe_test_max_token_length_direct_assert() {
/* more direct version of bpe_test_max_token_length test
Expand Down Expand Up @@ -831,4 +916,74 @@ mod tests {
.collect();
assert_eq!(trained_vocab, expected_vocab)
}

static BYTE_TO_CHAR: LazyLock<AHashMap<u8, char>> = LazyLock::new(bytes_char);

#[test]
fn test_bpe_utf8_boundary_enforcement_with_byte_level_pretokenizer() {
// Use the actual ByteLevel pre-tokenizer to process the input string.
let byte_level_pretok = ByteLevel::new(false, false, false);
let process_fn = |s: &str| -> Result<Vec<String>> {
let mut pretokenized = PreTokenizedString::from(s);
byte_level_pretok.pre_tokenize(&mut pretokenized)?;
Ok(pretokenized
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(word, _, _)| word.to_string())
.collect())
};

let sequence = " 🤗 🦒 🐹 🦦 🤗 𝟑".to_string();
let vocab_size = 25;

// --- Part 1: Unconstrained BPE ---
let mut unconstrained_trainer = BpeTrainer::builder()
.vocab_size(vocab_size)
.show_progress(false)
.enforce_utf8_boundaries(false)
.build();
unconstrained_trainer
.feed(std::iter::once(&sequence), &process_fn)
.unwrap();
let mut unconstrained_model = BPE::default();
unconstrained_trainer
.train(&mut unconstrained_model)
.unwrap();

let invalid_merge_token: String =
[BYTE_TO_CHAR[&b' '], BYTE_TO_CHAR[&0xF0]].iter().collect();
assert!(
unconstrained_model
.get_vocab()
.contains_key(&invalid_merge_token),
"Unconstrained vocab SHOULD contain the top frequency merge (bytes [20 F0])"
);

// --- Part 2: Constrained BPE ---
let mut constrained_trainer = BpeTrainer::builder()
.vocab_size(vocab_size)
.show_progress(false)
.enforce_utf8_boundaries(true)
.build();
constrained_trainer
.feed(std::iter::once(&sequence), &process_fn)
.unwrap();
let mut constrained_model = BPE::default();
constrained_trainer.train(&mut constrained_model).unwrap();

let valid_merge_token: String =
[BYTE_TO_CHAR[&0xF0], BYTE_TO_CHAR[&0x9F]].iter().collect();
assert!(
!constrained_model
.get_vocab()
.contains_key(&invalid_merge_token),
"Constrained vocab MUST NOT contain the invalid merge (bytes [20 F0])"
);
assert!(
constrained_model
.get_vocab()
.contains_key(&valid_merge_token),
"Constrained vocab SHOULD contain the next valid merge (bytes [F0 9F])"
);
}
}
2 changes: 1 addition & 1 deletion tokenizers/src/pre_tokenizers/byte_level.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ static RE: LazyLock<SysRegex> = LazyLock::new(|| {
.unwrap()
});
static BYTES_CHAR: LazyLock<AHashMap<u8, char>> = LazyLock::new(bytes_char);
static CHAR_BYTES: LazyLock<AHashMap<char, u8>> =
pub(crate) static CHAR_BYTES: LazyLock<AHashMap<char, u8>> =
LazyLock::new(|| bytes_char().into_iter().map(|(c, b)| (b, c)).collect());

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
Expand Down
12 changes: 12 additions & 0 deletions tokenizers/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,16 @@ where
PP: PostProcessor,
D: Decoder,
{
/// Validates compatibility between a trainer and the current tokenizer configuration.
/// Currently only checks:
// For BpeTrainer with `enforce_utf8_boundaries=True` => pretokenizer must be ByteLevel.
fn _check_trainer_compat<T: Trainer>(
&self,
_trainer: &T,
) -> Result<()> {
Ok(())
}

/// Instantiate a new Tokenizer, with the given Model
pub fn new(model: M) -> Self {
Self {
Expand Down Expand Up @@ -1345,6 +1355,7 @@ where
where
T: Trainer<Model = M> + Sync,
{
self._check_trainer_compat(trainer)?; // check that settings are compatible
let mut len = 0;
for file in files.iter() {
len += File::open(file)
Expand Down Expand Up @@ -1420,6 +1431,7 @@ where
I: Iterator<Item = S> + Send,
S: AsRef<str> + Send,
{
self._check_trainer_compat(trainer)?; // check that settings are compatible
let (lower, upper) = sequences.size_hint();
let len = upper.unwrap_or(lower) as u64;
let progress = if trainer.should_show_progress() {
Expand Down