Skip to content

Commit

Permalink
Update polars to fix PanicException (#2585)
Browse files Browse the repository at this point in the history
* update polars

* fix column name with counts

* fix operations on null type (when all values are null)

* add testcase with parquet file that throws this error
  • Loading branch information
polinaeterna authored Mar 14, 2024
1 parent 4a9503f commit f2e3c63
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 37 deletions.
59 changes: 47 additions & 12 deletions services/worker/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion services/worker/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ transformers = "^4.36.0"
typer = "^0.4.2"
uvicorn = "^0.20.0"
wget = "^3.2"
polars = "^0.19.15"
polars = ">=0.20.0"

[tool.poetry.group.dev.dependencies]
bandit = "^1.7.4"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def compute_histogram(
hist_df_reverted = df.with_columns(pl.col(column_name).mul(-1).alias("reverse"))["reverse"].hist(
bins=bins_edges_reverted
)
hist_reverted = hist_df_reverted["reverse_count"].cast(int).to_list()
hist_reverted = hist_df_reverted["count"].cast(int).to_list()
hist = hist_reverted[::-1]
hist = [hist[0] + hist[1]] + hist[2:-2] + [hist[-2] + hist[-1]]
else:
Expand All @@ -219,9 +219,7 @@ def compute_histogram(
)


def min_max_median_std_nan_count_proportion(
data: pl.DataFrame, column_name: str, n_samples: int
) -> tuple[float, float, float, float, float, int, float]:
def min_max_mean_median_std(data: pl.DataFrame, column_name: str) -> tuple[float, float, float, float, float]:
"""
Compute minimum, maximum, median, standard deviation, number of nan samples and their proportion in column data.
"""
Expand All @@ -231,7 +229,6 @@ def min_max_median_std_nan_count_proportion(
mean=pl.all().mean(),
median=pl.all().median(),
std=pl.all().std(),
nan_count=pl.all().null_count(),
)
stats_names = pl.Series(col_stats.keys())
stats_expressions = [pl.struct(stat) for stat in col_stats.values()]
Expand All @@ -240,26 +237,19 @@ def min_max_median_std_nan_count_proportion(
.select(name=stats_names, stats=pl.concat_list(stats_expressions).flatten())
.unnest("stats")
)
minimum, maximum, mean, median, std, nan_count = stats[column_name].to_list()
minimum, maximum, mean, median, std = stats[column_name].to_list()
if any(statistic is None for statistic in [minimum, maximum, mean, median, std]):
# this should be possible only if all values are none
if not all(statistic is None for statistic in [minimum, maximum, mean, median, std]):
raise StatisticsComputationError(
f"Unexpected result for {column_name=}: "
f"Some measures among {minimum=}, {maximum=}, {mean=}, {median=}, {std=} are None but not all of them. "
)
if nan_count != n_samples:
raise StatisticsComputationError(
f"Unexpected result for {column_name=}: "
f"{minimum=}, {maximum=}, {mean=}, {median=}, {std=} are None but not all values in column are None. "
)
return minimum, maximum, mean, median, std, nan_count, 1.0
return minimum, maximum, mean, median, std

minimum, maximum, mean, median, std = np.round([minimum, maximum, mean, median, std], DECIMALS).tolist()
nan_proportion = np.round(nan_count / n_samples, DECIMALS).item() if nan_count else 0.0
nan_count = int(nan_count)

return minimum, maximum, mean, median, std, nan_count, nan_proportion
return minimum, maximum, mean, median, std


def value_counts(data: pl.DataFrame, column_name: str) -> dict[Any, Any]:
Expand Down Expand Up @@ -385,10 +375,7 @@ def _compute_statistics(
data: pl.DataFrame, column_name: str, n_samples: int, n_bins: int
) -> NumericalStatisticsItem:
logging.info(f"Compute statistics for float column {column_name} with polars. ")
minimum, maximum, mean, median, std, nan_count, nan_proportion = min_max_median_std_nan_count_proportion(
data, column_name, n_samples
)
logging.debug(f"{minimum=}, {maximum=}, {mean=}, {median=}, {std=}, {nan_count=} {nan_proportion=}")
nan_count, nan_proportion = nan_count_proportion(data, column_name, n_samples)
if nan_count == n_samples: # all values are None
return NumericalStatisticsItem(
nan_count=n_samples,
Expand All @@ -400,6 +387,8 @@ def _compute_statistics(
std=None,
histogram=None,
)
minimum, maximum, mean, median, std = min_max_mean_median_std(data, column_name)
logging.debug(f"{minimum=}, {maximum=}, {mean=}, {median=}, {std=}, {nan_count=} {nan_proportion=}")

hist = compute_histogram(
data,
Expand Down Expand Up @@ -442,11 +431,8 @@ def _compute_statistics(
data: pl.DataFrame, column_name: str, n_samples: int, n_bins: int
) -> NumericalStatisticsItem:
logging.info(f"Compute statistics for integer column {column_name} with polars. ")
minimum, maximum, mean, median, std, nan_count, nan_proportion = min_max_median_std_nan_count_proportion(
data, column_name, n_samples
)
logging.debug(f"{minimum=}, {maximum=}, {mean=}, {median=}, {std=}, {nan_count=} {nan_proportion=}")
if nan_count == n_samples: # all values are None
nan_count, nan_proportion = nan_count_proportion(data, column_name, n_samples=n_samples)
if nan_count == n_samples:
return NumericalStatisticsItem(
nan_count=n_samples,
nan_proportion=1.0,
Expand All @@ -458,6 +444,9 @@ def _compute_statistics(
histogram=None,
)

minimum, maximum, mean, median, std = min_max_mean_median_std(data, column_name)
logging.debug(f"{minimum=}, {maximum=}, {mean=}, {median=}, {std=}, {nan_count=} {nan_proportion=}")

minimum, maximum = int(minimum), int(maximum)
hist = compute_histogram(
data,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import polars as pl
import pytest
from datasets import ClassLabel, Dataset
from huggingface_hub.hf_api import HfApi
from libcommon.dtos import Priority
from libcommon.exceptions import StatisticsComputationError
from libcommon.resources import CacheMongoResource, QueueMongoResource
Expand Down Expand Up @@ -38,6 +39,7 @@
)
from worker.resources import LibrariesResource

from ...constants import CI_HUB_ENDPOINT, CI_USER_TOKEN
from ...fixtures.hub import HubDatasetTest
from ..utils import REVISION_NAME

Expand Down Expand Up @@ -662,6 +664,38 @@ def test_list_statistics(
assert computed == expected


@pytest.fixture
def struct_thread_panic_error_parquet_file(tmp_path_factory: pytest.TempPathFactory) -> str:
repo_id = "__DUMMY_TRANSFORMERS_USER__/test_polars_panic_error"
hf_api = HfApi(endpoint=CI_HUB_ENDPOINT)

dir_name = tmp_path_factory.mktemp("data")
hf_api.hf_hub_download(
repo_id=repo_id,
filename="test_polars_panic_error.parquet",
repo_type="dataset",
local_dir=dir_name,
token=CI_USER_TOKEN,
)
return str(dir_name / "test_polars_panic_error.parquet")


def test_polars_struct_thread_panic_error(struct_thread_panic_error_parquet_file: str) -> None:
from polars import (
Float64,
List,
String,
Struct,
)

df = pl.read_parquet(struct_thread_panic_error_parquet_file) # should not raise
assert "conversations" in df

conversations_schema = List(Struct({"from": String, "value": String, "weight": Float64}))
assert "conversations" in df.schema
assert df.schema["conversations"] == conversations_schema


@pytest.mark.parametrize(
"hub_dataset_name,expected_error_code",
[
Expand Down

0 comments on commit f2e3c63

Please sign in to comment.