-
Notifications
You must be signed in to change notification settings - Fork 172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix GradScaler import on torch >= 2.3.0 #620
base: master
Are you sure you want to change the base?
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #620 +/- ##
==========================================
+ Coverage 85.39% 86.07% +0.67%
==========================================
Files 81 81
Lines 8006 8014 +8
==========================================
+ Hits 6837 6898 +61
+ Misses 1169 1116 -53 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! I have one small concern about compatibility, but apart from that we should be good to go
benchmarks/benchmark_optimizer.py
Outdated
@@ -98,7 +98,7 @@ def run_trainer(batch_size: int, batch_time: float, client_mode: bool, verbose: | |||
grad_scaler = hivemind.GradScaler() | |||
else: | |||
# check that hivemind.Optimizer supports regular PyTorch grad scaler as well | |||
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp) | |||
grad_scaler = torch.amp.GradScaler(enabled=args.use_amp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One small question just to make sure: if the user has torch==1.9.0 (earliest supported version in requirements.txt), does this version already have torch.amp? Maybe we need to bump the version in requirements to make sure it won't break existing code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part used to exist, but got lost somewhere in splitting all of this work into multiple PRs. I've restored the logic that imports this GradScalar differently here, depending on the installed Torch version.
To answer your question, no, torch.amp
does not exist in 1.9.
d31ca2c
to
4b5c777
Compare
4b5c777
to
58d95fc
Compare
No description provided.