Skip to content

Commit

Permalink
simplify model builders (#5001)
Browse files Browse the repository at this point in the history
* simplify model builders

* cleanup

* refactor kwonly to pos or kw handling

* put weight verification back

* revert num categories checks

* fix default weights

* cleanup

* remove manual parameter map

* refactor decorator interface

* address review comments

* cleanup

* refactor callable default

* fix type annotation

* process ungrouped models

* cleanup

* mroe cleanup

* use decorator for detection models

* add decorator for quantization models

* add decorator for segmentation  models

* add decorator for video  models

* remove old helpers

* fix resnet50

* Adding verification back on InceptionV3

* Add kwargs in DeeplabeV3

* Add kwargs on FCN

* Fix typing on Deeplab

* Fix typing on FCN

Co-authored-by: Vasilis Vryniotis <[email protected]>
  • Loading branch information
pmeier and datumbox authored Dec 7, 2021
1 parent 3ceaff1 commit 588e9b5
Show file tree
Hide file tree
Showing 33 changed files with 579 additions and 644 deletions.
86 changes: 86 additions & 0 deletions test/test_prototype_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
from common_utils import cpu_and_gpu, run_on_env_var
from torchvision.prototype import models
from torchvision.prototype.models._api import WeightsEnum, Weights
from torchvision.prototype.models._utils import handle_legacy_interface

run_if_test_with_prototype = run_on_env_var(
"PYTORCH_TEST_WITH_PROTOTYPE",
Expand Down Expand Up @@ -164,3 +166,87 @@ def test_old_vs_new_factory(model_fn, dev):

def test_smoke():
import torchvision.prototype.models # noqa: F401


# With this filter, every unexpected warning will be turned into an error
@pytest.mark.filterwarnings("error")
class TestHandleLegacyInterface:
class TestWeights(WeightsEnum):
Sentinel = Weights(url="https://pytorch.org", transforms=lambda x: x, meta=dict())

@pytest.mark.parametrize(
"kwargs",
[
pytest.param(dict(), id="empty"),
pytest.param(dict(weights=None), id="None"),
pytest.param(dict(weights=TestWeights.Sentinel), id="Weights"),
],
)
def test_no_warn(self, kwargs):
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
def builder(*, weights=None):
pass

builder(**kwargs)

@pytest.mark.parametrize("pretrained", (True, False))
def test_pretrained_pos(self, pretrained):
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
def builder(*, weights=None):
pass

with pytest.warns(UserWarning, match="positional"):
builder(pretrained)

@pytest.mark.parametrize("pretrained", (True, False))
def test_pretrained_kw(self, pretrained):
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
def builder(*, weights=None):
pass

with pytest.warns(UserWarning, match="deprecated"):
builder(pretrained)

@pytest.mark.parametrize("pretrained", (True, False))
@pytest.mark.parametrize("positional", (True, False))
def test_equivalent_behavior_weights(self, pretrained, positional):
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
def builder(*, weights=None):
pass

args, kwargs = ((pretrained,), dict()) if positional else ((), dict(pretrained=pretrained))
with pytest.warns(UserWarning, match=f"weights={self.TestWeights.Sentinel if pretrained else None}"):
builder(*args, **kwargs)

def test_multi_params(self):
weights_params = ("weights", "weights_other")
pretrained_params = [param.replace("weights", "pretrained") for param in weights_params]

@handle_legacy_interface(
**{
weights_param: (pretrained_param, self.TestWeights.Sentinel)
for weights_param, pretrained_param in zip(weights_params, pretrained_params)
}
)
def builder(*, weights=None, weights_other=None):
pass

for pretrained_param in pretrained_params:
with pytest.warns(UserWarning, match="deprecated"):
builder(**{pretrained_param: True})

def test_default_callable(self):
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: self.TestWeights.Sentinel if kwargs["flag"] else None,
)
)
def builder(*, weights=None, flag):
pass

with pytest.warns(UserWarning, match="deprecated"):
builder(pretrained=True, flag=True)

with pytest.raises(ValueError, match="weights"):
builder(pretrained=True, flag=False)
101 changes: 82 additions & 19 deletions torchvision/prototype/models/_utils.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,95 @@
import functools
import warnings
from typing import Any, Dict, Optional, TypeVar
from typing import Any, Dict, Optional, TypeVar, Callable, Tuple, Union

from ._api import WeightsEnum
from torch import nn
from torchvision.prototype.utils._internal import kwonly_to_pos_or_kw

from ._api import WeightsEnum

W = TypeVar("W", bound=WeightsEnum)
M = TypeVar("M", bound=nn.Module)
V = TypeVar("V")


def _deprecated_param(
kwargs: Dict[str, Any], deprecated_param: str, new_param: str, default_value: Optional[W]
) -> Optional[W]:
warnings.warn(f"The parameter '{deprecated_param}' is deprecated, please use '{new_param}' instead.")
if kwargs.pop(deprecated_param):
if default_value is not None:
return default_value
else:
raise ValueError("No checkpoint is available for model.")
else:
return None
def handle_legacy_interface(**weights: Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]):
"""Decorates a model builder with the new interface to make it compatible with the old.
In particular this handles two things:
1. Allows positional parameters again, but emits a deprecation warning in case they are used. See
:func:`torchvision.prototype.utils._internal.kwonly_to_pos_or_kw` for details.
2. Handles the default value change from ``pretrained=False`` to ``weights=None`` and ``pretrained=True`` to
``weights=Weights`` and emits a deprecation warning with instructions for the new interface.
Args:
**weights (Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): Deprecated parameter
name and default value for the legacy ``pretrained=True``. The default value can be a callable in which
case it will be called with a dictionary of the keyword arguments. The only key that is guaranteed to be in
the dictionary is the deprecated parameter name passed as first element in the tuple. All other parameters
should be accessed with :meth:`~dict.get`.
"""

def outer_wrapper(builder: Callable[..., M]) -> Callable[..., M]:
@kwonly_to_pos_or_kw
@functools.wraps(builder)
def inner_wrapper(*args: Any, **kwargs: Any) -> M:
for weights_param, (pretrained_param, default) in weights.items(): # type: ignore[union-attr]
# If neither the weights nor the pretrained parameter as passed, or the weights argument already use
# the new style arguments, there is nothing to do. Note that we cannot use `None` as sentinel for the
# weight argument, since it is a valid value.
sentinel = object()
weights_arg = kwargs.get(weights_param, sentinel)
if (
(weights_param not in kwargs and pretrained_param not in kwargs)
or isinstance(weights_arg, WeightsEnum)
or weights_arg is None
):
continue

# If the pretrained parameter was passed as positional argument, it is now mapped to
# `kwargs[weights_param]`. This happens because the @kwonly_to_pos_or_kw decorator uses the current
# signature to infer the names of positionally passed arguments and thus has no knowledge that there
# used to be a pretrained parameter.
pretrained_positional = weights_arg is not sentinel
if pretrained_positional:
# We put the pretrained argument under its legacy name in the keyword argument dictionary to have a
# unified access to the value if the default value is a callable.
kwargs[pretrained_param] = pretrained_arg = kwargs.pop(weights_param)
else:
pretrained_arg = kwargs[pretrained_param]

if pretrained_arg:
default_weights_arg = default(kwargs) if callable(default) else default
if not isinstance(default_weights_arg, WeightsEnum):
raise ValueError(f"No weights available for model {builder.__name__}")
else:
default_weights_arg = None

if not pretrained_positional:
warnings.warn(
f"The parameter '{pretrained_param}' is deprecated, please use '{weights_param}' instead."
)

msg = (
f"Arguments other than a weight enum or `None` for '{weights_param}' are deprecated. "
f"The current behavior is equivalent to passing `{weights_param}={default_weights_arg}`."
)
if pretrained_arg:
msg = (
f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.default` "
f"to get the most up-to-date weights."
)
warnings.warn(msg)

del kwargs[pretrained_param]
kwargs[weights_param] = default_weights_arg

return builder(*args, **kwargs)

return inner_wrapper

def _deprecated_positional(kwargs: Dict[str, Any], deprecated_param: str, new_param: str, default_value: V) -> None:
warnings.warn(
f"The positional parameter '{deprecated_param}' is deprecated, please use keyword parameter '{new_param}'"
+ " instead."
)
kwargs[deprecated_param] = default_value
return outer_wrapper


def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> None:
Expand Down
9 changes: 3 additions & 6 deletions torchvision/prototype/models/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ...models.alexnet import AlexNet
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ._utils import handle_legacy_interface, _ovewrite_named_param


__all__ = ["AlexNet", "AlexNet_Weights", "alexnet"]
Expand All @@ -29,11 +29,8 @@ class AlexNet_Weights(WeightsEnum):
default = ImageNet1K_V1


def alexnet(weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", AlexNet_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.ImageNet1K_V1))
def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
weights = AlexNet_Weights.verify(weights)

if weights is not None:
Expand Down
30 changes: 9 additions & 21 deletions torchvision/prototype/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ...models.densenet import DenseNet
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param
from ._utils import handle_legacy_interface, _ovewrite_named_param


__all__ = [
Expand Down Expand Up @@ -123,41 +123,29 @@ class DenseNet201_Weights(WeightsEnum):
default = ImageNet1K_V1


def densenet121(weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet121_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.ImageNet1K_V1))
def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
weights = DenseNet121_Weights.verify(weights)

return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)


def densenet161(weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet161_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.ImageNet1K_V1))
def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
weights = DenseNet161_Weights.verify(weights)

return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs)


def densenet169(weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet169_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.ImageNet1K_V1))
def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
weights = DenseNet169_Weights.verify(weights)

return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs)


def densenet201(weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet201_Weights.ImageNet1K_V1)
@handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.ImageNet1K_V1))
def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
weights = DenseNet201_Weights.verify(weights)

return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs)
Loading

0 comments on commit 588e9b5

Please sign in to comment.