Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BertTokenizer and BertTokenizerFast have different behavior when requested "return_overflowing_tokens" #28900

Open
2 of 4 tasks
ivlcic opened this issue Feb 6, 2024 · 7 comments · May be fixed by #34669
Open
2 of 4 tasks
Labels
bug Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want!

Comments

@ivlcic
Copy link

ivlcic commented Feb 6, 2024

System Info

  • transformers version: 4.37.2
  • Platform: Linux-6.5.5-arch1-1-x86_64-with-glibc2.38
  • Python version: 3.11.5
  • Huggingface_hub version: 0.20.3
  • Safetensors version: 0.4.2
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.0+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import BertTokenizer, BertTokenizerFast, BatchEncoding
n_tok = BertTokenizer.from_pretrained("bert-base-uncased")
f_tok = BertTokenizerFast.from_pretrained("bert-base-uncased")

text = "hello my name is nikola and i debug transformers now"

n_inputs: BatchEncoding = n_tok(text=text, add_special_tokens=True, max_length=6, truncation=True, padding='max_length', return_overflowing_tokens=True)
o = n_inputs.get("overflowing_tokens")
print(f'Overflowing {o}')
n_inputs['input_ids']


f_inputs: BatchEncoding = f_tok(text=text, add_special_tokens=True, max_length=6, truncation=True, padding='max_length', return_overflowing_tokens=True)
o = f_inputs.get("overflowing_tokens")
print(f'Overflowing {o}')
f_inputs['input_ids']

Expected behavior

For the n_inputs['input_ids'] we get [101, 7592, 2026, 2171, 2003, 102], and
for the f_inputs['input_ids'] we get [[101, 7592, 2026, 2171, 2003, 102], [101, 24794, 1998, 1045, 2139, 102], [101, 8569, 2290, 19081, 2085, 102]].
Outputs should be the same.

@ArthurZucker
Copy link
Collaborator

Hey! Thanks for opening this issue. Would you like to dive in this and open a PR for a fix? It might be a known bug + overflowing tokens are not supported on all slow tokenizer. The fast is probably right behaviour

@ivlcic
Copy link
Author

ivlcic commented Feb 13, 2024

I don't know what is the correct behaviour. You can get the overflowing tokens from both tokenizers. It's just that the returned data structure needs to be more consistent. I prefer the fast tokenizers behaviour, but the BatchEncoding returns None for the overflowing_tokens and is inconsistent with the advertised API in reference help.
I can try to fix this late in March, but I would appreciate your decision on which direction the API should go since I'm not an expert on transformers API.

@huggingface huggingface deleted a comment from github-actions bot Mar 8, 2024
@amyeroberts amyeroberts added Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want! bug labels Mar 8, 2024
@JINO-ROHIT
Copy link
Contributor

@ArthurZucker @amyeroberts im interested in taking up this issue,

Just wanted to confirm something else as well, shouldnt the behavior of AutoTokenizer match the specific tokenizer? Eg I tried this

from transformers import BertTokenizer, BertTokenizerFast, BatchEncoding, AutoTokenizer
n_tok = AutoTokenizer.from_pretrained("bert-base-uncased")
f_tok = BertTokenizerFast.from_pretrained("bert-base-uncased")

text = "hey this is jino, im just reading the api dont mind me"

n_inputs: BatchEncoding = n_tok(text=text, add_special_tokens=True, max_length=6, truncation=True, padding='max_length', return_overflowing_tokens=True)
o = n_inputs.get("overflowing_tokens")
print(f'Overflowing {o}')
print(n_inputs)

Outputs(much different from using the BertTokenizer shown by nikola above)
Overflowing None
{
"input_ids": [
[101, 10930, 10930, 2023, 2003, 102],
[101, 9743, 2080, 2054, 2015, 102],
[101, 2039, 9152, 23033, 2015, 102]
],
"token_type_ids": [
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]
],
"attention_mask": [
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]
],
"overflow_to_sample_mapping": [0, 0, 0]
}

@ArthurZucker
Copy link
Collaborator

Yes, fast and slow tokenizers are suppose to give a similar output (not the same format but all the overflow etc should)

@bayllama
Copy link
Contributor

bayllama commented Jul 8, 2024

Hi @ArthurZucker / @amyeroberts
For the following code when slow tokenizer is used,

from transformers import BertTokenizer, BertTokenizerFast, BatchEncoding, AutoTokenizer
n_tok = BertTokenizer.from_pretrained("bert-base-uncased")
f_tok = BertTokenizerFast.from_pretrained("bert-base-uncased")

text = "hey this is jino, im just reading the api dont mind me"

n_inputs: BatchEncoding = n_tok(text=text, add_special_tokens=True, max_length=6, truncation=True, padding='max_length', return_overflowing_tokens=True)
o = n_inputs.get("overflowing_tokens")
print(f'Overflowing {o}')
print(n_inputs)`

The following is the output,
Overflowing [2080, 1010, 10047, 2074, 3752, 1996, 17928, 2123, 2102, 2568, 2033] {'overflowing_tokens': [2080, 1010, 10047, 2074, 3752, 1996, 17928, 2123, 2102, 2568, 2033], 'num_truncated_tokens': 11, 'input_ids': [101, 4931, 2023, 2003, 9743, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1]}
What we notice is that, this is different from the output of Fast Tokenizer where the overflowing tokens are split into multiple batches of max sequence length and appended to input_ids. Do we want the Slow tokenizer to behave similar to the Fast one as well or is this the expected behavior?

@ArthurZucker
Copy link
Collaborator

In an optimal word, we want the slow to match the fast! I am not ceratin in this specific case which is "expected" or not 😅

@tibor-reiss
Copy link
Contributor

Hi @ArthurZucker / @amyeroberts,
this is the commit which changed the behavior, i.e. the overflowing tokens are not returned in the dictionary under "overflowing_tokens" key anymore.
As people were already asking, what is the preferred approach? Options:

  • Bring back this key, which means that the fast tokenizer will not return the overflowing tokens in "input_ids" anymore, but instead in "overflowing_tokens", thus making fast tokenizer consistent with slow tokenizer. This seems like the simpler approach.
  • Adjust the slow tokenizer to match the fast. This is quite big of a refactoring, and it would also mean that the key "overflowing_tokens" is not necessary anymore.
    Let me know what you think please.

@tibor-reiss tibor-reiss linked a pull request Nov 9, 2024 that will close this issue
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want!
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants