Skip to content

Commit

Permalink
Adjust slow tokenizer for return_overflowing_tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
tibor-reiss committed Nov 9, 2024
1 parent a06a0d1 commit c7c875e
Showing 1 changed file with 82 additions and 53 deletions.
135 changes: 82 additions & 53 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import os
import re
import warnings
from collections import UserDict
from collections import UserDict, defaultdict
from collections.abc import Mapping, Sized
from contextlib import contextmanager
from dataclasses import dataclass
Expand Down Expand Up @@ -3445,10 +3445,6 @@ def prepare_for_model(
**kwargs,
)

pair = bool(pair_ids is not None)
len_ids = len(ids)
len_pair_ids = len(pair_ids) if pair else 0

if return_token_type_ids and not add_special_tokens:
raise ValueError(
"Asking to return token_type_ids while setting add_special_tokens to False "
Expand All @@ -3473,64 +3469,97 @@ def prepare_for_model(
if return_attention_mask is None:
return_attention_mask = "attention_mask" in self.model_input_names

encoded_inputs = {}

# Compute the total size of the returned encodings
total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
def _calc(ids, pair_ids):
pair = bool(pair_ids is not None)
len_ids = len(ids)
len_pair_ids = len(pair_ids) if pair else 0

# Truncation: Handle max sequence length
overflowing_tokens = []
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
ids, pair_ids, overflowing_tokens = self.truncate_sequences(
ids,
pair_ids=pair_ids,
num_tokens_to_remove=total_len - max_length,
truncation_strategy=truncation_strategy,
stride=stride,
# Compute the total size of the returned encodings
total_len = (
len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
)

if return_overflowing_tokens:
encoded_inputs["overflowing_tokens"] = overflowing_tokens
encoded_inputs["num_truncated_tokens"] = total_len - max_length
# Truncation: Handle max sequence length
overflowing_tokens = []
num_tokens_to_remove = 0
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
num_tokens_to_remove = total_len - max_length
ids, pair_ids, overflowing_tokens = self.truncate_sequences(
ids,
pair_ids=pair_ids,
num_tokens_to_remove=num_tokens_to_remove,
truncation_strategy=truncation_strategy,
stride=stride,
)

# Add special tokens
if add_special_tokens:
sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
else:
sequence = ids + pair_ids if pair else ids
token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else [])

# Build output dictionary
encoded_inputs["input_ids"] = sequence
if return_token_type_ids:
encoded_inputs["token_type_ids"] = token_type_ids
if return_special_tokens_mask:
# Add special tokens
if add_special_tokens:
encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
else:
encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
sequence = ids + pair_ids if pair else ids
token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else [])

# Build output dictionary
encoded_inputs = {}
encoded_inputs["input_ids"] = sequence
if return_token_type_ids:
encoded_inputs["token_type_ids"] = token_type_ids
if return_special_tokens_mask:
if add_special_tokens:
encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
else:
encoded_inputs["special_tokens_mask"] = [0] * len(sequence)

# Check lengths
self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose)

# Padding
if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
encoded_inputs = self.pad(
encoded_inputs,
max_length=max_length,
padding=padding_strategy.value,
pad_to_multiple_of=pad_to_multiple_of,
padding_side=padding_side,
return_attention_mask=return_attention_mask,
)

# Check lengths
self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose)
if return_length:
encoded_inputs["length"] = len(encoded_inputs["input_ids"])
return encoded_inputs, num_tokens_to_remove

# Padding
if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
encoded_inputs = self.pad(
encoded_inputs,
max_length=max_length,
padding=padding_strategy.value,
pad_to_multiple_of=pad_to_multiple_of,
padding_side=padding_side,
return_attention_mask=return_attention_mask,
def _add_to_encoding_dict(encoding_dict, e):
encoding_dict["input_ids"].append(e["input_ids"])
if return_token_type_ids:
encoding_dict["token_type_ids"].append(e["token_type_ids"])
if return_attention_mask:
encoding_dict["attention_mask"].append(e["attention_mask"])
if return_special_tokens_mask:
encoding_dict["special_tokens_mask"].append(e["special_tokens_mask"])
if return_offsets_mapping:
encoding_dict["offset_mapping"].append(e["offsets"])
if return_length:
encoding_dict["length"].append(len(e["input_ids"]))

encoded_inputs, num_tokens_to_remove = _calc(ids, pair_ids)

if return_overflowing_tokens and num_tokens_to_remove:
encoding_dict = defaultdict(list)
_add_to_encoding_dict(encoding_dict, encoded_inputs)
while num_tokens_to_remove > 0:
encoded_inputs, num_tokens_to_remove = _calc(
ids[-num_tokens_to_remove:], pair_ids[-num_tokens_to_remove:] if pair_ids is not None else None
)
_add_to_encoding_dict(encoding_dict, encoded_inputs)
batch_outputs = BatchEncoding(
encoding_dict, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis
)

if return_length:
encoded_inputs["length"] = len(encoded_inputs["input_ids"])

batch_outputs = BatchEncoding(
encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis
)
else:
batch_outputs = BatchEncoding(
encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis
)

return batch_outputs

Expand Down

0 comments on commit c7c875e

Please sign in to comment.