diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index aa56b83ff1..ce36705b55 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -1,4 +1,5 @@ import itertools +import logging import typing as t from collections import defaultdict from enum import Enum, auto @@ -7,6 +8,8 @@ from sqlglot.errors import OptimizeError from sqlglot.helper import find_new_name +logger = logging.getLogger("sqlglot") + class ScopeType(Enum): ROOT = auto() @@ -536,7 +539,11 @@ def _traverse_scope(scope): elif isinstance(scope.expression, exp.UDTF): pass else: - raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}") + logger.warning( + "Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression) + ) + return + yield scope @@ -576,6 +583,8 @@ def _traverse_ctes(scope): if isinstance(union, exp.Union): recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE) + child_scope = None + for child_scope in _traverse_scope( scope.branch( cte.this, @@ -593,7 +602,8 @@ def _traverse_ctes(scope): child_scope.add_source(alias, recursive_scope) # append the final child_scope yielded - scope.cte_scopes.append(child_scope) + if child_scope: + scope.cte_scopes.append(child_scope) scope.sources.update(sources) @@ -634,6 +644,9 @@ def _traverse_tables(scope): sources[source_name] = expression continue + if not isinstance(expression, exp.DerivedTable): + continue + if isinstance(expression, exp.UDTF): lateral_sources = sources scope_type = ScopeType.UDTF diff --git a/tests/helpers.py b/tests/helpers.py index 30aeff76ca..cc085b3eff 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -22,7 +22,9 @@ def _extract_meta(sql): def assert_logger_contains(message, logger, level="error"): output = "\n".join(str(args[0][0]) for args in getattr(logger, level).call_args_list) - assert message in output + if message not in output: + print(f"Expected '{message}' not in {output}") + raise def load_sql_fixtures(filename): diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index b7425af518..40eef9fe14 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -1,6 +1,7 @@ import unittest from concurrent.futures import ProcessPoolExecutor, as_completed from functools import partial +from unittest.mock import patch import duckdb from pandas.testing import assert_frame_equal @@ -14,6 +15,7 @@ from tests.helpers import ( TPCDS_SCHEMA, TPCH_SCHEMA, + assert_logger_contains, load_sql_fixture_pairs, load_sql_fixtures, string_to_bool, @@ -411,6 +413,15 @@ def test_scope(self): {"s.b"}, ) + @patch("sqlglot.optimizer.scope.logger") + def test_scope_warning(self, logger): + self.assertEqual(len(traverse_scope(parse_one("WITH q AS (@y) SELECT * FROM q"))), 1) + assert_logger_contains( + "Cannot traverse scope %s with type '%s'", + logger, + level="warning", + ) + def test_literal_type_annotation(self): tests = { "SELECT 5": exp.DataType.Type.INT,