diff --git a/docs/source/models.rst b/docs/source/models.rst index 3cf52389e82..410e0e42e99 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -176,6 +176,15 @@ Most pre-trained models can be accessed directly via PyTorch Hub without having weights = torch.hub.load("pytorch/vision", "get_weight", weights="ResNet50_Weights.IMAGENET1K_V2") model = torch.hub.load("pytorch/vision", "resnet50", weights=weights) +You can also retrieve all the available weights of a specific model via PyTorch Hub by doing: + +.. code:: python + + import torch + + weight_enum = torch.hub.load("pytorch/vision", "get_model_weights", name="resnet50") + print([weight for weight in weight_enum]) + The only exception to the above are the detection models included on :mod:`torchvision.models.detection`. These models require TorchVision to be installed because they depend on custom C++ operators. diff --git a/hubconf.py b/hubconf.py index 1231b0bbea6..57ce7a0d12a 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,7 +1,7 @@ # Optional list of dependencies required by the package dependencies = ["torch"] -from torchvision.models import get_weight +from torchvision.models import get_model_weights, get_weight from torchvision.models.alexnet import alexnet from torchvision.models.convnext import convnext_base, convnext_large, convnext_small, convnext_tiny from torchvision.models.densenet import densenet121, densenet161, densenet169, densenet201 diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 4df988dcc9a..c2886d2ed99 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -115,7 +115,7 @@ def get_weight(name: str) -> WeightsEnum: W = TypeVar("W", bound=WeightsEnum) -def get_model_weights(model: Union[Callable, str]) -> W: +def get_model_weights(name: Union[Callable, str]) -> W: """ Retuns the weights enum class associated to the given model. @@ -127,8 +127,7 @@ def get_model_weights(model: Union[Callable, str]) -> W: Returns: weights_enum (W): The weights enum class associated with the model. """ - if isinstance(model, str): - model = find_model(model) + model = find_model(name) if isinstance(name, str) else name return cast(W, _get_enum_from_fn(model))