-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support skip_first_batches for XLA #2966
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Overall this looks fine, just one suggestion :)
src/accelerate/data_loader.py
Outdated
@@ -1083,6 +1088,12 @@ def skip_first_batches(dataloader, num_batches=0): | |||
""" | |||
Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. | |||
""" | |||
is_xla_dataloader = False | |||
if is_torch_xla_available() and isinstance(dataloader, MpDeviceLoaderWrapper): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At this point I believe we can use PartialState().distributed_type == DistributedType.XLA
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I change to AcceleratorState
because this class has already been imported and it aligns with the usage of the prepare_data_loader
method. I believe that the distributed_type
of these two states is shared.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PartialState
is safer and better for these types of situations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your correction. I have already changed it to PartialState
. I'm not very familiar with these states :)
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for enabling!
hi, @muellerzr , can you merge this pr? The error in tests seems unrelated to this PR. |
@yitongh yes indeed! |
What does this PR do?
At present, when using the resume_from_checkpoint feature in the Transformers Trainer, it results in an error because
skip_first_batches
does not supportMpDeviceLoaderWrapper
ofXLA
. This PR supports this feature.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@muellerzr