Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify model builders #5001

Merged
merged 36 commits into from
Dec 7, 2021
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
98444db
simplify model builders
pmeier Nov 29, 2021
ba95639
cleanup
pmeier Nov 29, 2021
9185309
refactor kwonly to pos or kw handling
pmeier Nov 30, 2021
c893734
put weight verification back
pmeier Nov 30, 2021
cd9e5b9
revert num categories checks
pmeier Nov 30, 2021
96c9e8e
Merge branch 'main' into prototype-model-cleanup
pmeier Nov 30, 2021
ae56d66
fix default weights
pmeier Nov 30, 2021
7a2fd53
cleanup
pmeier Nov 30, 2021
0134267
remove manual parameter map
pmeier Nov 30, 2021
9cde2b7
refactor decorator interface
pmeier Nov 30, 2021
ff9134b
Merge branch 'main' into prototype-model-cleanup
pmeier Nov 30, 2021
c14117a
address review comments
pmeier Dec 2, 2021
71255cb
cleanup
pmeier Dec 2, 2021
a49c4cf
Merge branch 'main' into prototype-model-cleanup
pmeier Dec 2, 2021
4f1c55e
Merge branch 'prototype-model-cleanup' of https://github.com/pmeier/v…
pmeier Dec 2, 2021
ef675f3
refactor callable default
pmeier Dec 3, 2021
54a0d08
Merge branch 'main' into prototype-model-cleanup
pmeier Dec 3, 2021
226c498
Merge branch 'main' into prototype-model-cleanup
pmeier Dec 6, 2021
bb069f4
fix type annotation
pmeier Dec 6, 2021
be6e8e7
process ungrouped models
pmeier Dec 6, 2021
3d4494a
cleanup
pmeier Dec 6, 2021
eff0011
mroe cleanup
pmeier Dec 6, 2021
11ec91e
use decorator for detection models
pmeier Dec 6, 2021
d698709
add decorator for quantization models
pmeier Dec 6, 2021
9c5e0fe
add decorator for segmentation models
pmeier Dec 6, 2021
bfd0aac
add decorator for video models
pmeier Dec 6, 2021
76e58e0
remove old helpers
pmeier Dec 6, 2021
6efefd1
Merge branch 'main' into prototype-model-cleanup
pmeier Dec 6, 2021
b06df00
fix resnet50
pmeier Dec 6, 2021
7660f3a
Adding verification back on InceptionV3
datumbox Dec 6, 2021
1c54c74
Add kwargs in DeeplabeV3
datumbox Dec 6, 2021
31f8a22
Add kwargs on FCN
datumbox Dec 6, 2021
f51dc7f
Merge branch 'main' into prototype-model-cleanup
datumbox Dec 6, 2021
dbf6419
Fix typing on Deeplab
datumbox Dec 7, 2021
65d5b5d
Fix typing on FCN
datumbox Dec 7, 2021
bc358bc
Merge branch 'main' into prototype-model-cleanup
datumbox Dec 7, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions test/test_prototype_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
from common_utils import cpu_and_gpu, run_on_env_var
from torchvision.prototype import models
from torchvision.prototype.models._api import WeightsEnum, Weights
from torchvision.prototype.models._utils import handle_legacy_interface

run_if_test_with_prototype = run_on_env_var(
"PYTORCH_TEST_WITH_PROTOTYPE",
Expand Down Expand Up @@ -164,3 +166,87 @@ def test_old_vs_new_factory(model_fn, dev):

def test_smoke():
import torchvision.prototype.models # noqa: F401


# With this filter, every unexpected warning will be turned into an error
@pytest.mark.filterwarnings("error")
class TestHandleLegacyInterface:
class TestWeights(WeightsEnum):
Sentinel = Weights(url="https://pytorch.org", transforms=lambda x: x, meta=dict())

@pytest.mark.parametrize(
"kwargs",
[
pytest.param(dict(), id="empty"),
pytest.param(dict(weights=None), id="None"),
pytest.param(dict(weights=TestWeights.Sentinel), id="Weights"),
],
)
def test_no_warn(self, kwargs):
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
def builder(*, weights=None):
pass

builder(**kwargs)

@pytest.mark.parametrize("pretrained", (True, False))
def test_pretrained_pos(self, pretrained):
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
def builder(*, weights=None):
pass

with pytest.warns(UserWarning, match="positional"):
builder(pretrained)

@pytest.mark.parametrize("pretrained", (True, False))
def test_pretrained_kw(self, pretrained):
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
def builder(*, weights=None):
pass

with pytest.warns(UserWarning, match="deprecated"):
builder(pretrained)

@pytest.mark.parametrize("pretrained", (True, False))
@pytest.mark.parametrize("positional", (True, False))
def test_equivalent_behavior_weights(self, pretrained, positional):
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
def builder(*, weights=None):
pass

args, kwargs = ((pretrained,), dict()) if positional else ((), dict(pretrained=pretrained))
with pytest.warns(UserWarning, match=f"weights={self.TestWeights.Sentinel if pretrained else None}"):
builder(*args, **kwargs)

def test_multi_params(self):
weights_params = ("weights", "weights_other")
pretrained_params = [param.replace("weights", "pretrained") for param in weights_params]

@handle_legacy_interface(
**{
weights_param: (pretrained_param, self.TestWeights.Sentinel)
for weights_param, pretrained_param in zip(weights_params, pretrained_params)
}
)
def builder(*, weights=None, weights_other=None):
pass

for pretrained_param in pretrained_params:
with pytest.warns(UserWarning, match="deprecated"):
builder(**{pretrained_param: True})

def test_default_callable(self):
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: self.TestWeights.Sentinel if kwargs["flag"] else None,
)
)
def builder(*, weights=None, flag):
pass

with pytest.warns(UserWarning, match="deprecated"):
builder(pretrained=True, flag=True)

with pytest.raises(ValueError, match="weights"):
builder(pretrained=True, flag=False)
101 changes: 82 additions & 19 deletions torchvision/prototype/models/_utils.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,95 @@
import functools
import warnings
from typing import Any, Dict, Optional, TypeVar
from typing import Any, Dict, Optional, TypeVar, Callable, Tuple, Union

from ._api import WeightsEnum
from torch import nn
from torchvision.prototype.utils._internal import kwonly_to_pos_or_kw

from ._api import WeightsEnum

W = TypeVar("W", bound=WeightsEnum)
M = TypeVar("M", bound=nn.Module)
V = TypeVar("V")


def _deprecated_param(
kwargs: Dict[str, Any], deprecated_param: str, new_param: str, default_value: Optional[W]
) -> Optional[W]:
warnings.warn(f"The parameter '{deprecated_param}' is deprecated, please use '{new_param}' instead.")
if kwargs.pop(deprecated_param):
if default_value is not None:
return default_value
else:
raise ValueError("No checkpoint is available for model.")
else:
return None
def handle_legacy_interface(**weights: Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]):
"""Decorates a model builder with the new interface to make it compatible with the old.
In particular this handles two things:
1. Allows positional parameters again, but emits a deprecation warning in case they are used. See
:func:`torchvision.prototype.utils._internal.kwonly_to_pos_or_kw` for details.
2. Handles the default value change from ``pretrained=False`` to ``weights=None`` and ``pretrained=True`` to
``weights=Weights`` and emits a deprecation warning with instructions for the new interface.
Args:
**weights (Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): Deprecated parameter
name and default value for the legacy ``pretrained=True``. The default value can be a callable in which
case it will be called with a dictionary of the keyword arguments. The only key that is guaranteed to be in
the dictionary is the deprecated parameter name passed as first element in the tuple. All other parameters
should be accessed with :meth:`~dict.get`.
"""

def outer_wrapper(builder: Callable[..., M]) -> Callable[..., M]:
@kwonly_to_pos_or_kw
@functools.wraps(builder)
def inner_wrapper(*args: Any, **kwargs: Any) -> M:
for weights_param, (pretrained_param, default) in weights.items(): # type: ignore[union-attr]
# If neither the weights nor the pretrained parameter as passed, or the weights argument already use
# the new style arguments, there is nothing to do. Note that we cannot use `None` as sentinel for the
# weight argument, since it is a valid value.
sentinel = object()
weights_arg = kwargs.get(weights_param, sentinel)
if (
(weights_param not in kwargs and pretrained_param not in kwargs)
or isinstance(weights_arg, WeightsEnum)
or weights_arg is None
):
continue

# If the pretrained parameter was passed as positional argument, it is now mapped to
# `kwargs[weights_param]`. This happens because the @kwonly_to_pos_or_kw decorator uses the current
# signature to infer the names of positionally passed arguments and thus has no knowledge that there
# used to be a pretrained parameter.
pretrained_positional = weights_arg is not sentinel
if pretrained_positional:
# We put the pretrained argument under its legacy name in the keyword argument dictionary to have a
# unified access to the value if the default value is a callable.
kwargs[pretrained_param] = pretrained_arg = kwargs.pop(weights_param)
else:
pretrained_arg = kwargs[pretrained_param]

if pretrained_arg:
default_weights_arg = default(kwargs) if callable(default) else default
if not isinstance(default_weights_arg, WeightsEnum):
raise ValueError(f"No weights available for model {builder.__name__}")
else:
default_weights_arg = None
datumbox marked this conversation as resolved.
Show resolved Hide resolved

if not pretrained_positional:
warnings.warn(
f"The parameter '{pretrained_param}' is deprecated, please use '{weights_param}' instead."
)

msg = (
f"Arguments other than a weight enum or `None` for '{weights_param}' are deprecated. "
f"The current behavior is equivalent to passing `{weights_param}={default_weights_arg}`."
)
if pretrained_arg:
msg = (
f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.default` "
f"to get the most up-to-date weights."
)
warnings.warn(msg)

del kwargs[pretrained_param]
kwargs[weights_param] = default_weights_arg

return builder(*args, **kwargs)

return inner_wrapper

def _deprecated_positional(kwargs: Dict[str, Any], deprecated_param: str, new_param: str, default_value: V) -> None:
warnings.warn(
f"The positional parameter '{deprecated_param}' is deprecated, please use keyword parameter '{new_param}'"
+ " instead."
)
kwargs[deprecated_param] = default_value
return outer_wrapper


def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> None:
Expand Down
9 changes: 3 additions & 6 deletions torchvision/prototype/models/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ...models.alexnet import AlexNet
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ._utils import handle_legacy_interface, _ovewrite_named_param


__all__ = ["AlexNet", "AlexNet_Weights", "alexnet"]
Expand All @@ -29,11 +29,8 @@ class AlexNet_Weights(WeightsEnum):
default = ImageNet1K_V1


def alexnet(weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", AlexNet_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.ImageNet1K_V1))
def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
weights = AlexNet_Weights.verify(weights)

if weights is not None:
Expand Down
30 changes: 9 additions & 21 deletions torchvision/prototype/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ...models.densenet import DenseNet
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ._utils import handle_legacy_interface, _ovewrite_named_param


__all__ = [
Expand Down Expand Up @@ -123,41 +123,29 @@ class DenseNet201_Weights(WeightsEnum):
default = ImageNet1K_V1


def densenet121(weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet121_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.ImageNet1K_V1))
def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
weights = DenseNet121_Weights.verify(weights)

return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)


def densenet161(weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet161_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.ImageNet1K_V1))
def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
weights = DenseNet161_Weights.verify(weights)

return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs)


def densenet169(weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet169_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.ImageNet1K_V1))
def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
weights = DenseNet169_Weights.verify(weights)

return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs)


def densenet201(weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet201_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.ImageNet1K_V1))
def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
weights = DenseNet201_Weights.verify(weights)

return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs)
Loading