Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add repeat method to datasets #7198

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
32 changes: 32 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4078,6 +4078,38 @@ def skip(self, n: int) -> "Dataset":
"""
return self.select(range(n, len(self)))

def repeat(self, num_times: int) -> "Dataset":
"""
Create a new [`Dataset`] that repeats the underlying dataset `num_times` times.

Like itertools.repeat, repeating once just returns the full dataset.

Args:
num_times (`int`):
Number of times to repeat the dataset.

Example:
```py
>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="train")
>>> ds = ds.take(2).repeat(2)
>>> list(ds)
[{'label': 1,
'text': 'the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .'},
{'label': 1,
'text': 'the gorgeously elaborate continuation of " the lord of the rings " trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson\'s expanded vision of j . r . r . tolkien\'s middle-earth .'},
{'label': 1, 'text': 'effective but too-tepid biopic'},
{'label': 1,
'text': 'the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .'},
{'label': 1,
'text': 'the gorgeously elaborate continuation of " the lord of the rings " trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson\'s expanded vision of j . r . r . tolkien\'s middle-earth .'},
{'label': 1, 'text': 'effective but too-tepid biopic'}]
```
"""
if num_times is None:
raise ValueError("Map style datasets do not support indefinite repetition.")
return _concatenate_map_style_datasets([self] * num_times) if num_times > 0 else self.select([])

def take(self, n: int) -> "Dataset":
"""
Create a new [`Dataset`] with only the first `n` elements.
Expand Down
91 changes: 91 additions & 0 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1622,6 +1622,54 @@ def num_shards(self) -> int:
return self.ex_iterable.num_shards


class RepeatExamplesIterable(_BaseExamplesIterable):
"""
Iterable that repeats the underlying iterable a given number of times.
"""

def __init__(
self,
ex_iterable: _BaseExamplesIterable,
num_times: Optional[int],
):
super().__init__()
self.ex_iterable = ex_iterable
self.num_times = num_times

def _init_state_dict(self) -> dict:
self._state_dict = {
"repeat_index": 0,
"ex_iterable": self.ex_iterable._init_state_dict(),
}
return self._state_dict

def __iter__(self):
repeat_index = self._state_dict["repeat_index"] if self._state_dict else 0
while True:
if self.num_times is not None and repeat_index >= max(self.num_times, 0):
break
yield from self.ex_iterable
repeat_index += 1
if self._state_dict:
self._state_dict["repeat_index"] = repeat_index
self._state_dict["ex_iterable"] = self.ex_iterable._init_state_dict()

def shuffle_data_sources(self, generator: np.random.Generator) -> "RepeatExamplesIterable":
"""Shuffle the underlying iterable, then repeat."""
return RepeatExamplesIterable(self.ex_iterable.shuffle_data_sources(generator), num_times=self.num_times)

def shard_data_sources(self, worker_id: int, num_workers: int) -> "RepeatExamplesIterable":
"""Shard, then repeat shards."""
return RepeatExamplesIterable(
self.ex_iterable.shard_data_sources(worker_id, num_workers),
num_times=self.num_times,
)

@property
def n_shards(self) -> int:
return self.ex_iterable.n_shards


class TakeExamplesIterable(_BaseExamplesIterable):
def __init__(
self,
Expand Down Expand Up @@ -2762,6 +2810,49 @@ def skip(self, n: int) -> "IterableDataset":
token_per_repo_id=self._token_per_repo_id,
)

def repeat(self, num_times: Optional[int]) -> "IterableDataset":
"""
Create a new [`IterableDataset`] that repeats the underlying dataset `num_times` times.

N.B. The effect of calling shuffle after repeat depends significantly on buffer size.
With buffer_size 1, duplicate data is never seen in the same iteration, even after shuffling:
ds.repeat(n).shuffle(seed=42, buffer_size=1) is equivalent to ds.shuffle(seed=42, buffer_size=1).repeat(n),
and only shuffles shard orders within each iteration.
With buffer size >= (num samples in the dataset * num_times), we get full shuffling of the repeated data, i.e. we can observe duplicates in
the same iteration.

Args:
num_times (`int`) or (`None`):
Number of times to repeat the dataset. If `None`, the dataset will be repeated indefinitely.

Example:
```py
>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="train")
>>> ds = ds.take(2).repeat(2)
>>> list(ds)
[{'label': 1,
'text': 'the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .'},
{'label': 1,
'text': 'the gorgeously elaborate continuation of " the lord of the rings " trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson\'s expanded vision of j . r . r . tolkien\'s middle-earth .'},
{'label': 1, 'text': 'effective but too-tepid biopic'},
{'label': 1,
'text': 'the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .'},
{'label': 1,
'text': 'the gorgeously elaborate continuation of " the lord of the rings " trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson\'s expanded vision of j . r . r . tolkien\'s middle-earth .'},
{'label': 1, 'text': 'effective but too-tepid biopic'}]
```
"""
return IterableDataset(
ex_iterable=RepeatExamplesIterable(self._ex_iterable, num_times=num_times),
info=self._info,
split=self._split,
formatting=self._formatting,
shuffling=copy.deepcopy(self._shuffling),
distributed=copy.deepcopy(self._distributed),
token_per_repo_id=self._token_per_repo_id,
)

def take(self, n: int) -> "IterableDataset":
"""
Create a new [`IterableDataset`] with only the first `n` elements.
Expand Down
26 changes: 26 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,32 @@ def test_concatenate_pickle(self, in_memory):
self.assertEqual(dset_concat.info.description, "Dataset2\n\nDataset1")
del dset1, dset2, dset3

def test_repeat(self, in_memory):
with tempfile.TemporaryDirectory() as tmp_dir:
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
repeated_dset = dset.repeat(3)
column_values_dict = {col: dset[col] for col in dset.column_names}
for col, single_values in column_values_dict.items():
self.assertListEqual(repeated_dset[col], single_values * 3)
del repeated_dset

with tempfile.TemporaryDirectory() as tmp_dir:
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
with pytest.raises(ValueError):
dset.repeat(None)

with tempfile.TemporaryDirectory() as tmp_dir:
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
repeated_dset = dset.repeat(0)
self.assertEqual(len(repeated_dset), 0)
del repeated_dset

with tempfile.TemporaryDirectory() as tmp_dir:
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
repeated_dset = dset.repeat(-1)
self.assertEqual(len(repeated_dset), 0)
del repeated_dset

def test_flatten(self, in_memory):
with tempfile.TemporaryDirectory() as tmp_dir:
with Dataset.from_dict(
Expand Down
31 changes: 31 additions & 0 deletions tests/test_iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
MappedExamplesIterable,
RandomlyCyclingMultiSourcesExamplesIterable,
RebatchedArrowExamplesIterable,
RepeatExamplesIterable,
SelectColumnsIterable,
ShuffledDataSourcesArrowExamplesIterable,
ShuffledDataSourcesExamplesIterable,
Expand Down Expand Up @@ -1165,6 +1166,28 @@ def test_take_examples_iterable():
assert_load_state_dict_resumes_iteration(take_ex_iterable)


@pytest.mark.parametrize(
"n, num_times",
[
(3, None),
(3, 3),
(3, 0),
],
)
def test_repeat_examples_iterable(n, num_times):
base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n})
ex_iterable = RepeatExamplesIterable(base_ex_iterable, num_times=num_times)
all_examples = [x for _, x in generate_examples_fn(n=n)]
if num_times is not None:
expected = all_examples * max(num_times, 0)
assert [x for _, x in ex_iterable] == expected
else:
max_iters = 135
iterator = iter(ex_iterable)
for i in range(max_iters):
assert next(iterator)[1] == all_examples[i % len(all_examples)], f"iteration {i} failed,"


def test_vertically_concatenated_examples_iterable():
ex_iterable1 = ExamplesIterable(generate_examples_fn, {"label": 10})
ex_iterable2 = ExamplesIterable(generate_examples_fn, {"label": 5})
Expand Down Expand Up @@ -1735,6 +1758,14 @@ def test_iterable_dataset_take(dataset: IterableDataset, n):
assert list(take_dataset) == list(dataset)[:n]


@pytest.mark.parametrize("n", [0, 2])
def test_iterable_dataset_repeat(dataset: IterableDataset, n):
repeat_dataset = dataset.repeat(n)
assert isinstance(repeat_dataset._ex_iterable, RepeatExamplesIterable)
assert repeat_dataset._ex_iterable.num_times == n
assert list(repeat_dataset) == list(dataset) * n


def test_iterable_dataset_shard():
num_examples = 20
num_shards = 5
Expand Down