Skip to content

Commit

Permalink
Fix: set quote_identifiers in qualify, add normalize flag in schema (#…
Browse files Browse the repository at this point in the history
…1701)

* Fix: set quote_identifiers in qualify, add normalize flag in schema

* import typing as t

* Fixup

* PR feedback

* Use new quote_identifiers rule before annotate_types

* Reset quote_identifiers kwarg to False in optimize

* Formatting

* Set kwargs instead of positional arguments in qualify

* Include quote_identifiers rule in test_canonicalize

* Formatting

* PR feedback

* Remove copy arg from quote_identifiers
  • Loading branch information
georgesittas authored May 30, 2023
1 parent 6045b74 commit 910166c
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 68 deletions.
2 changes: 1 addition & 1 deletion sqlglot/dataframe/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]:
for expression_type, select_expression in select_expressions:
select_expression = select_expression.transform(replace_id_value, replacement_mapping)
if optimize:
select_expression = optimize_func(select_expression, identify="always")
select_expression = t.cast(exp.Select, optimize_func(select_expression))
select_expression = df._replace_cte_names_with_hashes(select_expression)
expression: t.Union[exp.Select, exp.Cache, exp.Drop]
if expression_type == exp.Cache:
Expand Down
13 changes: 2 additions & 11 deletions sqlglot/optimizer/canonicalize.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,25 @@
from __future__ import annotations

import itertools
import typing as t

from sqlglot import exp
from sqlglot.optimizer.qualify_columns import quote_identifiers

if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType


def canonicalize(
expression: exp.Expression, identify: bool = True, dialect: DialectType = None
) -> exp.Expression:
def canonicalize(expression: exp.Expression) -> exp.Expression:
"""Converts a sql expression into a standard form.
This method relies on annotate_types because many of the
conversions rely on type inference.
Args:
expression: The expression to canonicalize.
identify: Whether or not to force identify identifier.
"""
exp.replace_children(expression, canonicalize, identify=identify, dialect=dialect)
exp.replace_children(expression, canonicalize)

expression = add_text_to_concat(expression)
expression = coerce_type(expression)
expression = remove_redundant_casts(expression)
expression = ensure_bool_predicates(expression)
expression = quote_identifiers(expression, dialect=dialect, identify=identify)

return expression

Expand Down
5 changes: 2 additions & 3 deletions sqlglot/optimizer/normalize_identifiers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from sqlglot import exp
from sqlglot._typing import E
from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE, DialectType


def normalize_identifiers(
expression: exp.Expression, dialect: DialectType = None
) -> exp.Expression:
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
"""
Normalize all unquoted identifiers to either lower or upper case, depending on
the dialect. This essentially makes those identifiers case-insensitive.
Expand Down
14 changes: 8 additions & 6 deletions sqlglot/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.qualify import qualify
from sqlglot.optimizer.qualify_columns import quote_identifiers
from sqlglot.optimizer.simplify import simplify
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
from sqlglot.schema import ensure_schema
Expand All @@ -31,6 +32,7 @@
merge_subqueries,
eliminate_joins,
eliminate_ctes,
quote_identifiers,
annotate_types,
canonicalize,
simplify,
Expand All @@ -45,7 +47,7 @@ def optimize(
dialect: DialectType = None,
rules: t.Sequence[t.Callable] = RULES,
**kwargs,
):
) -> exp.Expression:
"""
Rewrite a sqlglot AST into an optimized form.
Expand All @@ -63,11 +65,11 @@ def optimize(
dialect: The dialect to parse the sql string.
rules: sequence of optimizer rules to use.
Many of the rules require tables and columns to be qualified.
Do not remove qualify_tables or qualify_columns from the sequence of rules unless you know
what you're doing!
Do not remove `qualify` from the sequence of rules unless you know what you're doing!
**kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in.
Returns:
sqlglot.Expression: optimized expression
The optimized expression.
"""
schema = ensure_schema(schema or sqlglot.schema, dialect=dialect)
possible_kwargs = {
Expand All @@ -79,8 +81,8 @@ def optimize(
"quote_identifiers": False, # this happens in canonicalize
**kwargs,
}
expression = exp.maybe_parse(expression, dialect=dialect, copy=True)

expression = exp.maybe_parse(expression, dialect=dialect, copy=True)
for rule in rules:
# Find any additional rule parameters, beyond `expression`
rule_params = rule.__code__.co_varnames
Expand All @@ -89,4 +91,4 @@ def optimize(
}
expression = rule(expression, **rule_kwargs)

return expression
return t.cast(exp.Expression, expression)
25 changes: 14 additions & 11 deletions sqlglot/optimizer/qualify.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@

from sqlglot import exp
from sqlglot.dialects.dialect import DialectType
from sqlglot.optimizer import qualify_columns
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
from sqlglot.optimizer.qualify_columns import (
qualify_columns as qualify_columns_func,
quote_identifiers as quote_identifiers_func,
validate_qualify_columns as validate_qualify_columns_func,
)
from sqlglot.optimizer.qualify_tables import qualify_tables
from sqlglot.schema import Schema, ensure_schema

Expand All @@ -20,6 +24,7 @@ def qualify(
expand_alias_refs: bool = True,
infer_schema: t.Optional[bool] = None,
isolate_tables: bool = False,
qualify_columns: bool = True,
validate_qualify_columns: bool = True,
quote_identifiers: bool = True,
identify: bool = True,
Expand All @@ -44,11 +49,13 @@ def qualify(
expand_alias_refs: Whether or not to expand references to aliases.
infer_schema: Whether or not to infer the schema if missing.
isolate_tables: Whether or not to isolate table selects.
qualify_columns: Whether or not to qualify columns.
validate_qualify_columns: Whether or not to validate columns.
quote_identifiers: Whether or not to run the quote_identifiers step.
This step is necessary to ensure correctness for case sensitive queries.
But this flag is provided in case this step is performed at a later time.
identify: If True, quote all identifiers, else only necessary ones.
Returns:
The qualified expression.
"""
Expand All @@ -59,19 +66,15 @@ def qualify(
if isolate_tables:
expression = isolate_table_selects(expression, schema=schema)

expression = qualify_columns.qualify_columns(
expression,
schema,
expand_alias_refs=expand_alias_refs,
infer_schema=infer_schema,
)
if qualify_columns:
expression = qualify_columns_func(
expression, schema, expand_alias_refs=expand_alias_refs, infer_schema=infer_schema
)

if quote_identifiers:
expression = expression.transform(
qualify_columns.quote_identifiers, dialect, identify, copy=False
)
expression = quote_identifiers_func(expression, dialect=dialect, identify=identify)

if validate_qualify_columns:
qualify_columns.validate_qualify_columns(expression)
validate_qualify_columns_func(expression)

return expression
27 changes: 15 additions & 12 deletions sqlglot/optimizer/qualify_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import typing as t

from sqlglot import alias, exp
from sqlglot._typing import E
from sqlglot.dialects.dialect import DialectType
from sqlglot.errors import OptimizeError
from sqlglot.helper import case_sensitive, seq_get
Expand Down Expand Up @@ -414,19 +415,21 @@ def _qualify_outputs(scope):
scope.expression.set("expressions", new_selections)


def quote_identifiers(
expression: exp.Expression, dialect: DialectType, identify: bool
) -> exp.Expression:
def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
"""Makes sure all identifiers that need to be quoted are quoted."""
if isinstance(expression, exp.Identifier):
name = expression.this
expression.set(
"quoted",
identify
or case_sensitive(name, dialect=dialect)
or not exp.SAFE_IDENTIFIER_RE.match(name),
)
return expression

def _quote(expression: E) -> E:
if isinstance(expression, exp.Identifier):
name = expression.this
expression.set(
"quoted",
identify
or case_sensitive(name, dialect=dialect)
or not exp.SAFE_IDENTIFIER_RE.match(name),
)
return expression

return expression.transform(_quote, copy=False)


class Resolver:
Expand Down
5 changes: 4 additions & 1 deletion sqlglot/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,19 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
2. {db: {table: set(*cols)}}}
3. {catalog: {db: {table: set(*cols)}}}}
dialect: The dialect to be used for custom type mappings & parsing string arguments.
normalize: Whether to normalize identifier names according to the given dialect or not.
"""

def __init__(
self,
schema: t.Optional[t.Dict] = None,
visible: t.Optional[t.Dict] = None,
dialect: DialectType = None,
normalize: bool = True,
) -> None:
self.dialect = dialect
self.visible = visible or {}
self.normalize = normalize
self._type_mapping_cache: t.Dict[str, exp.DataType] = {}

super().__init__(self._normalize(schema or {}))
Expand Down Expand Up @@ -333,7 +336,7 @@ def _normalize_name(self, name: str | exp.Identifier, dialect: DialectType = Non

name = identifier.name

if identifier.quoted:
if not self.normalize or identifier.quoted:
return name

return name.upper() if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else name.lower()
Expand Down
4 changes: 2 additions & 2 deletions tests/fixtures/optimizer/canonicalize.sql
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ SELECT CAST(1 AS VARCHAR) AS "a" FROM "w" AS "w";
SELECT CAST(1 + 3.2 AS DOUBLE) AS a FROM w AS w;
SELECT 1 + 3.2 AS "a" FROM "w" AS "w";

SELECT CAST("2022-01-01" AS DATE) + INTERVAL '1' day;
SELECT CAST("2022-01-01" AS DATE) + INTERVAL '1' day AS "_col_0";
SELECT CAST('2022-01-01' AS DATE) + INTERVAL '1' day;
SELECT CAST('2022-01-01' AS DATE) + INTERVAL '1' day AS "_col_0";

--------------------------------------
-- Ensure boolean predicates
Expand Down
37 changes: 16 additions & 21 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,8 @@ def test_canonicalize(self):
optimize = partial(
optimizer.optimize,
rules=[
optimizer.qualify_tables.qualify_tables,
optimizer.qualify_columns.qualify_columns,
optimizer.qualify.qualify,
optimizer.qualify_columns.quote_identifiers,
annotate_types,
optimizer.canonicalize.canonicalize,
],
Expand Down Expand Up @@ -699,23 +699,18 @@ def test_quotes(self):
}
}

self.assertEqual(
optimizer.qualify.qualify(
parse_one(
"""
SELECT * FROM example."source"
"""
),
dialect="snowflake",
schema=schema,
).sql(pretty=True),
parse_one(
"""
SELECT
"source"."ID" AS "ID",
"source"."name" AS "name",
"source"."payload" AS "payload"
FROM "EXAMPLE"."source" AS "source"
expected = parse_one(
"""
).sql(pretty=True),
)
SELECT
"source"."ID" AS "ID",
"source"."name" AS "name",
"source"."payload" AS "payload"
FROM "EXAMPLE"."source" AS "source"
""",
read="snowflake",
).sql(pretty=True, dialect="snowflake")

for func in (optimizer.qualify.qualify, optimizer.optimize):
source_query = parse_one('SELECT * FROM example."source"', read="snowflake")
transformed = func(source_query, dialect="snowflake", schema=schema)
self.assertEqual(transformed.sql(pretty=True, dialect="snowflake"), expected)
4 changes: 4 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,7 @@ def test_schema_normalization(self):
# Check that names are normalized to uppercase for Snowflake
schema = MappingSchema(schema={"x": {"foo": "int", '"bLa"': "int"}}, dialect="snowflake")
self.assertEqual(schema.column_names(exp.Table(this="x")), ["FOO", "bLa"])

# Check that switching off the normalization logic works as expected
schema = MappingSchema(schema={"x": {"foo": "int"}}, normalize=False, dialect="snowflake")
self.assertEqual(schema.column_names(exp.Table(this="x")), ["foo"])

0 comments on commit 910166c

Please sign in to comment.