Skip to content

Commit

Permalink
Fix(snowflake): rename GenerateSeries, include offset in unnest_sql (#…
Browse files Browse the repository at this point in the history
…2243)

* Fix(snowflake): rename GenerateSeries, include offset in unnest_sql

* Add list constructor back

* Remove unnecessary exp.null() in Generator.filter_sql

* Parse ARRAY_GENERATE_RANGE using inclusive end, fix its generation
  • Loading branch information
georgesittas authored Sep 17, 2023
1 parent 8ebbfe2 commit 829415c
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 14 deletions.
5 changes: 2 additions & 3 deletions sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
date_add_interval_sql,
datestrtodate_sql,
format_time_lambda,
if_sql,
inline_array_sql,
json_keyvalue_comma_sql,
max_or_greatest,
Expand Down Expand Up @@ -433,9 +434,7 @@ class Generator(generator.Generator):
exp.GenerateSeries: rename_func("GENERATE_ARRAY"),
exp.GroupConcat: rename_func("STRING_AGG"),
exp.Hex: rename_func("TO_HEX"),
exp.If: lambda self, e: self.func(
"IF", e.this, e.args.get("true"), e.args.get("false") or "NULL"
),
exp.If: if_sql(false_value="NULL"),
exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"),
exp.JSONFormat: rename_func("TO_JSON_STRING"),
Expand Down
16 changes: 12 additions & 4 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,18 @@ def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -
return self.func("APPROX_COUNT_DISTINCT", expression.this)


def if_sql(self: Generator, expression: exp.If) -> str:
return self.func(
"IF", expression.this, expression.args.get("true"), expression.args.get("false")
)
def if_sql(
name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
) -> t.Callable[[Generator, exp.If], str]:
def _if_sql(self: Generator, expression: exp.If) -> str:
return self.func(
name,
expression.this,
expression.args.get("true"),
expression.args.get("false") or false_value,
)

return _if_sql


def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ class Generator(generator.Generator):
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.DATEINT_FORMAT})",
exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
exp.FromBase64: rename_func("UNBASE64"),
exp.If: if_sql,
exp.If: if_sql(),
exp.ILike: no_ilike_sql,
exp.IsNan: rename_func("ISNAN"),
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ class Generator(generator.Generator):
exp.First: _first_last_sql,
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql,
exp.If: if_sql(),
exp.ILike: no_ilike_sql,
exp.Initcap: _initcap_sql,
exp.ParseJSON: rename_func("JSON_PARSE"),
Expand Down
27 changes: 24 additions & 3 deletions sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
date_trunc_to_time,
datestrtodate_sql,
format_time_lambda,
if_sql,
inline_array_sql,
max_or_greatest,
min_or_least,
Expand Down Expand Up @@ -242,6 +243,12 @@ class Parser(parser.Parser):
**parser.Parser.FUNCTIONS,
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
"ARRAY_GENERATE_RANGE": lambda args: exp.GenerateSeries(
# ARRAY_GENERATE_RANGE has an exlusive end; we normalize it to be inclusive
start=seq_get(args, 0),
end=exp.Sub(this=seq_get(args, 1), expression=exp.Literal.number(1)),
step=seq_get(args, 2),
),
"ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list,
"BITXOR": binary_from_function(exp.BitwiseXor),
"BIT_XOR": binary_from_function(exp.BitwiseXor),
Expand Down Expand Up @@ -405,8 +412,11 @@ class Generator(generator.Generator):
exp.DataType: _datatype_sql,
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.Extract: rename_func("DATE_PART"),
exp.GenerateSeries: lambda self, e: self.func(
"ARRAY_GENERATE_RANGE", e.args["start"], e.args["end"] + 1, e.args.get("step")
),
exp.GroupConcat: rename_func("LISTAGG"),
exp.If: rename_func("IFF"),
exp.If: if_sql(name="IFF", false_value="NULL"),
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
exp.LogicalOr: rename_func("BOOLOR_AGG"),
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
Expand Down Expand Up @@ -464,12 +474,23 @@ class Generator(generator.Generator):
}

def unnest_sql(self, expression: exp.Unnest) -> str:
selects = ["value"]
unnest_alias = expression.args.get("alias")

offset = expression.args.get("offset")
if offset:
if unnest_alias:
expression = expression.copy()
unnest_alias.append("columns", offset.pop())

selects.append("index")

subquery = exp.Subquery(
this=exp.select("value").from_(
this=exp.select(*selects).from_(
f"TABLE(FLATTEN(INPUT => {self.sql(expression.expressions[0])}))"
),
)
alias = self.sql(expression, "alias")
alias = self.sql(unnest_alias)
alias = f" AS {alias}" if alias else ""
return f"{self.sql(subquery)}{alias}"

Expand Down
2 changes: 1 addition & 1 deletion sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,7 @@ def filter_sql(self, expression: exp.Filter) -> str:
agg = expression.this.copy()
agg_arg = agg.this
cond = expression.expression.this
agg_arg.replace(exp.If(this=cond.copy(), true=agg_arg.copy(), false=exp.null()))
agg_arg.replace(exp.If(this=cond.copy(), true=agg_arg.copy()))
return self.sql(agg)

def hint_sql(self, expression: exp.Hint) -> str:
Expand Down
1 change: 1 addition & 0 deletions sqlglot/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def new_name(names: t.Set[str], name: str) -> str:
table=[series_alias],
)

# we use list here because expression.selects is mutated inside the loop
for select in list(expression.selects):
to_replace = select
pos_alias = ""
Expand Down
2 changes: 1 addition & 1 deletion tests/dialects/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_duckdb(self):
"SELECT UNNEST([1, 2, 3])",
write={
"duckdb": "SELECT UNNEST([1, 2, 3])",
"snowflake": "SELECT IFF(pos = pos_2, col) AS col FROM (SELECT value FROM TABLE(FLATTEN(INPUT => GENERATE_SERIES(0, GREATEST(ARRAY_SIZE([1, 2, 3])) - 1)))) AS _u(pos) CROSS JOIN (SELECT value FROM TABLE(FLATTEN(INPUT => [1, 2, 3]))) AS _u_2(col) WHERE pos = pos_2 OR (pos > (ARRAY_SIZE([1, 2, 3]) - 1) AND pos_2 = (ARRAY_SIZE([1, 2, 3]) - 1))",
"snowflake": "SELECT IFF(pos = pos_2, col, NULL) AS col FROM (SELECT value FROM TABLE(FLATTEN(INPUT => ARRAY_GENERATE_RANGE(0, (GREATEST(ARRAY_SIZE([1, 2, 3])) - 1) + 1)))) AS _u(pos) CROSS JOIN (SELECT value, index FROM TABLE(FLATTEN(INPUT => [1, 2, 3]))) AS _u_2(col, pos_2) WHERE pos = pos_2 OR (pos > (ARRAY_SIZE([1, 2, 3]) - 1) AND pos_2 = (ARRAY_SIZE([1, 2, 3]) - 1))",
},
)
self.validate_all(
Expand Down
17 changes: 17 additions & 0 deletions tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,23 @@ def test_snowflake(self):
self.validate_all("CAST(x AS CHAR VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"})
self.validate_all("CAST(x AS CHARACTER VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"})
self.validate_all("CAST(x AS NCHAR VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"})
self.validate_all(
"ARRAY_GENERATE_RANGE(0, 3)",
write={
"bigquery": "GENERATE_ARRAY(0, 3 - 1)",
"postgres": "GENERATE_SERIES(0, 3 - 1)",
"presto": "SEQUENCE(0, 3 - 1)",
"snowflake": "ARRAY_GENERATE_RANGE(0, (3 - 1) + 1)",
},
)
self.validate_all(
"ARRAY_GENERATE_RANGE(0, 3 + 1)",
read={
"bigquery": "GENERATE_ARRAY(0, 3)",
"postgres": "GENERATE_SERIES(0, 3)",
"presto": "SEQUENCE(0, 3)",
},
)
self.validate_all(
"SELECT DATE_PART('year', TIMESTAMP '2020-01-01')",
write={
Expand Down

0 comments on commit 829415c

Please sign in to comment.