From c7c875e51bf0f97e782caade035f5a709edf6ba3 Mon Sep 17 00:00:00 2001 From: Tibor Reiss Date: Sat, 9 Nov 2024 11:59:22 +0100 Subject: [PATCH] Adjust slow tokenizer for return_overflowing_tokens --- src/transformers/tokenization_utils_base.py | 135 ++++++++++++-------- 1 file changed, 82 insertions(+), 53 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 381f3ef497d9bd..663cce4f6d4bbb 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -23,7 +23,7 @@ import os import re import warnings -from collections import UserDict +from collections import UserDict, defaultdict from collections.abc import Mapping, Sized from contextlib import contextmanager from dataclasses import dataclass @@ -3445,10 +3445,6 @@ def prepare_for_model( **kwargs, ) - pair = bool(pair_ids is not None) - len_ids = len(ids) - len_pair_ids = len(pair_ids) if pair else 0 - if return_token_type_ids and not add_special_tokens: raise ValueError( "Asking to return token_type_ids while setting add_special_tokens to False " @@ -3473,64 +3469,97 @@ def prepare_for_model( if return_attention_mask is None: return_attention_mask = "attention_mask" in self.model_input_names - encoded_inputs = {} - - # Compute the total size of the returned encodings - total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0) + def _calc(ids, pair_ids): + pair = bool(pair_ids is not None) + len_ids = len(ids) + len_pair_ids = len(pair_ids) if pair else 0 - # Truncation: Handle max sequence length - overflowing_tokens = [] - if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length: - ids, pair_ids, overflowing_tokens = self.truncate_sequences( - ids, - pair_ids=pair_ids, - num_tokens_to_remove=total_len - max_length, - truncation_strategy=truncation_strategy, - stride=stride, + # Compute the total size of the returned encodings + total_len = ( + len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0) ) - if return_overflowing_tokens: - encoded_inputs["overflowing_tokens"] = overflowing_tokens - encoded_inputs["num_truncated_tokens"] = total_len - max_length + # Truncation: Handle max sequence length + overflowing_tokens = [] + num_tokens_to_remove = 0 + if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length: + num_tokens_to_remove = total_len - max_length + ids, pair_ids, overflowing_tokens = self.truncate_sequences( + ids, + pair_ids=pair_ids, + num_tokens_to_remove=num_tokens_to_remove, + truncation_strategy=truncation_strategy, + stride=stride, + ) - # Add special tokens - if add_special_tokens: - sequence = self.build_inputs_with_special_tokens(ids, pair_ids) - token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) - else: - sequence = ids + pair_ids if pair else ids - token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else []) - - # Build output dictionary - encoded_inputs["input_ids"] = sequence - if return_token_type_ids: - encoded_inputs["token_type_ids"] = token_type_ids - if return_special_tokens_mask: + # Add special tokens if add_special_tokens: - encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids) + sequence = self.build_inputs_with_special_tokens(ids, pair_ids) + token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) else: - encoded_inputs["special_tokens_mask"] = [0] * len(sequence) + sequence = ids + pair_ids if pair else ids + token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else []) + + # Build output dictionary + encoded_inputs = {} + encoded_inputs["input_ids"] = sequence + if return_token_type_ids: + encoded_inputs["token_type_ids"] = token_type_ids + if return_special_tokens_mask: + if add_special_tokens: + encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids) + else: + encoded_inputs["special_tokens_mask"] = [0] * len(sequence) + + # Check lengths + self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose) + + # Padding + if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask: + encoded_inputs = self.pad( + encoded_inputs, + max_length=max_length, + padding=padding_strategy.value, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=return_attention_mask, + ) - # Check lengths - self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose) + if return_length: + encoded_inputs["length"] = len(encoded_inputs["input_ids"]) + return encoded_inputs, num_tokens_to_remove - # Padding - if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask: - encoded_inputs = self.pad( - encoded_inputs, - max_length=max_length, - padding=padding_strategy.value, - pad_to_multiple_of=pad_to_multiple_of, - padding_side=padding_side, - return_attention_mask=return_attention_mask, + def _add_to_encoding_dict(encoding_dict, e): + encoding_dict["input_ids"].append(e["input_ids"]) + if return_token_type_ids: + encoding_dict["token_type_ids"].append(e["token_type_ids"]) + if return_attention_mask: + encoding_dict["attention_mask"].append(e["attention_mask"]) + if return_special_tokens_mask: + encoding_dict["special_tokens_mask"].append(e["special_tokens_mask"]) + if return_offsets_mapping: + encoding_dict["offset_mapping"].append(e["offsets"]) + if return_length: + encoding_dict["length"].append(len(e["input_ids"])) + + encoded_inputs, num_tokens_to_remove = _calc(ids, pair_ids) + + if return_overflowing_tokens and num_tokens_to_remove: + encoding_dict = defaultdict(list) + _add_to_encoding_dict(encoding_dict, encoded_inputs) + while num_tokens_to_remove > 0: + encoded_inputs, num_tokens_to_remove = _calc( + ids[-num_tokens_to_remove:], pair_ids[-num_tokens_to_remove:] if pair_ids is not None else None + ) + _add_to_encoding_dict(encoding_dict, encoded_inputs) + batch_outputs = BatchEncoding( + encoding_dict, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis ) - if return_length: - encoded_inputs["length"] = len(encoded_inputs["input_ids"]) - - batch_outputs = BatchEncoding( - encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis - ) + else: + batch_outputs = BatchEncoding( + encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis + ) return batch_outputs