Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
load_file_path,
save_to_hub_args_decorator,
truncate_embeddings,
truncate_masked_sequence,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -560,15 +561,15 @@ def encode(
out_features["sentence_embedding"] = truncate_embeddings(
out_features["sentence_embedding"], self.truncate_dim
)

# Do this check because otherwise this function is always in the
# hot path.
if output_value == "token_embeddings" or output_value is None:
if "token_embeddings" in out_features:
out_features["token_embeddings"] = truncate_masked_sequence(
out_features["token_embeddings"], out_features["attention_mask"]
)
if output_value == "token_embeddings":
embeddings = []
for token_emb, attention in zip(out_features[output_value], out_features["attention_mask"]):
last_mask_id = len(attention) - 1
while last_mask_id > 0 and attention[last_mask_id].item() == 0:
last_mask_id -= 1

embeddings.append(token_emb[0 : last_mask_id + 1])
embeddings = features["token_embeddings"]
elif output_value is None: # Return all outputs
embeddings = []
for sent_idx in range(len(out_features["sentence_embedding"])):
Expand Down
35 changes: 35 additions & 0 deletions sentence_transformers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1509,3 +1509,38 @@ def disable_datasets_caching():
finally:
if is_originally_enabled:
enable_caching()


def truncate_masked_sequence(token_embeddings: torch.Tensor, attention_mask: torch.Tensor) -> list[torch.Tensor]:
"""
Process a tensor to remove padding tokens.

Args:
token_embeddings (torch.Tensor): The token embeddings.
attention_mask (torch.Tensor): The attention mask.

Returns:
list[torch.Tensor]: The processed token embeddings. Each tensor in the list corresponds to a sequence.
Each tensor contains all tokens up until the first padding token in the trailing sequence of padding tokens.
"""
out: list[torch.Tensor] = []
all_zero_mask = attention_mask == 0
for token_emb, zero_mask in zip(token_embeddings, all_zero_mask):
# Three cases:
# 1. No padding tokens. This happens at least once per batch
# unless padding is constant.
if not any(zero_mask):
out.append(token_emb)
continue
# 2. The first token is already a padding token.
# This should not happen, but leads to weird cases if we don't check.
if zero_mask[0]:
last_mask_id = 1
else:
# 3. The padding tokens are in the middle of the sequence.
# We find the first padding token and remove it and everything after it.
last_mask_id = zero_mask.float().argmax().item()

out.append(token_emb[:last_mask_id])

return out
27 changes: 26 additions & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from sentence_transformers import SentenceTransformer, util
from sentence_transformers.util import community_detection
from sentence_transformers.util import community_detection, truncate_masked_sequence


def test_normalize_embeddings() -> None:
Expand Down Expand Up @@ -275,3 +275,28 @@ def test_community_detection_gpu_support():
]
result = community_detection(embeddings, threshold=0.8, min_community_size=2)
assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected])


def test_truncate_masked_sequence() -> None:
"""Tests whether util.truncate_masked_sequence works correctly."""
input_ids = torch.tensor([[101, 102, 103, 104, 105, 106], [101, 102, 103, 104, 105, 106]])
expected = [torch.tensor([101, 102, 103, 104]), torch.tensor([101, 102, 103, 104, 105, 106])]
attention_mask = torch.tensor([[1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1]])
truncated_input_ids = truncate_masked_sequence(input_ids, attention_mask)
assert [torch.equal(x, y) for x, y in zip(truncated_input_ids, expected)]

old_output = _truncate_masked_sequence_old(input_ids, attention_mask)
assert [torch.equal(x, y) for x, y in zip(truncated_input_ids, old_output)]


def _truncate_masked_sequence_old(token_embeddings: torch.Tensor, attention_mask: torch.Tensor) -> list[torch.Tensor]:
"""Helper function."""
out: list[torch.Tensor] = []
for token_emb, attention in zip(token_embeddings, attention_mask):
last_mask_id = len(attention) - 1
while last_mask_id > 0 and attention[last_mask_id].item() == 0:
last_mask_id -= 1

out.append(token_emb[0 : last_mask_id + 1])

return out
Loading