Skip to content

Commit

Permalink
support ray arrays in arrow dataset source
Browse files Browse the repository at this point in the history
  • Loading branch information
neindochoh committed Feb 23, 2024
1 parent 2c8619c commit 1a20fff
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 10 deletions.
81 changes: 80 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ pandas-stubs = "^2.0.2.230605"
ruff = "^0.2.1"
check-wheel-contents = "^0.6.0"
torch = { version = "^2.1.1+cpu", source = "torch-cpu" }
ray = {extras = ["data"], version = "^2.9.3"}

[tool.poetry.group.playbook.dependencies]
towhee = "^0.9.0"
Expand Down
55 changes: 46 additions & 9 deletions renumics/spotlight_plugins/core/arrow_dataset_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@
import pyarrow.dataset
import pyarrow.types

try:
# let ray register it's extension types if it is installed
import ray.data # noqa
import ray.air.util.tensor_extensions.arrow # noqa
except ModuleNotFoundError:
pass


import renumics.spotlight.dtypes as spotlight_dtypes
from renumics.spotlight.data_source import DataSource
from renumics.spotlight.data_source.data_source import ColumnMetadata
Expand All @@ -20,6 +28,12 @@ class UnknownArrowType(Exception):
"""


class UnknownArrowExtensionType(Exception):
"""
We encountered an unknown arrow Extension Type during type conversion
"""


EMPTY_MAP: spotlight_dtypes.DTypeMap = {}


Expand All @@ -36,14 +50,16 @@ def __init__(self, source: pyarrow.dataset.Dataset):
self._intermediate_dtypes: spotlight_dtypes.DTypeMap = self._convert_schema()

self._semantic_dtypes = {}
# support hf metadata (only images for now)
if hf_metadata := orjson.loads(
source.schema.metadata.get(b"huggingface", "null")
):
features = hf_metadata.get("info", {}).get("features", {})
for name, feat in features.items():
if feat.get("_type") == "Image":
self._semantic_dtypes[name] = spotlight_dtypes.image_dtype

if source.schema.metadata:
# support hf metadata (only images for now)
if hf_metadata := orjson.loads(
source.schema.metadata.get(b"huggingface", "null")
):
features = hf_metadata.get("info", {}).get("features", {})
for name, feat in features.items():
if feat.get("_type") == "Image":
self._semantic_dtypes[name] = spotlight_dtypes.image_dtype

@property
def column_names(self) -> List[str]:
Expand Down Expand Up @@ -74,6 +90,12 @@ def get_column_values(
column_name: str,
indices: Union[List[int], np.ndarray, slice] = slice(None),
) -> np.ndarray:
try:
# Import these arrow extension types to ensure that they are registered.
import ray.air.util.tensor_extensions.arrow # noqa
except ModuleNotFoundError:
pass

if indices == slice(None):
table = self._dataset.to_table(columns=[column_name])
else:
Expand All @@ -90,6 +112,14 @@ def get_column_values(

raw_values = table[column_name]

dtype = self._intermediate_dtypes.get(column_name)
if isinstance(dtype, spotlight_dtypes.ArrayDType):
if dtype.shape is not None:
shape = [-1 if x is None else x for x in dtype.shape]
return np.array([np.array(arr).reshape(shape) for arr in raw_values])
else:
return raw_values.to_numpy()

# convert hf image values
if self._semantic_dtypes.get(column_name) == spotlight_dtypes.image_dtype:
return np.array(
Expand All @@ -115,6 +145,7 @@ def _convert_schema(self) -> spotlight_dtypes.DTypeMap:
schema: spotlight_dtypes.DTypeMap = {}
for field in self._dataset.schema:
schema[field.name] = _convert_dtype(field)
print(schema)
return schema


Expand Down Expand Up @@ -189,5 +220,11 @@ def _convert_dtype(field: pa.Field) -> spotlight_dtypes.DType:
return spotlight_dtypes.SequenceDType(
_convert_dtype(pa.field("", field.type.value_type))
)
if isinstance(field.type, pa.ExtensionType):
# handle known extensions
if field.type.extension_name == "ray.data.arrow_tensor":
return spotlight_dtypes.ArrayDType(shape=field.type.shape)

raise UnknownArrowExtensionType(field.type.extension_name)

raise UnknownArrowType()
raise UnknownArrowType(field.type)

0 comments on commit 1a20fff

Please sign in to comment.