From 93db47c9c832201e2be7f45e16e735ee5db34286 Mon Sep 17 00:00:00 2001 From: Evgenii Gorchakov Date: Mon, 18 Nov 2024 14:33:04 +0100 Subject: [PATCH] feat: Zenseact Open Dataset (zod) support (#23) - add zod (https://zod.zenseact.com/) - refactor TableAligner - flatten batch structure - more generic table processing: mapping -> pytree - group sources/readers by filetype - remove `scripts/build_table.py` + configs - test data fixes --- .gitattributes | 2 - .gitmodules | 4 +- .pre-commit-config.yaml | 9 +- config/_templates/build_table.yaml | 10 - config/_templates/dataset/carla.yaml | 62 ++--- config/_templates/dataset/mimicgen.yaml | 44 ++-- config/_templates/dataset/nuscenes/mcap.yaml | 102 ++++++++ config/_templates/dataset/nuscenes/rrd.yaml | 93 +++++++ config/_templates/dataset/nuscenes_mcap.yaml | 104 -------- config/_templates/dataset/nuscenes_rrd.yaml | 97 ------- config/_templates/dataset/yaak.yaml | 194 +++++++------- config/_templates/dataset/zod.yaml | 114 +++++++++ config/_templates/frame_reader/directory.yaml | 10 - config/_templates/frame_reader/hdf5.yaml | 5 - config/_templates/frame_reader/mcap.yaml | 14 -- .../_templates/frame_reader/video/ffmpeg.yaml | 6 - .../_templates/frame_reader/video/vali.yaml | 5 - config/_templates/logger/rerun/carla.yaml | 24 +- config/_templates/logger/rerun/mimicgen.yaml | 12 +- .../logger/rerun/nuscenes/mcap.yaml | 39 +++ .../{nuscenes_rrd.yaml => nuscenes/rrd.yaml} | 34 ++- .../logger/rerun/nuscenes_mcap.yaml | 40 --- config/_templates/logger/rerun/yaak.yaml | 29 ++- config/_templates/logger/rerun/zod.yaml | 16 ++ config/_templates/read_frames.yaml | 17 -- config/_templates/table_builder/carla.yaml | 28 --- config/_templates/table_builder/hdf5.yaml | 20 -- config/_templates/table_builder/mcap.yaml | 64 ----- config/_templates/table_builder/rrd.yaml | 53 ---- config/_templates/table_builder/yaak.yaml | 102 -------- config/_templates/table_writer/console.yaml | 4 - config/_templates/table_writer/csv.yaml | 4 - config/_templates/table_writer/parquet.yaml | 4 - examples/.gitattributes | 1 + examples/nuscenes_mcap.ipynb | 4 +- examples/nuscenes_rrd.ipynb | 4 +- hatch_build.py | 12 +- justfile | 19 +- pyproject.toml | 21 +- src/rbyte/batch/batch.py | 3 +- src/rbyte/config/base.py | 7 +- src/rbyte/dataset.py | 165 +++++------- src/rbyte/io/__init__.py | 50 ++++ src/rbyte/io/_json/__init__.py | 3 + .../json/reader.py => _json/table_reader.py} | 16 +- src/rbyte/io/_mcap/__init__.py | 4 + .../mcap/reader.py => _mcap/table_reader.py} | 17 +- .../mcap/reader.py => _mcap/tensor_source.py} | 39 +-- src/rbyte/io/_numpy/__init__.py | 3 + src/rbyte/io/_numpy/tensor_source.py | 50 ++++ src/rbyte/io/base.py | 10 + src/rbyte/io/frame/__init__.py | 32 --- src/rbyte/io/frame/base.py | 13 - src/rbyte/io/frame/directory/__init__.py | 3 - src/rbyte/io/frame/directory/reader.py | 47 ---- src/rbyte/io/frame/hdf5/__init__.py | 3 - src/rbyte/io/frame/hdf5/reader.py | 25 -- src/rbyte/io/frame/mcap/__init__.py | 3 - src/rbyte/io/frame/rrd/__init__.py | 3 - src/rbyte/io/frame/rrd/reader.py | 78 ------ src/rbyte/io/frame/video/ffmpeg_reader.py | 46 ---- src/rbyte/io/hdf5/__init__.py | 4 + src/rbyte/io/hdf5/table_reader.py | 97 +++++++ src/rbyte/io/hdf5/tensor_source.py | 25 ++ src/rbyte/io/path/__init__.py | 4 + src/rbyte/io/path/table_reader.py | 75 ++++++ src/rbyte/io/path/tensor_source.py | 44 ++++ src/rbyte/io/rrd/__init__.py | 4 + src/rbyte/io/rrd/frame_source.py | 55 ++++ .../rrd/reader.py => rrd/table_reader.py} | 14 +- src/rbyte/io/table/__init__.py | 32 +-- src/rbyte/io/table/aligner.py | 212 ++++++++-------- src/rbyte/io/table/base.py | 25 +- src/rbyte/io/table/builder.py | 68 +++-- src/rbyte/io/table/concater.py | 28 ++- src/rbyte/io/table/hdf5/__init__.py | 3 - src/rbyte/io/table/hdf5/reader.py | 92 ------- src/rbyte/io/table/json/__init__.py | 3 - src/rbyte/io/table/mcap/__init__.py | 3 - src/rbyte/io/table/rrd/__init__.py | 3 - src/rbyte/io/table/transforms/base.py | 4 +- .../io/table/transforms/fps_resampler.py | 10 +- src/rbyte/io/table/yaak/__init__.py | 3 - src/rbyte/io/{frame => }/video/__init__.py | 0 src/rbyte/io/video/ffmpeg_source.py | 41 +++ .../vali_reader.py => video/vali_source.py} | 35 ++- src/rbyte/io/yaak/__init__.py | 3 + src/rbyte/io/{table => }/yaak/idl-repo | 0 .../io/{table => }/yaak/message_iterator.py | 10 +- .../io/{table => }/yaak/proto/__init__.py | 0 .../yaak/reader.py => yaak/table_reader.py} | 18 +- src/rbyte/sample/__init__.py | 3 + src/rbyte/sample/base.py | 2 +- .../sample/{builder.py => greedy_builder.py} | 23 +- src/rbyte/scripts/build_table.py | 24 -- src/rbyte/scripts/read_frames.py | 116 --------- src/rbyte/utils/dataframe/cache.py | 6 +- src/rbyte/utils/functional.py | 41 +++ .../mcap/decoders/protobuf_decoder_factory.py | 2 +- src/rbyte/viz/loggers/rerun_logger.py | 237 +++++++++--------- tests/data/mimicgen/.gitattributes | 1 + tests/data/mimicgen/README.md | 1 + tests/data/{ => mimicgen}/coffee.hdf5 | 0 tests/data/nuscenes/README.md | 1 + tests/data/nuscenes/mcap/.gitattributes | 1 + tests/data/nuscenes/mcap/README.md | 1 + .../nuScenes-v1.0-mini-scene-0061-cut.mcap | 0 tests/data/nuscenes/rrd/.gitattributes | 1 + tests/data/nuscenes/rrd/README.md | 1 + tests/data/nuscenes/rrd/nuscenes_dataset.rrd | 3 + tests/data/yaak/.gitattributes | 1 + .../Niro098-HQ/2024-06-18--13-39-54/ai.mcap | 0 .../cam_front_left.pii.mp4 | 0 .../cam_left_backward.pii.mp4 | 0 .../cam_right_backward.pii.mp4 | 0 .../2024-06-18--13-39-54/metadata.log | 0 tests/data/yaak/README.md | 1 + tests/data/zod/.gitattributes | 1 + tests/data/zod/README.md | 3 + ...0002_romeo_2022-06-13T10:50:07.092270Z.jpg | 3 + ...0002_romeo_2022-06-13T10:50:07.191297Z.jpg | 3 + ...0002_romeo_2022-06-13T10:50:07.290323Z.jpg | 3 + ...0002_romeo_2022-06-13T10:50:07.389350Z.jpg | 3 + ...0002_romeo_2022-06-13T10:50:07.488377Z.jpg | 3 + ...0002_romeo_2022-06-13T10:50:07.587404Z.jpg | 3 + ...0002_romeo_2022-06-13T10:50:07.686430Z.jpg | 3 + ...0002_romeo_2022-06-13T10:50:07.785457Z.jpg | 3 + ...0002_romeo_2022-06-13T10:50:07.884484Z.jpg | 3 + ...0002_romeo_2022-06-13T10:50:07.983510Z.jpg | 3 + ...0002_romeo_2022-06-13T10:50:07.000063Z.npy | 3 + ...0002_romeo_2022-06-13T10:50:07.111227Z.npy | 3 + ...0002_romeo_2022-06-13T10:50:07.222375Z.npy | 3 + ...0002_romeo_2022-06-13T10:50:07.333497Z.npy | 3 + ...0002_romeo_2022-06-13T10:50:07.444642Z.npy | 3 + ...0002_romeo_2022-06-13T10:50:07.555767Z.npy | 3 + ...0002_romeo_2022-06-13T10:50:07.666867Z.npy | 3 + ...0002_romeo_2022-06-13T10:50:07.777966Z.npy | 3 + ...0002_romeo_2022-06-13T10:50:07.889058Z.npy | 3 + .../sequences/000002_short/vehicle_data.hdf5 | 3 + tests/test_dataloader.py | 191 +++++++++----- 140 files changed, 1741 insertions(+), 1941 deletions(-) delete mode 100644 .gitattributes delete mode 100644 config/_templates/build_table.yaml create mode 100644 config/_templates/dataset/nuscenes/mcap.yaml create mode 100644 config/_templates/dataset/nuscenes/rrd.yaml delete mode 100644 config/_templates/dataset/nuscenes_mcap.yaml delete mode 100644 config/_templates/dataset/nuscenes_rrd.yaml create mode 100644 config/_templates/dataset/zod.yaml delete mode 100644 config/_templates/frame_reader/directory.yaml delete mode 100644 config/_templates/frame_reader/hdf5.yaml delete mode 100644 config/_templates/frame_reader/mcap.yaml delete mode 100644 config/_templates/frame_reader/video/ffmpeg.yaml delete mode 100644 config/_templates/frame_reader/video/vali.yaml create mode 100644 config/_templates/logger/rerun/nuscenes/mcap.yaml rename config/_templates/logger/rerun/{nuscenes_rrd.yaml => nuscenes/rrd.yaml} (56%) delete mode 100644 config/_templates/logger/rerun/nuscenes_mcap.yaml create mode 100644 config/_templates/logger/rerun/zod.yaml delete mode 100644 config/_templates/read_frames.yaml delete mode 100644 config/_templates/table_builder/carla.yaml delete mode 100644 config/_templates/table_builder/hdf5.yaml delete mode 100644 config/_templates/table_builder/mcap.yaml delete mode 100644 config/_templates/table_builder/rrd.yaml delete mode 100644 config/_templates/table_builder/yaak.yaml delete mode 100644 config/_templates/table_writer/console.yaml delete mode 100644 config/_templates/table_writer/csv.yaml delete mode 100644 config/_templates/table_writer/parquet.yaml create mode 100644 examples/.gitattributes create mode 100644 src/rbyte/io/_json/__init__.py rename src/rbyte/io/{table/json/reader.py => _json/table_reader.py} (88%) create mode 100644 src/rbyte/io/_mcap/__init__.py rename src/rbyte/io/{table/mcap/reader.py => _mcap/table_reader.py} (93%) rename src/rbyte/io/{frame/mcap/reader.py => _mcap/tensor_source.py} (82%) create mode 100644 src/rbyte/io/_numpy/__init__.py create mode 100644 src/rbyte/io/_numpy/tensor_source.py create mode 100644 src/rbyte/io/base.py delete mode 100644 src/rbyte/io/frame/__init__.py delete mode 100644 src/rbyte/io/frame/base.py delete mode 100644 src/rbyte/io/frame/directory/__init__.py delete mode 100644 src/rbyte/io/frame/directory/reader.py delete mode 100644 src/rbyte/io/frame/hdf5/__init__.py delete mode 100644 src/rbyte/io/frame/hdf5/reader.py delete mode 100644 src/rbyte/io/frame/mcap/__init__.py delete mode 100644 src/rbyte/io/frame/rrd/__init__.py delete mode 100644 src/rbyte/io/frame/rrd/reader.py delete mode 100644 src/rbyte/io/frame/video/ffmpeg_reader.py create mode 100644 src/rbyte/io/hdf5/__init__.py create mode 100644 src/rbyte/io/hdf5/table_reader.py create mode 100644 src/rbyte/io/hdf5/tensor_source.py create mode 100644 src/rbyte/io/path/__init__.py create mode 100644 src/rbyte/io/path/table_reader.py create mode 100644 src/rbyte/io/path/tensor_source.py create mode 100644 src/rbyte/io/rrd/__init__.py create mode 100644 src/rbyte/io/rrd/frame_source.py rename src/rbyte/io/{table/rrd/reader.py => rrd/table_reader.py} (90%) delete mode 100644 src/rbyte/io/table/hdf5/__init__.py delete mode 100644 src/rbyte/io/table/hdf5/reader.py delete mode 100644 src/rbyte/io/table/json/__init__.py delete mode 100644 src/rbyte/io/table/mcap/__init__.py delete mode 100644 src/rbyte/io/table/rrd/__init__.py delete mode 100644 src/rbyte/io/table/yaak/__init__.py rename src/rbyte/io/{frame => }/video/__init__.py (100%) create mode 100644 src/rbyte/io/video/ffmpeg_source.py rename src/rbyte/io/{frame/video/vali_reader.py => video/vali_source.py} (80%) create mode 100644 src/rbyte/io/yaak/__init__.py rename src/rbyte/io/{table => }/yaak/idl-repo (100%) rename src/rbyte/io/{table => }/yaak/message_iterator.py (92%) rename src/rbyte/io/{table => }/yaak/proto/__init__.py (100%) rename src/rbyte/io/{table/yaak/reader.py => yaak/table_reader.py} (87%) rename src/rbyte/sample/{builder.py => greedy_builder.py} (75%) delete mode 100644 src/rbyte/scripts/build_table.py delete mode 100644 src/rbyte/scripts/read_frames.py create mode 100644 src/rbyte/utils/functional.py create mode 100644 tests/data/mimicgen/.gitattributes create mode 100644 tests/data/mimicgen/README.md rename tests/data/{ => mimicgen}/coffee.hdf5 (100%) create mode 100644 tests/data/nuscenes/README.md create mode 100644 tests/data/nuscenes/mcap/.gitattributes create mode 100644 tests/data/nuscenes/mcap/README.md rename tests/data/{ => nuscenes/mcap}/nuScenes-v1.0-mini-scene-0061-cut.mcap (100%) create mode 100644 tests/data/nuscenes/rrd/.gitattributes create mode 100644 tests/data/nuscenes/rrd/README.md create mode 100644 tests/data/nuscenes/rrd/nuscenes_dataset.rrd create mode 100644 tests/data/yaak/.gitattributes rename tests/data/{ => yaak}/Niro098-HQ/2024-06-18--13-39-54/ai.mcap (100%) rename tests/data/{ => yaak}/Niro098-HQ/2024-06-18--13-39-54/cam_front_left.pii.mp4 (100%) rename tests/data/{ => yaak}/Niro098-HQ/2024-06-18--13-39-54/cam_left_backward.pii.mp4 (100%) rename tests/data/{ => yaak}/Niro098-HQ/2024-06-18--13-39-54/cam_right_backward.pii.mp4 (100%) rename tests/data/{ => yaak}/Niro098-HQ/2024-06-18--13-39-54/metadata.log (100%) create mode 100644 tests/data/yaak/README.md create mode 100644 tests/data/zod/.gitattributes create mode 100644 tests/data/zod/README.md create mode 100644 tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.092270Z.jpg create mode 100644 tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.191297Z.jpg create mode 100644 tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.290323Z.jpg create mode 100644 tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.389350Z.jpg create mode 100644 tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.488377Z.jpg create mode 100644 tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.587404Z.jpg create mode 100644 tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.686430Z.jpg create mode 100644 tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.785457Z.jpg create mode 100644 tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.884484Z.jpg create mode 100644 tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.983510Z.jpg create mode 100644 tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.000063Z.npy create mode 100644 tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.111227Z.npy create mode 100644 tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.222375Z.npy create mode 100644 tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.333497Z.npy create mode 100644 tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.444642Z.npy create mode 100644 tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.555767Z.npy create mode 100644 tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.666867Z.npy create mode 100644 tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.777966Z.npy create mode 100644 tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.889058Z.npy create mode 100644 tests/data/zod/sequences/000002_short/vehicle_data.hdf5 diff --git a/.gitattributes b/.gitattributes deleted file mode 100644 index ea45e55..0000000 --- a/.gitattributes +++ /dev/null @@ -1,2 +0,0 @@ -tests/data/** filter=lfs diff=lfs merge=lfs -text -*.ipynb filter=jupyter-nbconvert-clear-output diff --git a/.gitmodules b/.gitmodules index fd947eb..77e8506 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ -[submodule "src/rbyte/io/table/yaak/idl-repo"] - path = src/rbyte/io/table/yaak/idl-repo +[submodule "src/rbyte/io/yaak/idl-repo"] + path = src/rbyte/io/yaak/idl-repo url = git@github.com:yaak-ai/idl-repo diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 863e329..73cd447 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,25 +5,20 @@ repos: hooks: - id: validate-pyproject - - repo: https://github.com/crate-ci/typos - rev: v1.26.8 - hooks: - - id: typos - - repo: https://github.com/asottile/pyupgrade rev: v3.19.0 hooks: - id: pyupgrade - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.7.1 + rev: v0.7.2 hooks: - id: ruff args: [--fix] - id: ruff-format - repo: https://github.com/DetachHead/basedpyright-pre-commit-mirror - rev: 1.19.1 + rev: 1.21.0 hooks: - id: basedpyright diff --git a/config/_templates/build_table.yaml b/config/_templates/build_table.yaml deleted file mode 100644 index 755726d..0000000 --- a/config/_templates/build_table.yaml +++ /dev/null @@ -1,10 +0,0 @@ ---- -defaults: - - table_builder: !!null - - table_writer: console - - _self_ - -hydra: - output_subdir: !!null - run: - dir: . diff --git a/config/_templates/dataset/carla.yaml b/config/_templates/dataset/carla.yaml index e565b67..afd107d 100644 --- a/config/_templates/dataset/carla.yaml +++ b/config/_templates/dataset/carla.yaml @@ -59,14 +59,14 @@ _convert_: all inputs: #@ for input_id in drives: (@=input_id@): - frame: + sources: #@ for source_id in cameras: (@=source_id@): index_column: _idx_ - reader: - _target_: rbyte.io.frame.DirectoryFrameReader + source: + _target_: rbyte.io.PathTensorSource path: "${data_dir}/(@=input_id@)/frames/(@=source_id@).defish.mp4/576x324/{:09d}.jpg" - frame_decoder: + decoder: _target_: simplejpeg.decode_jpeg _partial_: true colorspace: rgb @@ -75,40 +75,40 @@ inputs: #@ end - table: - builder: - _target_: rbyte.io.table.TableBuilder - _convert_: all - readers: - - path: ${data_dir}/(@=input_id@)/ego_logs.json - reader: - _target_: rbyte.io.table.JsonTableReader - _recursive_: false - fields: - records: - _idx_: - control.brake: - control.throttle: - control.steer: - state.velocity.value: - state.acceleration.value: + table_builder: + _target_: rbyte.io.table.TableBuilder + _convert_: all + readers: + ego_logs: + path: ${data_dir}/(@=input_id@)/ego_logs.json + reader: + _target_: rbyte.io.JsonTableReader + _recursive_: false + fields: + records: + _idx_: + control.brake: + control.throttle: + control.steer: + state.velocity.value: + state.acceleration.value: - transforms: - - _target_: rbyte.io.table.transforms.FpsResampler - source_fps: 20 - target_fps: 30 + transforms: + - _target_: rbyte.io.FpsResampler + source_fps: 20 + target_fps: 30 - merger: - _target_: rbyte.io.table.TableConcater - method: vertical + merger: + _target_: rbyte.io.TableConcater + method: vertical - filter: | - `control.throttle` > 0.5 + filter: | + `control.throttle` > 0.5 #@ end sample_builder: - _target_: rbyte.sample.builder.GreedySampleTableBuilder + _target_: rbyte.sample.GreedySampleBuilder index_column: _idx_ length: 1 stride: 1 diff --git a/config/_templates/dataset/mimicgen.yaml b/config/_templates/dataset/mimicgen.yaml index 680ad25..10cc0b4 100644 --- a/config/_templates/dataset/mimicgen.yaml +++ b/config/_templates/dataset/mimicgen.yaml @@ -1,5 +1,3 @@ -#! https://huggingface.co/datasets/amandlek/mimicgen_datasets/blob/main/source/coffee.hdf5 - #@yaml/text-templated-strings #@ inputs = { @@ -10,7 +8,7 @@ #@ } #@ frame_keys = [ -#@ 'obs/agentview_image', +#@ "obs/agentview_image", #@ ] --- _target_: rbyte.Dataset @@ -20,38 +18,38 @@ inputs: #@ for input_id, input_keys in inputs.items(): #@ for input_key in input_keys: (@=input_id@)(@=input_key@): - frame: + sources: #@ for frame_key in frame_keys: (@=frame_key@): index_column: _idx_ - reader: - _target_: rbyte.io.frame.Hdf5FrameReader + source: + _target_: rbyte.io.Hdf5TensorSource path: "${data_dir}/(@=input_id@).hdf5" key: (@=input_key@)/(@=frame_key@) #@ end - table: - builder: - _target_: rbyte.io.table.TableBuilder - _convert_: all - readers: - - path: "${data_dir}/(@=input_id@).hdf5" - reader: - _target_: rbyte.io.table.Hdf5TableReader - _recursive_: false - fields: - (@=input_key@): - _idx_: - obs/robot0_eef_pos: + table_builder: + _target_: rbyte.io.TableBuilder + _convert_: all + readers: + hdf5: + path: "${data_dir}/(@=input_id@).hdf5" + reader: + _target_: rbyte.io.Hdf5TableReader + _recursive_: false + fields: + (@=input_key@): + _idx_: + obs/robot0_eef_pos: - merger: - _target_: rbyte.io.table.TableConcater - method: vertical + merger: + _target_: rbyte.io.TableConcater + method: vertical #@ end #@ end sample_builder: - _target_: rbyte.sample.builder.GreedySampleTableBuilder + _target_: rbyte.sample.GreedySampleBuilder index_column: _idx_ length: 1 stride: 1 diff --git a/config/_templates/dataset/nuscenes/mcap.yaml b/config/_templates/dataset/nuscenes/mcap.yaml new file mode 100644 index 0000000..4c15504 --- /dev/null +++ b/config/_templates/dataset/nuscenes/mcap.yaml @@ -0,0 +1,102 @@ +#@yaml/text-templated-strings + +#@ inputs = [ +#@ "nuScenes-v1.0-mini-scene-0061-cut", +#@ ] + +#@ camera_topics = { +#@ "CAM_FRONT": "/CAM_FRONT/image_rect_compressed", +#@ "CAM_FRONT_LEFT": "/CAM_FRONT_LEFT/image_rect_compressed", +#@ "CAM_FRONT_RIGHT": "/CAM_FRONT_RIGHT/image_rect_compressed", +#@ } +--- +_target_: rbyte.Dataset +_convert_: all +_recursive_: false +inputs: + #@ for input_id in inputs: + (@=input_id@): + sources: + #@ for camera, topic in camera_topics.items(): + (@=camera@): + index_column: mcap/(@=topic@)/_idx_ + source: + _target_: rbyte.io.McapTensorSource + path: "${data_dir}/(@=input_id@).mcap" + topic: (@=topic@) + decoder_factory: mcap_protobuf.decoder.DecoderFactory + decoder: + _target_: simplejpeg.decode_jpeg + _partial_: true + colorspace: rgb + fastdct: true + fastupsample: true + #@ end + + table_builder: + _target_: rbyte.io.TableBuilder + _convert_: all + readers: + mcap: + path: "${data_dir}/(@=input_id@).mcap" + reader: + _target_: rbyte.io.McapTableReader + _recursive_: false + decoder_factories: + - rbyte.utils.mcap.ProtobufDecoderFactory + - rbyte.utils.mcap.JsonDecoderFactory + + fields: + #@ for topic in camera_topics.values(): + (@=topic@): + log_time: + _target_: polars.Datetime + time_unit: ns + + _idx_: + #@ end + + /odom: + log_time: + _target_: polars.Datetime + time_unit: ns + vel.x: + + merger: + _target_: rbyte.io.TableAligner + separator: "/" + merge: + mcap: + #@ topic = camera_topics.values()[0] + (@=topic@): + key: log_time + + #@ for topic in camera_topics.values()[1:]: + (@=topic@): + key: log_time + columns: + _idx_: + method: asof + tolerance: 40ms + strategy: nearest + #@ end + + /odom: + key: log_time + columns: + vel.x: + method: interp + + filter: | + `mcap//odom/vel.x` >= 8 + + cache: + #@ end + +sample_builder: + _target_: rbyte.sample.GreedySampleBuilder + index_column: mcap/(@=camera_topics.values()[0]@)/_idx_ + length: 1 + stride: 1 + min_step: 1 + filter: !!null diff --git a/config/_templates/dataset/nuscenes/rrd.yaml b/config/_templates/dataset/nuscenes/rrd.yaml new file mode 100644 index 0000000..50a5dd5 --- /dev/null +++ b/config/_templates/dataset/nuscenes/rrd.yaml @@ -0,0 +1,93 @@ +#@yaml/text-templated-strings + +#@ inputs = [ +#@ "nuscenes_dataset", +#@ ] + +#@ camera_entities = { +#@ "CAM_FRONT": "/world/ego_vehicle/CAM_FRONT", +#@ "CAM_FRONT_LEFT": "/world/ego_vehicle/CAM_FRONT_LEFT", +#@ "CAM_FRONT_RIGHT": "/world/ego_vehicle/CAM_FRONT_RIGHT", +#@ } +--- +_target_: rbyte.Dataset +_convert_: all +_recursive_: false +inputs: + #@ for input_id in inputs: + (@=input_id@): + sources: + #@ for camera, entity in camera_entities.items(): + (@=camera@): + index_column: rrd/(@=entity@)/_idx_ + source: + _target_: rbyte.io.RrdFrameSource + path: "${data_dir}/(@=input_id@).rrd" + index: timestamp + entity_path: (@=entity@) + decoder: + _target_: simplejpeg.decode_jpeg + _partial_: true + colorspace: rgb + fastdct: true + fastupsample: true + #@ end + + table_builder: + _target_: rbyte.io.TableBuilder + _convert_: all + readers: + rrd: + path: "${data_dir}/(@=input_id@).rrd" + reader: + _target_: rbyte.io.RrdTableReader + _recursive_: false + index: timestamp + contents: + #@ for entity in camera_entities.values(): + (@=entity@): + - _idx_ + #@ end + + /world/ego_vehicle/LIDAR_TOP: + - Position3D + + merger: + _target_: rbyte.io.TableAligner + separator: / + merge: + rrd: + #@ entity = camera_entities.values()[0] + (@=entity@): + key: timestamp + + #@ for entity in camera_entities.values()[1:]: + (@=entity@): + key: timestamp + columns: + _idx_: + method: asof + strategy: nearest + tolerance: 40ms + #@ end + + /world/ego_vehicle/LIDAR_TOP: + key: timestamp + columns: + Position3D: + method: asof + strategy: nearest + tolerance: 40ms + + filter: | + `rrd//world/ego_vehicle/CAM_FRONT/timestamp` between '2018-07-24 03:28:48' and '2018-07-24 03:28:50' + #@ end + +sample_builder: + _target_: rbyte.sample.GreedySampleBuilder + index_column: rrd/(@=camera_entities.values()[0]@)/_idx_ + length: 1 + stride: 1 + min_step: 1 + filter: !!null + diff --git a/config/_templates/dataset/nuscenes_mcap.yaml b/config/_templates/dataset/nuscenes_mcap.yaml deleted file mode 100644 index ff2dff5..0000000 --- a/config/_templates/dataset/nuscenes_mcap.yaml +++ /dev/null @@ -1,104 +0,0 @@ -#! https://github.com/foxglove/nuscenes2mcap - -#@yaml/text-templated-strings - -#@ inputs = [ -#@ 'nuScenes-v1.0-mini-scene-0061-cut', -#@ ] - -#@ camera_topics = { -#@ 'CAM_FRONT': '/CAM_FRONT/image_rect_compressed', -#@ 'CAM_FRONT_LEFT': '/CAM_FRONT_LEFT/image_rect_compressed', -#@ 'CAM_FRONT_RIGHT': '/CAM_FRONT_RIGHT/image_rect_compressed', -#@ } ---- -_target_: rbyte.Dataset -_convert_: all -_recursive_: false -inputs: - #@ for input_id in inputs: - (@=input_id@): - frame: - #@ for camera, topic in camera_topics.items(): - (@=camera@): - index_column: (@=topic@)/_idx_ - reader: - _target_: rbyte.io.frame.McapFrameReader - path: "${data_dir}/(@=input_id@).mcap" - topic: (@=topic@) - decoder_factory: mcap_protobuf.decoder.DecoderFactory - frame_decoder: - _target_: simplejpeg.decode_jpeg - _partial_: true - colorspace: rgb - fastdct: true - fastupsample: true - #@ end - - table: - builder: - _target_: rbyte.io.table.TableBuilder - _convert_: all - readers: - - path: "${data_dir}/(@=input_id@).mcap" - reader: - _target_: rbyte.io.table.McapTableReader - _recursive_: false - decoder_factories: - - rbyte.utils.mcap.ProtobufDecoderFactory - - rbyte.utils.mcap.JsonDecoderFactory - - fields: - #@ for topic in camera_topics.values(): - (@=topic@): - log_time: - _target_: polars.Datetime - time_unit: ns - - _idx_: - #@ end - - /odom: - log_time: - _target_: polars.Datetime - time_unit: ns - vel.x: - - merger: - _target_: rbyte.io.table.TableAligner - separator: "/" - merge: - (@=camera_topics.values()[0]@): - log_time: - method: ref - - #@ for topic in camera_topics.values()[1:]: - (@=topic@): - log_time: - method: ref - - _idx_: - method: asof - tolerance: 40ms - strategy: nearest - #@ end - - /odom: - log_time: - method: ref - vel.x: - method: interp - - filter: | - `/odom/vel.x` >= 8 - - cache: - #@ end - -sample_builder: - _target_: rbyte.sample.builder.GreedySampleTableBuilder - index_column: (@=camera_topics.values()[0]@)/_idx_ - length: 1 - stride: 1 - min_step: 1 - filter: !!null diff --git a/config/_templates/dataset/nuscenes_rrd.yaml b/config/_templates/dataset/nuscenes_rrd.yaml deleted file mode 100644 index b99f263..0000000 --- a/config/_templates/dataset/nuscenes_rrd.yaml +++ /dev/null @@ -1,97 +0,0 @@ -#! https://app.rerun.io/examples/nuscenes_dataset.rrd - -#@yaml/text-templated-strings - -#@ inputs = [ -#@ 'nuscenes', -#@ ] - -#@ camera_entities = { -#@ 'CAM_FRONT_LEFT': '/world/ego_vehicle/CAM_FRONT_LEFT', -#@ 'CAM_FRONT': '/world/ego_vehicle/CAM_FRONT', -#@ 'CAM_FRONT_RIGHT': '/world/ego_vehicle/CAM_FRONT_RIGHT', -#@ } ---- -_target_: rbyte.Dataset -_convert_: all -_recursive_: false -inputs: - #@ for input_id in inputs: - (@=input_id@): - frame: - #@ for camera, entity in camera_entities.items(): - (@=camera@): - index_column: (@=entity@)/_idx_ - reader: - _target_: rbyte.io.frame.RrdFrameReader - path: "${data_dir}/(@=input_id@).rrd" - index: timestamp - entity_path: (@=entity@) - frame_decoders: - image/jpeg: - _target_: simplejpeg.decode_jpeg - _partial_: true - colorspace: rgb - fastdct: true - fastupsample: true - #@ end - - table: - builder: - _target_: rbyte.io.table.TableBuilder - _convert_: all - readers: - - path: "${data_dir}/(@=input_id@).rrd" - reader: - _target_: rbyte.io.table.RrdTableReader - _recursive_: false - index: timestamp - contents: - #@ for entity in camera_entities.values(): - (@=entity@): - - _idx_ - #@ end - - /world/ego_vehicle/LIDAR_TOP: - - Position3D - - merger: - _target_: rbyte.io.table.TableAligner - separator: / - merge: - #@ entity = camera_entities.values()[0] - (@=entity@): - timestamp: - method: ref - - #@ for entity in camera_entities.values()[1:]: - (@=entity@): - timestamp: - method: ref - - _idx_: - method: asof - strategy: nearest - tolerance: 20ms - #@ end - - /world/ego_vehicle/LIDAR_TOP: - timestamp: - method: ref - - Position3D: - method: asof - strategy: nearest - tolerance: 20ms - - filter: | - `/world/ego_vehicle/CAM_FRONT/timestamp` between '2018-07-24 03:28:48' and '2018-07-24 03:28:50' - #@ end - -sample_builder: - _target_: rbyte.sample.builder.GreedySampleTableBuilder - index_column: (@=camera_entities.values()[0]@)/_idx_ - length: 1 - stride: 1 - min_step: 1 - filter: !!null diff --git a/config/_templates/dataset/yaak.yaml b/config/_templates/dataset/yaak.yaml index 02b6671..ce1acbe 100644 --- a/config/_templates/dataset/yaak.yaml +++ b/config/_templates/dataset/yaak.yaml @@ -16,120 +16,124 @@ _convert_: all inputs: #@ for input_id in drives: (@=input_id@): - frame: + sources: #@ for source_id in cameras: (@=source_id@): - index_column: "ImageMetadata.(@=source_id@).frame_idx" - reader: - _target_: rbyte.io.frame.FfmpegFrameReader + index_column: "meta/ImageMetadata.(@=source_id@)/frame_idx" + source: + _target_: rbyte.io.FfmpegFrameSource _recursive_: true path: "${data_dir}/(@=input_id@)/(@=source_id@).pii.mp4" resize_shorter_side: 324 #@ end - table: - builder: - _target_: rbyte.io.table.TableBuilder - _convert_: all - readers: - - path: ${data_dir}/(@=input_id@)/metadata.log - reader: - _target_: rbyte.io.table.yaak.YaakMetadataTableReader - _recursive_: false - fields: - rbyte.io.table.yaak.proto.sensor_pb2.ImageMetadata: - time_stamp: - _target_: polars.Datetime - time_unit: ns - - frame_idx: polars.UInt32 - camera_name: - _target_: polars.Enum - categories: - - cam_front_center - - cam_front_left - - cam_front_right - - cam_left_forward - - cam_right_forward - - cam_left_backward - - cam_right_backward - - cam_rear - - rbyte.io.table.yaak.proto.can_pb2.VehicleMotion: - time_stamp: - _target_: polars.Datetime - time_unit: ns - - speed: polars.Float32 - gear: - _target_: polars.Enum - categories: ["0", "1", "2", "3"] - - - path: ${data_dir}/(@=input_id@)/ai.mcap - reader: - _target_: rbyte.io.table.McapTableReader - _recursive_: false - decoder_factories: [rbyte.utils.mcap.ProtobufDecoderFactory] - fields: - /ai/safety_score: - clip.end_timestamp: - _target_: polars.Datetime - time_unit: ns - - score: polars.Float32 - - merger: - _target_: rbyte.io.table.TableAligner - separator: "." - merge: + table_builder: + _target_: rbyte.io.TableBuilder + _convert_: all + readers: + meta: + path: ${data_dir}/(@=input_id@)/metadata.log + reader: + _target_: rbyte.io.YaakMetadataTableReader + _recursive_: false + fields: + rbyte.io.yaak.proto.sensor_pb2.ImageMetadata: + time_stamp: + _target_: polars.Datetime + time_unit: ns + + frame_idx: polars.UInt32 + camera_name: + _target_: polars.Enum + categories: + - cam_front_center + - cam_front_left + - cam_front_right + - cam_left_forward + - cam_right_forward + - cam_left_backward + - cam_right_backward + - cam_rear + + rbyte.io.yaak.proto.can_pb2.VehicleMotion: + time_stamp: + _target_: polars.Datetime + time_unit: ns + + speed: polars.Float32 + gear: + _target_: polars.Enum + categories: ["0", "1", "2", "3"] + + mcap: + path: ${data_dir}/(@=input_id@)/ai.mcap + reader: + _target_: rbyte.io.McapTableReader + _recursive_: false + decoder_factories: [rbyte.utils.mcap.ProtobufDecoderFactory] + fields: + /ai/safety_score: + clip.end_timestamp: + _target_: polars.Datetime + time_unit: ns + + score: polars.Float32 + + merger: + _target_: rbyte.io.TableAligner + separator: "/" + merge: + meta: ImageMetadata.(@=cameras[0]@): - time_stamp: - method: ref + key: time_stamp #@ for camera in cameras[1:]: ImageMetadata.(@=camera@): - time_stamp: - method: ref - - frame_idx: - method: asof - tolerance: 20ms - strategy: nearest + key: time_stamp + columns: + frame_idx: + method: asof + tolerance: 20ms + strategy: nearest #@ end VehicleMotion: - time_stamp: - method: ref - speed: - method: interp - gear: - method: asof - tolerance: 100ms - strategy: nearest - + key: time_stamp + columns: + speed: + method: interp + gear: + method: asof + tolerance: 100ms + strategy: nearest + + mcap: /ai/safety_score: - clip.end_timestamp: - method: ref - - score: - method: asof - tolerance: 500ms - strategy: nearest - - filter: | - `VehicleMotion.gear` == '3' - - cache: - _target_: rbyte.utils.dataframe.DataframeDiskCache - directory: /tmp/rbyte-cache - size_limit: 1GiB + key: clip.end_timestamp + columns: + clip.end_timestamp: + method: asof + tolerance: 500ms + strategy: nearest + score: + method: asof + tolerance: 500ms + strategy: nearest + + filter: | + `meta/VehicleMotion/gear` == '3' + + cache: + _target_: rbyte.utils.dataframe.DataframeDiskCache + directory: /tmp/rbyte-cache + size_limit: 1GiB #@ end sample_builder: - _target_: rbyte.sample.builder.GreedySampleTableBuilder - index_column: ImageMetadata.(@=cameras[0]@).frame_idx + _target_: rbyte.sample.GreedySampleBuilder + index_column: meta/ImageMetadata.(@=cameras[0]@)/frame_idx length: 1 stride: 1 min_step: 1 filter: | - array_mean(`VehicleMotion.speed`) > 40 + array_mean(`meta/VehicleMotion/speed`) > 40 diff --git a/config/_templates/dataset/zod.yaml b/config/_templates/dataset/zod.yaml new file mode 100644 index 0000000..b1d050c --- /dev/null +++ b/config/_templates/dataset/zod.yaml @@ -0,0 +1,114 @@ +--- +_target_: rbyte.Dataset +_convert_: all +_recursive_: false +inputs: + 000002_short: + sources: + camera_front_blur: + index_column: camera_front_blur/timestamp + source: + _target_: rbyte.io.PathTensorSource + path: "${data_dir}/zod/sequences/000002_short/camera_front_blur/000002_romeo_{:%Y-%m-%dT%H:%M:%S.%f}Z.jpg" + decoder: + _target_: simplejpeg.decode_jpeg + _partial_: true + colorspace: rgb + fastdct: true + fastupsample: true + + lidar_velodyne: + index_column: lidar_velodyne/timestamp + source: + _target_: rbyte.io.NumpyTensorSource + path: "${data_dir}/zod/sequences/000002_short/lidar_velodyne/000002_romeo_{:%Y-%m-%dT%H:%M:%S.%f}Z.npy" + select: ["x", "y", "z"] + + table_builder: + _target_: rbyte.io.TableBuilder + _convert_: all + readers: + camera_front_blur: + path: "${data_dir}/zod/sequences/000002_short/camera_front_blur/000002_romeo_{timestamp:%Y-%m-%dT%H:%M:%S.%f}Z.jpg" + reader: + _target_: rbyte.io.PathTableReader + _recursive_: false + fields: + timestamp: + _target_: polars.Datetime + time_unit: ns + + lidar_velodyne: + path: "${data_dir}/zod/sequences/000002_short/lidar_velodyne/000002_romeo_{timestamp:%Y-%m-%dT%H:%M:%S.%f}Z.npy" + reader: + _target_: rbyte.io.PathTableReader + _recursive_: false + fields: + timestamp: + _target_: polars.Datetime + time_unit: ns + + vehicle_data: + path: "${data_dir}/zod/sequences/000002_short/vehicle_data.hdf5" + reader: + _target_: rbyte.io.Hdf5TableReader + _recursive_: false + fields: + ego_vehicle_controls: + timestamp/nanoseconds/value: + _target_: polars.Datetime + time_unit: ns + + acceleration_pedal/ratio/unitless/value: + steering_wheel_angle/angle/radians/value: + + satellite: + timestamp/nanoseconds/value: + _target_: polars.Datetime + time_unit: ns + + speed/meters_per_second/value: + + merger: + _target_: rbyte.io.TableAligner + separator: "/" + merge: + camera_front_blur: + key: timestamp + + lidar_velodyne: + key: timestamp + columns: + timestamp: + method: asof + strategy: nearest + tolerance: 50ms + + vehicle_data: + ego_vehicle_controls: + key: timestamp/nanoseconds/value + columns: + timestamp/nanoseconds/value: + method: asof + strategy: nearest + tolerance: 50ms + + acceleration_pedal/ratio/unitless/value: + method: asof + strategy: nearest + tolerance: 50ms + + steering_wheel_angle/angle/radians/value: + method: asof + strategy: nearest + tolerance: 50ms + + satellite: + key: timestamp/nanoseconds/value + columns: + speed/meters_per_second/value: + method: interp + +sample_builder: + _target_: rbyte.sample.GreedySampleBuilder + length: 1 diff --git a/config/_templates/frame_reader/directory.yaml b/config/_templates/frame_reader/directory.yaml deleted file mode 100644 index 7e02999..0000000 --- a/config/_templates/frame_reader/directory.yaml +++ /dev/null @@ -1,10 +0,0 @@ ---- -_target_: rbyte.io.frame.DirectoryFrameReader -_recursive_: true -path: ??? -frame_decoder: - _target_: simplejpeg.decode_jpeg - _partial_: true - colorspace: rgb - fastdct: true - fastupsample: true diff --git a/config/_templates/frame_reader/hdf5.yaml b/config/_templates/frame_reader/hdf5.yaml deleted file mode 100644 index 1a0a769..0000000 --- a/config/_templates/frame_reader/hdf5.yaml +++ /dev/null @@ -1,5 +0,0 @@ ---- -_target_: rbyte.io.frame.Hdf5FrameReader -_recursive_: true -path: ??? -key: ??? diff --git a/config/_templates/frame_reader/mcap.yaml b/config/_templates/frame_reader/mcap.yaml deleted file mode 100644 index f19e906..0000000 --- a/config/_templates/frame_reader/mcap.yaml +++ /dev/null @@ -1,14 +0,0 @@ ---- -_target_: rbyte.io.frame.McapFrameReader -_recursive_: true -path: ??? -topic: ??? -decoder_factory: mcap_protobuf.decoder.DecoderFactory -frame_decoder: - _target_: simplejpeg.decode_jpeg - _partial_: true - colorspace: rgb - fastdct: true - fastupsample: true - -validate_crcs: false diff --git a/config/_templates/frame_reader/video/ffmpeg.yaml b/config/_templates/frame_reader/video/ffmpeg.yaml deleted file mode 100644 index 5107b03..0000000 --- a/config/_templates/frame_reader/video/ffmpeg.yaml +++ /dev/null @@ -1,6 +0,0 @@ ---- -_target_: rbyte.io.frame.FfmpegFrameReader -path: ??? -threads: !!null -resize_shorter_side: !!null -with_fallback: !!null diff --git a/config/_templates/frame_reader/video/vali.yaml b/config/_templates/frame_reader/video/vali.yaml deleted file mode 100644 index b65be70..0000000 --- a/config/_templates/frame_reader/video/vali.yaml +++ /dev/null @@ -1,5 +0,0 @@ ---- -_target_: rbyte.io.frame.video.vali_reader.ValiGpuFrameReader -_convert_: all -path: ??? -pixel_format_chain: [NV12] diff --git a/config/_templates/logger/rerun/carla.yaml b/config/_templates/logger/rerun/carla.yaml index 51fb3dd..58a8b2a 100644 --- a/config/_templates/logger/rerun/carla.yaml +++ b/config/_templates/logger/rerun/carla.yaml @@ -8,17 +8,15 @@ _target_: rbyte.viz.loggers.RerunLogger _recursive_: true _convert_: all schema: - frame: - #@ for camera in cameras: - (@=camera@): - Image: - color_model: RGB - #@ end + #@ for camera in cameras: + (@=camera@): + Image: + color_model: RGB + #@ end - table: - _idx_: TimeSequenceColumn - control.brake: Scalar - control.steer: Scalar - control.throttle: Scalar - state.acceleration.value: Scalar - state.velocity.value: Scalar + _idx_: TimeSequenceColumn + control.brake: Scalar + control.steer: Scalar + control.throttle: Scalar + state.acceleration.value: Scalar + state.velocity.value: Scalar diff --git a/config/_templates/logger/rerun/mimicgen.yaml b/config/_templates/logger/rerun/mimicgen.yaml index e0093ce..43b752c 100644 --- a/config/_templates/logger/rerun/mimicgen.yaml +++ b/config/_templates/logger/rerun/mimicgen.yaml @@ -1,11 +1,9 @@ --- _target_: rbyte.viz.loggers.RerunLogger schema: - frame: - obs/agentview_image: - Image: - color_model: RGB + obs/agentview_image: + Image: + color_model: RGB - table: - _idx_: TimeSequenceColumn - obs/robot0_eef_pos: Points3D + _idx_: TimeSequenceColumn + obs/robot0_eef_pos: Points3D diff --git a/config/_templates/logger/rerun/nuscenes/mcap.yaml b/config/_templates/logger/rerun/nuscenes/mcap.yaml new file mode 100644 index 0000000..98389ee --- /dev/null +++ b/config/_templates/logger/rerun/nuscenes/mcap.yaml @@ -0,0 +1,39 @@ +#@yaml/text-templated-strings + +#@ camera_topics = { +#@ "CAM_FRONT": "/CAM_FRONT/image_rect_compressed", +#@ "CAM_FRONT_LEFT": "/CAM_FRONT_LEFT/image_rect_compressed", +#@ "CAM_FRONT_RIGHT": "/CAM_FRONT_RIGHT/image_rect_compressed", +#@ } +--- +_target_: rbyte.viz.loggers.RerunLogger +schema: + #@ for camera in camera_topics.keys(): + (@=camera@): + Image: + color_model: RGB + #@ end + + #@ topic = camera_topics.values()[0] + mcap/(@=topic@)/_idx_: TimeSequenceColumn + mcap/(@=topic@)/log_time: TimeNanosColumn + #@ for topic in camera_topics.values()[1:]: + mcap/(@=topic@)/_idx_: TimeSequenceColumn + #@ end + mcap//odom/vel.x: Scalar + +spawn: true +blueprint: + _target_: rerun.blueprint.Blueprint + _args_: + - _target_: rerun.blueprint.Vertical + contents: + - _target_: rerun.blueprint.Horizontal + contents: + #@ for camera in camera_topics.keys(): + - _target_: rerun.blueprint.Spatial2DView + origin: #@ camera + #@ end + + - _target_: rerun.blueprint.TimeSeriesView + origin: mcap/ diff --git a/config/_templates/logger/rerun/nuscenes_rrd.yaml b/config/_templates/logger/rerun/nuscenes/rrd.yaml similarity index 56% rename from config/_templates/logger/rerun/nuscenes_rrd.yaml rename to config/_templates/logger/rerun/nuscenes/rrd.yaml index d3b6d9e..babdb29 100644 --- a/config/_templates/logger/rerun/nuscenes_rrd.yaml +++ b/config/_templates/logger/rerun/nuscenes/rrd.yaml @@ -1,41 +1,39 @@ #@yaml/text-templated-strings #@ camera_entities = { -#@ "CAM_FRONT_LEFT": "/world/ego_vehicle/CAM_FRONT_LEFT", #@ "CAM_FRONT": "/world/ego_vehicle/CAM_FRONT", +#@ "CAM_FRONT_LEFT": "/world/ego_vehicle/CAM_FRONT_LEFT", #@ "CAM_FRONT_RIGHT": "/world/ego_vehicle/CAM_FRONT_RIGHT", #@ } --- _target_: rbyte.viz.loggers.RerunLogger +spawn: true schema: - frame: - #@ for camera in camera_entities.keys(): - (@=camera@): - Image: - color_model: RGB - #@ end + #@ for camera in camera_entities.keys(): + (@=camera@): + Image: + color_model: RGB + #@ end - table: - #@ for camera_entity in camera_entities.values(): - (@=camera_entity@)/_idx_: TimeSequenceColumn - (@=camera_entity@)/timestamp: TimeNanosColumn - #@ end - /world/ego_vehicle/LIDAR_TOP/timestamp: TimeNanosColumn - /world/ego_vehicle/LIDAR_TOP/Position3D: Points3D + #@ entity = camera_entities.values()[0] + rrd/(@=entity@)/_idx_: TimeSequenceColumn + rrd/(@=entity@)/timestamp: TimeNanosColumn + #@ for entity in camera_entities.values()[1:]: + rrd/(@=entity@)/_idx_: TimeSequenceColumn + #@ end + #! rrd//world/ego_vehicle/LIDAR_TOP/Position3D: Points3D -spawn: true blueprint: _target_: rerun.blueprint.Blueprint _args_: - _target_: rerun.blueprint.Vertical contents: - _target_: rerun.blueprint.Spatial3DView - origin: table + origin: rrd/ - _target_: rerun.blueprint.Horizontal contents: #@ for camera in camera_entities.keys(): - _target_: rerun.blueprint.Spatial2DView - name: #@ camera - origin: #@ "frame/{}".format(camera) + origin : #@ camera #@ end diff --git a/config/_templates/logger/rerun/nuscenes_mcap.yaml b/config/_templates/logger/rerun/nuscenes_mcap.yaml deleted file mode 100644 index 60ec90a..0000000 --- a/config/_templates/logger/rerun/nuscenes_mcap.yaml +++ /dev/null @@ -1,40 +0,0 @@ -#@yaml/text-templated-strings - -#@ camera_topics = { -#@ 'CAM_FRONT': '/CAM_FRONT/image_rect_compressed', -#@ 'CAM_FRONT_LEFT': '/CAM_FRONT_LEFT/image_rect_compressed', -#@ 'CAM_FRONT_RIGHT': '/CAM_FRONT_RIGHT/image_rect_compressed', -#@ } ---- -_target_: rbyte.viz.loggers.RerunLogger -schema: - frame: - #@ for camera in camera_topics.keys(): - (@=camera@): - Image: - color_model: RGB - #@ end - - table: - #@ for topic in camera_topics.values(): - (@=topic@)/_idx_: TimeSequenceColumn - (@=topic@)/log_time: TimeNanosColumn - #@ end - /odom/vel.x: Scalar - -spawn: true -blueprint: - _target_: rerun.blueprint.Blueprint - _args_: - - _target_: rerun.blueprint.Vertical - contents: - - _target_: rerun.blueprint.Horizontal - contents: - #@ for camera in camera_topics.keys(): - - _target_: rerun.blueprint.Spatial2DView - name: #@ camera - origin: #@ "frame/{}".format(camera) - #@ end - - - _target_: rerun.blueprint.TimeSeriesView - origin: table diff --git a/config/_templates/logger/rerun/yaak.yaml b/config/_templates/logger/rerun/yaak.yaml index ba81bb9..686265f 100644 --- a/config/_templates/logger/rerun/yaak.yaml +++ b/config/_templates/logger/rerun/yaak.yaml @@ -8,20 +8,19 @@ --- _target_: rbyte.viz.loggers.RerunLogger schema: - frame: - #@ for camera in cameras: - (@=camera@): - Image: - color_model: RGB - #@ end + #@ for camera in cameras: + (@=camera@): + Image: + color_model: RGB + #@ end - table: - #@ for camera in cameras: - ImageMetadata.(@=camera@).frame_idx: TimeSequenceColumn - ImageMetadata.(@=camera@).time_stamp: TimeNanosColumn - #@ end - VehicleMotion.time_stamp: TimeNanosColumn - VehicleMotion.speed: Scalar + #@ camera = cameras[0] + meta/ImageMetadata.(@=camera@)/frame_idx: TimeSequenceColumn + meta/ImageMetadata.(@=camera@)/time_stamp: TimeNanosColumn + #@ for camera in cameras[1:]: + meta/ImageMetadata.(@=camera@)/frame_idx: TimeSequenceColumn + #@ end + meta/VehicleMotion/speed: Scalar - /ai/safety_score.clip.end_timestamp: TimeNanosColumn - /ai/safety_score.score: Scalar + mcap//ai/safety_score/clip.end_timestamp: TimeNanosColumn + mcap//ai/safety_score/score: Scalar diff --git a/config/_templates/logger/rerun/zod.yaml b/config/_templates/logger/rerun/zod.yaml new file mode 100644 index 0000000..ad21664 --- /dev/null +++ b/config/_templates/logger/rerun/zod.yaml @@ -0,0 +1,16 @@ +--- +_target_: rbyte.viz.loggers.RerunLogger +spawn: true +schema: + camera_front_blur/timestamp: TimeNanosColumn + camera_front_blur: + Image: + color_model: RGB + + lidar_velodyne/timestamp: TimeNanosColumn + lidar_velodyne: Points3D + + vehicle_data/ego_vehicle_controls/timestamp/nanoseconds/value: TimeNanosColumn + vehicle_data/ego_vehicle_controls/acceleration_pedal/ratio/unitless/value: Scalar + vehicle_data/ego_vehicle_controls/steering_wheel_angle/angle/radians/value: Scalar + vehicle_data/satellite/speed/meters_per_second/value: Scalar diff --git a/config/_templates/read_frames.yaml b/config/_templates/read_frames.yaml deleted file mode 100644 index 1cf3fde..0000000 --- a/config/_templates/read_frames.yaml +++ /dev/null @@ -1,17 +0,0 @@ ---- -defaults: - - frame_reader: !!null - - _self_ - -batch_size: 1 -application_id: rbyte -entity_path: frames -frame_config: - Image: - pixel_format: !!null - color_model: !!null - -hydra: - output_subdir: !!null - run: - dir: . diff --git a/config/_templates/table_builder/carla.yaml b/config/_templates/table_builder/carla.yaml deleted file mode 100644 index 9e29ea4..0000000 --- a/config/_templates/table_builder/carla.yaml +++ /dev/null @@ -1,28 +0,0 @@ ---- -_target_: rbyte.io.table.TableBuilder -_convert_: all -readers: - - path: ??? - reader: - _target_: rbyte.io.table.JsonTableReader - _recursive_: false - fields: - records: - _idx_: - control.brake: - control.throttle: - control.steer: - state.velocity.value: - state.acceleration.value: - - transforms: - - _target_: rbyte.io.table.transforms.FpsResampler - source_fps: 20 - target_fps: 30 - -merger: - _target_: rbyte.io.table.TableConcater - method: vertical - -filter: |- - `control.throttle` > 0.5 diff --git a/config/_templates/table_builder/hdf5.yaml b/config/_templates/table_builder/hdf5.yaml deleted file mode 100644 index 5bc81f5..0000000 --- a/config/_templates/table_builder/hdf5.yaml +++ /dev/null @@ -1,20 +0,0 @@ ---- -_target_: rbyte.io.table.TableBuilder -_convert_: all -readers: - - path: ??? - reader: - _target_: rbyte.io.table.Hdf5TableReader - _recursive_: false - fields: - /data/demo_0: - _idx_: - actions: - dones: - obs/robot0_eef_pos: - rewards: - states: - -merger: - _target_: rbyte.io.table.TableConcater - method: vertical diff --git a/config/_templates/table_builder/mcap.yaml b/config/_templates/table_builder/mcap.yaml deleted file mode 100644 index 2564115..0000000 --- a/config/_templates/table_builder/mcap.yaml +++ /dev/null @@ -1,64 +0,0 @@ -#@yaml/text-templated-strings - -#@ camera_topics = [ -#@ '/CAM_FRONT_LEFT/image_rect_compressed', -#@ '/CAM_FRONT_RIGHT/image_rect_compressed', -#@ ] ---- -_target_: rbyte.io.table.TableBuilder -_convert_: all -readers: - - path: ??? - reader: - _target_: rbyte.io.table.McapTableReader - _recursive_: false - decoder_factories: - - rbyte.utils.mcap.ProtobufDecoderFactory - - rbyte.utils.mcap.JsonDecoderFactory - - mcap_ros2.decoder.DecoderFactory - - fields: - #@ for topic in camera_topics: - (@=topic@): - log_time: - _target_: polars.Datetime - time_unit: ns - - _idx_: - #@ end - - /odom: - log_time: - _target_: polars.Datetime - time_unit: ns - vel.x: - -merger: - _target_: rbyte.io.table.TableAligner - separator: / - merge: - (@=camera_topics[0]@): - log_time: - method: ref - - #@ for topic in camera_topics[1:]: - (@=topic@): - log_time: - method: ref - _idx_: - method: asof - tolerance: 10ms - strategy: nearest - #@ end - - /odom: - log_time: - method: ref - vel.x: - method: interp - -filter: !!null -cache: - _target_: rbyte.utils.dataframe.DataframeDiskCache - directory: /tmp/rbyte-cache - size_limit: 1GiB diff --git a/config/_templates/table_builder/rrd.yaml b/config/_templates/table_builder/rrd.yaml deleted file mode 100644 index 17adf25..0000000 --- a/config/_templates/table_builder/rrd.yaml +++ /dev/null @@ -1,53 +0,0 @@ -#@yaml/text-templated-strings - -#@ camera_entities = { -#@ "CAM_FRONT_LEFT": "/world/ego_vehicle/CAM_FRONT_LEFT", -#@ "CAM_FRONT": "/world/ego_vehicle/CAM_FRONT", -#@ "CAM_FRONT_RIGHT": "/world/ego_vehicle/CAM_FRONT_RIGHT", -#@ } ---- -_target_: rbyte.io.table.TableBuilder -_convert_: all -readers: - - path: ??? - reader: - _target_: rbyte.io.table.RrdTableReader - _recursive_: false - index: timestamp - contents: - #@ for entity in camera_entities.values(): - (@=entity@): - - _idx_ - #@ end - - /world/ego_vehicle/LIDAR_TOP: - - Position3D - -merger: - _target_: rbyte.io.table.TableAligner - separator: / - merge: - #@ entity = camera_entities.values()[0] - (@=entity@): - timestamp: - method: ref - - #@ for entity in camera_entities.values()[1:]: - (@=entity@): - timestamp: - method: ref - - _idx_: - method: asof - strategy: nearest - tolerance: 20ms - #@ end - - /world/ego_vehicle/LIDAR_TOP: - timestamp: - method: ref - - Position3D: - method: asof - strategy: nearest - tolerance: 20ms diff --git a/config/_templates/table_builder/yaak.yaml b/config/_templates/table_builder/yaak.yaml deleted file mode 100644 index a3b0034..0000000 --- a/config/_templates/table_builder/yaak.yaml +++ /dev/null @@ -1,102 +0,0 @@ -#@yaml/text-templated-strings - -#@ cameras = [ -#@ 'cam_front_left', -#@ 'cam_left_forward', -#@ 'cam_right_forward', -#@ ] ---- -_target_: rbyte.io.table.TableBuilder -_convert_: all -readers: - - path: ??? - reader: - _target_: rbyte.io.table.yaak.YaakMetadataTableReader - _recursive_: false - fields: - rbyte.io.table.yaak.proto.sensor_pb2.ImageMetadata: - time_stamp: - _target_: polars.Datetime - time_unit: ns - - frame_idx: polars.UInt32 - camera_name: - _target_: polars.Enum - categories: - - cam_front_center - - cam_front_left - - cam_front_right - - cam_left_forward - - cam_right_forward - - cam_left_backward - - cam_right_backward - - cam_rear - - rbyte.io.table.yaak.proto.can_pb2.VehicleMotion: - time_stamp: - _target_: polars.Datetime - time_unit: ns - - speed: polars.Float32 - gear: - _target_: polars.Enum - categories: ["0", "1", "2", "3"] - - - path: ??? - reader: - _target_: rbyte.io.table.McapTableReader - _recursive_: false - decoder_factories: [rbyte.utils.mcap.ProtobufDecoderFactory] - fields: - /ai/safety_score: - clip.end_timestamp: - _target_: polars.Datetime - time_unit: ns - - score: polars.Float32 - -merger: - _target_: rbyte.io.table.TableAligner - separator: "." - merge: - #@ camera = cameras[0] - ImageMetadata.(@=camera@): - time_stamp: - method: ref - - #@ for camera in cameras[1:]: - ImageMetadata.(@=camera@): - time_stamp: - method: ref - - frame_idx: - method: asof - tolerance: 10ms - strategy: nearest - #@ end - - VehicleMotion: - time_stamp: - method: ref - speed: - method: interp - gear: - method: asof - tolerance: 100ms - - /ai/safety_score: - clip.end_timestamp: - method: ref - - score: - method: asof - tolerance: 100ms - strategy: nearest - -filter: | - `VehicleMotion.gear` == '3' - -cache: - _target_: rbyte.utils.dataframe.DataframeDiskCache - directory: /tmp/rbyte-cache - size_limit: 1GiB diff --git a/config/_templates/table_writer/console.yaml b/config/_templates/table_writer/console.yaml deleted file mode 100644 index fba8b54..0000000 --- a/config/_templates/table_writer/console.yaml +++ /dev/null @@ -1,4 +0,0 @@ ---- -_target_: polars.DataFrame.glimpse -_partial_: true -max_items_per_column: 3 diff --git a/config/_templates/table_writer/csv.yaml b/config/_templates/table_writer/csv.yaml deleted file mode 100644 index 8aca7e2..0000000 --- a/config/_templates/table_writer/csv.yaml +++ /dev/null @@ -1,4 +0,0 @@ ---- -_target_: polars.DataFrame.write_csv -_partial_: true -file: ??? diff --git a/config/_templates/table_writer/parquet.yaml b/config/_templates/table_writer/parquet.yaml deleted file mode 100644 index a58f047..0000000 --- a/config/_templates/table_writer/parquet.yaml +++ /dev/null @@ -1,4 +0,0 @@ ---- -_target_: polars.DataFrame.write_parquet -_partial_: true -file: ??? diff --git a/examples/.gitattributes b/examples/.gitattributes new file mode 100644 index 0000000..0b15a56 --- /dev/null +++ b/examples/.gitattributes @@ -0,0 +1 @@ +*.ipynb filter=jupyter-nbconvert-clear-output diff --git a/examples/nuscenes_mcap.ipynb b/examples/nuscenes_mcap.ipynb index bf78df7..4b872a3 100644 --- a/examples/nuscenes_mcap.ipynb +++ b/examples/nuscenes_mcap.ipynb @@ -39,8 +39,8 @@ "outputs": [], "source": [ "CONFIG_PATH = \"../config\"\n", - "DATA_DIR = Path.cwd().parent.resolve() / \"tests\" / \"data\"\n", - "DATASET = \"nuscenes_mcap\"\n", + "DATA_DIR = Path.cwd().parent.resolve() / \"tests\" / \"data\" / \"nuscenes\" / \"mcap\"\n", + "DATASET = \"nuscenes/mcap\"\n", "LOGGER = f\"rerun/{DATASET}\"\n", "\n", "with initialize(version_base=None, config_path=CONFIG_PATH):\n", diff --git a/examples/nuscenes_rrd.ipynb b/examples/nuscenes_rrd.ipynb index 7cf98cb..2f5b7c9 100644 --- a/examples/nuscenes_rrd.ipynb +++ b/examples/nuscenes_rrd.ipynb @@ -39,8 +39,8 @@ "outputs": [], "source": [ "CONFIG_PATH = \"../config\"\n", - "DATA_DIR = Path.cwd().parent.resolve() / \"tests\" / \"data\"\n", - "DATASET = \"nuscenes_rrd\"\n", + "DATA_DIR = Path.cwd().parent.resolve() / \"tests\" / \"data\" / \"nuscenes\" / \"rrd\"\n", + "DATASET = \"nuscenes/rrd\"\n", "LOGGER = f\"rerun/{DATASET}\"\n", "\n", "with initialize(version_base=None, config_path=CONFIG_PATH):\n", diff --git a/hatch_build.py b/hatch_build.py index 6e6edb0..a6aeb0e 100644 --- a/hatch_build.py +++ b/hatch_build.py @@ -3,7 +3,7 @@ from importlib import resources from pathlib import Path from tempfile import NamedTemporaryFile -from typing import Any, override +from typing import Any, final, override from grpc_tools import protoc from hatchling.builders.hooks.plugin.interface import BuildHookInterface @@ -13,6 +13,7 @@ logging.basicConfig(level=logging.DEBUG) +@final class BuildYaakIdlProtosHook(BuildHookInterface): # pyright: ignore[reportMissingTypeArgument] PLUGIN_NAME = "build-yaak-idl-protos" @@ -21,20 +22,13 @@ class BuildYaakIdlProtosHook(BuildHookInterface): # pyright: ignore[reportMissi / "src" / "rbyte" / "io" - / "table" / "yaak" / "idl-repo" / "intercom" / "proto" ) YAAK_IDL_PYTHON_OUT = ( - Path(__file__).resolve().parent - / "src" - / "rbyte" - / "io" - / "table" - / "yaak" - / "proto" + Path(__file__).resolve().parent / "src" / "rbyte" / "io" / "yaak" / "proto" ) YAAK_IDL_PROTOS = ("can.proto", "sensor.proto") diff --git a/justfile b/justfile index d3e502e..faf88e1 100644 --- a/justfile +++ b/justfile @@ -1,5 +1,6 @@ export PYTHONOPTIMIZE := "1" export HATCH_BUILD_CLEAN := "1" +export HYDRA_FULL_ERROR := "1" _default: @just --list --unsorted @@ -60,24 +61,6 @@ visualize *ARGS: generate-config hydra/job_logging=disabled \ {{ ARGS }} -[group('scripts')] -build-table *ARGS: generate-config - uv run rbyte-build-table \ - --config-path {{ justfile_directory() }}/config \ - --config-name build_table.yaml \ - hydra/hydra_logging=disabled \ - hydra/job_logging=disabled \ - {{ ARGS }} - -[group('scripts')] -read-frames *ARGS: generate-config - uv run rbyte-read-frames \ - --config-path {{ justfile_directory() }}/config \ - --config-name read_frames.yaml \ - hydra/hydra_logging=disabled \ - hydra/job_logging=disabled \ - {{ ARGS }} - # rerun server and viewer rerun bind="0.0.0.0" port="9876" ws-server-port="9877" web-viewer-port="9090": RUST_LOG=debug uv run rerun \ diff --git a/pyproject.toml b/pyproject.toml index d718517..41a5158 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,17 +1,18 @@ [project] name = "rbyte" -version = "0.6.0" +version = "0.7.0" description = "Multimodal PyTorch dataset library" authors = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }] maintainers = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }] dependencies = [ - "tensordict>=0.6.0", + "tensordict>=0.6.2", "torch", - "polars>=1.12.0", + "numpy", + "polars>=1.14.0", "pydantic>=2.9.2", "more-itertools>=10.5.0", "hydra-core>=1.3.2", - "optree>=0.13.0", + "optree>=0.13.1", "cachetools>=5.5.0", "diskcache>=5.6.3", "jaxtyping>=0.2.34", @@ -38,7 +39,7 @@ repo = "https://github.com/yaak-ai/rbyte" [project.optional-dependencies] build = ["hatchling>=1.25.0", "grpcio-tools>=1.62.0", "protoletariat==3.2.19"] -visualize = ["rerun-sdk[notebook]>=0.19.0"] +visualize = ["rerun-sdk[notebook]>=0.20.0"] mcap = [ "mcap>=1.2.1", "mcap-ros2-support>=0.5.5", @@ -49,14 +50,12 @@ yaak = ["protobuf", "ptars>=0.0.3"] jpeg = ["simplejpeg>=1.7.6"] video = [ "python-vali>=4.2.0.post0; sys_platform == 'linux'", - "video-reader-rs>=0.1.9", + "video-reader-rs>=0.2.1", ] hdf5 = ["h5py>=3.12.1"] -rrd = ["rerun-sdk>=0.19.0", "pyarrow-stubs"] +rrd = ["rerun-sdk>=0.20.0", "pyarrow-stubs"] [project.scripts] -rbyte-build-table = 'rbyte.scripts.build_table:main' -rbyte-read-frames = 'rbyte.scripts.read_frames:main' rbyte-visualize = 'rbyte.scripts.visualize:main' [build-system] @@ -69,10 +68,10 @@ build-backend = "hatchling.build" [tool.uv] dev-dependencies = [ - "wat-inspector>=0.4.0", + "wat-inspector>=0.4.2", "lovely-tensors>=0.1.17", "pudb>=2024.1.2", - "ipython>=8.28.0", + "ipython>=8.29.0", "ipython-autoimport>=0.5", "pytest>=8.3.3", "testbook>=0.4.2", diff --git a/src/rbyte/batch/batch.py b/src/rbyte/batch/batch.py index 8065461..a61056a 100644 --- a/src/rbyte/batch/batch.py +++ b/src/rbyte/batch/batch.py @@ -14,6 +14,5 @@ class BatchMeta: @tensorclass(autocast=True) # pyright: ignore[reportUntypedClassDecorator] class Batch: + data: TensorDict # pyright: ignore[reportUninitializedInstanceVariable] meta: BatchMeta # pyright: ignore[reportUninitializedInstanceVariable] - frame: TensorDict # pyright: ignore[reportUninitializedInstanceVariable] - table: TensorDict # pyright: ignore[reportUninitializedInstanceVariable] diff --git a/src/rbyte/config/base.py b/src/rbyte/config/base.py index 30e21ef..4935432 100644 --- a/src/rbyte/config/base.py +++ b/src/rbyte/config/base.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import Literal, TypeVar +from typing import ClassVar, Literal, TypeVar from hydra.utils import instantiate from pydantic import BaseModel as _BaseModel @@ -7,7 +7,8 @@ class BaseModel(_BaseModel): - model_config = ConfigDict( + model_config: ClassVar[ConfigDict] = ConfigDict( + arbitrary_types_allowed=True, frozen=True, extra="forbid", validate_assignment=True, @@ -19,7 +20,7 @@ class BaseModel(_BaseModel): class HydraConfig[T](BaseModel): - model_config = ConfigDict(extra="allow") + model_config: ClassVar[ConfigDict] = ConfigDict(extra="allow") target: ImportString[type[T]] = Field(alias="_target_") recursive: bool = Field(alias="_recursive_", default=True) diff --git a/src/rbyte/dataset.py b/src/rbyte/dataset.py index 8a4cd0b..ad90d21 100644 --- a/src/rbyte/dataset.py +++ b/src/rbyte/dataset.py @@ -3,20 +3,20 @@ from functools import cache from typing import Annotated -import more_itertools as mit import polars as pl import torch -from pydantic import ConfigDict, Field, StringConstraints, validate_call +from pydantic import Field, StringConstraints, validate_call from structlog import get_logger from structlog.contextvars import bound_contextvars from tensordict import TensorDict from torch.utils.data import Dataset as TorchDataset from rbyte.batch import Batch, BatchMeta -from rbyte.config.base import BaseModel, HydraConfig -from rbyte.io.frame.base import FrameReader -from rbyte.io.table.base import TableBuilderBase -from rbyte.sample.base import SampleTableBuilder +from rbyte.config import BaseModel, HydraConfig +from rbyte.io.base import TensorSource +from rbyte.io.table.base import TableBuilder +from rbyte.sample.base import SampleBuilder +from rbyte.utils.functional import pad_sequence __all__ = ["Dataset"] @@ -27,36 +27,32 @@ ] -class FrameSourceConfig(BaseModel): - reader: HydraConfig[FrameReader] +class SourceConfig(BaseModel): + source: HydraConfig[TensorSource] index_column: str -class TableSourceConfig(BaseModel): - builder: HydraConfig[TableBuilderBase] - - -class SourcesConfig(BaseModel): - frame: Mapping[Id, FrameSourceConfig] = Field(min_length=1) - table: TableSourceConfig | None = None +class InputConfig(BaseModel): + sources: Mapping[Id, SourceConfig] = Field(min_length=1) + table_builder: HydraConfig[TableBuilder] @unique class Column(StrEnum): input_id = "__input_id" sample_idx = "__sample_idx" - frame_idx = "__frame_idx" - source_id = "source.id" - source_reader = "source.reader" - source_index_column = "source.index_column" + source_idxs = "__source_idxs" + source_id = "__source.id" + source_config = "__source.config" + source_index_column = "__source.index_column" class Dataset(TorchDataset[TensorDict]): - @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) + @validate_call(config=BaseModel.model_config) def __init__( self, - inputs: Annotated[Mapping[Id, SourcesConfig], Field(min_length=1)], - sample_builder: HydraConfig[SampleTableBuilder], + inputs: Annotated[Mapping[Id, InputConfig], Field(min_length=1)], + sample_builder: HydraConfig[SampleBuilder], ) -> None: logger.debug("initializing dataset") @@ -66,7 +62,7 @@ def __init__( samples: Mapping[str, pl.LazyFrame] = {} for input_id, input_cfg in inputs.items(): with bound_contextvars(input_id=input_id): - table = self._build_table(input_cfg) + table = input_cfg.table_builder.instantiate().build().lazy() samples[input_id] = _sample_builder.build(table) logger.debug( "built samples", @@ -76,7 +72,7 @@ def __init__( input_id_enum = pl.Enum(sorted(samples)) - self._samples = ( + self._samples: pl.DataFrame = ( pl.concat( [ df.select( @@ -93,20 +89,20 @@ def __init__( .rechunk() ) - self._frame_sources = ( + self._sources: pl.DataFrame = ( pl.LazyFrame( [ { Column.input_id: input_id, - (k := "source"): [ - source_cfg.model_dump(exclude={"reader"}) + (k := "__source"): [ + source_cfg.model_dump(exclude={"source"}) | { "id": source_id, - "reader": source_cfg.reader.model_dump_json( + "config": source_cfg.source.model_dump_json( by_alias=True ), } - for source_id, source_cfg in input_cfg.frame.items() + for source_id, source_cfg in input_cfg.sources.items() ], } for input_id, input_cfg in inputs.items() @@ -120,109 +116,64 @@ def __init__( .rechunk() ) - @classmethod - def _build_table(cls, sources: SourcesConfig) -> pl.LazyFrame: - logger.debug("building table") - - match sources: - case SourcesConfig(frame=frame_sources, table=None) if ( - len(frame_sources) == 1 - ): - frame_source = mit.one(frame_sources.values()) - frame_reader = frame_source.reader.instantiate() - frame_idxs = pl.Series( - name=frame_source.index_column, - values=frame_reader.get_available_indexes(), - dtype=pl.UInt32, - ).sort() - - return pl.LazyFrame(frame_idxs) - - case SourcesConfig( - frame=frame_sources, table=TableSourceConfig(builder=builder) - ): - table_builder = builder.instantiate() - table = table_builder.build().lazy() - schema = table.collect_schema() - - for frame_source_id, frame_source in frame_sources.items(): - logger.debug("pruning table", frame_source=frame_source_id) - frame_reader = frame_source.reader.instantiate() - frame_idxs = pl.Series( - name=(col := frame_source.index_column), - values=frame_reader.get_available_indexes(), - dtype=schema[col], - ).sort() - - table = table.join( - pl.LazyFrame(frame_idxs), on=frame_idxs.name, how="semi" - ) - - return table - - case _: - logger.error("not implemented") - - raise NotImplementedError - @property def samples(self) -> pl.DataFrame: return self._samples @property - def frame_sources(self) -> pl.DataFrame: - return self._frame_sources + def sources(self) -> pl.DataFrame: + return self._sources @cache # noqa: B019 - def _get_frame_reader(self, reader_json: str) -> FrameReader: # noqa: PLR6301 - return HydraConfig[FrameReader].model_validate_json(reader_json).instantiate() + def _get_source(self, config: str) -> TensorSource: # noqa: PLR6301 + return HydraConfig[TensorSource].model_validate_json(config).instantiate() def __getitems__(self, indexes: Sequence[int]) -> Batch: # noqa: PLW3201 samples = self.samples[indexes] batch_size = [samples.height] - meta = BatchMeta( - sample_idx=samples[Column.sample_idx].to_torch(), # pyright: ignore[reportCallIssue] - input_id=samples[Column.input_id].to_list(), # pyright: ignore[reportCallIssue] - batch_size=batch_size, # pyright: ignore[reportCallIssue] - ) + source_idx_cols = self._sources[Column.source_index_column].unique() - frame_source_idx_cols = self._frame_sources[Column.source_index_column].unique() - - frame_sources = ( + sources = ( samples.lazy() - .join(self._frame_sources.lazy(), on=Column.input_id, how="left") + .join(self.sources.lazy(), on=Column.input_id, how="left") .with_columns( pl.coalesce( pl.when(pl.col(Column.source_index_column) == idx_col).then(idx_col) - for idx_col in frame_source_idx_cols - ).alias(Column.frame_idx) + for idx_col in source_idx_cols + ).alias(Column.source_idxs) ) .group_by(Column.source_id) - .agg(Column.source_reader, Column.frame_idx) + .agg(Column.source_config, Column.source_idxs) ) - frames = TensorDict( - { - row[Column.source_id]: torch.stack([ - self._get_frame_reader(reader).read(frame_idxs) - for (reader, frame_idxs) in zip( - row[Column.source_reader], row[Column.frame_idx], strict=True + tensors: Mapping[str, torch.Tensor] = { + row[Column.source_id]: pad_sequence( + [ + self._get_source(source)[idxs] + for (source, idxs) in zip( + row[Column.source_config], row[Column.source_idxs], strict=True ) - ]) - for row in frame_sources.collect().iter_rows(named=True) - }, - batch_size=batch_size, - ) + ], + dim=1, + value=torch.nan, + ) + for row in sources.collect().iter_rows(named=True) + } + + table: Mapping[str, Sequence[object]] = samples.select( + pl.exclude(Column.sample_idx, Column.input_id).to_physical() + ).to_dict(as_series=False) - table = TensorDict( - samples.select( # pyright: ignore[reportArgumentType] - pl.exclude(Column.sample_idx, Column.input_id).to_physical() - ).to_dict(as_series=False), - batch_size=batch_size, + data = TensorDict(tensors | table, batch_size=batch_size) # pyright: ignore[reportArgumentType] + + meta = BatchMeta( + sample_idx=samples[Column.sample_idx].to_torch(), # pyright: ignore[reportCallIssue] + input_id=samples[Column.input_id].to_list(), # pyright: ignore[reportCallIssue] + batch_size=batch_size, # pyright: ignore[reportCallIssue] ) - return Batch(meta=meta, frame=frames, table=table, batch_size=batch_size) # pyright: ignore[reportCallIssue] + return Batch(data=data, meta=meta, batch_size=batch_size) # pyright: ignore[reportCallIssue] def __len__(self) -> int: return len(self.samples) diff --git a/src/rbyte/io/__init__.py b/src/rbyte/io/__init__.py index e69de29..34e7d26 100644 --- a/src/rbyte/io/__init__.py +++ b/src/rbyte/io/__init__.py @@ -0,0 +1,50 @@ +from ._json import JsonTableReader +from ._numpy import NumpyTensorSource +from .path import PathTableReader, PathTensorSource +from .table import FpsResampler, TableAligner, TableBuilder, TableConcater + +__all__: list[str] = [ + "FpsResampler", + "JsonTableReader", + "NumpyTensorSource", + "PathTableReader", + "PathTensorSource", + "TableAligner", + "TableBuilder", + "TableConcater", +] + +try: + from .hdf5 import Hdf5TableReader, Hdf5TensorSource +except ImportError: + pass +else: + __all__ += ["Hdf5TableReader", "Hdf5TensorSource"] + +try: + from ._mcap import McapTableReader, McapTensorSource +except ImportError: + pass +else: + __all__ += ["McapTableReader", "McapTensorSource"] + +try: + from .rrd import RrdFrameSource, RrdTableReader +except ImportError: + pass +else: + __all__ += ["RrdFrameSource", "RrdTableReader"] + +try: + from .video.ffmpeg_source import FfmpegFrameSource +except ImportError: + pass +else: + __all__ += ["FfmpegFrameSource"] + +try: + from .yaak import YaakMetadataTableReader +except ImportError: + pass +else: + __all__ += ["YaakMetadataTableReader"] diff --git a/src/rbyte/io/_json/__init__.py b/src/rbyte/io/_json/__init__.py new file mode 100644 index 0000000..2c19de4 --- /dev/null +++ b/src/rbyte/io/_json/__init__.py @@ -0,0 +1,3 @@ +from .table_reader import JsonTableReader + +__all__ = ["JsonTableReader"] diff --git a/src/rbyte/io/table/json/reader.py b/src/rbyte/io/_json/table_reader.py similarity index 88% rename from src/rbyte/io/table/json/reader.py rename to src/rbyte/io/_json/table_reader.py index 4d23573..3ecd60a 100644 --- a/src/rbyte/io/table/json/reader.py +++ b/src/rbyte/io/_json/table_reader.py @@ -7,17 +7,17 @@ from typing import override import polars as pl -from optree import tree_map +from optree import PyTree, tree_map from polars._typing import PolarsDataType from polars.datatypes import ( DataType, # pyright: ignore[reportUnusedImport] # noqa: F401 DataTypeClass, # pyright: ignore[reportUnusedImport] # noqa: F401 ) -from pydantic import ConfigDict, Field +from pydantic import Field from xxhash import xxh3_64_intdigest as digest from rbyte.config.base import BaseModel, HydraConfig -from rbyte.io.table.base import TableReaderBase +from rbyte.io.table.base import TableReader from rbyte.io.table.transforms.base import TableTransform from rbyte.utils.dataframe.misc import unnest_all @@ -28,18 +28,16 @@ class SpecialField(StrEnum): class Config(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - fields: Mapping[str, Mapping[str, HydraConfig[PolarsDataType] | None]] transforms: Sequence[HydraConfig[TableTransform]] = Field(default=()) -class JsonTableReader(TableReaderBase, Hashable): +class JsonTableReader(TableReader, Hashable): def __init__(self, **kwargs: object) -> None: - self._config = Config.model_validate(kwargs) + self._config: Config = Config.model_validate(kwargs) @override - def read(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: + def read(self, path: PathLike[str]) -> PyTree[pl.DataFrame]: dfs: Mapping[str, pl.DataFrame] = {} for k, series in ( @@ -68,7 +66,7 @@ def read(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: dfs[k] = df - return dfs + return dfs # pyright: ignore[reportReturnType] @override def __hash__(self) -> int: diff --git a/src/rbyte/io/_mcap/__init__.py b/src/rbyte/io/_mcap/__init__.py new file mode 100644 index 0000000..aba5317 --- /dev/null +++ b/src/rbyte/io/_mcap/__init__.py @@ -0,0 +1,4 @@ +from .table_reader import McapTableReader +from .tensor_source import McapTensorSource + +__all__ = ["McapTableReader", "McapTensorSource"] diff --git a/src/rbyte/io/table/mcap/reader.py b/src/rbyte/io/_mcap/table_reader.py similarity index 93% rename from src/rbyte/io/table/mcap/reader.py rename to src/rbyte/io/_mcap/table_reader.py index 89eb9ce..f7a32f6 100644 --- a/src/rbyte/io/table/mcap/reader.py +++ b/src/rbyte/io/_mcap/table_reader.py @@ -13,14 +13,13 @@ import polars as pl from mcap.decoder import DecoderFactory from mcap.reader import SeekingReader -from optree import tree_map +from optree import PyTree, tree_map from polars._typing import PolarsDataType from polars.datatypes import ( DataType, # pyright: ignore[reportUnusedImport] # noqa: F401 DataTypeClass, # pyright: ignore[reportUnusedImport] # noqa: F401 ) from pydantic import ( - ConfigDict, ImportString, SerializationInfo, SerializerFunctionWrapHandler, @@ -32,15 +31,13 @@ from xxhash import xxh3_64_intdigest as digest from rbyte.config.base import BaseModel, HydraConfig -from rbyte.io.table.base import TableReaderBase +from rbyte.io.table.base import TableReader from rbyte.utils.dataframe import unnest_all logger = get_logger(__name__) class Config(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - decoder_factories: frozenset[ImportString[type[DecoderFactory]]] fields: Mapping[str, Mapping[str, HydraConfig[PolarsDataType] | None]] validate_crcs: bool = False @@ -67,12 +64,12 @@ class SpecialField(StrEnum): idx = "_idx_" -class McapTableReader(TableReaderBase, Hashable): +class McapTableReader(TableReader, Hashable): def __init__(self, **kwargs: object) -> None: - self._config = Config.model_validate(kwargs) + self._config: Config = Config.model_validate(kwargs) @override - def read(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: + def read(self, path: PathLike[str]) -> PyTree[pl.DataFrame]: with ( bound_contextvars(path=str(path)), Path(path).open("rb") as _f, @@ -128,7 +125,7 @@ def read(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: row_df = pl.DataFrame( [getattr(dmt.message, field) for field in special_fields], - schema=special_fields, # pyright: ignore[reportArgumentType] + schema=special_fields, ) if ( @@ -150,7 +147,7 @@ def read(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: dfs[topic] = df.rechunk() - return dfs + return dfs # pyright: ignore[reportReturnType] @override def __hash__(self) -> int: diff --git a/src/rbyte/io/frame/mcap/reader.py b/src/rbyte/io/_mcap/tensor_source.py similarity index 82% rename from src/rbyte/io/frame/mcap/reader.py rename to src/rbyte/io/_mcap/tensor_source.py index 57e51f6..f815fb8 100644 --- a/src/rbyte/io/frame/mcap/reader.py +++ b/src/rbyte/io/_mcap/tensor_source.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from functools import cached_property from mmap import ACCESS_READ, mmap -from typing import IO, override +from typing import IO, Any, override import more_itertools as mit import numpy.typing as npt @@ -12,14 +12,15 @@ from mcap.decoder import DecoderFactory from mcap.opcode import Opcode from mcap.reader import SeekingReader -from mcap.records import Chunk, ChunkIndex, Message +from mcap.records import Channel, Chunk, ChunkIndex, Message from mcap.stream_reader import get_chunk_data_stream -from pydantic import ConfigDict, FilePath, ImportString, validate_call +from pydantic import FilePath, ImportString, validate_call from structlog import get_logger from structlog.contextvars import bound_contextvars from torch import Tensor -from rbyte.io.frame.base import FrameReader +from rbyte.config.base import BaseModel +from rbyte.io.base import TensorSource logger = get_logger(__name__) @@ -31,14 +32,14 @@ class MessageIndex: message_length: int -class McapFrameReader(FrameReader): - @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) +class McapTensorSource(TensorSource): + @validate_call(config=BaseModel.model_config) def __init__( self, path: FilePath, topic: str, decoder_factory: ImportString[type[DecoderFactory]], - frame_decoder: Callable[[bytes], npt.ArrayLike], + decoder: Callable[[bytes], npt.ArrayLike], validate_crcs: bool = False, # noqa: FBT001, FBT002 ) -> None: super().__init__() @@ -46,8 +47,8 @@ def __init__( with bound_contextvars( path=path.as_posix(), topic=topic, message_decoder_factory=decoder_factory ): - self._path = path - self._validate_crcs = validate_crcs + self._path: FilePath = path + self._validate_crcs: bool = validate_crcs summary = SeekingReader( stream=self._file, validate_crcs=self._validate_crcs @@ -57,7 +58,7 @@ def __init__( logger.error(msg := "missing summary") raise ValueError(msg) - self._channel = mit.one( + self._channel: Channel = mit.one( channel for channel in summary.channels.values() if channel.topic == topic @@ -72,13 +73,13 @@ def __init__( logger.error(msg := "missing message decoder") raise RuntimeError(msg) - self._message_decoder = message_decoder - self._chunk_indexes = tuple( + self._message_decoder: Callable[[bytes], Any] = message_decoder + self._chunk_indexes: tuple[ChunkIndex, ...] = tuple( chunk_index for chunk_index in summary.chunk_indexes if self._channel.id in chunk_index.message_index_offsets ) - self._frame_decoder = frame_decoder + self._decoder: Callable[[bytes], npt.ArrayLike] = decoder @property def _file(self) -> IO[bytes]: @@ -88,7 +89,9 @@ def _file(self) -> IO[bytes]: case None | mmap(closed=True): with self._path.open("rb") as f: - self._mmap = mmap(fileno=f.fileno(), length=0, access=ACCESS_READ) + self._mmap: mmap = mmap( + fileno=f.fileno(), length=0, access=ACCESS_READ + ) case _: raise RuntimeError @@ -96,7 +99,7 @@ def _file(self) -> IO[bytes]: return self._mmap # pyright: ignore[reportReturnType] @override - def read(self, indexes: Iterable[int]) -> Shaped[Tensor, "b h w c"]: + def __getitem__(self, indexes: Iterable[int]) -> Shaped[Tensor, "b h w c"]: frames: Mapping[int, npt.ArrayLike] = {} message_indexes_by_chunk_start_offset: Mapping[ @@ -119,13 +122,13 @@ def read(self, indexes: Iterable[int]) -> Shaped[Tensor, "b h w c"]: stream.read(message_index.message_start_offset - stream.count) # pyright: ignore[reportUnusedCallResult] message = Message.read(stream, message_index.message_length) decoded_message = self._message_decoder(message.data) - frames[frame_index] = self._frame_decoder(decoded_message.data) + frames[frame_index] = self._decoder(decoded_message.data) return torch.stack([torch.from_numpy(frames[idx]) for idx in indexes]) # pyright: ignore[reportUnknownMemberType] @override - def get_available_indexes(self) -> Sequence[int]: - return range(len(self._message_indexes)) + def __len__(self) -> int: + return len(self._message_indexes) @cached_property def _message_indexes(self) -> Sequence[MessageIndex]: diff --git a/src/rbyte/io/_numpy/__init__.py b/src/rbyte/io/_numpy/__init__.py new file mode 100644 index 0000000..b5b27f0 --- /dev/null +++ b/src/rbyte/io/_numpy/__init__.py @@ -0,0 +1,3 @@ +from .tensor_source import NumpyTensorSource + +__all__ = ["NumpyTensorSource"] diff --git a/src/rbyte/io/_numpy/tensor_source.py b/src/rbyte/io/_numpy/tensor_source.py new file mode 100644 index 0000000..a952996 --- /dev/null +++ b/src/rbyte/io/_numpy/tensor_source.py @@ -0,0 +1,50 @@ +from collections.abc import Iterable, Sequence +from functools import cached_property +from os import PathLike +from pathlib import Path +from typing import TYPE_CHECKING, Any, override + +import numpy as np +import torch +from numpy.lib.recfunctions import ( + structured_to_unstructured, # pyright: ignore[reportUnknownVariableType] +) +from pydantic import validate_call +from torch import Tensor + +from rbyte.config.base import BaseModel +from rbyte.io.base import TensorSource +from rbyte.utils.functional import pad_sequence + +if TYPE_CHECKING: + from types import EllipsisType + + +class NumpyTensorSource(TensorSource): + @validate_call(config=BaseModel.model_config) + def __init__( + self, path: PathLike[str], select: Sequence[str] | None = None + ) -> None: + super().__init__() + + self._path: Path = Path(path) + self._select: Sequence[str] | EllipsisType = select or ... + + @cached_property + def _path_posix(self) -> str: + return self._path.resolve().as_posix() + + @override + def __getitem__(self, indexes: Iterable[Any]) -> Tensor: + tensors: list[Tensor] = [] + for index in indexes: + path = self._path_posix.format(index) + array = structured_to_unstructured(np.load(path)[self._select]) # pyright: ignore[reportUnknownVariableType] + tensor = torch.from_numpy(np.ascontiguousarray(array)) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType] + tensors.append(tensor) + + return pad_sequence(list(tensors), dim=0, value=torch.nan) + + @override + def __len__(self) -> int: + raise NotImplementedError diff --git a/src/rbyte/io/base.py b/src/rbyte/io/base.py new file mode 100644 index 0000000..a67a9d6 --- /dev/null +++ b/src/rbyte/io/base.py @@ -0,0 +1,10 @@ +from collections.abc import Iterable +from typing import Any, Protocol, runtime_checkable + +from torch import Tensor + + +@runtime_checkable +class TensorSource(Protocol): + def __getitem__(self, indexes: Iterable[Any]) -> Tensor: ... + def __len__(self) -> int: ... diff --git a/src/rbyte/io/frame/__init__.py b/src/rbyte/io/frame/__init__.py deleted file mode 100644 index 1be0a47..0000000 --- a/src/rbyte/io/frame/__init__.py +++ /dev/null @@ -1,32 +0,0 @@ -from .directory import DirectoryFrameReader - -__all__ = ["DirectoryFrameReader"] - - -try: - from .mcap import McapFrameReader -except ImportError: - pass -else: - __all__ += ["McapFrameReader"] - -try: - from .hdf5 import Hdf5FrameReader -except ImportError: - pass -else: - __all__ += ["Hdf5FrameReader"] - -try: - from .rrd import RrdFrameReader -except ImportError: - pass -else: - __all__ += ["RrdFrameReader"] - -try: - from .video.ffmpeg_reader import FfmpegFrameReader -except ImportError: - pass -else: - __all__ += ["FfmpegFrameReader"] diff --git a/src/rbyte/io/frame/base.py b/src/rbyte/io/frame/base.py deleted file mode 100644 index 2b2d1e3..0000000 --- a/src/rbyte/io/frame/base.py +++ /dev/null @@ -1,13 +0,0 @@ -from collections.abc import Iterable, Sequence -from typing import Protocol, runtime_checkable - -from jaxtyping import Shaped -from torch import Tensor - - -@runtime_checkable -class FrameReader(Protocol): - def read( - self, indexes: Iterable[int] - ) -> Shaped[Tensor, "b h w c"] | Shaped[Tensor, "b c h w"]: ... - def get_available_indexes(self) -> Sequence[int]: ... diff --git a/src/rbyte/io/frame/directory/__init__.py b/src/rbyte/io/frame/directory/__init__.py deleted file mode 100644 index 08a85f4..0000000 --- a/src/rbyte/io/frame/directory/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .reader import DirectoryFrameReader - -__all__ = ["DirectoryFrameReader"] diff --git a/src/rbyte/io/frame/directory/reader.py b/src/rbyte/io/frame/directory/reader.py deleted file mode 100644 index f160383..0000000 --- a/src/rbyte/io/frame/directory/reader.py +++ /dev/null @@ -1,47 +0,0 @@ -from collections.abc import Callable, Iterable, Sequence -from functools import cached_property -from os import PathLike -from pathlib import Path -from typing import override - -import numpy.typing as npt -import parse -import torch -from jaxtyping import UInt8 -from pydantic import validate_call -from torch import Tensor - -from rbyte.io.frame.base import FrameReader - - -class DirectoryFrameReader(FrameReader): - @validate_call - def __init__( - self, path: PathLike[str], frame_decoder: Callable[[bytes], npt.ArrayLike] - ) -> None: - super().__init__() - - self._path = Path(path) - self._frame_decoder = frame_decoder - - @cached_property - def _path_posix(self) -> str: - return self._path.resolve().as_posix() - - def _decode(self, path: str) -> npt.ArrayLike: - with Path(path).open("rb") as f: - return self._frame_decoder(f.read()) - - @override - def read(self, indexes: Iterable[int]) -> UInt8[Tensor, "b h w c"]: - paths = map(self._path_posix.format, indexes) - frames_np = map(self._decode, paths) - frames_tch = map(torch.from_numpy, frames_np) # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType] - - return torch.stack(list(frames_tch)) - - @override - def get_available_indexes(self) -> Sequence[int]: - parser = parse.compile(self._path.name) # pyright: ignore[reportUnknownMemberType] - filenames = (path.name for path in self._path.parent.iterdir()) - return [res[0] for res in map(parser.parse, filenames) if res] # pyright: ignore[reportUnknownVariableType, reportIndexIssue, reportUnknownArgumentType, reportUnknownMemberType] diff --git a/src/rbyte/io/frame/hdf5/__init__.py b/src/rbyte/io/frame/hdf5/__init__.py deleted file mode 100644 index 125cf3d..0000000 --- a/src/rbyte/io/frame/hdf5/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .reader import Hdf5FrameReader - -__all__ = ["Hdf5FrameReader"] diff --git a/src/rbyte/io/frame/hdf5/reader.py b/src/rbyte/io/frame/hdf5/reader.py deleted file mode 100644 index 3613642..0000000 --- a/src/rbyte/io/frame/hdf5/reader.py +++ /dev/null @@ -1,25 +0,0 @@ -from collections.abc import Iterable, Sequence -from typing import cast, override - -import h5py -import torch -from jaxtyping import UInt8 -from pydantic import FilePath, validate_call -from torch import Tensor - -from rbyte.io.frame.base import FrameReader - - -class Hdf5FrameReader(FrameReader): - @validate_call - def __init__(self, path: FilePath, key: str) -> None: - file = h5py.File(path) - self._dataset = cast(h5py.Dataset, file[key]) - - @override - def read(self, indexes: Iterable[int]) -> UInt8[Tensor, "b h w c"]: - return torch.from_numpy(self._dataset[indexes]) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType] - - @override - def get_available_indexes(self) -> Sequence[int]: - return range(len(self._dataset)) diff --git a/src/rbyte/io/frame/mcap/__init__.py b/src/rbyte/io/frame/mcap/__init__.py deleted file mode 100644 index 1f57037..0000000 --- a/src/rbyte/io/frame/mcap/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .reader import McapFrameReader - -__all__ = ["McapFrameReader"] diff --git a/src/rbyte/io/frame/rrd/__init__.py b/src/rbyte/io/frame/rrd/__init__.py deleted file mode 100644 index c7265b0..0000000 --- a/src/rbyte/io/frame/rrd/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .reader import RrdFrameReader - -__all__ = ["RrdFrameReader"] diff --git a/src/rbyte/io/frame/rrd/reader.py b/src/rbyte/io/frame/rrd/reader.py deleted file mode 100644 index c547127..0000000 --- a/src/rbyte/io/frame/rrd/reader.py +++ /dev/null @@ -1,78 +0,0 @@ -from collections.abc import Callable, Iterable, Mapping, Sequence -from enum import StrEnum, unique -from typing import cast, override - -import numpy.typing as npt -import polars as pl -import rerun as rr -import torch -from jaxtyping import UInt8 -from pydantic import ConfigDict, FilePath, validate_call -from rerun.components import Blob, MediaType -from torch import Tensor - -from rbyte.io.frame.base import FrameReader - - -@unique -class Column(StrEnum): - blob = Blob.__name__ - media_type = MediaType.__name__ - - -class RrdFrameReader(FrameReader): - @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) - def __init__( - self, - path: FilePath, - *, - index: str, - entity_path: str, - frame_decoders: Mapping[str, Callable[[bytes], npt.ArrayLike]], - ) -> None: - recording = rr.dataframe.load_recording(path) # pyright: ignore[reportUnknownMemberType] - view = recording.view(index=index, contents={entity_path: (Blob, MediaType)}) - reader = view.select( - columns=[ - index, - *(f"{entity_path}:{ct.__name__}" for ct in (Blob, MediaType)), - ] - ) - - # WARN: RecordBatchReader does not support random seeking => storing in memory - df = ( - cast( - pl.DataFrame, - pl.from_arrow(reader.read_all(), rechunk=True), # pyright: ignore[reportUnknownMemberType] - ) - .sort(index) - .drop(index) - .select( - pl.all().explode().name.map(lambda x: x.removeprefix(f"{entity_path}:")) - ) - ) - - self._df = df.cast({ - (col := Column.media_type): pl.Enum(df.select(col).unique().to_series()) - }) - - self._frame_decoders = frame_decoders - - @override - def read(self, indexes: Iterable[int]) -> UInt8[Tensor, "b h w c"]: - df = self._df[list(indexes)] - - frames_np = ( - self._frame_decoders[media_type](blob.to_numpy(allow_copy=False)) - for media_type, blob in zip( - df[Column.media_type], df[Column.blob], strict=True - ) - ) - - frames_tch = map(torch.from_numpy, frames_np) # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType] - - return torch.stack(list(frames_tch)) - - @override - def get_available_indexes(self) -> Sequence[int]: - return range(len(self._df)) diff --git a/src/rbyte/io/frame/video/ffmpeg_reader.py b/src/rbyte/io/frame/video/ffmpeg_reader.py deleted file mode 100644 index 55109cb..0000000 --- a/src/rbyte/io/frame/video/ffmpeg_reader.py +++ /dev/null @@ -1,46 +0,0 @@ -from collections.abc import Callable, Iterable, Sequence -from functools import partial -from pathlib import Path -from typing import override - -import torch -import video_reader as vr -from jaxtyping import UInt8 -from pydantic import FilePath, NonNegativeInt, validate_call -from torch import Tensor - -from rbyte.io.frame.base import FrameReader - - -class FfmpegFrameReader(FrameReader): - @validate_call - def __init__( - self, - path: FilePath, - threads: NonNegativeInt | None = None, - resize_shorter_side: NonNegativeInt | None = None, - with_fallback: bool | None = None, # noqa: FBT001 - ) -> None: - super().__init__() - self._path = Path(path).resolve().as_posix() - - self._get_batch: Callable[[str, Iterable[int]], UInt8[Tensor, "b h w c"]] = ( - partial( - vr.get_batch, # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] - threads=threads, - resize_shorter_side=resize_shorter_side, - with_fallback=with_fallback, - ) - ) - - @override - def read(self, indexes: Iterable[int]) -> UInt8[Tensor, "b h w c"]: - batch = self._get_batch(self._path, indexes) - - return torch.from_numpy(batch) # pyright: ignore[reportUnknownMemberType] - - @override - def get_available_indexes(self) -> Sequence[int]: - num_frames = int(vr.get_info(self._path)["frame_count"]) # pyright: ignore[reportAttributeAccessIssue, reportUnknownArgumentType, reportUnknownMemberType] - - return range(num_frames) diff --git a/src/rbyte/io/hdf5/__init__.py b/src/rbyte/io/hdf5/__init__.py new file mode 100644 index 0000000..aa57ae7 --- /dev/null +++ b/src/rbyte/io/hdf5/__init__.py @@ -0,0 +1,4 @@ +from .table_reader import Hdf5TableReader +from .tensor_source import Hdf5TensorSource + +__all__ = ["Hdf5TableReader", "Hdf5TensorSource"] diff --git a/src/rbyte/io/hdf5/table_reader.py b/src/rbyte/io/hdf5/table_reader.py new file mode 100644 index 0000000..f50da8e --- /dev/null +++ b/src/rbyte/io/hdf5/table_reader.py @@ -0,0 +1,97 @@ +import json +from collections.abc import Hashable, Mapping, Sequence +from enum import StrEnum, unique +from functools import cached_property +from os import PathLike +from typing import cast, override + +import numpy.typing as npt +import polars as pl +from h5py import Dataset, File +from optree import PyTree, tree_map, tree_map_with_path +from polars._typing import PolarsDataType # noqa: PLC2701 +from polars.datatypes import ( + DataType, # pyright: ignore[reportUnusedImport] # noqa: F401 + DataTypeClass, # pyright: ignore[reportUnusedImport] # noqa: F401 +) +from xxhash import xxh3_64_intdigest as digest + +from rbyte.config import BaseModel +from rbyte.config.base import HydraConfig +from rbyte.io.table.base import TableReader + +type Fields = Mapping[str, HydraConfig[PolarsDataType] | None] | Mapping[str, "Fields"] + + +class Config(BaseModel): + fields: Fields + + +Config.model_rebuild() # pyright: ignore[reportUnusedCallResult] + + +@unique +class SpecialField(StrEnum): + idx = "_idx_" + + +class Hdf5TableReader(TableReader, Hashable): + def __init__(self, **kwargs: object) -> None: + self._config: Config = Config.model_validate(kwargs) + + @override + def read(self, path: PathLike[str]) -> PyTree[pl.DataFrame]: + with File(path) as f: + + def build_series( + path: Sequence[str], dtype: PolarsDataType | None + ) -> pl.Series | None: + key = "/".join(path) + match obj := f.get(key): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] + case Dataset(): + values = cast(npt.ArrayLike, obj[:]) + return pl.Series(values=values, dtype=dtype) + + case None: + return None + + case _: # pyright: ignore[reportUnknownVariableType] + raise NotImplementedError + + series = tree_map_with_path(build_series, self._fields, none_is_leaf=True) + + dfs = tree_map( + pl.DataFrame, + series, + is_leaf=lambda obj: isinstance(obj, dict) + and all(isinstance(v, pl.Series) or v is None for v in obj.values()), # pyright: ignore[reportUnknownVariableType] + ) + + def maybe_add_index( + df: pl.DataFrame, schema: Mapping[str, PolarsDataType | None] + ) -> pl.DataFrame: + match schema: + case {SpecialField.idx: dtype}: + return df.select( + pl.int_range(pl.len(), dtype=dtype or pl.UInt32).alias( # pyright: ignore[reportArgumentType] + SpecialField.idx + ), + pl.exclude(SpecialField.idx), + ) + + case _: + return df + + return tree_map(maybe_add_index, dfs, self._fields) + + @override + def __hash__(self) -> int: + config = self._config.model_dump_json() + # roundtripping json to work around https://github.com/pydantic/pydantic/issues/7424 + config_str = json.dumps(json.loads(config), sort_keys=True) + + return digest(config_str) + + @cached_property + def _fields(self) -> PyTree[PolarsDataType | None]: + return tree_map(HydraConfig.instantiate, self._config.fields) # pyright: ignore[reportArgumentType, reportUnknownArgumentType, reportUnknownMemberType, reportUnknownVariableType] diff --git a/src/rbyte/io/hdf5/tensor_source.py b/src/rbyte/io/hdf5/tensor_source.py new file mode 100644 index 0000000..592bb1c --- /dev/null +++ b/src/rbyte/io/hdf5/tensor_source.py @@ -0,0 +1,25 @@ +from collections.abc import Iterable +from typing import cast, override + +import torch +from h5py import Dataset, File +from jaxtyping import UInt8 +from pydantic import FilePath, validate_call +from torch import Tensor + +from rbyte.io.base import TensorSource + + +class Hdf5TensorSource(TensorSource): + @validate_call + def __init__(self, path: FilePath, key: str) -> None: + file = File(path) + self._dataset: Dataset = cast(Dataset, file[key]) + + @override + def __getitem__(self, indexes: Iterable[int]) -> UInt8[Tensor, "b h w c"]: + return torch.from_numpy(self._dataset[indexes]) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType] + + @override + def __len__(self) -> int: + return len(self._dataset) diff --git a/src/rbyte/io/path/__init__.py b/src/rbyte/io/path/__init__.py new file mode 100644 index 0000000..c316858 --- /dev/null +++ b/src/rbyte/io/path/__init__.py @@ -0,0 +1,4 @@ +from .table_reader import PathTableReader +from .tensor_source import PathTensorSource + +__all__ = ["PathTableReader", "PathTensorSource"] diff --git a/src/rbyte/io/path/table_reader.py b/src/rbyte/io/path/table_reader.py new file mode 100644 index 0000000..b2cdb36 --- /dev/null +++ b/src/rbyte/io/path/table_reader.py @@ -0,0 +1,75 @@ +import os +from collections.abc import Mapping +from enum import StrEnum, unique +from functools import cached_property +from os import PathLike +from pathlib import Path +from typing import override + +import parse +import polars as pl +from optree import PyTree, tree_map +from polars._typing import PolarsDataType +from polars.datatypes import ( + DataType, # pyright: ignore[reportUnusedImport] # noqa: F401 + DataTypeClass, # pyright: ignore[reportUnusedImport] # noqa: F401 +) +from structlog import get_logger + +from rbyte.config.base import BaseModel, HydraConfig +from rbyte.io.table.base import TableReader + +logger = get_logger(__name__) + + +class Config(BaseModel): + fields: Mapping[str, HydraConfig[PolarsDataType] | None] = {} + + +@unique +class SpecialField(StrEnum): + idx = "_idx_" + + +class PathTableReader(TableReader): + def __init__(self, **kwargs: object) -> None: + self._config: Config = Config.model_validate(kwargs) + + @override + def read(self, path: PathLike[str]) -> PyTree[pl.DataFrame]: + parser = parse.compile(Path(path).resolve().as_posix()) # pyright: ignore[reportUnknownMemberType] + match parser.named_fields, parser.fixed_fields: # pyright: ignore[reportUnknownMemberType] + case ([_, *_], []): # pyright: ignore[reportUnknownVariableType] + pass + + case (named_fields, fixed_fields): # pyright: ignore[reportUnknownVariableType] + logger.error( + msg := "parser not supported", + named_fields=named_fields, + fixed_fields=fixed_fields, + ) + raise RuntimeError(msg) + + parent = Path(os.path.commonpath([path, parser._expression])) # pyright: ignore[reportPrivateUsage] # noqa: SLF001 + results = (parser.parse(p.as_posix()) for p in parent.rglob("*") if p.is_file()) # pyright: ignore[reportUnknownMemberType] + + df = pl.DataFrame( + result.named # pyright: ignore[reportUnknownMemberType] + for result in results + if isinstance(result, parse.Result) + ) + + if (idx_name := SpecialField.idx) in self._fields: + df = df.with_row_index(idx_name).cast({ + idx_name: self._fields[idx_name] or pl.UInt32 + }) + + df_schema = { + name: dtype for name, dtype in self._fields.items() if dtype is not None + } + + return df.cast(df_schema) # pyright: ignore[reportArgumentType, reportReturnType] + + @cached_property + def _fields(self) -> Mapping[str, PolarsDataType | None]: + return tree_map(HydraConfig.instantiate, self._config.fields) # pyright: ignore[reportArgumentType, reportUnknownArgumentType, reportUnknownMemberType, reportUnknownVariableType, reportReturnType] diff --git a/src/rbyte/io/path/tensor_source.py b/src/rbyte/io/path/tensor_source.py new file mode 100644 index 0000000..8f74a19 --- /dev/null +++ b/src/rbyte/io/path/tensor_source.py @@ -0,0 +1,44 @@ +from collections.abc import Callable, Iterable +from functools import cached_property +from os import PathLike +from pathlib import Path +from typing import Any, override + +import numpy.typing as npt +import torch +from jaxtyping import UInt8 +from pydantic import validate_call +from torch import Tensor + +from rbyte.io.base import TensorSource + + +class PathTensorSource(TensorSource): + @validate_call + def __init__( + self, path: PathLike[str], decoder: Callable[[bytes], npt.ArrayLike] + ) -> None: + super().__init__() + + self._path: Path = Path(path) + self._decoder: Callable[[bytes], npt.ArrayLike] = decoder + + @cached_property + def _path_posix(self) -> str: + return self._path.resolve().as_posix() + + def _decode(self, path: str) -> npt.ArrayLike: + with Path(path).open("rb") as f: + return self._decoder(f.read()) + + @override + def __getitem__(self, indexes: Iterable[Any]) -> UInt8[Tensor, "b h w c"]: + paths = map(self._path_posix.format, indexes) + arrays = map(self._decode, paths) + tensors = map(torch.from_numpy, arrays) # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType] + + return torch.stack(list(tensors)) + + @override + def __len__(self) -> int: + raise NotImplementedError diff --git a/src/rbyte/io/rrd/__init__.py b/src/rbyte/io/rrd/__init__.py new file mode 100644 index 0000000..bacca6e --- /dev/null +++ b/src/rbyte/io/rrd/__init__.py @@ -0,0 +1,4 @@ +from .frame_source import RrdFrameSource +from .table_reader import RrdTableReader + +__all__ = ["RrdFrameSource", "RrdTableReader"] diff --git a/src/rbyte/io/rrd/frame_source.py b/src/rbyte/io/rrd/frame_source.py new file mode 100644 index 0000000..1857fb3 --- /dev/null +++ b/src/rbyte/io/rrd/frame_source.py @@ -0,0 +1,55 @@ +from collections.abc import Callable, Iterable +from typing import cast, override + +import numpy.typing as npt +import polars as pl +import rerun as rr +import torch +from jaxtyping import UInt8 +from pydantic import FilePath, validate_call +from rerun.components import Blob +from torch import Tensor + +from rbyte.config.base import BaseModel +from rbyte.io.base import TensorSource + + +class RrdFrameSource(TensorSource): + @validate_call(config=BaseModel.model_config) + def __init__( + self, + path: FilePath, + *, + index: str, + entity_path: str, + decoder: Callable[[bytes], npt.ArrayLike], + ) -> None: + recording = rr.dataframe.load_recording(path) # pyright: ignore[reportUnknownMemberType] + view = recording.view(index=index, contents={entity_path: [Blob]}) + reader = view.select(columns=[index, f"{entity_path}:{Blob.__name__}"]) + + # WARN: RecordBatchReader does not support random seeking => storing in memory + self._series: pl.Series = ( + cast( + pl.DataFrame, + pl.from_arrow(reader.read_all(), rechunk=True), # pyright: ignore[reportUnknownMemberType] + ) + .sort(index) + .drop(index) + .select(pl.all().explode()) + .to_series(0) + ) + + self._decoder: Callable[[bytes], npt.ArrayLike] = decoder + + @override + def __getitem__(self, indexes: Iterable[int]) -> UInt8[Tensor, "b h w c"]: + arrays = (self._series[i].to_numpy(allow_copy=False) for i in indexes) + frames_np = map(self._decoder, arrays) + frames_tch = map(torch.from_numpy, frames_np) # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType] + + return torch.stack(list(frames_tch)) + + @override + def __len__(self) -> int: + return len(self._series) diff --git a/src/rbyte/io/table/rrd/reader.py b/src/rbyte/io/rrd/table_reader.py similarity index 90% rename from src/rbyte/io/table/rrd/reader.py rename to src/rbyte/io/rrd/table_reader.py index 3876a3e..97b6c8a 100644 --- a/src/rbyte/io/table/rrd/reader.py +++ b/src/rbyte/io/rrd/table_reader.py @@ -7,16 +7,14 @@ import more_itertools as mit import polars as pl import rerun.dataframe as rrd -from pydantic import ConfigDict +from optree import PyTree from xxhash import xxh3_64_intdigest as digest from rbyte.config.base import BaseModel -from rbyte.io.table.base import TableReaderBase +from rbyte.io.table.base import TableReader class Config(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - index: str contents: Mapping[str, Sequence[str]] @@ -28,12 +26,12 @@ class Column(StrEnum): idx = "_idx_" -class RrdTableReader(TableReaderBase, Hashable): +class RrdTableReader(TableReader, Hashable): def __init__(self, **kwargs: object) -> None: - self._config = Config.model_validate(kwargs) + self._config: Config = Config.model_validate(kwargs) @override - def read(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: + def read(self, path: PathLike[str]) -> PyTree[pl.DataFrame]: recording = rrd.load_recording(path) # pyright: ignore[reportUnknownMemberType] schema = recording.schema() @@ -97,7 +95,7 @@ def read(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: dfs[entity] = entity_df - return dfs + return dfs # pyright: ignore[reportReturnType] @override def __hash__(self) -> int: diff --git a/src/rbyte/io/table/__init__.py b/src/rbyte/io/table/__init__.py index 45c7207..20c3101 100644 --- a/src/rbyte/io/table/__init__.py +++ b/src/rbyte/io/table/__init__.py @@ -1,34 +1,6 @@ from .aligner import TableAligner from .builder import TableBuilder from .concater import TableConcater -from .json import JsonTableReader +from .transforms import FpsResampler -__all__ = ["JsonTableReader", "TableAligner", "TableBuilder", "TableConcater"] - -try: - from .hdf5 import Hdf5TableReader -except ImportError: - pass -else: - __all__ += ["Hdf5TableReader"] - -try: - from .mcap import McapTableReader -except ImportError: - pass -else: - __all__ += ["McapTableReader"] - -try: - from .yaak import YaakMetadataTableReader -except ImportError: - pass -else: - __all__ += ["YaakMetadataTableReader"] - -try: - from .rrd import RrdTableReader -except ImportError: - pass -else: - __all__ += ["RrdTableReader"] +__all__ = ["FpsResampler", "TableAligner", "TableBuilder", "TableConcater"] diff --git a/src/rbyte/io/table/aligner.py b/src/rbyte/io/table/aligner.py index b51a4a7..0a880c7 100644 --- a/src/rbyte/io/table/aligner.py +++ b/src/rbyte/io/table/aligner.py @@ -1,26 +1,34 @@ import json from collections import OrderedDict -from collections.abc import Hashable, Mapping, Sequence +from collections.abc import Hashable from datetime import timedelta from functools import cached_property -from operator import itemgetter -from typing import Annotated, Literal, Self, override +from typing import Annotated, Literal, override +from uuid import uuid4 -import more_itertools as mit import polars as pl +from optree import ( + PyTree, + PyTreeAccessor, + tree_accessors, + tree_map_with_accessor, + tree_map_with_path, +) from polars._typing import AsofJoinStrategy -from pydantic import StringConstraints, model_validator +from pydantic import Field, StringConstraints from structlog import get_logger +from structlog.contextvars import bound_contextvars from xxhash import xxh3_64_intdigest as digest from rbyte.config.base import BaseModel -from rbyte.io.table.base import TableMergerBase + +from .base import TableMerger logger = get_logger(__name__) -class RefColumnMergeConfig(BaseModel): - method: Literal["ref"] = "ref" +class InterpColumnMergeConfig(BaseModel): + method: Literal["interp"] = "interp" class AsofColumnMergeConfig(BaseModel): @@ -29,129 +37,113 @@ class AsofColumnMergeConfig(BaseModel): tolerance: str | int | float | timedelta | None = None -class InterpColumnMergeConfig(BaseModel): - method: Literal["interp"] = "interp" - +ColumnMergeConfig = InterpColumnMergeConfig | AsofColumnMergeConfig -MergeConfig = RefColumnMergeConfig | AsofColumnMergeConfig | InterpColumnMergeConfig +class TableMergeConfig(BaseModel): + key: str + columns: OrderedDict[str, ColumnMergeConfig] = Field(default_factory=OrderedDict) -class Config(BaseModel): - merge: OrderedDict[str, Mapping[str, MergeConfig]] - separator: Annotated[str, StringConstraints(strip_whitespace=True)] = "/" - @model_validator(mode="after") - def validate_refs(self) -> Self: - ref_config = RefColumnMergeConfig() - for k, v in self.columns_by_merge_config.items(): - match v.get(ref_config, None): - case [_column]: - pass +type MergeConfig = TableMergeConfig | OrderedDict[str, "MergeConfig"] - case _: - msg = f"merge `{k}` must have exactly one column with {ref_config}" - raise ValueError(msg) - return self +class Config(BaseModel): + merge: MergeConfig + separator: Annotated[str, StringConstraints(strip_whitespace=True)] = "/" @cached_property - def columns_by_merge_config( - self, - ) -> Mapping[str, Mapping[MergeConfig, Sequence[str]]]: - return { - k: mit.map_reduce(v.items(), keyfunc=itemgetter(1), valuefunc=itemgetter(0)) - for k, v in self.merge.items() - } + def merge_fqn(self) -> PyTree[TableMergeConfig]: + # fully qualified key/column names + def fqn(path: tuple[str, ...], cfg: TableMergeConfig) -> TableMergeConfig: + key = self.separator.join((*path, cfg.key)) + columns = OrderedDict({ + self.separator.join((*path, k)): v for k, v in cfg.columns.items() + }) - @cached_property - def ref_columns(self) -> Mapping[str, str]: - return { - k: mit.one(v[RefColumnMergeConfig()]) - for k, v in self.columns_by_merge_config.items() - } + return TableMergeConfig(key=key, columns=columns) + return tree_map_with_path(fqn, self.merge) # pyright: ignore[reportArgumentType] -class TableAligner(TableMergerBase, Hashable): - def __init__(self, **kwargs: object) -> None: - self._config = Config.model_validate(kwargs) - def _col_name(self, *args: str) -> str: - return self._config.separator.join(args) +class TableAligner(TableMerger, Hashable): + def __init__(self, **kwargs: object) -> None: + self._config: Config = Config.model_validate(kwargs) @override - def merge(self, src: Mapping[str, pl.DataFrame]) -> pl.DataFrame: - if unused_keys := src.keys() - self._config.merge.keys(): - logger.warning("unused", keys=sorted(unused_keys)) - - dfs = { - k: src[k] - .sort(self._config.ref_columns[k]) - .rename(lambda col, k=k: self._col_name(k, col)) - for k in self._config.merge - } - k_df_ref = mit.first(self._config.merge.keys()) - df_ref = dfs.pop(k_df_ref) - df_ref_col_ref = self._col_name(k_df_ref, self._config.ref_columns[k_df_ref]) - - logger.debug( - "merging", merge_ref=f"{k_df_ref}[{self._config.ref_columns[k_df_ref]}]" - ) - - for k_merge, df_merge in dfs.items(): - cols_by_merge_config = self._config.columns_by_merge_config[k_merge] - df_merge_col_ref = self._col_name( - k_merge, self._config.ref_columns[k_merge] + def merge(self, src: PyTree[pl.DataFrame]) -> pl.DataFrame: + merge_configs = self._config.merge_fqn + + def get_df(accessor: PyTreeAccessor, cfg: TableMergeConfig) -> pl.DataFrame: + return ( + accessor(src) + .rename(lambda col: self._config.separator.join((*accessor.path, col))) # pyright: ignore[reportUnknownLambdaType, reportUnknownArgumentType] + .sort(cfg.key) ) - for merge_cfg, _df_merge_cols in cols_by_merge_config.items(): - if isinstance(merge_cfg, RefColumnMergeConfig): - continue - - df_merge_cols = tuple( - self._col_name(k_merge, col) for col in _df_merge_cols - ) - - df_ref_height_pre = df_ref.height - match merge_cfg: - case AsofColumnMergeConfig(strategy=strategy, tolerance=tolerance): - df_ref = df_ref.join_asof( - other=df_merge.select(df_merge_col_ref, *df_merge_cols), - left_on=df_ref_col_ref, - right_on=df_merge_col_ref, - strategy=strategy, - tolerance=tolerance, - ).drop_nulls(df_merge_cols) - - case InterpColumnMergeConfig(): - df_ref = ( - # take a union of timestamps - df_ref.join( - df_merge.select(df_merge_col_ref, *df_merge_cols), - how="full", - left_on=df_ref_col_ref, - right_on=df_merge_col_ref, - coalesce=True, - ) - # interpolate - .with_columns( - pl.col(df_merge_cols).interpolate_by(df_ref_col_ref) + dfs = tree_map_with_accessor(get_df, merge_configs) + accessor, *accessors_rest = tree_accessors(merge_configs) + df: pl.DataFrame = accessor(dfs) + left_on = accessor(merge_configs).key + + for accessor in accessors_rest: + other: pl.DataFrame = accessor(dfs) + merge_config: TableMergeConfig = accessor(merge_configs) + key = merge_config.key + + for column, config in merge_config.columns.items(): + df_height_pre = df.height + + with bound_contextvars(key=key, column=column, config=config): + match config: + case AsofColumnMergeConfig( + strategy=strategy, tolerance=tolerance + ): + right_on = key if key == column else uuid4().hex + + df = ( + df.join_asof( + other=other.select({key, column}).rename({ + key: right_on + }), + left_on=left_on, + right_on=right_on, + strategy=strategy, + tolerance=tolerance, + ) + .drop_nulls(column) + .drop({right_on} - {key}) ) - # narrow back to original ref col - .join( - df_ref.select(df_ref_col_ref), - on=df_ref_col_ref, - how="semi", - ) - .sort(df_ref_col_ref) - ).drop_nulls(df_merge_cols) + + case InterpColumnMergeConfig(): + if key == column: + logger.error(msg := "cannot interpolate key") + + raise ValueError(msg) + + right_on = key + + df = ( + # take a union of timestamps + df.join( + other.select(right_on, column), + how="full", + left_on=left_on, + right_on=right_on, + coalesce=True, + ) + # interpolate + .with_columns(pl.col(column).interpolate_by(left_on)) + # narrow back to original ref col + .join(df.select(left_on), on=left_on, how="semi") + .sort(left_on) + ).drop_nulls(column) logger.debug( - "merged", - merge_rows=f"{df_ref_height_pre}->{df_ref.height}", - merge_other=f"{k_merge}[{', '.join(_df_merge_cols)}]", + "merged", column=column, height=f"{df_height_pre}->{df.height}" ) - return df_ref + return df @override def __hash__(self) -> int: diff --git a/src/rbyte/io/table/base.py b/src/rbyte/io/table/base.py index 0b1cfba..0ba3f5a 100644 --- a/src/rbyte/io/table/base.py +++ b/src/rbyte/io/table/base.py @@ -1,29 +1,28 @@ -from collections.abc import Hashable, Mapping +from collections.abc import Hashable from os import PathLike from typing import Protocol, runtime_checkable -import polars as pl - -Table = pl.DataFrame +from optree import PyTree +from polars import DataFrame @runtime_checkable -class TableBuilderBase(Protocol): - def build(self) -> Table: ... +class TableBuilder(Protocol): + def build(self) -> DataFrame: ... @runtime_checkable -class TableReaderBase(Hashable, Protocol): - def read(self, path: PathLike[str]) -> Mapping[str, Table]: ... +class TableReader(Protocol): + def read(self, path: PathLike[str]) -> PyTree[DataFrame]: ... @runtime_checkable -class TableMergerBase(Hashable, Protocol): - def merge(self, src: Mapping[str, Table]) -> Table: ... +class TableMerger(Protocol): + def merge(self, src: PyTree[DataFrame]) -> DataFrame: ... @runtime_checkable -class TableCacheBase(Protocol): +class TableCache(Protocol): def __contains__(self, key: Hashable) -> bool: ... - def get(self, key: Hashable) -> Table | None: ... - def set(self, key: Hashable, value: Table) -> bool: ... + def get(self, key: Hashable) -> DataFrame | None: ... + def set(self, key: Hashable, value: DataFrame) -> bool: ... diff --git a/src/rbyte/io/table/builder.py b/src/rbyte/io/table/builder.py index da72343..16e87a6 100644 --- a/src/rbyte/io/table/builder.py +++ b/src/rbyte/io/table/builder.py @@ -1,74 +1,67 @@ -import operator -from collections import Counter -from collections.abc import Hashable, Sequence -from functools import reduce +from collections.abc import Hashable, Mapping from mmap import ACCESS_READ, mmap +from os import PathLike from pathlib import Path from typing import Annotated, Any, override -import more_itertools as mit import polars as pl -from pydantic import ConfigDict, Field, FilePath, StringConstraints, validate_call +from optree import PyTree, tree_map +from pydantic import Field, StringConstraints, validate_call from structlog import get_logger from xxhash import xxh3_64_intdigest as digest -from rbyte.config.base import BaseModel -from rbyte.io.table.base import ( - TableBuilderBase, - TableCacheBase, - TableMergerBase, - TableReaderBase, -) +from rbyte.config import BaseModel + +from .base import TableBuilder as _TableBuilder +from .base import TableCache, TableMerger, TableReader logger = get_logger(__name__) class TableReaderConfig(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - path: FilePath - reader: TableReaderBase + path: PathLike[str] + reader: TableReader -class TableBuilder(TableBuilderBase): - @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) +class TableBuilder(_TableBuilder): + @validate_call(config=BaseModel.model_config) def __init__( self, - readers: Annotated[Sequence[TableReaderConfig], Field(min_length=1)], - merger: TableMergerBase, + readers: Annotated[Mapping[str, TableReaderConfig], Field(min_length=1)], + merger: TableMerger, filter: Annotated[str, StringConstraints(strip_whitespace=True)] | None = None, # noqa: A002 - cache: TableCacheBase | None = None, + cache: TableCache | None = None, ) -> None: super().__init__() - self._readers = readers - self._merger = merger - self._filter = filter - self._cache = cache + self._readers: Mapping[str, TableReaderConfig] = readers + self._merger: TableMerger = merger + self._filter: str | None = filter + self._cache: TableCache | None = cache def _build_cache_key(self) -> Hashable: - from rbyte import __version__ as rbyte_version # noqa: PLC0415 + from rbyte import __version__ # noqa: PLC0415 - key: list[Any] = [rbyte_version, hash(self._merger)] + key: list[Any] = [__version__, hash(self._merger)] if self._filter is not None: key.append(digest(self._filter)) - for reader_config in self._readers: + for reader_name, reader_config in sorted(self._readers.items()): with ( Path(reader_config.path).open("rb") as _f, mmap(_f.fileno(), 0, access=ACCESS_READ) as f, ): file_hash = digest(f) # pyright: ignore[reportArgumentType] - key.append((file_hash, hash(reader_config.reader))) + key.append((file_hash, digest(reader_name), hash(reader_config.reader))) return tuple(key) @override def build(self) -> pl.DataFrame: match self._cache: - case TableCacheBase(): + case TableCache(): key = self._build_cache_key() if key in self._cache: logger.debug("reading table from cache") @@ -88,15 +81,10 @@ def build(self) -> pl.DataFrame: return self._build() def _build(self) -> pl.DataFrame: - reader_dfs = [cfg.reader.read(cfg.path) for cfg in self._readers] - if duplicate_keys := { - k for k, count in Counter(mit.flatten(reader_dfs)).items() if count > 1 - }: - logger.error(msg := "readers produced duplicate keys", keys=duplicate_keys) - - raise RuntimeError(msg) - - dfs = reduce(operator.or_, reader_dfs) + dfs: PyTree[pl.DataFrame] = tree_map( + lambda cfg: cfg.reader.read(cfg.path), # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType, reportUnknownLambdaType] + self._readers, # pyright: ignore[reportArgumentType] + ) df = self._merger.merge(dfs) return df.sql(f"select * from self where ({self._filter or True})") # noqa: S608 diff --git a/src/rbyte/io/table/concater.py b/src/rbyte/io/table/concater.py index c8bab57..9d3afb0 100644 --- a/src/rbyte/io/table/concater.py +++ b/src/rbyte/io/table/concater.py @@ -1,13 +1,15 @@ import json -from collections.abc import Hashable, Mapping +from collections.abc import Hashable from typing import override import polars as pl +from optree import PyTree, tree_leaves, tree_map_with_path from polars._typing import ConcatMethod from xxhash import xxh3_64_intdigest as digest from rbyte.config import BaseModel -from rbyte.io.table.base import TableMergerBase + +from .base import TableMerger class Config(BaseModel): @@ -15,19 +17,21 @@ class Config(BaseModel): method: ConcatMethod = "horizontal" -class TableConcater(TableMergerBase, Hashable): +class TableConcater(TableMerger, Hashable): def __init__(self, **kwargs: object) -> None: - self._config = Config.model_validate(kwargs) + self._config: Config = Config.model_validate(kwargs) @override - def merge(self, src: Mapping[str, pl.DataFrame]) -> pl.DataFrame: - if (separator := self._config.separator) is not None: - src = { - k: df.select(pl.all().name.prefix(f"{k}{separator}")) - for k, df in src.items() - } - - return pl.concat(src.values(), how=self._config.method, rechunk=True) + def merge(self, src: PyTree[pl.DataFrame]) -> pl.DataFrame: + if (sep := self._config.separator) is not None: + src = tree_map_with_path( + lambda path, df: df.rename( # pyright: ignore[reportUnknownArgumentType,reportUnknownLambdaType, reportUnknownMemberType] + lambda col: f"{sep.join([*path, col])}" # pyright: ignore[reportUnknownArgumentType,reportUnknownLambdaType] + ), + src, + ) + + return pl.concat(tree_leaves(src), how=self._config.method, rechunk=True) @override def __hash__(self) -> int: diff --git a/src/rbyte/io/table/hdf5/__init__.py b/src/rbyte/io/table/hdf5/__init__.py deleted file mode 100644 index ccff9e3..0000000 --- a/src/rbyte/io/table/hdf5/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .reader import Hdf5TableReader - -__all__ = ["Hdf5TableReader"] diff --git a/src/rbyte/io/table/hdf5/reader.py b/src/rbyte/io/table/hdf5/reader.py deleted file mode 100644 index 7772634..0000000 --- a/src/rbyte/io/table/hdf5/reader.py +++ /dev/null @@ -1,92 +0,0 @@ -import json -from collections.abc import Hashable, Mapping -from enum import StrEnum, unique -from functools import cached_property -from os import PathLike -from typing import Any, cast, override - -import numpy.typing as npt -import polars as pl -from h5py import Dataset, File, Group -from optree import tree_map -from polars._typing import PolarsDataType -from polars.datatypes import ( - DataType, # pyright: ignore[reportUnusedImport] # noqa: F401 - DataTypeClass, # pyright: ignore[reportUnusedImport] # noqa: F401 -) -from pydantic import ConfigDict -from xxhash import xxh3_64_intdigest as digest - -from rbyte.config import BaseModel -from rbyte.config.base import HydraConfig -from rbyte.io.table.base import TableReaderBase - - -class Config(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - fields: Mapping[str, Mapping[str, HydraConfig[PolarsDataType] | None]] - - -@unique -class SpecialField(StrEnum): - idx = "_idx_" - - -class Hdf5TableReader(TableReaderBase, Hashable): - def __init__(self, **kwargs: object) -> None: - self._config = Config.model_validate(kwargs) - - @override - def __hash__(self) -> int: - config = self._config.model_dump_json() - # roundtripping json to work around https://github.com/pydantic/pydantic/issues/7424 - config_str = json.dumps(json.loads(config), sort_keys=True) - - return digest(config_str) - - @override - def read(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: - dfs: Mapping[str, pl.DataFrame] = {} - - with File(path) as f: - for group_key, schema in self._fields.items(): - match group := f[group_key]: - case Group(): - series: list[pl.Series] = [] - for name, dtype in schema.items(): - match name: - case SpecialField.idx: - pass - - case _: - match dataset := group[name]: - case Dataset(): - values = cast(npt.NDArray[Any], dataset[:]) - series.append( - pl.Series( - name=name, - values=values, - dtype=dtype, - ) - ) - - case _: - raise NotImplementedError - - df = pl.DataFrame(data=series) # pyright: ignore[reportGeneralTypeIssues] - if (idx_name := SpecialField.idx) in schema: - df = df.with_row_index(idx_name).cast({ - idx_name: schema[idx_name] or pl.UInt32 - }) - - dfs[group_key] = df.rechunk() - - case _: - raise NotImplementedError - - return dfs - - @cached_property - def _fields(self) -> Mapping[str, Mapping[str, PolarsDataType | None]]: - return tree_map(HydraConfig.instantiate, self._config.fields) # pyright: ignore[reportArgumentType, reportUnknownArgumentType, reportUnknownMemberType, reportUnknownVariableType, reportReturnType] diff --git a/src/rbyte/io/table/json/__init__.py b/src/rbyte/io/table/json/__init__.py deleted file mode 100644 index 8296e67..0000000 --- a/src/rbyte/io/table/json/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .reader import JsonTableReader - -__all__ = ["JsonTableReader"] diff --git a/src/rbyte/io/table/mcap/__init__.py b/src/rbyte/io/table/mcap/__init__.py deleted file mode 100644 index fef5bc6..0000000 --- a/src/rbyte/io/table/mcap/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .reader import McapTableReader - -__all__ = ["McapTableReader"] diff --git a/src/rbyte/io/table/rrd/__init__.py b/src/rbyte/io/table/rrd/__init__.py deleted file mode 100644 index 46db84c..0000000 --- a/src/rbyte/io/table/rrd/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .reader import RrdTableReader - -__all__ = ["RrdTableReader"] diff --git a/src/rbyte/io/table/transforms/base.py b/src/rbyte/io/table/transforms/base.py index 51180a4..a287bb2 100644 --- a/src/rbyte/io/table/transforms/base.py +++ b/src/rbyte/io/table/transforms/base.py @@ -1,8 +1,8 @@ from typing import Protocol, runtime_checkable -from rbyte.io.table.base import Table +import polars as pl @runtime_checkable class TableTransform(Protocol): - def __call__(self, src: Table) -> Table: ... + def __call__(self, src: pl.DataFrame) -> pl.DataFrame: ... diff --git a/src/rbyte/io/table/transforms/fps_resampler.py b/src/rbyte/io/table/transforms/fps_resampler.py index 1b2c128..62878bd 100644 --- a/src/rbyte/io/table/transforms/fps_resampler.py +++ b/src/rbyte/io/table/transforms/fps_resampler.py @@ -1,16 +1,16 @@ from math import lcm -from typing import override +from typing import final, override +from uuid import uuid4 import polars as pl from pydantic import PositiveInt, validate_call -from rbyte.io.table.base import Table - from .base import TableTransform +@final class FpsResampler(TableTransform): - IDX_COL = "__idx" + IDX_COL = uuid4().hex @validate_call def __init__(self, source_fps: PositiveInt, target_fps: PositiveInt) -> None: @@ -21,7 +21,7 @@ def __init__(self, source_fps: PositiveInt, target_fps: PositiveInt) -> None: self._fps_lcm = lcm(source_fps, target_fps) @override - def __call__(self, src: Table) -> Table: + def __call__(self, src: pl.DataFrame) -> pl.DataFrame: return ( src.with_row_index(self.IDX_COL) .with_columns(pl.col(self.IDX_COL) * (self._fps_lcm // self._source_fps)) diff --git a/src/rbyte/io/table/yaak/__init__.py b/src/rbyte/io/table/yaak/__init__.py deleted file mode 100644 index e83fb67..0000000 --- a/src/rbyte/io/table/yaak/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .reader import YaakMetadataTableReader - -__all__ = ["YaakMetadataTableReader"] diff --git a/src/rbyte/io/frame/video/__init__.py b/src/rbyte/io/video/__init__.py similarity index 100% rename from src/rbyte/io/frame/video/__init__.py rename to src/rbyte/io/video/__init__.py diff --git a/src/rbyte/io/video/ffmpeg_source.py b/src/rbyte/io/video/ffmpeg_source.py new file mode 100644 index 0000000..754352e --- /dev/null +++ b/src/rbyte/io/video/ffmpeg_source.py @@ -0,0 +1,41 @@ +from collections.abc import Iterable +from pathlib import Path +from typing import cast, override + +import numpy.typing as npt +import torch +from jaxtyping import UInt8 +from pydantic import FilePath, NonNegativeInt, validate_call +from torch import Tensor +from video_reader import ( + PyVideoReader, # pyright: ignore[reportAttributeAccessIssue, reportUnknownVariableType] +) + +from rbyte.io.base import TensorSource + + +class FfmpegFrameSource(TensorSource): + @validate_call + def __init__( + self, + path: FilePath, + threads: NonNegativeInt | None = None, + resize_shorter_side: NonNegativeInt | None = None, + ) -> None: + super().__init__() + + self._reader: PyVideoReader = PyVideoReader( + filename=Path(path).resolve().as_posix(), + threads=threads, + resize_shorter_side=resize_shorter_side, + ) + + @override + def __getitem__(self, indexes: Iterable[int]) -> UInt8[Tensor, "b h w c"]: + batch = cast(npt.ArrayLike, self._reader.get_batch(indexes)) # pyright: ignore[reportUnknownMemberType] + + return torch.from_numpy(batch) # pyright: ignore[reportUnknownMemberType] + + @override + def __len__(self) -> int: + return int(self._reader.get_info()["frame_count"]) # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType] diff --git a/src/rbyte/io/frame/video/vali_reader.py b/src/rbyte/io/video/vali_source.py similarity index 80% rename from src/rbyte/io/frame/video/vali_reader.py rename to src/rbyte/io/video/vali_source.py index 154f7a4..8f95a3a 100644 --- a/src/rbyte/io/frame/video/vali_reader.py +++ b/src/rbyte/io/video/vali_source.py @@ -1,4 +1,4 @@ -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Iterable, Mapping from functools import cached_property from itertools import pairwise from typing import Annotated, override @@ -7,17 +7,12 @@ import python_vali as vali import torch from jaxtyping import Shaped -from pydantic import ( - BeforeValidator, - ConfigDict, - FilePath, - NonNegativeInt, - validate_call, -) +from pydantic import BeforeValidator, FilePath, NonNegativeInt, validate_call from structlog import get_logger from torch import Tensor -from rbyte.io.frame.base import FrameReader +from rbyte.config.base import BaseModel +from rbyte.io.base import TensorSource logger = get_logger(__name__) @@ -29,8 +24,8 @@ ] -class ValiGpuFrameReader(FrameReader): - @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) +class ValiGpuFrameSource(TensorSource): + @validate_call(config=BaseModel.model_config) def __init__( self, path: FilePath, @@ -42,13 +37,13 @@ def __init__( ) -> None: super().__init__() - self._gpu_id = gpu_id + self._gpu_id: int = gpu_id - self._decoder = vali.PyDecoder( + self._decoder: vali.PyDecoder = vali.PyDecoder( input=path.resolve().as_posix(), opts={}, gpu_id=self._gpu_id ) - self._pixel_format_chain = ( + self._pixel_format_chain: tuple[PixelFormat, ...] = ( (self._decoder.Format, *pixel_format_chain) if mit.first(pixel_format_chain, default=None) != self._decoder.Format else pixel_format_chain @@ -77,7 +72,9 @@ def _surfaces(self) -> Mapping[vali.PixelFormat, vali.Surface]: for pixel_format in self._pixel_format_chain } - def _read(self, index: int) -> Shaped[Tensor, "c h w"] | Shaped[Tensor, "h w c"]: + def _read_frame( + self, index: int + ) -> Shaped[Tensor, "c h w"] | Shaped[Tensor, "h w c"]: seek_ctx = vali.SeekContext(seek_frame=index) success, details = self._decoder.DecodeSingleSurface( # pyright: ignore[reportUnknownMemberType] self._surfaces[self._decoder.Format], seek_ctx @@ -106,11 +103,11 @@ def _read(self, index: int) -> Shaped[Tensor, "c h w"] | Shaped[Tensor, "h w c"] return torch.from_dlpack(surface).clone().detach() # pyright: ignore[reportPrivateImportUsage] @override - def read( + def __getitem__( self, indexes: Iterable[int] ) -> Shaped[Tensor, "b h w c"] | Shaped[Tensor, "b c h w"]: - return torch.stack([self._read(index) for index in indexes]) + return torch.stack([self._read_frame(index) for index in indexes]) @override - def get_available_indexes(self) -> Sequence[int]: - return range(self._decoder.NumFrames) + def __len__(self) -> int: + return self._decoder.NumFrames diff --git a/src/rbyte/io/yaak/__init__.py b/src/rbyte/io/yaak/__init__.py new file mode 100644 index 0000000..8cfd72d --- /dev/null +++ b/src/rbyte/io/yaak/__init__.py @@ -0,0 +1,3 @@ +from .table_reader import YaakMetadataTableReader + +__all__ = ["YaakMetadataTableReader"] diff --git a/src/rbyte/io/table/yaak/idl-repo b/src/rbyte/io/yaak/idl-repo similarity index 100% rename from src/rbyte/io/table/yaak/idl-repo rename to src/rbyte/io/yaak/idl-repo diff --git a/src/rbyte/io/table/yaak/message_iterator.py b/src/rbyte/io/yaak/message_iterator.py similarity index 92% rename from src/rbyte/io/table/yaak/message_iterator.py rename to src/rbyte/io/yaak/message_iterator.py index 4c9203f..6efda9a 100644 --- a/src/rbyte/io/table/yaak/message_iterator.py +++ b/src/rbyte/io/yaak/message_iterator.py @@ -26,9 +26,9 @@ class YaakMetadataMessageIterator(Iterator[tuple[type[Message], bytes]]): 8: can_pb2.VehicleState, } - FILE_HEADER_VERSION = 1 - FILE_HEADER_LEN = 12 - MESSAGE_HEADER_LEN = 8 + FILE_HEADER_VERSION: int = 1 + FILE_HEADER_LEN: int = 12 + MESSAGE_HEADER_LEN: int = 8 def __init__( self, @@ -37,7 +37,7 @@ def __init__( ) -> None: super().__init__() - self._file = file + self._file: BinaryIO | mmap = file for expected_val, desc in ( (self.FILE_HEADER_LEN, "file header length"), @@ -49,7 +49,7 @@ def __init__( raise ValueError(msg) if message_types is None: - self._message_types = self.MESSAGE_TYPES + self._message_types: Mapping[int, type[Message]] = self.MESSAGE_TYPES else: if unknown_message_types := set(message_types) - set( self.MESSAGE_TYPES.values() diff --git a/src/rbyte/io/table/yaak/proto/__init__.py b/src/rbyte/io/yaak/proto/__init__.py similarity index 100% rename from src/rbyte/io/table/yaak/proto/__init__.py rename to src/rbyte/io/yaak/proto/__init__.py diff --git a/src/rbyte/io/table/yaak/reader.py b/src/rbyte/io/yaak/table_reader.py similarity index 87% rename from src/rbyte/io/table/yaak/reader.py rename to src/rbyte/io/yaak/table_reader.py index dd75054..2208e57 100644 --- a/src/rbyte/io/table/yaak/reader.py +++ b/src/rbyte/io/yaak/table_reader.py @@ -10,20 +10,20 @@ import more_itertools as mit import polars as pl from google.protobuf.message import Message -from optree import tree_map +from optree import PyTree, tree_map from polars._typing import PolarsDataType from polars.datatypes import ( DataType, # pyright: ignore[reportUnusedImport] # noqa: F401 DataTypeClass, # pyright: ignore[reportUnusedImport] # noqa: F401 ) from ptars import HandlerPool -from pydantic import ConfigDict, ImportString +from pydantic import ImportString from structlog import get_logger from tqdm import tqdm from xxhash import xxh3_64_intdigest as digest from rbyte.config.base import BaseModel, HydraConfig -from rbyte.io.table.base import TableReaderBase +from rbyte.io.table.base import TableReader from .message_iterator import YaakMetadataMessageIterator from .proto import sensor_pb2 @@ -32,21 +32,19 @@ class Config(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - fields: Mapping[ ImportString[type[Message]], Mapping[str, HydraConfig[PolarsDataType] | None] ] -class YaakMetadataTableReader(TableReaderBase): +class YaakMetadataTableReader(TableReader): def __init__(self, **kwargs: object) -> None: super().__init__() - self._config = Config.model_validate(kwargs) + self._config: Config = Config.model_validate(kwargs) @override - def read(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: + def read(self, path: PathLike[str]) -> PyTree[pl.DataFrame]: with Path(path).open("rb") as _f, mmap(_f.fileno(), 0, access=ACCESS_READ) as f: handler_pool = HandlerPool() @@ -68,7 +66,7 @@ def read(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: ) ]) .select(schema), - schema=schema, # pyright: ignore[reportArgumentType] + schema=schema, rechunk=True, ), ) @@ -83,7 +81,7 @@ def read(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: ).items() } - return dfs + return dfs # pyright: ignore[reportReturnType] @override def __hash__(self) -> int: diff --git a/src/rbyte/sample/__init__.py b/src/rbyte/sample/__init__.py index e69de29..d60e7c1 100644 --- a/src/rbyte/sample/__init__.py +++ b/src/rbyte/sample/__init__.py @@ -0,0 +1,3 @@ +from .greedy_builder import GreedySampleBuilder + +__all__ = ["GreedySampleBuilder"] diff --git a/src/rbyte/sample/base.py b/src/rbyte/sample/base.py index f5d08f5..67d3e07 100644 --- a/src/rbyte/sample/base.py +++ b/src/rbyte/sample/base.py @@ -4,5 +4,5 @@ @runtime_checkable -class SampleTableBuilder(Protocol): +class SampleBuilder(Protocol): def build(self, source: pl.LazyFrame) -> pl.LazyFrame: ... diff --git a/src/rbyte/sample/builder.py b/src/rbyte/sample/greedy_builder.py similarity index 75% rename from src/rbyte/sample/builder.py rename to src/rbyte/sample/greedy_builder.py index 72e8f32..eb4cc32 100644 --- a/src/rbyte/sample/builder.py +++ b/src/rbyte/sample/greedy_builder.py @@ -4,14 +4,14 @@ import polars as pl from pydantic import PositiveInt, StringConstraints, validate_call -from .base import SampleTableBuilder +from .base import SampleBuilder -class GreedySampleTableBuilder(SampleTableBuilder): +class GreedySampleBuilder(SampleBuilder): @validate_call def __init__( self, - index_column: str, + index_column: str | None = None, length: PositiveInt = 1, min_step: PositiveInt = 1, stride: PositiveInt = 1, @@ -20,15 +20,18 @@ def __init__( ) -> None: super().__init__() - self._index_column = index_column - self._length = length - self._min_step = min_step - self._stride = stride - self._filter = filter + self._index_column: str | None = index_column + self._length: int = length + self._min_step: int = min_step + self._stride: int = stride + self._filter: str | None = filter @override def build(self, source: pl.LazyFrame) -> pl.LazyFrame: - idx_col = self._index_column + if (idx_col := self._index_column) is None: + idx_col = uuid4().hex + source = source.with_row_index(idx_col) + idx_dtype = source.select(idx_col).collect_schema()[idx_col] return ( @@ -56,6 +59,6 @@ def build(self, source: pl.LazyFrame) -> pl.LazyFrame: .filter(pl.col(idx_col).list.len() == self._length) .sql(f"select * from self where ({self._filter or True})") # noqa: S608 .sort(sample_idx_col) - .select(pl.exclude(sample_idx_col)) + .drop([sample_idx_col, *([idx_col] if self._index_column is None else [])]) .select(pl.all().list.to_array(self._length)) ) diff --git a/src/rbyte/scripts/build_table.py b/src/rbyte/scripts/build_table.py deleted file mode 100644 index 0144279..0000000 --- a/src/rbyte/scripts/build_table.py +++ /dev/null @@ -1,24 +0,0 @@ -from collections.abc import Callable -from typing import cast - -import hydra -from hydra.utils import instantiate -from omegaconf import DictConfig -from structlog import get_logger - -from rbyte.io.table.base import Table, TableBuilderBase - -logger = get_logger(__name__) - - -@hydra.main(version_base=None) -def main(config: DictConfig) -> None: - table_builder = cast(TableBuilderBase, instantiate(config.table_builder)) - table_writer = cast(Callable[[Table], None], instantiate(config.table_writer)) - table = table_builder.build() - - return table_writer(table) - - -if __name__ == "__main__": - main() diff --git a/src/rbyte/scripts/read_frames.py b/src/rbyte/scripts/read_frames.py deleted file mode 100644 index f370f24..0000000 --- a/src/rbyte/scripts/read_frames.py +++ /dev/null @@ -1,116 +0,0 @@ -from typing import Any, cast - -import hydra -import more_itertools as mit -import numpy as np -import numpy.typing as npt -import rerun as rr -from omegaconf import DictConfig, OmegaConf -from pydantic import ConfigDict, NonNegativeInt -from structlog import get_logger -from structlog.contextvars import bound_contextvars -from tqdm import tqdm - -from rbyte.config.base import BaseModel, HydraConfig -from rbyte.io.frame.base import FrameReader -from rbyte.viz.loggers.rerun_logger import FrameConfig - -logger = get_logger(__name__) - - -class Config(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - frame_reader: HydraConfig[FrameReader] - frame_config: FrameConfig - application_id: str - entity_path: str - batch_size: NonNegativeInt = 1 - - -@hydra.main(version_base=None) -def main(_config: DictConfig) -> None: - config = Config.model_validate(OmegaConf.to_object(_config)) - frame_reader = config.frame_reader.instantiate() - frame_config = config.frame_config - - rr.init(config.application_id, spawn=True) - - for frame_indexes in mit.chunked( - tqdm(sorted(frame_reader.get_available_indexes())), - config.batch_size, - strict=False, - ): - tensor = frame_reader.read(frame_indexes) - - with bound_contextvars(frame_config=frame_config, shape=tensor.shape): - match frame_config: - case {rr.Image: image_format} | {rr.DepthImage: image_format}: - match ( - image_format.pixel_format, - image_format.color_model, - tensor.shape, - ): - case None, color_model, shape: - match color_model, shape: - case ( - (rr.ColorModel.L, (batch, height, width, 1)) - | (rr.ColorModel.RGB, (batch, height, width, 3)) - | (rr.ColorModel.RGBA, (batch, height, width, 4)) - ): - pass - - case ( - (rr.ColorModel.L, (batch, 1, height, width)) - | (rr.ColorModel.RGB, (batch, 3, height, width)) - | (rr.ColorModel.RGBA, (batch, 4, height, width)) - ): - tensor = tensor.permute(0, 2, 3, 1) - - case _: - logger.error("not implemented") - - raise NotImplementedError - - case rr.PixelFormat.NV12, _, (batch, dim, width): - height = int(dim / 1.5) - - case _: - logger.error("not implemented") - - raise NotImplementedError - - arr = cast(npt.NDArray[Any], tensor.cpu().numpy()) # pyright: ignore[reportUnknownMemberType] - image_format = rr.components.ImageFormat( - height=height, - width=width, - pixel_format=image_format.pixel_format, - color_model=image_format.color_model, - channel_datatype=rr.ChannelDatatype.from_np_dtype(arr.dtype), - ) - - components = [ - mit.one(frame_config).indicator(), - rr.components.ImageFormatBatch([image_format] * batch), - rr.components.ImageBufferBatch( - arr.reshape(batch, -1).view(np.uint8) - ), - ] - - case _: - logger.error("not implemented") - - raise NotImplementedError - - times = [rr.TimeSequenceColumn("frame_index", frame_indexes)] - - rr.send_columns( - entity_path=config.entity_path, - times=times, - components=components, # pyright: ignore[reportArgumentType] - strict=True, - ) - - -if __name__ == "__main__": - main() diff --git a/src/rbyte/utils/dataframe/cache.py b/src/rbyte/utils/dataframe/cache.py index 359981b..93fa76d 100644 --- a/src/rbyte/utils/dataframe/cache.py +++ b/src/rbyte/utils/dataframe/cache.py @@ -7,16 +7,16 @@ from diskcache import Cache from pydantic import ByteSize, DirectoryPath, NewPath, validate_call -from rbyte.io.table.base import TableCacheBase +from rbyte.io.table.base import TableCache -class DataframeDiskCache(TableCacheBase): +class DataframeDiskCache(TableCache): @validate_call def __init__( self, directory: DirectoryPath | NewPath, size_limit: ByteSize | None = None ) -> None: super().__init__() - self._cache = Cache(directory=directory, size_limit=size_limit) + self._cache: Cache = Cache(directory=directory, size_limit=size_limit) @override def __contains__(self, key: Hashable) -> bool: diff --git a/src/rbyte/utils/functional.py b/src/rbyte/utils/functional.py new file mode 100644 index 0000000..343d24f --- /dev/null +++ b/src/rbyte/utils/functional.py @@ -0,0 +1,41 @@ +from collections.abc import Sequence + +import more_itertools as mit +import torch +import torch.nn.functional as F # noqa: N812 +from jaxtyping import Float +from torch import Tensor + + +def pad_dim( + input: Float[Tensor, "..."], # noqa: A002 + *, + pad: tuple[int, int], + dim: int, + mode: str = "constant", + value: float | None = None, +) -> Float[Tensor, "..."]: + _pad = [(0, 0) for _ in input.shape] + _pad[dim] = pad + _pad = list(mit.flatten(reversed(_pad))) + + return F.pad(input, _pad, mode=mode, value=value) + + +def pad_sequence( + sequences: Sequence[Float[Tensor, "..."]], dim: int, value: float = 0.0 +) -> Float[Tensor, "..."]: + max_length = max(sequence.shape[dim] for sequence in sequences) + + padded = ( + pad_dim( + sequence, + pad=(0, max_length - sequence.shape[dim]), + dim=dim, + mode="constant", + value=value, + ) + for sequence in sequences + ) + + return torch.stack(list(padded)) diff --git a/src/rbyte/utils/mcap/decoders/protobuf_decoder_factory.py b/src/rbyte/utils/mcap/decoders/protobuf_decoder_factory.py index 151651d..14e5506 100644 --- a/src/rbyte/utils/mcap/decoders/protobuf_decoder_factory.py +++ b/src/rbyte/utils/mcap/decoders/protobuf_decoder_factory.py @@ -21,7 +21,7 @@ class ProtobufDecoderFactory(McapDecoderFactory): def __init__(self) -> None: - self._handler_pool = HandlerPool() + self._handler_pool: HandlerPool = HandlerPool() @override def decoder_for( diff --git a/src/rbyte/viz/loggers/rerun_logger.py b/src/rbyte/viz/loggers/rerun_logger.py index 1888b73..588320e 100644 --- a/src/rbyte/viz/loggers/rerun_logger.py +++ b/src/rbyte/viz/loggers/rerun_logger.py @@ -1,15 +1,6 @@ -from collections.abc import Mapping, Sequence +from collections.abc import Iterable, Mapping, Sequence from functools import cache, cached_property -from typing import ( - Annotated, - Any, - Literal, - Protocol, - Self, - cast, - override, - runtime_checkable, -) +from typing import Annotated, Any, Protocol, Self, cast, override, runtime_checkable import more_itertools as mit import numpy as np @@ -18,14 +9,17 @@ import rerun.blueprint as rrb from pydantic import ( BeforeValidator, - ConfigDict, Field, ImportString, + RootModel, model_validator, validate_call, ) from pydantic.types import AnyType -from rerun._baseclasses import Archetype # noqa: PLC2701 +from rerun._baseclasses import ( + Archetype, # noqa: PLC2701 + ComponentBatchLike, +) from rerun._send_columns import TimeColumnLike # noqa: PLC2701 from structlog import get_logger from structlog.contextvars import bound_contextvars @@ -62,37 +56,41 @@ def validate_model(self: Self) -> Self: RerunImportString = Annotated[ ImportString[AnyType], - BeforeValidator(lambda x: f"rerun.{x}" if not x.startswith("rerun.") else x), -] - -FrameConfig = Annotated[ - Mapping[RerunImportString[type[Archetype]], ImageFormat], Field(max_length=1) + BeforeValidator( + lambda x: f"rerun.{x}" + if isinstance(x, str) and not x.startswith("rerun.") + else x + ), ] -TableConfig = RerunImportString[type[TimeColumn | Archetype]] +TimeConfig = RerunImportString[type[TimeColumn]] -class Schema(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) +ComponentConfig = ( + RerunImportString[type[Archetype]] + | Annotated[ + Mapping[RerunImportString[type[rr.Image | rr.DepthImage]], ImageFormat], + Field(max_length=1), + ] +) - frame: Mapping[str, FrameConfig] = Field(default_factory=dict) - table: Mapping[str, TableConfig] = Field(default_factory=dict) +class Schema(RootModel[Mapping[str, TimeConfig | ComponentConfig]]): @cached_property - def times(self) -> Mapping[tuple[Literal["table"], str], TimeColumn]: - return { - ("table", k): v for k, v in self.table.items() if isinstance(v, TimeColumn) - } + def times(self) -> Mapping[str, TimeColumn]: + return {k: v for k, v in self.root.items() if isinstance(v, TimeColumn)} @cached_property - def components(self) -> Mapping[tuple[str, str], FrameConfig | type[Archetype]]: - return {("frame", k): v for k, v in self.frame.items()} | { - ("table", k): v for k, v in self.table.items() if issubclass(v, Archetype) - } + def components( + self, + ) -> Mapping[ + str, type[Archetype] | Mapping[type[rr.Image | rr.DepthImage], ImageFormat] + ]: + return {k: v for k, v in self.root.items() if not isinstance(v, TimeColumn)} # pyright: ignore[reportReturnType] class RerunLogger(Logger[Batch]): - @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) + @validate_call(config=BaseModel.model_config) def __init__( self, *, @@ -102,9 +100,9 @@ def __init__( ) -> None: super().__init__() - self._schema = schema - self._spawn = spawn - self._blueprint = blueprint + self._schema: Schema = schema + self._spawn: bool = spawn + self._blueprint: rrb.BlueprintLike | None = blueprint @cache # noqa: B019 def _get_recording(self, application_id: str) -> rr.RecordingStream: @@ -116,102 +114,93 @@ def _get_recording(self, application_id: str) -> rr.RecordingStream: return recording + @classmethod + def _build_components( + cls, + array: npt.NDArray[Any], + schema: type[Archetype] | Mapping[type[rr.Image | rr.DepthImage], ImageFormat], + ) -> Iterable[ComponentBatchLike]: + match schema: + case rr.Scalar: + return [schema.indicator(), rr.components.ScalarBatch(array)] + + case rr.Points3D: + match shape := array.shape: + case (n, 3): + batch = rr.components.Position3DBatch(array) + + case (s, n, 3): + batch = rr.components.Position3DBatch( + array.reshape(-1, 3) + ).partition([n] * s) + + case _: + logger.debug("not implemented", shape=shape) + + raise NotImplementedError + + return [schema.indicator(), batch] + + case rr.Tensor: + return [schema.indicator(), rr.components.TensorDataBatch(array)] + + case {rr.Image: image_format} | {rr.DepthImage: image_format}: + with bound_contextvars(image_format=image_format, shape=array.shape): + match ( + image_format.pixel_format, + image_format.color_model, + array.shape, + ): + case None, rr.ColorModel(), (_batch, height, width, _): + pass + + case rr.PixelFormat.NV12, None, (_batch, dim, width): + height = int(dim / 1.5) + + case _: + logger.error("not implemented") + + raise NotImplementedError + + image_format = rr.components.ImageFormat( + height=height, + width=width, + pixel_format=image_format.pixel_format, + color_model=image_format.color_model, + channel_datatype=rr.ChannelDatatype.from_np_dtype(array.dtype), + ) + return [ + mit.one(schema).indicator(), + rr.components.ImageFormatBatch([image_format] * _batch), + rr.components.ImageBufferBatch( + array.reshape(_batch, -1).view(np.uint8) + ), + ] + + case _: + logger.error("not implemented") + + raise NotImplementedError + @override def log(self, batch_idx: int, batch: Batch) -> None: - # NOTE: zip because batch.meta.input_id is NonTensorData and isn't indexed - for input_id, sample in zip( # pyright: ignore[reportUnknownVariableType] - batch.get(path := ("meta", "input_id")), # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType, reportAttributeAccessIssue] - batch.exclude(path), # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType, reportAttributeAccessIssue] - strict=True, - ): - with self._get_recording(input_id): # pyright: ignore[reportUnknownArgumentType] + for i, sample in enumerate(batch.data): # pyright: ignore[reportUnknownVariableType] + with self._get_recording(batch.meta.input_id[i]): # pyright: ignore[reportUnknownArgumentType, reportIndexIssue] times: Sequence[TimeColumn] = [ - v(timeline="/".join(k), times=sample.get(k).numpy()) # pyright: ignore[reportUnknownMemberType, reportCallIssue] - for k, v in self._schema.times.items() + column(timeline=timeline, times=sample.get(timeline).numpy()) # pyright: ignore[reportUnknownMemberType, reportCallIssue] + for timeline, column in self._schema.times.items() ] - for path, schema in self._schema.components.items(): - with bound_contextvars(path=path, schema=schema): - arr = cast(npt.NDArray[Any], sample.get(path).cpu().numpy()) # pyright: ignore[reportUnknownMemberType] - match schema: - case rr.Scalar: - components = [ - schema.indicator(), - rr.components.ScalarBatch(arr), - ] - - case rr.Points3D: - s, n, *_ = arr.shape - components = [ - schema.indicator(), - rr.components.Position3DBatch( - arr.reshape(s * n, -1) - ).partition([n for _ in range(s)]), - ] - - case rr.Tensor: - components = [ - schema.indicator(), - rr.components.TensorDataBatch(arr), - ] - - case {rr.Image: image_format} | { - rr.DepthImage: image_format - }: - with bound_contextvars( - image_format=image_format, shape=arr.shape - ): - match ( - image_format.pixel_format, - image_format.color_model, - arr.shape, - ): - case None, rr.ColorModel(), ( - _batch, - height, - width, - _, - ): - pass - - case rr.PixelFormat.NV12, None, ( - _batch, - dim, - width, - ): - height = int(dim / 1.5) - - case _: - logger.error("not implemented") - - raise NotImplementedError - - image_format = rr.components.ImageFormat( - height=height, - width=width, - pixel_format=image_format.pixel_format, - color_model=image_format.color_model, - channel_datatype=rr.ChannelDatatype.from_np_dtype( - arr.dtype - ), - ) - components = [ - mit.one(schema).indicator(), - rr.components.ImageFormatBatch( - [image_format] * _batch - ), - rr.components.ImageBufferBatch( - arr.reshape(_batch, -1).view(np.uint8) - ), - ] - - case _: - logger.error("not implemented") - - raise NotImplementedError + for entity_path, schema in self._schema.components.items(): + with bound_contextvars(path=entity_path, schema=schema): + array = cast( + npt.NDArray[Any], + sample.get(entity_path).cpu().numpy(), # pyright: ignore[reportUnknownMemberType] + ) + components = self._build_components(array, schema) rr.send_columns( - entity_path="/".join(path), + entity_path=entity_path, times=times, components=components, # pyright: ignore[reportArgumentType] strict=True, diff --git a/tests/data/mimicgen/.gitattributes b/tests/data/mimicgen/.gitattributes new file mode 100644 index 0000000..0820b3d --- /dev/null +++ b/tests/data/mimicgen/.gitattributes @@ -0,0 +1 @@ +*.hdf5 filter=lfs diff=lfs merge=lfs -text diff --git a/tests/data/mimicgen/README.md b/tests/data/mimicgen/README.md new file mode 100644 index 0000000..0929fdf --- /dev/null +++ b/tests/data/mimicgen/README.md @@ -0,0 +1 @@ +https://huggingface.co/datasets/amandlek/mimicgen_datasets/blob/main/source/coffee.hdf5 diff --git a/tests/data/coffee.hdf5 b/tests/data/mimicgen/coffee.hdf5 similarity index 100% rename from tests/data/coffee.hdf5 rename to tests/data/mimicgen/coffee.hdf5 diff --git a/tests/data/nuscenes/README.md b/tests/data/nuscenes/README.md new file mode 100644 index 0000000..9354b5a --- /dev/null +++ b/tests/data/nuscenes/README.md @@ -0,0 +1 @@ +https://www.nuscenes.org/terms-of-use diff --git a/tests/data/nuscenes/mcap/.gitattributes b/tests/data/nuscenes/mcap/.gitattributes new file mode 100644 index 0000000..91ccd51 --- /dev/null +++ b/tests/data/nuscenes/mcap/.gitattributes @@ -0,0 +1 @@ +*.mcap filter=lfs diff=lfs merge=lfs -text diff --git a/tests/data/nuscenes/mcap/README.md b/tests/data/nuscenes/mcap/README.md new file mode 100644 index 0000000..7da5b9d --- /dev/null +++ b/tests/data/nuscenes/mcap/README.md @@ -0,0 +1 @@ +https://github.com/foxglove/nuscenes2mcap diff --git a/tests/data/nuScenes-v1.0-mini-scene-0061-cut.mcap b/tests/data/nuscenes/mcap/nuScenes-v1.0-mini-scene-0061-cut.mcap similarity index 100% rename from tests/data/nuScenes-v1.0-mini-scene-0061-cut.mcap rename to tests/data/nuscenes/mcap/nuScenes-v1.0-mini-scene-0061-cut.mcap diff --git a/tests/data/nuscenes/rrd/.gitattributes b/tests/data/nuscenes/rrd/.gitattributes new file mode 100644 index 0000000..ddfc47e --- /dev/null +++ b/tests/data/nuscenes/rrd/.gitattributes @@ -0,0 +1 @@ +*.rrd filter=lfs diff=lfs merge=lfs -text diff --git a/tests/data/nuscenes/rrd/README.md b/tests/data/nuscenes/rrd/README.md new file mode 100644 index 0000000..00cfbde --- /dev/null +++ b/tests/data/nuscenes/rrd/README.md @@ -0,0 +1 @@ +https://app.rerun.io/examples/nuscenes_dataset.rrd diff --git a/tests/data/nuscenes/rrd/nuscenes_dataset.rrd b/tests/data/nuscenes/rrd/nuscenes_dataset.rrd new file mode 100644 index 0000000..7b93001 --- /dev/null +++ b/tests/data/nuscenes/rrd/nuscenes_dataset.rrd @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5c4ed29b0b8da16cd37bc534823934c16b3bf21f660a9c097111e9e8487342a5 +size 94129588 diff --git a/tests/data/yaak/.gitattributes b/tests/data/yaak/.gitattributes new file mode 100644 index 0000000..dc5353e --- /dev/null +++ b/tests/data/yaak/.gitattributes @@ -0,0 +1 @@ +Niro098-HQ filter=lfs diff=lfs merge=lfs -text diff --git a/tests/data/Niro098-HQ/2024-06-18--13-39-54/ai.mcap b/tests/data/yaak/Niro098-HQ/2024-06-18--13-39-54/ai.mcap similarity index 100% rename from tests/data/Niro098-HQ/2024-06-18--13-39-54/ai.mcap rename to tests/data/yaak/Niro098-HQ/2024-06-18--13-39-54/ai.mcap diff --git a/tests/data/Niro098-HQ/2024-06-18--13-39-54/cam_front_left.pii.mp4 b/tests/data/yaak/Niro098-HQ/2024-06-18--13-39-54/cam_front_left.pii.mp4 similarity index 100% rename from tests/data/Niro098-HQ/2024-06-18--13-39-54/cam_front_left.pii.mp4 rename to tests/data/yaak/Niro098-HQ/2024-06-18--13-39-54/cam_front_left.pii.mp4 diff --git a/tests/data/Niro098-HQ/2024-06-18--13-39-54/cam_left_backward.pii.mp4 b/tests/data/yaak/Niro098-HQ/2024-06-18--13-39-54/cam_left_backward.pii.mp4 similarity index 100% rename from tests/data/Niro098-HQ/2024-06-18--13-39-54/cam_left_backward.pii.mp4 rename to tests/data/yaak/Niro098-HQ/2024-06-18--13-39-54/cam_left_backward.pii.mp4 diff --git a/tests/data/Niro098-HQ/2024-06-18--13-39-54/cam_right_backward.pii.mp4 b/tests/data/yaak/Niro098-HQ/2024-06-18--13-39-54/cam_right_backward.pii.mp4 similarity index 100% rename from tests/data/Niro098-HQ/2024-06-18--13-39-54/cam_right_backward.pii.mp4 rename to tests/data/yaak/Niro098-HQ/2024-06-18--13-39-54/cam_right_backward.pii.mp4 diff --git a/tests/data/Niro098-HQ/2024-06-18--13-39-54/metadata.log b/tests/data/yaak/Niro098-HQ/2024-06-18--13-39-54/metadata.log similarity index 100% rename from tests/data/Niro098-HQ/2024-06-18--13-39-54/metadata.log rename to tests/data/yaak/Niro098-HQ/2024-06-18--13-39-54/metadata.log diff --git a/tests/data/yaak/README.md b/tests/data/yaak/README.md new file mode 100644 index 0000000..707a649 --- /dev/null +++ b/tests/data/yaak/README.md @@ -0,0 +1 @@ +https://www.yaak.ai/privacy-policy/product diff --git a/tests/data/zod/.gitattributes b/tests/data/zod/.gitattributes new file mode 100644 index 0000000..9d570c1 --- /dev/null +++ b/tests/data/zod/.gitattributes @@ -0,0 +1 @@ +sequences filter=lfs diff=lfs merge=lfs -text diff --git a/tests/data/zod/README.md b/tests/data/zod/README.md new file mode 100644 index 0000000..1c51fa7 --- /dev/null +++ b/tests/data/zod/README.md @@ -0,0 +1,3 @@ +https://zod.zenseact.com/license/ + +For this dataset, Zenseact AB has taken all reasonable measures to remove all personally identifiable information, including faces and license plates. To the extent that you like to request removal of specific images from the dataset, please contact privacy@zenseact.com. diff --git a/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.092270Z.jpg b/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.092270Z.jpg new file mode 100644 index 0000000..99c2daa --- /dev/null +++ b/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.092270Z.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04fcdb250573897705e43964d27eb5532d7eddf8079d27ac692377f6d92543c0 +size 433884 diff --git a/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.191297Z.jpg b/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.191297Z.jpg new file mode 100644 index 0000000..0815a58 --- /dev/null +++ b/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.191297Z.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1547e031cbe24ca69a966512117e58ab1d3620344957f9f58575c35637680f30 +size 427586 diff --git a/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.290323Z.jpg b/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.290323Z.jpg new file mode 100644 index 0000000..2f2b8c7 --- /dev/null +++ b/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.290323Z.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f99dd4421f275590efd4cefdde263223143d13ef5e48bd5088d4c2ed8039d3f3 +size 426448 diff --git a/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.389350Z.jpg b/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.389350Z.jpg new file mode 100644 index 0000000..403b483 --- /dev/null +++ b/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.389350Z.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7ea9af14b184caf55a2e9ba5610c8635b0e12ed2f4e44ffcf33017b754f69216 +size 421772 diff --git a/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.488377Z.jpg b/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.488377Z.jpg new file mode 100644 index 0000000..d49f40b --- /dev/null +++ b/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.488377Z.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fc3aebcd9612954f8b2fc843b1b1df01850cd603d6dc398c01cb23b235e91e32 +size 419310 diff --git a/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.587404Z.jpg b/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.587404Z.jpg new file mode 100644 index 0000000..8a5cad0 --- /dev/null +++ b/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.587404Z.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1ef59dccd9cf35a81c1d11918cdb8e4da5606deb52545fd60104cf922cb86db7 +size 416026 diff --git a/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.686430Z.jpg b/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.686430Z.jpg new file mode 100644 index 0000000..1940409 --- /dev/null +++ b/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.686430Z.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:43d0ff964116a8e75acfd5ec488eec4d1db1b746eb59b7d8b940e03379727255 +size 409570 diff --git a/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.785457Z.jpg b/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.785457Z.jpg new file mode 100644 index 0000000..7c61106 --- /dev/null +++ b/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.785457Z.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fb7dc979c7538f8d1531cff2642a1413fcf1f0e301b1cd043e134b8d5687037f +size 409658 diff --git a/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.884484Z.jpg b/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.884484Z.jpg new file mode 100644 index 0000000..a0e831e --- /dev/null +++ b/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.884484Z.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1ab862bfa5ede05b3634a3409a1660eea2a24f28b0aa612d94cae0c4a116b20f +size 404761 diff --git a/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.983510Z.jpg b/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.983510Z.jpg new file mode 100644 index 0000000..b8f8071 --- /dev/null +++ b/tests/data/zod/sequences/000002_short/camera_front_blur/000002_romeo_2022-06-13T10:50:07.983510Z.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1e89e497bdc07664d6ab538fa9f063063680382a0d415153f2c9a3c816e39461 +size 396713 diff --git a/tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.000063Z.npy b/tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.000063Z.npy new file mode 100644 index 0000000..b783791 --- /dev/null +++ b/tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.000063Z.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e9de6b306f0ce9f0915410842a1dcab81ee676eb19d51f2f6230960d88410589 +size 5083006 diff --git a/tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.111227Z.npy b/tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.111227Z.npy new file mode 100644 index 0000000..5243f73 --- /dev/null +++ b/tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.111227Z.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee44d41ff0128fa32e6e00059dcf2f55b402e4841341a6096a12a7eb28717049 +size 5109098 diff --git a/tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.222375Z.npy b/tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.222375Z.npy new file mode 100644 index 0000000..c1f411c --- /dev/null +++ b/tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.222375Z.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:713350918e3589afe62581091623f8f4a63fd2601907c7c82fd68297b22b5982 +size 5135828 diff --git a/tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.333497Z.npy b/tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.333497Z.npy new file mode 100644 index 0000000..96415a5 --- /dev/null +++ b/tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.333497Z.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0b75f4d1a12263fb935b82014fa15745b893203094961314b2213b42848517e8 +size 5155034 diff --git a/tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.444642Z.npy b/tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.444642Z.npy new file mode 100644 index 0000000..084f8be --- /dev/null +++ b/tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.444642Z.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:525ff6b8aaa90b0265a72a89f82e236fa6d45c7639be1b813aed567153c81503 +size 5170786 diff --git a/tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.555767Z.npy b/tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.555767Z.npy new file mode 100644 index 0000000..ea93e6e --- /dev/null +++ b/tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.555767Z.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5fe915fe078bbf53479009a47242736bb351bb7eb6a006002607d727d18e3c56 +size 5196878 diff --git a/tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.666867Z.npy b/tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.666867Z.npy new file mode 100644 index 0000000..ef1e45c --- /dev/null +++ b/tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.666867Z.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba88d99b3ba897d38d99d3f2686d589ebc474a437442578bd4ef13ef3c5861e6 +size 5239162 diff --git a/tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.777966Z.npy b/tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.777966Z.npy new file mode 100644 index 0000000..4949c79 --- /dev/null +++ b/tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.777966Z.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:26fcfe409de8d1fd5bf3a8cacdc31d24ba81494021c177d2a4266167ffcb1884 +size 5275022 diff --git a/tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.889058Z.npy b/tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.889058Z.npy new file mode 100644 index 0000000..9405e43 --- /dev/null +++ b/tests/data/zod/sequences/000002_short/lidar_velodyne/000002_romeo_2022-06-13T10:50:07.889058Z.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fac623d623c10a5517007a5d988fc6f4caa9f8dc7fe1b21e13cae728ae373f08 +size 5326788 diff --git a/tests/data/zod/sequences/000002_short/vehicle_data.hdf5 b/tests/data/zod/sequences/000002_short/vehicle_data.hdf5 new file mode 100644 index 0000000..7c4594b --- /dev/null +++ b/tests/data/zod/sequences/000002_short/vehicle_data.hdf5 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9d9c42208d362b04271db71e87ba591a3b9bdea16a14a978c704f1d356e807b6 +size 8458969 diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index bebc1af..6875172 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -15,7 +15,12 @@ def test_mimicgen() -> None: with initialize(version_base=None, config_path=CONFIG_PATH): cfg = compose( - "visualize", overrides=["dataset=mimicgen", f"+data_dir={DATA_DIR}"] + "visualize", + overrides=[ + "dataset=mimicgen", + "logger=rerun/mimicgen", + f"+data_dir={DATA_DIR}/mimicgen", + ], ) dataloader = instantiate(cfg.dataloader) @@ -27,14 +32,11 @@ def test_mimicgen() -> None: batch = next(iter(dataloader)) match batch.to_dict(): case { - "frame": { + "data": { "obs/agentview_image": Tensor(shape=[c.B, c.S, *_]), - **frame_rest, - }, - "table": { "_idx_": Tensor(shape=[c.B, c.S]), "obs/robot0_eef_pos": Tensor(shape=[c.B, c.S, *_]), - **table_rest, + **data_rest, }, "meta": { "input_id": input_id, @@ -44,8 +46,7 @@ def test_mimicgen() -> None: **batch_rest, } if set(input_id).issubset(cfg.dataloader.dataset.inputs) and not any(( batch_rest, - frame_rest, - table_rest, + data_rest, meta_rest, )): pass @@ -55,11 +56,19 @@ def test_mimicgen() -> None: raise AssertionError(msg) + batch_logger = instantiate(cfg.logger, spawn=False) + batch_logger.log(0, batch) + def test_nuscenes_mcap() -> None: with initialize(version_base=None, config_path=CONFIG_PATH): cfg = compose( - "visualize", overrides=["dataset=nuscenes_mcap", f"+data_dir={DATA_DIR}"] + "visualize", + overrides=[ + "dataset=nuscenes/mcap", + "logger=rerun/nuscenes/mcap", + f"+data_dir={DATA_DIR}/nuscenes/mcap", + ], ) dataloader = instantiate(cfg.dataloader) @@ -71,27 +80,22 @@ def test_nuscenes_mcap() -> None: batch = next(iter(dataloader)) match batch.to_dict(): case { - "frame": { + "data": { "CAM_FRONT": Tensor(shape=[c.B, c.S, *_]), "CAM_FRONT_LEFT": Tensor(shape=[c.B, c.S, *_]), "CAM_FRONT_RIGHT": Tensor(shape=[c.B, c.S, *_]), - **frame_rest, - }, - "table": { - "/CAM_FRONT/image_rect_compressed/_idx_": Tensor(shape=[c.B, c.S]), - "/CAM_FRONT_LEFT/image_rect_compressed/_idx_": Tensor(shape=[c.B, c.S]), - "/CAM_FRONT_RIGHT/image_rect_compressed/_idx_": Tensor( + "mcap//CAM_FRONT/image_rect_compressed/_idx_": Tensor(shape=[c.B, c.S]), + "mcap//CAM_FRONT/image_rect_compressed/log_time": Tensor( shape=[c.B, c.S] ), - "/CAM_FRONT/image_rect_compressed/log_time": Tensor(shape=[c.B, c.S]), - "/CAM_FRONT_LEFT/image_rect_compressed/log_time": Tensor( + "mcap//CAM_FRONT_LEFT/image_rect_compressed/_idx_": Tensor( shape=[c.B, c.S] ), - "/CAM_FRONT_RIGHT/image_rect_compressed/log_time": Tensor( + "mcap//CAM_FRONT_RIGHT/image_rect_compressed/_idx_": Tensor( shape=[c.B, c.S] ), - "/odom/vel.x": Tensor(shape=[c.B, c.S]), - **table_rest, + "mcap//odom/vel.x": Tensor(shape=[c.B, c.S]), + **data_rest, }, "meta": { "input_id": input_id, @@ -101,8 +105,7 @@ def test_nuscenes_mcap() -> None: **batch_rest, } if set(input_id).issubset(cfg.dataloader.dataset.inputs) and not any(( batch_rest, - frame_rest, - table_rest, + data_rest, meta_rest, )): pass @@ -112,11 +115,19 @@ def test_nuscenes_mcap() -> None: raise AssertionError(msg) + batch_logger = instantiate(cfg.logger, spawn=False) + batch_logger.log(0, batch) + def test_nuscenes_rrd() -> None: with initialize(version_base=None, config_path=CONFIG_PATH): cfg = compose( - "visualize", overrides=["dataset=nuscenes_rrd", f"+data_dir={DATA_DIR}"] + "visualize", + overrides=[ + "dataset=nuscenes/rrd", + "logger=rerun/nuscenes/rrd", + f"+data_dir={DATA_DIR}/nuscenes/rrd", + ], ) dataloader = instantiate(cfg.dataloader) @@ -128,28 +139,24 @@ def test_nuscenes_rrd() -> None: batch = next(iter(dataloader)) match batch.to_dict(): case { - "frame": { + "data": { "CAM_FRONT": Tensor(shape=[c.B, c.S, *_]), "CAM_FRONT_LEFT": Tensor(shape=[c.B, c.S, *_]), "CAM_FRONT_RIGHT": Tensor(shape=[c.B, c.S, *_]), - **frame_rest, - }, - "table": { - "/world/ego_vehicle/CAM_FRONT/timestamp": Tensor(shape=[c.B, c.S, *_]), - "/world/ego_vehicle/CAM_FRONT/_idx_": Tensor(shape=[c.B, c.S, *_]), - "/world/ego_vehicle/CAM_FRONT_LEFT/timestamp": Tensor( + "rrd//world/ego_vehicle/CAM_FRONT/timestamp": Tensor( + shape=[c.B, c.S, *_] + ), + "rrd//world/ego_vehicle/CAM_FRONT/_idx_": Tensor(shape=[c.B, c.S, *_]), + "rrd//world/ego_vehicle/CAM_FRONT_LEFT/_idx_": Tensor( shape=[c.B, c.S, *_] ), - "/world/ego_vehicle/CAM_FRONT_LEFT/_idx_": Tensor(shape=[c.B, c.S, *_]), - "/world/ego_vehicle/CAM_FRONT_RIGHT/timestamp": Tensor( + "rrd//world/ego_vehicle/CAM_FRONT_RIGHT/_idx_": Tensor( shape=[c.B, c.S, *_] ), - "/world/ego_vehicle/CAM_FRONT_RIGHT/_idx_": Tensor( + "rrd//world/ego_vehicle/LIDAR_TOP/Position3D": Tensor( shape=[c.B, c.S, *_] ), - "/world/ego_vehicle/LIDAR_TOP/timestamp": Tensor(shape=[c.B, c.S, *_]), - "/world/ego_vehicle/LIDAR_TOP/Position3D": Tensor(shape=[c.B, c.S, *_]), - **table_rest, + **data_rest, }, "meta": { "input_id": input_id, @@ -159,8 +166,7 @@ def test_nuscenes_rrd() -> None: **batch_rest, } if set(input_id).issubset(cfg.dataloader.dataset.inputs) and not any(( batch_rest, - frame_rest, - table_rest, + data_rest, meta_rest, )): pass @@ -170,10 +176,20 @@ def test_nuscenes_rrd() -> None: raise AssertionError(msg) + batch_logger = instantiate(cfg.logger, spawn=False) + batch_logger.log(0, batch) + def test_yaak() -> None: with initialize(version_base=None, config_path=CONFIG_PATH): - cfg = compose("visualize", overrides=["dataset=yaak", f"+data_dir={DATA_DIR}"]) + cfg = compose( + "visualize", + overrides=[ + "dataset=yaak", + "logger=rerun/yaak", + f"+data_dir={DATA_DIR}/yaak", + ], + ) dataloader = instantiate(cfg.dataloader) @@ -184,34 +200,92 @@ def test_yaak() -> None: batch = next(iter(dataloader)) match batch.to_dict(): case { - "frame": { + "data": { "cam_front_left": Tensor(shape=[c.B, c.S, *_]), "cam_left_backward": Tensor(shape=[c.B, c.S, *_]), "cam_right_backward": Tensor(shape=[c.B, c.S, *_]), - **frame_rest, + "meta/ImageMetadata.cam_front_left/frame_idx": Tensor(shape=[c.B, c.S]), + "meta/ImageMetadata.cam_front_left/time_stamp": Tensor( + shape=[c.B, c.S] + ), + "meta/ImageMetadata.cam_left_backward/frame_idx": Tensor( + shape=[c.B, c.S] + ), + "meta/ImageMetadata.cam_right_backward/frame_idx": Tensor( + shape=[c.B, c.S] + ), + "meta/VehicleMotion/gear": Tensor(shape=[c.B, c.S]), + "meta/VehicleMotion/speed": Tensor(shape=[c.B, c.S]), + "mcap//ai/safety_score/clip.end_timestamp": Tensor(shape=[c.B, c.S]), + "mcap//ai/safety_score/score": Tensor(shape=[c.B, c.S]), + **data_rest, }, - "table": { - "ImageMetadata.cam_front_left.frame_idx": Tensor(shape=[c.B, c.S]), - "ImageMetadata.cam_front_left.time_stamp": Tensor(shape=[c.B, c.S]), - "ImageMetadata.cam_left_backward.frame_idx": Tensor(shape=[c.B, c.S]), - "ImageMetadata.cam_left_backward.time_stamp": Tensor(shape=[c.B, c.S]), - "ImageMetadata.cam_right_backward.frame_idx": Tensor(shape=[c.B, c.S]), - "ImageMetadata.cam_right_backward.time_stamp": Tensor(shape=[c.B, c.S]), - "VehicleMotion.gear": Tensor(shape=[c.B, c.S]), - "VehicleMotion.speed": Tensor(shape=[c.B, c.S]), - "VehicleMotion.time_stamp": Tensor(shape=[c.B, c.S]), - "/ai/safety_score.clip.end_timestamp": Tensor(shape=[c.B, c.S]), - "/ai/safety_score.score": Tensor(shape=[c.B, c.S]), - **table_rest, + "meta": { + "input_id": input_id, + "sample_idx": Tensor(shape=[c.B]), + **meta_rest, + }, + **batch_rest, + } if set(input_id).issubset(cfg.dataloader.dataset.inputs) and not any(( + batch_rest, + data_rest, + meta_rest, + )): + pass + + case _: + logger.error(msg := "invalid batch structure", batch=batch) + + raise AssertionError(msg) + + batch_logger = instantiate(cfg.logger, spawn=False) + batch_logger.log(0, batch) + + +def test_zod() -> None: + with initialize(version_base=None, config_path=CONFIG_PATH): + cfg = compose( + "visualize", + overrides=["dataset=zod", "logger=rerun/zod", f"+data_dir={DATA_DIR}"], + ) + + dataloader = instantiate(cfg.dataloader) + + c = SimpleNamespace( + B=cfg.dataloader.batch_size, S=cfg.dataloader.dataset.sample_builder.length + ) + + batch = next(iter(dataloader)) + match batch.to_dict(): + case { + "data": { + "camera_front_blur": Tensor(shape=[c.B, c.S, *_]), + "camera_front_blur/timestamp": Tensor(shape=[c.B, c.S, *_]), + "lidar_velodyne": Tensor(shape=[c.B, c.S, *_]), + "lidar_velodyne/timestamp": Tensor(shape=[c.B, c.S, *_]), + "vehicle_data/ego_vehicle_controls/acceleration_pedal/ratio/unitless/value": Tensor( # noqa: E501 + shape=[c.B, c.S, *_] + ), + "vehicle_data/ego_vehicle_controls/steering_wheel_angle/angle/radians/value": Tensor( # noqa: E501 + shape=[c.B, c.S, *_] + ), + "vehicle_data/ego_vehicle_controls/timestamp/nanoseconds/value": Tensor( + shape=[c.B, c.S, *_] + ), + "vehicle_data/satellite/speed/meters_per_second/value": Tensor( + shape=[c.B, c.S, *_] + ), + **data_rest, }, "meta": { "input_id": input_id, "sample_idx": Tensor(shape=[c.B]), **meta_rest, }, + **batch_rest, } if set(input_id).issubset(cfg.dataloader.dataset.inputs) and not any(( - frame_rest, - table_rest, + batch_rest, + data_rest, meta_rest, )): pass @@ -220,3 +294,6 @@ def test_yaak() -> None: logger.error(msg := "invalid batch structure", batch=batch) raise AssertionError(msg) + + batch_logger = instantiate(cfg.logger, spawn=False) + batch_logger.log(0, batch)