Skip to content

Commit e0ac98e

Browse files
author
Tsyren Balzhanov
committed
Mine hard negatives: optionally output scores
1 parent 85ec645 commit e0ac98e

File tree

2 files changed

+130
-21
lines changed

2 files changed

+130
-21
lines changed

sentence_transformers/util/hard_negatives.py

Lines changed: 74 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def mine_hard_negatives(
4343
corpus_prompt: str | None = None,
4444
include_positives: bool = False,
4545
output_format: Literal["triplet", "n-tuple", "n-tuple-scores", "labeled-pair", "labeled-list"] = "triplet",
46+
include_scores: bool = False,
4647
batch_size: int = 32,
4748
faiss_batch_size: int = 16384,
4849
use_faiss: bool = False,
@@ -157,6 +158,18 @@ def mine_hard_negatives(
157158
'answer': "Richmond Football Club Richmond began 2017 with 5 straight wins, a feat it had not achieved since 1995. A series of close losses hampered the Tigers throughout the middle of the season, including a 5-point loss to the Western Bulldogs, 2-point loss to Fremantle, and a 3-point loss to the Giants. Richmond ended the season strongly with convincing victories over Fremantle and St Kilda in the final two rounds, elevating the club to 3rd on the ladder. Richmond's first final of the season against the Cats at the MCG attracted a record qualifying final crowd of 95,028; the Tigers won by 51 points. Having advanced to the first preliminary finals for the first time since 2001, Richmond defeated Greater Western Sydney by 36 points in front of a crowd of 94,258 to progress to the Grand Final against Adelaide, their first Grand Final appearance since 1982. The attendance was 100,021, the largest crowd to a grand final since 1986. The Crows led at quarter time and led by as many as 13, but the Tigers took over the game as it progressed and scored seven straight goals at one point. They eventually would win by 48 points – 16.12 (108) to Adelaide's 8.12 (60) – to end their 37-year flag drought.[22] Dustin Martin also became the first player to win a Premiership medal, the Brownlow Medal and the Norm Smith Medal in the same season, while Damien Hardwick was named AFL Coaches Association Coach of the Year. Richmond's jump from 13th to premiers also marked the biggest jump from one AFL season to the next.",
158159
'negative': "2018 NRL Grand Final The 2018 NRL Grand Final was the conclusive and premiership-deciding game of the 2018 National Rugby League season and was played on Sunday September 30 at Sydney's ANZ Stadium.[1] The match was contested between minor premiers the Sydney Roosters and defending premiers the Melbourne Storm. In front of a crowd of 82,688, Sydney won the match 21–6 to claim their 14th premiership title and their first since 2013. Roosters five-eighth Luke Keary was awarded the Clive Churchill Medal as the game's official man of the match."
159160
}
161+
>>> # To include similarity scores, use include_scores=True
162+
>>> dataset_with_scores = mine_hard_negatives(
163+
... dataset=dataset,
164+
... model=model,
165+
... include_scores=True,
166+
... # ... other parameters
167+
... )
168+
>>> dataset_with_scores
169+
Dataset({
170+
features: ['query', 'answer', 'negative', 'positive_score', 'negative_score'],
171+
num_rows: 487865
172+
})
160173
>>> dataset.push_to_hub("natural-questions-hard-negatives", "triplet-all")
161174
162175
Args:
@@ -204,13 +217,20 @@ def mine_hard_negatives(
204217
Defaults to False.
205218
output_format (Literal["triplet", "n-tuple", "n-tuple-scores", "labeled-pair", "labeled-list"]): Output format for the `datasets.Dataset`. Options are:
206219
207-
- "triplet": (anchor, positive, negative) triplets, i.e. 3 columns. Useful for e.g. :class:`~sentence_transformers.cross_encoder.losses.CachedMultipleNegativesRankingLoss`.
220+
- "triplet": (anchor, positive, negative) triplets, i.e. 3 columns. If `include_scores=True`, adds `positive_score` and `negative_score` columns (5 columns total). Useful for e.g. :class:`~sentence_transformers.cross_encoder.losses.CachedMultipleNegativesRankingLoss`.
208221
- "n-tuple": (anchor, positive, negative_1, ..., negative_n) tuples, i.e. 2 + num_negatives columns. Useful for e.g. :class:`~sentence_transformers.cross_encoder.losses.CachedMultipleNegativesRankingLoss`.
209222
- "n-tuple-scores": (anchor, positive, negative_1, ..., negative_n, score) tuples, i.e. 2 + num_negatives columns, but with one score value that's a list of similarities for the query-positive and each of the query-negative pairs. Useful for e.g. :class:`~sentence_transformers.sparse_encoder.losses.SparseMarginMSELoss`.
210-
- "labeled-pair": (anchor, passage, label) text tuples with a label of 0 for negative and 1 for positive, i.e. 3 columns. Useful for e.g. :class:`~sentence_transformers.cross_encoder.losses.BinaryCrossEntropyLoss`.
211-
- "labeled-list": (anchor, [doc1, doc2, ..., docN], [label1, label2, ..., labelN]) triplets with labels of 0 for negative and 1 for positive, i.e. 3 columns. Useful for e.g. :class:`~sentence_transformers.cross_encoder.losses.LambdaLoss`.
223+
- "labeled-pair": (anchor, passage, label) text tuples with a label of 0 for negative and 1 for positive, i.e. 3 columns. If `include_scores=True`, adds a `score` column (4 columns total). Useful for e.g. :class:`~sentence_transformers.cross_encoder.losses.BinaryCrossEntropyLoss`.
224+
- "labeled-list": (anchor, [doc1, doc2, ..., docN], [label1, label2, ..., labelN]) tuples with labels of 0 for negative and 1 for positive, i.e. 3 columns. If `include_scores=True`, adds a `scores` column with corresponding similarity scores (4 columns total). Useful for e.g. :class:`~sentence_transformers.cross_encoder.losses.LambdaLoss`.
212225
213226
Defaults to "triplet".
227+
include_scores (bool): Whether to include similarity scores in the output dataset. When True, adds score fields to the output:
228+
- For "triplet" format: adds `positive_score` and `negative_score` columns
229+
- For "labeled-pair" format: adds `score` column
230+
- For "labeled-list" format: adds `scores` column
231+
- For "n-tuple-scores" format: scores are always included regardless of this parameter
232+
- For "n-tuple" format: no scores are added
233+
Defaults to False.
214234
batch_size (int): Batch size for encoding the dataset. Defaults to 32.
215235
faiss_batch_size (int): Batch size for FAISS top-k search. Defaults to 16384.
216236
use_faiss (bool): Whether to use FAISS for similarity search. May be recommended for large datasets. Defaults to False.
@@ -227,8 +247,14 @@ def mine_hard_negatives(
227247
228248
229249
Returns:
230-
Dataset: A dataset containing (anchor, positive, negative) triplets, (anchor, passage, label) text tuples with
231-
a label, or (anchor, positive, negative_1, ..., negative_n) tuples.
250+
Dataset: A dataset containing the specified output format. When `include_scores=True`, score fields are added:
251+
- "triplet": (anchor, positive, negative, positive_score, negative_score) tuples
252+
- "labeled-pair": (anchor, passage, label, score) tuples
253+
- "labeled-list": (anchor, [passages], [labels], [scores]) tuples
254+
- "n-tuple": (anchor, positive, negative_1, ..., negative_n) tuples (no scores)
255+
- "n-tuple-scores": (anchor, positive, negative_1, ..., negative_n, score) tuples (always includes scores)
256+
257+
When `include_scores=False` (default), only the basic format is returned without score fields.
232258
"""
233259
if not is_datasets_available():
234260
raise ImportError("Please install `datasets` to use this function: `pip install datasets`.")
@@ -468,7 +494,7 @@ def mine_hard_negatives(
468494

469495
n_positives = [len(p) for p in positive_indices]
470496

471-
# re-sort the positives and all_queries according to the deduplicated queries
497+
# Reorder the positives and all_queries according to the deduplicated queries
472498
positives = []
473499
all_queries = []
474500
for idx in range(n_queries):
@@ -523,10 +549,12 @@ def mine_hard_negatives(
523549
# If there are multiple positives, we need to define which one to use for the margin
524550
# To be on the safe side, we will use the _minimum_ positive score (i.e., harder positive) for the margin
525551
max_positive_scores = torch.empty(n_queries, device=positive_scores.device, dtype=positive_scores.dtype)
526-
start_idx = 0
552+
positive_score_idx = 0
527553
for q_idx in range(n_queries):
528-
max_positive_scores[q_idx] = torch.min(positive_scores[start_idx : start_idx + n_positives[q_idx]])
529-
start_idx += n_positives[q_idx - 1]
554+
max_positive_scores[q_idx] = torch.min(
555+
positive_scores[positive_score_idx : positive_score_idx + n_positives[q_idx]]
556+
)
557+
positive_score_idx += n_positives[q_idx - 1]
530558

531559
if absolute_margin is not None:
532560
removed_indices = scores + absolute_margin > max_positive_scores.repeat(scores.size(1), 1).T
@@ -612,15 +640,15 @@ def mine_hard_negatives(
612640
negative_scores = negative_scores[indices_to_keep]
613641

614642
# the anchor_indices matrix is shaped [n_total_queries, n_negatives]
615-
start_idx = 0
643+
positive_score_idx = 0
616644
for q_idx in range(n_queries):
617-
anchor_indices[start_idx : start_idx + n_positives[q_idx]] = torch.tensor(q_idx).repeat(
645+
anchor_indices[positive_score_idx : positive_score_idx + n_positives[q_idx]] = torch.tensor(q_idx).repeat(
618646
n_positives[q_idx], num_negatives
619647
)
620-
pos_indices[start_idx : start_idx + n_positives[q_idx]] = (
648+
pos_indices[positive_score_idx : positive_score_idx + n_positives[q_idx]] = (
621649
positive_indices[q_idx].repeat(num_negatives, 1).T
622650
)
623-
start_idx += n_positives[q_idx]
651+
positive_score_idx += n_positives[q_idx]
624652

625653
anchor_indices = anchor_indices[indices_to_keep]
626654
positive_indices = pos_indices[indices_to_keep]
@@ -631,10 +659,20 @@ def mine_hard_negatives(
631659
"negative": [],
632660
}
633661

634-
for anchor_idx, positive_idx, negative_idx in zip(anchor_indices, positive_indices, indices):
662+
if include_scores:
663+
dataset_data["positive_score"] = []
664+
dataset_data["negative_score"] = []
665+
666+
positive_scores_expanded = positive_scores.repeat(num_negatives, 1).T[indices_to_keep]
667+
for anchor_idx, positive_corpus_idx, negative_idx, pos_score, neg_score in zip(
668+
anchor_indices, positive_indices, indices, positive_scores_expanded, negative_scores
669+
):
635670
dataset_data[anchor_column_name].append(queries[anchor_idx])
636-
dataset_data[positive_column_name].append(corpus[positive_idx])
671+
dataset_data[positive_column_name].append(corpus[positive_corpus_idx])
637672
dataset_data["negative"].append(corpus[negative_idx])
673+
if include_scores:
674+
dataset_data["positive_score"].append(pos_score.item())
675+
dataset_data["negative_score"].append(neg_score.item())
638676
difference_scores = positive_scores.repeat(num_negatives, 1).T[indices_to_keep] - negative_scores
639677
maximum_possible_samples = indices_to_keep.numel()
640678

@@ -647,17 +685,27 @@ def mine_hard_negatives(
647685
"label": [],
648686
}
649687

688+
if include_scores:
689+
dataset_data["score"] = []
690+
691+
positive_score_idx = 0
650692
for query_idx in range(n_queries):
651-
for positive_idx in positive_indices[query_idx]:
693+
for positive_corpus_idx in positive_indices[query_idx]:
652694
dataset_data[anchor_column_name].append(queries[query_idx])
653-
dataset_data[positive_column_name].append(corpus[positive_idx])
695+
dataset_data[positive_column_name].append(corpus[positive_corpus_idx])
654696
dataset_data["label"].append(1)
697+
if include_scores:
698+
dataset_data["score"].append(positive_scores[positive_score_idx].item())
699+
positive_score_idx += 1
700+
655701
for negative_idx, negative_score in zip(indices[query_idx], negative_scores[query_idx]):
656702
if negative_score == -float("inf"):
657703
continue
658704
dataset_data[anchor_column_name].append(queries[query_idx])
659705
dataset_data[positive_column_name].append(corpus[negative_idx])
660706
dataset_data["label"].append(0)
707+
if include_scores:
708+
dataset_data["score"].append(negative_score.item())
661709

662710
negative_scores = negative_scores[indices_to_keep]
663711
difference_scores = positive_scores.repeat(num_negatives, 1).T[indices_to_keep] - negative_scores
@@ -677,7 +725,7 @@ def mine_hard_negatives(
677725
for i, neg_indices in enumerate(indices.T, start=1)
678726
},
679727
}
680-
if output_format == "n-tuple-scores":
728+
if include_scores or output_format == "n-tuple-scores":
681729
dataset_data["score"] = torch.cat(
682730
[positive_scores[indices_to_keep].unsqueeze(-1), negative_scores], dim=1
683731
).tolist()
@@ -697,6 +745,14 @@ def mine_hard_negatives(
697745
],
698746
"labels": [[1] + [0] * sum(keep_row) for keep_row in indices_to_keep if keep_row.any()],
699747
}
748+
749+
if include_scores:
750+
dataset_data["scores"] = [
751+
[positive_scores[idx].item()] + negative_scores[idx][keep_row].tolist()
752+
for idx, keep_row in enumerate(indices_to_keep)
753+
if keep_row.any()
754+
]
755+
700756
negative_scores = negative_scores[indices_to_keep]
701757
difference_scores = positive_scores.repeat(num_negatives, 1).T[indices_to_keep] - negative_scores
702758
maximum_possible_samples = indices_to_keep.size(0)

tests/util/test_hard_negatives.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -397,13 +397,32 @@ def test_include_positives_with_labeled_formats(
397397
def test_output_formats(dataset: Dataset, static_retrieval_mrl_en_v1_model: SentenceTransformer) -> None:
398398
"""Test all output_format options."""
399399
model = static_retrieval_mrl_en_v1_model
400-
# Test triplet format
400+
401+
# Test triplet format without scores (default)
401402
result_triplet = mine_hard_negatives(dataset=dataset, model=model, output_format="triplet", verbose=False)
402403
assert "query" in result_triplet.column_names
403404
assert "passage" in result_triplet.column_names
404405
assert "negative" in result_triplet.column_names
406+
assert "positive_score" not in result_triplet.column_names
407+
assert "negative_score" not in result_triplet.column_names
405408
assert len(result_triplet.column_names) == 3
406409

410+
# Test triplet format with scores
411+
result_triplet_scores = mine_hard_negatives(
412+
dataset=dataset, model=model, output_format="triplet", include_scores=True, verbose=False
413+
)
414+
assert "query" in result_triplet_scores.column_names
415+
assert "passage" in result_triplet_scores.column_names
416+
assert "negative" in result_triplet_scores.column_names
417+
assert "positive_score" in result_triplet_scores.column_names
418+
assert "negative_score" in result_triplet_scores.column_names
419+
assert len(result_triplet_scores.column_names) == 5
420+
421+
# Verify scores are numeric values
422+
if len(result_triplet_scores) > 0:
423+
assert all(isinstance(score, (int, float)) for score in result_triplet_scores["positive_score"])
424+
assert all(isinstance(score, (int, float)) for score in result_triplet_scores["negative_score"])
425+
407426
# Test n-tuple format
408427
result_ntuple = mine_hard_negatives(
409428
dataset=dataset, model=model, num_negatives=2, output_format="n-tuple", verbose=False
@@ -426,21 +445,49 @@ def test_output_formats(dataset: Dataset, static_retrieval_mrl_en_v1_model: Sent
426445
# Verify scores are lists of expected length (1 positive + num_negatives)
427446
assert all(len(score) == 3 for score in result_scores["score"])
428447

429-
# Test labeled-pair format
448+
# Test labeled-pair format without scores (default)
430449
result_pair = mine_hard_negatives(dataset=dataset, model=model, output_format="labeled-pair", verbose=False)
431450
assert "query" in result_pair.column_names
432451
assert "passage" in result_pair.column_names
433452
assert "label" in result_pair.column_names
453+
assert "score" not in result_pair.column_names
454+
assert len(result_pair.column_names) == 3
455+
456+
# Test labeled-pair format with scores
457+
result_pair_scores = mine_hard_negatives(
458+
dataset=dataset, model=model, output_format="labeled-pair", include_scores=True, verbose=False
459+
)
460+
assert "query" in result_pair_scores.column_names
461+
assert "passage" in result_pair_scores.column_names
462+
assert "label" in result_pair_scores.column_names
463+
assert "score" in result_pair_scores.column_names
464+
assert len(result_pair_scores.column_names) == 4
434465

435466
# Verify labels are 0 or 1
436467
labels = set(result_pair["label"])
437468
assert labels == {0, 1}
438469

439-
# Test labeled-list format
470+
# Verify scores are numeric values
471+
if len(result_pair_scores) > 0:
472+
assert all(isinstance(score, (int, float)) for score in result_pair_scores["score"])
473+
474+
# Test labeled-list format without scores (default)
440475
result_list = mine_hard_negatives(dataset=dataset, model=model, output_format="labeled-list", verbose=False)
441476
assert "query" in result_list.column_names
442477
assert "passage" in result_list.column_names
443478
assert "labels" in result_list.column_names
479+
assert "scores" not in result_list.column_names
480+
assert len(result_list.column_names) == 3
481+
482+
# Test labeled-list format with scores
483+
result_list_scores = mine_hard_negatives(
484+
dataset=dataset, model=model, output_format="labeled-list", include_scores=True, verbose=False
485+
)
486+
assert "query" in result_list_scores.column_names
487+
assert "passage" in result_list_scores.column_names
488+
assert "labels" in result_list_scores.column_names
489+
assert "scores" in result_list_scores.column_names
490+
assert len(result_list_scores.column_names) == 4
444491

445492
# Verify each item in 'passage' is a list
446493
assert all(isinstance(p, list) for p in result_list["passage"])
@@ -451,6 +498,12 @@ def test_output_formats(dataset: Dataset, static_retrieval_mrl_en_v1_model: Sent
451498
assert label_list[0] == 1
452499
assert all(label == 0 for label in label_list[1:])
453500

501+
# Verify each item in 'scores' is a list with numeric values
502+
if len(result_list_scores) > 0:
503+
for score_list in result_list_scores["scores"]:
504+
assert isinstance(score_list, list)
505+
assert all(isinstance(score, (int, float)) for score in score_list)
506+
454507

455508
def test_batch_size(dataset: Dataset, static_retrieval_mrl_en_v1_model: SentenceTransformer) -> None:
456509
"""Test batch_size parameter."""

0 commit comments

Comments
 (0)