Skip to content

Commit

Permalink
Fix: remove several mutations in Generator methods (#2009)
Browse files Browse the repository at this point in the history
* Fix: remove several mutations in Generator methods

* Make mypy happy

* Fixup

* Fixup

* Fixup
  • Loading branch information
georgesittas authored Aug 8, 2023
1 parent ae080cb commit 289493b
Show file tree
Hide file tree
Showing 13 changed files with 61 additions and 79 deletions.
12 changes: 4 additions & 8 deletions sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def func(self, expression):
this = self.sql(expression, "this")
unit = expression.args.get("unit")
unit = exp.var(unit.name.upper() if unit else "DAY")
interval = exp.Interval(this=expression.expression, unit=unit)
interval = exp.Interval(this=expression.expression.copy(), unit=unit)
return f"{data_type}_{kind}({this}, {self.sql(interval)})"

return func
Expand Down Expand Up @@ -76,16 +76,12 @@ def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsPrope
def _create_sql(self: generator.Generator, expression: exp.Create) -> str:
kind = expression.args["kind"]
returns = expression.find(exp.ReturnsProperty)

if kind.upper() == "FUNCTION" and returns and returns.args.get("is_table"):
expression = expression.copy()
expression.set("kind", "TABLE FUNCTION")
if isinstance(
expression.expression,
(
exp.Subquery,
exp.Literal,
),
):

if isinstance(expression.expression, (exp.Subquery, exp.Literal)):
expression.set("expression", expression.expression.this)

return self.create_sql(expression)
Expand Down
8 changes: 3 additions & 5 deletions sqlglot/dialects/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ class Generator(generator.Generator):

def safeconcat_sql(self, expression: exp.SafeConcat) -> str:
# Clickhouse errors out if we try to cast a NULL value to TEXT
expression = expression.copy()
return self.func(
"CONCAT",
*[
Expand Down Expand Up @@ -389,11 +390,7 @@ def placeholder_sql(self, expression: exp.Placeholder) -> str:
def oncluster_sql(self, expression: exp.OnCluster) -> str:
return f"ON CLUSTER {self.sql(expression, 'this')}"

def createable_sql(
self,
expression: exp.Create,
locations: dict[exp.Properties.Location, list[exp.Property]],
) -> str:
def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str:
kind = self.sql(expression, "kind").upper()
if kind in self.ON_CLUSTER_TARGETS and locations.get(exp.Properties.Location.POST_NAME):
this_name = self.sql(expression.this, "this")
Expand All @@ -402,4 +399,5 @@ def createable_sql(
)
this_schema = self.schema_columns_sql(expression.this)
return f"{this_name}{self.sep()}{this_properties}{self.sep()}{this_schema}"

return super().createable_sql(expression, locations)
12 changes: 7 additions & 5 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,9 @@ def inline_array_sql(self: Generator, expression: exp.Array) -> str:

def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
return self.like_sql(
exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
exp.Like(
this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy()
)
)


Expand Down Expand Up @@ -410,7 +412,7 @@ def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:

def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
this = self.sql(expression, "this")
struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
struct_key = self.sql(exp.Identifier(this=expression.expression.copy(), quoted=True))
return f"{this}.{struct_key}"


Expand Down Expand Up @@ -599,7 +601,7 @@ def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
cond = expression.this.expressions[0]
self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")

return self.func("sum", exp.func("if", cond, 1, 0))
return self.func("sum", exp.func("if", cond.copy(), 1, 0))


def trim_sql(self: Generator, expression: exp.Trim) -> str:
Expand Down Expand Up @@ -636,6 +638,7 @@ def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:


def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
expression = expression.copy()
this, *rest_args = expression.expressions
for arg in rest_args:
this = exp.DPipe(this=this, expression=arg)
Expand Down Expand Up @@ -685,11 +688,10 @@ def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectTyp
return names


def simplify_literal(expression: E, copy: bool = True) -> E:
def simplify_literal(expression: E) -> E:
if not isinstance(expression.expression, exp.Literal):
from sqlglot.optimizer.simplify import simplify

expression = exp.maybe_copy(expression, copy)
simplify(expression.expression)

return expression
Expand Down
6 changes: 2 additions & 4 deletions sqlglot/dialects/drill.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e
def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
unit = exp.var(expression.text("unit").upper() or "DAY")
return (
f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})"
)
return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})"

return func

Expand Down Expand Up @@ -145,7 +143,7 @@ class Generator(generator.Generator):
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.TryCast: no_trycast_sql,
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})",
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression.copy(), unit=exp.var('DAY')))})",
exp.TsOrDsToDate: ts_or_ds_to_date_sql("drill"),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
}
Expand Down
10 changes: 5 additions & 5 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@
def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))}"


def _date_delta_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
op = "+" if isinstance(expression, exp.DateAdd) else "-"
return f"{this} {op} {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
return f"{this} {op} {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))}"


# BigQuery -> DuckDB conversion for the DATE function
Expand Down Expand Up @@ -241,8 +241,8 @@ class Generator(generator.Generator):
exp.MonthsBetween: lambda self, e: self.func(
"DATEDIFF",
"'month'",
exp.cast(e.expression, "timestamp"),
exp.cast(e.this, "timestamp"),
exp.cast(e.expression, "timestamp", copy=True),
exp.cast(e.this, "timestamp", copy=True),
),
exp.Properties: no_properties_sql,
exp.RegexpExtract: regexp_extract_sql,
Expand Down Expand Up @@ -303,7 +303,7 @@ def interval_sql(self, expression: exp.Interval) -> str:
multiplier = 90

if multiplier:
return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('day')))})"
return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this.copy(), unit=exp.var('day')))})"

return super().interval_sql(expression)

Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _add_date_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateS
if expression.expression.is_number:
modified_increment = exp.Literal.number(int(expression.text("expression")) * multiplier)
else:
modified_increment = expression.expression
modified_increment = expression.expression.copy()
if multiplier != 1:
modified_increment = exp.Mul( # type: ignore
this=modified_increment, expression=exp.Literal.number(multiplier)
Expand Down
10 changes: 4 additions & 6 deletions sqlglot/dialects/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,7 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e
def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
unit = expression.text("unit").upper() or "DAY"
return (
f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})"
)
return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})"

return func

Expand Down Expand Up @@ -522,7 +520,7 @@ class Generator(generator.Generator):
exp.StrToTime: _str_to_date_sql,
exp.TableSample: no_tablesample_sql,
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime")),
exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime", copy=True)),
exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)),
exp.Trim: _trim_sql,
exp.TryCast: no_trycast_sql,
Expand Down Expand Up @@ -556,12 +554,12 @@ class Generator(generator.Generator):

def limit_sql(self, expression: exp.Limit, top: bool = False) -> str:
# MySQL requires simple literal values for its LIMIT clause.
expression = simplify_literal(expression)
expression = simplify_literal(expression.copy())
return super().limit_sql(expression, top=top)

def offset_sql(self, expression: exp.Offset) -> str:
# MySQL requires simple literal values for its OFFSET clause.
expression = simplify_literal(expression)
expression = simplify_literal(expression.copy())
return super().offset_sql(expression)

def xor_sql(self, expression: exp.Xor) -> str:
Expand Down
4 changes: 3 additions & 1 deletion sqlglot/dialects/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@

def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
expression = expression.copy()

this = self.sql(expression, "this")
unit = expression.args.get("unit")

expression = simplify_literal(expression.copy(), copy=False).expression
expression = simplify_literal(expression).expression
if not isinstance(expression, exp.Literal):
self.unsupported("Cannot add non literal")

Expand Down
5 changes: 3 additions & 2 deletions sqlglot/dialects/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:

def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -> str:
if isinstance(expression.this, (exp.Explode, exp.Posexplode)):
expression = expression.copy()
return self.sql(
exp.Join(
this=exp.Unnest(
Expand Down Expand Up @@ -96,14 +97,14 @@ def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDat
time_format = self.format_time(expression)
if time_format and time_format not in (Presto.TIME_FORMAT, Presto.DATE_FORMAT):
return exp.cast(_str_to_time_sql(self, expression), "DATE").sql(dialect="presto")
return exp.cast(exp.cast(expression.this, "TIMESTAMP"), "DATE").sql(dialect="presto")
return exp.cast(exp.cast(expression.this, "TIMESTAMP", copy=True), "DATE").sql(dialect="presto")


def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
this = expression.this

if not isinstance(this, exp.CurrentDate):
this = exp.cast(exp.cast(expression.this, "TIMESTAMP"), "DATE")
this = exp.cast(exp.cast(expression.this, "TIMESTAMP", copy=True), "DATE")

return self.func(
"DATE_ADD",
Expand Down
9 changes: 4 additions & 5 deletions sqlglot/dialects/teradata.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import typing as t

from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least
from sqlglot.tokens import TokenType
Expand Down Expand Up @@ -194,11 +196,7 @@ def rangen_sql(self, expression: exp.RangeN) -> str:

return f"RANGE_N({this} BETWEEN {expressions_sql}{each_sql})"

def createable_sql(
self,
expression: exp.Create,
locations: dict[exp.Properties.Location, list[exp.Property]],
) -> str:
def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str:
kind = self.sql(expression, "kind").upper()
if kind == "TABLE" and locations.get(exp.Properties.Location.POST_NAME):
this_name = self.sql(expression.this, "this")
Expand All @@ -209,4 +207,5 @@ def createable_sql(
)
this_schema = self.schema_columns_sql(expression.this)
return f"{this_name}{this_properties}{self.sep()}{this_schema}"

return super().createable_sql(expression, locations)
6 changes: 1 addition & 5 deletions sqlglot/dialects/tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,11 +639,7 @@ class Generator(generator.Generator):

LIMIT_FETCH = "FETCH"

def createable_sql(
self,
expression: exp.Create,
locations: dict[exp.Properties.Location, list[exp.Property]],
) -> str:
def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str:
sql = self.sql(expression, "this")
properties = expression.args.get("properties")

Expand Down
12 changes: 11 additions & 1 deletion sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4860,8 +4860,18 @@ def maybe_parse(
return sqlglot.parse_one(sql, read=dialect, into=into, **opts)


@t.overload
def maybe_copy(instance: None, copy: bool = True) -> None:
...


@t.overload
def maybe_copy(instance: E, copy: bool = True) -> E:
return instance.copy() if copy else instance
...


def maybe_copy(instance, copy=True):
return instance.copy() if copy and instance else instance


def _is_wrong_expression(expression, into):
Expand Down
Loading

0 comments on commit 289493b

Please sign in to comment.