Skip to content

Commit

Permalink
Merge pull request #430 from Renumics/feature/infer-categories-of-seq…
Browse files Browse the repository at this point in the history
…uences

Feature/infer categories of sequences
  • Loading branch information
neindochoh authored Feb 20, 2024
2 parents 6c9c954 + 6406c3c commit 8466957
Showing 1 changed file with 50 additions and 15 deletions.
65 changes: 50 additions & 15 deletions renumics/spotlight/data_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import hashlib
import io
import itertools
import os
import statistics
from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast
Expand All @@ -17,7 +18,11 @@
from renumics.spotlight.cache import external_data_cache
from renumics.spotlight.data_source import DataSource
from renumics.spotlight.data_source.data_source import ColumnMetadata
from renumics.spotlight.dtypes.conversion import ConvertedValue, convert_to_dtype
from renumics.spotlight.dtypes.conversion import (
ConvertedValue,
NormalizedValue,
convert_to_dtype,
)
from renumics.spotlight.io import audio
from renumics.spotlight.media.audio import Audio
from renumics.spotlight.media.embedding import Embedding
Expand Down Expand Up @@ -183,20 +188,13 @@ def _update_dtypes(self) -> None:

# determine categories for _automatic_ CategoryDtypes
for column_name, dtype in dtypes.items():
if (
spotlight_dtypes.is_category_dtype(dtype)
and dtype.categories is None
and spotlight_dtypes.is_str_dtype(guessed_dtypes[column_name])
):
normalized_values = self._data_source.get_column_values(column_name)
converted_values = [
convert_to_dtype(
value, spotlight_dtypes.str_dtype, simple=True, check=True
)
for value in normalized_values
]
category_names = sorted(cast(Set[str], set(converted_values)))
dtypes[column_name] = spotlight_dtypes.CategoryDType(category_names)

def values() -> Iterable[NormalizedValue]:
yield from self._data_source.get_column_values(column_name)

dtypes[column_name] = self._refine_dtype(
values(), guessed_dtypes[column_name], dtype
)

self._dtypes = dtypes

Expand All @@ -211,6 +209,43 @@ def _guess_dtype(self, col: str) -> spotlight_dtypes.DType:
sample_dtype = _guess_dtype_from_values(sample_values)
return sample_dtype or spotlight_dtypes.str_dtype

def _refine_dtype(
self,
values: Iterable[NormalizedValue],
guessed_dtype: spotlight_dtypes.DType,
user_dtype: spotlight_dtypes.DType,
) -> spotlight_dtypes.DType:
if (
spotlight_dtypes.is_category_dtype(user_dtype)
and user_dtype.categories is None
and spotlight_dtypes.is_str_dtype(guessed_dtype)
):
converted_values = [
convert_to_dtype(
value, spotlight_dtypes.str_dtype, simple=True, check=True
)
for value in values
]
category_names = sorted(cast(Set[str], set(converted_values)))
return spotlight_dtypes.CategoryDType(category_names)
if spotlight_dtypes.is_sequence_dtype(
user_dtype
) and spotlight_dtypes.is_sequence_dtype(guessed_dtype):
length = user_dtype.length
if length is None:
length = guessed_dtype.length
return spotlight_dtypes.SequenceDType(
self._refine_dtype(
itertools.chain.from_iterable(
cast(Iterable[Iterable[NormalizedValue]], values)
),
guessed_dtype.dtype,
user_dtype.dtype,
),
length,
)
return user_dtype


def _guess_dtype_from_values(values: Iterable) -> Optional[spotlight_dtypes.DType]:
dtypes: List[spotlight_dtypes.DType] = []
Expand Down

0 comments on commit 8466957

Please sign in to comment.