You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
TorchMetrics v0.9 is now out, and it brings significant changes to how the forward method works. This blog post goes over these improvements and how they affect both users of TorchMetrics and users that implement custom metrics. TorchMetrics v0.9 also includes several new metrics and bug fixes.
Since the beginning of TorchMetrics, Forward has served the dual purpose of calculating the metric on the current batch and accumulating in a global state. Internally, this was achieved by calling update twice: one for each purpose, which meant repeating the same computation. However, for many metrics, calling update twice is unnecessary to achieve both the local batch statistics and accumulating globally because the global statistics are simple reductions of the local batch states.
In v0.9, we have finally implemented a logic that can take advantage of this and will only call update once before making a simple reduction. As you can see in the figure below, this can lead to a single call of forward being 2x faster in v0.9 compared to v0.8 of the same metric.
With the improvements to forward, many metrics have become significantly faster (up to 2x)
It should be noted that this change mainly benefits metrics (for example, confusionmatrix) where calling update is expensive.
We went through all existing metrics in TorchMetrics and enabled this feature for all appropriate metrics, which was almost 95% of all metrics. We want to stress that if you are using metrics from TorchMetrics, nothing has changed to the API, and no code changes are necessary.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Highligths
TorchMetrics v0.9 is now out, and it brings significant changes to how the forward method works. This blog post goes over these improvements and how they affect both users of TorchMetrics and users that implement custom metrics. TorchMetrics v0.9 also includes several new metrics and bug fixes.
Blog: TorchMetrics v0.9 — Faster forward
The Story of the Forward Method
Since the beginning of TorchMetrics, Forward has served the dual purpose of calculating the metric on the current batch and accumulating in a global state. Internally, this was achieved by calling update twice: one for each purpose, which meant repeating the same computation. However, for many metrics, calling update twice is unnecessary to achieve both the local batch statistics and accumulating globally because the global statistics are simple reductions of the local batch states.
In v0.9, we have finally implemented a logic that can take advantage of this and will only call update once before making a simple reduction. As you can see in the figure below, this can lead to a single call of forward being 2x faster in v0.9 compared to v0.8 of the same metric.
With the improvements to forward, many metrics have become significantly faster (up to 2x)
It should be noted that this change mainly benefits metrics (for example,
confusionmatrix) where calling update is expensive.We went through all existing metrics in TorchMetrics and enabled this feature for all appropriate metrics, which was almost 95% of all metrics. We want to stress that if you are using metrics from TorchMetrics, nothing has changed to the API, and no code changes are necessary.
[0.9.0] - 2022-05-31
Added
RetrievalPrecisionRecallCurveandRetrievalRecallAtFixedPrecisionto retrieval package (Metric retrieval recall at precision #951)full_state_updatethat determinesforwardshould callupdateonce or twice (Refactor/remove double forward #984,forward upgrade utils #1033)Diceto classification package (Dice score as metric #1021)segmas IOU for mean average precision (IOU with segm masks and MAP for instance segment. #822)Changed
reductionargument toaveragein Jaccard score and added additional options (added micro average option for torch metrics #874)Removed
compute_on_stepargument (Removed Deprecatedcompute_on_stepfrom Classification #962, Remove deprecatedcompute_on_stepin Regression #967, remove deprecated compute_on_step from torchmetrics/image #979 ,Removecompute_on_stepfrom aggregation and tests #990, Removecompute_on_stepfrom wrappers #991, remove compute_on_step in torchmetrics/retrieval #993, Removecompute_on_stepfrom detection #1005, Removecompute_on_stepfrom text #1004, Removed Deprecated compute_on_step from audio #1007)Fixed
dictfor a few metrics (Fixed non-empty state dict for a few metrics #1012)torch.doublesupport in stat score metrics (Double support for stat score metrics #1023)FIDcalculation for non-equal size real and fake input (Fix FID computation for non equal size #1028)KLDivergencecould outputNan(Fix nan in KL Divergence #1030)mdmc_averageinAccuracy(Fixed default value for mdmc_average in Accuracy as per documentation #1036)MetricCollection(Fix metric collection missing update of property #1052)Contributors
@Borda, @burglarhobbit, @charlielito, @gianscarpe, @MrShevan, @phaseolud, @razmikmelikbekyan, @SkafteNicki, @tanmoyio, @vumichien
If we forgot someone due to not matching commit email with GitHub account, let us know :]
This discussion was created from the release Faster forward.
Beta Was this translation helpful? Give feedback.
All reactions