Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 82 additions & 18 deletions sentence_transformers/util/hard_negatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def mine_hard_negatives(
corpus_prompt: str | None = None,
include_positives: bool = False,
output_format: Literal["triplet", "n-tuple", "n-tuple-scores", "labeled-pair", "labeled-list"] = "triplet",
include_scores: bool = False,
batch_size: int = 32,
faiss_batch_size: int = 16384,
use_faiss: bool = False,
Expand Down Expand Up @@ -157,6 +158,18 @@ def mine_hard_negatives(
'answer': "Richmond Football Club Richmond began 2017 with 5 straight wins, a feat it had not achieved since 1995. A series of close losses hampered the Tigers throughout the middle of the season, including a 5-point loss to the Western Bulldogs, 2-point loss to Fremantle, and a 3-point loss to the Giants. Richmond ended the season strongly with convincing victories over Fremantle and St Kilda in the final two rounds, elevating the club to 3rd on the ladder. Richmond's first final of the season against the Cats at the MCG attracted a record qualifying final crowd of 95,028; the Tigers won by 51 points. Having advanced to the first preliminary finals for the first time since 2001, Richmond defeated Greater Western Sydney by 36 points in front of a crowd of 94,258 to progress to the Grand Final against Adelaide, their first Grand Final appearance since 1982. The attendance was 100,021, the largest crowd to a grand final since 1986. The Crows led at quarter time and led by as many as 13, but the Tigers took over the game as it progressed and scored seven straight goals at one point. They eventually would win by 48 points – 16.12 (108) to Adelaide's 8.12 (60) – to end their 37-year flag drought.[22] Dustin Martin also became the first player to win a Premiership medal, the Brownlow Medal and the Norm Smith Medal in the same season, while Damien Hardwick was named AFL Coaches Association Coach of the Year. Richmond's jump from 13th to premiers also marked the biggest jump from one AFL season to the next.",
'negative': "2018 NRL Grand Final The 2018 NRL Grand Final was the conclusive and premiership-deciding game of the 2018 National Rugby League season and was played on Sunday September 30 at Sydney's ANZ Stadium.[1] The match was contested between minor premiers the Sydney Roosters and defending premiers the Melbourne Storm. In front of a crowd of 82,688, Sydney won the match 21–6 to claim their 14th premiership title and their first since 2013. Roosters five-eighth Luke Keary was awarded the Clive Churchill Medal as the game's official man of the match."
}
>>> # To include similarity scores, use include_scores=True
>>> dataset_with_scores = mine_hard_negatives(
... dataset=dataset,
... model=model,
... include_scores=True,
... # ... other parameters
... )
>>> dataset_with_scores
Dataset({
features: ['query', 'answer', 'negative', 'positive_score', 'negative_score'],
num_rows: 487865
})
>>> dataset.push_to_hub("natural-questions-hard-negatives", "triplet-all")

Args:
Expand Down Expand Up @@ -204,13 +217,20 @@ def mine_hard_negatives(
Defaults to False.
output_format (Literal["triplet", "n-tuple", "n-tuple-scores", "labeled-pair", "labeled-list"]): Output format for the `datasets.Dataset`. Options are:

- "triplet": (anchor, positive, negative) triplets, i.e. 3 columns. Useful for e.g. :class:`~sentence_transformers.cross_encoder.losses.CachedMultipleNegativesRankingLoss`.
- "triplet": (anchor, positive, negative) triplets, i.e. 3 columns. If `include_scores=True`, adds `positive_score` and `negative_score` columns (5 columns total). Useful for e.g. :class:`~sentence_transformers.cross_encoder.losses.CachedMultipleNegativesRankingLoss`.
- "n-tuple": (anchor, positive, negative_1, ..., negative_n) tuples, i.e. 2 + num_negatives columns. Useful for e.g. :class:`~sentence_transformers.cross_encoder.losses.CachedMultipleNegativesRankingLoss`.
- "n-tuple-scores": (anchor, positive, negative_1, ..., negative_n, score) tuples, i.e. 2 + num_negatives columns, but with one score value that's a list of similarities for the query-positive and each of the query-negative pairs. Useful for e.g. :class:`~sentence_transformers.sparse_encoder.losses.SparseMarginMSELoss`.
- "labeled-pair": (anchor, passage, label) text tuples with a label of 0 for negative and 1 for positive, i.e. 3 columns. Useful for e.g. :class:`~sentence_transformers.cross_encoder.losses.BinaryCrossEntropyLoss`.
- "labeled-list": (anchor, [doc1, doc2, ..., docN], [label1, label2, ..., labelN]) triplets with labels of 0 for negative and 1 for positive, i.e. 3 columns. Useful for e.g. :class:`~sentence_transformers.cross_encoder.losses.LambdaLoss`.
- "labeled-pair": (anchor, passage, label) text tuples with a label of 0 for negative and 1 for positive, i.e. 3 columns. If `include_scores=True`, adds a `score` column (4 columns total). Useful for e.g. :class:`~sentence_transformers.cross_encoder.losses.BinaryCrossEntropyLoss`.
- "labeled-list": (anchor, [doc1, doc2, ..., docN], [label1, label2, ..., labelN]) tuples with labels of 0 for negative and 1 for positive, i.e. 3 columns. If `include_scores=True`, adds a `scores` column with corresponding similarity scores (4 columns total). Useful for e.g. :class:`~sentence_transformers.cross_encoder.losses.LambdaLoss`.

Defaults to "triplet".
include_scores (bool): Whether to include similarity scores in the output dataset. When True, adds score fields to the output:
- For "triplet" format: adds `positive_score` and `negative_score` columns
- For "labeled-pair" format: adds `score` column
- For "labeled-list" format: adds `scores` column
- For "n-tuple-scores" format: scores are always included regardless of this parameter
- For "n-tuple" format: no scores are added
Defaults to False.
batch_size (int): Batch size for encoding the dataset. Defaults to 32.
faiss_batch_size (int): Batch size for FAISS top-k search. Defaults to 16384.
use_faiss (bool): Whether to use FAISS for similarity search. May be recommended for large datasets. Defaults to False.
Expand All @@ -227,8 +247,14 @@ def mine_hard_negatives(


Returns:
Dataset: A dataset containing (anchor, positive, negative) triplets, (anchor, passage, label) text tuples with
a label, or (anchor, positive, negative_1, ..., negative_n) tuples.
Dataset: A dataset containing the specified output format. When `include_scores=True`, score fields are added:
- "triplet": (anchor, positive, negative, positive_score, negative_score) tuples
- "labeled-pair": (anchor, passage, label, score) tuples
- "labeled-list": (anchor, [passages], [labels], [scores]) tuples
- "n-tuple": (anchor, positive, negative_1, ..., negative_n) tuples (no scores)
- "n-tuple-scores": (anchor, positive, negative_1, ..., negative_n, score) tuples (always includes scores)

When `include_scores=False` (default), only the basic format is returned without score fields.
"""
if not is_datasets_available():
raise ImportError("Please install `datasets` to use this function: `pip install datasets`.")
Expand Down Expand Up @@ -468,7 +494,7 @@ def mine_hard_negatives(

n_positives = [len(p) for p in positive_indices]

# re-sort the positives and all_queries according to the deduplicated queries
# Reorder the positives and all_queries according to the deduplicated queries
positives = []
all_queries = []
for idx in range(n_queries):
Expand Down Expand Up @@ -523,10 +549,12 @@ def mine_hard_negatives(
# If there are multiple positives, we need to define which one to use for the margin
# To be on the safe side, we will use the _minimum_ positive score (i.e., harder positive) for the margin
max_positive_scores = torch.empty(n_queries, device=positive_scores.device, dtype=positive_scores.dtype)
start_idx = 0
positive_score_idx = 0
for q_idx in range(n_queries):
max_positive_scores[q_idx] = torch.min(positive_scores[start_idx : start_idx + n_positives[q_idx]])
start_idx += n_positives[q_idx - 1]
max_positive_scores[q_idx] = torch.min(
positive_scores[positive_score_idx : positive_score_idx + n_positives[q_idx]]
)
positive_score_idx += n_positives[q_idx - 1]

if absolute_margin is not None:
removed_indices = scores + absolute_margin > max_positive_scores.repeat(scores.size(1), 1).T
Expand Down Expand Up @@ -612,15 +640,15 @@ def mine_hard_negatives(
negative_scores = negative_scores[indices_to_keep]

# the anchor_indices matrix is shaped [n_total_queries, n_negatives]
start_idx = 0
positive_score_idx = 0
for q_idx in range(n_queries):
anchor_indices[start_idx : start_idx + n_positives[q_idx]] = torch.tensor(q_idx).repeat(
anchor_indices[positive_score_idx : positive_score_idx + n_positives[q_idx]] = torch.tensor(q_idx).repeat(
n_positives[q_idx], num_negatives
)
pos_indices[start_idx : start_idx + n_positives[q_idx]] = (
pos_indices[positive_score_idx : positive_score_idx + n_positives[q_idx]] = (
positive_indices[q_idx].repeat(num_negatives, 1).T
)
start_idx += n_positives[q_idx]
positive_score_idx += n_positives[q_idx]

anchor_indices = anchor_indices[indices_to_keep]
positive_indices = pos_indices[indices_to_keep]
Expand All @@ -631,10 +659,20 @@ def mine_hard_negatives(
"negative": [],
}

for anchor_idx, positive_idx, negative_idx in zip(anchor_indices, positive_indices, indices):
if include_scores:
dataset_data["positive_score"] = []
dataset_data["negative_score"] = []

positive_scores_expanded = positive_scores.repeat(num_negatives, 1).T[indices_to_keep]
for anchor_idx, positive_corpus_idx, negative_idx, pos_score, neg_score in zip(
anchor_indices, positive_indices, indices, positive_scores_expanded, negative_scores
):
dataset_data[anchor_column_name].append(queries[anchor_idx])
dataset_data[positive_column_name].append(corpus[positive_idx])
dataset_data[positive_column_name].append(corpus[positive_corpus_idx])
dataset_data["negative"].append(corpus[negative_idx])
if include_scores:
dataset_data["positive_score"].append(pos_score.item())
dataset_data["negative_score"].append(neg_score.item())
difference_scores = positive_scores.repeat(num_negatives, 1).T[indices_to_keep] - negative_scores
maximum_possible_samples = indices_to_keep.numel()

Expand All @@ -647,23 +685,41 @@ def mine_hard_negatives(
"label": [],
}

if include_scores:
dataset_data["score"] = []

positive_score_idx = 0
for query_idx in range(n_queries):
for positive_idx in positive_indices[query_idx]:
for positive_corpus_idx in positive_indices[query_idx]:
dataset_data[anchor_column_name].append(queries[query_idx])
dataset_data[positive_column_name].append(corpus[positive_idx])
dataset_data[positive_column_name].append(corpus[positive_corpus_idx])
dataset_data["label"].append(1)
if include_scores:
dataset_data["score"].append(positive_scores[positive_score_idx].item())
positive_score_idx += 1

for negative_idx, negative_score in zip(indices[query_idx], negative_scores[query_idx]):
if negative_score == -float("inf"):
continue
dataset_data[anchor_column_name].append(queries[query_idx])
dataset_data[positive_column_name].append(corpus[negative_idx])
dataset_data["label"].append(0)
if include_scores:
dataset_data["score"].append(negative_score.item())

negative_scores = negative_scores[indices_to_keep]
difference_scores = positive_scores.repeat(num_negatives, 1).T[indices_to_keep] - negative_scores
maximum_possible_samples = n_queries * num_negatives + len(dataset)

elif output_format in ("n-tuple", "n-tuple-scores"):
if output_format == "n-tuple-scores":
logger.warning(
"\"n-tuple-scores\" value of `output_format` is deprecated. '" \
"'Use \"n-tuple\" with `include_scores=True instead"
)
output_format = "n-tuple"
include_scores = True

# Keep only indices where num_negative negatives were found
indices_to_keep = (negative_scores != -float("inf")).all(dim=1)
negative_scores = negative_scores[indices_to_keep]
Expand All @@ -677,7 +733,7 @@ def mine_hard_negatives(
for i, neg_indices in enumerate(indices.T, start=1)
},
}
if output_format == "n-tuple-scores":
if include_scores:
dataset_data["score"] = torch.cat(
[positive_scores[indices_to_keep].unsqueeze(-1), negative_scores], dim=1
).tolist()
Expand All @@ -697,6 +753,14 @@ def mine_hard_negatives(
],
"labels": [[1] + [0] * sum(keep_row) for keep_row in indices_to_keep if keep_row.any()],
}

if include_scores:
dataset_data["scores"] = [
[positive_scores[idx].item()] + negative_scores[idx][keep_row].tolist()
for idx, keep_row in enumerate(indices_to_keep)
if keep_row.any()
]

negative_scores = negative_scores[indices_to_keep]
difference_scores = positive_scores.repeat(num_negatives, 1).T[indices_to_keep] - negative_scores
maximum_possible_samples = indices_to_keep.size(0)
Expand Down
72 changes: 68 additions & 4 deletions tests/util/test_hard_negatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,21 +397,51 @@ def test_include_positives_with_labeled_formats(
def test_output_formats(dataset: Dataset, static_retrieval_mrl_en_v1_model: SentenceTransformer) -> None:
"""Test all output_format options."""
model = static_retrieval_mrl_en_v1_model
# Test triplet format

# Test triplet format without scores (default)
result_triplet = mine_hard_negatives(dataset=dataset, model=model, output_format="triplet", verbose=False)
assert "query" in result_triplet.column_names
assert "passage" in result_triplet.column_names
assert "negative" in result_triplet.column_names
assert "positive_score" not in result_triplet.column_names
assert "negative_score" not in result_triplet.column_names
assert len(result_triplet.column_names) == 3

# Test n-tuple format
# Test triplet format with scores
result_triplet_scores = mine_hard_negatives(
dataset=dataset, model=model, output_format="triplet", include_scores=True, verbose=False
)
assert "query" in result_triplet_scores.column_names
assert "passage" in result_triplet_scores.column_names
assert "negative" in result_triplet_scores.column_names
assert "positive_score" in result_triplet_scores.column_names
assert "negative_score" in result_triplet_scores.column_names
assert len(result_triplet_scores.column_names) == 5

# Verify scores are numeric values
if len(result_triplet_scores) > 0:
assert all(isinstance(score, (int, float)) for score in result_triplet_scores["positive_score"])
assert all(isinstance(score, (int, float)) for score in result_triplet_scores["negative_score"])

# Test n-tuple format without scores (default)
result_ntuple = mine_hard_negatives(
dataset=dataset, model=model, num_negatives=2, output_format="n-tuple", verbose=False
)
assert "query" in result_ntuple.column_names
assert "passage" in result_ntuple.column_names
assert "negative_1" in result_ntuple.column_names
assert "negative_2" in result_ntuple.column_names
assert "score" not in result_ntuple.column_names

# Test n-tuple format with scores
result_ntuple_scores = mine_hard_negatives(
dataset=dataset, model=model, num_negatives=2, output_format="n-tuple", include_scores=True, verbose=False
)
assert "query" in result_ntuple_scores.column_names
assert "passage" in result_ntuple_scores.column_names
assert "negative_1" in result_ntuple_scores.column_names
assert "negative_2" in result_ntuple_scores.column_names
assert "score" in result_ntuple_scores.column_names

# Test n-tuple-scores format
result_scores = mine_hard_negatives(
Expand All @@ -426,21 +456,49 @@ def test_output_formats(dataset: Dataset, static_retrieval_mrl_en_v1_model: Sent
# Verify scores are lists of expected length (1 positive + num_negatives)
assert all(len(score) == 3 for score in result_scores["score"])

# Test labeled-pair format
# Test labeled-pair format without scores (default)
result_pair = mine_hard_negatives(dataset=dataset, model=model, output_format="labeled-pair", verbose=False)
assert "query" in result_pair.column_names
assert "passage" in result_pair.column_names
assert "label" in result_pair.column_names
assert "score" not in result_pair.column_names
assert len(result_pair.column_names) == 3

# Test labeled-pair format with scores
result_pair_scores = mine_hard_negatives(
dataset=dataset, model=model, output_format="labeled-pair", include_scores=True, verbose=False
)
assert "query" in result_pair_scores.column_names
assert "passage" in result_pair_scores.column_names
assert "label" in result_pair_scores.column_names
assert "score" in result_pair_scores.column_names
assert len(result_pair_scores.column_names) == 4

# Verify labels are 0 or 1
labels = set(result_pair["label"])
assert labels == {0, 1}

# Test labeled-list format
# Verify scores are numeric values
if len(result_pair_scores) > 0:
assert all(isinstance(score, (int, float)) for score in result_pair_scores["score"])

# Test labeled-list format without scores (default)
result_list = mine_hard_negatives(dataset=dataset, model=model, output_format="labeled-list", verbose=False)
assert "query" in result_list.column_names
assert "passage" in result_list.column_names
assert "labels" in result_list.column_names
assert "scores" not in result_list.column_names
assert len(result_list.column_names) == 3

# Test labeled-list format with scores
result_list_scores = mine_hard_negatives(
dataset=dataset, model=model, output_format="labeled-list", include_scores=True, verbose=False
)
assert "query" in result_list_scores.column_names
assert "passage" in result_list_scores.column_names
assert "labels" in result_list_scores.column_names
assert "scores" in result_list_scores.column_names
assert len(result_list_scores.column_names) == 4

# Verify each item in 'passage' is a list
assert all(isinstance(p, list) for p in result_list["passage"])
Expand All @@ -451,6 +509,12 @@ def test_output_formats(dataset: Dataset, static_retrieval_mrl_en_v1_model: Sent
assert label_list[0] == 1
assert all(label == 0 for label in label_list[1:])

# Verify each item in 'scores' is a list with numeric values
if len(result_list_scores) > 0:
for score_list in result_list_scores["scores"]:
assert isinstance(score_list, list)
assert all(isinstance(score, (int, float)) for score in score_list)


def test_batch_size(dataset: Dataset, static_retrieval_mrl_en_v1_model: SentenceTransformer) -> None:
"""Test batch_size parameter."""
Expand Down