Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion timm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,10 @@ def train(config: dict[str, t.Any]):
if eval_metrics:
mlflow.log_metric("val loss", eval_metrics["loss"], step=epoch)
mlflow.log_metric("val accuracy", eval_metrics["top1"], step=epoch)
for vr in utils.EVAL_VERIFICATION_RATES:
mlflow.log_metric(f"FA at {int(100 * vr):03d}", eval_metrics[f"fa@{vr}"])
mlflow.log_metric(f"AFA at {int(100 * vr):03d}", eval_metrics[f"afa@{vr}"])


if output_dir is not None:
lrs = [param_group['lr'] for param_group in optimizer.param_groups]
Expand Down Expand Up @@ -1152,6 +1156,7 @@ def validate(
losses_m = utils.AverageMeter()
top1_m = utils.AverageMeter()
top5_m = utils.AverageMeter()
correct_with_confidences_m = utils.CorrectnessOfPredictionsWithConfidencesMeter()

model.eval()

Expand Down Expand Up @@ -1193,6 +1198,7 @@ def validate(
losses_m.update(reduced_loss.item(), input.size(0))
top1_m.update(acc1.item(), output.size(0))
top5_m.update(acc5.item(), output.size(0))
correct_with_confidences_m.update(output, target)

batch_time_m.update(time.time() - end)
end = time.time()
Expand All @@ -1206,7 +1212,32 @@ def validate(
f'Acc@5: {top5_m.val:>7.3f} ({top5_m.avg:>7.3f})'
)

metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
metrics = OrderedDict(
[
("loss", losses_m.avg),
("top1", top1_m.avg),
("top5", top5_m.avg),
*[
(f"fa@{vr}", fa)
for vr, fa in zip(
utils.EVAL_VERIFICATION_RATES,
correct_with_confidences_m.final_accuracy(
utils.EVAL_VERIFICATION_RATES
),
)
],
*[
(f"afa@{vr}", afa)
for vr, afa in zip(
utils.EVAL_VERIFICATION_RATES,
correct_with_confidences_m.average_final_accuracy(
utils.EVAL_VERIFICATION_RATES
),
)
],
]
)


return metrics

Expand Down
2 changes: 1 addition & 1 deletion timm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
world_info_from_env, is_distributed_env, is_primary
from .jit import set_jit_legacy, set_jit_fuser
from .log import setup_default_logging, FormatterNoInfo
from .metrics import AverageMeter, accuracy
from .metrics import AverageMeter, accuracy, CorrectnessOfPredictionsWithConfidencesMeter, EVAL_VERIFICATION_RATES
from .misc import natural_key, add_bool_arg, ParseKwargs
from .model import unwrap_model, get_state_dict, freeze, unfreeze, reparameterize_model
from .model_ema import ModelEma, ModelEmaV2, ModelEmaV3
Expand Down
52 changes: 52 additions & 0 deletions timm/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

Hacked together by / Copyright 2020 Ross Wightman
"""
import torch

EVAL_VERIFICATION_RATES = [0.01, 0.02, 0.05, 0.1, 0.2]

class AverageMeter:
"""Computes and stores the average and current value"""
Expand All @@ -22,6 +24,56 @@ def update(self, val, n=1):
self.avg = self.sum / self.count


class CorrectnessOfPredictionsWithConfidencesMeter:
def __init__(self):
self.reset()

def reset(self):
self.predictions_correct = []
self.confidences = []

def update(self, output, target):
confidences, preds = output.topk(k=1)
preds = preds.t()
correct = preds.eq(target.reshape(1, -1).expand_as(preds)).flatten()

self.predictions_correct.append(correct.detach().cpu())
self.confidences.append(confidences.detach().cpu())

def final_accuracy(self, vrs):
correct = torch.cat(self.predictions_correct)
confidences = torch.cat(self.confidences)

correct_sorted = correct[confidences.flatten().argsort()]
N = len(correct_sorted)

def _fa(vr):
n_verified = round(vr * N)
return (n_verified + correct_sorted[n_verified:].sum()) / N

return [_fa(vr) for vr in vrs]

def average_final_accuracy(self, vrs):
correct = torch.cat(self.predictions_correct)
confidences = torch.cat(self.confidences)

correct_sorted = correct[confidences.flatten().argsort()]
N = len(correct_sorted)

def _afa(vr):
# see https://drive.google.com/file/d/1Uag8VtD3RwsoS8hs59X6T5u_iwuqspkS/view
# for derivation of this formula
n_verified = round(vr * N)
afa_weights = torch.arange(1, N + 1) / n_verified
return (
(n_verified - 1) / 2
+ (afa_weights[:n_verified] * correct_sorted[:n_verified]).sum()
+ correct_sorted[n_verified:].sum()
) / N

return [_afa(vr) for vr in vrs]


def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
maxk = min(max(topk), output.size()[1])
Expand Down