Skip to content

Commit

Permalink
[sqlachemy] Improve the performance of get_view_names
Browse files Browse the repository at this point in the history
  • Loading branch information
john-bodley committed Oct 14, 2022
1 parent 68f3d9d commit a5a6c75
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 3 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"pytest",
"pytest-runner",
"click",
"sqlalchemy-utils",
]

setup(
Expand Down
59 changes: 57 additions & 2 deletions tests/integration/test_sqlalchemy_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from unittest.mock import patch

import pytest
import sqlalchemy as sqla
from sqlalchemy.sql import and_, or_, not_
from sqlalchemy_utils import create_view

from tests.unit.conftest import sqlalchemy_version
from trino.sqlalchemy.datatype import JSON
Expand All @@ -20,8 +23,16 @@
@pytest.fixture
def trino_connection(run_trino, request):
_, host, port = run_trino
engine = sqla.create_engine(f"trino://test@{host}:{port}/{request.param}",
connect_args={"source": "test", "max_attempts": 1})

engine = sqla.create_engine(
f"trino://test@{host}:{port}/{request.param}",
connect_args={
"max_attempts": 1,
"schema": "test",
"source": "test",
},
)

yield engine, engine.connect()


Expand Down Expand Up @@ -368,3 +379,47 @@ def test_get_table_comment(trino_connection):
assert actual['text'] is None
finally:
metadata.drop_all(engine)


@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True)
@pytest.mark.parametrize('schema', [None, 'test'])
def test_get_view_names(trino_connection, schema):
engine, conn = trino_connection
insp = sqla.inspect(engine)
name = schema or insp.dialect._get_default_schema_name(conn)
metadata = sqla.MetaData(schema=name)

if not engine.dialect.has_schema(engine, name):
engine.execute(sqla.schema.CreateSchema(name))

try:
create_view(
'my_view',
sqla.select(
[
sqla.Table(
'my_table',
metadata,
sqla.Column('id', sqla.Integer),
),
],
),
metadata,
cascade_on_drop=False,
)

metadata.create_all(engine)
assert insp.get_view_names(schema) == ['my_view']
finally:
metadata.drop_all(engine)


@patch('trino.sqlalchemy.dialect.TrinoDialect._get_default_schema_name')
@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True)
def test_get_view_names_raises(mock_get_default_schema_name, trino_connection):
engine, _ = trino_connection
insp = sqla.inspect(engine)
mock_get_default_schema_name.return_value = None

with pytest.raises(sqla.exc.NoSuchTableError):
insp.get_view_names(None)
3 changes: 2 additions & 1 deletion trino/sqlalchemy/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,9 @@ def get_view_names(self, connection: Connection, schema: str = None, **kw) -> Li
query = dedent(
"""
SELECT "table_name"
FROM "information_schema"."views"
FROM "information_schema"."tables"
WHERE "table_schema" = :schema
AND "table_type" = 'VIEW'
"""
).strip()
res = connection.execute(sql.text(query), schema=schema)
Expand Down

0 comments on commit a5a6c75

Please sign in to comment.