Skip to content

Commit

Permalink
Fix: use preprocess instead of expanding transform dicts BREAKING (#1525
Browse files Browse the repository at this point in the history
)

* Fix: use preprocess instead of expanding transform dicts

* Formatting

* Remove transform variables
  • Loading branch information
georgesittas authored May 3, 2023
1 parent 52c80e0 commit 3d964c6
Show file tree
Hide file tree
Showing 16 changed files with 57 additions and 39 deletions.
7 changes: 4 additions & 3 deletions sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,11 @@ class Generator(generator.Generator):

TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
**transforms.REMOVE_PRECISION_PARAMETERIZED_TYPES, # type: ignore
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.AtTimeZone: lambda self, e: self.func(
"TIMESTAMP", self.func("DATETIME", e.this, e.args.get("zone"))
),
exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]),
exp.DateAdd: _date_add_sql("DATE", "ADD"),
exp.DateSub: _date_add_sql("DATE", "SUB"),
exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"),
Expand All @@ -235,7 +234,9 @@ class Generator(generator.Generator):
exp.IntDiv: rename_func("DIV"),
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.Select: transforms.preprocess([_unqualify_unnest]),
exp.Select: transforms.preprocess(
[_unqualify_unnest, transforms.eliminate_distinct_on]
),
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
exp.TimeSub: _date_add_sql("TIME", "SUB"),
Expand Down
9 changes: 7 additions & 2 deletions sqlglot/dialects/databricks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from sqlglot import exp
from sqlglot import exp, transforms
from sqlglot.dialects.dialect import parse_date_delta
from sqlglot.dialects.spark import Spark
from sqlglot.dialects.tsql import generate_date_delta_with_unit_sql
Expand Down Expand Up @@ -29,9 +29,14 @@ class Generator(Spark.Generator):
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.JSONExtract: lambda self, e: self.binary(e, ":"),
exp.Select: transforms.preprocess(
[
transforms.eliminate_distinct_on,
transforms.unnest_to_explode,
]
),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
}
TRANSFORMS.pop(exp.Select) # Remove the ELIMINATE_QUALIFY transformation

PARAMETER_TOKEN = "$"

Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/drill.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ class Generator(generator.Generator):

TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.ArrayContains: rename_func("REPEATED_CONTAINS"),
exp.ArraySize: rename_func("REPEATED_COUNT"),
Expand All @@ -146,6 +145,7 @@ class Generator(generator.Generator):
exp.StrPosition: str_position_sql,
exp.StrToDate: _str_to_date,
exp.Pow: rename_func("POW"),
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: timestrtotime_sql,
Expand Down
10 changes: 6 additions & 4 deletions sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,13 @@ class Generator(generator.Generator):

TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
**transforms.ELIMINATE_QUALIFY, # type: ignore
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Select: transforms.preprocess(
[transforms.eliminate_qualify, transforms.unnest_to_explode]
[
transforms.eliminate_qualify,
transforms.eliminate_distinct_on,
transforms.unnest_to_explode,
]
),
exp.Property: _property_sql,
exp.ApproxDistinct: approx_count_distinct_sql,
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,6 @@ class Generator(generator.Generator):

TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
exp.CurrentDate: no_paren_current_date_sql,
exp.DateDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression),
exp.DateAdd: _date_add_sql("ADD"),
Expand All @@ -404,6 +403,7 @@ class Generator(generator.Generator):
exp.Min: min_or_least,
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StrPosition: strposition_to_locate_sql,
exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_date_sql,
Expand Down
6 changes: 3 additions & 3 deletions sqlglot/dialects/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,14 @@ class Generator(generator.Generator):

TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
exp.DateStrToDate: lambda self, e: self.func(
"TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD")
),
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.ILike: no_ilike_sql,
exp.IfNull: rename_func("NVL"),
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
exp.Substring: rename_func("SUBSTR"),
Expand All @@ -136,7 +137,6 @@ class Generator(generator.Generator):
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: trim_sql,
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
exp.IfNull: rename_func("NVL"),
}

PROPERTIES_LOCATION = {
Expand Down
7 changes: 3 additions & 4 deletions sqlglot/dialects/postgres.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from sqlglot import exp, generator, parser, tokens
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
Expand All @@ -20,7 +20,6 @@
from sqlglot.helper import seq_get
from sqlglot.parser import binary_range_parser
from sqlglot.tokens import TokenType
from sqlglot.transforms import preprocess, remove_target_from_merge

DATE_DIFF_FACTOR = {
"MICROSECOND": " * 1000000",
Expand Down Expand Up @@ -316,7 +315,7 @@ class Generator(generator.Generator):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
exp.BitwiseXor: lambda self, e: self.binary(e, "#"),
exp.ColumnDef: preprocess(
exp.ColumnDef: transforms.preprocess(
[
_auto_increment_to_serial,
_serial_to_generated,
Expand All @@ -341,7 +340,7 @@ class Generator(generator.Generator):
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
exp.ArrayContained: lambda self, e: self.binary(e, "<@"),
exp.Merge: preprocess([remove_target_from_merge]),
exp.Merge: transforms.preprocess([transforms.remove_target_from_merge]),
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
exp.StrPosition: str_position_sql,
Expand Down
9 changes: 6 additions & 3 deletions sqlglot/dialects/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,6 @@ class Generator(generator.Generator):

TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
exp.ApproxDistinct: _approx_distinct_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.ArrayConcat: rename_func("CONCAT"),
Expand All @@ -296,6 +294,7 @@ class Generator(generator.Generator):
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
exp.Encode: _encode_sql,
exp.GenerateSeries: _sequence_sql,
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql,
exp.ILike: no_ilike_sql,
Expand All @@ -309,7 +308,11 @@ class Generator(generator.Generator):
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.Select: transforms.preprocess(
[transforms.eliminate_qualify, transforms.explode_to_unnest]
[
transforms.eliminate_qualify,
transforms.eliminate_distinct_on,
transforms.explode_to_unnest,
]
),
exp.SortArray: _no_sort_array,
exp.StrPosition: rename_func("STRPOS"),
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ class Generator(Postgres.Generator):

TRANSFORMS = {
**Postgres.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
exp.CurrentTimestamp: lambda self, e: "SYSDATE",
exp.DateAdd: lambda self, e: self.func(
"DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this
Expand All @@ -102,6 +101,7 @@ class Generator(Postgres.Generator):
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
exp.JSONExtract: _json_sql,
exp.JSONExtractScalar: _json_sql,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
}

Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,6 @@ class Generator(generator.Generator):

TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
exp.Array: inline_array_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.ArrayJoin: rename_func("ARRAY_TO_STRING"),
Expand All @@ -306,6 +305,7 @@ class Generator(generator.Generator):
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StarMap: rename_func("OBJECT_CONSTRUCT"),
exp.StrPosition: lambda self, e: self.func(
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
Expand Down
5 changes: 3 additions & 2 deletions sqlglot/dialects/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ class Generator(generator.Generator):

TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
**transforms.ELIMINATE_QUALIFY, # type: ignore
exp.CountIf: count_if_to_sum,
exp.CurrentDate: lambda *_: "CURRENT_DATE",
exp.CurrentTime: lambda *_: "CURRENT_TIME",
Expand All @@ -81,6 +79,9 @@ class Generator(generator.Generator):
exp.Levenshtein: rename_func("EDITDIST3"),
exp.LogicalOr: rename_func("MAX"),
exp.LogicalAnd: rename_func("MIN"),
exp.Select: transforms.preprocess(
[transforms.eliminate_distinct_on, transforms.eliminate_qualify]
),
exp.TableSample: no_tablesample_sql,
exp.TimeStrToTime: lambda self, e: self.sql(e, "this"),
exp.TryCast: no_trycast_sql,
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/tableau.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ class Generator(generator.Generator):

TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
exp.If: _if_sql,
exp.Coalesce: _coalesce_sql,
exp.Count: _count_sql,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
}

PROPERTIES_LOCATION = {
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/teradata.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ class Generator(generator.Generator):

TRANSFORMS = {
**generator.Generator.TRANSFORMS,
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.TimeToStr: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
}
Expand Down
8 changes: 4 additions & 4 deletions sqlglot/dialects/tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,22 +459,22 @@ class Generator(generator.Generator):

TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.CurrentDate: rename_func("GETDATE"),
exp.CurrentTimestamp: rename_func("GETDATE"),
exp.If: rename_func("IIF"),
exp.NumberToStr: _format_sql,
exp.TimeToStr: _format_sql,
exp.GroupConcat: _string_agg_sql,
exp.If: rename_func("IIF"),
exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
exp.Min: min_or_least,
exp.NumberToStr: _format_sql,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this),
exp.SHA2: lambda self, e: self.func(
"HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this
),
exp.TimeToStr: _format_sql,
}

TRANSFORMS.pop(exp.ReturnsProperty)
Expand Down
8 changes: 0 additions & 8 deletions sqlglot/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,11 +261,3 @@ def _to_sql(self, expression: exp.Expression) -> str:
return getattr(self, expression.key + "_sql")(expression)

return _to_sql


UNALIAS_GROUP = {exp.Group: preprocess([unalias_group])}
ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on])}
ELIMINATE_QUALIFY = {exp.Select: preprocess([eliminate_qualify])}
REMOVE_PRECISION_PARAMETERIZED_TYPES = {
exp.Cast: preprocess([remove_precision_parameterized_types])
}
15 changes: 15 additions & 0 deletions tests/dialects/test_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,22 @@ def test_redshift(self):
self.validate_all(
"SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC",
write={
"bigquery": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1",
"databricks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1",
"drill": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1",
"hive": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1",
"mysql": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE `_row_number` = 1",
"oracle": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1',
"presto": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1',
"redshift": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE "_row_number" = 1',
"snowflake": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE "_row_number" = 1',
"spark": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1",
"sqlite": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1',
"starrocks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE `_row_number` = 1",
"tableau": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1',
"teradata": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1',
"trino": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1',
"tsql": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1',
},
)
self.validate_all(
Expand Down

0 comments on commit 3d964c6

Please sign in to comment.