Skip to content

Commit

Permalink
Feat: improve support for NVL2 function (#2042)
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas authored Aug 11, 2023
1 parent cf038ee commit c817e19
Show file tree
Hide file tree
Showing 17 changed files with 85 additions and 7 deletions.
1 change: 1 addition & 0 deletions sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ class Generator(generator.Generator):
LIMIT_FETCH = "LIMIT"
RENAME_TABLE_WITH_DB = False
ESCAPE_LINE_BREAK = True
NVL2_SUPPORTED = False

TRANSFORMS = {
**generator.Generator.TRANSFORMS,
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def _parse_on_property(self) -> t.Optional[exp.Expression]:
class Generator(generator.Generator):
QUERY_HINTS = False
STRUCT_DELIMITER = ("(", ")")
NVL2_SUPPORTED = False

TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
Expand Down
1 change: 0 additions & 1 deletion sqlglot/dialects/doris.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ class Generator(MySQL.Generator):
**MySQL.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayAgg: rename_func("COLLECT_LIST"),
exp.Coalesce: rename_func("NVL"),
exp.CurrentTimestamp: lambda *_: "NOW()",
exp.DateTrunc: lambda self, e: self.func(
"DATE_TRUNC", e.this, "'" + e.text("unit") + "'"
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/drill.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
QUERY_HINTS = False
NVL2_SUPPORTED = False

TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ class Generator(generator.Generator):
LIMIT_FETCH = "LIMIT"
STRUCT_DELIMITER = ("(", ")")
RENAME_TABLE_WITH_DB = False
NVL2_SUPPORTED = False

TRANSFORMS = {
**generator.Generator.TRANSFORMS,
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ class Generator(generator.Generator):
QUERY_HINTS = False
INDEX_ON = "ON TABLE"
EXTRACT_ALLOWS_QUOTES = False
NVL2_SUPPORTED = False

TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ class Generator(generator.Generator):
DUPLICATE_KEY_UPDATE_WITH_SET = False
QUERY_HINT_SEP = " "
VALUES_AS_TABLE = False
NVL2_SUPPORTED = False

TRANSFORMS = {
**generator.Generator.TRANSFORMS,
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
QUERY_HINTS = False
NVL2_SUPPORTED = False
PARAMETER_TOKEN = "$"

TYPE_MAPPING = {
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ class Generator(generator.Generator):
QUERY_HINTS = False
IS_BOOL_ALLOWED = False
TZ_TO_WITH_TIME_ZONE = True
NVL2_SUPPORTED = False
STRUCT_DELIMITER = ("(", ")")

PROPERTIES_LOCATION = {
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class Generator(Postgres.Generator):
QUERY_HINTS = False
VALUES_AS_TABLE = False
TZ_TO_WITH_TIME_ZONE = True
NVL2_SUPPORTED = True

TYPE_MAPPING = {
**Postgres.Generator.TYPE_MAPPING,
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/spark2.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[st

class Generator(Hive.Generator):
QUERY_HINTS = True
NVL2_SUPPORTED = True

TYPE_MAPPING = {
**Hive.Generator.TYPE_MAPPING,
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
QUERY_HINTS = False
NVL2_SUPPORTED = False

TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,7 @@ class Generator(generator.Generator):
LIMIT_IS_TOP = True
QUERY_HINTS = False
RETURNING_END = False
NVL2_SUPPORTED = False

TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
Expand Down
18 changes: 18 additions & 0 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ class Generator:
# Whether or not TIMETZ / TIMESTAMPTZ will be generated using the "WITH TIME ZONE" syntax
TZ_TO_WITH_TIME_ZONE = False

# Whether or not the NVL2 function is supported
NVL2_SUPPORTED = True

# https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax
SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE")

Expand Down Expand Up @@ -2631,6 +2634,21 @@ def indexcolumnconstraint_sql(self, expression: exp.IndexColumnConstraint) -> st
options = f" {options}" if options else ""
return f"{kind}{this}{type_}{schema}{options}"

def nvl2_sql(self, expression: exp.Nvl2) -> str:
if self.NVL2_SUPPORTED:
return self.function_fallback_sql(expression)

case = exp.Case().when(
expression.this.is_(exp.null()).not_(copy=False),
expression.args["true"].copy(),
copy=False,
)
else_cond = expression.args.get("false")
if else_cond:
case.else_(else_cond.copy(), copy=False)

return self.sql(case)


def cached_generator(
cache: t.Optional[t.Dict[int, str]] = None
Expand Down
54 changes: 54 additions & 0 deletions tests/dialects/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,60 @@ def test_if_null(self):
},
)

def test_nvl2(self):
self.validate_all(
"SELECT NVL2(a, b, c)",
write={
"": "SELECT NVL2(a, b, c)",
"bigquery": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
"clickhouse": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
"databricks": "SELECT NVL2(a, b, c)",
"doris": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
"drill": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
"duckdb": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
"hive": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
"mysql": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
"oracle": "SELECT NVL2(a, b, c)",
"postgres": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
"presto": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
"redshift": "SELECT NVL2(a, b, c)",
"snowflake": "SELECT NVL2(a, b, c)",
"spark": "SELECT NVL2(a, b, c)",
"spark2": "SELECT NVL2(a, b, c)",
"sqlite": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
"starrocks": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
"teradata": "SELECT NVL2(a, b, c)",
"trino": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
"tsql": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
},
)
self.validate_all(
"SELECT NVL2(a, b)",
write={
"": "SELECT NVL2(a, b)",
"bigquery": "SELECT CASE WHEN NOT a IS NULL THEN b END",
"clickhouse": "SELECT CASE WHEN NOT a IS NULL THEN b END",
"databricks": "SELECT NVL2(a, b)",
"doris": "SELECT CASE WHEN NOT a IS NULL THEN b END",
"drill": "SELECT CASE WHEN NOT a IS NULL THEN b END",
"duckdb": "SELECT CASE WHEN NOT a IS NULL THEN b END",
"hive": "SELECT CASE WHEN NOT a IS NULL THEN b END",
"mysql": "SELECT CASE WHEN NOT a IS NULL THEN b END",
"oracle": "SELECT NVL2(a, b)",
"postgres": "SELECT CASE WHEN NOT a IS NULL THEN b END",
"presto": "SELECT CASE WHEN NOT a IS NULL THEN b END",
"redshift": "SELECT NVL2(a, b)",
"snowflake": "SELECT NVL2(a, b)",
"spark": "SELECT NVL2(a, b)",
"spark2": "SELECT NVL2(a, b)",
"sqlite": "SELECT CASE WHEN NOT a IS NULL THEN b END",
"starrocks": "SELECT CASE WHEN NOT a IS NULL THEN b END",
"teradata": "SELECT NVL2(a, b)",
"trino": "SELECT CASE WHEN NOT a IS NULL THEN b END",
"tsql": "SELECT CASE WHEN NOT a IS NULL THEN b END",
},
)

def test_time(self):
self.validate_all(
"STR_TO_TIME(x, '%Y-%m-%dT%H:%M:%S')",
Expand Down
1 change: 1 addition & 0 deletions tests/dialects/test_doris.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ class TestDoris(Validator):
dialect = "doris"

def test_identity(self):
self.validate_identity("COALECSE(a, b, c, d)")
self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo")
self.validate_identity("SELECT APPROX_COUNT_DISTINCT(a) FROM x")

Expand Down
6 changes: 0 additions & 6 deletions tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,12 +382,6 @@ def test_snowflake(self):
"snowflake": "SELECT ARRAY_UNION_AGG(a)",
},
)
self.validate_all(
"SELECT NVL2(a, b, c)",
write={
"snowflake": "SELECT NVL2(a, b, c)",
},
)
self.validate_all(
"SELECT $$a$$",
write={
Expand Down

0 comments on commit c817e19

Please sign in to comment.