-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Resumable IterableDataset] Add IterableDataset state_dict #6658
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
would be nice to have this feature in the new dataset release! |
Before finalising this this I'd like to make sure this philosophy makes sense for other libs like cc @muellerzr I'd love your feedback on this one |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall I think this looks like a very nice API decision, and super easy for us to bring into Accelerate as part of load_state
. Will be nice to not have to use skip_batches
if a user is using an IterableDataset
.
One design question though: what's the logic behind self._state_dict
rather than having it all be state_dict
?
Private stuff doesn't exist in python, so what's the aim in doing that here and having state_dict
be a passthrough to it? (If this is a common design pattern over in datasets
that's okay)
The We need to copy it every time the user accesses it. Otherwise we would get state_dict = ds.state_dict()
for x in ds:
assert ds.state_dict() == state_dict # and actually `assert ds.state_dict() is state_dict` The state is updated in-place since it's made of dictionaries that are shared with the steps in the IterableDataset pipeline. |
What do you think of making it a full property with a docstring explicitly stating users shouldn’t call/modify it directly? I can imagine some exploratory users getting curious |
I don't think users read docstrings of properties that often. What about explaining the logic in the |
Sure, I can agree with that! |
Just a small note mentioning returns a copy of the state dict should be enough imo |
looking forward as well for this PR to be merge |
Hi, I'm experimenting with LLM pretraining using your code. I found that the time of resuming an iterable dataset can be reduced to 5% (my streaming process includes tokenization), but I'm not sure if I'm using it correctly. Could you help me check it? Thanks.
|
it sounds good to me :) |
@lhoestq Hi, if I set |
It does work well if you prefetch and then resume from a state, but you might lose the samples that were in the prefetch buffer of the DataLoader (which could be acceptable in some circumstances). Fortunately we're about to ship an integration with the new StatefulDataLoader from torchdata which can help on this matter :) |
yeah, what I meant is that prefetch might drop a few data entries. really looking forward to the new StatefulDataLoader. :) |
6254769
to
c323af0
Compare
Show benchmarksPyArrow==8.0.0 Show updated benchmarks!Benchmark: benchmark_array_xd.json
Benchmark: benchmark_getitem_100B.json
Benchmark: benchmark_indices_mapping.json
Benchmark: benchmark_iterating.json
Benchmark: benchmark_map_filter.json
Show updated benchmarks!Benchmark: benchmark_array_xd.json
Benchmark: benchmark_getitem_100B.json
Benchmark: benchmark_indices_mapping.json
Benchmark: benchmark_iterating.json
Benchmark: benchmark_map_filter.json
|
@lhoestq Hello, I'm wondering if there are any solutions to work with shuffle now. I've noticed the caveats in docs,
|
Hi ! I haven't experimented with implementing state_dict for the shuffle buffer. Not sure if this is a good idea to add this, given a shuffle buffer can be quite big and poses serialization challenges. It shouldn't be difficult to experiment with a simple implementation in |
@lhoestq thank you for your quick response! I'll try it :} |
@lhoestq Hi, just revise the class BufferShuffledExamplesIterable(datasets.iterable_dataset.BufferShuffledExamplesIterable):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _init_state_dict(self) -> dict:
self._state_dict = self.ex_iterable._init_state_dict()
self._state_dict['mem_buffer'] = ([],)
self._state_dict['gloabl_example_index'] = 0
return self._state_dict
def __iter__(self):
buffer_size = self.buffer_size
rng = deepcopy(self.generator)
indices_iterator = self._iter_random_indices(rng, buffer_size)
# this is the shuffle buffer that we keep in memory
mem_buffer = self._state_dict['mem_buffer'][0]
gloabl_example_index_start = self._state_dict["gloabl_example_index"] if self._state_dict else 0
# skip already consumed ones
for i in range(gloabl_example_index_start):
_ = next(indices_iterator)
for x in self.ex_iterable:
if len(mem_buffer) == buffer_size: # if the buffer is full, pick and example from it
i = next(indices_iterator)
if self._state_dict:
self._state_dict['gloabl_example_index'] += 1
yield mem_buffer[i]
mem_buffer[i] = x # replace the picked example by a new one
else: # otherwise, keep filling the buffer
mem_buffer.append(x)
# when we run out of examples, we shuffle the remaining examples in the buffer and yield them
rng.shuffle(mem_buffer)
yield from mem_buffer
def shuffle_data_sources(self, generator: np.random.Generator) -> BufferShuffledExamplesIterable:
"""Shuffle the wrapped examples iterable as well as the shuffling buffer."""
return BufferShuffledExamplesIterable(
self.ex_iterable.shuffle_data_sources(generator), buffer_size=self.buffer_size, generator=generator
)
def shard_data_sources(self, worker_id: int, num_workers: int) -> BufferShuffledExamplesIterable:
"""Keep only the requested shard."""
return BufferShuffledExamplesIterable(
self.ex_iterable.shard_data_sources(worker_id, num_workers),
buffer_size=self.buffer_size,
generator=self.generator,
)
def load_state_dict(self, state_dict: dict) -> dict:
def _inner_load_state_dict(state, new_state):
if new_state is not None and isinstance(state, dict):
for key in state:
state[key] = _inner_load_state_dict(state[key], new_state[key])
return state
elif new_state is not None and isinstance(state, list):
for i in range(len(state)):
state[i] = _inner_load_state_dict(state[i], new_state[i])
return state
return new_state
return _inner_load_state_dict(self._state_dict, state_dict) I've noticed that it uses significantly more RAM than the original version and experiences a considerable decrease in GPU utilization. Could you offer some suggestions to address this issue? or is it prohibited to maintain sth except for simple indices that small enough for each worker 😢 |
Some ExamplesIterable copy and store old versions of the state_dict of parent ExamplesIterable. It is the case for example for batched Copying a shuffle buffer takes some RAM and some time, which can slow down the data loading pipeline. |
A simple implementation of a mechanism to resume an IterableDataset.
It works by restarting at the latest shard and skip samples. It provides fast resuming (though not instantaneous).
Example:
returns
using torchdata:
docs: https://huggingface.co/docs/datasets/main/en/use_with_pytorch#checkpoint-and-resume