Skip to content

Commit

Permalink
Provide helper methods for parameter checks to reduce duplicate code.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Nov 24, 2021
1 parent 60f5bc5 commit 1172e1f
Show file tree
Hide file tree
Showing 30 changed files with 245 additions and 312 deletions.
20 changes: 20 additions & 0 deletions torchvision/prototype/models/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import warnings
from typing import Any, Dict, Optional, TypeVar

from ._api import Weights


W = TypeVar("W", bound=Weights)


def _deprecated_param(
deprecated_param: str, new_param: str, default_value: Optional[W], kwargs: Dict[str, Any]
) -> Optional[W]:
warnings.warn(f"The parameter '{deprecated_param}' is deprecated, please use {new_param} instead.")
if kwargs.pop(deprecated_param):
if default_value is not None:
return default_value
else:
raise ValueError("No checkpoint is available for model.")
else:
return None
6 changes: 3 additions & 3 deletions torchvision/prototype/models/alexnet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from functools import partial
from typing import Any, Optional

Expand All @@ -8,6 +7,7 @@
from ...models.alexnet import AlexNet
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param


__all__ = ["AlexNet", "AlexNetWeights", "alexnet"]
Expand All @@ -30,9 +30,9 @@ class AlexNetWeights(Weights):

def alexnet(weights: Optional[AlexNetWeights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = AlexNetWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param("pretrained", "weights", AlexNetWeights.ImageNet1K_RefV1, kwargs)
weights = AlexNetWeights.verify(weights)

if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])

Expand Down
14 changes: 5 additions & 9 deletions torchvision/prototype/models/densenet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import re
import warnings
from functools import partial
from typing import Any, Optional, Tuple

Expand All @@ -10,6 +9,7 @@
from ...models.densenet import DenseNet
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param


__all__ = [
Expand Down Expand Up @@ -121,35 +121,31 @@ class DenseNet201Weights(Weights):

def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DenseNet121Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = _deprecated_param("pretrained", "weights", DenseNet121Weights.ImageNet1K_Community, kwargs)
weights = DenseNet121Weights.verify(weights)

return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)


def densenet161(weights: Optional[DenseNet161Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DenseNet161Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = _deprecated_param("pretrained", "weights", DenseNet161Weights.ImageNet1K_Community, kwargs)
weights = DenseNet161Weights.verify(weights)

return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs)


def densenet169(weights: Optional[DenseNet169Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DenseNet169Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = _deprecated_param("pretrained", "weights", DenseNet169Weights.ImageNet1K_Community, kwargs)
weights = DenseNet169Weights.verify(weights)

return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs)


def densenet201(weights: Optional[DenseNet201Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DenseNet201Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = _deprecated_param("pretrained", "weights", DenseNet201Weights.ImageNet1K_Community, kwargs)
weights = DenseNet201Weights.verify(weights)

return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs)
26 changes: 13 additions & 13 deletions torchvision/prototype/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from typing import Any, Optional, Union

from torchvision.prototype.transforms import CocoEval
Expand All @@ -15,6 +14,7 @@
)
from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large
from ..resnet import ResNet50Weights, resnet50

Expand Down Expand Up @@ -81,12 +81,12 @@ def fasterrcnn_resnet50_fpn(
**kwargs: Any,
) -> FasterRCNN:
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = FasterRCNNResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param("pretrained", "weights", FasterRCNNResNet50FPNWeights.Coco_RefV1, kwargs)
weights = FasterRCNNResNet50FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = _deprecated_param(
"pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1, kwargs
)
weights_backbone = ResNet50Weights.verify(weights_backbone)

if weights is not None:
Expand Down Expand Up @@ -160,12 +160,12 @@ def fasterrcnn_mobilenet_v3_large_fpn(
**kwargs: Any,
) -> FasterRCNN:
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = FasterRCNNMobileNetV3LargeFPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param("pretrained", "weights", FasterRCNNMobileNetV3LargeFPNWeights.Coco_RefV1, kwargs)
weights = FasterRCNNMobileNetV3LargeFPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = _deprecated_param(
"pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1, kwargs
)
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)

defaults = {
Expand All @@ -192,12 +192,12 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(
**kwargs: Any,
) -> FasterRCNN:
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = FasterRCNNMobileNetV3Large320FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param("pretrained", "weights", FasterRCNNMobileNetV3Large320FPNWeights.Coco_RefV1, kwargs)
weights = FasterRCNNMobileNetV3Large320FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = _deprecated_param(
"pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1, kwargs
)
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)

defaults = {
Expand Down
20 changes: 9 additions & 11 deletions torchvision/prototype/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from typing import Any, Optional

from torchvision.prototype.transforms import CocoEval
Expand All @@ -12,6 +11,7 @@
)
from .._api import Weights, WeightEntry
from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
from .._utils import _deprecated_param
from ..resnet import ResNet50Weights, resnet50


Expand Down Expand Up @@ -58,18 +58,16 @@ def keypointrcnn_resnet50_fpn(
**kwargs: Any,
) -> KeypointRCNN:
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
pretrained = kwargs.pop("pretrained")
if type(pretrained) == str and pretrained == "legacy":
weights = KeypointRCNNResNet50FPNWeights.Coco_RefV1_Legacy
elif type(pretrained) == bool and pretrained:
weights = KeypointRCNNResNet50FPNWeights.Coco_RefV1
else:
weights = None
default_value = KeypointRCNNResNet50FPNWeights.Coco_RefV1
if kwargs["pretrained"] == "legacy":
default_value = KeypointRCNNResNet50FPNWeights.Coco_RefV1_Legacy
kwargs["pretrained"] = True
weights = _deprecated_param("pretrained", "weights", default_value, kwargs)
weights = KeypointRCNNResNet50FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = _deprecated_param(
"pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1, kwargs
)
weights_backbone = ResNet50Weights.verify(weights_backbone)

if weights is not None:
Expand Down
10 changes: 5 additions & 5 deletions torchvision/prototype/models/detection/mask_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from typing import Any, Optional

from torchvision.prototype.transforms import CocoEval
Expand All @@ -13,6 +12,7 @@
)
from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param
from ..resnet import ResNet50Weights, resnet50


Expand Down Expand Up @@ -46,12 +46,12 @@ def maskrcnn_resnet50_fpn(
**kwargs: Any,
) -> MaskRCNN:
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = MaskRCNNResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param("pretrained", "weights", MaskRCNNResNet50FPNWeights.Coco_RefV1, kwargs)
weights = MaskRCNNResNet50FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = _deprecated_param(
"pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1, kwargs
)
weights_backbone = ResNet50Weights.verify(weights_backbone)

if weights is not None:
Expand Down
10 changes: 5 additions & 5 deletions torchvision/prototype/models/detection/retinanet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from typing import Any, Optional

from torchvision.prototype.transforms import CocoEval
Expand All @@ -14,6 +13,7 @@
)
from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param
from ..resnet import ResNet50Weights, resnet50


Expand Down Expand Up @@ -46,12 +46,12 @@ def retinanet_resnet50_fpn(
**kwargs: Any,
) -> RetinaNet:
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = RetinaNetResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param("pretrained", "weights", RetinaNetResNet50FPNWeights.Coco_RefV1, kwargs)
weights = RetinaNetResNet50FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = _deprecated_param(
"pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1, kwargs
)
weights_backbone = ResNet50Weights.verify(weights_backbone)

if weights is not None:
Expand Down
9 changes: 5 additions & 4 deletions torchvision/prototype/models/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param
from ..vgg import VGG16Weights, vgg16


Expand Down Expand Up @@ -44,12 +45,12 @@ def ssd300_vgg16(
**kwargs: Any,
) -> SSD:
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = SSD300VGG16Weights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param("pretrained", "weights", SSD300VGG16Weights.Coco_RefV1, kwargs)
weights = SSD300VGG16Weights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = VGG16Weights.ImageNet1K_Features if kwargs.pop("pretrained_backbone") else None
weights_backbone = _deprecated_param(
"pretrained_backbone", "weights_backbone", VGG16Weights.ImageNet1K_Features, kwargs
)
weights_backbone = VGG16Weights.verify(weights_backbone)

if "size" in kwargs:
Expand Down
9 changes: 5 additions & 4 deletions torchvision/prototype/models/detection/ssdlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large


Expand Down Expand Up @@ -50,12 +51,12 @@ def ssdlite320_mobilenet_v3_large(
**kwargs: Any,
) -> SSD:
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = SSDlite320MobileNetV3LargeFPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param("pretrained", "weights", SSDlite320MobileNetV3LargeFPNWeights.Coco_RefV1, kwargs)
weights = SSDlite320MobileNetV3LargeFPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = _deprecated_param(
"pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1, kwargs
)
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)

if "size" in kwargs:
Expand Down
Loading

0 comments on commit 1172e1f

Please sign in to comment.