Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the get_weights API #5006

Merged
merged 4 commits into from
Nov 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,7 @@ def load_data(traindir, valdir, args):
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
)
else:
fn = PM.quantization.__dict__[args.model] if hasattr(args, "backend") else PM.__dict__[args.model]
weights = PM._api.get_weight(fn, args.weights)
weights = PM.get_weight(args.weights)
preprocessing = weights.transforms()

dataset_test = torchvision.datasets.ImageFolder(
Expand Down
3 changes: 1 addition & 2 deletions references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ def get_transform(train, args):
elif not args.weights:
return presets.DetectionPresetEval()
else:
fn = PM.detection.__dict__[args.model]
weights = PM._api.get_weight(fn, args.weights)
weights = PM.get_weight(args.weights)
return weights.transforms()


Expand Down
3 changes: 1 addition & 2 deletions references/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def get_transform(train, args):
elif not args.weights:
return presets.SegmentationPresetEval(base_size=520)
else:
fn = PM.segmentation.__dict__[args.model]
weights = PM._api.get_weight(fn, args.weights)
weights = PM.get_weight(args.weights)
return weights.transforms()


Expand Down
3 changes: 1 addition & 2 deletions references/video_classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,7 @@ def main(args):
if not args.weights:
transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112))
else:
fn = PM.video.__dict__[args.model]
weights = PM._api.get_weight(fn, args.weights)
weights = PM.get_weight(args.weights)
transform_test = weights.transforms()

if args.cache_dataset and os.path.exists(cache_path):
Expand Down
6 changes: 5 additions & 1 deletion test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@

def get_models_from_module(module):
# TODO add a registration mechanism to torchvision.models
return [v for k, v in module.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
return [
v
for k, v in module.__dict__.items()
if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k != "get_weight"
]


@pytest.fixture
Expand Down
36 changes: 23 additions & 13 deletions test/test_prototype_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@ def _get_parent_module(model_fn):
return module


def _get_model_weights(model_fn):
module = _get_parent_module(model_fn)
weights_name = "_QuantizedWeights" if module.__name__.split(".")[-1] == "quantization" else "_Weights"
try:
return next(
v
for k, v in module.__dict__.items()
if k.endswith(weights_name) and k.replace(weights_name, "").lower() == model_fn.__name__
)
except StopIteration:
return None


def _build_model(fn, **kwargs):
try:
model = fn(**kwargs)
Expand All @@ -36,24 +49,22 @@ def _build_model(fn, **kwargs):


@pytest.mark.parametrize(
"model_fn, name, weight",
"name, weight",
[
(models.resnet50, "ImageNet1K_V1", models.ResNet50_Weights.ImageNet1K_V1),
(models.resnet50, "default", models.ResNet50_Weights.ImageNet1K_V2),
("ResNet50_Weights.ImageNet1K_V1", models.ResNet50_Weights.ImageNet1K_V1),
("ResNet50_Weights.default", models.ResNet50_Weights.ImageNet1K_V2),
(
models.quantization.resnet50,
"default",
"ResNet50_QuantizedWeights.default",
models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V2,
),
(
models.quantization.resnet50,
"ImageNet1K_FBGEMM_V1",
"ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1",
models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1,
),
],
)
def test_get_weight(model_fn, name, weight):
assert models._api.get_weight(model_fn, name) == weight
def test_get_weight(name, weight):
assert models.get_weight(name) == weight


@pytest.mark.parametrize(
Expand All @@ -65,10 +76,9 @@ def test_get_weight(model_fn, name, weight):
+ TM.get_models_from_module(models.video),
)
def test_naming_conventions(model_fn):
model_name = model_fn.__name__
module = _get_parent_module(model_fn)
weights_name = "_QuantizedWeights" if module.__name__.split(".")[-1] == "quantization" else "_Weights"
assert model_name in set(x.replace(weights_name, "").lower() for x in module.__dict__ if x.endswith(weights_name))
weights_enum = _get_model_weights(model_fn)
assert weights_enum is not None
assert len(weights_enum) == 0 or hasattr(weights_enum, "default")
datumbox marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models))
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
from . import quantization
from . import segmentation
from . import video
from ._api import get_weight
57 changes: 26 additions & 31 deletions torchvision/prototype/models/_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import importlib
import inspect
import sys
from collections import OrderedDict
from dataclasses import dataclass, fields
from enum import Enum
from inspect import signature
from typing import Any, Callable, Dict

from ..._internally_replaced_utils import load_state_dict_from_url
Expand Down Expand Up @@ -30,7 +32,6 @@ class Weights:
url: str
transforms: Callable
meta: Dict[str, Any]
default: bool


class WeightsEnum(Enum):
Expand All @@ -50,7 +51,7 @@ def __init__(self, value: Weights):
def verify(cls, obj: Any) -> Any:
if obj is not None:
if type(obj) is str:
obj = cls.from_str(obj)
obj = cls.from_str(obj.replace(cls.__name__ + ".", ""))
elif not isinstance(obj, cls):
raise TypeError(
f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}."
Expand All @@ -59,8 +60,8 @@ def verify(cls, obj: Any) -> Any:

@classmethod
def from_str(cls, value: str) -> "WeightsEnum":
for v in cls:
if v._name_ == value or (value == "default" and v.default):
for k, v in cls.__members__.items():
if k == value:
return v
raise ValueError(f"Invalid value {value} for enum {cls.__name__}.")

Expand All @@ -78,41 +79,35 @@ def __getattr__(self, name):
return super().__getattr__(name)


def get_weight(fn: Callable, weight_name: str) -> WeightsEnum:
def get_weight(name: str) -> WeightsEnum:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering: shoudn't this return a Weight instance, instead of WeightsEnum?
Same for from_str (and in the return section below).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to return a WeightsEnum value. That is the value of an Enum which maintains information about the class it comes from (ResNet50_Weights). Returning a Weights loses the information necessary to validate that the right type of weights were passed to the method.

"""
Gets the weight enum of a specific model builder method and weight name combination.
Gets the weight enum value by its full name. Example: "ResNet50_Weights.ImageNet1K_V1"

Args:
fn (Callable): The builder method used to create the model.
weight_name (str): The name of the weight enum entry of the specific model.
name (str): The name of the weight enum entry.

Returns:
WeightsEnum: The requested weight enum.
"""
sig = signature(fn)
if "weights" not in sig.parameters:
raise ValueError("The method is missing the 'weights' parameter.")
try:
enum_name, value_name = name.split(".")
except ValueError:
raise ValueError(f"Invalid weight name provided: '{name}'.")

base_module_name = ".".join(sys.modules[__name__].__name__.split(".")[:-1])
base_module = importlib.import_module(base_module_name)
datumbox marked this conversation as resolved.
Show resolved Hide resolved
model_modules = [base_module] + [
x[1] for x in inspect.getmembers(base_module, inspect.ismodule) if x[1].__file__.endswith("__init__.py")
]

ann = signature(fn).parameters["weights"].annotation
weights_enum = None
if isinstance(ann, type) and issubclass(ann, WeightsEnum):
weights_enum = ann
else:
# handle cases like Union[Optional, T]
# TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8
for t in ann.__args__: # type: ignore[union-attr]
if isinstance(t, type) and issubclass(t, WeightsEnum):
# ensure the name exists. handles builders with multiple types of weights like in quantization
try:
t.from_str(weight_name)
except ValueError:
continue
weights_enum = t
break
for m in model_modules:
potential_class = m.__dict__.get(enum_name, None)
if potential_class is not None and issubclass(potential_class, WeightsEnum):
weights_enum = potential_class
break

if weights_enum is None:
raise ValueError(
"The weight class for the specific method couldn't be retrieved. Make sure the typing info is correct."
)
raise ValueError(f"The weight enum '{enum_name}' for the specific method couldn't be retrieved.")

return weights_enum.from_str(weight_name)
return weights_enum.from_str(value_name)
2 changes: 1 addition & 1 deletion torchvision/prototype/models/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ class AlexNet_Weights(WeightsEnum):
"acc@1": 56.522,
"acc@5": 79.066,
},
default=True,
)
default = ImageNet1K_V1


def alexnet(weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
Expand Down
8 changes: 4 additions & 4 deletions torchvision/prototype/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ class DenseNet121_Weights(WeightsEnum):
"acc@1": 74.434,
"acc@5": 91.972,
},
default=True,
)
default = ImageNet1K_V1


class DenseNet161_Weights(WeightsEnum):
Expand All @@ -93,8 +93,8 @@ class DenseNet161_Weights(WeightsEnum):
"acc@1": 77.138,
"acc@5": 93.560,
},
default=True,
)
default = ImageNet1K_V1


class DenseNet169_Weights(WeightsEnum):
Expand All @@ -106,8 +106,8 @@ class DenseNet169_Weights(WeightsEnum):
"acc@1": 75.600,
"acc@5": 92.806,
},
default=True,
)
default = ImageNet1K_V1


class DenseNet201_Weights(WeightsEnum):
Expand All @@ -119,8 +119,8 @@ class DenseNet201_Weights(WeightsEnum):
"acc@1": 76.896,
"acc@5": 93.370,
},
default=True,
)
default = ImageNet1K_V1


def densenet121(weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
"map": 37.0,
},
default=True,
)
default = Coco_V1


class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
Expand All @@ -58,8 +58,8 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
"map": 32.8,
},
default=True,
)
default = Coco_V1


class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
Expand All @@ -71,8 +71,8 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
"map": 22.8,
},
default=True,
)
default = Coco_V1


def fasterrcnn_resnet50_fpn(
Expand Down
3 changes: 1 addition & 2 deletions torchvision/prototype/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
"box_map": 50.6,
"kp_map": 61.1,
},
default=False,
)
Coco_V1 = Weights(
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
Expand All @@ -46,8 +45,8 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
"box_map": 54.6,
"kp_map": 65.0,
},
default=True,
)
default = Coco_V1


def keypointrcnn_resnet50_fpn(
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/models/detection/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
"box_map": 37.9,
"mask_map": 34.6,
},
default=True,
)
default = Coco_V1


def maskrcnn_resnet50_fpn(
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/models/detection/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet",
"map": 36.4,
},
default=True,
)
default = Coco_V1


def retinanet_resnet50_fpn(
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/models/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ class SSD300_VGG16_Weights(WeightsEnum):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssd300-vgg16",
"map": 25.1,
},
default=True,
)
default = Coco_V1


def ssd300_vgg16(
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/models/detection/ssdlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssdlite320-mobilenetv3-large",
"map": 21.3,
},
default=True,
)
default = Coco_V1


def ssdlite320_mobilenet_v3_large(
Expand Down
Loading