Skip to content

Commit

Permalink
Fix spurious warnings and bogus index when reflecting Iceberg tables
Browse files Browse the repository at this point in the history
  • Loading branch information
metadaddy committed Jan 24, 2025
1 parent 24cc388 commit 264ea95
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
8 changes: 8 additions & 0 deletions tests/integration/test_sqlalchemy_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def trino_connection(run_trino, request):
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
@pytest.mark.skipif(trino_version() == 351, reason="connector name is not available in older Trino versions")
@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
def test_select_query(trino_connection):
_, conn = trino_connection
Expand Down Expand Up @@ -71,6 +72,7 @@ def assert_column(table, column_name, column_type):
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
@pytest.mark.skipif(trino_version() == 351, reason="connector name is not available in older Trino versions")
@pytest.mark.parametrize('trino_connection', ['system'], indirect=True)
def test_select_specific_columns(trino_connection):
_, conn = trino_connection
Expand Down Expand Up @@ -240,6 +242,7 @@ def test_insert_multiple_statements(trino_connection):
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
@pytest.mark.skipif(trino_version() == 351, reason="connector name is not available in older Trino versions")
@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
def test_operators(trino_connection):
_, conn = trino_connection
Expand All @@ -260,6 +263,7 @@ def test_operators(trino_connection):
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
@pytest.mark.skipif(trino_version() == 351, reason="connector name is not available in older Trino versions")
@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
def test_conjunctions(trino_connection):
_, conn = trino_connection
Expand Down Expand Up @@ -300,6 +304,7 @@ def test_textual_sql(trino_connection):
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
@pytest.mark.skipif(trino_version() == 351, reason="connector name is not available in older Trino versions")
@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
def test_alias(trino_connection):
_, conn = trino_connection
Expand All @@ -323,6 +328,7 @@ def test_alias(trino_connection):
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
@pytest.mark.skipif(trino_version() == 351, reason="connector name is not available in older Trino versions")
@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
def test_subquery(trino_connection):
_, conn = trino_connection
Expand All @@ -341,6 +347,7 @@ def test_subquery(trino_connection):
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
@pytest.mark.skipif(trino_version() == 351, reason="connector name is not available in older Trino versions")
@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
def test_joins(trino_connection):
_, conn = trino_connection
Expand All @@ -360,6 +367,7 @@ def test_joins(trino_connection):
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
@pytest.mark.skipif(trino_version() == 351, reason="connector name is not available in older Trino versions")
@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
def test_cte(trino_connection):
_, conn = trino_connection
Expand Down
23 changes: 22 additions & 1 deletion trino/sqlalchemy/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,18 @@ def _get_partitions(
partition_names = [desc[0] for desc in res.cursor.description]
return partition_names

def _get_connector_name(self, connection: Connection, catalog_name: str):
query = dedent(
"""
SELECT
"connector_name"
FROM "system"."metadata"."catalogs"
WHERE "catalog_name" = :catalog_name
"""
).strip()
res = connection.execute(sql.text(query), {"catalog_name": catalog_name})
return res.scalar()

def get_pk_constraint(self, connection: Connection, table_name: str, schema: str = None, **kw) -> Dict[str, Any]:
"""Trino has no support for primary keys. Returns a dummy"""
return dict(name=None, constrained_columns=[])
Expand Down Expand Up @@ -322,11 +334,20 @@ def get_indexes(self, connection: Connection, table_name: str, schema: str = Non
if not self.has_table(connection, table_name, schema):
raise exc.NoSuchTableError(f"schema={schema}, table={table_name}")

catalog_name = self._get_default_catalog_name(connection)
if catalog_name is None:
raise exc.NoSuchTableError("catalog is required in connection")
connector_name = self._get_connector_name(connection, catalog_name)
if connector_name is None:
raise exc.NoSuchTableError("connector name is required")
if connector_name != "hive":
return []

partitioned_columns = None
try:
partitioned_columns = self._get_partitions(connection, f"{table_name}", schema)
except Exception as e:
# e.g. it's not a Hive table or an unpartitioned Hive table
# e.g. it's an unpartitioned Hive table
logger.debug("Couldn't fetch partition columns. schema: %s, table: %s, error: %s", schema, table_name, e)
if not partitioned_columns:
return []
Expand Down

0 comments on commit 264ea95

Please sign in to comment.