Skip to content

Commit

Permalink
Add "default" weights support that returns always the best weights.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Nov 25, 2021
1 parent 8fa1e47 commit 8aa1c23
Show file tree
Hide file tree
Showing 31 changed files with 105 additions and 6 deletions.
15 changes: 10 additions & 5 deletions test/test_prototype_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,19 @@ def get_models_with_module_names(module):


@pytest.mark.parametrize(
"model_fn, weight",
"model_fn, name, weight",
[
(models.resnet50, models.ResNet50Weights.ImageNet1K_RefV2),
(models.quantization.resnet50, models.quantization.QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1),
(models.resnet50, "ImageNet1K_RefV1", models.ResNet50Weights.ImageNet1K_RefV1),
(models.resnet50, "default", models.ResNet50Weights.ImageNet1K_RefV2),
(
models.quantization.resnet50,
"ImageNet1K_FBGEMM_RefV1",
models.quantization.QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1,
),
],
)
def test_get_weight(model_fn, weight):
assert models._api.get_weight(model_fn, weight.name) == weight
def test_get_weight(model_fn, name, weight):
assert models._api.get_weight(model_fn, name) == weight


@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models))
Expand Down
3 changes: 2 additions & 1 deletion torchvision/prototype/models/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class WeightEntry:
url: str
transforms: Callable
meta: Dict[str, Any]
default: bool


class Weights(Enum):
Expand Down Expand Up @@ -59,7 +60,7 @@ def verify(cls, obj: Any) -> Any:
@classmethod
def from_str(cls, value: str) -> "Weights":
for v in cls:
if v._name_ == value:
if v._name_ == value or (value == "default" and v.default):
return v
raise ValueError(f"Invalid value {value} for enum {cls.__name__}.")

Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/models/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class AlexNetWeights(Weights):
"acc@1": 56.522,
"acc@5": 79.066,
},
default=True,
)


Expand Down
4 changes: 4 additions & 0 deletions torchvision/prototype/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class DenseNet121Weights(Weights):
"acc@1": 74.434,
"acc@5": 91.972,
},
default=True,
)


Expand All @@ -92,6 +93,7 @@ class DenseNet161Weights(Weights):
"acc@1": 77.138,
"acc@5": 93.560,
},
default=True,
)


Expand All @@ -104,6 +106,7 @@ class DenseNet169Weights(Weights):
"acc@1": 75.600,
"acc@5": 92.806,
},
default=True,
)


Expand All @@ -116,6 +119,7 @@ class DenseNet201Weights(Weights):
"acc@1": 76.896,
"acc@5": 93.370,
},
default=True,
)


Expand Down
3 changes: 3 additions & 0 deletions torchvision/prototype/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class FasterRCNNResNet50FPNWeights(Weights):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
"map": 37.0,
},
default=True,
)


Expand All @@ -57,6 +58,7 @@ class FasterRCNNMobileNetV3LargeFPNWeights(Weights):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
"map": 32.8,
},
default=True,
)


Expand All @@ -69,6 +71,7 @@ class FasterRCNNMobileNetV3Large320FPNWeights(Weights):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
"map": 22.8,
},
default=True,
)


Expand Down
2 changes: 2 additions & 0 deletions torchvision/prototype/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class KeypointRCNNResNet50FPNWeights(Weights):
"box_map": 50.6,
"kp_map": 61.1,
},
default=False,
)
Coco_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
Expand All @@ -45,6 +46,7 @@ class KeypointRCNNResNet50FPNWeights(Weights):
"box_map": 54.6,
"kp_map": 65.0,
},
default=True,
)


Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/models/detection/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class MaskRCNNResNet50FPNWeights(Weights):
"box_map": 37.9,
"mask_map": 34.6,
},
default=True,
)


Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/models/detection/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class RetinaNetResNet50FPNWeights(Weights):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet",
"map": 36.4,
},
default=True,
)


Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/models/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class SSD300VGG16Weights(Weights):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssd300-vgg16",
"map": 25.1,
},
default=True,
)


Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/models/detection/ssdlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class SSDlite320MobileNetV3LargeFPNWeights(Weights):
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssdlite320-mobilenetv3-large",
"map": 21.3,
},
default=True,
)


Expand Down
8 changes: 8 additions & 0 deletions torchvision/prototype/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class EfficientNetB0Weights(Weights):
"acc@1": 77.692,
"acc@5": 93.532,
},
default=True,
)


Expand All @@ -92,6 +93,7 @@ class EfficientNetB1Weights(Weights):
"acc@1": 78.642,
"acc@5": 94.186,
},
default=True,
)


Expand All @@ -105,6 +107,7 @@ class EfficientNetB2Weights(Weights):
"acc@1": 80.608,
"acc@5": 95.310,
},
default=True,
)


Expand All @@ -118,6 +121,7 @@ class EfficientNetB3Weights(Weights):
"acc@1": 82.008,
"acc@5": 96.054,
},
default=True,
)


Expand All @@ -131,6 +135,7 @@ class EfficientNetB4Weights(Weights):
"acc@1": 83.384,
"acc@5": 96.594,
},
default=True,
)


Expand All @@ -144,6 +149,7 @@ class EfficientNetB5Weights(Weights):
"acc@1": 83.444,
"acc@5": 96.628,
},
default=True,
)


Expand All @@ -157,6 +163,7 @@ class EfficientNetB6Weights(Weights):
"acc@1": 84.008,
"acc@5": 96.916,
},
default=True,
)


Expand All @@ -170,6 +177,7 @@ class EfficientNetB7Weights(Weights):
"acc@1": 84.122,
"acc@5": 96.908,
},
default=True,
)


Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/models/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class GoogLeNetWeights(Weights):
"acc@1": 69.778,
"acc@5": 89.530,
},
default=True,
)


Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/models/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class InceptionV3Weights(Weights):
"acc@1": 77.294,
"acc@5": 93.450,
},
default=True,
)


Expand Down
2 changes: 2 additions & 0 deletions torchvision/prototype/models/mnasnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class MNASNet0_5Weights(Weights):
"acc@1": 67.734,
"acc@5": 87.490,
},
default=True,
)


Expand All @@ -57,6 +58,7 @@ class MNASNet1_0Weights(Weights):
"acc@1": 73.456,
"acc@5": 91.510,
},
default=True,
)


Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/models/mobilenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class MobileNetV2Weights(Weights):
"acc@1": 71.878,
"acc@5": 90.286,
},
default=True,
)


Expand Down
3 changes: 3 additions & 0 deletions torchvision/prototype/models/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class MobileNetV3LargeWeights(Weights):
"acc@1": 74.042,
"acc@5": 91.340,
},
default=False,
)
ImageNet1K_RefV2 = WeightEntry(
url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth",
Expand All @@ -64,6 +65,7 @@ class MobileNetV3LargeWeights(Weights):
"acc@1": 75.274,
"acc@5": 92.566,
},
default=True,
)


Expand All @@ -77,6 +79,7 @@ class MobileNetV3SmallWeights(Weights):
"acc@1": 67.668,
"acc@5": 87.402,
},
default=True,
)


Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/models/quantization/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class QuantizedGoogLeNetWeights(Weights):
"acc@1": 69.826,
"acc@5": 89.404,
},
default=True,
)


Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/models/quantization/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class QuantizedInceptionV3Weights(Weights):
"acc@1": 77.176,
"acc@5": 93.354,
},
default=True,
)


Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/models/quantization/mobilenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class QuantizedMobileNetV2Weights(Weights):
"acc@1": 71.658,
"acc@5": 90.150,
},
default=True,
)


Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/models/quantization/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class QuantizedMobileNetV3LargeWeights(Weights):
"acc@1": 73.004,
"acc@5": 90.858,
},
default=True,
)


Expand Down
5 changes: 5 additions & 0 deletions torchvision/prototype/models/quantization/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class QuantizedResNet18Weights(Weights):
"acc@1": 69.494,
"acc@5": 88.882,
},
default=True,
)


Expand All @@ -86,6 +87,7 @@ class QuantizedResNet50Weights(Weights):
"acc@1": 75.920,
"acc@5": 92.814,
},
default=False,
)
ImageNet1K_FBGEMM_RefV2 = WeightEntry(
url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth",
Expand All @@ -96,6 +98,7 @@ class QuantizedResNet50Weights(Weights):
"acc@1": 80.282,
"acc@5": 94.976,
},
default=True,
)


Expand All @@ -109,6 +112,7 @@ class QuantizedResNeXt101_32x8dWeights(Weights):
"acc@1": 78.986,
"acc@5": 94.480,
},
default=False,
)
ImageNet1K_FBGEMM_RefV2 = WeightEntry(
url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth",
Expand All @@ -119,6 +123,7 @@ class QuantizedResNeXt101_32x8dWeights(Weights):
"acc@1": 82.574,
"acc@5": 96.132,
},
default=True,
)


Expand Down
2 changes: 2 additions & 0 deletions torchvision/prototype/models/quantization/shufflenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class QuantizedShuffleNetV2_x0_5Weights(Weights):
"acc@1": 57.972,
"acc@5": 79.780,
},
default=True,
)


Expand All @@ -82,6 +83,7 @@ class QuantizedShuffleNetV2_x1_0Weights(Weights):
"acc@1": 68.360,
"acc@5": 87.582,
},
default=True,
)


Expand Down
Loading

0 comments on commit 8aa1c23

Please sign in to comment.