Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Smart update of metric collection #709

Merged
merged 51 commits into from
Feb 8, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
7a00a63
update
SkafteNicki Oct 19, 2021
501aef3
Merge branch 'master' into smart_collection
Borda Jan 20, 2022
2264b18
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 20, 2022
a911b5e
update docs
SkafteNicki Jan 23, 2022
f5eac5f
fix imports
SkafteNicki Jan 23, 2022
b2cb581
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 23, 2022
307f8ba
Apply suggestions from code review
SkafteNicki Jan 24, 2022
e788128
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 24, 2022
6e29cbf
somethings working
SkafteNicki Jan 24, 2022
f391fdd
better solution
SkafteNicki Jan 24, 2022
c7c05c1
revert registry
SkafteNicki Jan 24, 2022
7996ede
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 24, 2022
ae505fc
revert
SkafteNicki Jan 24, 2022
efd6a95
Merge branch 'smart_collection' of https://github.com/PyTorchLightnin…
SkafteNicki Jan 24, 2022
9495da2
flake8
SkafteNicki Jan 24, 2022
e695ad0
Merge branch 'master' into smart_collection
SkafteNicki Jan 24, 2022
44cf2b7
improve testing
SkafteNicki Jan 25, 2022
9af0392
Merge branch 'smart_collection' of https://github.com/PyTorchLightnin…
SkafteNicki Jan 25, 2022
6754a93
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 25, 2022
7b8bd72
accuracy fix
SkafteNicki Jan 25, 2022
dcec4d4
Merge branch 'smart_collection' of https://github.com/PyTorchLightnin…
SkafteNicki Jan 25, 2022
c9c0f50
Apply suggestions from code review
SkafteNicki Jan 27, 2022
3ada9af
Merge branch 'master' into smart_collection
SkafteNicki Jan 27, 2022
e957fda
fix tests
SkafteNicki Jan 31, 2022
4ff45cb
enhancement
SkafteNicki Jan 31, 2022
3abd4aa
doctest
SkafteNicki Jan 31, 2022
0b17d8d
Merge branch 'smart_collection' of https://github.com/PyTorchLightnin…
SkafteNicki Jan 31, 2022
8eda1b6
Merge branch 'master' into smart_collection
SkafteNicki Jan 31, 2022
0c97613
mypy
SkafteNicki Jan 31, 2022
0ffb39d
Merge branch 'smart_collection' of https://github.com/PyTorchLightnin…
SkafteNicki Jan 31, 2022
3f30985
typo
Borda Feb 1, 2022
9d24233
optimize
SkafteNicki Feb 1, 2022
01b3dee
Merge branch 'smart_collection' of https://github.com/PyTorchLightnin…
SkafteNicki Feb 1, 2022
8fad12c
Merge branch 'master' into smart_collection
Borda Feb 5, 2022
475f4b2
Merge branch 'master' into smart_collection
SkafteNicki Feb 8, 2022
3512128
docs
SkafteNicki Feb 8, 2022
6b524b9
Merge branches 'smart_collection' and 'smart_collection' of https://g…
SkafteNicki Feb 8, 2022
5e187ae
fix tests
SkafteNicki Feb 8, 2022
18f393f
Apply suggestions from code review
Borda Feb 8, 2022
85a3174
Update docs/source/pages/overview.rst
SkafteNicki Feb 8, 2022
5638156
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 8, 2022
09d0613
Apply suggestions from code review
SkafteNicki Feb 8, 2022
a966667
Update docs/source/pages/overview.rst
SkafteNicki Feb 8, 2022
c24e8ad
speed tests
SkafteNicki Feb 8, 2022
5e0d801
Merge branch 'master' into smart_collection
SkafteNicki Feb 8, 2022
ef26080
add speed tests
SkafteNicki Feb 8, 2022
7262cf5
Merge branch 'smart_collection' of https://github.com/PyTorchLightnin…
SkafteNicki Feb 8, 2022
a735762
flake8
SkafteNicki Feb 8, 2022
2c63299
fix docs
SkafteNicki Feb 8, 2022
035cea3
fix typing
SkafteNicki Feb 8, 2022
7c360e5
Apply suggestions from code review
SkafteNicki Feb 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added support for `MetricCollection` in `MetricTracker` ([#718](https://github.com/PyTorchLightning/metrics/pull/718))


- Added smart update of `MetricCollection` ([#718](https://github.com/PyTorchLightning/metrics/pull/709))
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

### Changed

- Used `torch.bucketize` in calibration error when `torch>1.8` for faster computations ([#769](https://github.com/PyTorchLightning/metrics/pull/769))
Expand Down
16 changes: 14 additions & 2 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging as __logging
import os

from torchmetrics.__about__ import * # noqa: F401, F403
from torchmetrics.__about__ import * # noqa: F403
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

_logger = __logging.getLogger("torchmetrics")
_logger.addHandler(__logging.StreamHandler())
Expand All @@ -25,7 +25,7 @@
SignalDistortionRatio,
SignalNoiseRatio,
)
from torchmetrics.classification import ( # noqa: E402, F401
from torchmetrics.classification import ( # noqa: E402
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
AUC,
AUROC,
F1,
Expand Down Expand Up @@ -102,8 +102,19 @@
WordInfoLost,
WordInfoPreserved,
)
from torchmetrics.utilities.registry import register_compute_group
from torchmetrics.wrappers import BootStrapper, MetricTracker, MinMaxMetric, MultioutputWrapper # noqa: E402

register_compute_group(F1, FBeta, Recall, Precision, Specificity, StatScores)
register_compute_group(AUROC, AveragePrecision, PrecisionRecallCurve, ROC)
register_compute_group(BinnedPrecisionRecallCurve, BinnedAveragePrecision)
register_compute_group(CohenKappa, ConfusionMatrix, IoU, MatthewsCorrcoef)
register_compute_group(CosineSimilarity, SpearmanCorrcoef)
register_compute_group(
RetrievalMAP, RetrievalMRR, RetrievalFallOut, RetrievalNormalizedDCG, RetrievalPrecision, RetrievalRecall
)


__all__ = [
"functional",
"Accuracy",
Expand Down Expand Up @@ -185,6 +196,7 @@
"StatScores",
"SumMetric",
"SymmetricMeanAbsolutePercentageError",
"WER",
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
"TranslationEditRate",
"WordErrorRate",
"CharErrorRate",
Expand Down
93 changes: 86 additions & 7 deletions torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy
from typing import Any, Dict, Hashable, Iterable, Optional, Sequence, Tuple, Union

import torch
from torch import nn
from torch import Tensor, nn

from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
Expand Down Expand Up @@ -45,6 +44,13 @@ class name as key for the output dict.

postfix: a string to append after the keys of the output dict

enable_compute_groups:
By defualt the MetricCollection will try to reduce the computations needed for the metrics in the collection
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
Borda marked this conversation as resolved.
Show resolved Hide resolved
by checking if they belong to the same *compute group*. All metrics in a compute group share the same metric
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
state and are therefore only different in their compute step e.g. accuracy, precision and recall can all be
computed from the true positives/negatives and false positives/nagatives. Set this argument to `False` for
disabling this behaviour.

Raises:
ValueError:
If one of the elements of ``metrics`` is not an instance of ``pl.metrics.Metric``.
Expand Down Expand Up @@ -94,13 +100,15 @@ def __init__(
*additional_metrics: Metric,
prefix: Optional[str] = None,
postfix: Optional[str] = None,
enable_compute_groups: bool = True,
) -> None:
super().__init__()

self.add_metrics(metrics, *additional_metrics)

self.prefix = self._check_arg(prefix, "prefix")
self.postfix = self._check_arg(postfix, "postfix")
self.enable_compute_groups = enable_compute_groups

self.add_metrics(metrics, *additional_metrics)

@torch.jit.unused
def forward(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
Expand All @@ -117,9 +125,45 @@ def update(self, *args: Any, **kwargs: Any) -> None:
Positional arguments (args) will be passed to every metric in the collection, while keyword arguments (kwargs)
will be filtered based on the signature of the individual metric.
"""
for _, m in self.items(keep_base=True):
m_kwargs = m._filter_kwargs(**kwargs)
m.update(*args, **m_kwargs)

if self._groups_checked:
for _, cg in self._groups.items():
m0 = getattr(self, cg[0])
m0.update(*args, **m0._filter_kwargs(**kwargs))
# copy the state to the remaining metrics in the compute group
for i in range(1, len(cg)):
mi = getattr(self, cg[i])
for state in m0._defaults:
setattr(mi, state, getattr(m0, state))

else: # the first update we do it per metric to make sure the states matches
for _, m in self.items(keep_base=True):
m_kwargs = m._filter_kwargs(**kwargs)
m.update(*args, **m_kwargs)
import pdb

pdb.set_trace()
n_groups = len(self._groups)
for k, cg in self._groups.copy().items():
member1 = cg[0] # check the first against all other
for i, member2 in reversed(list(enumerate(cg[1:]))):
for state in self[member1]._defaults.keys():
# if the states do not match we need to divide the compute group
s1 = getattr(self[member1], state)
s2 = getattr(self[member2], state)
if (
(isinstance(s1, Tensor) and isinstance(s2, Tensor) and not torch.allclose(s1, s2))
or (isinstance(s1, list) and isinstance(s2, list) and not s1 == s2)
or (type(s1) != type(s2))
):

# split member2 into its own computational group
n_groups += 1
self._groups[f"cg{n_groups}"] = [member2]
self._groups[k].pop(i + 1)
break

self._groups_checked = True

def compute(self) -> Dict[str, Any]:
return {k: m.compute() for k, m in self.items()}
Expand Down Expand Up @@ -193,6 +237,41 @@ def add_metrics(
else:
raise ValueError("Unknown input to MetricCollection.")

if self.enable_compute_groups:
self._find_compute_groups()

def _find_compute_groups(self):
"""Find group of metrics that shares the same underlying states.

If such metrics exist, only one should be updated and the rest should just copy the state
"""
from torchmetrics import _COMPUTE_GROUP_REGISTRY

self._groups = {}

# Duplicates of the same metric belongs to the same compute group
for k, v in self.items(keep_base=False):
self._groups.setdefault(v.__class__.__name__, set()).add(k)
for k, v in self._groups.items():
self._groups[k] = list(v)

# Find compute groups for remaining based on registry
for k, v in self._groups.copy().items():
for cg in _COMPUTE_GROUP_REGISTRY:
if k in cg and k in self._groups: # found one metric in compute group
# prevent we compare the metric to itself
compare_dict = self._groups.copy()
compare_dict.pop(k)
for kk, vv in compare_dict.items():
if kk in cg: # found another metric in compute group
self._groups[k] = [*self._groups[k], *compare_dict[kk]]
self._groups.pop(kk)

# Rename groups
self._groups = {f"cg{i}": v for i, v in enumerate(self._groups.values())}

self._groups_checked = False

def _set_name(self, base: str) -> str:
name = base if self.prefix is None else self.prefix + base
name = name if self.postfix is None else name + self.postfix
Expand Down
7 changes: 5 additions & 2 deletions torchmetrics/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.image.inception import IS, InceptionScore # noqa: F401
from torchmetrics.image.kid import KID, KernelInceptionDistance # noqa: F401
from torchmetrics.image.psnr import PSNR, PeakSignalNoiseRatio # noqa: F401
from torchmetrics.image.ssim import ( # noqa: F401
SSIM,
MultiScaleStructuralSimilarityIndexMeasure,
StructuralSimilarityIndexMeasure,
)
from torchmetrics.utilities.imports import _LPIPS_AVAILABLE, _TORCH_FIDELITY_AVAILABLE
from torchmetrics.utilities.registry import register_compute_group

if _TORCH_FIDELITY_AVAILABLE:
from torchmetrics.image.fid import FID, FrechetInceptionDistance # noqa: F401
from torchmetrics.image.inception import IS, InceptionScore # noqa: F401
from torchmetrics.image.kid import KID, KernelInceptionDistance # noqa: F401

register_compute_group(FrechetInceptionDistance, KernelInceptionDistance) # todo: what

if _LPIPS_AVAILABLE:
from torchmetrics.image.lpip import LPIPS, LearnedPerceptualImagePatchSimilarity # noqa: F401
17 changes: 17 additions & 0 deletions torchmetrics/testscript.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import torch

from torchmetrics import F1, Accuracy, ConfusionMatrix, MetricCollection, Recall

m = MetricCollection(
{
"acc": Accuracy(num_classes=5),
"acc2": Accuracy(num_classes=5),
"acc3": Accuracy(num_classes=5, average="macro"),
"f1": F1(num_classes=5),
"recall": Recall(num_classes=5),
"confmat": ConfusionMatrix(num_classes=5),
}
)
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))
m.update(preds, target)
36 changes: 36 additions & 0 deletions torchmetrics/utilities/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.metric import Metric

# define compute groups for metric collection
_COMPUTE_GROUP_REGISTRY = []


def register_compute_group(*metrics):
"""Register a compute group of metrics.

A compute group consist of metrics that share the underlying metric state meaning that only their
`compute` method should differ. Compute groups are used in connection with MetricCollection to
reduce the computational cost of metrics that share the underlying same metric state. Registered
compute groups can found in the global variable `_COMPUTE_GROUP_REGISTRY`.

Args:
*metrics: An iterable of metrics
"""
for m in metrics:
if not issubclass(m, Metric):
raise ValueError(
"Expected all metrics in compute group to be subclass of `torchmetrics.Metric` but got {m}"
)
_COMPUTE_GROUP_REGISTRY.append(tuple(m.__name__ for m in metrics))