Retrieval Metrics Indexes Parameter needs Continuous Indexes? #2757
-
|
Sadly, there doesn't seem to be a retrieval metrics category, so I've added it here. I seem to have had a crucial misunderstanding of the retrieval metrics indexes parameter and want confirmation that I got it right now and raise awareness of this issue. So I've always understood the indexes parameter as a way for torchmetrics to differentiate the queries on BATCH-LEVEL. As in, the following code snippets should be equal: indexes = tensor([0, 0, 0, 0, 1, 1, 1])
preds = tensor([0.4, 0.01, 0.5, 0.6, 0.2, 0.3, 0.5])
preds2 = tensor([0.6, 0.01, 0.5, 0.6, 0.2, 0.3, 0.5])
target = tensor([True, False, False, True, True, False, True])
target2 = ~target
retP.update(preds, target, indexes=indexes)
retP.update(preds2, target2, indexes=indexes)
retP.compute()
# tensor(0.5000)
retP.reset()
retP(torch.concat([preds, preds2]), torch.concat([target, target2]), indexes=torch.concat([indexes, indexes + 2]))
# tensor(0.3750)When in reality the first one is equal to: retP(torch.concat([preds, preds2]), torch.concat([target, target2]), indexes=torch.concat([indexes, indexes]))
# tensor(0.5000)It believes the first query of the first batch belongs to the same query as the first query of the first batch instead of acknowledging that it's a different batch and therefore probably a different query. So the code I've used in my collator: query_idces = torch.arange(batch_size).unsqueeze(1).expand(batch_size, num_labels)That only iterates on the batch level seems to have been wrong the whole time. Might be that this is something obvious for people that work more retrieval, but it wasn't obvious to me from the docs and I found absolutely zero discussions about this online. ChatGPT-4o also didn't get that one right, whatever that's worth (Chat). Could somebody confirm this find? And tell me whether I was stupid and this should have been obvious or whether some clarification should be added to the docs? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
|
Thank you for raising this important point about the |
Beta Was this translation helpful? Give feedback.
Thank you for raising this important point about the
indexesparameter in TorchMetrics retrieval metrics. Your observation is correct: theindexesparameter is expected to contain continuous indices that distinguish queries globally, not just within each batch. This means if the same index is used in multiple batches, TorchMetrics treats those as the same query, which can lead to unintended results if batch-level differentiation was assumed.This behavior can indeed cause confusion, especially because it’s not clearly documented. Clarifying in the docs that
indexesshould uniquely identify queries across the entire dataset or evaluation run—rather than resetting per batch—would help avoid…