Skip to content

Commit

Permalink
fix for pyarrow
Browse files Browse the repository at this point in the history
  • Loading branch information
xoolive committed Nov 1, 2024
1 parent 8b2d374 commit fb0c20b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "pyopensky"
version = "2.10"
version = "2.11"
description = "A Python interface for OpenSky database"
authors = [
{ name = "Xavier Olive", email = "[email protected]" },
Expand Down
23 changes: 16 additions & 7 deletions src/pyopensky/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,15 @@ def query(
if (cache_file := (cache_path / digest).with_suffix(suffix)).exists():
if cached:
_log.info(f"Reading results from {cache_file}")
return pd.read_parquet(cache_file).convert_dtypes(
dtype_backend="pyarrow"
)
df = pd.read_parquet(cache_file)
df = df.convert_dtypes(dtype_backend="pyarrow")

for column in df.select_dtypes(include=["datetime"]):
if df[column].dtype.pyarrow_dtype.tz is None:
df = df.assign(
**{column: df[column].dt.tz_localize("UTC")}
)
return df
else:
cache_file.unlink(missing_ok=True)

Expand Down Expand Up @@ -196,6 +202,11 @@ def query(

df = pd.concat(self.process_result(res))

df = df.convert_dtypes(dtype_backend="pyarrow")
for column in df.select_dtypes(include=["datetime"]):
if df[column].dtype.pyarrow_dtype.tz is None:
df = df.assign(**{column: df[column].dt.tz_localize("UTC")})

if cached:
_log.info(f"Saving results to {cache_file}")
df.to_parquet(cache_file)
Expand Down Expand Up @@ -233,16 +244,14 @@ def process_result(
) as download_bar:
sequence_rows = async_result.get()
download_bar.update(len(sequence_rows))
yield pd.DataFrame.from_records(
sequence_rows, columns=res.keys()
).convert_dtypes(dtype_backend="pyarrow")
yield pd.DataFrame.from_records(sequence_rows, columns=res.keys())

while len(sequence_rows) == batch_size:
sequence_rows = res.fetchmany(batch_size)
download_bar.update(len(sequence_rows))
yield pd.DataFrame.from_records(
sequence_rows, columns=res.keys()
).convert_dtypes(dtype_backend="pyarrow")
)

## Specific queries

Expand Down

0 comments on commit fb0c20b

Please sign in to comment.