diff --git a/setup.py b/setup.py index 2695a530..f056c841 100755 --- a/setup.py +++ b/setup.py @@ -39,6 +39,7 @@ "pytest", "pytest-runner", "click", + "sqlalchemy-utils", ] setup( diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index 8fca22c1..c6de2eef 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -12,6 +12,7 @@ import pytest import sqlalchemy as sqla from sqlalchemy.sql import and_, or_, not_ +from sqlalchemy_utils import create_view from tests.unit.conftest import sqlalchemy_version from trino.sqlalchemy.datatype import JSON @@ -368,3 +369,43 @@ def test_get_table_comment(trino_connection): assert actual['text'] is None finally: metadata.drop_all(engine) + + +@pytest.mark.parametrize('trino_connection', ['memory/test'], indirect=True) +@pytest.mark.parametrize('schema', [None, 'test']) +def test_get_view_names(trino_connection, schema): + engine, conn = trino_connection + name = schema or engine.dialect._get_default_schema_name(conn) + metadata = sqla.MetaData(schema=name) + + if not engine.dialect.has_schema(conn, name): + engine.execute(sqla.schema.CreateSchema(name)) + + try: + create_view( + 'my_view', + sqla.select( + [ + sqla.Table( + 'my_table', + metadata, + sqla.Column('id', sqla.Integer), + ), + ], + ), + metadata, + cascade_on_drop=False, + ) + + metadata.create_all(engine) + assert sqla.inspect(engine).get_view_names(schema) == ['my_view'] + finally: + metadata.drop_all(engine) + + +@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True) +def test_get_view_names_raises(trino_connection): + engine, _ = trino_connection + + with pytest.raises(sqla.exc.NoSuchTableError): + sqla.inspect(engine).get_view_names(None) diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py index b200771c..ffe928ed 100644 --- a/trino/sqlalchemy/dialect.py +++ b/trino/sqlalchemy/dialect.py @@ -216,8 +216,9 @@ def get_view_names(self, connection: Connection, schema: str = None, **kw) -> Li query = dedent( """ SELECT "table_name" - FROM "information_schema"."views" + FROM "information_schema"."tables" WHERE "table_schema" = :schema + AND "table_type" = 'VIEW' """ ).strip() res = connection.execute(sql.text(query), schema=schema)