Skip to content

Commit

Permalink
More Multiweight support cleanups (#4948)
Browse files Browse the repository at this point in the history
* Updated the link for densenet recipe.

* Set default value of `num_classes` and `num_keypoints` to `None`

* Provide helper methods for parameter checks to reduce duplicate code.

* Throw errors on silent config overwrites from weight meta-data and legacy builders.

* Changing order of arguments + fixing mypy.

* Make the builders fully BC.

* Add "default" weights support that returns always the best weights.
  • Loading branch information
datumbox authored Nov 25, 2021
1 parent 09e759e commit 18cf5ab
Show file tree
Hide file tree
Showing 32 changed files with 703 additions and 420 deletions.
15 changes: 10 additions & 5 deletions test/test_prototype_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,19 @@ def get_models_with_module_names(module):


@pytest.mark.parametrize(
"model_fn, weight",
"model_fn, name, weight",
[
(models.resnet50, models.ResNet50Weights.ImageNet1K_RefV2),
(models.quantization.resnet50, models.quantization.QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1),
(models.resnet50, "ImageNet1K_RefV1", models.ResNet50Weights.ImageNet1K_RefV1),
(models.resnet50, "default", models.ResNet50Weights.ImageNet1K_RefV2),
(
models.quantization.resnet50,
"ImageNet1K_FBGEMM_RefV1",
models.quantization.QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1,
),
],
)
def test_get_weight(model_fn, weight):
assert models._api.get_weight(model_fn, weight.name) == weight
def test_get_weight(model_fn, name, weight):
assert models._api.get_weight(model_fn, name) == weight


@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models))
Expand Down
3 changes: 2 additions & 1 deletion torchvision/prototype/models/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class WeightEntry:
url: str
transforms: Callable
meta: Dict[str, Any]
default: bool


class Weights(Enum):
Expand Down Expand Up @@ -59,7 +60,7 @@ def verify(cls, obj: Any) -> Any:
@classmethod
def from_str(cls, value: str) -> "Weights":
for v in cls:
if v._name_ == value:
if v._name_ == value or (value == "default" and v.default):
return v
raise ValueError(f"Invalid value {value} for enum {cls.__name__}.")

Expand Down
44 changes: 44 additions & 0 deletions torchvision/prototype/models/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import warnings
from typing import Any, Dict, Optional, TypeVar

from ._api import Weights


W = TypeVar("W", bound=Weights)
V = TypeVar("V")


def _deprecated_param(
kwargs: Dict[str, Any], deprecated_param: str, new_param: str, default_value: Optional[W]
) -> Optional[W]:
warnings.warn(f"The parameter '{deprecated_param}' is deprecated, please use '{new_param}' instead.")
if kwargs.pop(deprecated_param):
if default_value is not None:
return default_value
else:
raise ValueError("No checkpoint is available for model.")
else:
return None


def _deprecated_positional(kwargs: Dict[str, Any], deprecated_param: str, new_param: str, default_value: V) -> None:
warnings.warn(
f"The positional parameter '{deprecated_param}' is deprecated, please use keyword parameter '{new_param}'"
+ " instead."
)
kwargs[deprecated_param] = default_value


def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> None:
if param in kwargs:
if kwargs[param] != new_value:
raise ValueError(f"The parameter '{param}' expected value {new_value} but got {kwargs[param]} instead.")
else:
kwargs[param] = new_value


def _ovewrite_value_param(param: Optional[V], new_value: V) -> V:
if param is not None:
if param != new_value:
raise ValueError(f"The parameter '{param}' expected value {new_value} but got {param} instead.")
return new_value
11 changes: 7 additions & 4 deletions torchvision/prototype/models/alexnet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from functools import partial
from typing import Any, Optional

Expand All @@ -8,6 +7,7 @@
from ...models.alexnet import AlexNet
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param


__all__ = ["AlexNet", "AlexNetWeights", "alexnet"]
Expand All @@ -25,16 +25,19 @@ class AlexNetWeights(Weights):
"acc@1": 56.522,
"acc@5": 79.066,
},
default=True,
)


def alexnet(weights: Optional[AlexNetWeights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = AlexNetWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", AlexNetWeights.ImageNet1K_RefV1)
weights = AlexNetWeights.verify(weights)

if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

model = AlexNet(**kwargs)

Expand Down
30 changes: 19 additions & 11 deletions torchvision/prototype/models/densenet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import re
import warnings
from functools import partial
from typing import Any, Optional, Tuple

Expand All @@ -10,6 +9,7 @@
from ...models.densenet import DenseNet
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param


__all__ = [
Expand Down Expand Up @@ -53,7 +53,7 @@ def _densenet(
**kwargs: Any,
) -> DenseNet:
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)

Expand All @@ -67,7 +67,7 @@ def _densenet(
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": None, # TODO: add here a URL to documentation stating that the weights were ported from LuaTorch
"recipe": "https://github.com/pytorch/vision/pull/116",
}


Expand All @@ -80,6 +80,7 @@ class DenseNet121Weights(Weights):
"acc@1": 74.434,
"acc@5": 91.972,
},
default=True,
)


Expand All @@ -92,6 +93,7 @@ class DenseNet161Weights(Weights):
"acc@1": 77.138,
"acc@5": 93.560,
},
default=True,
)


Expand All @@ -104,6 +106,7 @@ class DenseNet169Weights(Weights):
"acc@1": 75.600,
"acc@5": 92.806,
},
default=True,
)


Expand All @@ -116,40 +119,45 @@ class DenseNet201Weights(Weights):
"acc@1": 76.896,
"acc@5": 93.370,
},
default=True,
)


def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DenseNet121Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet121Weights.ImageNet1K_Community)
weights = DenseNet121Weights.verify(weights)

return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)


def densenet161(weights: Optional[DenseNet161Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DenseNet161Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet161Weights.ImageNet1K_Community)
weights = DenseNet161Weights.verify(weights)

return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs)


def densenet169(weights: Optional[DenseNet169Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DenseNet169Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet169Weights.ImageNet1K_Community)
weights = DenseNet169Weights.verify(weights)

return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs)


def densenet201(weights: Optional[DenseNet201Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
weights = DenseNet201Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet201Weights.ImageNet1K_Community)
weights = DenseNet201Weights.verify(weights)

return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs)
Loading

0 comments on commit 18cf5ab

Please sign in to comment.