From 9ff811ba6ae4af939aa8fb6f44b0e7b5eaef818a Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 24 Nov 2021 20:10:32 +0000 Subject: [PATCH] Throw errors on silent config overwrites from weight meta-data and legacy builders. --- torchvision/prototype/models/_utils.py | 15 +++++++++++++++ torchvision/prototype/models/alexnet.py | 4 ++-- torchvision/prototype/models/densenet.py | 4 ++-- .../prototype/models/detection/faster_rcnn.py | 6 +++--- .../prototype/models/detection/keypoint_rcnn.py | 6 +++--- .../prototype/models/detection/mask_rcnn.py | 4 ++-- .../prototype/models/detection/retinanet.py | 4 ++-- torchvision/prototype/models/detection/ssd.py | 4 ++-- .../prototype/models/detection/ssdlite.py | 4 ++-- torchvision/prototype/models/efficientnet.py | 4 ++-- torchvision/prototype/models/googlenet.py | 10 +++++----- torchvision/prototype/models/inception.py | 10 +++++----- torchvision/prototype/models/mnasnet.py | 4 ++-- torchvision/prototype/models/mobilenetv2.py | 4 ++-- torchvision/prototype/models/mobilenetv3.py | 4 ++-- .../prototype/models/quantization/googlenet.py | 12 ++++++------ .../prototype/models/quantization/inception.py | 10 +++++----- .../prototype/models/quantization/mobilenetv2.py | 6 +++--- .../prototype/models/quantization/mobilenetv3.py | 6 +++--- .../prototype/models/quantization/resnet.py | 10 +++++----- .../models/quantization/shufflenetv2.py | 6 +++--- torchvision/prototype/models/regnet.py | 4 ++-- torchvision/prototype/models/resnet.py | 16 ++++++++-------- .../prototype/models/segmentation/deeplabv3.py | 14 +++++++------- torchvision/prototype/models/segmentation/fcn.py | 10 +++++----- .../prototype/models/segmentation/lraspp.py | 4 ++-- torchvision/prototype/models/shufflenetv2.py | 4 ++-- torchvision/prototype/models/squeezenet.py | 6 +++--- torchvision/prototype/models/vgg.py | 4 ++-- torchvision/prototype/models/video/resnet.py | 4 ++-- 30 files changed, 109 insertions(+), 94 deletions(-) diff --git a/torchvision/prototype/models/_utils.py b/torchvision/prototype/models/_utils.py index c8827ba79be..7decd029736 100644 --- a/torchvision/prototype/models/_utils.py +++ b/torchvision/prototype/models/_utils.py @@ -18,3 +18,18 @@ def _deprecated_param( raise ValueError("No checkpoint is available for model.") else: return None + + +def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: Any) -> None: + if param in kwargs: + if kwargs[param] != new_value: + raise ValueError(f"The parameter {param} expected value {new_value} but got {kwargs[param]} instead.") + else: + kwargs[param] = new_value + + +def _ovewrite_value_param(param: Any, new_value: Any) -> Any: + if param is not None: + if param != new_value: + raise ValueError(f"The parameter {param} expected value {new_value} but got {param} instead.") + return new_value diff --git a/torchvision/prototype/models/alexnet.py b/torchvision/prototype/models/alexnet.py index 3d03b0b1421..3b6cba289ac 100644 --- a/torchvision/prototype/models/alexnet.py +++ b/torchvision/prototype/models/alexnet.py @@ -7,7 +7,7 @@ from ...models.alexnet import AlexNet from ._api import Weights, WeightEntry from ._meta import _IMAGENET_CATEGORIES -from ._utils import _deprecated_param +from ._utils import _deprecated_param, _ovewrite_named_param __all__ = ["AlexNet", "AlexNetWeights", "alexnet"] @@ -34,7 +34,7 @@ def alexnet(weights: Optional[AlexNetWeights] = None, progress: bool = True, **k weights = AlexNetWeights.verify(weights) if weights is not None: - kwargs["num_classes"] = len(weights.meta["categories"]) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) model = AlexNet(**kwargs) diff --git a/torchvision/prototype/models/densenet.py b/torchvision/prototype/models/densenet.py index 8c4cedbf716..c8d4d311ec3 100644 --- a/torchvision/prototype/models/densenet.py +++ b/torchvision/prototype/models/densenet.py @@ -9,7 +9,7 @@ from ...models.densenet import DenseNet from ._api import Weights, WeightEntry from ._meta import _IMAGENET_CATEGORIES -from ._utils import _deprecated_param +from ._utils import _deprecated_param, _ovewrite_named_param __all__ = [ @@ -53,7 +53,7 @@ def _densenet( **kwargs: Any, ) -> DenseNet: if weights is not None: - kwargs["num_classes"] = len(weights.meta["categories"]) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) diff --git a/torchvision/prototype/models/detection/faster_rcnn.py b/torchvision/prototype/models/detection/faster_rcnn.py index 69710636ccf..e9db66831cd 100644 --- a/torchvision/prototype/models/detection/faster_rcnn.py +++ b/torchvision/prototype/models/detection/faster_rcnn.py @@ -14,7 +14,7 @@ ) from .._api import Weights, WeightEntry from .._meta import _COCO_CATEGORIES -from .._utils import _deprecated_param +from .._utils import _deprecated_param, _ovewrite_value_param from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large from ..resnet import ResNet50Weights, resnet50 @@ -91,7 +91,7 @@ def fasterrcnn_resnet50_fpn( if weights is not None: weights_backbone = None - num_classes = len(weights.meta["categories"]) + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) elif num_classes is None: num_classes = 91 @@ -121,7 +121,7 @@ def _fasterrcnn_mobilenet_v3_large_fpn( ) -> FasterRCNN: if weights is not None: weights_backbone = None - num_classes = len(weights.meta["categories"]) + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) elif num_classes is None: num_classes = 91 diff --git a/torchvision/prototype/models/detection/keypoint_rcnn.py b/torchvision/prototype/models/detection/keypoint_rcnn.py index 7b54f7719e3..4294bbaa21a 100644 --- a/torchvision/prototype/models/detection/keypoint_rcnn.py +++ b/torchvision/prototype/models/detection/keypoint_rcnn.py @@ -11,7 +11,7 @@ ) from .._api import Weights, WeightEntry from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES -from .._utils import _deprecated_param +from .._utils import _deprecated_param, _ovewrite_value_param from ..resnet import ResNet50Weights, resnet50 @@ -72,8 +72,8 @@ def keypointrcnn_resnet50_fpn( if weights is not None: weights_backbone = None - num_classes = len(weights.meta["categories"]) - num_keypoints = len(weights.meta["keypoint_names"]) + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + num_keypoints = _ovewrite_value_param(num_keypoints, len(weights.meta["keypoint_names"])) else: if num_classes is None: num_classes = 2 diff --git a/torchvision/prototype/models/detection/mask_rcnn.py b/torchvision/prototype/models/detection/mask_rcnn.py index 872d57ff420..f756ab1a600 100644 --- a/torchvision/prototype/models/detection/mask_rcnn.py +++ b/torchvision/prototype/models/detection/mask_rcnn.py @@ -12,7 +12,7 @@ ) from .._api import Weights, WeightEntry from .._meta import _COCO_CATEGORIES -from .._utils import _deprecated_param +from .._utils import _deprecated_param, _ovewrite_value_param from ..resnet import ResNet50Weights, resnet50 @@ -56,7 +56,7 @@ def maskrcnn_resnet50_fpn( if weights is not None: weights_backbone = None - num_classes = len(weights.meta["categories"]) + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) elif num_classes is None: num_classes = 91 diff --git a/torchvision/prototype/models/detection/retinanet.py b/torchvision/prototype/models/detection/retinanet.py index 29e83fdbf41..6e87053ff8d 100644 --- a/torchvision/prototype/models/detection/retinanet.py +++ b/torchvision/prototype/models/detection/retinanet.py @@ -13,7 +13,7 @@ ) from .._api import Weights, WeightEntry from .._meta import _COCO_CATEGORIES -from .._utils import _deprecated_param +from .._utils import _deprecated_param, _ovewrite_value_param from ..resnet import ResNet50Weights, resnet50 @@ -56,7 +56,7 @@ def retinanet_resnet50_fpn( if weights is not None: weights_backbone = None - num_classes = len(weights.meta["categories"]) + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) elif num_classes is None: num_classes = 91 diff --git a/torchvision/prototype/models/detection/ssd.py b/torchvision/prototype/models/detection/ssd.py index 45a93835bd9..f4eb4a84956 100644 --- a/torchvision/prototype/models/detection/ssd.py +++ b/torchvision/prototype/models/detection/ssd.py @@ -12,7 +12,7 @@ ) from .._api import Weights, WeightEntry from .._meta import _COCO_CATEGORIES -from .._utils import _deprecated_param +from .._utils import _deprecated_param, _ovewrite_value_param from ..vgg import VGG16Weights, vgg16 @@ -58,7 +58,7 @@ def ssd300_vgg16( if weights is not None: weights_backbone = None - num_classes = len(weights.meta["categories"]) + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) elif num_classes is None: num_classes = 91 diff --git a/torchvision/prototype/models/detection/ssdlite.py b/torchvision/prototype/models/detection/ssdlite.py index 7e3fa4b8268..616821a6067 100644 --- a/torchvision/prototype/models/detection/ssdlite.py +++ b/torchvision/prototype/models/detection/ssdlite.py @@ -17,7 +17,7 @@ ) from .._api import Weights, WeightEntry from .._meta import _COCO_CATEGORIES -from .._utils import _deprecated_param +from .._utils import _deprecated_param, _ovewrite_value_param from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large @@ -64,7 +64,7 @@ def ssdlite320_mobilenet_v3_large( if weights is not None: weights_backbone = None - num_classes = len(weights.meta["categories"]) + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) elif num_classes is None: num_classes = 91 diff --git a/torchvision/prototype/models/efficientnet.py b/torchvision/prototype/models/efficientnet.py index 7740a3ef376..acb9296d536 100644 --- a/torchvision/prototype/models/efficientnet.py +++ b/torchvision/prototype/models/efficientnet.py @@ -8,7 +8,7 @@ from ...models.efficientnet import EfficientNet, MBConvConfig from ._api import Weights, WeightEntry from ._meta import _IMAGENET_CATEGORIES -from ._utils import _deprecated_param +from ._utils import _deprecated_param, _ovewrite_named_param __all__ = [ @@ -41,7 +41,7 @@ def _efficientnet( **kwargs: Any, ) -> EfficientNet: if weights is not None: - kwargs["num_classes"] = len(weights.meta["categories"]) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) bneck_conf = partial(MBConvConfig, width_mult=width_mult, depth_mult=depth_mult) inverted_residual_setting = [ diff --git a/torchvision/prototype/models/googlenet.py b/torchvision/prototype/models/googlenet.py index fbc43cc740a..45cf13ce54f 100644 --- a/torchvision/prototype/models/googlenet.py +++ b/torchvision/prototype/models/googlenet.py @@ -8,7 +8,7 @@ from ...models.googlenet import GoogLeNet, GoogLeNetOutputs, _GoogLeNetOutputs from ._api import Weights, WeightEntry from ._meta import _IMAGENET_CATEGORIES -from ._utils import _deprecated_param +from ._utils import _deprecated_param, _ovewrite_named_param __all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNetWeights", "googlenet"] @@ -37,10 +37,10 @@ def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True, original_aux_logits = kwargs.get("aux_logits", False) if weights is not None: if "transform_input" not in kwargs: - kwargs["transform_input"] = True - kwargs["aux_logits"] = True - kwargs["init_weights"] = False - kwargs["num_classes"] = len(weights.meta["categories"]) + _ovewrite_named_param(kwargs, "transform_input", True) + _ovewrite_named_param(kwargs, "aux_logits", True) + _ovewrite_named_param(kwargs, "init_weights", False) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) model = GoogLeNet(**kwargs) diff --git a/torchvision/prototype/models/inception.py b/torchvision/prototype/models/inception.py index dc6413795fd..5aa8606b560 100644 --- a/torchvision/prototype/models/inception.py +++ b/torchvision/prototype/models/inception.py @@ -7,7 +7,7 @@ from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs from ._api import Weights, WeightEntry from ._meta import _IMAGENET_CATEGORIES -from ._utils import _deprecated_param +from ._utils import _deprecated_param, _ovewrite_named_param __all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "InceptionV3Weights", "inception_v3"] @@ -36,10 +36,10 @@ def inception_v3(weights: Optional[InceptionV3Weights] = None, progress: bool = original_aux_logits = kwargs.get("aux_logits", True) if weights is not None: if "transform_input" not in kwargs: - kwargs["transform_input"] = True - kwargs["aux_logits"] = True - kwargs["init_weights"] = False - kwargs["num_classes"] = len(weights.meta["categories"]) + _ovewrite_named_param(kwargs, "transform_input", True) + _ovewrite_named_param(kwargs, "aux_logits", True) + _ovewrite_named_param(kwargs, "init_weights", False) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) model = Inception3(**kwargs) diff --git a/torchvision/prototype/models/mnasnet.py b/torchvision/prototype/models/mnasnet.py index d26ac9112a8..06726ec073e 100644 --- a/torchvision/prototype/models/mnasnet.py +++ b/torchvision/prototype/models/mnasnet.py @@ -7,7 +7,7 @@ from ...models.mnasnet import MNASNet from ._api import Weights, WeightEntry from ._meta import _IMAGENET_CATEGORIES -from ._utils import _deprecated_param +from ._utils import _deprecated_param, _ovewrite_named_param __all__ = [ @@ -67,7 +67,7 @@ class MNASNet1_3Weights(Weights): def _mnasnet(alpha: float, weights: Optional[Weights], progress: bool, **kwargs: Any) -> MNASNet: if weights is not None: - kwargs["num_classes"] = len(weights.meta["categories"]) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) model = MNASNet(alpha, **kwargs) diff --git a/torchvision/prototype/models/mobilenetv2.py b/torchvision/prototype/models/mobilenetv2.py index 1252275ea38..7dde50f3daa 100644 --- a/torchvision/prototype/models/mobilenetv2.py +++ b/torchvision/prototype/models/mobilenetv2.py @@ -7,7 +7,7 @@ from ...models.mobilenetv2 import MobileNetV2 from ._api import Weights, WeightEntry from ._meta import _IMAGENET_CATEGORIES -from ._utils import _deprecated_param +from ._utils import _deprecated_param, _ovewrite_named_param __all__ = ["MobileNetV2", "MobileNetV2Weights", "mobilenet_v2"] @@ -34,7 +34,7 @@ def mobilenet_v2(weights: Optional[MobileNetV2Weights] = None, progress: bool = weights = MobileNetV2Weights.verify(weights) if weights is not None: - kwargs["num_classes"] = len(weights.meta["categories"]) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) model = MobileNetV2(**kwargs) diff --git a/torchvision/prototype/models/mobilenetv3.py b/torchvision/prototype/models/mobilenetv3.py index 166b1c58dfc..e1eabd1d670 100644 --- a/torchvision/prototype/models/mobilenetv3.py +++ b/torchvision/prototype/models/mobilenetv3.py @@ -7,7 +7,7 @@ from ...models.mobilenetv3 import MobileNetV3, _mobilenet_v3_conf, InvertedResidualConfig from ._api import Weights, WeightEntry from ._meta import _IMAGENET_CATEGORIES -from ._utils import _deprecated_param +from ._utils import _deprecated_param, _ovewrite_named_param __all__ = [ @@ -27,7 +27,7 @@ def _mobilenet_v3( **kwargs: Any, ) -> MobileNetV3: if weights is not None: - kwargs["num_classes"] = len(weights.meta["categories"]) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs) diff --git a/torchvision/prototype/models/quantization/googlenet.py b/torchvision/prototype/models/quantization/googlenet.py index da5003f6a76..738feff103a 100644 --- a/torchvision/prototype/models/quantization/googlenet.py +++ b/torchvision/prototype/models/quantization/googlenet.py @@ -12,7 +12,7 @@ ) from .._api import Weights, WeightEntry from .._meta import _IMAGENET_CATEGORIES -from .._utils import _deprecated_param +from .._utils import _deprecated_param, _ovewrite_named_param from ..googlenet import GoogLeNetWeights @@ -60,12 +60,12 @@ def googlenet( original_aux_logits = kwargs.get("aux_logits", False) if weights is not None: if "transform_input" not in kwargs: - kwargs["transform_input"] = True - kwargs["aux_logits"] = True - kwargs["init_weights"] = False - kwargs["num_classes"] = len(weights.meta["categories"]) + _ovewrite_named_param(kwargs, "transform_input", True) + _ovewrite_named_param(kwargs, "aux_logits", True) + _ovewrite_named_param(kwargs, "init_weights", False) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) if "backend" in weights.meta: - kwargs["backend"] = weights.meta["backend"] + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) backend = kwargs.pop("backend", "fbgemm") model = QuantizableGoogLeNet(**kwargs) diff --git a/torchvision/prototype/models/quantization/inception.py b/torchvision/prototype/models/quantization/inception.py index babb9232f8c..e2cae702aa4 100644 --- a/torchvision/prototype/models/quantization/inception.py +++ b/torchvision/prototype/models/quantization/inception.py @@ -11,7 +11,7 @@ ) from .._api import Weights, WeightEntry from .._meta import _IMAGENET_CATEGORIES -from .._utils import _deprecated_param +from .._utils import _deprecated_param, _ovewrite_named_param from ..inception import InceptionV3Weights @@ -59,11 +59,11 @@ def inception_v3( original_aux_logits = kwargs.get("aux_logits", False) if weights is not None: if "transform_input" not in kwargs: - kwargs["transform_input"] = True - kwargs["aux_logits"] = True - kwargs["num_classes"] = len(weights.meta["categories"]) + _ovewrite_named_param(kwargs, "transform_input", True) + _ovewrite_named_param(kwargs, "aux_logits", True) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) if "backend" in weights.meta: - kwargs["backend"] = weights.meta["backend"] + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) backend = kwargs.pop("backend", "fbgemm") model = QuantizableInception3(**kwargs) diff --git a/torchvision/prototype/models/quantization/mobilenetv2.py b/torchvision/prototype/models/quantization/mobilenetv2.py index 00134751639..0aa798266c1 100644 --- a/torchvision/prototype/models/quantization/mobilenetv2.py +++ b/torchvision/prototype/models/quantization/mobilenetv2.py @@ -12,7 +12,7 @@ ) from .._api import Weights, WeightEntry from .._meta import _IMAGENET_CATEGORIES -from .._utils import _deprecated_param +from .._utils import _deprecated_param, _ovewrite_named_param from ..mobilenetv2 import MobileNetV2Weights @@ -58,9 +58,9 @@ def mobilenet_v2( weights = MobileNetV2Weights.verify(weights) if weights is not None: - kwargs["num_classes"] = len(weights.meta["categories"]) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) if "backend" in weights.meta: - kwargs["backend"] = weights.meta["backend"] + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) backend = kwargs.pop("backend", "qnnpack") model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs) diff --git a/torchvision/prototype/models/quantization/mobilenetv3.py b/torchvision/prototype/models/quantization/mobilenetv3.py index 72ec4dbbe89..35b97a2be6d 100644 --- a/torchvision/prototype/models/quantization/mobilenetv3.py +++ b/torchvision/prototype/models/quantization/mobilenetv3.py @@ -13,7 +13,7 @@ ) from .._api import Weights, WeightEntry from .._meta import _IMAGENET_CATEGORIES -from .._utils import _deprecated_param +from .._utils import _deprecated_param, _ovewrite_named_param from ..mobilenetv3 import MobileNetV3LargeWeights, _mobilenet_v3_conf @@ -33,9 +33,9 @@ def _mobilenet_v3_model( **kwargs: Any, ) -> QuantizableMobileNetV3: if weights is not None: - kwargs["num_classes"] = len(weights.meta["categories"]) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) if "backend" in weights.meta: - kwargs["backend"] = weights.meta["backend"] + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) backend = kwargs.pop("backend", "qnnpack") model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs) diff --git a/torchvision/prototype/models/quantization/resnet.py b/torchvision/prototype/models/quantization/resnet.py index 898355b295b..8ea978d0379 100644 --- a/torchvision/prototype/models/quantization/resnet.py +++ b/torchvision/prototype/models/quantization/resnet.py @@ -13,7 +13,7 @@ ) from .._api import Weights, WeightEntry from .._meta import _IMAGENET_CATEGORIES -from .._utils import _deprecated_param +from .._utils import _deprecated_param, _ovewrite_named_param from ..resnet import ResNet18Weights, ResNet50Weights, ResNeXt101_32x8dWeights @@ -37,9 +37,9 @@ def _resnet( **kwargs: Any, ) -> QuantizableResNet: if weights is not None: - kwargs["num_classes"] = len(weights.meta["categories"]) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) if "backend" in weights.meta: - kwargs["backend"] = weights.meta["backend"] + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) backend = kwargs.pop("backend", "fbgemm") model = QuantizableResNet(block, layers, **kwargs) @@ -158,6 +158,6 @@ def resnext101_32x8d( else: weights = ResNeXt101_32x8dWeights.verify(weights) - kwargs["groups"] = 32 - kwargs["width_per_group"] = 8 + _ovewrite_named_param(kwargs, "groups", 32) + _ovewrite_named_param(kwargs, "width_per_group", 8) return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs) diff --git a/torchvision/prototype/models/quantization/shufflenetv2.py b/torchvision/prototype/models/quantization/shufflenetv2.py index 765f4ec0b97..4a106136fd8 100644 --- a/torchvision/prototype/models/quantization/shufflenetv2.py +++ b/torchvision/prototype/models/quantization/shufflenetv2.py @@ -11,7 +11,7 @@ ) from .._api import Weights, WeightEntry from .._meta import _IMAGENET_CATEGORIES -from .._utils import _deprecated_param +from .._utils import _deprecated_param, _ovewrite_named_param from ..shufflenetv2 import ShuffleNetV2_x0_5Weights, ShuffleNetV2_x1_0Weights @@ -33,9 +33,9 @@ def _shufflenetv2( **kwargs: Any, ) -> QuantizableShuffleNetV2: if weights is not None: - kwargs["num_classes"] = len(weights.meta["categories"]) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) if "backend" in weights.meta: - kwargs["backend"] = weights.meta["backend"] + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) backend = kwargs.pop("backend", "fbgemm") model = QuantizableShuffleNetV2(stages_repeats, stages_out_channels, **kwargs) diff --git a/torchvision/prototype/models/regnet.py b/torchvision/prototype/models/regnet.py index 1206a150709..8bf53832637 100644 --- a/torchvision/prototype/models/regnet.py +++ b/torchvision/prototype/models/regnet.py @@ -8,7 +8,7 @@ from ...models.regnet import RegNet, BlockParams from ._api import Weights, WeightEntry from ._meta import _IMAGENET_CATEGORIES -from ._utils import _deprecated_param +from ._utils import _deprecated_param, _ovewrite_named_param __all__ = [ @@ -53,7 +53,7 @@ def _regnet( **kwargs: Any, ) -> RegNet: if weights is not None: - kwargs["num_classes"] = len(weights.meta["categories"]) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) norm_layer = kwargs.pop("norm_layer", partial(nn.BatchNorm2d, eps=1e-05, momentum=0.1)) model = RegNet(block_params, norm_layer=norm_layer, **kwargs) diff --git a/torchvision/prototype/models/resnet.py b/torchvision/prototype/models/resnet.py index 06a60089614..6a935e15271 100644 --- a/torchvision/prototype/models/resnet.py +++ b/torchvision/prototype/models/resnet.py @@ -7,7 +7,7 @@ from ...models.resnet import BasicBlock, Bottleneck, ResNet from ._api import Weights, WeightEntry from ._meta import _IMAGENET_CATEGORIES -from ._utils import _deprecated_param +from ._utils import _deprecated_param, _ovewrite_named_param __all__ = [ @@ -41,7 +41,7 @@ def _resnet( **kwargs: Any, ) -> ResNet: if weights is not None: - kwargs["num_classes"] = len(weights.meta["categories"]) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) model = ResNet(block, layers, **kwargs) @@ -286,8 +286,8 @@ def resnext50_32x4d(weights: Optional[ResNeXt50_32x4dWeights] = None, progress: weights = _deprecated_param("pretrained", "weights", ResNeXt50_32x4dWeights.ImageNet1K_RefV1, kwargs) weights = ResNeXt50_32x4dWeights.verify(weights) - kwargs["groups"] = 32 - kwargs["width_per_group"] = 4 + _ovewrite_named_param(kwargs, "groups", 32) + _ovewrite_named_param(kwargs, "width_per_group", 4) return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) @@ -296,8 +296,8 @@ def resnext101_32x8d(weights: Optional[ResNeXt101_32x8dWeights] = None, progress weights = _deprecated_param("pretrained", "weights", ResNeXt101_32x8dWeights.ImageNet1K_RefV1, kwargs) weights = ResNeXt101_32x8dWeights.verify(weights) - kwargs["groups"] = 32 - kwargs["width_per_group"] = 8 + _ovewrite_named_param(kwargs, "groups", 32) + _ovewrite_named_param(kwargs, "width_per_group", 8) return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) @@ -306,7 +306,7 @@ def wide_resnet50_2(weights: Optional[WideResNet50_2Weights] = None, progress: b weights = _deprecated_param("pretrained", "weights", WideResNet50_2Weights.ImageNet1K_Community, kwargs) weights = WideResNet50_2Weights.verify(weights) - kwargs["width_per_group"] = 64 * 2 + _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) @@ -315,5 +315,5 @@ def wide_resnet101_2(weights: Optional[WideResNet101_2Weights] = None, progress: weights = _deprecated_param("pretrained", "weights", WideResNet101_2Weights.ImageNet1K_Community, kwargs) weights = WideResNet101_2Weights.verify(weights) - kwargs["width_per_group"] = 64 * 2 + _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) diff --git a/torchvision/prototype/models/segmentation/deeplabv3.py b/torchvision/prototype/models/segmentation/deeplabv3.py index d9fdc0af9d9..9ba4dc02827 100644 --- a/torchvision/prototype/models/segmentation/deeplabv3.py +++ b/torchvision/prototype/models/segmentation/deeplabv3.py @@ -7,7 +7,7 @@ from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet from .._api import Weights, WeightEntry from .._meta import _VOC_CATEGORIES -from .._utils import _deprecated_param +from .._utils import _deprecated_param, _ovewrite_value_param from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large from ..resnet import resnet50, resnet101 from ..resnet import ResNet50Weights, ResNet101Weights @@ -88,8 +88,8 @@ def deeplabv3_resnet50( if weights is not None: weights_backbone = None - aux_loss = True - num_classes = len(weights.meta["categories"]) + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param(aux_loss, True) elif num_classes is None: num_classes = 21 @@ -121,8 +121,8 @@ def deeplabv3_resnet101( if weights is not None: weights_backbone = None - aux_loss = True - num_classes = len(weights.meta["categories"]) + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param(aux_loss, True) elif num_classes is None: num_classes = 21 @@ -156,8 +156,8 @@ def deeplabv3_mobilenet_v3_large( if weights is not None: weights_backbone = None - aux_loss = True - num_classes = len(weights.meta["categories"]) + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param(aux_loss, True) elif num_classes is None: num_classes = 21 diff --git a/torchvision/prototype/models/segmentation/fcn.py b/torchvision/prototype/models/segmentation/fcn.py index 6459aa566ef..881d63eea51 100644 --- a/torchvision/prototype/models/segmentation/fcn.py +++ b/torchvision/prototype/models/segmentation/fcn.py @@ -7,7 +7,7 @@ from ....models.segmentation.fcn import FCN, _fcn_resnet from .._api import Weights, WeightEntry from .._meta import _VOC_CATEGORIES -from .._utils import _deprecated_param +from .._utils import _deprecated_param, _ovewrite_value_param from ..resnet import ResNet50Weights, ResNet101Weights, resnet50, resnet101 @@ -64,9 +64,9 @@ def fcn_resnet50( weights_backbone = ResNet50Weights.verify(weights_backbone) if weights is not None: - aux_loss = True weights_backbone = None - num_classes = len(weights.meta["categories"]) + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param(aux_loss, True) elif num_classes is None: num_classes = 21 @@ -97,9 +97,9 @@ def fcn_resnet101( weights_backbone = ResNet101Weights.verify(weights_backbone) if weights is not None: - aux_loss = True weights_backbone = None - num_classes = len(weights.meta["categories"]) + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param(aux_loss, True) elif num_classes is None: num_classes = 21 diff --git a/torchvision/prototype/models/segmentation/lraspp.py b/torchvision/prototype/models/segmentation/lraspp.py index d65abc50892..037d38a88dd 100644 --- a/torchvision/prototype/models/segmentation/lraspp.py +++ b/torchvision/prototype/models/segmentation/lraspp.py @@ -7,7 +7,7 @@ from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3 from .._api import Weights, WeightEntry from .._meta import _VOC_CATEGORIES -from .._utils import _deprecated_param +from .._utils import _deprecated_param, _ovewrite_value_param from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large @@ -51,7 +51,7 @@ def lraspp_mobilenet_v3_large( if weights is not None: weights_backbone = None - num_classes = len(weights.meta["categories"]) + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) elif num_classes is None: num_classes = 21 diff --git a/torchvision/prototype/models/shufflenetv2.py b/torchvision/prototype/models/shufflenetv2.py index 45ecde22c80..7a3a8db2158 100644 --- a/torchvision/prototype/models/shufflenetv2.py +++ b/torchvision/prototype/models/shufflenetv2.py @@ -7,7 +7,7 @@ from ...models.shufflenetv2 import ShuffleNetV2 from ._api import Weights, WeightEntry from ._meta import _IMAGENET_CATEGORIES -from ._utils import _deprecated_param +from ._utils import _deprecated_param, _ovewrite_named_param __all__ = [ @@ -30,7 +30,7 @@ def _shufflenetv2( **kwargs: Any, ) -> ShuffleNetV2: if weights is not None: - kwargs["num_classes"] = len(weights.meta["categories"]) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) model = ShuffleNetV2(*args, **kwargs) diff --git a/torchvision/prototype/models/squeezenet.py b/torchvision/prototype/models/squeezenet.py index 61353bd0d9f..d3d56eda62e 100644 --- a/torchvision/prototype/models/squeezenet.py +++ b/torchvision/prototype/models/squeezenet.py @@ -7,7 +7,7 @@ from ...models.squeezenet import SqueezeNet from ._api import Weights, WeightEntry from ._meta import _IMAGENET_CATEGORIES -from ._utils import _deprecated_param +from ._utils import _deprecated_param, _ovewrite_named_param __all__ = ["SqueezeNet", "SqueezeNet1_0Weights", "SqueezeNet1_1Weights", "squeezenet1_0", "squeezenet1_1"] @@ -51,7 +51,7 @@ def squeezenet1_0(weights: Optional[SqueezeNet1_0Weights] = None, progress: bool weights = SqueezeNet1_0Weights.verify(weights) if weights is not None: - kwargs["num_classes"] = len(weights.meta["categories"]) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) model = SqueezeNet("1_0", **kwargs) @@ -67,7 +67,7 @@ def squeezenet1_1(weights: Optional[SqueezeNet1_1Weights] = None, progress: bool weights = SqueezeNet1_1Weights.verify(weights) if weights is not None: - kwargs["num_classes"] = len(weights.meta["categories"]) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) model = SqueezeNet("1_1", **kwargs) diff --git a/torchvision/prototype/models/vgg.py b/torchvision/prototype/models/vgg.py index a01abc2fca0..0f67fabf750 100644 --- a/torchvision/prototype/models/vgg.py +++ b/torchvision/prototype/models/vgg.py @@ -7,7 +7,7 @@ from ...models.vgg import VGG, make_layers, cfgs from ._api import Weights, WeightEntry from ._meta import _IMAGENET_CATEGORIES -from ._utils import _deprecated_param +from ._utils import _deprecated_param, _ovewrite_named_param __all__ = [ @@ -33,7 +33,7 @@ def _vgg(cfg: str, batch_norm: bool, weights: Optional[Weights], progress: bool, **kwargs: Any) -> VGG: if weights is not None: - kwargs["num_classes"] = len(weights.meta["categories"]) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) diff --git a/torchvision/prototype/models/video/resnet.py b/torchvision/prototype/models/video/resnet.py index 5c725486dc5..9d42e567ff2 100644 --- a/torchvision/prototype/models/video/resnet.py +++ b/torchvision/prototype/models/video/resnet.py @@ -17,7 +17,7 @@ ) from .._api import Weights, WeightEntry from .._meta import _KINETICS400_CATEGORIES -from .._utils import _deprecated_param +from .._utils import _deprecated_param, _ovewrite_named_param __all__ = [ @@ -41,7 +41,7 @@ def _video_resnet( **kwargs: Any, ) -> VideoResNet: if weights is not None: - kwargs["num_classes"] = len(weights.meta["categories"]) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) model = VideoResNet(block, conv_makers, layers, stem, **kwargs)