Skip to content

Commit

Permalink
Refactor: factor out the the name sequence generation logic (#1716)
Browse files Browse the repository at this point in the history
* Refactor: factor out the the name sequence generation logic

* Formatting

* Add test

* Formatting
  • Loading branch information
georgesittas authored Jun 2, 2023
1 parent 92dbace commit 5d6fbfe
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

ColumnLiterals = t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
ColumnOrName = t.Union[Column, str]
ColumnOrLiteral = t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
ColumnOrLiteral = t.Union[
Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime
]
SchemaInput = t.Union[str, t.List[str], StructType, t.Dict[str, t.Optional[str]]]
OutputExpressionContainer = t.Union[exp.Select, exp.Create, exp.Insert]
7 changes: 7 additions & 0 deletions sqlglot/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from contextlib import contextmanager
from copy import copy
from enum import Enum
from itertools import count

if t.TYPE_CHECKING:
from sqlglot import exp
Expand Down Expand Up @@ -303,6 +304,12 @@ def find_new_name(taken: t.Collection[str], base: str) -> str:
return new


def name_sequence(prefix: str) -> t.Callable[[], str]:
"""Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
sequence = count()
return lambda: f"{prefix}{next(sequence)}"


def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
"""Returns a dictionary created from an object's attributes."""
return {
Expand Down
38 changes: 23 additions & 15 deletions sqlglot/optimizer/qualify_tables.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
import itertools
import typing as t

from sqlglot import alias, exp
from sqlglot.helper import csv_reader
from sqlglot._typing import E
from sqlglot.helper import csv_reader, name_sequence
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import Schema


def qualify_tables(expression, db=None, catalog=None, schema=None):
def qualify_tables(
expression: E,
db: t.Optional[str] = None,
catalog: t.Optional[str] = None,
schema: t.Optional[Schema] = None,
) -> E:
"""
Rewrite sqlglot AST to have fully qualified tables. Additionally, this
replaces "join constructs" (*) by equivalent SELECT * subqueries.
Expand All @@ -21,19 +29,17 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
'SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS _q_0'
Args:
expression (sqlglot.Expression): expression to qualify
db (str): Database name
catalog (str): Catalog name
expression: Expression to qualify
db: Database name
catalog: Catalog name
schema: A schema to populate
Returns:
sqlglot.Expression: qualified expression
The qualified expression.
(*) See section 7.2.1.2 in https://www.postgresql.org/docs/current/queries-table-expressions.html
"""
sequence = itertools.count()

next_name = lambda: f"_q_{next(sequence)}"
next_alias_name = name_sequence("_q_")

for scope in traverse_scope(expression):
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
Expand All @@ -44,13 +50,13 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))

if not derived_table.args.get("alias"):
alias_ = next_name()
alias_ = next_alias_name()
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
scope.rename_source(None, alias_)

pivots = derived_table.args.get("pivots")
if pivots and not pivots[0].alias:
pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_name())))
pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())))

for name, source in scope.sources.items():
if isinstance(source, exp.Table):
Expand All @@ -64,15 +70,17 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
source = source.replace(
alias(
source,
name or source.name or next_name(),
name or source.name or next_alias_name(),
copy=True,
table=True,
)
)

pivots = source.args.get("pivots")
if pivots and not pivots[0].alias:
pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_name())))
pivots[0].set(
"alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))
)

if schema and isinstance(source.this, exp.ReadCSV):
with csv_reader(source.this) as reader:
Expand All @@ -83,11 +91,11 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
)
elif isinstance(source, Scope) and source.is_udtf:
udtf = source.expression
table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_name())
table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_alias_name())
udtf.set("alias", table_alias)

if not table_alias.name:
table_alias.set("this", next_name())
table_alias.set("this", next_alias_name())
if isinstance(udtf, exp.Values) and not table_alias.columns:
for i, e in enumerate(udtf.expressions[0].expressions):
table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
Expand Down
23 changes: 9 additions & 14 deletions sqlglot/optimizer/unnest_subqueries.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import itertools

from sqlglot import exp
from sqlglot.helper import name_sequence
from sqlglot.optimizer.scope import ScopeType, traverse_scope


Expand All @@ -22,27 +21,27 @@ def unnest_subqueries(expression):
Returns:
sqlglot.Expression: unnested expression
"""
sequence = itertools.count()
next_alias_name = name_sequence("_u_")

for scope in traverse_scope(expression):
select = scope.expression
parent = select.parent_select
if not parent:
continue
if scope.external_columns:
decorrelate(select, parent, scope.external_columns, sequence)
decorrelate(select, parent, scope.external_columns, next_alias_name)
elif scope.scope_type == ScopeType.SUBQUERY:
unnest(select, parent, sequence)
unnest(select, parent, next_alias_name)

return expression


def unnest(select, parent_select, sequence):
def unnest(select, parent_select, next_alias_name):
if len(select.selects) > 1:
return

predicate = select.find_ancestor(exp.Condition)
alias = _alias(sequence)
alias = next_alias_name()

if not predicate or parent_select is not predicate.parent_select:
return
Expand Down Expand Up @@ -87,13 +86,13 @@ def unnest(select, parent_select, sequence):
)


def decorrelate(select, parent_select, external_columns, sequence):
def decorrelate(select, parent_select, external_columns, next_alias_name):
where = select.args.get("where")

if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset):
return

table_alias = _alias(sequence)
table_alias = next_alias_name()
keys = []

# for all external columns in the where statement, find the relevant predicate
Expand Down Expand Up @@ -136,7 +135,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
group_by.append(key)
else:
if key not in key_aliases:
key_aliases[key] = _alias(sequence)
key_aliases[key] = next_alias_name()
# all predicates that are equalities must also be in the unique
# so that we don't do a many to many join
if isinstance(predicate, exp.EQ) and key not in group_by:
Expand Down Expand Up @@ -244,10 +243,6 @@ def remove_aggs(node):
)


def _alias(sequence):
return f"_u_{next(sequence)}"


def _replace(expression, condition):
return expression.replace(exp.condition(condition))

Expand Down
6 changes: 3 additions & 3 deletions sqlglot/planner.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

import itertools
import math
import typing as t

from sqlglot import alias, exp
from sqlglot.helper import name_sequence
from sqlglot.optimizer.eliminate_joins import join_condition


Expand Down Expand Up @@ -121,15 +121,15 @@ def from_expression(
projections = [] # final selects in this chain of steps representing a select
operands = {} # intermediate computations of agg funcs eg x + 1 in SUM(x + 1)
aggregations = []
sequence = itertools.count()
next_operand_name = name_sequence("_a_")

def extract_agg_operands(expression):
for agg in expression.find_all(exp.AggFunc):
for operand in agg.unnest_operands():
if isinstance(operand, exp.Column):
continue
if operand not in operands:
operands[operand] = f"_a_{next(sequence)}"
operands[operand] = next_operand_name()
operand.replace(exp.column(operands[operand], quoted=True))

for e in expression.expressions:
Expand Down
6 changes: 2 additions & 4 deletions sqlglot/transforms.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from __future__ import annotations

import itertools
import typing as t

from sqlglot import expressions as exp
from sqlglot.helper import find_new_name
from sqlglot.helper import find_new_name, name_sequence

if t.TYPE_CHECKING:
from sqlglot.generator import Generator
Expand Down Expand Up @@ -253,8 +252,7 @@ def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expre

def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.With) and expression.recursive:
sequence = itertools.count()
next_name = lambda: f"_c_{next(sequence)}"
next_name = name_sequence("_c_")

for cte in expression.expressions:
if not cte.args["alias"].columns:
Expand Down
13 changes: 12 additions & 1 deletion tests/test_helper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest

from sqlglot.dialects import BigQuery, Dialect, Snowflake
from sqlglot.helper import tsort
from sqlglot.helper import name_sequence, tsort


class TestHelper(unittest.TestCase):
Expand Down Expand Up @@ -56,3 +56,14 @@ def test_compare_dialects(self):
self.assertTrue(snowflake_object in {"snowflake", "bigquery"})
self.assertFalse(snowflake_class in {"bigquery", "redshift"})
self.assertFalse(snowflake_object in {"bigquery", "redshift"})

def test_name_sequence(self):
s1 = name_sequence("a")
s2 = name_sequence("b")

self.assertEqual(s1(), "a0")
self.assertEqual(s1(), "a1")
self.assertEqual(s2(), "b0")
self.assertEqual(s1(), "a2")
self.assertEqual(s2(), "b1")
self.assertEqual(s2(), "b2")

0 comments on commit 5d6fbfe

Please sign in to comment.