Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a beam writer that doesn't shuffle #5639

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tensorflow_datasets/core/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,9 @@ def download_and_prepare(
self.info.download_size = dl_manager.downloaded_size
# Write DatasetInfo to disk, even if we haven't computed statistics.
self.info.write_to_directory(self.data_dir)
print(f'XXXXXXXXXX: at end of temporary assignment')
print(f'XXXXXXXXXX: at end of with utils.incomplete_dir')
print(f'XXXXXXXXXX: utils.incomplete_dir ended')
# The generated DatasetInfo contains references to `tmp_data_dir`
self.info.update_data_dir(self.data_dir)

Expand Down Expand Up @@ -1411,11 +1414,13 @@ def _get_filename_template(
self, split_name: str
) -> naming.ShardedFileTemplate:
"""Returns a filename template for the given split."""
if self.info.file_format is None:
raise ValueError("File format is not set!")
return naming.ShardedFileTemplate(
split=split_name,
dataset_name=self.name,
data_dir=self.data_path,
filetype_suffix=self.info.file_format.file_suffix, # pytype: disable=attribute-error
filetype_suffix=self.info.file_format.file_suffix,
)


Expand Down
48 changes: 31 additions & 17 deletions tensorflow_datasets/core/dataset_builder_beam_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,16 @@ class DummyBeamDataset(dataset_builder.GeneratorBasedBuilder):
'valid_725': 725,
}

FEATURE_DICT = features.FeaturesDict({
'image': features.Image(shape=(16, 16, 1)),
'label': features.ClassLabel(names=['dog', 'cat']),
'id': tf.int32,
})

def _info(self):
return dataset_info.DatasetInfo(
builder=self,
features=features.FeaturesDict({
'image': features.Image(shape=(16, 16, 1)),
'label': features.ClassLabel(names=['dog', 'cat']),
'id': tf.int32,
}),
features=self.FEATURE_DICT,
supervised_keys=('x', 'x'),
metadata=dataset_info.BeamMetadataDict(),
)
Expand All @@ -71,6 +73,18 @@ def _generate_examples(self, num_examples):
return examples


class UnshuffledDummyBeamDataset(DummyBeamDataset):

def _info(self) -> dataset_info.DatasetInfo:
return dataset_info.DatasetInfo(
builder=self,
features=self.FEATURE_DICT,
supervised_keys=('x', 'x'),
metadata=dataset_info.BeamMetadataDict(),
disable_shuffling=True,
)


class CommonPipelineDummyBeamDataset(DummyBeamDataset):
EXPECTED_METADATA = {
'label_sum_1000': 500,
Expand Down Expand Up @@ -156,7 +170,12 @@ def make_default_config():


@pytest.mark.parametrize(
'dataset_cls', [DummyBeamDataset, CommonPipelineDummyBeamDataset]
'dataset_cls',
[
DummyBeamDataset,
CommonPipelineDummyBeamDataset,
UnshuffledDummyBeamDataset,
],
)
@pytest.mark.parametrize(
'make_dl_config',
Expand All @@ -178,17 +197,12 @@ def test_beam_datasets(
assert data_path.exists() # Dataset has been generated

# Check number of shards/generated files
_test_shards(
data_path,
pattern='%s-test.tfrecord-{:05}-of-{:05}' % dataset_name,
# Liquid sharding is not guaranteed to always use the same number.
num_shards=builder.info.splits['test'].num_shards,
)
_test_shards(
data_path,
pattern='%s-train.tfrecord-{:05}-of-{:05}' % dataset_name,
num_shards=1,
)
for split in ['test', 'train']:
_test_shards(
data_path,
pattern='%s-%s.tfrecord-{:05}-of-{:05}' % (dataset_name, split),
num_shards=builder.info.splits[split].num_shards,
)

ds = dataset_utils.as_numpy(builder.as_dataset())

Expand Down
50 changes: 50 additions & 0 deletions tensorflow_datasets/core/file_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,18 @@
from typing import Any, ClassVar, Type, TypeVar

from etils import epy
from tensorflow_datasets.core.utils.lazy_imports_utils import apache_beam as beam
from tensorflow_datasets.core.utils.lazy_imports_utils import array_record_io
from tensorflow_datasets.core.utils.lazy_imports_utils import array_record_module
from tensorflow_datasets.core.utils.lazy_imports_utils import parquet as pq
from tensorflow_datasets.core.utils.lazy_imports_utils import pyarrow as pa
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
from tensorflow_datasets.core.utils.lazy_imports_utils import tfrecordio

with epy.lazy_imports():
# pylint: disable=g-import-not-at-top
from etils import epath
from tensorflow_datasets.core import naming
from tensorflow_datasets.core.utils import file_utils
from tensorflow_datasets.core.utils import type_utils

Expand Down Expand Up @@ -167,6 +171,23 @@ def deserialize(cls, raw_example: bytes) -> Any:
"""
return tf.train.Example.FromString(raw_example)

@classmethod
def beam_sink(
cls,
filename_template: naming.ShardedFileTemplate,
num_shards: int | None = None,
) -> beam.PTransform:
"""Returns a Beam sink for writing examples in the given file format."""
raise NotImplementedError()

@classmethod
def num_examples(cls, filename: epath.PathLike) -> int:
"""Returns the number of examples in the given file."""
n = 0
for _ in cls.make_tf_data(filename):
n += 1
return n


class TfRecordFileAdapter(FileAdapter):
"""File adapter for TFRecord file format."""
Expand Down Expand Up @@ -205,6 +226,20 @@ def write_examples(
writer.write(serialized_example)
writer.flush()

@classmethod
def beam_sink(
cls,
filename_template: naming.ShardedFileTemplate,
num_shards: int | None = None,
) -> beam.PTransform:
"""Returns a Beam sink for writing examples in the given file format."""
file_path_prefix = filename_template.sharded_filepaths_pattern(
num_shards=num_shards, use_at_notation=True
).removesuffix('@*')
return tfrecordio.WriteToTFRecord(
file_path_prefix=file_path_prefix, num_shards=num_shards
)


class RiegeliFileAdapter(FileAdapter):
"""File adapter for Riegeli file format."""
Expand Down Expand Up @@ -291,6 +326,21 @@ def write_examples(
writer.write(serialized_example)
writer.close()

@classmethod
def beam_sink(
cls,
filename_template: naming.ShardedFileTemplate,
num_shards: int | None = None,
) -> beam.PTransform:
"""Returns a Beam sink for writing examples in the given file format."""
return array_record_io.WriteToArrayRecord(
filename_template.sharded_filepaths_pattern(
num_shards=num_shards, use_at_notation=True
),
num_shards=num_shards,
record_writer_options='group_size:1',
)


class ParquetFileAdapter(FileAdapter):
"""File adapter for the [Parquet](https://parquet.apache.org) file format.
Expand Down
6 changes: 5 additions & 1 deletion tensorflow_datasets/core/naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,7 @@ def sharded_filepaths_pattern(
self,
*,
num_shards: int | None = None,
use_at_notation: bool = False,
) -> str:
"""Returns a pattern describing all the file paths captured by this template.

Expand All @@ -641,21 +642,24 @@ def sharded_filepaths_pattern(

Args:
num_shards: optional specification of the number of shards.
use_at_notation: whether to return @* in case `num_shards` is `None`.

Returns:
the pattern describing all shards captured by this template.
"""
a_filepath = self.sharded_filepath(shard_index=0, num_shards=1)
if num_shards:
replacement = f'@{num_shards}'
elif use_at_notation:
replacement = '@*'
else:
replacement = '*'
return _replace_shard_pattern(os.fspath(a_filepath), replacement)

def sharded_filenames(self, num_shards: int) -> list[str]:
return [path.name for path in self.sharded_filepaths(num_shards=num_shards)]

def replace(self, **kwargs: Any) -> 'ShardedFileTemplate':
def replace(self, **kwargs: Any) -> ShardedFileTemplate:
"""Returns a copy of the `ShardedFileTemplate` with updated attributes."""
return dataclasses.replace(self, **kwargs)

Expand Down
35 changes: 24 additions & 11 deletions tensorflow_datasets/core/split_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
# pylint: disable=g-import-not-at-top
from tensorflow_datasets.core import example_serializer
from tensorflow_datasets.core import features as features_lib
from tensorflow_datasets.core import file_adapters
from tensorflow_datasets.core import naming
from tensorflow_datasets.core import splits as splits_lib
from tensorflow_datasets.core import utils
Expand Down Expand Up @@ -530,17 +531,29 @@ def _build_from_pcollection(
) -> _SplitInfoFuture:
"""Split generator for `beam.PCollection`."""
# TODO(tfds): Should try to add support to `max_examples_per_split`
beam_writer = writer_lib.BeamWriter(
serializer=example_serializer.ExampleSerializer(
self._features.get_serialized_info()
),
filename_template=filename_template,
hash_salt=split_name,
disable_shuffling=disable_shuffling,
shard_config=self._shard_config,
example_writer=self._example_writer,
ignore_duplicates=self._ignore_duplicates,
)
# TODO(weide): DO NOT SUBMIT
if disable_shuffling:
beam_writer = writer_lib.NoShuffleBeamWriter(
serializer=example_serializer.ExampleSerializer(
self._features.get_serialized_info()
),
file_format=file_adapters.FileFormat.from_value(
filename_template.filetype_suffix
),
filename_template=filename_template,
)
else:
beam_writer = writer_lib.BeamWriter(
serializer=example_serializer.ExampleSerializer(
self._features.get_serialized_info()
),
filename_template=filename_template,
hash_salt=split_name,
disable_shuffling=disable_shuffling,
shard_config=self._shard_config,
example_writer=self._example_writer,
ignore_duplicates=self._ignore_duplicates,
)

def _encode_example(key_ex, encode_fn=self._features.encode_example):
# We do not access self._features in this function to avoid pickling the
Expand Down
88 changes: 88 additions & 0 deletions tensorflow_datasets/core/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,3 +717,91 @@ def finalize(self) -> tuple[list[int], int]:
split_info_path.unlink()

return self._split_info["shard_lengths"], self._split_info["total_size"]


class NoShuffleBeamWriter:
"""Shuffles / writes Examples beam collection to sharded files."""

_OUTPUT_TAG_BUCKETS_LEN_SIZE = "tag_buckets_len_size"

def __init__(
self,
serializer: example_serializer.Serializer,
filename_template: naming.ShardedFileTemplate,
file_format: file_adapters.FileFormat,
):
"""Init BeamWriter.

Note that file "{filepath_prefix}.shard_lengths.json" is also created. It
contains a list with the number of examples in each final shard. Eg:
"[10,11,10,11]".

Args:
serializer: class that can serialize examples.
filename_template: template to format sharded filenames.
file_format: the file format to use.
"""
self._original_state = dict(
serializer=serializer,
filename_template=filename_template,
file_format=file_format,
)
self._file_format = file_format
self._file_adapter = file_adapters.ADAPTER_FOR_FORMAT[self._file_format]
self._filename_template = filename_template
self._serializer = serializer

@functools.lru_cache()
def _get_counter(self, name: str, namespace: str = "BeamWriter"):
return beam.metrics.Metrics.counter(namespace, name)

def inc_counter(self, name: str, value: int = 1) -> None:
self._get_counter(name).inc(value)

def __getstate__(self):
return self._original_state

def __setstate__(self, state):
self.__init__(**state)

def _serialize_example(
self,
key_example: tuple[hashing.HashKey, Example],
) -> bytes:
"""Returns (serialized_example)."""
_, example = key_example
self.inc_counter(name="serialized_examples")
return self._serializer.serialize_example(example)

def write_from_pcollection(self, examples_pcollection):
"""Returns PTransform to write (key, example) PCollection."""
return (
examples_pcollection
| "Serialize" >> beam.Map(self._serialize_example)
| "Write"
>> self._file_adapter.beam_sink(
filename_template=self._filename_template
)
)

def finalize(self) -> tuple[list[int], int]:
"""Returns the computed shard_lengths and total_size.

Returns:
List of length <number of shards> containing the number of examples stored
in each shard, and size of the files (in bytes).
"""
# We don't know the number of shards, the length of each shard, nor the
# total size, so we compute them here.
length_per_shard = {}
total_size_bytes = 0
prefix = epath.Path(self._filename_template.filepath_prefix())
for shard in self._filename_template.data_dir.glob(f"{prefix.name}*"):
length = self._file_adapter.num_examples(shard)
length_per_shard[shard] = length
total_size_bytes += shard.stat().length
shard_lengths: list[int] = []
for _, length in sorted(length_per_shard.items()):
shard_lengths.append(length)

return shard_lengths, total_size_bytes
Loading
Loading