Skip to content

Commit

Permalink
Fix(duckdb): parse DATEDIFF correctly (#1546)
Browse files Browse the repository at this point in the history
* Fix(duckdb): parse DATEDIFF correctly

* Fixup

* Fixup
  • Loading branch information
georgesittas authored May 4, 2023
1 parent 0578d6d commit b7e08cc
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 24 deletions.
57 changes: 33 additions & 24 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import typing as t

from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
Expand All @@ -23,52 +25,61 @@
from sqlglot.tokens import TokenType


def _ts_or_ds_add(self, expression):
def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"


def _date_add(self, expression):
def _date_add_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
return f"{this} + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"


def _array_sort_sql(self, expression):
def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str:
if expression.expression:
self.unsupported("DUCKDB ARRAY_SORT does not support a comparator")
return f"ARRAY_SORT({self.sql(expression, 'this')})"


def _sort_array_sql(self, expression):
def _sort_array_sql(self: generator.Generator, expression: exp.SortArray) -> str:
this = self.sql(expression, "this")
if expression.args.get("asc") == exp.false():
return f"ARRAY_REVERSE_SORT({this})"
return f"ARRAY_SORT({this})"


def _sort_array_reverse(args):
def _sort_array_reverse(args: t.Sequence) -> exp.Expression:
return exp.SortArray(this=seq_get(args, 0), asc=exp.false())


def _struct_sql(self, expression):
def _parse_date_diff(args: t.Sequence) -> exp.Expression:
return exp.DateDiff(
this=seq_get(args, 2),
expression=seq_get(args, 1),
unit=seq_get(args, 0),
)


def _struct_sql(self: generator.Generator, expression: exp.Struct) -> str:
args = [
f"'{e.name or e.this.name}': {self.sql(e, 'expression')}" for e in expression.expressions
]
return f"{{{', '.join(args)}}}"


def _datatype_sql(self, expression):
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
if expression.this == exp.DataType.Type.ARRAY:
return f"{self.expressions(expression, flat=True)}[]"
return self.datatype_sql(expression)


def _regexp_extract_sql(self, expression):
def _regexp_extract_sql(self: generator.Generator, expression: exp.RegexpExtract) -> str:
bad_args = list(filter(expression.args.get, ("position", "occurrence")))
if bad_args:
self.unsupported(f"REGEXP_EXTRACT does not support arg(s) {bad_args}")

return self.func(
"REGEXP_EXTRACT",
expression.args.get("this"),
Expand Down Expand Up @@ -108,25 +119,27 @@ class Parser(parser.Parser):
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
"ARRAY_SORT": exp.SortArray.from_arg_list,
"ARRAY_REVERSE_SORT": _sort_array_reverse,
"DATEDIFF": _parse_date_diff,
"DATE_DIFF": _parse_date_diff,
"EPOCH": exp.TimeToUnix.from_arg_list,
"EPOCH_MS": lambda args: exp.UnixToTime(
this=exp.Div(
this=seq_get(args, 0),
expression=exp.Literal.number(1000),
)
),
"LIST_SORT": exp.SortArray.from_arg_list,
"LIST_REVERSE_SORT": _sort_array_reverse,
"LIST_SORT": exp.SortArray.from_arg_list,
"LIST_VALUE": exp.Array.from_arg_list,
"REGEXP_MATCHES": exp.RegexpLike.from_arg_list,
"STRFTIME": format_time_lambda(exp.TimeToStr, "duckdb"),
"STRPTIME": format_time_lambda(exp.StrToTime, "duckdb"),
"STR_SPLIT": exp.Split.from_arg_list,
"STRING_SPLIT": exp.Split.from_arg_list,
"STRING_TO_ARRAY": exp.Split.from_arg_list,
"STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"STRING_TO_ARRAY": exp.Split.from_arg_list,
"STRPTIME": format_time_lambda(exp.StrToTime, "duckdb"),
"STRUCT_PACK": exp.Struct.from_arg_list,
"STR_SPLIT": exp.Split.from_arg_list,
"STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"TO_TIMESTAMP": exp.UnixToTime.from_arg_list,
"UNNEST": exp.Explode.from_arg_list,
}
Expand All @@ -142,10 +155,11 @@ class Parser(parser.Parser):
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
LIMIT_FETCH = "LIMIT"
STRUCT_DELIMITER = ("(", ")")

TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**generator.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.Array: lambda self, e: self.func("ARRAY", e.expressions[0])
if isinstance(seq_get(e.expressions, 0), exp.Select)
Expand All @@ -158,9 +172,9 @@ class Generator(generator.Generator):
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.DataType: _datatype_sql,
exp.DateAdd: _date_add,
exp.DateAdd: _date_add_sql,
exp.DateDiff: lambda self, e: self.func(
"DATE_DIFF", e.args.get("unit") or exp.Literal.string("day"), e.expression, e.this
"DATE_DIFF", f"'{e.args.get('unit', 'day')}'", e.expression, e.this
),
exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)",
Expand Down Expand Up @@ -192,7 +206,7 @@ class Generator(generator.Generator):
exp.TimeToStr: lambda self, e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("EPOCH"),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _ts_or_ds_add,
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsToDate: ts_or_ds_to_date_sql("duckdb"),
exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: rename_func("TO_TIMESTAMP"),
Expand All @@ -201,7 +215,7 @@ class Generator(generator.Generator):
}

TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BINARY: "BLOB",
exp.DataType.Type.CHAR: "TEXT",
exp.DataType.Type.FLOAT: "REAL",
Expand All @@ -212,17 +226,12 @@ class Generator(generator.Generator):
exp.DataType.Type.VARCHAR: "TEXT",
}

STAR_MAPPING = {
**generator.Generator.STAR_MAPPING,
"except": "EXCLUDE",
}
STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"}

PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}

LIMIT_FETCH = "LIMIT"

def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str:
return super().tablesample_sql(expression, seed_prefix="REPEATABLE")
14 changes: 14 additions & 0 deletions tests/dialects/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,20 @@ def test_duckdb(self):
"SELECT a['x space'] FROM (SELECT {'x space': 1, 'y': 2, 'z': 3} AS a)"
)

self.validate_all(
"""SELECT DATEDIFF('day', t1."A", t1."B") FROM "table" AS t1""",
write={
"duckdb": """SELECT DATE_DIFF('day', t1."A", t1."B") FROM "table" AS t1""",
"trino": """SELECT DATE_DIFF('day', t1."A", t1."B") FROM "table" AS t1""",
},
)
self.validate_all(
"SELECT DATE_DIFF('day', DATE '2020-01-01', DATE '2020-01-05')",
write={
"duckdb": "SELECT DATE_DIFF('day', CAST('2020-01-01' AS DATE), CAST('2020-01-05' AS DATE))",
"trino": "SELECT DATE_DIFF('day', CAST('2020-01-01' AS DATE), CAST('2020-01-05' AS DATE))",
},
)
self.validate_all("x ~ y", write={"duckdb": "REGEXP_MATCHES(x, y)"})
self.validate_all("SELECT * FROM 'x.y'", write={"duckdb": 'SELECT * FROM "x.y"'})
self.validate_all(
Expand Down

0 comments on commit b7e08cc

Please sign in to comment.