diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 1dc883528ef..3286785f60a 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -5,6 +5,8 @@ import torch from common_utils import cpu_and_gpu, run_on_env_var from torchvision.prototype import models +from torchvision.prototype.models._api import WeightsEnum, Weights +from torchvision.prototype.models._utils import handle_legacy_interface run_if_test_with_prototype = run_on_env_var( "PYTORCH_TEST_WITH_PROTOTYPE", @@ -164,3 +166,87 @@ def test_old_vs_new_factory(model_fn, dev): def test_smoke(): import torchvision.prototype.models # noqa: F401 + + +# With this filter, every unexpected warning will be turned into an error +@pytest.mark.filterwarnings("error") +class TestHandleLegacyInterface: + class TestWeights(WeightsEnum): + Sentinel = Weights(url="https://pytorch.org", transforms=lambda x: x, meta=dict()) + + @pytest.mark.parametrize( + "kwargs", + [ + pytest.param(dict(), id="empty"), + pytest.param(dict(weights=None), id="None"), + pytest.param(dict(weights=TestWeights.Sentinel), id="Weights"), + ], + ) + def test_no_warn(self, kwargs): + @handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel)) + def builder(*, weights=None): + pass + + builder(**kwargs) + + @pytest.mark.parametrize("pretrained", (True, False)) + def test_pretrained_pos(self, pretrained): + @handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel)) + def builder(*, weights=None): + pass + + with pytest.warns(UserWarning, match="positional"): + builder(pretrained) + + @pytest.mark.parametrize("pretrained", (True, False)) + def test_pretrained_kw(self, pretrained): + @handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel)) + def builder(*, weights=None): + pass + + with pytest.warns(UserWarning, match="deprecated"): + builder(pretrained) + + @pytest.mark.parametrize("pretrained", (True, False)) + @pytest.mark.parametrize("positional", (True, False)) + def test_equivalent_behavior_weights(self, pretrained, positional): + @handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel)) + def builder(*, weights=None): + pass + + args, kwargs = ((pretrained,), dict()) if positional else ((), dict(pretrained=pretrained)) + with pytest.warns(UserWarning, match=f"weights={self.TestWeights.Sentinel if pretrained else None}"): + builder(*args, **kwargs) + + def test_multi_params(self): + weights_params = ("weights", "weights_other") + pretrained_params = [param.replace("weights", "pretrained") for param in weights_params] + + @handle_legacy_interface( + **{ + weights_param: (pretrained_param, self.TestWeights.Sentinel) + for weights_param, pretrained_param in zip(weights_params, pretrained_params) + } + ) + def builder(*, weights=None, weights_other=None): + pass + + for pretrained_param in pretrained_params: + with pytest.warns(UserWarning, match="deprecated"): + builder(**{pretrained_param: True}) + + def test_default_callable(self): + @handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: self.TestWeights.Sentinel if kwargs["flag"] else None, + ) + ) + def builder(*, weights=None, flag): + pass + + with pytest.warns(UserWarning, match="deprecated"): + builder(pretrained=True, flag=True) + + with pytest.raises(ValueError, match="weights"): + builder(pretrained=True, flag=False) diff --git a/torchvision/prototype/models/_utils.py b/torchvision/prototype/models/_utils.py index e2ee9034953..6286d7b19b1 100644 --- a/torchvision/prototype/models/_utils.py +++ b/torchvision/prototype/models/_utils.py @@ -1,32 +1,95 @@ +import functools import warnings -from typing import Any, Dict, Optional, TypeVar +from typing import Any, Dict, Optional, TypeVar, Callable, Tuple, Union -from ._api import WeightsEnum +from torch import nn +from torchvision.prototype.utils._internal import kwonly_to_pos_or_kw +from ._api import WeightsEnum W = TypeVar("W", bound=WeightsEnum) +M = TypeVar("M", bound=nn.Module) V = TypeVar("V") -def _deprecated_param( - kwargs: Dict[str, Any], deprecated_param: str, new_param: str, default_value: Optional[W] -) -> Optional[W]: - warnings.warn(f"The parameter '{deprecated_param}' is deprecated, please use '{new_param}' instead.") - if kwargs.pop(deprecated_param): - if default_value is not None: - return default_value - else: - raise ValueError("No checkpoint is available for model.") - else: - return None +def handle_legacy_interface(**weights: Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): + """Decorates a model builder with the new interface to make it compatible with the old. + + In particular this handles two things: + + 1. Allows positional parameters again, but emits a deprecation warning in case they are used. See + :func:`torchvision.prototype.utils._internal.kwonly_to_pos_or_kw` for details. + 2. Handles the default value change from ``pretrained=False`` to ``weights=None`` and ``pretrained=True`` to + ``weights=Weights`` and emits a deprecation warning with instructions for the new interface. + + Args: + **weights (Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): Deprecated parameter + name and default value for the legacy ``pretrained=True``. The default value can be a callable in which + case it will be called with a dictionary of the keyword arguments. The only key that is guaranteed to be in + the dictionary is the deprecated parameter name passed as first element in the tuple. All other parameters + should be accessed with :meth:`~dict.get`. + """ + + def outer_wrapper(builder: Callable[..., M]) -> Callable[..., M]: + @kwonly_to_pos_or_kw + @functools.wraps(builder) + def inner_wrapper(*args: Any, **kwargs: Any) -> M: + for weights_param, (pretrained_param, default) in weights.items(): # type: ignore[union-attr] + # If neither the weights nor the pretrained parameter as passed, or the weights argument already use + # the new style arguments, there is nothing to do. Note that we cannot use `None` as sentinel for the + # weight argument, since it is a valid value. + sentinel = object() + weights_arg = kwargs.get(weights_param, sentinel) + if ( + (weights_param not in kwargs and pretrained_param not in kwargs) + or isinstance(weights_arg, WeightsEnum) + or weights_arg is None + ): + continue + + # If the pretrained parameter was passed as positional argument, it is now mapped to + # `kwargs[weights_param]`. This happens because the @kwonly_to_pos_or_kw decorator uses the current + # signature to infer the names of positionally passed arguments and thus has no knowledge that there + # used to be a pretrained parameter. + pretrained_positional = weights_arg is not sentinel + if pretrained_positional: + # We put the pretrained argument under its legacy name in the keyword argument dictionary to have a + # unified access to the value if the default value is a callable. + kwargs[pretrained_param] = pretrained_arg = kwargs.pop(weights_param) + else: + pretrained_arg = kwargs[pretrained_param] + + if pretrained_arg: + default_weights_arg = default(kwargs) if callable(default) else default + if not isinstance(default_weights_arg, WeightsEnum): + raise ValueError(f"No weights available for model {builder.__name__}") + else: + default_weights_arg = None + + if not pretrained_positional: + warnings.warn( + f"The parameter '{pretrained_param}' is deprecated, please use '{weights_param}' instead." + ) + + msg = ( + f"Arguments other than a weight enum or `None` for '{weights_param}' are deprecated. " + f"The current behavior is equivalent to passing `{weights_param}={default_weights_arg}`." + ) + if pretrained_arg: + msg = ( + f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.default` " + f"to get the most up-to-date weights." + ) + warnings.warn(msg) + + del kwargs[pretrained_param] + kwargs[weights_param] = default_weights_arg + + return builder(*args, **kwargs) + return inner_wrapper -def _deprecated_positional(kwargs: Dict[str, Any], deprecated_param: str, new_param: str, default_value: V) -> None: - warnings.warn( - f"The positional parameter '{deprecated_param}' is deprecated, please use keyword parameter '{new_param}'" - + " instead." - ) - kwargs[deprecated_param] = default_value + return outer_wrapper def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> None: diff --git a/torchvision/prototype/models/alexnet.py b/torchvision/prototype/models/alexnet.py index 28b0fa60504..623aa8c3a01 100644 --- a/torchvision/prototype/models/alexnet.py +++ b/torchvision/prototype/models/alexnet.py @@ -7,7 +7,7 @@ from ...models.alexnet import AlexNet from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES -from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param +from ._utils import handle_legacy_interface, _ovewrite_named_param __all__ = ["AlexNet", "AlexNet_Weights", "alexnet"] @@ -29,11 +29,8 @@ class AlexNet_Weights(WeightsEnum): default = ImageNet1K_V1 -def alexnet(weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", AlexNet_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.ImageNet1K_V1)) +def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet: weights = AlexNet_Weights.verify(weights) if weights is not None: diff --git a/torchvision/prototype/models/densenet.py b/torchvision/prototype/models/densenet.py index b8abbdde947..f25ad2f64c2 100644 --- a/torchvision/prototype/models/densenet.py +++ b/torchvision/prototype/models/densenet.py @@ -9,7 +9,7 @@ from ...models.densenet import DenseNet from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES -from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param +from ._utils import handle_legacy_interface, _ovewrite_named_param __all__ = [ @@ -123,41 +123,29 @@ class DenseNet201_Weights(WeightsEnum): default = ImageNet1K_V1 -def densenet121(weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet121_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.ImageNet1K_V1)) +def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: weights = DenseNet121_Weights.verify(weights) return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs) -def densenet161(weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet161_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.ImageNet1K_V1)) +def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: weights = DenseNet161_Weights.verify(weights) return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs) -def densenet169(weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet169_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.ImageNet1K_V1)) +def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: weights = DenseNet169_Weights.verify(weights) return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs) -def densenet201(weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet201_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.ImageNet1K_V1)) +def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: weights = DenseNet201_Weights.verify(weights) return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/detection/faster_rcnn.py b/torchvision/prototype/models/detection/faster_rcnn.py index 1f5c6461698..ea462ee0758 100644 --- a/torchvision/prototype/models/detection/faster_rcnn.py +++ b/torchvision/prototype/models/detection/faster_rcnn.py @@ -14,7 +14,7 @@ ) from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES -from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param +from .._utils import handle_legacy_interface, _ovewrite_value_param from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large from ..resnet import ResNet50_Weights, resnet50 @@ -75,7 +75,12 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum): default = Coco_V1 +@handle_legacy_interface( + weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.Coco_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1), +) def fasterrcnn_resnet50_fpn( + *, weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, @@ -83,17 +88,7 @@ def fasterrcnn_resnet50_fpn( trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> FasterRCNN: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNN_ResNet50_FPN_Weights.Coco_V1) weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights) - if type(weights_backbone) == bool and weights_backbone: - _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) - if "pretrained_backbone" in kwargs: - weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1 - ) weights_backbone = ResNet50_Weights.verify(weights_backbone) if weights is not None: @@ -119,6 +114,7 @@ def fasterrcnn_resnet50_fpn( def _fasterrcnn_mobilenet_v3_large_fpn( + *, weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]], progress: bool, num_classes: Optional[int], @@ -158,7 +154,12 @@ def _fasterrcnn_mobilenet_v3_large_fpn( return model +@handle_legacy_interface( + weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.Coco_V1), + weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1), +) def fasterrcnn_mobilenet_v3_large_fpn( + *, weights: Optional[FasterRCNN_MobileNet_V3_Large_FPN_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, @@ -166,17 +167,7 @@ def fasterrcnn_mobilenet_v3_large_fpn( trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> FasterRCNN: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNN_MobileNet_V3_Large_FPN_Weights.Coco_V1) weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.verify(weights) - if type(weights_backbone) == bool and weights_backbone: - _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) - if "pretrained_backbone" in kwargs: - weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1 - ) weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) defaults = { @@ -185,16 +176,21 @@ def fasterrcnn_mobilenet_v3_large_fpn( kwargs = {**defaults, **kwargs} return _fasterrcnn_mobilenet_v3_large_fpn( - weights, - progress, - num_classes, - weights_backbone, - trainable_backbone_layers, + weights=weights, + progress=progress, + num_classes=num_classes, + weights_backbone=weights_backbone, + trainable_backbone_layers=trainable_backbone_layers, **kwargs, ) +@handle_legacy_interface( + weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.Coco_V1), + weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1), +) def fasterrcnn_mobilenet_v3_large_320_fpn( + *, weights: Optional[FasterRCNN_MobileNet_V3_Large_320_FPN_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, @@ -202,19 +198,8 @@ def fasterrcnn_mobilenet_v3_large_320_fpn( trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> FasterRCNN: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param( - kwargs, "pretrained", "weights", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.Coco_V1 - ) + weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.verify(weights) - if type(weights_backbone) == bool and weights_backbone: - _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) - if "pretrained_backbone" in kwargs: - weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1 - ) weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) defaults = { @@ -227,10 +212,10 @@ def fasterrcnn_mobilenet_v3_large_320_fpn( kwargs = {**defaults, **kwargs} return _fasterrcnn_mobilenet_v3_large_fpn( - weights, - progress, - num_classes, - weights_backbone, - trainable_backbone_layers, + weights=weights, + progress=progress, + num_classes=num_classes, + weights_backbone=weights_backbone, + trainable_backbone_layers=trainable_backbone_layers, **kwargs, ) diff --git a/torchvision/prototype/models/detection/keypoint_rcnn.py b/torchvision/prototype/models/detection/keypoint_rcnn.py index a811999681d..b5e8e8267ff 100644 --- a/torchvision/prototype/models/detection/keypoint_rcnn.py +++ b/torchvision/prototype/models/detection/keypoint_rcnn.py @@ -11,7 +11,7 @@ ) from .._api import WeightsEnum, Weights from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES -from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param +from .._utils import handle_legacy_interface, _ovewrite_value_param from ..resnet import ResNet50_Weights, resnet50 @@ -49,7 +49,17 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): default = Coco_V1 +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: KeypointRCNN_ResNet50_FPN_Weights.Coco_Legacy + if kwargs["pretrained"] == "legacy" + else KeypointRCNN_ResNet50_FPN_Weights.Coco_V1, + ), + weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1), +) def keypointrcnn_resnet50_fpn( + *, weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, @@ -58,21 +68,7 @@ def keypointrcnn_resnet50_fpn( trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> KeypointRCNN: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - default_value = KeypointRCNN_ResNet50_FPN_Weights.Coco_V1 - if kwargs["pretrained"] == "legacy": - default_value = KeypointRCNN_ResNet50_FPN_Weights.Coco_Legacy - kwargs["pretrained"] = True - weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights) - if type(weights_backbone) == bool and weights_backbone: - _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) - if "pretrained_backbone" in kwargs: - weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1 - ) weights_backbone = ResNet50_Weights.verify(weights_backbone) if weights is not None: diff --git a/torchvision/prototype/models/detection/mask_rcnn.py b/torchvision/prototype/models/detection/mask_rcnn.py index 4eb285fac0d..9dbf14249c7 100644 --- a/torchvision/prototype/models/detection/mask_rcnn.py +++ b/torchvision/prototype/models/detection/mask_rcnn.py @@ -12,7 +12,7 @@ ) from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES -from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param +from .._utils import handle_legacy_interface, _ovewrite_value_param from ..resnet import ResNet50_Weights, resnet50 @@ -38,7 +38,12 @@ class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum): default = Coco_V1 +@handle_legacy_interface( + weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.Coco_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1), +) def maskrcnn_resnet50_fpn( + *, weights: Optional[MaskRCNN_ResNet50_FPN_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, @@ -46,17 +51,7 @@ def maskrcnn_resnet50_fpn( trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> MaskRCNN: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", MaskRCNN_ResNet50_FPN_Weights.Coco_V1) weights = MaskRCNN_ResNet50_FPN_Weights.verify(weights) - if type(weights_backbone) == bool and weights_backbone: - _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) - if "pretrained_backbone" in kwargs: - weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1 - ) weights_backbone = ResNet50_Weights.verify(weights_backbone) if weights is not None: diff --git a/torchvision/prototype/models/detection/retinanet.py b/torchvision/prototype/models/detection/retinanet.py index 799bc21c379..234f80dace2 100644 --- a/torchvision/prototype/models/detection/retinanet.py +++ b/torchvision/prototype/models/detection/retinanet.py @@ -13,7 +13,7 @@ ) from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES -from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param +from .._utils import handle_legacy_interface, _ovewrite_value_param from ..resnet import ResNet50_Weights, resnet50 @@ -38,7 +38,12 @@ class RetinaNet_ResNet50_FPN_Weights(WeightsEnum): default = Coco_V1 +@handle_legacy_interface( + weights=("pretrained", RetinaNet_ResNet50_FPN_Weights.Coco_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1), +) def retinanet_resnet50_fpn( + *, weights: Optional[RetinaNet_ResNet50_FPN_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, @@ -46,17 +51,7 @@ def retinanet_resnet50_fpn( trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> RetinaNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RetinaNet_ResNet50_FPN_Weights.Coco_V1) weights = RetinaNet_ResNet50_FPN_Weights.verify(weights) - if type(weights_backbone) == bool and weights_backbone: - _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) - if "pretrained_backbone" in kwargs: - weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1 - ) weights_backbone = ResNet50_Weights.verify(weights_backbone) if weights is not None: diff --git a/torchvision/prototype/models/detection/ssd.py b/torchvision/prototype/models/detection/ssd.py index f57b47c00d6..53d3a996d9f 100644 --- a/torchvision/prototype/models/detection/ssd.py +++ b/torchvision/prototype/models/detection/ssd.py @@ -12,7 +12,7 @@ ) from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES -from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param +from .._utils import handle_legacy_interface, _ovewrite_value_param from ..vgg import VGG16_Weights, vgg16 @@ -37,7 +37,12 @@ class SSD300_VGG16_Weights(WeightsEnum): default = Coco_V1 +@handle_legacy_interface( + weights=("pretrained", SSD300_VGG16_Weights.Coco_V1), + weights_backbone=("pretrained_backbone", VGG16_Weights.ImageNet1K_Features), +) def ssd300_vgg16( + *, weights: Optional[SSD300_VGG16_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, @@ -45,17 +50,7 @@ def ssd300_vgg16( trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> SSD: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", SSD300_VGG16_Weights.Coco_V1) weights = SSD300_VGG16_Weights.verify(weights) - if type(weights_backbone) == bool and weights_backbone: - _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) - if "pretrained_backbone" in kwargs: - weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", VGG16_Weights.ImageNet1K_Features - ) weights_backbone = VGG16_Weights.verify(weights_backbone) if "size" in kwargs: diff --git a/torchvision/prototype/models/detection/ssdlite.py b/torchvision/prototype/models/detection/ssdlite.py index 4a61c50101a..e95dda01ee7 100644 --- a/torchvision/prototype/models/detection/ssdlite.py +++ b/torchvision/prototype/models/detection/ssdlite.py @@ -17,7 +17,7 @@ ) from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES -from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param +from .._utils import handle_legacy_interface, _ovewrite_value_param from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large @@ -42,7 +42,12 @@ class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum): default = Coco_V1 +@handle_legacy_interface( + weights=("pretrained", SSDLite320_MobileNet_V3_Large_Weights.Coco_V1), + weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1), +) def ssdlite320_mobilenet_v3_large( + *, weights: Optional[SSDLite320_MobileNet_V3_Large_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, @@ -51,17 +56,7 @@ def ssdlite320_mobilenet_v3_large( norm_layer: Optional[Callable[..., nn.Module]] = None, **kwargs: Any, ) -> SSD: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", SSDLite320_MobileNet_V3_Large_Weights.Coco_V1) weights = SSDLite320_MobileNet_V3_Large_Weights.verify(weights) - if type(weights_backbone) == bool and weights_backbone: - _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) - if "pretrained_backbone" in kwargs: - weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1 - ) weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) if "size" in kwargs: diff --git a/torchvision/prototype/models/efficientnet.py b/torchvision/prototype/models/efficientnet.py index f4a69aac70c..e9ec0767e7a 100644 --- a/torchvision/prototype/models/efficientnet.py +++ b/torchvision/prototype/models/efficientnet.py @@ -8,7 +8,7 @@ from ...models.efficientnet import EfficientNet, MBConvConfig from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES -from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param +from ._utils import handle_legacy_interface, _ovewrite_named_param __all__ = [ @@ -181,73 +181,55 @@ class EfficientNet_B7_Weights(WeightsEnum): default = ImageNet1K_V1 +@handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.ImageNet1K_V1)) def efficientnet_b0( - weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any + *, weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any ) -> EfficientNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B0_Weights.ImageNet1K_V1) weights = EfficientNet_B0_Weights.verify(weights) return _efficientnet(width_mult=1.0, depth_mult=1.0, dropout=0.2, weights=weights, progress=progress, **kwargs) +@handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.ImageNet1K_V1)) def efficientnet_b1( - weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any + *, weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any ) -> EfficientNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B1_Weights.ImageNet1K_V1) weights = EfficientNet_B1_Weights.verify(weights) return _efficientnet(width_mult=1.0, depth_mult=1.1, dropout=0.2, weights=weights, progress=progress, **kwargs) +@handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.ImageNet1K_V1)) def efficientnet_b2( - weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any + *, weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any ) -> EfficientNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B2_Weights.ImageNet1K_V1) weights = EfficientNet_B2_Weights.verify(weights) return _efficientnet(width_mult=1.1, depth_mult=1.2, dropout=0.3, weights=weights, progress=progress, **kwargs) +@handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.ImageNet1K_V1)) def efficientnet_b3( - weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any + *, weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any ) -> EfficientNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B3_Weights.ImageNet1K_V1) weights = EfficientNet_B3_Weights.verify(weights) return _efficientnet(width_mult=1.2, depth_mult=1.4, dropout=0.3, weights=weights, progress=progress, **kwargs) +@handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.ImageNet1K_V1)) def efficientnet_b4( - weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any + *, weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any ) -> EfficientNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B4_Weights.ImageNet1K_V1) weights = EfficientNet_B4_Weights.verify(weights) return _efficientnet(width_mult=1.4, depth_mult=1.8, dropout=0.4, weights=weights, progress=progress, **kwargs) +@handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.ImageNet1K_V1)) def efficientnet_b5( - weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any + *, weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any ) -> EfficientNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B5_Weights.ImageNet1K_V1) weights = EfficientNet_B5_Weights.verify(weights) return _efficientnet( @@ -261,13 +243,10 @@ def efficientnet_b5( ) +@handle_legacy_interface(weights=("pretrained", EfficientNet_B6_Weights.ImageNet1K_V1)) def efficientnet_b6( - weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any + *, weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any ) -> EfficientNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B6_Weights.ImageNet1K_V1) weights = EfficientNet_B6_Weights.verify(weights) return _efficientnet( @@ -281,13 +260,10 @@ def efficientnet_b6( ) +@handle_legacy_interface(weights=("pretrained", EfficientNet_B7_Weights.ImageNet1K_V1)) def efficientnet_b7( - weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any + *, weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any ) -> EfficientNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B7_Weights.ImageNet1K_V1) weights = EfficientNet_B7_Weights.verify(weights) return _efficientnet( diff --git a/torchvision/prototype/models/googlenet.py b/torchvision/prototype/models/googlenet.py index f62c5a96e15..06639321110 100644 --- a/torchvision/prototype/models/googlenet.py +++ b/torchvision/prototype/models/googlenet.py @@ -8,7 +8,7 @@ from ...models.googlenet import GoogLeNet, GoogLeNetOutputs, _GoogLeNetOutputs from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES -from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param +from ._utils import handle_legacy_interface, _ovewrite_named_param __all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNet_Weights", "googlenet"] @@ -30,11 +30,8 @@ class GoogLeNet_Weights(WeightsEnum): default = ImageNet1K_V1 -def googlenet(weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", GoogLeNet_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", GoogLeNet_Weights.ImageNet1K_V1)) +def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet: weights = GoogLeNet_Weights.verify(weights) original_aux_logits = kwargs.get("aux_logits", False) diff --git a/torchvision/prototype/models/inception.py b/torchvision/prototype/models/inception.py index 4814fa76c5c..c4fd8c0ca8a 100644 --- a/torchvision/prototype/models/inception.py +++ b/torchvision/prototype/models/inception.py @@ -7,7 +7,7 @@ from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES -from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param +from ._utils import handle_legacy_interface, _ovewrite_named_param __all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception_V3_Weights", "inception_v3"] @@ -29,11 +29,8 @@ class Inception_V3_Weights(WeightsEnum): default = ImageNet1K_V1 -def inception_v3(weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", Inception_V3_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", Inception_V3_Weights.ImageNet1K_V1)) +def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3: weights = Inception_V3_Weights.verify(weights) original_aux_logits = kwargs.get("aux_logits", True) diff --git a/torchvision/prototype/models/mnasnet.py b/torchvision/prototype/models/mnasnet.py index 554057a9ba1..066a84e41aa 100644 --- a/torchvision/prototype/models/mnasnet.py +++ b/torchvision/prototype/models/mnasnet.py @@ -7,7 +7,7 @@ from ...models.mnasnet import MNASNet from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES -from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param +from ._utils import handle_legacy_interface, _ovewrite_named_param __all__ = [ @@ -79,41 +79,29 @@ def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwa return model -def mnasnet0_5(weights: Optional[MNASNet0_5_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet0_5_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", MNASNet0_5_Weights.ImageNet1K_V1)) +def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: weights = MNASNet0_5_Weights.verify(weights) return _mnasnet(0.5, weights, progress, **kwargs) -def mnasnet0_75(weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", None) +@handle_legacy_interface(weights=("pretrained", None)) +def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: weights = MNASNet0_75_Weights.verify(weights) return _mnasnet(0.75, weights, progress, **kwargs) -def mnasnet1_0(weights: Optional[MNASNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet1_0_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", MNASNet1_0_Weights.ImageNet1K_V1)) +def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: weights = MNASNet1_0_Weights.verify(weights) return _mnasnet(1.0, weights, progress, **kwargs) -def mnasnet1_3(weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", None) +@handle_legacy_interface(weights=("pretrained", None)) +def mnasnet1_3(*, weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: weights = MNASNet1_3_Weights.verify(weights) return _mnasnet(1.3, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/mobilenetv2.py b/torchvision/prototype/models/mobilenetv2.py index 64c7221da6d..6436b6f504d 100644 --- a/torchvision/prototype/models/mobilenetv2.py +++ b/torchvision/prototype/models/mobilenetv2.py @@ -7,7 +7,7 @@ from ...models.mobilenetv2 import MobileNetV2 from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES -from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param +from ._utils import handle_legacy_interface, _ovewrite_named_param __all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"] @@ -29,11 +29,10 @@ class MobileNet_V2_Weights(WeightsEnum): default = ImageNet1K_V1 -def mobilenet_v2(weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any) -> MobileNetV2: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNet_V2_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.ImageNet1K_V1)) +def mobilenet_v2( + *, weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any +) -> MobileNetV2: weights = MobileNet_V2_Weights.verify(weights) if weights is not None: diff --git a/torchvision/prototype/models/mobilenetv3.py b/torchvision/prototype/models/mobilenetv3.py index a92c7667aab..c5fd1f0b54d 100644 --- a/torchvision/prototype/models/mobilenetv3.py +++ b/torchvision/prototype/models/mobilenetv3.py @@ -7,7 +7,7 @@ from ...models.mobilenetv3 import MobileNetV3, _mobilenet_v3_conf, InvertedResidualConfig from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES -from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param +from ._utils import handle_legacy_interface, _ovewrite_named_param __all__ = [ @@ -82,26 +82,20 @@ class MobileNet_V3_Small_Weights(WeightsEnum): default = ImageNet1K_V1 +@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Large_Weights.ImageNet1K_V1)) def mobilenet_v3_large( - weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any + *, weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any ) -> MobileNetV3: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNet_V3_Large_Weights.ImageNet1K_V1) weights = MobileNet_V3_Large_Weights.verify(weights) inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs) return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs) +@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Small_Weights.ImageNet1K_V1)) def mobilenet_v3_small( - weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any + *, weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any ) -> MobileNetV3: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNet_V3_Small_Weights.ImageNet1K_V1) weights = MobileNet_V3_Small_Weights.verify(weights) inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_small", **kwargs) diff --git a/torchvision/prototype/models/quantization/googlenet.py b/torchvision/prototype/models/quantization/googlenet.py index dc3c875b79a..2bda6b946e4 100644 --- a/torchvision/prototype/models/quantization/googlenet.py +++ b/torchvision/prototype/models/quantization/googlenet.py @@ -12,7 +12,7 @@ ) from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES -from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param +from .._utils import handle_legacy_interface, _ovewrite_named_param from ..googlenet import GoogLeNet_Weights @@ -42,21 +42,22 @@ class GoogLeNet_QuantizedWeights(WeightsEnum): default = ImageNet1K_FBGEMM_V1 +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: GoogLeNet_QuantizedWeights.ImageNet1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else GoogLeNet_Weights.ImageNet1K_V1, + ) +) def googlenet( + *, weights: Optional[Union[GoogLeNet_QuantizedWeights, GoogLeNet_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, ) -> QuantizableGoogLeNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - default_value = GoogLeNet_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize else GoogLeNet_Weights.ImageNet1K_V1 - weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] - if quantize: - weights = GoogLeNet_QuantizedWeights.verify(weights) - else: - weights = GoogLeNet_Weights.verify(weights) + weights = (GoogLeNet_QuantizedWeights if quantize else GoogLeNet_Weights).verify(weights) original_aux_logits = kwargs.get("aux_logits", False) if weights is not None: diff --git a/torchvision/prototype/models/quantization/inception.py b/torchvision/prototype/models/quantization/inception.py index d1d5d4ca8fe..e9f48d097f6 100644 --- a/torchvision/prototype/models/quantization/inception.py +++ b/torchvision/prototype/models/quantization/inception.py @@ -11,7 +11,7 @@ ) from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES -from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param +from .._utils import handle_legacy_interface, _ovewrite_named_param from ..inception import Inception_V3_Weights @@ -41,23 +41,22 @@ class Inception_V3_QuantizedWeights(WeightsEnum): default = ImageNet1K_FBGEMM_V1 +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: Inception_V3_QuantizedWeights.ImageNet1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else Inception_V3_Weights.ImageNet1K_V1, + ) +) def inception_v3( + *, weights: Optional[Union[Inception_V3_QuantizedWeights, Inception_V3_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, ) -> QuantizableInception3: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - default_value = ( - Inception_V3_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize else Inception_V3_Weights.ImageNet1K_V1 - ) - weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] - if quantize: - weights = Inception_V3_QuantizedWeights.verify(weights) - else: - weights = Inception_V3_Weights.verify(weights) + weights = (Inception_V3_QuantizedWeights if quantize else Inception_V3_Weights).verify(weights) original_aux_logits = kwargs.get("aux_logits", False) if weights is not None: diff --git a/torchvision/prototype/models/quantization/mobilenetv2.py b/torchvision/prototype/models/quantization/mobilenetv2.py index 81540f2f840..50a94d4b5b6 100644 --- a/torchvision/prototype/models/quantization/mobilenetv2.py +++ b/torchvision/prototype/models/quantization/mobilenetv2.py @@ -12,7 +12,7 @@ ) from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES -from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param +from .._utils import handle_legacy_interface, _ovewrite_named_param from ..mobilenetv2 import MobileNet_V2_Weights @@ -42,23 +42,22 @@ class MobileNet_V2_QuantizedWeights(WeightsEnum): default = ImageNet1K_QNNPACK_V1 +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: MobileNet_V2_QuantizedWeights.ImageNet1K_QNNPACK_V1 + if kwargs.get("quantize", False) + else MobileNet_V2_Weights.ImageNet1K_V1, + ) +) def mobilenet_v2( + *, weights: Optional[Union[MobileNet_V2_QuantizedWeights, MobileNet_V2_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, ) -> QuantizableMobileNetV2: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - default_value = ( - MobileNet_V2_QuantizedWeights.ImageNet1K_QNNPACK_V1 if quantize else MobileNet_V2_Weights.ImageNet1K_V1 - ) - weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] - if quantize: - weights = MobileNet_V2_QuantizedWeights.verify(weights) - else: - weights = MobileNet_V2_Weights.verify(weights) + weights = (MobileNet_V2_QuantizedWeights if quantize else MobileNet_V2_Weights).verify(weights) if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) diff --git a/torchvision/prototype/models/quantization/mobilenetv3.py b/torchvision/prototype/models/quantization/mobilenetv3.py index 9d29484c18f..1a68c29c2c2 100644 --- a/torchvision/prototype/models/quantization/mobilenetv3.py +++ b/torchvision/prototype/models/quantization/mobilenetv3.py @@ -13,7 +13,7 @@ ) from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES -from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param +from .._utils import handle_legacy_interface, _ovewrite_named_param from ..mobilenetv3 import MobileNet_V3_Large_Weights, _mobilenet_v3_conf @@ -75,25 +75,22 @@ class MobileNet_V3_Large_QuantizedWeights(WeightsEnum): default = ImageNet1K_QNNPACK_V1 +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: MobileNet_V3_Large_QuantizedWeights.ImageNet1K_QNNPACK_V1 + if kwargs.get("quantize", False) + else MobileNet_V3_Large_Weights.ImageNet1K_V1, + ) +) def mobilenet_v3_large( + *, weights: Optional[Union[MobileNet_V3_Large_QuantizedWeights, MobileNet_V3_Large_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, ) -> QuantizableMobileNetV3: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - default_value = ( - MobileNet_V3_Large_QuantizedWeights.ImageNet1K_QNNPACK_V1 - if quantize - else MobileNet_V3_Large_Weights.ImageNet1K_V1 - ) - weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] - if quantize: - weights = MobileNet_V3_Large_QuantizedWeights.verify(weights) - else: - weights = MobileNet_V3_Large_Weights.verify(weights) + weights = (MobileNet_V3_Large_QuantizedWeights if quantize else MobileNet_V3_Large_Weights).verify(weights) inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs) return _mobilenet_v3_model(inverted_residual_setting, last_channel, weights, progress, quantize, **kwargs) diff --git a/torchvision/prototype/models/quantization/resnet.py b/torchvision/prototype/models/quantization/resnet.py index c6bd530f393..aea52dc0641 100644 --- a/torchvision/prototype/models/quantization/resnet.py +++ b/torchvision/prototype/models/quantization/resnet.py @@ -13,7 +13,7 @@ ) from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES -from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param +from .._utils import handle_legacy_interface, _ovewrite_named_param from ..resnet import ResNet18_Weights, ResNet50_Weights, ResNeXt101_32X8D_Weights @@ -125,63 +125,62 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): default = ImageNet1K_FBGEMM_V2 +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ResNet18_QuantizedWeights.ImageNet1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ResNet18_Weights.ImageNet1K_V1, + ) +) def resnet18( + *, weights: Optional[Union[ResNet18_QuantizedWeights, ResNet18_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, ) -> QuantizableResNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - default_value = ResNet18_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize else ResNet18_Weights.ImageNet1K_V1 - weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] - if quantize: - weights = ResNet18_QuantizedWeights.verify(weights) - else: - weights = ResNet18_Weights.verify(weights) + weights = (ResNet18_QuantizedWeights if quantize else ResNet18_Weights).verify(weights) return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs) +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ResNet50_Weights.ImageNet1K_V1, + ) +) def resnet50( + *, weights: Optional[Union[ResNet50_QuantizedWeights, ResNet50_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, ) -> QuantizableResNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - default_value = ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize else ResNet50_Weights.ImageNet1K_V1 - weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] - if quantize: - weights = ResNet50_QuantizedWeights.verify(weights) - else: - weights = ResNet50_Weights.verify(weights) + weights = (ResNet50_QuantizedWeights if quantize else ResNet50_Weights).verify(weights) return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs) +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ResNeXt101_32X8D_QuantizedWeights.ImageNet1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ResNeXt101_32X8D_Weights.ImageNet1K_V1, + ) +) def resnext101_32x8d( + *, weights: Optional[Union[ResNeXt101_32X8D_QuantizedWeights, ResNeXt101_32X8D_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, ) -> QuantizableResNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - default_value = ( - ResNeXt101_32X8D_QuantizedWeights.ImageNet1K_FBGEMM_V1 - if quantize - else ResNeXt101_32X8D_Weights.ImageNet1K_V1 - ) - weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] - if quantize: - weights = ResNeXt101_32X8D_QuantizedWeights.verify(weights) - else: - weights = ResNeXt101_32X8D_Weights.verify(weights) + weights = (ResNeXt101_32X8D_QuantizedWeights if quantize else ResNeXt101_32X8D_Weights).verify(weights) _ovewrite_named_param(kwargs, "groups", 32) _ovewrite_named_param(kwargs, "width_per_group", 8) diff --git a/torchvision/prototype/models/quantization/shufflenetv2.py b/torchvision/prototype/models/quantization/shufflenetv2.py index 111763f2614..00c1a673eb4 100644 --- a/torchvision/prototype/models/quantization/shufflenetv2.py +++ b/torchvision/prototype/models/quantization/shufflenetv2.py @@ -11,7 +11,7 @@ ) from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES -from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param +from .._utils import handle_legacy_interface, _ovewrite_named_param from ..shufflenetv2 import ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights @@ -27,6 +27,7 @@ def _shufflenetv2( stages_repeats: List[int], stages_out_channels: List[int], + *, weights: Optional[WeightsEnum], progress: bool, quantize: bool, @@ -87,47 +88,43 @@ class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum): default = ImageNet1K_FBGEMM_V1 +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ShuffleNet_V2_X0_5_QuantizedWeights.ImageNet1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ShuffleNet_V2_X0_5_Weights.ImageNet1K_V1, + ) +) def shufflenet_v2_x0_5( + *, weights: Optional[Union[ShuffleNet_V2_X0_5_QuantizedWeights, ShuffleNet_V2_X0_5_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, ) -> QuantizableShuffleNetV2: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - default_value = ( - ShuffleNet_V2_X0_5_QuantizedWeights.ImageNet1K_FBGEMM_V1 - if quantize - else ShuffleNet_V2_X0_5_Weights.ImageNet1K_V1 - ) - weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] - if quantize: - weights = ShuffleNet_V2_X0_5_QuantizedWeights.verify(weights) - else: - weights = ShuffleNet_V2_X0_5_Weights.verify(weights) - - return _shufflenetv2([4, 8, 4], [24, 48, 96, 192, 1024], weights, progress, quantize, **kwargs) + weights = (ShuffleNet_V2_X0_5_QuantizedWeights if quantize else ShuffleNet_V2_X0_5_Weights).verify(weights) + return _shufflenetv2( + [4, 8, 4], [24, 48, 96, 192, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs + ) +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ShuffleNet_V2_X1_0_QuantizedWeights.ImageNet1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ShuffleNet_V2_X1_0_Weights.ImageNet1K_V1, + ) +) def shufflenet_v2_x1_0( + *, weights: Optional[Union[ShuffleNet_V2_X1_0_QuantizedWeights, ShuffleNet_V2_X1_0_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, ) -> QuantizableShuffleNetV2: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - default_value = ( - ShuffleNet_V2_X1_0_QuantizedWeights.ImageNet1K_FBGEMM_V1 - if quantize - else ShuffleNet_V2_X1_0_Weights.ImageNet1K_V1 - ) - weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] - if quantize: - weights = ShuffleNet_V2_X1_0_QuantizedWeights.verify(weights) - else: - weights = ShuffleNet_V2_X1_0_Weights.verify(weights) - - return _shufflenetv2([4, 8, 4], [24, 116, 232, 464, 1024], weights, progress, quantize, **kwargs) + weights = (ShuffleNet_V2_X1_0_QuantizedWeights if quantize else ShuffleNet_V2_X1_0_Weights).verify(weights) + return _shufflenetv2( + [4, 8, 4], [24, 116, 232, 464, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs + ) diff --git a/torchvision/prototype/models/regnet.py b/torchvision/prototype/models/regnet.py index d810a0d1300..bee53ba5920 100644 --- a/torchvision/prototype/models/regnet.py +++ b/torchvision/prototype/models/regnet.py @@ -8,7 +8,7 @@ from ...models.regnet import RegNet, BlockParams from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES -from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param +from ._utils import handle_legacy_interface, _ovewrite_named_param __all__ = [ @@ -260,33 +260,24 @@ class RegNet_X_32GF_Weights(WeightsEnum): default = ImageNet1K_V1 -def regnet_y_400mf(weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_400MF_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", RegNet_Y_400MF_Weights.ImageNet1K_V1)) +def regnet_y_400mf(*, weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: weights = RegNet_Y_400MF_Weights.verify(weights) params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs) return _regnet(params, weights, progress, **kwargs) -def regnet_y_800mf(weights: Optional[RegNet_Y_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_800MF_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", RegNet_Y_800MF_Weights.ImageNet1K_V1)) +def regnet_y_800mf(*, weights: Optional[RegNet_Y_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: weights = RegNet_Y_800MF_Weights.verify(weights) params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs) return _regnet(params, weights, progress, **kwargs) -def regnet_y_1_6gf(weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_1_6GF_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", RegNet_Y_1_6GF_Weights.ImageNet1K_V1)) +def regnet_y_1_6gf(*, weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: weights = RegNet_Y_1_6GF_Weights.verify(weights) params = BlockParams.from_init_params( @@ -295,11 +286,8 @@ def regnet_y_1_6gf(weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: b return _regnet(params, weights, progress, **kwargs) -def regnet_y_3_2gf(weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_3_2GF_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", RegNet_Y_3_2GF_Weights.ImageNet1K_V1)) +def regnet_y_3_2gf(*, weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: weights = RegNet_Y_3_2GF_Weights.verify(weights) params = BlockParams.from_init_params( @@ -308,11 +296,8 @@ def regnet_y_3_2gf(weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: b return _regnet(params, weights, progress, **kwargs) -def regnet_y_8gf(weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_8GF_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", RegNet_Y_8GF_Weights.ImageNet1K_V1)) +def regnet_y_8gf(*, weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: weights = RegNet_Y_8GF_Weights.verify(weights) params = BlockParams.from_init_params( @@ -321,11 +306,8 @@ def regnet_y_8gf(weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool return _regnet(params, weights, progress, **kwargs) -def regnet_y_16gf(weights: Optional[RegNet_Y_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_16GF_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", RegNet_Y_16GF_Weights.ImageNet1K_V1)) +def regnet_y_16gf(*, weights: Optional[RegNet_Y_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: weights = RegNet_Y_16GF_Weights.verify(weights) params = BlockParams.from_init_params( @@ -334,11 +316,8 @@ def regnet_y_16gf(weights: Optional[RegNet_Y_16GF_Weights] = None, progress: boo return _regnet(params, weights, progress, **kwargs) -def regnet_y_32gf(weights: Optional[RegNet_Y_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_32GF_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", RegNet_Y_32GF_Weights.ImageNet1K_V1)) +def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: weights = RegNet_Y_32GF_Weights.verify(weights) params = BlockParams.from_init_params( @@ -347,77 +326,56 @@ def regnet_y_32gf(weights: Optional[RegNet_Y_32GF_Weights] = None, progress: boo return _regnet(params, weights, progress, **kwargs) -def regnet_x_400mf(weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_400MF_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", RegNet_X_400MF_Weights.ImageNet1K_V1)) +def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: weights = RegNet_X_400MF_Weights.verify(weights) params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs) return _regnet(params, weights, progress, **kwargs) -def regnet_x_800mf(weights: Optional[RegNet_X_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_800MF_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", RegNet_X_800MF_Weights.ImageNet1K_V1)) +def regnet_x_800mf(*, weights: Optional[RegNet_X_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: weights = RegNet_X_800MF_Weights.verify(weights) params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs) return _regnet(params, weights, progress, **kwargs) -def regnet_x_1_6gf(weights: Optional[RegNet_X_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_1_6GF_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", RegNet_X_1_6GF_Weights.ImageNet1K_V1)) +def regnet_x_1_6gf(*, weights: Optional[RegNet_X_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: weights = RegNet_X_1_6GF_Weights.verify(weights) params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs) return _regnet(params, weights, progress, **kwargs) -def regnet_x_3_2gf(weights: Optional[RegNet_X_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_3_2GF_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", RegNet_X_3_2GF_Weights.ImageNet1K_V1)) +def regnet_x_3_2gf(*, weights: Optional[RegNet_X_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: weights = RegNet_X_3_2GF_Weights.verify(weights) params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs) return _regnet(params, weights, progress, **kwargs) -def regnet_x_8gf(weights: Optional[RegNet_X_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_8GF_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", RegNet_X_8GF_Weights.ImageNet1K_V1)) +def regnet_x_8gf(*, weights: Optional[RegNet_X_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: weights = RegNet_X_8GF_Weights.verify(weights) params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs) return _regnet(params, weights, progress, **kwargs) -def regnet_x_16gf(weights: Optional[RegNet_X_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_16GF_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", RegNet_X_16GF_Weights.ImageNet1K_V1)) +def regnet_x_16gf(*, weights: Optional[RegNet_X_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: weights = RegNet_X_16GF_Weights.verify(weights) params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs) return _regnet(params, weights, progress, **kwargs) -def regnet_x_32gf(weights: Optional[RegNet_X_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_32GF_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", RegNet_X_32GF_Weights.ImageNet1K_V1)) +def regnet_x_32gf(*, weights: Optional[RegNet_X_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: weights = RegNet_X_32GF_Weights.verify(weights) params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs) diff --git a/torchvision/prototype/models/resnet.py b/torchvision/prototype/models/resnet.py index 3c68f0a430c..8d0266c0a33 100644 --- a/torchvision/prototype/models/resnet.py +++ b/torchvision/prototype/models/resnet.py @@ -7,7 +7,7 @@ from ...models.resnet import BasicBlock, Bottleneck, ResNet from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES -from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param +from ._utils import handle_legacy_interface, _ovewrite_named_param __all__ = [ @@ -250,61 +250,45 @@ class Wide_ResNet101_2_Weights(WeightsEnum): default = ImageNet1K_V2 -def resnet18(weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet18_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", ResNet18_Weights.ImageNet1K_V1)) +def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: weights = ResNet18_Weights.verify(weights) return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs) -def resnet34(weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet34_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", ResNet34_Weights.ImageNet1K_V1)) +def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: weights = ResNet34_Weights.verify(weights) return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs) -def resnet50(weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet50_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", ResNet50_Weights.ImageNet1K_V1)) +def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: weights = ResNet50_Weights.verify(weights) return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) -def resnet101(weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet101_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", ResNet101_Weights.ImageNet1K_V1)) +def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: weights = ResNet101_Weights.verify(weights) return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) -def resnet152(weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet152_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", ResNet152_Weights.ImageNet1K_V1)) +def resnet152(*, weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: weights = ResNet152_Weights.verify(weights) return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs) -def resnext50_32x4d(weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt50_32X4D_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", ResNeXt50_32X4D_Weights.ImageNet1K_V1)) +def resnext50_32x4d( + *, weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: bool = True, **kwargs: Any +) -> ResNet: weights = ResNeXt50_32X4D_Weights.verify(weights) _ovewrite_named_param(kwargs, "groups", 32) @@ -312,13 +296,10 @@ def resnext50_32x4d(weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) +@handle_legacy_interface(weights=("pretrained", ResNeXt101_32X8D_Weights.ImageNet1K_V1)) def resnext101_32x8d( - weights: Optional[ResNeXt101_32X8D_Weights] = None, progress: bool = True, **kwargs: Any + *, weights: Optional[ResNeXt101_32X8D_Weights] = None, progress: bool = True, **kwargs: Any ) -> ResNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt101_32X8D_Weights.ImageNet1K_V1) weights = ResNeXt101_32X8D_Weights.verify(weights) _ovewrite_named_param(kwargs, "groups", 32) @@ -326,24 +307,20 @@ def resnext101_32x8d( return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) -def wide_resnet50_2(weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", Wide_ResNet50_2_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", Wide_ResNet50_2_Weights.ImageNet1K_V1)) +def wide_resnet50_2( + *, weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any +) -> ResNet: weights = Wide_ResNet50_2_Weights.verify(weights) _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) +@handle_legacy_interface(weights=("pretrained", Wide_ResNet101_2_Weights.ImageNet1K_V1)) def wide_resnet101_2( - weights: Optional[Wide_ResNet101_2_Weights] = None, progress: bool = True, **kwargs: Any + *, weights: Optional[Wide_ResNet101_2_Weights] = None, progress: bool = True, **kwargs: Any ) -> ResNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", Wide_ResNet101_2_Weights.ImageNet1K_V1) weights = Wide_ResNet101_2_Weights.verify(weights) _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) diff --git a/torchvision/prototype/models/segmentation/deeplabv3.py b/torchvision/prototype/models/segmentation/deeplabv3.py index 30c90013c9b..1144dafd173 100644 --- a/torchvision/prototype/models/segmentation/deeplabv3.py +++ b/torchvision/prototype/models/segmentation/deeplabv3.py @@ -7,7 +7,7 @@ from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet from .._api import WeightsEnum, Weights from .._meta import _VOC_CATEGORIES -from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param +from .._utils import handle_legacy_interface, _ovewrite_value_param from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large from ..resnet import resnet50, resnet101 from ..resnet import ResNet50_Weights, ResNet101_Weights @@ -72,7 +72,12 @@ class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum): default = CocoWithVocLabels_V1 +@handle_legacy_interface( + weights=("pretrained", DeepLabV3_ResNet50_Weights.CocoWithVocLabels_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1), +) def deeplabv3_resnet50( + *, weights: Optional[DeepLabV3_ResNet50_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, @@ -80,17 +85,7 @@ def deeplabv3_resnet50( weights_backbone: Optional[ResNet50_Weights] = None, **kwargs: Any, ) -> DeepLabV3: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3_ResNet50_Weights.CocoWithVocLabels_V1) weights = DeepLabV3_ResNet50_Weights.verify(weights) - if type(weights_backbone) == bool and weights_backbone: - _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) - if "pretrained_backbone" in kwargs: - weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1 - ) weights_backbone = ResNet50_Weights.verify(weights_backbone) if weights is not None: @@ -109,7 +104,12 @@ def deeplabv3_resnet50( return model +@handle_legacy_interface( + weights=("pretrained", DeepLabV3_ResNet101_Weights.CocoWithVocLabels_V1), + weights_backbone=("pretrained_backbone", ResNet101_Weights.ImageNet1K_V1), +) def deeplabv3_resnet101( + *, weights: Optional[DeepLabV3_ResNet101_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, @@ -117,17 +117,7 @@ def deeplabv3_resnet101( weights_backbone: Optional[ResNet101_Weights] = None, **kwargs: Any, ) -> DeepLabV3: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3_ResNet101_Weights.CocoWithVocLabels_V1) weights = DeepLabV3_ResNet101_Weights.verify(weights) - if type(weights_backbone) == bool and weights_backbone: - _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) - if "pretrained_backbone" in kwargs: - weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet101_Weights.ImageNet1K_V1 - ) weights_backbone = ResNet101_Weights.verify(weights_backbone) if weights is not None: @@ -146,7 +136,12 @@ def deeplabv3_resnet101( return model +@handle_legacy_interface( + weights=("pretrained", DeepLabV3_MobileNet_V3_Large_Weights.CocoWithVocLabels_V1), + weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1), +) def deeplabv3_mobilenet_v3_large( + *, weights: Optional[DeepLabV3_MobileNet_V3_Large_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, @@ -154,19 +149,7 @@ def deeplabv3_mobilenet_v3_large( weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, **kwargs: Any, ) -> DeepLabV3: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param( - kwargs, "pretrained", "weights", DeepLabV3_MobileNet_V3_Large_Weights.CocoWithVocLabels_V1 - ) weights = DeepLabV3_MobileNet_V3_Large_Weights.verify(weights) - if type(weights_backbone) == bool and weights_backbone: - _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) - if "pretrained_backbone" in kwargs: - weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1 - ) weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) if weights is not None: diff --git a/torchvision/prototype/models/segmentation/fcn.py b/torchvision/prototype/models/segmentation/fcn.py index 42d15a0c3cf..108e6c13860 100644 --- a/torchvision/prototype/models/segmentation/fcn.py +++ b/torchvision/prototype/models/segmentation/fcn.py @@ -7,7 +7,7 @@ from ....models.segmentation.fcn import FCN, _fcn_resnet from .._api import WeightsEnum, Weights from .._meta import _VOC_CATEGORIES -from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param +from .._utils import handle_legacy_interface, _ovewrite_value_param from ..resnet import ResNet50_Weights, ResNet101_Weights, resnet50, resnet101 @@ -48,7 +48,12 @@ class FCN_ResNet101_Weights(WeightsEnum): default = CocoWithVocLabels_V1 +@handle_legacy_interface( + weights=("pretrained", FCN_ResNet50_Weights.CocoWithVocLabels_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1), +) def fcn_resnet50( + *, weights: Optional[FCN_ResNet50_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, @@ -56,17 +61,7 @@ def fcn_resnet50( weights_backbone: Optional[ResNet50_Weights] = None, **kwargs: Any, ) -> FCN: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", FCN_ResNet50_Weights.CocoWithVocLabels_V1) weights = FCN_ResNet50_Weights.verify(weights) - if type(weights_backbone) == bool and weights_backbone: - _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) - if "pretrained_backbone" in kwargs: - weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1 - ) weights_backbone = ResNet50_Weights.verify(weights_backbone) if weights is not None: @@ -85,7 +80,12 @@ def fcn_resnet50( return model +@handle_legacy_interface( + weights=("pretrained", FCN_ResNet101_Weights.CocoWithVocLabels_V1), + weights_backbone=("pretrained_backbone", ResNet101_Weights.ImageNet1K_V1), +) def fcn_resnet101( + *, weights: Optional[FCN_ResNet101_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, @@ -93,17 +93,7 @@ def fcn_resnet101( weights_backbone: Optional[ResNet101_Weights] = None, **kwargs: Any, ) -> FCN: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", FCN_ResNet101_Weights.CocoWithVocLabels_V1) weights = FCN_ResNet101_Weights.verify(weights) - if type(weights_backbone) == bool and weights_backbone: - _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) - if "pretrained_backbone" in kwargs: - weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet101_Weights.ImageNet1K_V1 - ) weights_backbone = ResNet101_Weights.verify(weights_backbone) if weights is not None: diff --git a/torchvision/prototype/models/segmentation/lraspp.py b/torchvision/prototype/models/segmentation/lraspp.py index f80e1079c87..9f74fde5e1c 100644 --- a/torchvision/prototype/models/segmentation/lraspp.py +++ b/torchvision/prototype/models/segmentation/lraspp.py @@ -7,7 +7,7 @@ from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3 from .._api import WeightsEnum, Weights from .._meta import _VOC_CATEGORIES -from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param +from .._utils import handle_legacy_interface, _ovewrite_value_param from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large @@ -29,7 +29,12 @@ class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum): default = CocoWithVocLabels_V1 +@handle_legacy_interface( + weights=("pretrained", LRASPP_MobileNet_V3_Large_Weights.CocoWithVocLabels_V1), + weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1), +) def lraspp_mobilenet_v3_large( + *, weights: Optional[LRASPP_MobileNet_V3_Large_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, @@ -39,19 +44,7 @@ def lraspp_mobilenet_v3_large( if kwargs.pop("aux_loss", False): raise NotImplementedError("This model does not use auxiliary loss") - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param( - kwargs, "pretrained", "weights", LRASPP_MobileNet_V3_Large_Weights.CocoWithVocLabels_V1 - ) weights = LRASPP_MobileNet_V3_Large_Weights.verify(weights) - if type(weights_backbone) == bool and weights_backbone: - _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) - if "pretrained_backbone" in kwargs: - weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1 - ) weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) if weights is not None: diff --git a/torchvision/prototype/models/shufflenetv2.py b/torchvision/prototype/models/shufflenetv2.py index a8857c2996e..06651a479a4 100644 --- a/torchvision/prototype/models/shufflenetv2.py +++ b/torchvision/prototype/models/shufflenetv2.py @@ -7,7 +7,7 @@ from ...models.shufflenetv2 import ShuffleNetV2 from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES -from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param +from ._utils import handle_legacy_interface, _ovewrite_named_param __all__ = [ @@ -82,49 +82,37 @@ class ShuffleNet_V2_X2_0_Weights(WeightsEnum): pass +@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X0_5_Weights.ImageNet1K_V1)) def shufflenet_v2_x0_5( - weights: Optional[ShuffleNet_V2_X0_5_Weights] = None, progress: bool = True, **kwargs: Any + *, weights: Optional[ShuffleNet_V2_X0_5_Weights] = None, progress: bool = True, **kwargs: Any ) -> ShuffleNetV2: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNet_V2_X0_5_Weights.ImageNet1K_V1) weights = ShuffleNet_V2_X0_5_Weights.verify(weights) return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) +@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X1_0_Weights.ImageNet1K_V1)) def shufflenet_v2_x1_0( - weights: Optional[ShuffleNet_V2_X1_0_Weights] = None, progress: bool = True, **kwargs: Any + *, weights: Optional[ShuffleNet_V2_X1_0_Weights] = None, progress: bool = True, **kwargs: Any ) -> ShuffleNetV2: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNet_V2_X1_0_Weights.ImageNet1K_V1) weights = ShuffleNet_V2_X1_0_Weights.verify(weights) return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) +@handle_legacy_interface(weights=("pretrained", None)) def shufflenet_v2_x1_5( - weights: Optional[ShuffleNet_V2_X1_5_Weights] = None, progress: bool = True, **kwargs: Any + *, weights: Optional[ShuffleNet_V2_X1_5_Weights] = None, progress: bool = True, **kwargs: Any ) -> ShuffleNetV2: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", None) weights = ShuffleNet_V2_X1_5_Weights.verify(weights) return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) +@handle_legacy_interface(weights=("pretrained", None)) def shufflenet_v2_x2_0( - weights: Optional[ShuffleNet_V2_X2_0_Weights] = None, progress: bool = True, **kwargs: Any + *, weights: Optional[ShuffleNet_V2_X2_0_Weights] = None, progress: bool = True, **kwargs: Any ) -> ShuffleNetV2: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", None) weights = ShuffleNet_V2_X2_0_Weights.verify(weights) return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) diff --git a/torchvision/prototype/models/squeezenet.py b/torchvision/prototype/models/squeezenet.py index 77c9a1629d4..6868569097b 100644 --- a/torchvision/prototype/models/squeezenet.py +++ b/torchvision/prototype/models/squeezenet.py @@ -7,7 +7,7 @@ from ...models.squeezenet import SqueezeNet from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES -from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param +from ._utils import handle_legacy_interface, _ovewrite_named_param __all__ = ["SqueezeNet", "SqueezeNet1_0_Weights", "SqueezeNet1_1_Weights", "squeezenet1_0", "squeezenet1_1"] @@ -47,11 +47,10 @@ class SqueezeNet1_1_Weights(WeightsEnum): default = ImageNet1K_V1 -def squeezenet1_0(weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_0_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", SqueezeNet1_0_Weights.ImageNet1K_V1)) +def squeezenet1_0( + *, weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any +) -> SqueezeNet: weights = SqueezeNet1_0_Weights.verify(weights) if weights is not None: @@ -65,11 +64,10 @@ def squeezenet1_0(weights: Optional[SqueezeNet1_0_Weights] = None, progress: boo return model -def squeezenet1_1(weights: Optional[SqueezeNet1_1_Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_1_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", SqueezeNet1_1_Weights.ImageNet1K_V1)) +def squeezenet1_1( + *, weights: Optional[SqueezeNet1_1_Weights] = None, progress: bool = True, **kwargs: Any +) -> SqueezeNet: weights = SqueezeNet1_1_Weights.verify(weights) if weights is not None: diff --git a/torchvision/prototype/models/vgg.py b/torchvision/prototype/models/vgg.py index 708608826e0..78e976d5fb3 100644 --- a/torchvision/prototype/models/vgg.py +++ b/torchvision/prototype/models/vgg.py @@ -7,7 +7,7 @@ from ...models.vgg import VGG, make_layers, cfgs from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES -from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param +from ._utils import handle_legacy_interface, _ovewrite_named_param __all__ = [ @@ -169,81 +169,57 @@ class VGG19_BN_Weights(WeightsEnum): default = ImageNet1K_V1 -def vgg11(weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", VGG11_Weights.ImageNet1K_V1)) +def vgg11(*, weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: weights = VGG11_Weights.verify(weights) return _vgg("A", False, weights, progress, **kwargs) -def vgg11_bn(weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11_BN_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", VGG11_BN_Weights.ImageNet1K_V1)) +def vgg11_bn(*, weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: weights = VGG11_BN_Weights.verify(weights) return _vgg("A", True, weights, progress, **kwargs) -def vgg13(weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", VGG13_Weights.ImageNet1K_V1)) +def vgg13(*, weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: weights = VGG13_Weights.verify(weights) return _vgg("B", False, weights, progress, **kwargs) -def vgg13_bn(weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13_BN_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", VGG13_BN_Weights.ImageNet1K_V1)) +def vgg13_bn(*, weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: weights = VGG13_BN_Weights.verify(weights) return _vgg("B", True, weights, progress, **kwargs) -def vgg16(weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", VGG16_Weights.ImageNet1K_V1)) +def vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: weights = VGG16_Weights.verify(weights) return _vgg("D", False, weights, progress, **kwargs) -def vgg16_bn(weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16_BN_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", VGG16_BN_Weights.ImageNet1K_V1)) +def vgg16_bn(*, weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: weights = VGG16_BN_Weights.verify(weights) return _vgg("D", True, weights, progress, **kwargs) -def vgg19(weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", VGG19_Weights.ImageNet1K_V1)) +def vgg19(*, weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: weights = VGG19_Weights.verify(weights) return _vgg("E", False, weights, progress, **kwargs) -def vgg19_bn(weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19_BN_Weights.ImageNet1K_V1) +@handle_legacy_interface(weights=("pretrained", VGG19_BN_Weights.ImageNet1K_V1)) +def vgg19_bn(*, weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: weights = VGG19_BN_Weights.verify(weights) return _vgg("E", True, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/video/resnet.py b/torchvision/prototype/models/video/resnet.py index 48c4293f0e1..af1a3963fdf 100644 --- a/torchvision/prototype/models/video/resnet.py +++ b/torchvision/prototype/models/video/resnet.py @@ -17,7 +17,7 @@ ) from .._api import WeightsEnum, Weights from .._meta import _KINETICS400_CATEGORIES -from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param +from .._utils import handle_legacy_interface, _ovewrite_named_param __all__ = [ @@ -98,11 +98,8 @@ class R2Plus1D_18_Weights(WeightsEnum): default = Kinetics400_V1 -def r3d_18(weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", R3D_18_Weights.Kinetics400_V1) +@handle_legacy_interface(weights=("pretrained", R3D_18_Weights.Kinetics400_V1)) +def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: weights = R3D_18_Weights.verify(weights) return _video_resnet( @@ -116,11 +113,8 @@ def r3d_18(weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kw ) -def mc3_18(weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", MC3_18_Weights.Kinetics400_V1) +@handle_legacy_interface(weights=("pretrained", MC3_18_Weights.Kinetics400_V1)) +def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: weights = MC3_18_Weights.verify(weights) return _video_resnet( @@ -134,11 +128,8 @@ def mc3_18(weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kw ) -def r2plus1d_18(weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", R2Plus1D_18_Weights.Kinetics400_V1) +@handle_legacy_interface(weights=("pretrained", R2Plus1D_18_Weights.Kinetics400_V1)) +def r2plus1d_18(*, weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: weights = R2Plus1D_18_Weights.verify(weights) return _video_resnet( diff --git a/torchvision/prototype/models/vision_transformer.py b/torchvision/prototype/models/vision_transformer.py index bbe5aba262c..9794559745d 100644 --- a/torchvision/prototype/models/vision_transformer.py +++ b/torchvision/prototype/models/vision_transformer.py @@ -12,7 +12,7 @@ from torch import Tensor from ._api import WeightsEnum -from ._utils import _deprecated_param, _deprecated_positional +from ._utils import handle_legacy_interface __all__ = [ @@ -279,7 +279,8 @@ def _vision_transformer( return model -def vit_b_16(weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: +@handle_legacy_interface(weights=("pretrained", None)) +def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_b_16 architecture from `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. @@ -289,10 +290,6 @@ def vit_b_16(weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, Default: None. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True. """ - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", None) weights = ViT_B_16_Weights.verify(weights) return _vision_transformer( @@ -307,7 +304,8 @@ def vit_b_16(weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, ) -def vit_b_32(weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: +@handle_legacy_interface(weights=("pretrained", None)) +def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_b_32 architecture from `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. @@ -317,10 +315,6 @@ def vit_b_32(weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, Default: None. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True. """ - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", None) weights = ViT_B_32_Weights.verify(weights) return _vision_transformer( @@ -335,7 +329,8 @@ def vit_b_32(weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, ) -def vit_l_16(weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: +@handle_legacy_interface(weights=("pretrained", None)) +def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_l_16 architecture from `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. @@ -345,10 +340,6 @@ def vit_l_16(weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, Default: None. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True. """ - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", None) weights = ViT_L_16_Weights.verify(weights) return _vision_transformer( @@ -363,7 +354,8 @@ def vit_l_16(weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, ) -def vit_l_32(weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: +@handle_legacy_interface(weights=("pretrained", None)) +def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_l_32 architecture from `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. @@ -373,10 +365,6 @@ def vit_l_32(weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, Default: None. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True. """ - if type(weights) == bool and weights: - _deprecated_positional(kwargs, "pretrained", "weights", True) - if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", None) weights = ViT_L_32_Weights.verify(weights) return _vision_transformer( diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py index afed3e9d279..68128a2b381 100644 --- a/torchvision/prototype/utils/_internal.py +++ b/torchvision/prototype/utils/_internal.py @@ -1,9 +1,12 @@ import collections.abc import difflib import enum +import functools +import inspect import os import os.path import textwrap +import warnings from typing import Collection, Sequence, Callable, Any, Iterator, NoReturn, Mapping, TypeVar, Iterable, Tuple, cast __all__ = [ @@ -13,6 +16,7 @@ "FrozenMapping", "make_repr", "FrozenBunch", + "kwonly_to_pos_or_kw", ] @@ -126,3 +130,54 @@ def __delattr__(self, item: Any) -> NoReturn: def __repr__(self) -> str: return make_repr(type(self).__name__, self.items()) + + +def kwonly_to_pos_or_kw(fn: Callable[..., D]) -> Callable[..., D]: + """Decorates a function that uses keyword only parameters to also allow them being passed as positionals. + + For example, consider the use case of changing the signature of ``old_fn`` into the one from ``new_fn``: + + .. code:: + + def old_fn(foo, bar, baz=None): + ... + + def new_fn(foo, *, bar, baz=None): + ... + + Calling ``old_fn("foo", "bar, "baz")`` was valid, but the same call is no longer valid with ``new_fn``. To keep BC + and at the same time warn the user of the deprecation, this decorator can be used: + + .. code:: + + @kwonly_to_pos_or_kw + def new_fn(foo, *, bar, baz=None): + ... + + new_fn("foo", "bar, "baz") + """ + params = inspect.signature(fn).parameters + + try: + keyword_only_start_idx = next( + idx for idx, param in enumerate(params.values()) if param.kind == param.KEYWORD_ONLY + ) + except StopIteration: + raise TypeError(f"Found no keyword-only parameter on function '{fn.__name__}'") from None + + keyword_only_params = tuple(inspect.signature(fn).parameters)[keyword_only_start_idx:] + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> D: + args, keyword_only_args = args[:keyword_only_start_idx], args[keyword_only_start_idx:] + if keyword_only_args: + keyword_only_kwargs = dict(zip(keyword_only_params, keyword_only_args)) + warnings.warn( + f"Using {sequence_to_str(tuple(keyword_only_kwargs.keys()), separate_last='and ')} as positional " + f"parameter(s) is deprecated. Please use keyword parameter(s) instead." + ) + kwargs.update(keyword_only_kwargs) + + return fn(*args, **kwargs) + + return wrapper