diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index a08e3fad1..c0951ba9d 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -48,6 +48,7 @@ load_file_path, save_to_hub_args_decorator, truncate_embeddings, + truncate_masked_sequence, ) logger = logging.getLogger(__name__) @@ -560,15 +561,15 @@ def encode( out_features["sentence_embedding"] = truncate_embeddings( out_features["sentence_embedding"], self.truncate_dim ) - + # Do this check because otherwise this function is always in the + # hot path. + if output_value == "token_embeddings" or output_value is None: + if "token_embeddings" in out_features: + out_features["token_embeddings"] = truncate_masked_sequence( + out_features["token_embeddings"], out_features["attention_mask"] + ) if output_value == "token_embeddings": - embeddings = [] - for token_emb, attention in zip(out_features[output_value], out_features["attention_mask"]): - last_mask_id = len(attention) - 1 - while last_mask_id > 0 and attention[last_mask_id].item() == 0: - last_mask_id -= 1 - - embeddings.append(token_emb[0 : last_mask_id + 1]) + embeddings = features["token_embeddings"] elif output_value is None: # Return all outputs embeddings = [] for sent_idx in range(len(out_features["sentence_embedding"])): diff --git a/sentence_transformers/util.py b/sentence_transformers/util.py index df1a80eb2..71deb1a83 100644 --- a/sentence_transformers/util.py +++ b/sentence_transformers/util.py @@ -1509,3 +1509,38 @@ def disable_datasets_caching(): finally: if is_originally_enabled: enable_caching() + + +def truncate_masked_sequence(token_embeddings: torch.Tensor, attention_mask: torch.Tensor) -> list[torch.Tensor]: + """ + Process a tensor to remove padding tokens. + + Args: + token_embeddings (torch.Tensor): The token embeddings. + attention_mask (torch.Tensor): The attention mask. + + Returns: + list[torch.Tensor]: The processed token embeddings. Each tensor in the list corresponds to a sequence. + Each tensor contains all tokens up until the first padding token in the trailing sequence of padding tokens. + """ + out: list[torch.Tensor] = [] + all_zero_mask = attention_mask == 0 + for token_emb, zero_mask in zip(token_embeddings, all_zero_mask): + # Three cases: + # 1. No padding tokens. This happens at least once per batch + # unless padding is constant. + if not any(zero_mask): + out.append(token_emb) + continue + # 2. The first token is already a padding token. + # This should not happen, but leads to weird cases if we don't check. + if zero_mask[0]: + last_mask_id = 1 + else: + # 3. The padding tokens are in the middle of the sequence. + # We find the first padding token and remove it and everything after it. + last_mask_id = zero_mask.float().argmax().item() + + out.append(token_emb[:last_mask_id]) + + return out diff --git a/tests/test_util.py b/tests/test_util.py index 0cecce2bb..8586eaae5 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -6,7 +6,7 @@ import torch from sentence_transformers import SentenceTransformer, util -from sentence_transformers.util import community_detection +from sentence_transformers.util import community_detection, truncate_masked_sequence def test_normalize_embeddings() -> None: @@ -275,3 +275,28 @@ def test_community_detection_gpu_support(): ] result = community_detection(embeddings, threshold=0.8, min_community_size=2) assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected]) + + +def test_truncate_masked_sequence() -> None: + """Tests whether util.truncate_masked_sequence works correctly.""" + input_ids = torch.tensor([[101, 102, 103, 104, 105, 106], [101, 102, 103, 104, 105, 106]]) + expected = [torch.tensor([101, 102, 103, 104]), torch.tensor([101, 102, 103, 104, 105, 106])] + attention_mask = torch.tensor([[1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1]]) + truncated_input_ids = truncate_masked_sequence(input_ids, attention_mask) + assert [torch.equal(x, y) for x, y in zip(truncated_input_ids, expected)] + + old_output = _truncate_masked_sequence_old(input_ids, attention_mask) + assert [torch.equal(x, y) for x, y in zip(truncated_input_ids, old_output)] + + +def _truncate_masked_sequence_old(token_embeddings: torch.Tensor, attention_mask: torch.Tensor) -> list[torch.Tensor]: + """Helper function.""" + out: list[torch.Tensor] = [] + for token_emb, attention in zip(token_embeddings, attention_mask): + last_mask_id = len(attention) - 1 + while last_mask_id > 0 and attention[last_mask_id].item() == 0: + last_mask_id -= 1 + + out.append(token_emb[0 : last_mask_id + 1]) + + return out