Skip to content

Commit

Permalink
Make Metadata abstract base class generic
Browse files Browse the repository at this point in the history
Redefine `Metadata` class to be generic and refactor downstream classes
  • Loading branch information
ahmadjiha committed Oct 11, 2024
1 parent 44a7a74 commit e56547c
Show file tree
Hide file tree
Showing 19 changed files with 149 additions and 126 deletions.
18 changes: 9 additions & 9 deletions src/zarr/abc/codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Generic, NotRequired, TypedDict, TypeVar

from zarr.abc.metadata import Metadata
from zarr.abc.metadata import Metadata, T
from zarr.core.buffer import Buffer, NDBuffer
from zarr.core.common import ChunkCoords, concurrent_map
from zarr.core.config import config
Expand Down Expand Up @@ -35,7 +35,7 @@
CodecOutput = TypeVar("CodecOutput", bound=NDBuffer | Buffer)


class BaseCodec(Metadata, Generic[CodecInput, CodecOutput]):
class BaseCodec(Generic[CodecInput, CodecOutput, T], Metadata[T]):
"""Generic base class for codecs.
Codecs can be registered via zarr.codecs.registry.
Expand Down Expand Up @@ -153,25 +153,25 @@ async def encode(
return await _batching_helper(self._encode_single, chunks_and_specs)


class ArrayArrayCodec(BaseCodec[NDBuffer, NDBuffer]):
class ArrayArrayCodec(BaseCodec[NDBuffer, NDBuffer, T]):
"""Base class for array-to-array codecs."""

...


class ArrayBytesCodec(BaseCodec[NDBuffer, Buffer]):
class ArrayBytesCodec(BaseCodec[NDBuffer, Buffer, T]):
"""Base class for array-to-bytes codecs."""

...


class BytesBytesCodec(BaseCodec[Buffer, Buffer]):
class BytesBytesCodec(BaseCodec[Buffer, Buffer, T]):
"""Base class for bytes-to-bytes codecs."""

...


Codec = ArrayArrayCodec | ArrayBytesCodec | BytesBytesCodec
Codec = ArrayArrayCodec[Any] | ArrayBytesCodec[Any] | BytesBytesCodec[Any]


class CodecConfigDict(TypedDict):
Expand All @@ -180,14 +180,14 @@ class CodecConfigDict(TypedDict):
...


T = TypeVar("T", bound=CodecConfigDict)
CodecConfigDictType = TypeVar("CodecConfigDictType", bound=CodecConfigDict)


class CodecDict(TypedDict, Generic[T]):
class CodecDict(Generic[CodecConfigDictType], TypedDict):
"""A generic dictionary representing a codec."""

name: str
configuration: NotRequired[T]
configuration: NotRequired[CodecConfigDictType]


class ArrayBytesCodecPartialDecodeMixin:
Expand Down
15 changes: 8 additions & 7 deletions src/zarr/abc/metadata.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import TYPE_CHECKING
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Generic, TypeVar, cast

if TYPE_CHECKING:
from typing import Self

from zarr.core.common import JSON

from dataclasses import dataclass, fields

__all__ = ["Metadata"]

T = TypeVar("T", bound=Mapping[str, object])


@dataclass(frozen=True)
class Metadata:
def to_dict(self) -> dict[str, JSON]:
class Metadata(Generic[T]):
def to_dict(self) -> T:
"""
Recursively serialize this model to a dictionary.
This method inspects the fields of self and calls `x.to_dict()` for any fields that
Expand All @@ -35,10 +36,10 @@ def to_dict(self) -> dict[str, JSON]:
else:
out_dict[key] = value

return out_dict
return cast(T, out_dict)

@classmethod
def from_dict(cls, data: dict[str, JSON]) -> Self:
def from_dict(cls, data: T) -> Self:
"""
Create an instance of the model from a dictionary
"""
Expand Down
6 changes: 3 additions & 3 deletions src/zarr/codecs/_v2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import numcodecs
from numcodecs.compat import ensure_bytes, ensure_ndarray
Expand All @@ -18,7 +18,7 @@


@dataclass(frozen=True)
class V2Compressor(ArrayBytesCodec):
class V2Compressor(ArrayBytesCodec[Any]):
compressor: numcodecs.abc.Codec | None

is_fixed_size = False
Expand Down Expand Up @@ -66,7 +66,7 @@ def compute_encoded_size(self, _input_byte_length: int, _chunk_spec: ArraySpec)


@dataclass(frozen=True)
class V2Filters(ArrayArrayCodec):
class V2Filters(ArrayArrayCodec[Any]):
filters: tuple[numcodecs.abc.Codec, ...] | None

is_fixed_size = False
Expand Down
4 changes: 2 additions & 2 deletions src/zarr/codecs/blosc.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def parse_blocksize(data: JSON) -> int:


@dataclass(frozen=True)
class BloscCodec(BytesBytesCodec):
class BloscCodec(BytesBytesCodec[BloscCodecDict]):
is_fixed_size = False

typesize: int | None
Expand Down Expand Up @@ -130,7 +130,7 @@ def __init__(
object.__setattr__(self, "blocksize", blocksize_parsed)

@classmethod
def from_dict(cls, data: dict[str, JSON]) -> Self:
def from_dict(cls, data: BloscCodecDict) -> Self:
_, configuration_parsed = parse_named_configuration(data, "blosc")
return cls(**configuration_parsed) # type: ignore[arg-type]

Expand Down
12 changes: 7 additions & 5 deletions src/zarr/codecs/bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@

from zarr.abc.codec import ArrayBytesCodec, CodecConfigDict, CodecDict
from zarr.core.buffer import Buffer, NDArrayLike, NDBuffer
from zarr.core.common import JSON, parse_enum, parse_named_configuration
from zarr.core.common import parse_enum, parse_named_configuration
from zarr.registry import register_codec

if TYPE_CHECKING:
from typing import Self
from typing import Literal, Self

from zarr.core.array_spec import ArraySpec

Expand All @@ -33,7 +33,8 @@ class Endian(Enum):
class BytesCodecConfigDict(CodecConfigDict):
"""A dictionary representing a bytes codec configuration."""

endian: Endian
# TODO: Why not type this w/ the Endian Enum
endian: Literal["big", "little"]


class BytesCodecDict(CodecDict[BytesCodecConfigDict]):
Expand All @@ -43,7 +44,7 @@ class BytesCodecDict(CodecDict[BytesCodecConfigDict]):


@dataclass(frozen=True)
class BytesCodec(ArrayBytesCodec):
class BytesCodec(ArrayBytesCodec[BytesCodecDict]):
is_fixed_size = True

endian: Endian | None
Expand All @@ -54,10 +55,11 @@ def __init__(self, *, endian: Endian | str | None = default_system_endian) -> No
object.__setattr__(self, "endian", endian_parsed)

@classmethod
def from_dict(cls, data: dict[str, JSON]) -> Self:
def from_dict(cls, data: BytesCodecDict) -> Self:
_, configuration_parsed = parse_named_configuration(
data, "bytes", require_configuration=False
)

configuration_parsed = configuration_parsed or {}
return cls(**configuration_parsed) # type: ignore[arg-type]

Expand Down
6 changes: 3 additions & 3 deletions src/zarr/codecs/crc32c_.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from crc32c import crc32c

from zarr.abc.codec import BytesBytesCodec, CodecConfigDict, CodecDict
from zarr.core.common import JSON, parse_named_configuration
from zarr.core.common import parse_named_configuration
from zarr.registry import register_codec

if TYPE_CHECKING:
Expand All @@ -25,11 +25,11 @@ class Crc32cCodecDict(CodecDict[CodecConfigDict]):


@dataclass(frozen=True)
class Crc32cCodec(BytesBytesCodec):
class Crc32cCodec(BytesBytesCodec[Crc32cCodecDict]):
is_fixed_size = True

@classmethod
def from_dict(cls, data: dict[str, JSON]) -> Self:
def from_dict(cls, data: Crc32cCodecDict) -> Self:
parse_named_configuration(data, "crc32c", require_configuration=False)
return cls()

Expand Down
4 changes: 2 additions & 2 deletions src/zarr/codecs/gzip.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def parse_gzip_level(data: JSON) -> int:


@dataclass(frozen=True)
class GzipCodec(BytesBytesCodec):
class GzipCodec(BytesBytesCodec[GzipCodecDict]):
is_fixed_size = False

level: int = 5
Expand All @@ -51,7 +51,7 @@ def __init__(self, *, level: int = 5) -> None:
object.__setattr__(self, "level", level_parsed)

@classmethod
def from_dict(cls, data: dict[str, JSON]) -> Self:
def from_dict(cls, data: GzipCodecDict) -> Self:
_, configuration_parsed = parse_named_configuration(data, "gzip")
return cls(**configuration_parsed) # type: ignore[arg-type]

Expand Down
27 changes: 15 additions & 12 deletions src/zarr/codecs/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def resolve_batched(codec: Codec, chunk_specs: Iterable[ArraySpec]) -> Iterable[
return [codec.resolve_metadata(chunk_spec) for chunk_spec in chunk_specs]


# TODO: Double-check whether `CodecDict[Any]` is appropriate
@dataclass(frozen=True)
class BatchedCodecPipeline(CodecPipeline):
"""Default codec pipeline.
Expand All @@ -65,9 +66,9 @@ class BatchedCodecPipeline(CodecPipeline):
lock step for each mini-batch. Multiple mini-batches are processing concurrently.
"""

array_array_codecs: tuple[ArrayArrayCodec, ...]
array_bytes_codec: ArrayBytesCodec
bytes_bytes_codecs: tuple[BytesBytesCodec, ...]
array_array_codecs: tuple[ArrayArrayCodec[Any], ...]
array_bytes_codec: ArrayBytesCodec[Any]
bytes_bytes_codecs: tuple[BytesBytesCodec[Any], ...]
batch_size: int

def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
Expand Down Expand Up @@ -134,11 +135,11 @@ def compute_encoded_size(self, byte_length: int, array_spec: ArraySpec) -> int:
def _codecs_with_resolved_metadata_batched(
self, chunk_specs: Iterable[ArraySpec]
) -> tuple[
list[tuple[ArrayArrayCodec, list[ArraySpec]]],
tuple[ArrayBytesCodec, list[ArraySpec]],
list[tuple[BytesBytesCodec, list[ArraySpec]]],
list[tuple[ArrayArrayCodec[Any], list[ArraySpec]]],
tuple[ArrayBytesCodec[Any], list[ArraySpec]],
list[tuple[BytesBytesCodec[Any], list[ArraySpec]]],
]:
aa_codecs_with_spec: list[tuple[ArrayArrayCodec, list[ArraySpec]]] = []
aa_codecs_with_spec: list[tuple[ArrayArrayCodec[Any], list[ArraySpec]]] = []
chunk_specs = list(chunk_specs)
for aa_codec in self.array_array_codecs:
aa_codecs_with_spec.append((aa_codec, chunk_specs))
Expand All @@ -149,7 +150,7 @@ def _codecs_with_resolved_metadata_batched(
self.array_bytes_codec.resolve_metadata(chunk_spec) for chunk_spec in chunk_specs
]

bb_codecs_with_spec: list[tuple[BytesBytesCodec, list[ArraySpec]]] = []
bb_codecs_with_spec: list[tuple[BytesBytesCodec[Any], list[ArraySpec]]] = []
for bb_codec in self.bytes_bytes_codecs:
bb_codecs_with_spec.append((bb_codec, chunk_specs))
chunk_specs = [bb_codec.resolve_metadata(chunk_spec) for chunk_spec in chunk_specs]
Expand Down Expand Up @@ -465,12 +466,14 @@ async def write(

def codecs_from_list(
codecs: Iterable[Codec],
) -> tuple[tuple[ArrayArrayCodec, ...], ArrayBytesCodec, tuple[BytesBytesCodec, ...]]:
) -> tuple[
tuple[ArrayArrayCodec[Any], ...], ArrayBytesCodec[Any], tuple[BytesBytesCodec[Any], ...]
]:
from zarr.codecs.sharding import ShardingCodec

array_array: tuple[ArrayArrayCodec, ...] = ()
array_bytes_maybe: ArrayBytesCodec | None = None
bytes_bytes: tuple[BytesBytesCodec, ...] = ()
array_array: tuple[ArrayArrayCodec[Any], ...] = ()
array_bytes_maybe: ArrayBytesCodec[Any] | None = None
bytes_bytes: tuple[BytesBytesCodec[Any], ...] = ()

if any(isinstance(codec, ShardingCodec) for codec in codecs) and len(tuple(codecs)) > 1:
warn(
Expand Down
8 changes: 5 additions & 3 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,9 @@ class ShardingCodecDict(CodecDict[ShardingCodecConfigDict]):

@dataclass(frozen=True)
class ShardingCodec(
ArrayBytesCodec, ArrayBytesCodecPartialDecodeMixin, ArrayBytesCodecPartialEncodeMixin
ArrayBytesCodec[ShardingCodecDict],
ArrayBytesCodecPartialDecodeMixin,
ArrayBytesCodecPartialEncodeMixin,
):
chunk_shape: ChunkCoords
codecs: tuple[Codec, ...]
Expand Down Expand Up @@ -370,7 +372,7 @@ def __init__(
object.__setattr__(self, "_get_chunks_per_shard", lru_cache()(self._get_chunks_per_shard))

# todo: typedict return type
def __getstate__(self) -> dict[str, Any]:
def __getstate__(self) -> ShardingCodecDict:
return self.to_dict()

def __setstate__(self, state: dict[str, Any]) -> None:
Expand All @@ -386,7 +388,7 @@ def __setstate__(self, state: dict[str, Any]) -> None:
object.__setattr__(self, "_get_chunks_per_shard", lru_cache()(self._get_chunks_per_shard))

@classmethod
def from_dict(cls, data: dict[str, JSON]) -> Self:
def from_dict(cls, data: ShardingCodecDict) -> Self:
_, configuration_parsed = parse_named_configuration(data, "sharding_indexed")
return cls(**configuration_parsed) # type: ignore[arg-type]

Expand Down
4 changes: 2 additions & 2 deletions src/zarr/codecs/transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class TransposeCodecDict(CodecDict[TransposeCodecConfigDict]):


@dataclass(frozen=True)
class TransposeCodec(ArrayArrayCodec):
class TransposeCodec(ArrayArrayCodec[TransposeCodecDict]):
is_fixed_size = True

order: tuple[int, ...]
Expand All @@ -50,7 +50,7 @@ def __init__(self, *, order: ChunkCoordsLike) -> None:
object.__setattr__(self, "order", order_parsed)

@classmethod
def from_dict(cls, data: dict[str, JSON]) -> Self:
def from_dict(cls, data: TransposeCodecDict) -> Self:
_, configuration_parsed = parse_named_configuration(data, "transpose")
return cls(**configuration_parsed) # type: ignore[arg-type]

Expand Down
Loading

0 comments on commit e56547c

Please sign in to comment.