diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py index fc342621ae..821266946b 100644 --- a/sqlglot/dialects/__init__.py +++ b/sqlglot/dialects/__init__.py @@ -60,6 +60,7 @@ class Generator(Generator): from sqlglot.dialects.clickhouse import ClickHouse from sqlglot.dialects.databricks import Databricks from sqlglot.dialects.dialect import Dialect, Dialects +from sqlglot.dialects.doris import Doris from sqlglot.dialects.drill import Drill from sqlglot.dialects.duckdb import DuckDB from sqlglot.dialects.hive import Hive diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 1d0584c679..e73cfc8bd2 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -39,6 +39,7 @@ class Dialects(str, Enum): TERADATA = "teradata" TRINO = "trino" TSQL = "tsql" + Doris = "doris" class _Dialect(type): diff --git a/sqlglot/dialects/doris.py b/sqlglot/dialects/doris.py new file mode 100644 index 0000000000..50e18ad1f8 --- /dev/null +++ b/sqlglot/dialects/doris.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import typing as t + +from sqlglot import exp, generator +from sqlglot.dialects.dialect import ( + approx_count_distinct_sql, + arrow_json_extract_sql, + rename_func, +) +from sqlglot.dialects.mysql import MySQL +from sqlglot.helper import seq_get + + +def _to_date_sql(self: MySQL.Generator, expression: exp.TsOrDsToDate) -> str: + this = self.sql(expression, "this") + self.format_time(expression) + return f"TO_DATE({this})" + + +def _time_format( + self: generator.Generator, expression: exp.UnixToStr | exp.StrToUnix +) -> t.Optional[str]: + time_format = self.format_time(expression) + if time_format == Doris.TIME_FORMAT: + return None + return time_format + + +class Doris(MySQL): + DATE_FORMAT = "'yyyy-MM-dd'" + DATEINT_FORMAT = "'yyyyMMdd'" + TIME_FORMAT = "'yyyy-MM-dd HH:mm:ss'" + + TIME_MAPPING = { + "%M": "%B", + "%m": "%%-M", + "%c": "%-m", + "%e": "%-d", + "%h": "%I", + "%S": "%S", + "%u": "%W", + "%k": "%-H", + "%l": "%-I", + "%W": "%a", + "%Y": "%Y", + "%d": "%%-d", + "%H": "%%-H", + "%s": "%%-S", + "%D": "%%-j", + "%a": "%%p", + "%y": "%%Y", + "%": "%%", + } + + class Parser(MySQL.Parser): + FUNCTIONS = { + **MySQL.Parser.FUNCTIONS, + "DATE_TRUNC": lambda args: exp.TimestampTrunc( + this=seq_get(args, 1), unit=seq_get(args, 0) + ), + "REGEXP": exp.RegexpLike.from_arg_list, + } + + class Generator(MySQL.Generator): + CAST_MAPPING = {} + + TYPE_MAPPING = { + **MySQL.Generator.TYPE_MAPPING, + exp.DataType.Type.TEXT: "STRING", + exp.DataType.Type.TIMESTAMP: "DATETIME", + exp.DataType.Type.TIMESTAMPTZ: "DATETIME", + } + + TRANSFORMS = { + **MySQL.Generator.TRANSFORMS, + exp.ApproxDistinct: approx_count_distinct_sql, + exp.ArrayAgg: rename_func("COLLECT_LIST"), + exp.Coalesce: rename_func("NVL"), + exp.CurrentTimestamp: lambda *_: "NOW()", + exp.DateTrunc: lambda self, e: self.func( + "DATE_TRUNC", e.this, "'" + e.text("unit") + "'" + ), + exp.JSONExtractScalar: arrow_json_extract_sql, + exp.JSONExtract: arrow_json_extract_sql, + exp.RegexpLike: rename_func("REGEXP"), + exp.RegexpSplit: rename_func("SPLIT_BY_STRING"), + exp.SetAgg: rename_func("COLLECT_SET"), + exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.Split: rename_func("SPLIT_BY_STRING"), + exp.TimeStrToDate: rename_func("TO_DATE"), + exp.ToChar: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", # Only for day level + exp.TsOrDsToDate: lambda self, e: self.func("TO_DATE", e.this), + exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), + exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), + exp.TimestampTrunc: lambda self, e: self.func( + "DATE_TRUNC", e.this, "'" + e.text("unit") + "'" + ), + exp.UnixToStr: lambda self, e: self.func( + "FROM_UNIXTIME", e.this, _time_format(self, e) + ), + exp.UnixToTime: rename_func("FROM_UNIXTIME"), + exp.Map: rename_func("ARRAY_MAP"), + } diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index aaaffab785..9873038f8f 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -90,6 +90,7 @@ def test_cast(self): "snowflake": "CAST(a AS TEXT)", "spark": "CAST(a AS STRING)", "starrocks": "CAST(a AS STRING)", + "doris": "CAST(a AS STRING)", }, ) self.validate_all( @@ -169,6 +170,7 @@ def test_cast(self): "snowflake": "CAST(a AS TEXT)", "spark": "CAST(a AS STRING)", "starrocks": "CAST(a AS STRING)", + "doris": "CAST(a AS STRING)", }, ) self.validate_all( @@ -186,6 +188,7 @@ def test_cast(self): "snowflake": "CAST(a AS VARCHAR)", "spark": "CAST(a AS STRING)", "starrocks": "CAST(a AS VARCHAR)", + "doris": "CAST(a AS VARCHAR)", }, ) self.validate_all( @@ -203,6 +206,7 @@ def test_cast(self): "snowflake": "CAST(a AS VARCHAR(3))", "spark": "CAST(a AS VARCHAR(3))", "starrocks": "CAST(a AS VARCHAR(3))", + "doris": "CAST(a AS VARCHAR(3))", }, ) self.validate_all( @@ -221,6 +225,7 @@ def test_cast(self): "spark": "CAST(a AS SMALLINT)", "sqlite": "CAST(a AS INTEGER)", "starrocks": "CAST(a AS SMALLINT)", + "doris": "CAST(a AS SMALLINT)", }, ) self.validate_all( @@ -234,6 +239,7 @@ def test_cast(self): "drill": "CAST(a AS DOUBLE)", "postgres": "CAST(a AS DOUBLE PRECISION)", "redshift": "CAST(a AS DOUBLE PRECISION)", + "doris": "CAST(a AS DOUBLE)", }, ) @@ -267,6 +273,7 @@ def test_cast(self): write={ "starrocks": "CAST(a AS DATETIME)", "redshift": "CAST(a AS TIMESTAMP)", + "doris": "CAST(a AS DATETIME)", }, ) self.validate_all( @@ -274,6 +281,7 @@ def test_cast(self): write={ "starrocks": "CAST(a AS DATETIME)", "redshift": "CAST(a AS TIMESTAMPTZ)", + "doris": "CAST(a AS DATETIME)", }, ) self.validate_all("CAST(a AS TINYINT)", write={"oracle": "CAST(a AS NUMBER)"}) @@ -408,6 +416,7 @@ def test_time(self): "hive": "UNIX_TIMESTAMP('2020-01-01', 'yyyy-mm-dd')", "presto": "TO_UNIXTIME(DATE_PARSE('2020-01-01', '%Y-%i-%d'))", "starrocks": "UNIX_TIMESTAMP('2020-01-01', '%Y-%i-%d')", + "doris": "UNIX_TIMESTAMP('2020-01-01', '%Y-%M-%d')", }, ) self.validate_all( @@ -418,6 +427,7 @@ def test_time(self): "hive": "TO_DATE('2020-01-01')", "presto": "CAST('2020-01-01' AS TIMESTAMP)", "starrocks": "TO_DATE('2020-01-01')", + "doris": "TO_DATE('2020-01-01')", }, ) self.validate_all( @@ -428,6 +438,7 @@ def test_time(self): "hive": "CAST('2020-01-01' AS TIMESTAMP)", "presto": "CAST('2020-01-01' AS TIMESTAMP)", "sqlite": "'2020-01-01'", + "doris": "CAST('2020-01-01' AS DATETIME)", }, ) self.validate_all( @@ -437,6 +448,7 @@ def test_time(self): "hive": "UNIX_TIMESTAMP('2020-01-01')", "mysql": "UNIX_TIMESTAMP('2020-01-01')", "presto": "TO_UNIXTIME(DATE_PARSE('2020-01-01', '%Y-%m-%d %T'))", + "doris": "UNIX_TIMESTAMP('2020-01-01')", }, ) self.validate_all( @@ -449,6 +461,7 @@ def test_time(self): "postgres": "TO_CHAR(x, 'YYYY-MM-DD')", "presto": "DATE_FORMAT(x, '%Y-%m-%d')", "redshift": "TO_CHAR(x, 'YYYY-MM-DD')", + "doris": "DATE_FORMAT(x, '%Y-%m-%d')", }, ) self.validate_all( @@ -459,6 +472,7 @@ def test_time(self): "hive": "CAST(x AS STRING)", "presto": "CAST(x AS VARCHAR)", "redshift": "CAST(x AS VARCHAR(MAX))", + "doris": "CAST(x AS STRING)", }, ) self.validate_all( @@ -468,6 +482,7 @@ def test_time(self): "duckdb": "EPOCH(x)", "hive": "UNIX_TIMESTAMP(x)", "presto": "TO_UNIXTIME(x)", + "doris": "UNIX_TIMESTAMP(x)", }, ) self.validate_all( @@ -476,6 +491,7 @@ def test_time(self): "duckdb": "SUBSTRING(CAST(x AS TEXT), 1, 10)", "hive": "SUBSTRING(CAST(x AS STRING), 1, 10)", "presto": "SUBSTRING(CAST(x AS VARCHAR), 1, 10)", + "doris": "SUBSTRING(CAST(x AS STRING), 1, 10)", }, ) self.validate_all( @@ -487,6 +503,7 @@ def test_time(self): "postgres": "CAST(x AS DATE)", "presto": "CAST(CAST(x AS TIMESTAMP) AS DATE)", "snowflake": "CAST(x AS DATE)", + "doris": "TO_DATE(x)", }, ) self.validate_all( @@ -505,6 +522,7 @@ def test_time(self): "hive": "FROM_UNIXTIME(x, y)", "presto": "DATE_FORMAT(FROM_UNIXTIME(x), y)", "starrocks": "FROM_UNIXTIME(x, y)", + "doris": "FROM_UNIXTIME(x, y)", }, ) self.validate_all( @@ -516,6 +534,7 @@ def test_time(self): "postgres": "TO_TIMESTAMP(x)", "presto": "FROM_UNIXTIME(x)", "starrocks": "FROM_UNIXTIME(x)", + "doris": "FROM_UNIXTIME(x)", }, ) self.validate_all( @@ -582,6 +601,7 @@ def test_time(self): "sqlite": "DATE(x, '1 DAY')", "starrocks": "DATE_ADD(x, INTERVAL 1 DAY)", "tsql": "DATEADD(DAY, 1, x)", + "doris": "DATE_ADD(x, INTERVAL 1 DAY)", }, ) self.validate_all( @@ -595,6 +615,7 @@ def test_time(self): "presto": "DATE_ADD('day', 1, x)", "spark": "DATE_ADD(x, 1)", "starrocks": "DATE_ADD(x, INTERVAL 1 DAY)", + "doris": "DATE_ADD(x, INTERVAL 1 DAY)", }, ) self.validate_all( @@ -612,6 +633,7 @@ def test_time(self): "snowflake": "DATE_TRUNC('day', x)", "starrocks": "DATE_TRUNC('day', x)", "spark": "TRUNC(x, 'day')", + "doris": "DATE_TRUNC(x, 'day')", }, ) self.validate_all( @@ -624,6 +646,7 @@ def test_time(self): "snowflake": "DATE_TRUNC('day', x)", "starrocks": "DATE_TRUNC('day', x)", "spark": "DATE_TRUNC('day', x)", + "doris": "DATE_TRUNC('day', x)", }, ) self.validate_all( @@ -684,6 +707,7 @@ def test_time(self): "snowflake": "DATE_TRUNC('year', x)", "starrocks": "DATE_TRUNC('year', x)", "spark": "TRUNC(x, 'year')", + "doris": "DATE_TRUNC(x, 'year')", }, ) self.validate_all( @@ -698,6 +722,7 @@ def test_time(self): write={ "bigquery": "TIMESTAMP_TRUNC(x, year)", "spark": "DATE_TRUNC('year', x)", + "doris": "DATE_TRUNC(x, 'year')", }, ) self.validate_all( @@ -719,6 +744,7 @@ def test_time(self): "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS DATE)", "presto": "CAST(DATE_PARSE(x, '%Y-%m-%dT%T') AS DATE)", "spark": "TO_DATE(x, 'yyyy-MM-ddTHH:mm:ss')", + "doris": "STR_TO_DATE(x, '%Y-%m-%dT%H:%M:%S')", }, ) self.validate_all( @@ -730,6 +756,7 @@ def test_time(self): "hive": "CAST(x AS DATE)", "presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)", "spark": "TO_DATE(x)", + "doris": "STR_TO_DATE(x, '%Y-%m-%d')", }, ) self.validate_all( @@ -784,6 +811,7 @@ def test_time(self): "mysql": "CAST('2022-01-01' AS TIMESTAMP)", "starrocks": "CAST('2022-01-01' AS DATETIME)", "hive": "CAST('2022-01-01' AS TIMESTAMP)", + "doris": "CAST('2022-01-01' AS DATETIME)", }, ) self.validate_all( @@ -792,6 +820,7 @@ def test_time(self): "mysql": "TIMESTAMP('2022-01-01')", "starrocks": "TIMESTAMP('2022-01-01')", "hive": "TIMESTAMP('2022-01-01')", + "doris": "TIMESTAMP('2022-01-01')", }, ) @@ -807,6 +836,7 @@ def test_time(self): "mysql", "presto", "starrocks", + "doris", ) }, write={ @@ -820,6 +850,7 @@ def test_time(self): "hive", "spark", "starrocks", + "doris", ) }, ) @@ -886,6 +917,7 @@ def test_json(self): "postgres": "x->'y'", "presto": "JSON_EXTRACT(x, 'y')", "starrocks": "x -> 'y'", + "doris": "x -> 'y'", }, write={ "mysql": "JSON_EXTRACT(x, 'y')", @@ -893,6 +925,7 @@ def test_json(self): "postgres": "x -> 'y'", "presto": "JSON_EXTRACT(x, 'y')", "starrocks": "x -> 'y'", + "doris": "x -> 'y'", }, ) self.validate_all( @@ -1115,6 +1148,7 @@ def test_operators(self): "sqlite": "LOWER(x) LIKE '%y'", "starrocks": "LOWER(x) LIKE '%y'", "trino": "LOWER(x) LIKE '%y'", + "doris": "LOWER(x) LIKE '%y'", }, ) self.validate_all( diff --git a/tests/dialects/test_doris.py b/tests/dialects/test_doris.py new file mode 100644 index 0000000000..63325a6b4d --- /dev/null +++ b/tests/dialects/test_doris.py @@ -0,0 +1,20 @@ +from tests.dialects.test_dialect import Validator + + +class TestDoris(Validator): + dialect = "doris" + + def test_identity(self): + self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo") + self.validate_identity("SELECT APPROX_COUNT_DISTINCT(a) FROM x") + + def test_time(self): + self.validate_identity("TIMESTAMP('2022-01-01')") + + def test_regex(self): + self.validate_all( + "SELECT REGEXP_LIKE(abc, '%foo%')", + write={ + "doris": "SELECT REGEXP(abc, '%foo%')", + }, + )