Skip to content

Commit

Permalink
Change get_weights to work with full Enum names and make it public.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Nov 29, 2021
1 parent 761959e commit 3339918
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 43 deletions.
3 changes: 1 addition & 2 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,7 @@ def load_data(traindir, valdir, args):
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
)
else:
fn = PM.quantization.__dict__[args.model] if hasattr(args, "backend") else PM.__dict__[args.model]
weights = PM._api.get_weight(fn, args.weights)
weights = PM.get_weight(args.weights)
preprocessing = weights.transforms()

dataset_test = torchvision.datasets.ImageFolder(
Expand Down
3 changes: 1 addition & 2 deletions references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ def get_transform(train, args):
elif not args.weights:
return presets.DetectionPresetEval()
else:
fn = PM.detection.__dict__[args.model]
weights = PM._api.get_weight(fn, args.weights)
weights = PM.get_weight(args.weights)
return weights.transforms()


Expand Down
3 changes: 1 addition & 2 deletions references/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def get_transform(train, args):
elif not args.weights:
return presets.SegmentationPresetEval(base_size=520)
else:
fn = PM.segmentation.__dict__[args.model]
weights = PM._api.get_weight(fn, args.weights)
weights = PM.get_weight(args.weights)
return weights.transforms()


Expand Down
3 changes: 1 addition & 2 deletions references/video_classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,7 @@ def main(args):
if not args.weights:
transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112))
else:
fn = PM.video.__dict__[args.model]
weights = PM._api.get_weight(fn, args.weights)
weights = PM.get_weight(args.weights)
transform_test = weights.transforms()

if args.cache_dataset and os.path.exists(cache_path):
Expand Down
6 changes: 5 additions & 1 deletion test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@

def get_models_from_module(module):
# TODO add a registration mechanism to torchvision.models
return [v for k, v in module.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
return [
v
for k, v in module.__dict__.items()
if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k != "get_weight"
]


@pytest.fixture
Expand Down
16 changes: 7 additions & 9 deletions test/test_prototype_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,24 +49,22 @@ def _build_model(fn, **kwargs):


@pytest.mark.parametrize(
"model_fn, name, weight",
"name, weight",
[
(models.resnet50, "ImageNet1K_V1", models.ResNet50_Weights.ImageNet1K_V1),
(models.resnet50, "default", models.ResNet50_Weights.ImageNet1K_V2),
("ResNet50_Weights.ImageNet1K_V1", models.ResNet50_Weights.ImageNet1K_V1),
("ResNet50_Weights.default", models.ResNet50_Weights.ImageNet1K_V2),
(
models.quantization.resnet50,
"default",
"ResNet50_QuantizedWeights.default",
models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V2,
),
(
models.quantization.resnet50,
"ImageNet1K_FBGEMM_V1",
"ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1",
models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1,
),
],
)
def test_get_weight(model_fn, name, weight):
assert models._api.get_weight(model_fn, name) == weight
def test_get_weight(name, weight):
assert models.get_weight(name) == weight


@pytest.mark.parametrize(
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
from . import quantization
from . import segmentation
from . import video
from ._api import get_weight
48 changes: 23 additions & 25 deletions torchvision/prototype/models/_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import importlib
import inspect
import sys
from collections import OrderedDict
from dataclasses import dataclass, fields
from enum import Enum
from inspect import signature
from typing import Any, Callable, Dict

from ..._internally_replaced_utils import load_state_dict_from_url
Expand Down Expand Up @@ -49,7 +51,7 @@ def __init__(self, value: Weights):
def verify(cls, obj: Any) -> Any:
if obj is not None:
if type(obj) is str:
obj = cls.from_str(obj)
obj = cls.from_str(obj.replace(cls.__name__ + ".", ""))
elif not isinstance(obj, cls):
raise TypeError(
f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}."
Expand Down Expand Up @@ -82,41 +84,37 @@ def __getattr__(self, name):
return super().__getattr__(name)


def get_weight(fn: Callable, weight_name: str) -> WeightsEnum:
def get_weight(name: str) -> WeightsEnum:
"""
Gets the weight enum of a specific model builder method and weight name combination.
Gets the weight enum value by its full name. Example: ``ResNet50_Weights.ImageNet1K_V1``
Args:
fn (Callable): The builder method used to create the model.
weight_name (str): The name of the weight enum entry of the specific model.
name (str): The name of the weight enum entry.
Returns:
WeightsEnum: The requested weight enum.
"""
sig = signature(fn)
if "weights" not in sig.parameters:
raise ValueError("The method is missing the 'weights' parameter.")
try:
enum_name, value_name = name.split(".")
except ValueError:
raise ValueError(f"Invalid weight name provided: '{name}'.")

base_module_name = ".".join(sys.modules[__name__].__name__.split(".")[:-1])
base_module = importlib.import_module(base_module_name)
model_modules = [base_module] + [
x[1] for x in inspect.getmembers(base_module, inspect.ismodule) if x[1].__file__.endswith("__init__.py")
]

ann = signature(fn).parameters["weights"].annotation
weights_enum = None
if isinstance(ann, type) and issubclass(ann, WeightsEnum):
weights_enum = ann
else:
# handle cases like Union[Optional, T]
# TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8
for t in ann.__args__: # type: ignore[union-attr]
if isinstance(t, type) and issubclass(t, WeightsEnum):
# ensure the name exists. handles builders with multiple types of weights like in quantization
try:
t.from_str(weight_name)
except ValueError:
continue
weights_enum = t
break
for m in model_modules:
potential_class = m.__dict__.get(enum_name, None)
if potential_class is not None and issubclass(potential_class, WeightsEnum):
weights_enum = potential_class
break

if weights_enum is None:
raise ValueError(
"The weight class for the specific method couldn't be retrieved. Make sure the typing info is correct."
)

return weights_enum.from_str(weight_name)
return weights_enum.from_str(value_name)

0 comments on commit 3339918

Please sign in to comment.