Skip to content

Commit

Permalink
feat: make sample builders optional (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
egorchakov authored Nov 30, 2024
1 parent 304b5ef commit 34c2f16
Show file tree
Hide file tree
Showing 19 changed files with 210 additions and 193 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:
- id: pyupgrade

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.0
rev: v0.8.1
hooks:
- id: ruff
args: [--fix]
Expand Down
9 changes: 0 additions & 9 deletions config/_templates/dataset/carla.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,4 @@ inputs:
_target_: rbyte.io.DataFrameFilter
predicate: |
`control.throttle` > 0.5
- _target_: pipefunc.PipeFunc
renames:
input: data_filtered
output_name: samples
func:
_target_: rbyte.RollingWindowSampleBuilder
index_column: _idx_
period: 1i
#@ end
10 changes: 0 additions & 10 deletions config/_templates/dataset/mimicgen.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,5 @@ inputs:
func:
_target_: rbyte.io.DataFrameConcater
method: vertical

- _target_: pipefunc.PipeFunc
renames:
input: data_concated
output_name: samples
func:
_target_: rbyte.RollingWindowSampleBuilder
index_column: _idx_
period: 1i

#@ end
#@ end
9 changes: 0 additions & 9 deletions config/_templates/dataset/nuscenes/mcap.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,4 @@ inputs:
_target_: rbyte.io.DataFrameFilter
predicate: |
`/odom/vel.x` >= 8
- _target_: pipefunc.PipeFunc
renames:
input: data_filtered
output_name: samples
func:
_target_: rbyte.RollingWindowSampleBuilder
index_column: (@=camera_topics.values()[0]@)/_idx_
period: 1i
#@ end
9 changes: 0 additions & 9 deletions config/_templates/dataset/nuscenes/rrd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,4 @@ inputs:
_target_: rbyte.io.DataFrameFilter
predicate: |
`/world/ego_vehicle/CAM_FRONT/timestamp` between '2018-07-24 03:28:48' and '2018-07-24 03:28:50'
- _target_: pipefunc.PipeFunc
renames:
input: data_filtered
output_name: samples
func:
_target_: rbyte.RollingWindowSampleBuilder
index_column: (@=camera_entities.values()[0]@)/_idx_
period: 1i
#@ end
24 changes: 17 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
[project]
name = "rbyte"
version = "0.9.1"
version = "0.10.0"
description = "Multimodal PyTorch dataset library"
authors = [{ name = "Evgenii Gorchakov", email = "[email protected]" }]
maintainers = [{ name = "Evgenii Gorchakov", email = "[email protected]" }]
dependencies = [
"tensordict>=0.6.2",
"torch",
"numpy",
"polars>=1.15.0",
"polars>=1.16.0",
"pydantic>=2.10.2",
"more-itertools>=10.5.0",
"hydra-core>=1.3.2",
"optree>=0.13.1",
"cachetools>=5.5.0",
"diskcache>=5.6.3",
"jaxtyping>=0.2.34",
"parse>=1.20.2",
"structlog>=24.4.0",
"xxhash>=3.5.0",
Expand All @@ -40,7 +39,7 @@ repo = "https://github.com/yaak-ai/rbyte"

[project.optional-dependencies]
build = ["hatchling>=1.25.0", "grpcio-tools>=1.62.0", "protoletariat==3.2.19"]
visualize = ["rerun-sdk[notebook]>=0.20.0"]
visualize = ["rerun-sdk[notebook]>=0.20.2"]
mcap = [
"mcap>=1.2.1",
"mcap-ros2-support>=0.5.5",
Expand All @@ -54,7 +53,7 @@ video = [
"video-reader-rs>=0.2.1",
]
hdf5 = ["h5py>=3.12.1"]
rrd = ["rerun-sdk>=0.20.0", "pyarrow-stubs"]
rrd = ["rerun-sdk>=0.20.2", "pyarrow-stubs"]

[project.scripts]
rbyte-visualize = 'rbyte.scripts.visualize:main'
Expand All @@ -72,7 +71,7 @@ dev-dependencies = [
"wat-inspector>=0.4.3",
"lovely-tensors>=0.1.18",
"pudb>=2024.1.2",
"ipython>=8.29.0",
"ipython>=8.30.0",
"ipython-autoimport>=0.5",
"pytest>=8.3.3",
"testbook>=0.4.2",
Expand Down Expand Up @@ -129,7 +128,18 @@ skip-magic-trailing-comma = true
preview = true
select = ["ALL"]
fixable = ["ALL"]
ignore = ["A001", "A002", "D", "CPY", "COM812", "F722", "PD901", "ISC001", "TD"]
ignore = [
"A001",
"A002",
"D",
"CPY",
"COM812",
"F722",
"PD901",
"ISC001",
"TD",
"TC006",
]

[tool.ruff.lint.isort]
split-on-trailing-comma = false
Expand Down
81 changes: 48 additions & 33 deletions src/rbyte/io/_mcap/tensor_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from dataclasses import dataclass
from functools import cached_property
from mmap import ACCESS_READ, mmap
from typing import IO, override
from operator import itemgetter
from typing import IO, final, override

import more_itertools as mit
import numpy.typing as npt
import torch
from jaxtyping import Shaped
from mcap.data_stream import ReadDataStream
from mcap.decoder import DecoderFactory
from mcap.opcode import Opcode
Expand All @@ -32,6 +32,7 @@ class MessageIndex:
message_length: int


@final
class McapTensorSource(TensorSource):
@validate_call(config=BaseModel.model_config)
def __init__(
Expand All @@ -47,8 +48,8 @@ def __init__(
with bound_contextvars(
path=path.as_posix(), topic=topic, message_decoder_factory=decoder_factory
):
self._path: FilePath = path
self._validate_crcs: bool = validate_crcs
self._path = path
self._validate_crcs = validate_crcs

summary = SeekingReader(
stream=self._file, validate_crcs=self._validate_crcs
Expand All @@ -73,13 +74,14 @@ def __init__(
logger.error(msg := "missing message decoder")
raise RuntimeError(msg)

self._message_decoder: Callable[[bytes], object] = message_decoder
self._chunk_indexes: tuple[ChunkIndex, ...] = tuple(
self._message_decoder = message_decoder
self._chunk_indexes = tuple(
chunk_index
for chunk_index in summary.chunk_indexes
if self._channel.id in chunk_index.message_index_offsets
)
self._decoder: Callable[[bytes], npt.ArrayLike] = decoder
self._decoder = decoder
self._mmap = None

@property
def _file(self) -> IO[bytes]:
Expand All @@ -89,42 +91,55 @@ def _file(self) -> IO[bytes]:

case None | mmap(closed=True):
with self._path.open("rb") as f:
self._mmap: mmap = mmap(
fileno=f.fileno(), length=0, access=ACCESS_READ
)
self._mmap = mmap(fileno=f.fileno(), length=0, access=ACCESS_READ)

case _:
raise RuntimeError

return self._mmap # pyright: ignore[reportReturnType]

@override
def __getitem__(self, indexes: Iterable[int]) -> Shaped[Tensor, "b h w c"]:
frames: Mapping[int, npt.ArrayLike] = {}

message_indexes_by_chunk_start_offset: Mapping[
int, Iterable[tuple[int, MessageIndex]]
] = mit.map_reduce(
zip(indexes, (self._message_indexes[idx] for idx in indexes), strict=True),
keyfunc=lambda x: x[1].chunk_start_offset,
)
def __getitem__(self, indexes: int | Iterable[int]) -> Tensor:
match indexes:
case Iterable():
arrays: Mapping[int, npt.ArrayLike] = {}
message_indexes = (self._message_indexes[idx] for idx in indexes)
indexes_by_chunk_start_offset = mit.map_reduce(
zip(indexes, message_indexes, strict=True),
keyfunc=lambda x: x[1].chunk_start_offset,
)

for chunk_start_offset, chunk_indexes in sorted(
indexes_by_chunk_start_offset.items(), key=itemgetter(0)
):
_ = self._file.seek(chunk_start_offset + 1 + 8)
chunk = Chunk.read(ReadDataStream(self._file))
stream, _ = get_chunk_data_stream(
chunk, validate_crc=self._validate_crcs
)
for index, message_index in sorted(
chunk_indexes, key=lambda x: x[1].message_start_offset
):
stream.read(message_index.message_start_offset - stream.count) # pyright: ignore[reportUnusedCallResult]
message = Message.read(stream, message_index.message_length)
decoded_message = self._message_decoder(message.data)
arrays[index] = self._decoder(decoded_message.data)

for (
chunk_start_offset,
chunk_message_indexes,
) in message_indexes_by_chunk_start_offset.items():
self._file.seek(chunk_start_offset + 1 + 8) # pyright: ignore[reportUnusedCallResult]
chunk = Chunk.read(ReadDataStream(self._file))
stream, _ = get_chunk_data_stream(chunk, validate_crc=self._validate_crcs)
for frame_index, message_index in sorted(
chunk_message_indexes, key=lambda x: x[1].message_start_offset
):
stream.read(message_index.message_start_offset - stream.count) # pyright: ignore[reportUnusedCallResult]
message = Message.read(stream, message_index.message_length)
tensors = [torch.from_numpy(arrays[idx]) for idx in indexes] # pyright: ignore[reportUnknownMemberType]

return torch.stack(tensors)

case _:
message_index = self._message_indexes[indexes]
_ = self._file.seek(message_index.chunk_start_offset + 1 + 8)
chunk = Chunk.read(ReadDataStream(self._file))
stream, _ = get_chunk_data_stream(chunk, self._validate_crcs)
_ = stream.read(message_index.message_start_offset - stream.count)
message = Message.read(stream, length=message_index.message_length)
decoded_message = self._message_decoder(message.data)
frames[frame_index] = self._decoder(decoded_message.data) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
array = self._decoder(decoded_message.data)

return torch.stack([torch.from_numpy(frames[idx]) for idx in indexes]) # pyright: ignore[reportUnknownMemberType]
return torch.from_numpy(array) # pyright: ignore[reportUnknownMemberType]

@override
def __len__(self) -> int:
Expand Down
33 changes: 18 additions & 15 deletions src/rbyte/io/_numpy/tensor_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import cached_property
from os import PathLike
from pathlib import Path
from typing import TYPE_CHECKING, override
from typing import final, override

import numpy as np
import torch
Expand All @@ -16,34 +16,37 @@
from rbyte.io.base import TensorSource
from rbyte.utils.tensor import pad_sequence

if TYPE_CHECKING:
from types import EllipsisType


@final
class NumpyTensorSource(TensorSource):
@validate_call(config=BaseModel.model_config)
def __init__(
self, path: PathLike[str], select: Sequence[str] | None = None
) -> None:
super().__init__()

self._path: Path = Path(path)
self._select: Sequence[str] | EllipsisType = select or ...
self._path = Path(path)
self._select = select or ...

@cached_property
def _path_posix(self) -> str:
return self._path.resolve().as_posix()

def _getitem(self, index: object) -> Tensor:
path = self._path_posix.format(index)
array = structured_to_unstructured(np.load(path)[self._select]) # pyright: ignore[reportUnknownVariableType]
return torch.from_numpy(np.ascontiguousarray(array)) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]

@override
def __getitem__(self, indexes: Iterable[object]) -> Tensor:
tensors: list[Tensor] = []
for index in indexes:
path = self._path_posix.format(index)
array = structured_to_unstructured(np.load(path)[self._select]) # pyright: ignore[reportUnknownVariableType]
tensor = torch.from_numpy(np.ascontiguousarray(array)) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
tensors.append(tensor)

return pad_sequence(list(tensors), dim=0, value=torch.nan)
def __getitem__(self, indexes: object | Iterable[object]) -> Tensor:
match indexes:
case Iterable():
tensors = map(self._getitem, indexes) # pyright: ignore[reportUnknownArgumentType]

return pad_sequence(list(tensors), dim=0, value=torch.nan)

case _:
return self._getitem(indexes)

@override
def __len__(self) -> int:
Expand Down
6 changes: 3 additions & 3 deletions src/rbyte/io/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from collections.abc import Iterable
from typing import Any, Protocol, runtime_checkable
from collections.abc import Sequence
from typing import Protocol, runtime_checkable

from torch import Tensor


@runtime_checkable
class TensorSource(Protocol):
def __getitem__(self, indexes: Iterable[Any]) -> Tensor: ...
def __getitem__[T](self, indexes: T | Sequence[T]) -> Tensor: ...
def __len__(self) -> int: ...
2 changes: 1 addition & 1 deletion src/rbyte/io/dataframe/aligner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class MergeConfig(BaseModel):
columns: OrderedDict[str, ColumnMergeConfig] = Field(default_factory=OrderedDict)


type Fields = MergeConfig | OrderedDict[str, "Fields"]
type Fields = MergeConfig | OrderedDict[str, Fields]


@final
Expand Down
2 changes: 1 addition & 1 deletion src/rbyte/io/hdf5/dataframe_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from pydantic import ConfigDict, validate_call

type Fields = Mapping[str, PolarsDataType | None] | Mapping[str, "Fields"]
type Fields = Mapping[str, PolarsDataType | None] | Mapping[str, Fields]


@final
Expand Down
11 changes: 5 additions & 6 deletions src/rbyte/io/hdf5/tensor_source.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
from collections.abc import Iterable
from typing import cast, override
from collections.abc import Sequence
from typing import cast, final, override

import torch
from h5py import Dataset, File
from jaxtyping import UInt8
from pydantic import FilePath, validate_call
from torch import Tensor

from rbyte.io.base import TensorSource


@final
class Hdf5TensorSource(TensorSource):
@validate_call
def __init__(self, path: FilePath, key: str) -> None:
file = File(path)
self._dataset: Dataset = cast(Dataset, file[key])
self._dataset = cast(Dataset, File(path)[key])

@override
def __getitem__(self, indexes: Iterable[int]) -> UInt8[Tensor, "b h w c"]:
def __getitem__(self, indexes: int | Sequence[int]) -> Tensor:
return torch.from_numpy(self._dataset[indexes]) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]

@override
Expand Down
Loading

0 comments on commit 34c2f16

Please sign in to comment.