Skip to content

Commit

Permalink
Throw errors on silent config overwrites from weight meta-data and l…
Browse files Browse the repository at this point in the history
…egacy builders.
  • Loading branch information
datumbox committed Nov 24, 2021
1 parent 1172e1f commit 9ff811b
Show file tree
Hide file tree
Showing 30 changed files with 109 additions and 94 deletions.
15 changes: 15 additions & 0 deletions torchvision/prototype/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,18 @@ def _deprecated_param(
raise ValueError("No checkpoint is available for model.")
else:
return None


def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: Any) -> None:
if param in kwargs:
if kwargs[param] != new_value:
raise ValueError(f"The parameter {param} expected value {new_value} but got {kwargs[param]} instead.")
else:
kwargs[param] = new_value


def _ovewrite_value_param(param: Any, new_value: Any) -> Any:
if param is not None:
if param != new_value:
raise ValueError(f"The parameter {param} expected value {new_value} but got {param} instead.")
return new_value
4 changes: 2 additions & 2 deletions torchvision/prototype/models/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ...models.alexnet import AlexNet
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param
from ._utils import _deprecated_param, _ovewrite_named_param


__all__ = ["AlexNet", "AlexNetWeights", "alexnet"]
Expand All @@ -34,7 +34,7 @@ def alexnet(weights: Optional[AlexNetWeights] = None, progress: bool = True, **k
weights = AlexNetWeights.verify(weights)

if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

model = AlexNet(**kwargs)

Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ...models.densenet import DenseNet
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param
from ._utils import _deprecated_param, _ovewrite_named_param


__all__ = [
Expand Down Expand Up @@ -53,7 +53,7 @@ def _densenet(
**kwargs: Any,
) -> DenseNet:
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)

Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param
from .._utils import _deprecated_param, _ovewrite_value_param
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large
from ..resnet import ResNet50Weights, resnet50

Expand Down Expand Up @@ -91,7 +91,7 @@ def fasterrcnn_resnet50_fpn(

if weights is not None:
weights_backbone = None
num_classes = len(weights.meta["categories"])
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91

Expand Down Expand Up @@ -121,7 +121,7 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
) -> FasterRCNN:
if weights is not None:
weights_backbone = None
num_classes = len(weights.meta["categories"])
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91

Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from .._api import Weights, WeightEntry
from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
from .._utils import _deprecated_param
from .._utils import _deprecated_param, _ovewrite_value_param
from ..resnet import ResNet50Weights, resnet50


Expand Down Expand Up @@ -72,8 +72,8 @@ def keypointrcnn_resnet50_fpn(

if weights is not None:
weights_backbone = None
num_classes = len(weights.meta["categories"])
num_keypoints = len(weights.meta["keypoint_names"])
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
num_keypoints = _ovewrite_value_param(num_keypoints, len(weights.meta["keypoint_names"]))
else:
if num_classes is None:
num_classes = 2
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/models/detection/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param
from .._utils import _deprecated_param, _ovewrite_value_param
from ..resnet import ResNet50Weights, resnet50


Expand Down Expand Up @@ -56,7 +56,7 @@ def maskrcnn_resnet50_fpn(

if weights is not None:
weights_backbone = None
num_classes = len(weights.meta["categories"])
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91

Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/models/detection/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param
from .._utils import _deprecated_param, _ovewrite_value_param
from ..resnet import ResNet50Weights, resnet50


Expand Down Expand Up @@ -56,7 +56,7 @@ def retinanet_resnet50_fpn(

if weights is not None:
weights_backbone = None
num_classes = len(weights.meta["categories"])
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91

Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/models/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param
from .._utils import _deprecated_param, _ovewrite_value_param
from ..vgg import VGG16Weights, vgg16


Expand Down Expand Up @@ -58,7 +58,7 @@ def ssd300_vgg16(

if weights is not None:
weights_backbone = None
num_classes = len(weights.meta["categories"])
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91

Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/models/detection/ssdlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES
from .._utils import _deprecated_param
from .._utils import _deprecated_param, _ovewrite_value_param
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large


Expand Down Expand Up @@ -64,7 +64,7 @@ def ssdlite320_mobilenet_v3_large(

if weights is not None:
weights_backbone = None
num_classes = len(weights.meta["categories"])
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91

Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ...models.efficientnet import EfficientNet, MBConvConfig
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param
from ._utils import _deprecated_param, _ovewrite_named_param


__all__ = [
Expand Down Expand Up @@ -41,7 +41,7 @@ def _efficientnet(
**kwargs: Any,
) -> EfficientNet:
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

bneck_conf = partial(MBConvConfig, width_mult=width_mult, depth_mult=depth_mult)
inverted_residual_setting = [
Expand Down
10 changes: 5 additions & 5 deletions torchvision/prototype/models/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ...models.googlenet import GoogLeNet, GoogLeNetOutputs, _GoogLeNetOutputs
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param
from ._utils import _deprecated_param, _ovewrite_named_param


__all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNetWeights", "googlenet"]
Expand Down Expand Up @@ -37,10 +37,10 @@ def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True,
original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
kwargs["aux_logits"] = True
kwargs["init_weights"] = False
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "transform_input", True)
_ovewrite_named_param(kwargs, "aux_logits", True)
_ovewrite_named_param(kwargs, "init_weights", False)
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

model = GoogLeNet(**kwargs)

Expand Down
10 changes: 5 additions & 5 deletions torchvision/prototype/models/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param
from ._utils import _deprecated_param, _ovewrite_named_param


__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "InceptionV3Weights", "inception_v3"]
Expand Down Expand Up @@ -36,10 +36,10 @@ def inception_v3(weights: Optional[InceptionV3Weights] = None, progress: bool =
original_aux_logits = kwargs.get("aux_logits", True)
if weights is not None:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
kwargs["aux_logits"] = True
kwargs["init_weights"] = False
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "transform_input", True)
_ovewrite_named_param(kwargs, "aux_logits", True)
_ovewrite_named_param(kwargs, "init_weights", False)
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

model = Inception3(**kwargs)

Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/models/mnasnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ...models.mnasnet import MNASNet
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param
from ._utils import _deprecated_param, _ovewrite_named_param


__all__ = [
Expand Down Expand Up @@ -67,7 +67,7 @@ class MNASNet1_3Weights(Weights):

def _mnasnet(alpha: float, weights: Optional[Weights], progress: bool, **kwargs: Any) -> MNASNet:
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

model = MNASNet(alpha, **kwargs)

Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/models/mobilenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ...models.mobilenetv2 import MobileNetV2
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param
from ._utils import _deprecated_param, _ovewrite_named_param


__all__ = ["MobileNetV2", "MobileNetV2Weights", "mobilenet_v2"]
Expand All @@ -34,7 +34,7 @@ def mobilenet_v2(weights: Optional[MobileNetV2Weights] = None, progress: bool =
weights = MobileNetV2Weights.verify(weights)

if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

model = MobileNetV2(**kwargs)

Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/models/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ...models.mobilenetv3 import MobileNetV3, _mobilenet_v3_conf, InvertedResidualConfig
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param
from ._utils import _deprecated_param, _ovewrite_named_param


__all__ = [
Expand All @@ -27,7 +27,7 @@ def _mobilenet_v3(
**kwargs: Any,
) -> MobileNetV3:
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)

Expand Down
12 changes: 6 additions & 6 deletions torchvision/prototype/models/quantization/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param
from .._utils import _deprecated_param, _ovewrite_named_param
from ..googlenet import GoogLeNetWeights


Expand Down Expand Up @@ -60,12 +60,12 @@ def googlenet(
original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
kwargs["aux_logits"] = True
kwargs["init_weights"] = False
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "transform_input", True)
_ovewrite_named_param(kwargs, "aux_logits", True)
_ovewrite_named_param(kwargs, "init_weights", False)
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if "backend" in weights.meta:
kwargs["backend"] = weights.meta["backend"]
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "fbgemm")

model = QuantizableGoogLeNet(**kwargs)
Expand Down
10 changes: 5 additions & 5 deletions torchvision/prototype/models/quantization/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param
from .._utils import _deprecated_param, _ovewrite_named_param
from ..inception import InceptionV3Weights


Expand Down Expand Up @@ -59,11 +59,11 @@ def inception_v3(
original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
kwargs["aux_logits"] = True
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "transform_input", True)
_ovewrite_named_param(kwargs, "aux_logits", True)
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if "backend" in weights.meta:
kwargs["backend"] = weights.meta["backend"]
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "fbgemm")

model = QuantizableInception3(**kwargs)
Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/models/quantization/mobilenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param
from .._utils import _deprecated_param, _ovewrite_named_param
from ..mobilenetv2 import MobileNetV2Weights


Expand Down Expand Up @@ -58,9 +58,9 @@ def mobilenet_v2(
weights = MobileNetV2Weights.verify(weights)

if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if "backend" in weights.meta:
kwargs["backend"] = weights.meta["backend"]
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "qnnpack")

model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs)
Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/models/quantization/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _deprecated_param
from .._utils import _deprecated_param, _ovewrite_named_param
from ..mobilenetv3 import MobileNetV3LargeWeights, _mobilenet_v3_conf


Expand All @@ -33,9 +33,9 @@ def _mobilenet_v3_model(
**kwargs: Any,
) -> QuantizableMobileNetV3:
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
if "backend" in weights.meta:
kwargs["backend"] = weights.meta["backend"]
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
backend = kwargs.pop("backend", "qnnpack")

model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs)
Expand Down
Loading

0 comments on commit 9ff811b

Please sign in to comment.