-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Data] Add support for objects to Arrow blocks (#45272)
<!-- Thank you for your contribution! Please review https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? Currently, Ray does not support blocks/batches with objects and multi-dimensional arrays in different columns. This causes Ray Data to throw exceptions when these are provided because: 1. Since there's an arbitrary object in the batch, the Arrow block format fails with ArrowNotImplemented with dtype 17. This falls back to `return pd.DataFrame(dict(batch))` in `BlockAccessor.batch_to_block`. 2. However, this particular DataFrame constructor does not support columns with numpy.ndarray objects, so it throws the exception listed in the linked issue. This change enables Python object storage in the Arrow blocks by defining an Arrow extension type that simply represents the Python objects as a variable-sized large binary. I suppose the alleged performance benefits listed in the comments are an extra benefit. I'm not sure that this is the correct approach or that I've properly patched all of the places, so some help would be appreciated! ## Related issue number Resolves #45235 ## Checks - [X] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [X] I've run `scripts/format.sh` to lint the changes in this PR. - [ ] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [ ] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [X] Unit tests - [ ] Release tests - [ ] This PR is not tested :( --------- Signed-off-by: Peter Wang <[email protected]> Signed-off-by: Hao Chen <[email protected]> Co-authored-by: Peter Wang <[email protected]> Co-authored-by: Hao Chen <[email protected]>
- Loading branch information
1 parent
4bcd381
commit ea452c9
Showing
17 changed files
with
669 additions
and
63 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import types | ||
|
||
import numpy as np | ||
import pyarrow as pa | ||
import pytest | ||
|
||
from ray.air.util.object_extensions.arrow import ( | ||
ArrowPythonObjectArray, | ||
ArrowPythonObjectType, | ||
object_extension_type_allowed, | ||
) | ||
from ray.air.util.object_extensions.pandas import PythonObjectArray | ||
|
||
|
||
@pytest.mark.skipif( | ||
not object_extension_type_allowed(), reason="Object extension not supported." | ||
) | ||
def test_object_array_validation(): | ||
# Test unknown input type raises TypeError. | ||
with pytest.raises(TypeError): | ||
PythonObjectArray(object()) | ||
|
||
PythonObjectArray(np.array([object(), object()])) | ||
PythonObjectArray([object(), object()]) | ||
|
||
|
||
@pytest.mark.skipif( | ||
not object_extension_type_allowed(), reason="Object extension not supported." | ||
) | ||
def test_arrow_scalar_object_array_roundtrip(): | ||
arr = np.array( | ||
["test", 20, False, {"some": "value"}, None, np.zeros((10, 10))], dtype=object | ||
) | ||
ata = ArrowPythonObjectArray.from_objects(arr) | ||
assert isinstance(ata.type, ArrowPythonObjectType) | ||
assert isinstance(ata, ArrowPythonObjectArray) | ||
assert len(ata) == len(arr) | ||
out = ata.to_numpy() | ||
np.testing.assert_array_equal(out[:-1], arr[:-1]) | ||
assert np.all(out[-1] == arr[-1]) | ||
|
||
|
||
@pytest.mark.skipif( | ||
not object_extension_type_allowed(), reason="Object extension not supported." | ||
) | ||
def test_arrow_python_object_array_slice(): | ||
arr = np.array(["test", 20, "test2", 40, "test3", 60], dtype=object) | ||
ata = ArrowPythonObjectArray.from_objects(arr) | ||
assert list(ata[1:3].to_pandas()) == [20, "test2"] | ||
assert ata[2:4].to_pylist() == ["test2", 40] | ||
|
||
|
||
@pytest.mark.skipif( | ||
not object_extension_type_allowed(), reason="Object extension not supported." | ||
) | ||
def test_arrow_pandas_roundtrip(): | ||
obj = types.SimpleNamespace(a=1, b="test") | ||
t1 = pa.table({"a": ArrowPythonObjectArray.from_objects([obj, obj]), "b": [0, 1]}) | ||
t2 = pa.Table.from_pandas(t1.to_pandas()) | ||
assert t1.equals(t2) | ||
|
||
|
||
if __name__ == "__main__": | ||
import sys | ||
|
||
sys.exit(pytest.main(["-v", "-x", __file__])) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import pickle | ||
import typing | ||
|
||
import numpy as np | ||
import pyarrow as pa | ||
from packaging.version import parse as parse_version | ||
|
||
import ray.air.util.object_extensions.pandas | ||
from ray._private.serialization import pickle_dumps | ||
from ray._private.utils import _get_pyarrow_version | ||
from ray.util.annotations import PublicAPI | ||
|
||
MIN_PYARROW_VERSION_SCALAR_SUBCLASS = parse_version("9.0.0") | ||
|
||
_VER = _get_pyarrow_version() | ||
PYARROW_VERSION = None if _VER is None else parse_version(_VER) | ||
|
||
|
||
def object_extension_type_allowed() -> bool: | ||
return ( | ||
PYARROW_VERSION is not None | ||
and PYARROW_VERSION >= MIN_PYARROW_VERSION_SCALAR_SUBCLASS | ||
) | ||
|
||
|
||
# Please see https://arrow.apache.org/docs/python/extending_types.html for more info | ||
@PublicAPI(stability="alpha") | ||
class ArrowPythonObjectType(pa.ExtensionType): | ||
"""Defines a new Arrow extension type for Python objects. | ||
We do not require a parametrized type, so the constructor does not | ||
take any arguments | ||
""" | ||
|
||
def __init__(self) -> None: | ||
# Defines the underlying storage type as the PyArrow LargeBinary type | ||
super().__init__(pa.large_binary(), "ray.data.arrow_pickled_object") | ||
|
||
def __arrow_ext_serialize__(self) -> bytes: | ||
# Since there are no type parameters, we are free to return empty | ||
return b"" | ||
|
||
@classmethod | ||
def __arrow_ext_deserialize__( | ||
cls, storage_type: pa.DataType, serialized: bytes | ||
) -> "ArrowPythonObjectType": | ||
return ArrowPythonObjectType() | ||
|
||
def __arrow_ext_scalar_class__(self) -> type: | ||
"""Returns the scalar class of the extension type. Indexing out of the | ||
PyArrow extension array will return instances of this type. | ||
""" | ||
return ArrowPythonObjectScalar | ||
|
||
def __arrow_ext_class__(self) -> type: | ||
"""Returns the array type of the extension type. Selecting one array | ||
out of the ChunkedArray that makes up a column in a Table with | ||
this custom type will return an instance of this type. | ||
""" | ||
return ArrowPythonObjectArray | ||
|
||
def to_pandas_dtype(self): | ||
"""Pandas interoperability type. This describes the Pandas counterpart | ||
to the Arrow type. See https://pandas.pydata.org/docs/development/extending.html | ||
for more information. | ||
""" | ||
return ray.air.util.object_extensions.pandas.PythonObjectDtype() | ||
|
||
def __reduce__(self): | ||
# Earlier PyArrow versions require custom pickling behavior. | ||
return self.__arrow_ext_deserialize__, ( | ||
self.storage_type, | ||
self.__arrow_ext_serialize__(), | ||
) | ||
|
||
|
||
@PublicAPI(stability="alpha") | ||
class ArrowPythonObjectScalar(pa.ExtensionScalar): | ||
"""Scalar class for ArrowPythonObjectType""" | ||
|
||
def as_py(self) -> typing.Any: | ||
if not isinstance(self.value, pa.LargeBinaryScalar): | ||
raise RuntimeError( | ||
f"{type(self.value)} is not the expected LargeBinaryScalar" | ||
) | ||
return pickle.load(pa.BufferReader(self.value.as_buffer())) | ||
|
||
|
||
@PublicAPI(stability="alpha") | ||
class ArrowPythonObjectArray(pa.ExtensionArray): | ||
"""Array class for ArrowPythonObjectType""" | ||
|
||
def from_objects( | ||
objects: typing.Union[np.ndarray, typing.Iterable[typing.Any]] | ||
) -> "ArrowPythonObjectArray": | ||
if isinstance(objects, np.ndarray): | ||
objects = objects.tolist() | ||
type_ = ArrowPythonObjectType() | ||
all_dumped_bytes = [] | ||
for obj in objects: | ||
dumped_bytes = pickle_dumps( | ||
obj, "Error pickling object to convert to Arrow" | ||
) | ||
all_dumped_bytes.append(dumped_bytes) | ||
arr = pa.array(all_dumped_bytes, type=type_.storage_type) | ||
return ArrowPythonObjectArray.from_storage(type_, arr) | ||
|
||
def to_numpy(self, zero_copy_only: bool = False) -> np.ndarray: | ||
arr = np.empty(len(self), dtype=object) | ||
arr[:] = self.to_pylist() | ||
return arr | ||
|
||
|
||
try: | ||
pa.register_extension_type(ArrowPythonObjectType()) | ||
except pa.ArrowKeyError: | ||
# Already registered | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import collections.abc | ||
import typing | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import pyarrow as pa | ||
from pandas._libs import lib | ||
from pandas._typing import ArrayLike, Dtype, PositionalIndexer, npt | ||
|
||
import ray.air.util.object_extensions.arrow | ||
from ray.util.annotations import PublicAPI | ||
|
||
|
||
# See https://pandas.pydata.org/docs/development/extending.html for more information. | ||
@PublicAPI(stability="alpha") | ||
class PythonObjectArray(pd.api.extensions.ExtensionArray): | ||
"""Implements the Pandas extension array interface for the Arrow object array""" | ||
|
||
def __init__(self, values: collections.abc.Iterable[typing.Any]): | ||
vals = list(values) | ||
self.values = np.empty(len(vals), dtype=object) | ||
self.values[:] = vals | ||
|
||
@classmethod | ||
def _from_sequence( | ||
cls, | ||
scalars: collections.abc.Sequence[typing.Any], | ||
*, | ||
dtype: typing.Union[Dtype, None] = None, | ||
copy: bool = False, | ||
) -> "PythonObjectArray": | ||
return PythonObjectArray(scalars) | ||
|
||
@classmethod | ||
def _from_factorized( | ||
cls, values: collections.abc.Sequence[typing.Any], original: "PythonObjectArray" | ||
) -> "PythonObjectArray": | ||
return PythonObjectArray(values) | ||
|
||
def __getitem__(self, item: PositionalIndexer) -> typing.Any: | ||
return self.values[item] | ||
|
||
def __setitem__(self, key, value) -> None: | ||
self.values[key] = value | ||
|
||
def __len__(self) -> int: | ||
return len(self.values) | ||
|
||
def __eq__(self, other: object) -> ArrayLike: | ||
if isinstance(other, PythonObjectArray): | ||
return self.values == other.values | ||
elif isinstance(other, np.ndarray): | ||
return self.values == other | ||
else: | ||
return NotImplemented | ||
|
||
def to_numpy( | ||
self, | ||
dtype: typing.Union["npt.DTypeLike", None] = None, | ||
copy: bool = False, | ||
na_value: object = lib.no_default, | ||
) -> np.ndarray: | ||
result = self.values | ||
if copy or na_value is not lib.no_default: | ||
result = result.copy() | ||
if na_value is not lib.no_default: | ||
result[self.isna()] = na_value | ||
return result | ||
|
||
@property | ||
def dtype(self) -> pd.api.extensions.ExtensionDtype: | ||
return PythonObjectDtype() | ||
|
||
@property | ||
def nbytes(self) -> int: | ||
return self.values.nbytes | ||
|
||
def __arrow_array__(self, type=None): | ||
return ray.air.util.object_extensions.arrow.ArrowPythonObjectArray.from_objects( | ||
self.values | ||
) | ||
|
||
|
||
@PublicAPI(stability="alpha") | ||
@pd.api.extensions.register_extension_dtype | ||
class PythonObjectDtype(pd.api.extensions.ExtensionDtype): | ||
@classmethod | ||
def construct_from_string(cls, string: str): | ||
if string != "python_object()": | ||
raise TypeError(f"Cannot construct a '{cls.__name__}' from '{string}'") | ||
return cls() | ||
|
||
@property | ||
def type(self): | ||
""" | ||
The scalar type for the array, e.g. ``int`` | ||
It's expected ``ExtensionArray[item]`` returns an instance | ||
of ``ExtensionDtype.type`` for scalar ``item``, assuming | ||
that value is valid (not NA). NA values do not need to be | ||
instances of `type`. | ||
""" | ||
return object | ||
|
||
@property | ||
def name(self) -> str: | ||
return "python_object()" | ||
|
||
@classmethod | ||
def construct_array_type(cls: type) -> type: | ||
""" | ||
Return the array type associated with this dtype. | ||
""" | ||
return PythonObjectArray | ||
|
||
def __from_arrow__( | ||
self, array: typing.Union[pa.Array, pa.ChunkedArray] | ||
) -> PythonObjectArray: | ||
return PythonObjectArray(array.to_pylist()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.