Skip to content

Commit

Permalink
Fix: remove unconditional expression copy (#1611)
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao authored May 14, 2023
1 parent 31a82cc commit 29e5af2
Show file tree
Hide file tree
Showing 9 changed files with 28 additions and 20 deletions.
6 changes: 3 additions & 3 deletions sqlglot/dataframe/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,9 +570,9 @@ def unionByName(self, other: DataFrame, allowMissingColumns: bool = False):
r_expressions.append(l_column)
r_columns_unused.remove(l_column)
else:
r_expressions.append(exp.alias_(exp.Null(), l_column))
r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False))
for r_column in r_columns_unused:
l_expressions.append(exp.alias_(exp.Null(), r_column))
l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False))
r_expressions.append(r_column)
r_df = (
other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
Expand Down Expand Up @@ -761,7 +761,7 @@ def withColumnRenamed(self, existing: str, new: str):
raise ValueError("Tried to rename a column that doesn't exist")
for existing_column in existing_columns:
if isinstance(existing_column, exp.Column):
existing_column.replace(exp.alias_(existing_column.copy(), new))
existing_column.replace(exp.alias_(existing_column, new))
else:
existing_column.set("alias", exp.to_identifier(new))
return self.copy(expression=expression)
Expand Down
7 changes: 3 additions & 4 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4951,6 +4951,7 @@ def alias_(
table: bool | t.Sequence[str | Identifier] = False,
quoted: t.Optional[bool] = None,
dialect: DialectType = None,
copy: bool = True,
**opts,
):
"""Create an Alias expression.
Expand All @@ -4970,18 +4971,17 @@ def alias_(
table: Whether or not to create a table alias, can also be a list of columns.
quoted: whether or not to quote the alias
dialect: the dialect used to parse the input expression.
copy: Whether or not to copy the expression.
**opts: other options to use to parse the input expressions.
Returns:
Alias: the aliased expression
"""
exp = maybe_parse(expression, dialect=dialect, **opts)
exp = maybe_parse(expression, dialect=dialect, copy=copy, **opts)
alias = to_identifier(alias, quoted=quoted)

if table:
table_alias = TableAlias(this=alias)

exp = exp.copy() if isinstance(expression, Expression) else exp
exp.set("alias", table_alias)

if not isinstance(table, bool):
Expand All @@ -4997,7 +4997,6 @@ def alias_(
# [1]: https://cloud.google.com/bigquery/docs/reference/standard-sql/window-function-calls

if "alias" in exp.arg_types and not isinstance(exp, Window):
exp = exp.copy()
exp.set("alias", alias)
return exp
return Alias(this=exp, alias=alias)
Expand Down
8 changes: 5 additions & 3 deletions sqlglot/optimizer/eliminate_subqueries.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,17 @@ def _eliminate_union(scope, existing_ctes, taken):
# Try to maintain the selections
expressions = scope.selects
selects = [
exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name)
exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name, copy=False)
for e in expressions
if e.alias_or_name
]
# If not all selections have an alias, just select *
if len(selects) != len(expressions):
selects = ["*"]

scope.expression.replace(exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias)))
scope.expression.replace(
exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias, copy=False))
)

if not duplicate_cte_alias:
existing_ctes[scope.expression] = alias
Expand Down Expand Up @@ -153,7 +155,7 @@ def _eliminate_cte(scope, existing_ctes, taken):
for child_scope in scope.parent.traverse():
for table, source in child_scope.selected_sources.values():
if source is scope:
new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name)
new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name, copy=False)
table.replace(new_table)

return cte
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/optimizer/isolate_table_selects.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def isolate_table_selects(expression, schema=None):
source.replace(
exp.select("*")
.from_(
alias(source.copy(), source.name or source.alias, table=True),
alias(source, source.name or source.alias, table=True),
copy=False,
)
.subquery(source.alias, copy=False)
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/optimizer/merge_subqueries.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
elif isinstance(source, exp.Table) and source.alias:
source.set("alias", new_alias)
elif isinstance(source, exp.Table):
source.replace(exp.alias_(source.copy(), new_alias))
source.replace(exp.alias_(source, new_alias))

for column in inner_scope.source_columns(conflict):
column.set("table", exp.to_identifier(new_name))
Expand Down
4 changes: 3 additions & 1 deletion sqlglot/optimizer/pushdown_projections.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ def _remove_unused_selections(scope, parent_selections, schema):

for name in sorted(parent_selections):
if name not in names:
new_selections.append(alias(exp.column(name, table=resolver.get_table(name)), name))
new_selections.append(
alias(exp.column(name, table=resolver.get_table(name)), name, copy=False)
)

# If there are no remaining selections, just select a single constant
if not new_selections:
Expand Down
12 changes: 8 additions & 4 deletions sqlglot/optimizer/qualify_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _expand_using(scope, resolver):

# Ensure selects keep their output name
if isinstance(column.parent, exp.Select):
replacement = exp.alias_(replacement, alias=column.name)
replacement = alias(replacement, alias=column.name, copy=False)

scope.replace(column, replacement)

Expand Down Expand Up @@ -311,14 +311,18 @@ def _expand_stars(scope, resolver, using_column_tables):
coalesce = [exp.column(name, table=table) for table in tables]

new_selections.append(
exp.alias_(
exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), alias=name
alias(
exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
alias=name,
copy=False,
)
)
elif name not in except_columns.get(table_id, set()):
alias_ = replace_columns.get(table_id, {}).get(name, name)
column = exp.column(name, table)
new_selections.append(alias(column, alias_) if alias_ != name else column)
new_selections.append(
alias(column, alias_, copy=False) if alias_ != name else column
)
else:
return
scope.expression.set("expressions", new_selections)
Expand Down
3 changes: 2 additions & 1 deletion sqlglot/optimizer/qualify_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
if not source.alias:
source = source.replace(
alias(
source.copy(),
source,
name if name else next_name(),
copy=True,
table=True,
)
)
Expand Down
4 changes: 2 additions & 2 deletions sqlglot/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
for select in expression.selects:
if not select.alias_or_name:
alias = find_new_name(taken, "_c")
select.replace(exp.alias_(select.copy(), alias))
select.replace(exp.alias_(select, alias))
taken.add(alias)

outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
Expand All @@ -102,7 +102,7 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
for expr in qualify_filters.find_all((exp.Window, exp.Column)):
if isinstance(expr, exp.Window):
alias = find_new_name(expression.named_selects, "_w")
expression.select(exp.alias_(expr.copy(), alias), copy=False)
expression.select(exp.alias_(expr, alias), copy=False)
column = exp.column(alias)
if isinstance(expr.parent, exp.Qualify):
qualify_filters = column
Expand Down

0 comments on commit 29e5af2

Please sign in to comment.