From 1a20fff83e9dab1b8db3cf1f695e1f877bddb8c0 Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Fri, 23 Feb 2024 13:53:21 +0100 Subject: [PATCH] support ray arrays in arrow dataset source --- poetry.lock | 81 ++++++++++++++++++- pyproject.toml | 1 + .../core/arrow_dataset_source.py | 55 ++++++++++--- 3 files changed, 127 insertions(+), 10 deletions(-) diff --git a/poetry.lock b/poetry.lock index cd71b97c..e53829ff 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4031,6 +4031,26 @@ files = [ [package.dependencies] wcwidth = "*" +[[package]] +name = "protobuf" +version = "4.25.3" +description = "" +optional = false +python-versions = ">=3.8" +files = [ + {file = "protobuf-4.25.3-cp310-abi3-win32.whl", hash = "sha256:d4198877797a83cbfe9bffa3803602bbe1625dc30d8a097365dbc762e5790faa"}, + {file = "protobuf-4.25.3-cp310-abi3-win_amd64.whl", hash = "sha256:209ba4cc916bab46f64e56b85b090607a676f66b473e6b762e6f1d9d591eb2e8"}, + {file = "protobuf-4.25.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:f1279ab38ecbfae7e456a108c5c0681e4956d5b1090027c1de0f934dfdb4b35c"}, + {file = "protobuf-4.25.3-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:e7cb0ae90dd83727f0c0718634ed56837bfeeee29a5f82a7514c03ee1364c019"}, + {file = "protobuf-4.25.3-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:7c8daa26095f82482307bc717364e7c13f4f1c99659be82890dcfc215194554d"}, + {file = "protobuf-4.25.3-cp38-cp38-win32.whl", hash = "sha256:f4f118245c4a087776e0a8408be33cf09f6c547442c00395fbfb116fac2f8ac2"}, + {file = "protobuf-4.25.3-cp38-cp38-win_amd64.whl", hash = "sha256:c053062984e61144385022e53678fbded7aea14ebb3e0305ae3592fb219ccfa4"}, + {file = "protobuf-4.25.3-cp39-cp39-win32.whl", hash = "sha256:19b270aeaa0099f16d3ca02628546b8baefe2955bbe23224aaf856134eccf1e4"}, + {file = "protobuf-4.25.3-cp39-cp39-win_amd64.whl", hash = "sha256:e3c97a1555fd6388f857770ff8b9703083de6bf1f9274a002a332d65fbb56c8c"}, + {file = "protobuf-4.25.3-py3-none-any.whl", hash = "sha256:f0700d54bcf45424477e46a9f0944155b46fb0639d69728739c0e47bab83f2b9"}, + {file = "protobuf-4.25.3.tar.gz", hash = "sha256:25b5d0b42fd000320bd7830b349e3b696435f3b329810427a6bcce6a5492cc5c"}, +] + [[package]] name = "psutil" version = "5.9.8" @@ -4795,6 +4815,65 @@ packaging = "*" [package.extras] test = ["pytest (>=6,!=7.0.0,!=7.0.1)", "pytest-cov (>=3.0.0)", "pytest-qt"] +[[package]] +name = "ray" +version = "2.9.3" +description = "Ray provides a simple, universal API for building distributed applications." +optional = false +python-versions = ">=3.8" +files = [ + {file = "ray-2.9.3-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:09b4d3f3cacc66f256695a5f72960111815cee3986bdcf7a9c3f6f0fac144100"}, + {file = "ray-2.9.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:287eed74fa536651aa799c4295e1b27eee1650f29236fa94487985b76bffff35"}, + {file = "ray-2.9.3-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:d5334fb43468f56a52ebd8fb30f39bbc6d2a6a16ecf3d9f78be59952aa533b6a"}, + {file = "ray-2.9.3-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:c54e35eb78816c722a58f31d75f5ec82834433fa639ecf70daee0d7b182598ca"}, + {file = "ray-2.9.3-cp310-cp310-win_amd64.whl", hash = "sha256:266f890ea8bb6ce417a4890ae495082eece45ac1c1ad0db92a5f6fb52792a3bc"}, + {file = "ray-2.9.3-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:8e72b92122c612f54224ffb33ef34f437aec59f370382882f4519b6fd55bb349"}, + {file = "ray-2.9.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:615a5b8d17a69713178cdb2184c4f6d11c5d3a1a5a358bd3617f9404d782323e"}, + {file = "ray-2.9.3-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:b493412cf3f38861f517664312da40d622baa7deb8b5a9811ca1b1fb60bd444a"}, + {file = "ray-2.9.3-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:747343a1115f7b851da287e0e2b1cd3c703c843c9dd1f522c1e47bfc76e14c9e"}, + {file = "ray-2.9.3-cp311-cp311-win_amd64.whl", hash = "sha256:606dded40b17350b2d29b1fc0cb7be7085a8f39c9576a63e450d86fc5670f01a"}, + {file = "ray-2.9.3-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:d3219d341b4f32ff9cb747783615fbdabe45a202d6e50f9a8510470d117ba40e"}, + {file = "ray-2.9.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:fb4bb8db188155671125facc8ed89d1d70314959c66f2bf8dba6f087ab3024e2"}, + {file = "ray-2.9.3-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:cc064f1760775600a2edd281fcbe70f2b84ec09c9b6fd3f0cf21cbe6e0e34269"}, + {file = "ray-2.9.3-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:41f3b8d6c8ff57875dbf8b2b1c9bb8bbd7c6fc0b6c2048772ddd704f53eec653"}, + {file = "ray-2.9.3-cp38-cp38-win_amd64.whl", hash = "sha256:06fedfd0bfb875cd504870a9960a244f41d202a61388edd23b7a8513bb007de2"}, + {file = "ray-2.9.3-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:0b892cdbc7bdd3cebb5ee71811c468b922b3c99e65aeb890a522af36f1933350"}, + {file = "ray-2.9.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f597662dafd3c5b91b41f892acb1ef12e69732ced845f40060c3455192e1bd29"}, + {file = "ray-2.9.3-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:febae4acb05b132f9c49cd3b2a9dd8bfaa1cb8a52ef75f734659469956efe9f1"}, + {file = "ray-2.9.3-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:859e7be3cfcc1eb52762aa0065a3c7c57002e67e23f2858b40cf5f3081e13391"}, + {file = "ray-2.9.3-cp39-cp39-win_amd64.whl", hash = "sha256:2befd5f928c896357170bf46ac1ab197509561dce1cc733db9b235e02039dfe7"}, +] + +[package.dependencies] +aiosignal = "*" +click = ">=7.0" +filelock = "*" +frozenlist = "*" +fsspec = {version = "*", optional = true, markers = "extra == \"data\""} +jsonschema = "*" +msgpack = ">=1.0.0,<2.0.0" +numpy = {version = ">=1.20", optional = true, markers = "extra == \"data\""} +packaging = "*" +pandas = {version = ">=1.3", optional = true, markers = "extra == \"data\""} +protobuf = ">=3.15.3,<3.19.5 || >3.19.5" +pyarrow = {version = ">=6.0.1", optional = true, markers = "extra == \"data\""} +pyyaml = "*" +requests = "*" + +[package.extras] +air = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "fastapi (<=0.108.0)", "fsspec", "gpustat (>=1.0.0)", "grpcio (>=1.32.0)", "grpcio (>=1.42.0)", "numpy (>=1.20)", "opencensus", "pandas", "pandas (>=1.3)", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pyarrow (>=6.0.1)", "pydantic (<2.0.dev0 || >=2.5.dev0,<3)", "requests", "smart-open", "starlette", "tensorboardX (>=1.9)", "uvicorn[standard]", "virtualenv (>=20.0.24,!=20.21.1)", "watchfiles"] +all = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "dm-tree", "fastapi (<=0.108.0)", "fsspec", "gpustat (>=1.0.0)", "grpcio (!=1.56.0)", "grpcio (>=1.32.0)", "grpcio (>=1.42.0)", "gymnasium (==0.28.1)", "lz4", "numpy (>=1.20)", "opencensus", "opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-sdk", "pandas", "pandas (>=1.3)", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pyarrow (>=6.0.1)", "pydantic (<2.0.dev0 || >=2.5.dev0,<3)", "pyyaml", "ray-cpp (==2.9.3)", "requests", "rich", "scikit-image", "scipy", "smart-open", "starlette", "tensorboardX (>=1.9)", "typer", "uvicorn[standard]", "virtualenv (>=20.0.24,!=20.21.1)", "watchfiles"] +client = ["grpcio (!=1.56.0)"] +cpp = ["ray-cpp (==2.9.3)"] +data = ["fsspec", "numpy (>=1.20)", "pandas (>=1.3)", "pyarrow (>=6.0.1)"] +default = ["aiohttp (>=3.7)", "aiohttp-cors", "colorful", "gpustat (>=1.0.0)", "grpcio (>=1.32.0)", "grpcio (>=1.42.0)", "opencensus", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pydantic (<2.0.dev0 || >=2.5.dev0,<3)", "requests", "smart-open", "virtualenv (>=20.0.24,!=20.21.1)"] +observability = ["opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-sdk"] +rllib = ["dm-tree", "fsspec", "gymnasium (==0.28.1)", "lz4", "pandas", "pyarrow (>=6.0.1)", "pyyaml", "requests", "rich", "scikit-image", "scipy", "tensorboardX (>=1.9)", "typer"] +serve = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "fastapi (<=0.108.0)", "gpustat (>=1.0.0)", "grpcio (>=1.32.0)", "grpcio (>=1.42.0)", "opencensus", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pydantic (<2.0.dev0 || >=2.5.dev0,<3)", "requests", "smart-open", "starlette", "uvicorn[standard]", "virtualenv (>=20.0.24,!=20.21.1)", "watchfiles"] +serve-grpc = ["aiohttp (>=3.7)", "aiohttp-cors", "aiorwlock", "colorful", "fastapi (<=0.108.0)", "gpustat (>=1.0.0)", "grpcio (>=1.32.0)", "grpcio (>=1.42.0)", "opencensus", "prometheus-client (>=0.7.1)", "py-spy (>=0.2.0)", "pydantic (<2.0.dev0 || >=2.5.dev0,<3)", "requests", "smart-open", "starlette", "uvicorn[standard]", "virtualenv (>=20.0.24,!=20.21.1)", "watchfiles"] +train = ["fsspec", "pandas", "pyarrow (>=6.0.1)", "requests", "tensorboardX (>=1.9)"] +tune = ["fsspec", "pandas", "pyarrow (>=6.0.1)", "requests", "tensorboardX (>=1.9)"] + [[package]] name = "readme-renderer" version = "42.0" @@ -7184,4 +7263,4 @@ descriptors = ["pycatch22"] [metadata] lock-version = "2.0" python-versions = ">=3.8, <3.12" -content-hash = "26190274151ccc122f23e3a5612bd6e6f871e7039ab1c3b905dcad1b8a8bcb62" +content-hash = "bf4727d613077cbff3ceb8ac0a682fd6d1759b01c990f47a64d95c3a0460a61f" diff --git a/pyproject.toml b/pyproject.toml index e2a8fdae..3ad3ea9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -127,6 +127,7 @@ pandas-stubs = "^2.0.2.230605" ruff = "^0.2.1" check-wheel-contents = "^0.6.0" torch = { version = "^2.1.1+cpu", source = "torch-cpu" } +ray = {extras = ["data"], version = "^2.9.3"} [tool.poetry.group.playbook.dependencies] towhee = "^0.9.0" diff --git a/renumics/spotlight_plugins/core/arrow_dataset_source.py b/renumics/spotlight_plugins/core/arrow_dataset_source.py index 4ff9ecfe..3b2f79c5 100644 --- a/renumics/spotlight_plugins/core/arrow_dataset_source.py +++ b/renumics/spotlight_plugins/core/arrow_dataset_source.py @@ -8,6 +8,14 @@ import pyarrow.dataset import pyarrow.types +try: + # let ray register it's extension types if it is installed + import ray.data # noqa + import ray.air.util.tensor_extensions.arrow # noqa +except ModuleNotFoundError: + pass + + import renumics.spotlight.dtypes as spotlight_dtypes from renumics.spotlight.data_source import DataSource from renumics.spotlight.data_source.data_source import ColumnMetadata @@ -20,6 +28,12 @@ class UnknownArrowType(Exception): """ +class UnknownArrowExtensionType(Exception): + """ + We encountered an unknown arrow Extension Type during type conversion + """ + + EMPTY_MAP: spotlight_dtypes.DTypeMap = {} @@ -36,14 +50,16 @@ def __init__(self, source: pyarrow.dataset.Dataset): self._intermediate_dtypes: spotlight_dtypes.DTypeMap = self._convert_schema() self._semantic_dtypes = {} - # support hf metadata (only images for now) - if hf_metadata := orjson.loads( - source.schema.metadata.get(b"huggingface", "null") - ): - features = hf_metadata.get("info", {}).get("features", {}) - for name, feat in features.items(): - if feat.get("_type") == "Image": - self._semantic_dtypes[name] = spotlight_dtypes.image_dtype + + if source.schema.metadata: + # support hf metadata (only images for now) + if hf_metadata := orjson.loads( + source.schema.metadata.get(b"huggingface", "null") + ): + features = hf_metadata.get("info", {}).get("features", {}) + for name, feat in features.items(): + if feat.get("_type") == "Image": + self._semantic_dtypes[name] = spotlight_dtypes.image_dtype @property def column_names(self) -> List[str]: @@ -74,6 +90,12 @@ def get_column_values( column_name: str, indices: Union[List[int], np.ndarray, slice] = slice(None), ) -> np.ndarray: + try: + # Import these arrow extension types to ensure that they are registered. + import ray.air.util.tensor_extensions.arrow # noqa + except ModuleNotFoundError: + pass + if indices == slice(None): table = self._dataset.to_table(columns=[column_name]) else: @@ -90,6 +112,14 @@ def get_column_values( raw_values = table[column_name] + dtype = self._intermediate_dtypes.get(column_name) + if isinstance(dtype, spotlight_dtypes.ArrayDType): + if dtype.shape is not None: + shape = [-1 if x is None else x for x in dtype.shape] + return np.array([np.array(arr).reshape(shape) for arr in raw_values]) + else: + return raw_values.to_numpy() + # convert hf image values if self._semantic_dtypes.get(column_name) == spotlight_dtypes.image_dtype: return np.array( @@ -115,6 +145,7 @@ def _convert_schema(self) -> spotlight_dtypes.DTypeMap: schema: spotlight_dtypes.DTypeMap = {} for field in self._dataset.schema: schema[field.name] = _convert_dtype(field) + print(schema) return schema @@ -189,5 +220,11 @@ def _convert_dtype(field: pa.Field) -> spotlight_dtypes.DType: return spotlight_dtypes.SequenceDType( _convert_dtype(pa.field("", field.type.value_type)) ) + if isinstance(field.type, pa.ExtensionType): + # handle known extensions + if field.type.extension_name == "ray.data.arrow_tensor": + return spotlight_dtypes.ArrayDType(shape=field.type.shape) + + raise UnknownArrowExtensionType(field.type.extension_name) - raise UnknownArrowType() + raise UnknownArrowType(field.type)