Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

uniformly use skip for both (map-style) Dataset and IterableDataset #521

Merged
merged 2 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/docker/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
torch >= 2.3.0
torchdata >= 0.8.0
datasets >= 2.19.0
datasets >= 2.21.0
tomli >= 1.1.0 ; python_version < "3.11"
tensorboard
sentencepiece
Expand Down
25 changes: 6 additions & 19 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,12 @@
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset

try:
from torchdata.stateful_dataloader import StatefulDataLoader
except ImportError as e:
raise ImportError(
"Please install the latest torchdata nightly to use StatefulDataloader via:"
"pip3 install --pre torchdata --index-url https://download.pytorch.org/whl/nightly"
) from e
from torchdata.stateful_dataloader import StatefulDataLoader

from torchtitan.datasets.tokenizer import Tokenizer
from torchtitan.logging import logger

from datasets import load_dataset
from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node

# map from dataset name to a local directory, or
Expand Down Expand Up @@ -102,7 +96,7 @@ def __init__(
else:
ds = load_dataset(dataset_path, split="train")

# TODO: support shuffling and checkpointing
# TODO: support shuffling
self.dataset_name = dataset_name
self._data = split_dataset_by_node(ds, rank, world_size)
self._tokenizer = tokenizer
Expand Down Expand Up @@ -143,17 +137,10 @@ def _get_data_iter(self):
if self._sample_idx == 0:
return iter(self._data)

# Skip samples
if isinstance(self._data, IterableDataset):
it = iter(self._data)
# Naively iterate through the samples as skip may not be supported
for _ in range(self._sample_idx):
next(it)
return it

# As skipping to the end throws an error in case of map-style dataset, return an empty iterator
if self._sample_idx == len(self._data):
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
return iter([])

return iter(self._data.skip(self._sample_idx))

def load_state_dict(self, state_dict):
Expand All @@ -179,7 +166,7 @@ def state_dict(self) -> Dict[str, Any]:
return {self._rank_id: pickle.dumps(super().state_dict())}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# State being empty is valid, don't log a warning
# State being empty is valid
if not state_dict:
return

Expand Down
Loading