Skip to content

Commit

Permalink
feat: add redshift concat_ws support (#2194)
Browse files Browse the repository at this point in the history
  • Loading branch information
eakmanrq authored Sep 12, 2023
1 parent 8c51275 commit 11d95ff
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 9 deletions.
16 changes: 12 additions & 4 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import typing as t
from enum import Enum
from functools import reduce

from sqlglot import exp
from sqlglot._typing import E
Expand Down Expand Up @@ -656,11 +657,18 @@ def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:

def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
expression = expression.copy()
this, *rest_args = expression.expressions
for arg in rest_args:
this = exp.DPipe(this=this, expression=arg)
return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))

return self.sql(this)

def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
expression = expression.copy()
delim, *rest_args = expression.expressions
return self.sql(
reduce(
lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)),
rest_args,
)
)


def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/dialects/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sqlglot import exp, transforms
from sqlglot.dialects.dialect import (
concat_to_dpipe_sql,
concat_ws_to_dpipe_sql,
rename_func,
ts_or_ds_to_date_sql,
)
Expand Down Expand Up @@ -123,6 +124,7 @@ class Generator(Postgres.Generator):
TRANSFORMS = {
**Postgres.Generator.TRANSFORMS,
exp.Concat: concat_to_dpipe_sql,
exp.ConcatWs: concat_ws_to_dpipe_sql,
exp.CurrentTimestamp: lambda self, e: "SYSDATE",
exp.DateAdd: lambda self, e: self.func(
"DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this
Expand Down
24 changes: 19 additions & 5 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,7 @@ class Parser(metaclass=_Parser):
"ANY_VALUE": lambda self: self._parse_any_value(),
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
"CONCAT": lambda self: self._parse_concat(),
"CONCAT_WS": lambda self: self._parse_concat_ws(),
"CONVERT": lambda self: self._parse_convert(self.STRICT_CAST),
"DECODE": lambda self: self._parse_decode(),
"EXTRACT": lambda self: self._parse_extract(),
Expand Down Expand Up @@ -4073,11 +4074,7 @@ def _parse_cast(self, strict: bool) -> exp.Expression:
def _parse_concat(self) -> t.Optional[exp.Expression]:
args = self._parse_csv(self._parse_conjunction)
if self.CONCAT_NULL_OUTPUTS_STRING:
args = [
exp.func("COALESCE", exp.cast(arg, "text"), exp.Literal.string(""))
for arg in args
if arg
]
args = self._ensure_string_if_null(args)

# Some dialects (e.g. Trino) don't allow a single-argument CONCAT call, so when
# we find such a call we replace it with its argument.
Expand All @@ -4088,6 +4085,16 @@ def _parse_concat(self) -> t.Optional[exp.Expression]:
exp.Concat if self.STRICT_STRING_CONCAT else exp.SafeConcat, expressions=args
)

def _parse_concat_ws(self) -> t.Optional[exp.Expression]:
args = self._parse_csv(self._parse_conjunction)
if len(args) < 2:
return self.expression(exp.ConcatWs, expressions=args)
delim, *values = args
if self.CONCAT_NULL_OUTPUTS_STRING:
values = self._ensure_string_if_null(values)

return self.expression(exp.ConcatWs, expressions=[delim] + values)

def _parse_string_agg(self) -> exp.Expression:
if self._match(TokenType.DISTINCT):
args: t.List[t.Optional[exp.Expression]] = [
Expand Down Expand Up @@ -5145,3 +5152,10 @@ def _replace_lambda(
else:
column.replace(dot_or_id)
return node

def _ensure_string_if_null(self, values: t.List[exp.Expression]) -> t.List[exp.Expression]:
return [
exp.func("COALESCE", exp.cast(value, "text"), exp.Literal.string(""))
for value in values
if value
]
14 changes: 14 additions & 0 deletions tests/dialects/test_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,17 @@ def test_no_schema_binding(self):
"redshift": "CREATE OR REPLACE VIEW v1 AS SELECT cola, colb FROM t1 WITH NO SCHEMA BINDING",
},
)

def test_concat(self):
self.validate_all(
"SELECT CONCAT('abc', 'def')",
write={
"redshift": "SELECT COALESCE(CAST('abc' AS VARCHAR(MAX)), '') || COALESCE(CAST('def' AS VARCHAR(MAX)), '')",
},
)
self.validate_all(
"SELECT CONCAT_WS('DELIM', 'abc', 'def', 'ghi')",
write={
"redshift": "SELECT COALESCE(CAST('abc' AS VARCHAR(MAX)), '') || 'DELIM' || COALESCE(CAST('def' AS VARCHAR(MAX)), '') || 'DELIM' || COALESCE(CAST('ghi' AS VARCHAR(MAX)), '')",
},
)
15 changes: 15 additions & 0 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,3 +719,18 @@ def test_parse_intervals(self):

self.assertEqual(ast.find(exp.Interval).this.sql(), "'71'")
self.assertEqual(ast.find(exp.Interval).unit.assert_is(exp.Var).sql(), "days")

def test_parse_concat_ws(self):
ast = parse_one("CONCAT_WS(' ', 'John', 'Doe')")

self.assertEqual(ast.sql(), "CONCAT_WS(' ', 'John', 'Doe')")
self.assertEqual(ast.expressions[0].sql(), "' '")
self.assertEqual(ast.expressions[1].sql(), "'John'")
self.assertEqual(ast.expressions[2].sql(), "'Doe'")

# Ensure we can parse without argument when error level is ignore
ast = parse(
"CONCAT_WS()",
error_level=ErrorLevel.IGNORE,
)
self.assertEqual(ast[0].sql(), "CONCAT_WS()")

0 comments on commit 11d95ff

Please sign in to comment.