Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the get_weights API #5006

Merged
merged 4 commits into from
Nov 30, 2021

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Nov 29, 2021

Fixes #4652

  • Change the default weights mechanism to sue Enum aliases.
  • Change get_weights to work with full Enum names and make it public.
  • Test that all reference scripts continue to run.

cc @datumbox @bjuncek

@facebook-github-bot
Copy link

facebook-github-bot commented Nov 29, 2021

💊 CI failures summary and remediations

As of commit 6d0cc8e (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@datumbox datumbox force-pushed the prototype/refactor_get_weight branch 2 times, most recently from 06ec3ca to fdc9d2a Compare November 29, 2021 18:28
@datumbox datumbox requested a review from NicolasHug November 29, 2021 18:31
@datumbox datumbox force-pushed the prototype/refactor_get_weight branch from fdc9d2a to 5c0ec62 Compare November 29, 2021 19:00
@datumbox datumbox force-pushed the prototype/refactor_get_weight branch from 5c0ec62 to 3339918 Compare November 29, 2021 20:30
@datumbox
Copy link
Contributor Author

datumbox commented Nov 29, 2021

The references were validated by running the following commands. All work as expected:

torchrun --nproc_per_node=1 train.py --test-only --weights ResNet50_Weights.ImageNet1K_V2 --model resnet50

python train_quantization.py --device cpu --test-only --weights ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V2 --model resnet50

torchrun --nproc_per_node=1 train.py --test-only --weights SSDLite320_MobileNet_V3_Large_Weights.Coco_V1 --model ssdlite320_mobilenet_v3_large

torchrun --nproc_per_node=1 train.py  --dataset coco --model lraspp_mobilenet_v3_large --test-only --weights LRASPP_MobileNet_V3_Large_Weights.CocoWithVocLabels_V1

torchrun --nproc_per_node=1 train.py --data-path /datasets01/kinetics/070618/ --train-dir=val_avi-480p --val-dir=val_avi-480p --batch-size=64 --sync-bn --test-only --weights R2Plus1D_18_Weights.Kinetics400_V1 --cache-dataset

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @datumbox , I made some minor comments and a question, but this LGTM

test/test_prototype_models.py Show resolved Hide resolved
torchvision/prototype/models/_api.py Outdated Show resolved Hide resolved
torchvision/prototype/models/_api.py Outdated Show resolved Hide resolved
torchvision/prototype/models/_api.py Show resolved Hide resolved
torchvision/prototype/models/_api.py Outdated Show resolved Hide resolved
@@ -78,41 +84,37 @@ def __getattr__(self, name):
return super().__getattr__(name)


def get_weight(fn: Callable, weight_name: str) -> WeightsEnum:
def get_weight(name: str) -> WeightsEnum:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering: shoudn't this return a Weight instance, instead of WeightsEnum?
Same for from_str (and in the return section below).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to return a WeightsEnum value. That is the value of an Enum which maintains information about the class it comes from (ResNet50_Weights). Returning a Weights loses the information necessary to validate that the right type of weights were passed to the method.

@datumbox datumbox merged commit 3d8723d into pytorch:main Nov 30, 2021
@datumbox datumbox deleted the prototype/refactor_get_weight branch November 30, 2021 12:56
facebook-github-bot pushed a commit that referenced this pull request Dec 2, 2021
Summary:
* Change the `default` weights mechanism to sue Enum aliases.

* Change `get_weights` to work with full Enum names and make it public.

* Applying improvements from code review.

Reviewed By: NicolasHug

Differential Revision: D32759199

fbshipit-source-id: 13cfa6201125db29f099d2e3a73260d62341a205
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Multi pretrained weights: Cleanups and Refactoring
3 participants