Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions sentence_transformers/evaluation/BinaryClassificationEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,12 @@ def from_input_examples(cls, examples: list[InputExample], **kwargs):
return cls(sentences1, sentences2, scores, **kwargs)

def __call__(
self, model: SentenceTransformer, output_path: str | None = None, epoch: int = -1, steps: int = -1
self,
model: SentenceTransformer,
output_path: str | None = None,
epoch: int = -1,
steps: int = -1,
encode_args: dict = {},
) -> dict[str, float]:
"""
Compute the evaluation metrics for the given model.
Expand All @@ -159,6 +164,7 @@ def __call__(
output_path (str, optional): Path to save the evaluation results CSV file. Defaults to None.
epoch (int, optional): The epoch number. Defaults to -1.
steps (int, optional): The number of steps. Defaults to -1.
encode_args (dict, optional): The args to be passed to the encode method. Defaults to {}.

Returns:
Dict[str, float]: A dictionary containing the evaluation metrics.
Expand All @@ -178,7 +184,7 @@ def __call__(
if not self.similarity_fn_names:
self.similarity_fn_names = [model.similarity_fn_name]
self._append_csv_headers(self.similarity_fn_names)
scores = self.compute_metrices(model)
scores = self.compute_metrices(model, encode_args)

file_output_data = [epoch, steps]

Expand Down Expand Up @@ -220,17 +226,17 @@ def __call__(
self.store_metrics_in_model_card_data(model, metrics, epoch, steps)
return metrics

def compute_metrices(self, model: SentenceTransformer) -> dict[str, dict[str, float]]:
def compute_metrices(self, model: SentenceTransformer, encode_args: dict) -> dict[str, dict[str, float]]:
try:
# If the sentences are hashable, then we can use a set to avoid embedding the same sentences multiple
# times
sentences = list(set(self.sentences1 + self.sentences2))
except TypeError:
# Otherwise we just embed everything, e.g. if the sentences are images for evaluating a CLIP model
embeddings1 = self.embed_inputs(model, self.sentences1)
embeddings2 = self.embed_inputs(model, self.sentences2)
embeddings1 = self.embed_inputs(model, self.sentences1, **encode_args)
embeddings2 = self.embed_inputs(model, self.sentences2, **encode_args)
else:
embeddings = self.embed_inputs(model, sentences)
embeddings = self.embed_inputs(model, sentences, **encode_args)
emb_dict = {sent: emb for sent, emb in zip(sentences, embeddings)}
embeddings1 = [emb_dict[sent] for sent in self.sentences1]
embeddings2 = [emb_dict[sent] for sent in self.sentences2]
Expand Down
11 changes: 8 additions & 3 deletions sentence_transformers/evaluation/EmbeddingSimilarityEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,12 @@ def from_input_examples(cls, examples: list[InputExample], **kwargs):
return cls(sentences1, sentences2, scores, **kwargs)

def __call__(
self, model: SentenceTransformer, output_path: str | None = None, epoch: int = -1, steps: int = -1
self,
model: SentenceTransformer,
output_path: str | None = None,
epoch: int = -1,
steps: int = -1,
encode_args: dict = {},
) -> dict[str, float]:
if epoch != -1:
if steps == -1:
Expand All @@ -165,8 +170,8 @@ def __call__(

logger.info(f"EmbeddingSimilarityEvaluator: Evaluating the model on the {self.name} dataset{out_txt}:")

embeddings1 = self.embed_inputs(model, self.sentences1)
embeddings2 = self.embed_inputs(model, self.sentences2)
embeddings1 = self.embed_inputs(model, self.sentences1, **encode_args)
embeddings2 = self.embed_inputs(model, self.sentences2, **encode_args)
# Binary and ubinary embeddings are packed, so we need to unpack them for the distance metrics
if self.precision == "binary":
embeddings1 = (embeddings1 + 128).astype(np.uint8)
Expand Down
7 changes: 6 additions & 1 deletion sentence_transformers/evaluation/LabelAccuracyEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@ def __init__(self, dataloader: DataLoader, name: str = "", softmax_model=None, w
self.primary_metric = "accuracy"

def __call__(
self, model: SentenceTransformer, output_path: str | None = None, epoch: int = -1, steps: int = -1
self,
model: SentenceTransformer,
output_path: str | None = None,
epoch: int = -1,
steps: int = -1,
encode_args: dict = {},
) -> dict[str, float]:
model.eval()
total = 0
Expand Down
7 changes: 4 additions & 3 deletions sentence_transformers/evaluation/MSEEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
name: str = "",
write_csv: bool = True,
truncate_dim: int | None = None,
encode_args: dict = {},
):
super().__init__()
self.truncate_dim = truncate_dim
Expand All @@ -92,10 +93,10 @@ def __init__(
self.write_csv = write_csv
self.primary_metric = "negative_mse"

self.source_embeddings = self.embed_inputs(teacher_model, source_sentences)
self.source_embeddings = self.embed_inputs(teacher_model, source_sentences, **encode_args)

def __call__(
self, model: SentenceTransformer, output_path: str | None = None, epoch=-1, steps=-1
self, model: SentenceTransformer, output_path: str | None = None, epoch=-1, steps=-1, encode_args: dict = {}
) -> dict[str, float]:
if epoch != -1:
if steps == -1:
Expand All @@ -107,7 +108,7 @@ def __call__(
if self.truncate_dim is not None:
out_txt += f" (truncated to {self.truncate_dim})"

target_embeddings = self.embed_inputs(model, self.target_sentences)
target_embeddings = self.embed_inputs(model, self.target_sentences, **encode_args)

mse = ((self.source_embeddings - target_embeddings) ** 2).mean()
mse = mse * 100
Expand Down
12 changes: 9 additions & 3 deletions sentence_transformers/evaluation/MSEEvaluatorFromDataFrame.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
name: str = "",
write_csv: bool = True,
truncate_dim: int | None = None,
encode_args: dict = {},
):
super().__init__()
self.combinations = combinations
Expand Down Expand Up @@ -81,11 +82,16 @@ def __init__(

all_source_sentences = list(all_source_sentences)

all_src_embeddings = self.embed_inputs(teacher_model, all_source_sentences)
all_src_embeddings = self.embed_inputs(teacher_model, all_source_sentences, **encode_args)
self.teacher_embeddings = {sent: emb for sent, emb in zip(all_source_sentences, all_src_embeddings)}

def __call__(
self, model: SentenceTransformer, output_path: str | None = None, epoch: int = -1, steps: int = -1
self,
model: SentenceTransformer,
output_path: str | None = None,
epoch: int = -1,
steps: int = -1,
encode_args: dict = {},
) -> dict[str, float]:
model.eval()

Expand All @@ -94,7 +100,7 @@ def __call__(
src_sentences, trg_sentences = self.data[(src_lang, trg_lang)]

src_embeddings = np.asarray([self.teacher_embeddings[sent] for sent in src_sentences])
trg_embeddings = np.asarray(self.embed_inputs(model, trg_sentences))
trg_embeddings = np.asarray(self.embed_inputs(model, trg_sentences, **encode_args))

mse = ((src_embeddings - trg_embeddings) ** 2).mean()
mse *= 100
Expand Down
28 changes: 20 additions & 8 deletions sentence_transformers/evaluation/RerankingEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,12 @@ def __init__(
self.primary_metric = f"ndcg@{self.at_k}"

def __call__(
self, model: SentenceTransformer, output_path: str | None = None, epoch: int = -1, steps: int = -1
self,
model: SentenceTransformer,
output_path: str | None = None,
epoch: int = -1,
steps: int = -1,
encode_args: dict = {},
) -> dict[str, float]:
"""
Evaluates the model on the dataset and returns the evaluation metrics.
Expand All @@ -161,7 +166,7 @@ def __call__(

logger.info(f"RerankingEvaluator: Evaluating the model on the {self.name} dataset{out_txt}:")

scores = self.compute_metrices(model)
scores = self.compute_metrices(model, encode_args)
mean_ap = scores["map"]
mean_mrr = scores["mrr"]
mean_ndcg = scores["ndcg"]
Expand Down Expand Up @@ -197,12 +202,13 @@ def __call__(
self.store_metrics_in_model_card_data(model, metrics, epoch, steps)
return metrics

def compute_metrices(self, model: SentenceTransformer):
def compute_metrices(self, model: SentenceTransformer, encode_args: dict):
"""
Computes the evaluation metrics for the given model.

Args:
model (SentenceTransformer): The SentenceTransformer model to compute metrics for.
encode_args (dict): The args to be passed to the encode method. Defaults to {}.

Returns:
Dict[str, float]: A dictionary containing the evaluation metrics.
Expand All @@ -213,12 +219,13 @@ def compute_metrices(self, model: SentenceTransformer):
else self.compute_metrices_individual(model)
)

def compute_metrices_batched(self, model: SentenceTransformer):
def compute_metrices_batched(self, model: SentenceTransformer, encode_args: dict):
"""
Computes the evaluation metrics in a batched way, by batching all queries and all documents together.

Args:
model (SentenceTransformer): The SentenceTransformer model to compute metrics for.
encode_args (dict): The args to be passed to the encode method. Defaults to {}.

Returns:
Dict[str, float]: A dictionary containing the evaluation metrics.
Expand All @@ -241,7 +248,7 @@ def compute_metrices_batched(self, model: SentenceTransformer):
all_docs.extend(sample["negative"])

all_docs_embs = self.embed_inputs(
model, all_docs, encode_fn_name="document", show_progress_bar=self.show_progress_bar
model, all_docs, encode_fn_name="document", show_progress_bar=self.show_progress_bar, **encode_args
)

# Compute scores
Expand Down Expand Up @@ -286,12 +293,13 @@ def compute_metrices_batched(self, model: SentenceTransformer):

return {"map": mean_ap, "mrr": mean_mrr, "ndcg": mean_ndcg}

def compute_metrices_individual(self, model: SentenceTransformer):
def compute_metrices_individual(self, model: SentenceTransformer, encode_args: dict):
"""
Computes the evaluation metrics individually by embedding every (query, positive, negative) tuple individually.

Args:
model (SentenceTransformer): The SentenceTransformer model to compute metrics for.
encode_args (dict): The args to be passed to the encode method. Defaults to {}.

Returns:
Dict[str, float]: A dictionary containing the evaluation metrics.
Expand All @@ -311,8 +319,12 @@ def compute_metrices_individual(self, model: SentenceTransformer):
docs = positive + negative
is_relevant = [1] * len(positive) + [0] * len(negative)

query_emb = self.embed_inputs(model, [query], encode_fn_name="query", show_progress_bar=False)
docs_emb = self.embed_inputs(model, docs, encode_fn_name="document", show_progress_bar=False)
query_emb = self.embed_inputs(
model, [query], encode_fn_name="query", show_progress_bar=False, **encode_args
)
docs_emb = self.embed_inputs(
model, docs, encode_fn_name="document", show_progress_bar=False, **encode_args
)

pred_scores = self.similarity_fct(query_emb, docs_emb)
if len(pred_scores.shape) > 1:
Expand Down
9 changes: 7 additions & 2 deletions sentence_transformers/evaluation/SequentialEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,17 @@ def __init__(self, evaluators: Iterable[SentenceEvaluator], main_score_function=
self.main_score_function = main_score_function

def __call__(
self, model: SentenceTransformer, output_path: str | None = None, epoch: int = -1, steps: int = -1
self,
model: SentenceTransformer,
output_path: str | None = None,
epoch: int = -1,
steps: int = -1,
encode_args: dict = {},
) -> dict[str, float]:
evaluations = []
scores = []
for evaluator_idx, evaluator in enumerate(self.evaluators):
evaluation = evaluator(model, output_path, epoch, steps)
evaluation = evaluator(model, output_path, epoch, steps, encode_args)

if not isinstance(evaluation, dict):
scores.append(evaluation)
Expand Down
11 changes: 8 additions & 3 deletions sentence_transformers/evaluation/TranslationEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,12 @@ def __init__(
self.primary_metric = "mean_accuracy"

def __call__(
self, model: SentenceTransformer, output_path: str | None = None, epoch: int = -1, steps: int = -1
self,
model: SentenceTransformer,
output_path: str | None = None,
epoch: int = -1,
steps: int = -1,
encode_args: dict = {},
) -> dict[str, float]:
if epoch != -1:
if steps == -1:
Expand All @@ -114,8 +119,8 @@ def __call__(

logger.info(f"Evaluating translation matching Accuracy of the model on the {self.name} dataset{out_txt}:")

embeddings1 = torch.stack(self.embed_inputs(model, self.source_sentences))
embeddings2 = torch.stack(self.embed_inputs(model, self.target_sentences))
embeddings1 = torch.stack(self.embed_inputs(model, self.source_sentences, **encode_args))
embeddings2 = torch.stack(self.embed_inputs(model, self.target_sentences, **encode_args))

cos_sims = pytorch_cos_sim(embeddings1, embeddings2).detach().cpu().numpy()

Expand Down
13 changes: 9 additions & 4 deletions sentence_transformers/evaluation/TripletEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,12 @@ def from_input_examples(cls, examples: list[InputExample], **kwargs):
return cls(anchors, positives, negatives, **kwargs)

def __call__(
self, model: SentenceTransformer, output_path: str | None = None, epoch: int = -1, steps: int = -1
self,
model: SentenceTransformer,
output_path: str | None = None,
epoch: int = -1,
steps: int = -1,
encode_args: dict = {},
) -> dict[str, float]:
if epoch != -1:
if steps == -1:
Expand All @@ -176,9 +181,9 @@ def __call__(

logger.info(f"TripletEvaluator: Evaluating the model on the {self.name} dataset{out_txt}:")

embeddings_anchors = self.embed_inputs(model, self.anchors)
embeddings_positives = self.embed_inputs(model, self.positives)
embeddings_negatives = self.embed_inputs(model, self.negatives)
embeddings_anchors = self.embed_inputs(model, self.anchors, **encode_args)
embeddings_positives = self.embed_inputs(model, self.positives, **encode_args)
embeddings_negatives = self.embed_inputs(model, self.negatives, **encode_args)

if not self.similarity_fn_names:
self.similarity_fn_names = [model.similarity_fn_name]
Expand Down
Loading