From da154eb6b49c8664a370fea9d80e954031261af4 Mon Sep 17 00:00:00 2001 From: stephantul Date: Thu, 20 Mar 2025 14:44:17 +0100 Subject: [PATCH 1/6] fix inconsistency in token type embeddings --- sentence_transformers/SentenceTransformer.py | 14 ++++------ sentence_transformers/util.py | 29 ++++++++++++++++++++ 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index b9f46da0a..3d788f018 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -46,6 +46,7 @@ is_sentence_transformer_model, load_dir_path, load_file_path, + process_attention_mask, save_to_hub_args_decorator, truncate_embeddings, ) @@ -628,15 +629,12 @@ def encode( out_features["sentence_embedding"] = truncate_embeddings( out_features["sentence_embedding"], self.truncate_dim ) - + if "token_embeddings" in out_features: + out_features["token_embeddings"] = process_attention_mask( + out_features["token_embeddings"], out_features["attention_mask"] + ) if output_value == "token_embeddings": - embeddings = [] - for token_emb, attention in zip(out_features[output_value], out_features["attention_mask"]): - last_mask_id = len(attention) - 1 - while last_mask_id > 0 and attention[last_mask_id].item() == 0: - last_mask_id -= 1 - - embeddings.append(token_emb[0 : last_mask_id + 1]) + embeddings = features["token_embeddings"] elif output_value is None: # Return all outputs embeddings = [] for sent_idx in range(len(out_features["sentence_embedding"])): diff --git a/sentence_transformers/util.py b/sentence_transformers/util.py index df1a80eb2..30909ffc9 100644 --- a/sentence_transformers/util.py +++ b/sentence_transformers/util.py @@ -1509,3 +1509,32 @@ def disable_datasets_caching(): finally: if is_originally_enabled: enable_caching() + + +def process_attention_mask(token_embeddings: list[torch.Tensor], attention_mask: torch.Tensor) -> list[torch.Tensor]: + """Process the attention mask to remove padding tokens from the token embeddings.""" + out: list[torch.Tensor] = [] + for token_emb, attention in zip(token_embeddings, attention_mask): + if attention[0] == 0: + last_mask_id = 1 + else: + last_mask_id = (attention == 0).float().argmax().item() + + out.append(token_emb[:last_mask_id]) + + return out + + +def process_attention_mask_old( + token_embeddings: list[torch.Tensor], attention_mask: torch.Tensor +) -> list[torch.Tensor]: + """Process the attention mask to remove padding tokens from the token embeddings.""" + out: list[torch.Tensor] = [] + for token_emb, attention in zip(token_embeddings, attention_mask): + last_mask_id = len(attention) - 1 + while last_mask_id > 0 and attention[last_mask_id].item() == 0: + last_mask_id -= 1 + + out.append(token_emb[0 : last_mask_id + 1]) + + return out From 46a38a7286b7cf450bd7e15e38fd3769e48703fd Mon Sep 17 00:00:00 2001 From: stephantul Date: Thu, 20 Mar 2025 14:44:48 +0100 Subject: [PATCH 2/6] remove old function --- sentence_transformers/util.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/sentence_transformers/util.py b/sentence_transformers/util.py index 30909ffc9..f05b33946 100644 --- a/sentence_transformers/util.py +++ b/sentence_transformers/util.py @@ -1523,18 +1523,3 @@ def process_attention_mask(token_embeddings: list[torch.Tensor], attention_mask: out.append(token_emb[:last_mask_id]) return out - - -def process_attention_mask_old( - token_embeddings: list[torch.Tensor], attention_mask: torch.Tensor -) -> list[torch.Tensor]: - """Process the attention mask to remove padding tokens from the token embeddings.""" - out: list[torch.Tensor] = [] - for token_emb, attention in zip(token_embeddings, attention_mask): - last_mask_id = len(attention) - 1 - while last_mask_id > 0 and attention[last_mask_id].item() == 0: - last_mask_id -= 1 - - out.append(token_emb[0 : last_mask_id + 1]) - - return out From 0c61def59689710fd9115d2a0cb7e6415412e369 Mon Sep 17 00:00:00 2001 From: stephantul Date: Thu, 20 Mar 2025 18:10:02 +0100 Subject: [PATCH 3/6] rename function --- sentence_transformers/SentenceTransformer.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 625059c2e..c0951ba9d 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -46,9 +46,9 @@ is_sentence_transformer_model, load_dir_path, load_file_path, - process_attention_mask, save_to_hub_args_decorator, truncate_embeddings, + truncate_masked_sequence, ) logger = logging.getLogger(__name__) @@ -561,10 +561,13 @@ def encode( out_features["sentence_embedding"] = truncate_embeddings( out_features["sentence_embedding"], self.truncate_dim ) - if "token_embeddings" in out_features: - out_features["token_embeddings"] = process_attention_mask( - out_features["token_embeddings"], out_features["attention_mask"] - ) + # Do this check because otherwise this function is always in the + # hot path. + if output_value == "token_embeddings" or output_value is None: + if "token_embeddings" in out_features: + out_features["token_embeddings"] = truncate_masked_sequence( + out_features["token_embeddings"], out_features["attention_mask"] + ) if output_value == "token_embeddings": embeddings = features["token_embeddings"] elif output_value is None: # Return all outputs From 9f5f7beed08d0ea9ff2388b4054ac863afcee4e3 Mon Sep 17 00:00:00 2001 From: stephantul Date: Thu, 20 Mar 2025 18:10:37 +0100 Subject: [PATCH 4/6] add extra test --- tests/test_util.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/tests/test_util.py b/tests/test_util.py index 0cecce2bb..8586eaae5 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -6,7 +6,7 @@ import torch from sentence_transformers import SentenceTransformer, util -from sentence_transformers.util import community_detection +from sentence_transformers.util import community_detection, truncate_masked_sequence def test_normalize_embeddings() -> None: @@ -275,3 +275,28 @@ def test_community_detection_gpu_support(): ] result = community_detection(embeddings, threshold=0.8, min_community_size=2) assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected]) + + +def test_truncate_masked_sequence() -> None: + """Tests whether util.truncate_masked_sequence works correctly.""" + input_ids = torch.tensor([[101, 102, 103, 104, 105, 106], [101, 102, 103, 104, 105, 106]]) + expected = [torch.tensor([101, 102, 103, 104]), torch.tensor([101, 102, 103, 104, 105, 106])] + attention_mask = torch.tensor([[1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1]]) + truncated_input_ids = truncate_masked_sequence(input_ids, attention_mask) + assert [torch.equal(x, y) for x, y in zip(truncated_input_ids, expected)] + + old_output = _truncate_masked_sequence_old(input_ids, attention_mask) + assert [torch.equal(x, y) for x, y in zip(truncated_input_ids, old_output)] + + +def _truncate_masked_sequence_old(token_embeddings: torch.Tensor, attention_mask: torch.Tensor) -> list[torch.Tensor]: + """Helper function.""" + out: list[torch.Tensor] = [] + for token_emb, attention in zip(token_embeddings, attention_mask): + last_mask_id = len(attention) - 1 + while last_mask_id > 0 and attention[last_mask_id].item() == 0: + last_mask_id -= 1 + + out.append(token_emb[0 : last_mask_id + 1]) + + return out From 1a27449ae57149d7611abe0f3bb08be4919b80b9 Mon Sep 17 00:00:00 2001 From: stephantul Date: Thu, 20 Mar 2025 19:49:13 +0100 Subject: [PATCH 5/6] add docstring --- sentence_transformers/util.py | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/sentence_transformers/util.py b/sentence_transformers/util.py index f05b33946..6c6159a6d 100644 --- a/sentence_transformers/util.py +++ b/sentence_transformers/util.py @@ -1511,14 +1511,35 @@ def disable_datasets_caching(): enable_caching() -def process_attention_mask(token_embeddings: list[torch.Tensor], attention_mask: torch.Tensor) -> list[torch.Tensor]: - """Process the attention mask to remove padding tokens from the token embeddings.""" +def truncate_masked_sequence(token_embeddings: torch.Tensor, attention_mask: torch.Tensor) -> list[torch.Tensor]: + """ + Process a tensor to remove padding tokens. + + Args: + token_embeddings (torch.Tensor): The token embeddings. + attention_mask (torch.Tensor): The attention mask. + + Returns: + list[torch.Tensor]: The processed token embeddings. Each tensor in the list corresponds to a sequence. + Each tensor contains all tokens up until the first padding token in the trailing sequence of padding tokens. + """ out: list[torch.Tensor] = [] - for token_emb, attention in zip(token_embeddings, attention_mask): - if attention[0] == 0: + all_zero_mask = attention_mask == 0 + for token_emb, zero_mask in zip(token_embeddings, all_zero_mask): + # Three cases: + # 1. No padding tokens. This happens at least once per batch + # unless padding is constant. + if not any(zero_mask): + out.append(token_emb) + continue + # 2. The first token is already a padding token. + # This should not happen, but leads to weird cases if we don't check. + if zero_mask[0]: last_mask_id = 1 else: - last_mask_id = (attention == 0).float().argmax().item() + # 3. The padding tokens are in the middle of the sequence. + # We find the last padding token and remove it and everything after it. + last_mask_id = zero_mask.float().argmax().item() out.append(token_emb[:last_mask_id]) From c23068c6083deb179c06938de697c3c4ed395d0a Mon Sep 17 00:00:00 2001 From: stephantul Date: Thu, 20 Mar 2025 20:07:53 +0100 Subject: [PATCH 6/6] fix comment --- sentence_transformers/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sentence_transformers/util.py b/sentence_transformers/util.py index 6c6159a6d..71deb1a83 100644 --- a/sentence_transformers/util.py +++ b/sentence_transformers/util.py @@ -1538,7 +1538,7 @@ def truncate_masked_sequence(token_embeddings: torch.Tensor, attention_mask: tor last_mask_id = 1 else: # 3. The padding tokens are in the middle of the sequence. - # We find the last padding token and remove it and everything after it. + # We find the first padding token and remove it and everything after it. last_mask_id = zero_mask.float().argmax().item() out.append(token_emb[:last_mask_id])