Skip to content

Add method to return attn-mask for HF Tokenizer. #60

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions include/tokenizers_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ extern "C" {
typedef void* TokenizerHandle;

typedef struct {
int* token_ids;
size_t len;
int* token_ids;
size_t len;
} TokenizerEncodeResult;

TokenizerHandle tokenizers_new_from_str(const char* json, size_t len);
Expand All @@ -28,10 +28,17 @@ TokenizerHandle byte_level_bpe_tokenizers_new_from_str(const char* vocab, size_t
const char* added_tokens,
size_t added_tokens_len);

void tokenizers_encode(TokenizerHandle handle, const char* data, size_t len, int add_special_token, TokenizerEncodeResult* result);
void tokenizers_encode(TokenizerHandle handle, const char* data, size_t len, int add_special_token,
TokenizerEncodeResult* result);

void tokenizers_encode_batch(TokenizerHandle handle, const char** data, size_t* len, size_t num_seqs,
int add_special_token, TokenizerEncodeResult* results);
void tokenizers_encode_batch(TokenizerHandle handle, const char** data, size_t* len,
size_t num_seqs, int add_special_token,
TokenizerEncodeResult* results);

void tokenizers_encode_batch_with_mask(TokenizerHandle handle, const char** data, size_t* len,
size_t num_seqs, int add_special_token,
TokenizerEncodeResult* results,
TokenizerEncodeResult* masks);

void tokenizers_free_encode_results(TokenizerEncodeResult* results, size_t num_seqs);

Expand Down
70 changes: 66 additions & 4 deletions include/tokenizers_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
#ifndef TOKENIZERS_CPP_H_
#define TOKENIZERS_CPP_H_

#include <tokenizers_c.h>

#include <memory>
#include <string>
#include <vector>

namespace tokenizers {

/*!
Expand Down Expand Up @@ -57,13 +58,14 @@ class Tokenizer {
virtual size_t GetVocabSize() = 0;

/*!
* \brief Convert the given id to its corresponding token if it exists. If not, return an
* empty string.
* \brief Convert the given id to its corresponding token if it exists. If
* not, return an empty string.
*/
virtual std::string IdToToken(int32_t token_id) = 0;

/*!
* \brief Convert the given token to its corresponding id if it exists. If not, return -1.
* \brief Convert the given token to its corresponding id if it exists. If
* not, return -1.
*/
virtual int32_t TokenToId(const std::string& token) = 0;

Expand Down Expand Up @@ -106,5 +108,65 @@ class Tokenizer {
static std::unique_ptr<Tokenizer> FromBlobRWKVWorld(const std::string& model_blob);
};

class HFTokenizer : public Tokenizer {
public:
explicit HFTokenizer(TokenizerHandle handle);

HFTokenizer(const HFTokenizer&);
HFTokenizer(HFTokenizer&& other);

~HFTokenizer();

// use i32 to be consistent with sentencepiece
std::vector<int32_t> Encode(const std::string& text, bool add_special_tokens);

// use i32 to be consistent with sentencepiece
std::vector<int32_t> Encode(const std::string& text) final;

// version specific to HFTokenizer, which adds special tokens flag
std::vector<std::vector<int32_t>> EncodeBatch(const std::vector<std::string>& texts,
bool add_special_tokens);

std::tuple<std::vector<std::vector<int32_t>>, std::vector<std::vector<int32_t>>>
EncodeBatchWithMask(const std::vector<std::string>& texts, bool add_special_tokens);

std::vector<std::vector<int32_t>> EncodeBatch(const std::vector<std::string>& texts) final;

// use i32 to be consistent with sentencepiece
std::string Decode(const std::vector<int32_t>& ids, bool skip_special_tokens);

std::string Decode(const std::vector<int32_t>& ids) final;

size_t GetVocabSize() final;

std::string IdToToken(int32_t id) final;

int32_t TokenToId(const std::string& token) final;

/*!
* \brief Create HF tokenizer from a single in-memory json blob.
*
* \param json_blob The json blob.
* \return The created tokenzier.
*/
static std::unique_ptr<HFTokenizer> FromBlobJSON(const std::string& json_blob);

/*!
* \brief Create BPE tokenizer
*
* \param vocab_blob The blob that contains vocabs.
* \param merges_blob The blob that contains the merges.
* \param added_tokens The added tokens.
* \return The created tokenizer.
*/
static std::unique_ptr<HFTokenizer> FromBlobByteLevelBPE(const std::string& vocab_blob,
const std::string& merges_blob,
const std::string& added_tokens = "");

private:
// internal handle
TokenizerHandle handle_{nullptr};
};

} // namespace tokenizers
#endif // TOKENIZERS_CPP_H_
78 changes: 70 additions & 8 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,24 @@ impl TokenizerWrapper {
return encoded.get_ids().to_vec();
}

pub fn encode_batch(&mut self, texts: Vec<&str>, add_special_tokens: bool) -> Vec<Vec<u32>> {
let results = self.tokenizer.encode_batch(texts, add_special_tokens).unwrap()
.into_iter()
.map(|encoded| encoded.get_ids().to_vec())
pub fn encode_batch_with_mask(
&mut self,
texts: Vec<&str>,
add_special_tokens: bool,
) -> (Vec<Vec<u32>>, Vec<Vec<u32>>) {
let encoded = self
.tokenizer
.encode_batch(texts, add_special_tokens)
.unwrap();
let tokens = encoded
.iter()
.map(|e| e.get_ids().to_vec())
.collect::<Vec<Vec<u32>>>();
return results;
let attention_mask = encoded
.iter()
.map(|e| e.get_attention_mask().to_vec())
.collect::<Vec<Vec<u32>>>();
return (tokens, attention_mask);
}

pub fn decode(&mut self, ids: &[u32], skip_special_tokens: bool) {
Expand Down Expand Up @@ -170,10 +182,49 @@ extern "C" fn tokenizers_encode_batch(
unsafe {
let input_data = (0..num_seqs)
.map(|i| {
std::str::from_utf8(std::slice::from_raw_parts(*input_cstr.offset(i as isize), *input_len.offset(i as isize))).unwrap()
std::str::from_utf8(std::slice::from_raw_parts(
*input_cstr.offset(i as isize),
*input_len.offset(i as isize),
))
.unwrap()
})
.collect::<Vec<&str>>();
let (encoded_batch, _encoded_masks) =
(*handle).encode_batch_with_mask(input_data, add_special_tokens != 0);
for (i, encoded) in encoded_batch.into_iter().enumerate() {
let len = encoded.len();
let result = TokenizerEncodeResult {
token_ids: Box::into_raw(encoded.into_boxed_slice()) as *mut u32,
len: len,
};
*out_result.offset(i as isize) = result;
}
}
}

#[no_mangle]
extern "C" fn tokenizers_encode_batch_with_mask(
handle: *mut TokenizerWrapper,
input_cstr: *const *const u8,
input_len: *const usize,
num_seqs: usize,
add_special_tokens: i32,
out_result: *mut TokenizerEncodeResult,
out_mask: *mut TokenizerEncodeResult,
) {
unsafe {
let input_data = (0..num_seqs)
.map(|i| {
std::str::from_utf8(std::slice::from_raw_parts(
*input_cstr.offset(i as isize),
*input_len.offset(i as isize),
))
.unwrap()
})
.collect::<Vec<&str>>();
let encoded_batch = (*handle).encode_batch(input_data, add_special_tokens != 0);
let (encoded_batch, encoded_mask) =
(*handle).encode_batch_with_mask(input_data, add_special_tokens != 0);

for (i, encoded) in encoded_batch.into_iter().enumerate() {
let len = encoded.len();
let result = TokenizerEncodeResult {
Expand All @@ -182,6 +233,14 @@ extern "C" fn tokenizers_encode_batch(
};
*out_result.offset(i as isize) = result;
}
for (i, encoded) in encoded_mask.into_iter().enumerate() {
let len = encoded.len();
let result = TokenizerEncodeResult {
token_ids: Box::into_raw(encoded.into_boxed_slice()) as *mut u32,
len: len,
};
*out_mask.offset(i as isize) = result;
}
}
}

Expand All @@ -190,7 +249,10 @@ extern "C" fn tokenizers_free_encode_results(results: *mut TokenizerEncodeResult
unsafe {
let slice = std::slice::from_raw_parts_mut(results, num_seqs);
for result in &mut *slice {
drop(Box::from_raw(std::slice::from_raw_parts_mut(result.token_ids, result.len)));
drop(Box::from_raw(std::slice::from_raw_parts_mut(
result.token_ids,
result.len,
)));
}
}
}
Expand Down
Loading