Skip to content

Commit

Permalink
map_schema
Browse files Browse the repository at this point in the history
  • Loading branch information
pacman82 committed Feb 2, 2024
1 parent 274c534 commit fb7a7d9
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 17 deletions.
3 changes: 2 additions & 1 deletion Changelog.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Changelog

## 4.0.0 (next)
## 4.0.0

- Removed parameter `driver_returns_memory_garbage_for_indicators` from `read_arrow_batches_from_odbc` as it was intended as a workaround for IBM/DB2 drivers. Turns out IBM offers drivers which work correctly with 64Bit driver managers. Look for file names ending in 'o'.
- Add support for mapping inferred schemas via the `map_schema` parameter. It accepts a callable taking and returning an Arrow Schema. This allows you to avoid data types which are not supported in downstream operations and map them e.g. to string. It also enables you to work around quirks of your ODBC driver and map float 32 to float 64 if precisions inferred by the driver are too small.

## 3.0.1

Expand Down
2 changes: 1 addition & 1 deletion doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
author = "Markus Klein"

# The full version, including alpha/beta/rc tags
release = "3.0.1"
release = "4.0.0"


# -- General configuration ---------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "arrow-odbc"
authors = [{name = "Markus Klein"}]
description="Read the data of an ODBC data source as sequence of Apache Arrow record batches."
readme = "README.md"
version = "3.0.1"
version = "4.0.0"
dependencies = ["cffi", "pyarrow >= 8.0.0"]

[project.license]
Expand Down
17 changes: 15 additions & 2 deletions python/arrow_odbc/reader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List, Optional, Callable
from cffi.api import FFI # type: ignore

from pyarrow.cffi import ffi as arrow_ffi # type: ignore
Expand Down Expand Up @@ -40,7 +40,7 @@ def __del__(self):
# Free the resources associated with this handle.
lib.arrow_odbc_reader_free(self.handle)

def schema(self):
def schema(self) -> Schema:
return _schema_from_handle(self.handle)

def next_batch(self):
Expand Down Expand Up @@ -72,7 +72,11 @@ def bind_buffers(
max_binary_size: int,
falliable_allocations: bool = False,
schema: Optional[Schema] = None,
map_schema: Optional[Callable[[Schema], Schema]] = None,
):
if map_schema is not None:
schema = map_schema(self.schema())

ptr_schema = _export_schema_to_c(schema)

error = lib.arrow_odbc_reader_bind_buffers(
Expand Down Expand Up @@ -144,6 +148,7 @@ def more_results(
max_binary_size: Optional[int] = None,
falliable_allocations: bool = False,
schema: Optional[Schema] = None,
map_schema: Optional[Callable[[Schema], Schema]] = None,
) -> bool:
"""
Move the reader to the next result set returned by the data source.
Expand Down Expand Up @@ -231,6 +236,7 @@ def more_results(
max_binary_size=max_binary_size,
falliable_allocations=falliable_allocations,
schema=schema,
map_schema=map_schema,
)

# Every result set can have its own schema, so we must update our member
Expand Down Expand Up @@ -299,6 +305,7 @@ def read_arrow_batches_from_odbc(
falliable_allocations: bool = False,
login_timeout_sec: Optional[int] = None,
schema: Optional[Schema] = None,
map_schema: Optional[Callable[[Schema], Schema]] = None,
) -> BatchReader:
"""
Execute the query and read the result as an iterator over Arrow batches.
Expand Down Expand Up @@ -388,6 +395,11 @@ def read_arrow_batches_from_odbc(
make sense to decide the type based on what you want to do with it, rather than its source.
E.g. if you simply want to put everything into a CSV file it can make perfect sense to fetch
everything as string independent of its source type.
:param map_schema: Allows you to provide a custom schema based on the schema inferred from the
metainformation of the query. This would allow you to e.g. map every column type to string
or replace any float32 with a float64, or anything else you might want to customize, for
various reasons while still staying generic over the input schema. If both ``map_schema``
and ``schema`` are specified ``map_schema`` takes priority.
:return: A ``BatchReader`` is returned, which implements the iterator protocol and iterates over
individual arrow batches.
"""
Expand Down Expand Up @@ -463,6 +475,7 @@ def read_arrow_batches_from_odbc(
max_binary_size=max_binary_size,
falliable_allocations=falliable_allocations,
schema=schema,
map_schema=map_schema,
)

return BatchReader(reader)
Expand Down
88 changes: 76 additions & 12 deletions tests/test_arrow_odbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
log_to_stderr()
enable_odbc_connection_pooling()


def setup_table(table: str, column_type: str, values: List[Any]):
connection = pyodbc.connect(MSSQL)
connection.execute(f"DROP TABLE IF EXISTS {table};")
Expand Down Expand Up @@ -138,7 +139,7 @@ def test_custom_schema_for_second_result_set():
schema = pa.schema([pa.field("a", pa.string())])
reader.more_results(batch_size=1, schema=schema)
batch = next(iter(reader))

expected = pa.RecordBatch.from_pydict({"a": ["2"]}, schema)
assert batch == expected

Expand Down Expand Up @@ -264,7 +265,7 @@ def test_concurrent_reader_into_concurrent():
query=query, batch_size=100, connection_string=MSSQL
)
reader.fetch_concurrently()
reader.fetch_concurrently() # Transforming already concurrent reader into concurrent reader
reader.fetch_concurrently() # Transforming already concurrent reader into concurrent reader
it = iter(reader)

actual = next(it)
Expand Down Expand Up @@ -323,7 +324,9 @@ def test_timestamp_us():
Query a table with one row. Should return one batch
"""
table = "TimestampUs"
setup_table(table=table, column_type="DATETIME2(6)", values=["2014-04-14 21:25:42.074841"])
setup_table(
table=table, column_type="DATETIME2(6)", values=["2014-04-14 21:25:42.074841"]
)

query = f"SELECT * FROM {table}"
reader = read_arrow_batches_from_odbc(
Expand All @@ -347,7 +350,9 @@ def test_timestamp_ns():
Query a table with one row. Should return one batch
"""
table = "TimestampNs"
setup_table(table=table, column_type="DATETIME2(7)", values=["2014-04-14 21:25:42.0748412"])
setup_table(
table=table, column_type="DATETIME2(7)", values=["2014-04-14 21:25:42.0748412"]
)

query = f"SELECT * FROM {table}"
reader = read_arrow_batches_from_odbc(
Expand All @@ -371,7 +376,9 @@ def test_out_of_range_timestamp_ns():
Query a table with one row. Should return one batch
"""
table = "OutOfRangeTimestampNs"
setup_table(table=table, column_type="DATETIME2(7)", values=["2300-04-14 21:25:42.0748412"])
setup_table(
table=table, column_type="DATETIME2(7)", values=["2300-04-14 21:25:42.0748412"]
)

query = f"SELECT * FROM {table}"

Expand Down Expand Up @@ -474,7 +481,9 @@ def test_query_zero_sized_column():
characters.
"""
query = "SELECT CAST('a' AS VARCHAR(MAX)) as a"
with raises(Error, match="ODBC driver did not specify a sensible upper bound for the column"):
with raises(
Error, match="ODBC driver did not specify a sensible upper bound for the column"
):
read_arrow_batches_from_odbc(
query=query, batch_size=100, connection_string=MSSQL
)
Expand All @@ -489,7 +498,9 @@ def test_query_with_string_parameter():
connection = pyodbc.connect(MSSQL)
connection.execute(f"DROP TABLE IF EXISTS {table};")
connection.execute(f"CREATE TABLE {table} (a CHAR(1), b INTEGER);")
connection.execute(f"INSERT INTO {table} (a,b) VALUES ('A', 1),('B',2),('C',3),('D',4);")
connection.execute(
f"INSERT INTO {table} (a,b) VALUES ('A', 1),('B',2),('C',3),('D',4);"
)
connection.commit()
connection.close()
query = f"SELECT b FROM {table} WHERE a=?;"
Expand Down Expand Up @@ -518,7 +529,9 @@ def test_query_with_none_parameter():
connection = pyodbc.connect(MSSQL)
connection.execute(f"DROP TABLE IF EXISTS {table};")
connection.execute(f"CREATE TABLE {table} (a CHAR(1), b INTEGER);")
connection.execute(f"INSERT INTO {table} (a,b) VALUES ('A', 1),('B',2),('C',3),('D',4);")
connection.execute(
f"INSERT INTO {table} (a,b) VALUES ('A', 1),('B',2),('C',3),('D',4);"
)
connection.commit()
connection.close()

Expand All @@ -541,7 +554,9 @@ def test_query_with_int_parameter():
connection = pyodbc.connect(MSSQL)
connection.execute(f"DROP TABLE IF EXISTS {table};")
connection.execute(f"CREATE TABLE {table} (a CHAR(1), b INTEGER);")
connection.execute(f"INSERT INTO {table} (a,b) VALUES ('A', 1),('B',2),('C',3),('D',4);")
connection.execute(
f"INSERT INTO {table} (a,b) VALUES ('A', 1),('B',2),('C',3),('D',4);"
)
connection.commit()
connection.close()

Expand Down Expand Up @@ -569,7 +584,7 @@ def test_query_timestamp_as_date():
batch = next(it)
value = batch.to_pydict()

assert value == { "a": [datetime.date(2023, 12, 24)] }
assert value == {"a": [datetime.date(2023, 12, 24)]}


def test_allocation_erros():
Expand Down Expand Up @@ -676,6 +691,44 @@ def test_support_varbinary_max():
next(it)


def test_map_f32_to_f64():
"""
ODBC drivers for PostgreSQL seem to have some trouble reporting the precision of floating point
types correctly. Using schema mapping users of this wheel which know this quirk can adopt to it
while still staying generic over the database schema.
See issue: https://github.com/pacman82/arrow-odbc-py/issues/73
"""
# Given
table = "MapF32ToF64"
# MS driver is pretty good, so we actually create a 32Bit float by setting precision to 17. This
# way we simulate a driver reporting a too small floating point.
setup_table(table=table, column_type="Float(17)", values=[])
query = f"SELECT (a) FROM {table}"

# When
map_schema = lambda schema: pa.schema(
[
(
name,
(
pa.float64()
if schema.field(name).type == pa.float32()
else schema.field(name).type
),
)
for name in schema.names
]
)
reader = read_arrow_batches_from_odbc(
query=query, batch_size=1, connection_string=MSSQL, map_schema=map_schema
)

# Then
expected = pa.schema([("a", pa.float64())])
assert expected == reader.schema


def test_insert_should_raise_on_invalid_connection_string():
"""
Insert should raise on invalid connection string
Expand Down Expand Up @@ -763,7 +816,9 @@ def test_insert_from_parquet():
table = "InsertFromParquet"
connection = pyodbc.connect(MSSQL)
connection.execute(f"DROP TABLE IF EXISTS {table};")
connection.execute(f"CREATE TABLE {table} (sepal_length REAL, sepal_width REAL, petal_length REAL, petal_width REAL, variety VARCHAR(20) );")
connection.execute(
f"CREATE TABLE {table} (sepal_length REAL, sepal_width REAL, petal_length REAL, petal_width REAL, variety VARCHAR(20) );"
)
connection.commit()
connection.close()

Expand Down Expand Up @@ -805,7 +860,16 @@ def iter_record_batches():

# Then
actual = check_output(
["odbcsv", "fetch", "-c", MSSQL, "--max-str-len", "2000","-q", f"SELECT a FROM {table}"]
[
"odbcsv",
"fetch",
"-c",
MSSQL,
"--max-str-len",
"2000",
"-q",
f"SELECT a FROM {table}",
]
)
assert f"a\n{large_string}\n" == actual.decode("utf8")

Expand Down

0 comments on commit fb7a7d9

Please sign in to comment.