Skip to content

Commit

Permalink
Feat: improve support for percentiles in duckdb, postgres (#2219)
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas authored Sep 14, 2023
1 parent fd1ed25 commit 0378325
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 7 deletions.
7 changes: 7 additions & 0 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ class Parser(parser.Parser):
"LIST_REVERSE_SORT": _sort_array_reverse,
"LIST_SORT": exp.SortArray.from_arg_list,
"LIST_VALUE": exp.Array.from_arg_list,
"MEDIAN": lambda args: exp.PercentileCont(
this=seq_get(args, 0), expression=exp.Literal.number(0.5)
),
"QUANTILE_CONT": exp.PercentileCont.from_arg_list,
"QUANTILE_DISC": exp.PercentileDisc.from_arg_list,
"REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
this=seq_get(args, 0), expression=seq_get(args, 1), group=seq_get(args, 2)
),
Expand Down Expand Up @@ -266,6 +271,8 @@ class Generator(generator.Generator):
exp.cast(e.expression, "timestamp", copy=True),
exp.cast(e.this, "timestamp", copy=True),
),
exp.PercentileCont: rename_func("QUANTILE_CONT"),
exp.PercentileDisc: rename_func("QUANTILE_DISC"),
exp.Properties: no_properties_sql,
exp.RegexpExtract: regexp_extract_sql,
exp.RegexpReplace: regexp_replace_sql,
Expand Down
12 changes: 9 additions & 3 deletions sqlglot/dialects/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,9 @@ class Generator(generator.Generator):
**generator.Generator.TRANSFORMS,
exp.AnyValue: any_value_to_max_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.ArrayContained: lambda self, e: self.binary(e, "<@"),
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
exp.BitwiseXor: lambda self, e: self.binary(e, "#"),
exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]),
exp.Explode: rename_func("UNNEST"),
Expand All @@ -401,10 +404,13 @@ class Generator(generator.Generator):
exp.Max: max_or_greatest,
exp.MapFromEntries: no_map_from_entries_sql,
exp.Min: min_or_least,
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
exp.ArrayContained: lambda self, e: self.binary(e, "<@"),
exp.Merge: transforms.preprocess([_remove_target_from_merge]),
exp.PercentileCont: transforms.preprocess(
[transforms.add_within_group_for_percentiles]
),
exp.PercentileDisc: transforms.preprocess(
[transforms.add_within_group_for_percentiles]
),
exp.Pivot: no_pivot_sql,
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
Expand Down
28 changes: 24 additions & 4 deletions sqlglot/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,27 @@ def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
return expression


PERCENTILES = (exp.PercentileCont, exp.PercentileDisc)


def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
if (
isinstance(expression, PERCENTILES)
and not isinstance(expression.parent, exp.WithinGroup)
and expression.expression
):
column = expression.this.pop()
expression.set("this", expression.expression.pop())
order = exp.Order(expressions=[exp.Ordered(this=column)])
expression = exp.WithinGroup(this=expression, expression=order)

return expression


def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
if (
isinstance(expression, exp.WithinGroup)
and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc))
and isinstance(expression.this, PERCENTILES)
and isinstance(expression.expression, exp.Order)
):
quantile = expression.this.this
Expand Down Expand Up @@ -294,10 +311,13 @@ def _to_sql(self, expression: exp.Expression) -> str:

transforms_handler = self.TRANSFORMS.get(type(expression))
if transforms_handler:
# Ensures we don't enter an infinite loop. This can happen when the original expression
# has the same type as the final expression and there's no _sql method available for it,
# because then it'd re-enter _to_sql.
if expression_type is type(expression):
if isinstance(expression, exp.Func):
return self.function_fallback_sql(expression)

# Ensures we don't enter an infinite loop. This can happen when the original expression
# has the same type as the final expression and there's no _sql method available for it,
# because then it'd re-enter _to_sql.
raise ValueError(
f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
)
Expand Down
21 changes: 21 additions & 0 deletions tests/dialects/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,27 @@ def test_duckdb(self):
"SELECT CAST('2020-05-06' AS DATE) + INTERVAL 5 DAY",
read={"bigquery": "SELECT DATE_ADD(CAST('2020-05-06' AS DATE), INTERVAL 5 DAY)"},
)
self.validate_all(
"SELECT QUANTILE_CONT(x, q) FROM t",
write={
"duckdb": "SELECT QUANTILE_CONT(x, q) FROM t",
"postgres": "SELECT PERCENTILE_CONT(q) WITHIN GROUP (ORDER BY x) FROM t",
},
)
self.validate_all(
"SELECT QUANTILE_DISC(x, q) FROM t",
write={
"duckdb": "SELECT QUANTILE_DISC(x, q) FROM t",
"postgres": "SELECT PERCENTILE_DISC(q) WITHIN GROUP (ORDER BY x) FROM t",
},
)
self.validate_all(
"SELECT MEDIAN(x) FROM t",
write={
"duckdb": "SELECT QUANTILE_CONT(x, 0.5) FROM t",
"postgres": "SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x) FROM t",
},
)

with self.assertRaises(UnsupportedError):
transpile(
Expand Down
1 change: 1 addition & 0 deletions tests/dialects/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def test_postgres(self):
"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY amount)",
write={
"databricks": "SELECT PERCENTILE_APPROX(amount, 0.5)",
"postgres": "SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY amount)",
"presto": "SELECT APPROX_PERCENTILE(amount, 0.5)",
"spark": "SELECT PERCENTILE_APPROX(amount, 0.5)",
"trino": "SELECT APPROX_PERCENTILE(amount, 0.5)",
Expand Down

0 comments on commit 0378325

Please sign in to comment.