Skip to content
Discussion options

You must be logged in to vote

For multi-label classification, you should use the Precision and Recall metrics with the task="multilabel" argument. The mdmc_average parameter and specifying num_classes alone may not be sufficient or correct in this case.
Here is a simple example of how to compute these metrics correctly using torchmetrics:

import torch
from torchmetrics.classification import Precision, Recall

# Example multi-label targets and predictions
target = torch.tensor([
    [0, 0, 1, 1, 0],  # sample 1 belongs to class 2 and 3
    [0, 0, 1, 0, 0],  # sample 2 belongs to class 2
])

# Predictions (e.g., logits or binary predictions)
preds = torch.tensor([
    [0, 0, 0, 0, 0],  # no predicted positives for sample 1

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by Borda
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants