Skip to content

Commit

Permalink
Fix: presto sequence to unnest closes #1600
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed May 14, 2023
1 parent 6875d07 commit 2f7473b
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 28 deletions.
67 changes: 39 additions & 28 deletions sqlglot/dialects/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,33 +127,6 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s
)


def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str:
start = expression.args["start"]
end = expression.args["end"]
step = expression.args.get("step")

target_type = None

if isinstance(start, exp.Cast):
target_type = start.to
elif isinstance(end, exp.Cast):
target_type = end.to

if target_type and target_type.this == exp.DataType.Type.TIMESTAMP:
to = target_type.copy()

if target_type is start.to:
end = exp.Cast(this=end, to=to)
else:
start = exp.Cast(this=start, to=to)

sql = self.func("SEQUENCE", start, end, step)
if isinstance(expression.parent, exp.Table):
sql = f"UNNEST({sql})"

return sql


def _ensure_utf8(charset: exp.Literal) -> None:
if charset.name.lower() != "utf-8":
raise UnsupportedError(f"Unsupported charset {charset}")
Expand Down Expand Up @@ -191,6 +164,22 @@ def _from_unixtime(args: t.Sequence) -> exp.Expression:
return exp.UnixToTime.from_arg_list(args)


def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.Table):
if isinstance(expression.this, exp.GenerateSeries):
unnest = exp.Unnest(expressions=[expression.this])

if expression.alias:
return exp.alias_(
unnest,
alias="_u",
table=[expression.alias],
copy=False,
)
return unnest
return expression


class Presto(Dialect):
index_offset = 1
null_ordering = "nulls_are_last"
Expand Down Expand Up @@ -294,7 +283,6 @@ class Generator(generator.Generator):
exp.Decode: _decode_sql,
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
exp.Encode: _encode_sql,
exp.GenerateSeries: _sequence_sql,
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql,
Expand Down Expand Up @@ -323,6 +311,7 @@ class Generator(generator.Generator):
exp.StructExtract: struct_extract_sql,
exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'",
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.Table: transforms.preprocess([_unnest_sequence]),
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToDate: timestrtotime_sql,
exp.TimeStrToTime: timestrtotime_sql,
Expand Down Expand Up @@ -352,3 +341,25 @@ def transaction_sql(self, expression: exp.Transaction) -> str:
modes = expression.args.get("modes")
modes = f" {', '.join(modes)}" if modes else ""
return f"START TRANSACTION{modes}"

def generateseries_sql(self, expression: exp.GenerateSeries) -> str:
start = expression.args["start"]
end = expression.args["end"]
step = expression.args.get("step")

if isinstance(start, exp.Cast):
target_type = start.to
elif isinstance(end, exp.Cast):
target_type = end.to
else:
target_type = None

if target_type and target_type.is_type(exp.DataType.Type.TIMESTAMP):
to = target_type.copy()

if target_type is start.to:
end = exp.Cast(this=end, to=to)
else:
start = exp.Cast(this=start, to=to)

return self.func("SEQUENCE", start, end, step)
9 changes: 9 additions & 0 deletions tests/dialects/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,15 @@ def test_postgres(self):
"tsql": "SELECT * FROM t CROSS JOIN GENERATE_SERIES(2, 4)",
},
)
self.validate_all(
"SELECT * FROM t CROSS JOIN GENERATE_SERIES(2, 4) AS s",
write={
"postgres": "SELECT * FROM t CROSS JOIN GENERATE_SERIES(2, 4) AS s",
"presto": "SELECT * FROM t CROSS JOIN UNNEST(SEQUENCE(2, 4)) AS _u(s)",
"trino": "SELECT * FROM t CROSS JOIN UNNEST(SEQUENCE(2, 4)) AS _u(s)",
"tsql": "SELECT * FROM t CROSS JOIN GENERATE_SERIES(2, 4) AS s",
},
)
self.validate_all(
"END WORK AND NO CHAIN",
write={"postgres": "COMMIT AND NO CHAIN"},
Expand Down

0 comments on commit 2f7473b

Please sign in to comment.