diff --git a/Changelog.md b/Changelog.md index 491eceb..ae6fa53 100644 --- a/Changelog.md +++ b/Changelog.md @@ -1,8 +1,9 @@ # Changelog -## 4.0.0 (next) +## 4.0.0 - Removed parameter `driver_returns_memory_garbage_for_indicators` from `read_arrow_batches_from_odbc` as it was intended as a workaround for IBM/DB2 drivers. Turns out IBM offers drivers which work correctly with 64Bit driver managers. Look for file names ending in 'o'. +- Add support for mapping inferred schemas via the `map_schema` parameter. It accepts a callable taking and returning an Arrow Schema. This allows you to avoid data types which are not supported in downstream operations and map them e.g. to string. It also enables you to work around quirks of your ODBC driver and map float 32 to float 64 if precisions inferred by the driver are too small. ## 3.0.1 diff --git a/doc/source/conf.py b/doc/source/conf.py index 35c71a3..4619ba2 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -22,7 +22,7 @@ author = "Markus Klein" # The full version, including alpha/beta/rc tags -release = "3.0.1" +release = "4.0.0" # -- General configuration --------------------------------------------------- diff --git a/pyproject.toml b/pyproject.toml index 5ab8bcf..c96f1b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "arrow-odbc" authors = [{name = "Markus Klein"}] description="Read the data of an ODBC data source as sequence of Apache Arrow record batches." readme = "README.md" -version = "3.0.1" +version = "4.0.0" dependencies = ["cffi", "pyarrow >= 8.0.0"] [project.license] diff --git a/python/arrow_odbc/reader.py b/python/arrow_odbc/reader.py index 2429db9..62e7946 100644 --- a/python/arrow_odbc/reader.py +++ b/python/arrow_odbc/reader.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Callable from cffi.api import FFI # type: ignore from pyarrow.cffi import ffi as arrow_ffi # type: ignore @@ -40,7 +40,7 @@ def __del__(self): # Free the resources associated with this handle. lib.arrow_odbc_reader_free(self.handle) - def schema(self): + def schema(self) -> Schema: return _schema_from_handle(self.handle) def next_batch(self): @@ -72,7 +72,11 @@ def bind_buffers( max_binary_size: int, falliable_allocations: bool = False, schema: Optional[Schema] = None, + map_schema: Optional[Callable[[Schema], Schema]] = None, ): + if map_schema is not None: + schema = map_schema(self.schema()) + ptr_schema = _export_schema_to_c(schema) error = lib.arrow_odbc_reader_bind_buffers( @@ -144,6 +148,7 @@ def more_results( max_binary_size: Optional[int] = None, falliable_allocations: bool = False, schema: Optional[Schema] = None, + map_schema: Optional[Callable[[Schema], Schema]] = None, ) -> bool: """ Move the reader to the next result set returned by the data source. @@ -231,6 +236,7 @@ def more_results( max_binary_size=max_binary_size, falliable_allocations=falliable_allocations, schema=schema, + map_schema=map_schema, ) # Every result set can have its own schema, so we must update our member @@ -299,6 +305,7 @@ def read_arrow_batches_from_odbc( falliable_allocations: bool = False, login_timeout_sec: Optional[int] = None, schema: Optional[Schema] = None, + map_schema: Optional[Callable[[Schema], Schema]] = None, ) -> BatchReader: """ Execute the query and read the result as an iterator over Arrow batches. @@ -388,6 +395,11 @@ def read_arrow_batches_from_odbc( make sense to decide the type based on what you want to do with it, rather than its source. E.g. if you simply want to put everything into a CSV file it can make perfect sense to fetch everything as string independent of its source type. + :param map_schema: Allows you to provide a custom schema based on the schema inferred from the + metainformation of the query. This would allow you to e.g. map every column type to string + or replace any float32 with a float64, or anything else you might want to customize, for + various reasons while still staying generic over the input schema. If both ``map_schema`` + and ``schema`` are specified ``map_schema`` takes priority. :return: A ``BatchReader`` is returned, which implements the iterator protocol and iterates over individual arrow batches. """ @@ -463,6 +475,7 @@ def read_arrow_batches_from_odbc( max_binary_size=max_binary_size, falliable_allocations=falliable_allocations, schema=schema, + map_schema=map_schema, ) return BatchReader(reader) diff --git a/tests/test_arrow_odbc.py b/tests/test_arrow_odbc.py index cc3a1d9..16e4400 100644 --- a/tests/test_arrow_odbc.py +++ b/tests/test_arrow_odbc.py @@ -28,6 +28,7 @@ log_to_stderr() enable_odbc_connection_pooling() + def setup_table(table: str, column_type: str, values: List[Any]): connection = pyodbc.connect(MSSQL) connection.execute(f"DROP TABLE IF EXISTS {table};") @@ -138,7 +139,7 @@ def test_custom_schema_for_second_result_set(): schema = pa.schema([pa.field("a", pa.string())]) reader.more_results(batch_size=1, schema=schema) batch = next(iter(reader)) - + expected = pa.RecordBatch.from_pydict({"a": ["2"]}, schema) assert batch == expected @@ -264,7 +265,7 @@ def test_concurrent_reader_into_concurrent(): query=query, batch_size=100, connection_string=MSSQL ) reader.fetch_concurrently() - reader.fetch_concurrently() # Transforming already concurrent reader into concurrent reader + reader.fetch_concurrently() # Transforming already concurrent reader into concurrent reader it = iter(reader) actual = next(it) @@ -323,7 +324,9 @@ def test_timestamp_us(): Query a table with one row. Should return one batch """ table = "TimestampUs" - setup_table(table=table, column_type="DATETIME2(6)", values=["2014-04-14 21:25:42.074841"]) + setup_table( + table=table, column_type="DATETIME2(6)", values=["2014-04-14 21:25:42.074841"] + ) query = f"SELECT * FROM {table}" reader = read_arrow_batches_from_odbc( @@ -347,7 +350,9 @@ def test_timestamp_ns(): Query a table with one row. Should return one batch """ table = "TimestampNs" - setup_table(table=table, column_type="DATETIME2(7)", values=["2014-04-14 21:25:42.0748412"]) + setup_table( + table=table, column_type="DATETIME2(7)", values=["2014-04-14 21:25:42.0748412"] + ) query = f"SELECT * FROM {table}" reader = read_arrow_batches_from_odbc( @@ -371,7 +376,9 @@ def test_out_of_range_timestamp_ns(): Query a table with one row. Should return one batch """ table = "OutOfRangeTimestampNs" - setup_table(table=table, column_type="DATETIME2(7)", values=["2300-04-14 21:25:42.0748412"]) + setup_table( + table=table, column_type="DATETIME2(7)", values=["2300-04-14 21:25:42.0748412"] + ) query = f"SELECT * FROM {table}" @@ -474,7 +481,9 @@ def test_query_zero_sized_column(): characters. """ query = "SELECT CAST('a' AS VARCHAR(MAX)) as a" - with raises(Error, match="ODBC driver did not specify a sensible upper bound for the column"): + with raises( + Error, match="ODBC driver did not specify a sensible upper bound for the column" + ): read_arrow_batches_from_odbc( query=query, batch_size=100, connection_string=MSSQL ) @@ -489,7 +498,9 @@ def test_query_with_string_parameter(): connection = pyodbc.connect(MSSQL) connection.execute(f"DROP TABLE IF EXISTS {table};") connection.execute(f"CREATE TABLE {table} (a CHAR(1), b INTEGER);") - connection.execute(f"INSERT INTO {table} (a,b) VALUES ('A', 1),('B',2),('C',3),('D',4);") + connection.execute( + f"INSERT INTO {table} (a,b) VALUES ('A', 1),('B',2),('C',3),('D',4);" + ) connection.commit() connection.close() query = f"SELECT b FROM {table} WHERE a=?;" @@ -518,7 +529,9 @@ def test_query_with_none_parameter(): connection = pyodbc.connect(MSSQL) connection.execute(f"DROP TABLE IF EXISTS {table};") connection.execute(f"CREATE TABLE {table} (a CHAR(1), b INTEGER);") - connection.execute(f"INSERT INTO {table} (a,b) VALUES ('A', 1),('B',2),('C',3),('D',4);") + connection.execute( + f"INSERT INTO {table} (a,b) VALUES ('A', 1),('B',2),('C',3),('D',4);" + ) connection.commit() connection.close() @@ -541,7 +554,9 @@ def test_query_with_int_parameter(): connection = pyodbc.connect(MSSQL) connection.execute(f"DROP TABLE IF EXISTS {table};") connection.execute(f"CREATE TABLE {table} (a CHAR(1), b INTEGER);") - connection.execute(f"INSERT INTO {table} (a,b) VALUES ('A', 1),('B',2),('C',3),('D',4);") + connection.execute( + f"INSERT INTO {table} (a,b) VALUES ('A', 1),('B',2),('C',3),('D',4);" + ) connection.commit() connection.close() @@ -569,7 +584,7 @@ def test_query_timestamp_as_date(): batch = next(it) value = batch.to_pydict() - assert value == { "a": [datetime.date(2023, 12, 24)] } + assert value == {"a": [datetime.date(2023, 12, 24)]} def test_allocation_erros(): @@ -676,6 +691,44 @@ def test_support_varbinary_max(): next(it) +def test_map_f32_to_f64(): + """ + ODBC drivers for PostgreSQL seem to have some trouble reporting the precision of floating point + types correctly. Using schema mapping users of this wheel which know this quirk can adopt to it + while still staying generic over the database schema. + + See issue: https://github.com/pacman82/arrow-odbc-py/issues/73 + """ + # Given + table = "MapF32ToF64" + # MS driver is pretty good, so we actually create a 32Bit float by setting precision to 17. This + # way we simulate a driver reporting a too small floating point. + setup_table(table=table, column_type="Float(17)", values=[]) + query = f"SELECT (a) FROM {table}" + + # When + map_schema = lambda schema: pa.schema( + [ + ( + name, + ( + pa.float64() + if schema.field(name).type == pa.float32() + else schema.field(name).type + ), + ) + for name in schema.names + ] + ) + reader = read_arrow_batches_from_odbc( + query=query, batch_size=1, connection_string=MSSQL, map_schema=map_schema + ) + + # Then + expected = pa.schema([("a", pa.float64())]) + assert expected == reader.schema + + def test_insert_should_raise_on_invalid_connection_string(): """ Insert should raise on invalid connection string @@ -763,7 +816,9 @@ def test_insert_from_parquet(): table = "InsertFromParquet" connection = pyodbc.connect(MSSQL) connection.execute(f"DROP TABLE IF EXISTS {table};") - connection.execute(f"CREATE TABLE {table} (sepal_length REAL, sepal_width REAL, petal_length REAL, petal_width REAL, variety VARCHAR(20) );") + connection.execute( + f"CREATE TABLE {table} (sepal_length REAL, sepal_width REAL, petal_length REAL, petal_width REAL, variety VARCHAR(20) );" + ) connection.commit() connection.close() @@ -805,7 +860,16 @@ def iter_record_batches(): # Then actual = check_output( - ["odbcsv", "fetch", "-c", MSSQL, "--max-str-len", "2000","-q", f"SELECT a FROM {table}"] + [ + "odbcsv", + "fetch", + "-c", + MSSQL, + "--max-str-len", + "2000", + "-q", + f"SELECT a FROM {table}", + ] ) assert f"a\n{large_string}\n" == actual.decode("utf8")