Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Smart update of metric collection #709

Merged
merged 51 commits into from
Feb 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
7a00a63
update
SkafteNicki Oct 19, 2021
501aef3
Merge branch 'master' into smart_collection
Borda Jan 20, 2022
2264b18
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 20, 2022
a911b5e
update docs
SkafteNicki Jan 23, 2022
f5eac5f
fix imports
SkafteNicki Jan 23, 2022
b2cb581
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 23, 2022
307f8ba
Apply suggestions from code review
SkafteNicki Jan 24, 2022
e788128
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 24, 2022
6e29cbf
somethings working
SkafteNicki Jan 24, 2022
f391fdd
better solution
SkafteNicki Jan 24, 2022
c7c05c1
revert registry
SkafteNicki Jan 24, 2022
7996ede
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 24, 2022
ae505fc
revert
SkafteNicki Jan 24, 2022
efd6a95
Merge branch 'smart_collection' of https://github.com/PyTorchLightnin…
SkafteNicki Jan 24, 2022
9495da2
flake8
SkafteNicki Jan 24, 2022
e695ad0
Merge branch 'master' into smart_collection
SkafteNicki Jan 24, 2022
44cf2b7
improve testing
SkafteNicki Jan 25, 2022
9af0392
Merge branch 'smart_collection' of https://github.com/PyTorchLightnin…
SkafteNicki Jan 25, 2022
6754a93
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 25, 2022
7b8bd72
accuracy fix
SkafteNicki Jan 25, 2022
dcec4d4
Merge branch 'smart_collection' of https://github.com/PyTorchLightnin…
SkafteNicki Jan 25, 2022
c9c0f50
Apply suggestions from code review
SkafteNicki Jan 27, 2022
3ada9af
Merge branch 'master' into smart_collection
SkafteNicki Jan 27, 2022
e957fda
fix tests
SkafteNicki Jan 31, 2022
4ff45cb
enhancement
SkafteNicki Jan 31, 2022
3abd4aa
doctest
SkafteNicki Jan 31, 2022
0b17d8d
Merge branch 'smart_collection' of https://github.com/PyTorchLightnin…
SkafteNicki Jan 31, 2022
8eda1b6
Merge branch 'master' into smart_collection
SkafteNicki Jan 31, 2022
0c97613
mypy
SkafteNicki Jan 31, 2022
0ffb39d
Merge branch 'smart_collection' of https://github.com/PyTorchLightnin…
SkafteNicki Jan 31, 2022
3f30985
typo
Borda Feb 1, 2022
9d24233
optimize
SkafteNicki Feb 1, 2022
01b3dee
Merge branch 'smart_collection' of https://github.com/PyTorchLightnin…
SkafteNicki Feb 1, 2022
8fad12c
Merge branch 'master' into smart_collection
Borda Feb 5, 2022
475f4b2
Merge branch 'master' into smart_collection
SkafteNicki Feb 8, 2022
3512128
docs
SkafteNicki Feb 8, 2022
6b524b9
Merge branches 'smart_collection' and 'smart_collection' of https://g…
SkafteNicki Feb 8, 2022
5e187ae
fix tests
SkafteNicki Feb 8, 2022
18f393f
Apply suggestions from code review
Borda Feb 8, 2022
85a3174
Update docs/source/pages/overview.rst
SkafteNicki Feb 8, 2022
5638156
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 8, 2022
09d0613
Apply suggestions from code review
SkafteNicki Feb 8, 2022
a966667
Update docs/source/pages/overview.rst
SkafteNicki Feb 8, 2022
c24e8ad
speed tests
SkafteNicki Feb 8, 2022
5e0d801
Merge branch 'master' into smart_collection
SkafteNicki Feb 8, 2022
ef26080
add speed tests
SkafteNicki Feb 8, 2022
7262cf5
Merge branch 'smart_collection' of https://github.com/PyTorchLightnin…
SkafteNicki Feb 8, 2022
a735762
flake8
SkafteNicki Feb 8, 2022
2c63299
fix docs
SkafteNicki Feb 8, 2022
035cea3
fix typing
SkafteNicki Feb 8, 2022
7c360e5
Apply suggestions from code review
SkafteNicki Feb 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added new image metric `UniversalImageQualityIndex` ([#824](https://github.com/PyTorchLightning/metrics/pull/824))


- Added smart update of `MetricCollection` ([#709](https://github.com/PyTorchLightning/metrics/pull/709))


SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
### Changed


Expand Down
11 changes: 11 additions & 0 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,17 @@ inside your LightningModule
have the same call signature. If this is not the case, input that should be
given to different metrics can given as keyword arguments to the collection.

An additional advantage of using the ``MetricCollection`` object is that it will
automatically try to reduce the computations needed by finding groups of metrics
that share the same underlying metric state. If such a group of metrics is found only one
of them is actually updated and the updated state will be broadcasted to the rest
of the metrics within the group. In the example above, this will lead to a 2x-3x lower computational
cost compared to disabling this feature. However, this speedup comes with a fixed cost upfront, where
the state-groups have to be determined after the first update. This overhead can be significantly higher then gains speed-up for very
a low number of steps (approx. up to 100) but still leads to an overall speedup for everything beyond that.
In case the groups are known beforehand, these can also be set manually to avoid this extra cost of the
dynamic search. See the *compute_groups* argument in the class docs below for more information on this topic.

SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
.. autoclass:: torchmetrics.MetricCollection
:noindex:

Expand Down
129 changes: 126 additions & 3 deletions tests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
import time
from copy import deepcopy

import pytest
import torch

from tests.helpers import seed_all
from tests.helpers.testers import DummyMetricDiff, DummyMetricSum
from torchmetrics import Metric
from torchmetrics.classification import Accuracy
from torchmetrics.collections import MetricCollection
from torchmetrics import Accuracy, CohenKappa, ConfusionMatrix, F1Score, Metric, MetricCollection, Precision, Recall

seed_all(42)

Expand Down Expand Up @@ -278,3 +278,126 @@ def compute(self):
mc2 = MetricCollection([MyAccuracy(), DummyMetric()])
mc(torch.tensor([0, 1]), torch.tensor([0, 1]), kwarg="kwarg")
mc2(torch.tensor([0, 1]), torch.tensor([0, 1]), kwarg="kwarg", kwarg2="kwarg2")


@pytest.mark.parametrize(
"metrics, expected",
[
# single metric forms its own compute group
(Accuracy(3), {0: ["Accuracy"]}),
# two metrics of same class forms a compute group
({"acc0": Accuracy(3), "acc1": Accuracy(3)}, {0: ["acc0", "acc1"]}),
# two metrics from registry froms a compute group
([Precision(3), Recall(3)], {0: ["Precision", "Recall"]}),
# two metrics from different classes gives two compute groups
([ConfusionMatrix(3), Recall(3)], {0: ["ConfusionMatrix"], 1: ["Recall"]}),
# multi group multi metric
(
[ConfusionMatrix(3), CohenKappa(3), Recall(3), Precision(3)],
{0: ["ConfusionMatrix", "CohenKappa"], 1: ["Recall", "Precision"]},
),
# Complex example
(
{
"acc": Accuracy(3),
"acc2": Accuracy(3),
"acc3": Accuracy(num_classes=3, average="macro"),
"f1": F1Score(3),
"recall": Recall(3),
"confmat": ConfusionMatrix(3),
},
{0: ["acc", "acc2", "f1", "recall"], 1: ["acc3"], 2: ["confmat"]},
),
],
)
def test_check_compute_groups(metrics, expected):
"""Check that compute groups are formed after initialization."""
m = MetricCollection(deepcopy(metrics), compute_groups=True)
# Construct without for comparison
m2 = MetricCollection(deepcopy(metrics), compute_groups=False)

assert len(m.compute_groups) == len(m)
assert m2.compute_groups == {}

preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
m.update(preds, target)
m2.update(preds, target)

assert m.compute_groups == expected
assert m2.compute_groups == {}

preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
# compute groups should kick in here
m.update(preds, target)
m2.update(preds, target)

# compare results for correctness
res_cg = m.compute()
res_without_cg = m2.compute()
for key in res_cg.keys():
assert torch.allclose(res_cg[key], res_without_cg[key])
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize(
"metrics",
[
{"acc0": Accuracy(3), "acc1": Accuracy(3)},
[Precision(3), Recall(3)],
[ConfusionMatrix(3), CohenKappa(3), Recall(3), Precision(3)],
{
"acc": Accuracy(3),
"acc2": Accuracy(3),
"acc3": Accuracy(num_classes=3, average="macro"),
"f1": F1Score(3),
"recall": Recall(3),
"confmat": ConfusionMatrix(3),
},
],
)
@pytest.mark.parametrize("steps", [100, 1000])
def test_check_compute_groups_is_faster(metrics, steps):
"""Check that compute groups are formed after initialization."""
m = MetricCollection(deepcopy(metrics), compute_groups=True)
# Construct without for comparison
m2 = MetricCollection(deepcopy(metrics), compute_groups=False)

preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))

start = time.time()
for _ in range(steps):
m.update(preds, target)
time_cg = time.time() - start

start = time.time()
for _ in range(steps):
m2.update(preds, target)
time_no_cg = time.time() - start

assert time_cg < time_no_cg, "using compute groups were not faster"


def test_compute_group_define_by_user():
"""Check that user can provide compute groups."""
m = MetricCollection(
ConfusionMatrix(3), Recall(3), Precision(3), compute_groups=[["ConfusionMatrix"], ["Recall", "Precision"]]
)

# Check that we are not going to check the groups in the first update
assert m._groups_checked
assert m.compute_groups == {0: ["ConfusionMatrix"], 1: ["Recall", "Precision"]}

preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
m.update(preds, target)
assert m.compute()


def test_error_on_wrong_specified_compute_groups():
"""Test that error is raised if user mis-specify the compute groups."""
with pytest.raises(ValueError, match="Input Accuracy in `compute_groups`.*"):
MetricCollection(
ConfusionMatrix(3), Recall(3), Precision(3), compute_groups=[["ConfusionMatrix"], ["Recall", "Accuracy"]]
)
7 changes: 4 additions & 3 deletions torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,6 @@ def __init__(
dist_sync_fn=dist_sync_fn,
)

self.add_state("correct", default=tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")

if top_k is not None and (not isinstance(top_k, int) or top_k <= 0):
raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}")

Expand All @@ -219,6 +216,10 @@ def __init__(
self.mode: DataType = None # type: ignore
self.multiclass = multiclass

if self.subset_accuracy:
self.add_state("correct", default=tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""Update state with predictions and targets. See
:ref:`references/modules:input types` for more information on input
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/classification/cohen_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(
if self.weights not in allowed_weights:
raise ValueError(f"Argument weights needs to one of the following: {allowed_weights}")

self.add_state("confmat", default=torch.zeros(num_classes, num_classes), dist_reduce_fx="sum")
self.add_state("confmat", default=torch.zeros(num_classes, num_classes, dtype=torch.long), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""Update state with predictions and targets.
Expand Down
Loading