-
|
Hi! Is there a function or a set of arguments that I can use in order to calculate Precision and Recall for a multi-label problem? Note that with multi-label I mean that each sample can be classified into more than one class. The following is not returning what I would expect: import torch
from torchmetrics import Accuracy, Precision, Recall
target = torch.tensor([
[0, 0, 1, 1, 0], # Sample 1 belongs to class 2 and 3 (zero-indexed)
[0, 0, 1, 0, 0], # Sample 2 belongs to class 2 (zero-indexed)
])
preds = torch.tensor([
[0, 0, 0, 0, 0], # Sample 1 predicted to belong to no class
[0, 0, 0, 0, 0], # Sample 2 predicted to belong to no class
])
metric = Precision(num_classes=5, mdmc_average="samplewise")
print(metric(preds, target))It returns: Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
|
For multi-label classification, you should use the import torch
from torchmetrics.classification import Precision, Recall
# Example multi-label targets and predictions
target = torch.tensor([
[0, 0, 1, 1, 0], # sample 1 belongs to class 2 and 3
[0, 0, 1, 0, 0], # sample 2 belongs to class 2
])
# Predictions (e.g., logits or binary predictions)
preds = torch.tensor([
[0, 0, 0, 0, 0], # no predicted positives for sample 1
[0, 0, 0, 0, 0], # no predicted positives for sample 2
])
# Initialize metrics for multilabel task
precision = Precision(task="multilabel", num_labels=5, threshold=0.5)
recall = Recall(task="multilabel", num_labels=5, threshold=0.5)
print("Precision:", precision(preds, target))
print("Recall:", recall(preds, target))This will correctly return 0 for precision and recall since there are no true positive predictions.
This usage should solve the issue where the metric returned an unexpected non-zero value despite no true positives. |
Beta Was this translation helpful? Give feedback.
For multi-label classification, you should use the
PrecisionandRecallmetrics with thetask="multilabel"argument. Themdmc_averageparameter and specifying num_classes alone may not be sufficient or correct in this case.Here is a simple example of how to compute these metrics correctly using torchmetrics: