-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Test optimizer to device #20062
base: master
Are you sure you want to change the base?
Test optimizer to device #20062
Conversation
|
||
# Try from_dict | ||
# These all pretend that we have an appropriate prototype, I don't think we can actually do this since | ||
# all we may have is a CPU pickle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test we have for pytorch load_state_dict
being able to read a CPU checkpoint into an appropriate GPU optimizer is here: https://github.com/pytorch/pytorch/blob/main/test/test_optim.py#L1545-L1574
The code above is also how I expect checkpointing to happen, without the need of an explicit move to device.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated this test to be more explicit about what is going on, please take a look and see if it makes sense since the test you linked doesn't look at thorough as far as I can tell.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, the case we test is moving from CPU to GPU, and I see you test more combinations.
# Use from_dict with cpu prototype, fused = True | ||
opt_gpu_dict = optimizer_on_device[gpu_device + "_fused_True"].state_dict() | ||
cpu_prototype = copy.deepcopy(optimizer_on_device["cpu"]) | ||
cpu_prototype.load_state_dict(opt_gpu_dict) # This should give an error / refuse to allow fused = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, for older versions of torch this should indeed be not allowed/would break. But since torch 2.4, there is a fused CPU Adam(W)/SGD/Adagrad, so fused=True on CPU for these optimizers would be valid.
What does this PR do?
Pursuant to #19955
add an extended test for _optimizer_to_device that explicitly tests moving the optimizer across devices.
Fixes #19955
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist
📚 Documentation preview 📚: https://pytorch-lightning--20062.org.readthedocs.build/en/20062/