Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,16 @@ impl PyTokenizer {
self.tokenizer.get_vocab(with_added_tokens)
}

/// Get the extra tokens
///
/// Returns:
/// :obj:`Dict[str, int]`: The vocabulary
#[pyo3(signature = ())]
#[pyo3(text_signature = "(self)")]
fn get_special_tokens_mapping(&self) -> Option<&HashMap<String, Vec<String>>> {
self.tokenizer.get_special_tokens_mapping()
}

/// Get the underlying vocabulary
///
/// Returns:
Expand Down Expand Up @@ -1848,6 +1858,22 @@ impl PyTokenizer {
fn set_decoder(&mut self, decoder: Option<PyRef<PyDecoder>>) {
self.tokenizer.with_decoder(decoder.map(|d| d.clone()));
}

/// The `optional` :class:`~tokenizers.decoders.Decoder` in use by the Tokenizer
#[getter]
fn get_eos_token(&self, py: Python<'_>) -> Option<Vec<String>> {
self.tokenizer
.get_special_tokens_mapping()
.and_then(|token| token.get("eos_token"))
// into_pyobject -> Bound<PyAny>. Turn that into PyObject.
.map(|v| v.clone())
}

/// Set the :class:`~tokenizers.decoders.Decoder`
#[setter]
fn set_eos_token(&mut self, new_eos_token: Option<String>) {
self.tokenizer.with_special_tokens_mapping();
}
}

#[cfg(test)]
Expand Down
36 changes: 35 additions & 1 deletion tokenizers/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,17 @@ use std::{
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};

use crate::utils::iter::ResultShunt;
use crate::utils::parallelism::*;
use crate::utils::progress::{ProgressBar, ProgressStyle};
use crate::{special_tokens_mapping::SpecialTokensMapping, utils::iter::ResultShunt};

mod added_vocabulary;
mod encoding;
pub mod normalizer;
pub mod pattern;
pub mod pre_tokenizer;
mod serialization;
pub mod special_tokens_mapping;

// Re-export wrappers
pub use crate::decoders::DecoderWrapper;
Expand Down Expand Up @@ -293,6 +294,7 @@ pub struct TokenizerBuilder<M, N, PT, PP, D> {

truncation: Option<TruncationParams>,
padding: Option<PaddingParams>,
special_tokens_mapping: Option<SpecialTokensMapping>,
}

impl<M, N, PT, PP, D> Default for TokenizerBuilder<M, N, PT, PP, D>
Expand Down Expand Up @@ -327,6 +329,7 @@ where
added_vocabulary: AddedVocabulary::new(),
truncation: None,
padding: None,
special_tokens_mapping: None,
}
}

Expand All @@ -347,6 +350,7 @@ where
added_vocabulary: self.added_vocabulary,
truncation: self.truncation,
padding: self.padding,
special_tokens_mapping: self.special_tokens_mapping,
})
}

Expand Down Expand Up @@ -404,6 +408,14 @@ where
self.padding = padding;
self
}

pub fn with_special_tokens_mapping(
mut self,
special_tokens_mapping: Option<SpecialTokensMapping>,
) -> Self {
self.special_tokens_mapping = special_tokens_mapping;
self
}
}

#[derive(Serialize, Deserialize, Debug, Clone)]
Expand Down Expand Up @@ -480,6 +492,7 @@ where
added_vocabulary: t.added_vocabulary,
padding: t.padding,
truncation: t.truncation,
special_tokens_mapping: t.special_tokens_mapping,
})
}
}
Expand Down Expand Up @@ -524,6 +537,7 @@ pub struct TokenizerImpl<M, N, PT, PP, D> {
// General processing parameters
truncation: Option<TruncationParams>,
padding: Option<PaddingParams>,
special_tokens_mapping: Option<SpecialTokensMapping>,
}

impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
Expand All @@ -547,6 +561,7 @@ where

truncation: None,
padding: None,
special_tokens_mapping: None,
}
}

Expand Down Expand Up @@ -654,6 +669,25 @@ where
self.padding.as_ref()
}

/// Set the special_tokens_mapping
pub fn with_special_tokens_mapping(
&mut self,
special_tokens_mapping: Option<SpecialTokensMapping>,
) -> &mut Self {
self.special_tokens_mapping = special_tokens_mapping;
self
}

/// Get the currently set extra tokens
pub fn get_special_tokens_mapping(&self) -> Option<&SpecialTokensMapping> {
self.special_tokens_mapping.as_ref()
}

/// Get the currently set extra tokens
pub fn get_extra_token_muts(&mut self) -> Option<&mut SpecialTokensMapping> {
self.special_tokens_mapping.as_mut()
}

/// Get a mutable reference to the currently set padding parameters
pub fn get_padding_mut(&mut self) -> Option<&mut PaddingParams> {
self.padding.as_mut()
Expand Down
8 changes: 7 additions & 1 deletion tokenizers/src/tokenizer/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
tokenizer.serialize_field("post_processor", &self.post_processor)?;
tokenizer.serialize_field("decoder", &self.decoder)?;
tokenizer.serialize_field("model", &self.model)?;
tokenizer.serialize_field("special_tokens_mapping", &self.special_tokens_mapping)?;

tokenizer.end()
}
Expand All @@ -63,6 +64,7 @@
"Tokenizer",
&[
"version",
"special_tokens_mapping",
"truncation",
"padding",
"added_tokens",
Expand Down Expand Up @@ -143,6 +145,9 @@
"post_processor" => {
builder = builder.with_post_processor(map.next_value()?);
}
"special_tokens_mapping" => {
builder = builder.with_special_tokens_mapping(map.next_value()?);

Check failure on line 149 in tokenizers/src/tokenizer/serialization.rs

View workflow job for this annotation

GitHub Actions / Check it builds for Windows 32-bit (3.13)

the trait bound `SpecialTokensMapping: serde::Deserialize<'de>` is not satisfied

Check failure on line 149 in tokenizers/src/tokenizer/serialization.rs

View workflow job for this annotation

GitHub Actions / Check everything builds & tests (ubuntu-latest)

the trait bound `SpecialTokensMapping: serde::Deserialize<'de>` is not satisfied
}
_ => {}
};
}
Expand Down Expand Up @@ -221,7 +226,8 @@
"continuing_subword_prefix": "",
"max_input_chars_per_word": 100,
"vocab": {}
}
},
"special_tokens_mapping": null
}"#;
let tokenizer = Tokenizer::from_str(tok_json).unwrap();

Expand Down
20 changes: 20 additions & 0 deletions tokenizers/src/tokenizer/special_tokens_mapping.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use std::collections::{BTreeMap, BTreeSet};

use serde::Serialize;

#[derive(Debug, Clone, Serialize)]
// A struct that represents the mapping between standard special token names like
// `eos_token` or `bos_token` or `my_token` to the corresponding string tokens.
//
// We choose BTreeMap and set for ordered serialization + fast element check
// Supports updating one entry, the whole entry
// Example
pub struct SpecialTokensMapping {
inner: BTreeMap<String, BTreeSet<u32>>,
}

impl SpecialTokensMapping {
pub fn new(inner: BTreeMap<String, BTreeSet<u32>>) -> Self {
Self { inner }
}
}
Loading