Skip to content

Commit

Permalink
Fix(parser, duckdb): decode/encode in duckdb don't take charset (#1993)
Browse files Browse the repository at this point in the history
* Fix(parser, duckdb): decode/encode in duckdb don't take charset

* refactor from code review

* prefer Generator.unsupported
  • Loading branch information
charsmith authored Aug 4, 2023
1 parent 2abc84f commit c9dd971
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 18 deletions.
14 changes: 14 additions & 0 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,20 @@ def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
return self.sql(exp.cast(expression.this, "date"))


# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
def encode_decode_sql(self: Generator, expression: exp.Expression, name: str) -> str:
if "charset" in expression.args:
charset = expression.args["charset"]

if charset.name.lower() != "utf-8":
self.unsupported(f"Unsupported charset {charset}")

expression = expression.copy()
del expression.args["charset"]

return rename_func(name)(self, expression)


def min_or_least(self: Generator, expression: exp.Min) -> str:
name = "LEAST" if expression.expressions else "MIN"
return rename_func(name)(self, expression)
Expand Down
18 changes: 18 additions & 0 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
binary_from_function,
date_trunc_to_time,
datestrtodate_sql,
encode_decode_sql,
format_time_lambda,
no_comment_column_constraint_sql,
no_properties_sql,
Expand All @@ -28,6 +29,9 @@
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType

if t.TYPE_CHECKING:
from sqlglot._typing import E


def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
this = self.sql(expression, "this")
Expand Down Expand Up @@ -167,6 +171,12 @@ class Parser(parser.Parser):
"XOR": binary_from_function(exp.BitwiseXor),
}

FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
"ENCODE": lambda self: self._parse_encode_decode(exp.Encode),
"DECODE": lambda self: self._parse_encode_decode(exp.Decode),
}

TYPE_TOKENS = {
*parser.Parser.TYPE_TOKENS,
TokenType.UBIGINT,
Expand All @@ -175,6 +185,12 @@ class Parser(parser.Parser):
TokenType.UTINYINT,
}

def _parse_encode_decode(self, expression: t.Type[E]) -> E:
args = self._parse_csv(self._parse_conjunction)
return self.expression(
expression, this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
)

def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
if len(aggregations) == 1:
return super()._pivot_column_names(aggregations)
Expand Down Expand Up @@ -215,7 +231,9 @@ class Generator(generator.Generator):
),
exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)",
exp.Decode: lambda self, e: encode_decode_sql(self, e, "DECODE"),
exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.DATEINT_FORMAT}) AS DATE)",
exp.Encode: lambda self, e: encode_decode_sql(self, e, "ENCODE"),
exp.Explode: rename_func("UNNEST"),
exp.IntDiv: lambda self, e: self.binary(e, "//"),
exp.JSONExtract: arrow_json_extract_sql,
Expand Down
21 changes: 3 additions & 18 deletions sqlglot/dialects/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Dialect,
binary_from_function,
date_trunc_to_time,
encode_decode_sql,
format_time_lambda,
if_sql,
left_to_substring_sql,
Expand All @@ -21,7 +22,6 @@
timestrtotime_sql,
)
from sqlglot.dialects.mysql import MySQL
from sqlglot.errors import UnsupportedError
from sqlglot.helper import apply_index_offset, seq_get
from sqlglot.tokens import TokenType

Expand Down Expand Up @@ -59,16 +59,6 @@ def _initcap_sql(self: generator.Generator, expression: exp.Initcap) -> str:
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"


def _decode_sql(self: generator.Generator, expression: exp.Decode) -> str:
_ensure_utf8(expression.args["charset"])
return self.func("FROM_UTF8", expression.this, expression.args.get("replace"))


def _encode_sql(self: generator.Generator, expression: exp.Encode) -> str:
_ensure_utf8(expression.args["charset"])
return f"TO_UTF8({self.sql(expression, 'this')})"


def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str:
if expression.args.get("asc") == exp.false():
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
Expand Down Expand Up @@ -123,11 +113,6 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s
)


def _ensure_utf8(charset: exp.Literal) -> None:
if charset.name.lower() != "utf-8":
raise UnsupportedError(f"Unsupported charset {charset}")


def _approx_percentile(args: t.List) -> exp.Expression:
if len(args) == 4:
return exp.ApproxQuantile(
Expand Down Expand Up @@ -288,9 +273,9 @@ class Generator(generator.Generator):
),
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.DATE_FORMAT}) AS DATE)",
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)",
exp.Decode: _decode_sql,
exp.Decode: lambda self, e: encode_decode_sql(self, e, "FROM_UTF8"),
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.DATEINT_FORMAT}) AS DATE)",
exp.Encode: _encode_sql,
exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"),
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Hex: rename_func("TO_HEX"),
Expand Down
26 changes: 26 additions & 0 deletions tests/dialects/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,32 @@ def test_bool_or(self):
write={"duckdb": "SELECT a, BOOL_OR(b) FROM table GROUP BY a"},
)

def test_encode_decode(self):
self.validate_all(
"ENCODE(x)",
read={
"spark": "ENCODE(x, 'utf-8')",
"presto": "TO_UTF8(x)",
},
write={
"duckdb": "ENCODE(x)",
"spark": "ENCODE(x, 'utf-8')",
"presto": "TO_UTF8(x)",
},
)
self.validate_all(
"DECODE(x)",
read={
"spark": "DECODE(x, 'utf-8')",
"presto": "FROM_UTF8(x)",
},
write={
"duckdb": "DECODE(x)",
"spark": "DECODE(x, 'utf-8')",
"presto": "FROM_UTF8(x)",
},
)

def test_rename_table(self):
self.validate_all(
"ALTER TABLE db.t1 RENAME TO db.t2",
Expand Down
4 changes: 4 additions & 0 deletions tests/dialects/test_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,12 +744,14 @@ def test_encode_decode(self):
"TO_UTF8(x)",
write={
"spark": "ENCODE(x, 'utf-8')",
"duckdb": "ENCODE(x)",
},
)
self.validate_all(
"FROM_UTF8(x)",
write={
"spark": "DECODE(x, 'utf-8')",
"duckdb": "DECODE(x)",
},
)
self.validate_all(
Expand All @@ -774,12 +776,14 @@ def test_encode_decode(self):
"ENCODE(x, 'invalid')",
write={
"presto": UnsupportedError,
"duckdb": UnsupportedError,
},
)
self.validate_all(
"DECODE(x, 'invalid')",
write={
"presto": UnsupportedError,
"duckdb": UnsupportedError,
},
)

Expand Down

0 comments on commit c9dd971

Please sign in to comment.