Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: add apache doris dialect #2006

Merged
merged 14 commits into from
Aug 9, 2023
1 change: 1 addition & 0 deletions sqlglot/dialects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class Generator(Generator):
from sqlglot.dialects.clickhouse import ClickHouse
from sqlglot.dialects.databricks import Databricks
from sqlglot.dialects.dialect import Dialect, Dialects
from sqlglot.dialects.doris import Doris
from sqlglot.dialects.drill import Drill
from sqlglot.dialects.duckdb import DuckDB
from sqlglot.dialects.hive import Hive
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class Dialects(str, Enum):
TERADATA = "teradata"
TRINO = "trino"
TSQL = "tsql"
Doris = "doris"


class _Dialect(type):
Expand Down
105 changes: 105 additions & 0 deletions sqlglot/dialects/doris.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from __future__ import annotations

import typing as t

from sqlglot import exp, generator
from sqlglot.dialects.dialect import (
approx_count_distinct_sql,
arrow_json_extract_sql,
rename_func,
)
from sqlglot.dialects.mysql import MySQL
from sqlglot.helper import seq_get


def _to_date_sql(self: MySQL.Generator, expression: exp.TsOrDsToDate) -> str:
this = self.sql(expression, "this")
self.format_time(expression)
return f"TO_DATE({this})"
Comment on lines +15 to +18
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unnecessary, let's get rid of it.



def _time_format(
self: generator.Generator, expression: exp.UnixToStr | exp.StrToUnix
) -> t.Optional[str]:
time_format = self.format_time(expression)
if time_format == Doris.TIME_FORMAT:
return None
return time_format
Comment on lines +21 to +27
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should factor this out into a helper in dialect.py, we do the exact same thing in Hive.



class Doris(MySQL):
georgesittas marked this conversation as resolved.
Show resolved Hide resolved
DATE_FORMAT = "'yyyy-MM-dd'"
DATEINT_FORMAT = "'yyyyMMdd'"
TIME_FORMAT = "'yyyy-MM-dd HH:mm:ss'"

TIME_MAPPING = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Several entries here are also in MySQL's TIME_MAPPING dict. Do Doris and MySQL have the same time mapping? Should we enhance MySQL's / reuse some existing ones here by expanding the superclass' dict?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, mysql and doris have some similar mappings, but we found that some mappings are not supported when going from hive to doris, so we added some on the previous basis to adapt to the syntax conversion of hive to doris

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how similar is it to starrocks? we support that as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The degree of compatibility with starrocks is similar to that of mysql. starrcoks came from fork doris and went out alone, but with the development and improvement of the community later, it will become more and more different from starrocks, and I plan to spend more time to perfect it Convert different data sources into doris.thanks

"%M": "%B",
"%m": "%%-M",
"%c": "%-m",
"%e": "%-d",
"%h": "%I",
"%S": "%S",
"%u": "%W",
"%k": "%-H",
"%l": "%-I",
"%W": "%a",
"%Y": "%Y",
"%d": "%%-d",
"%H": "%%-H",
"%s": "%%-S",
"%D": "%%-j",
"%a": "%%p",
"%y": "%%Y",
"%": "%%",
}

class Parser(MySQL.Parser):
FUNCTIONS = {
**MySQL.Parser.FUNCTIONS,
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1), unit=seq_get(args, 0)
),
Comment on lines +59 to +61
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do this exact thing in both Postgres and Starrocks already. Let's dry it out into a helper in dialect.py.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's dry it out into a helper in dialect.py.
Sorry, I don't quite understand the meaning of this sentence, do you mean to implement this method in dialect.py

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes dry means don’t repeat yourself

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please clean up any instance of copy and paste

"REGEXP": exp.RegexpLike.from_arg_list,
}

class Generator(MySQL.Generator):
CAST_MAPPING = {}

TYPE_MAPPING = {
**MySQL.Generator.TYPE_MAPPING,
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.TIMESTAMP: "DATETIME",
exp.DataType.Type.TIMESTAMPTZ: "DATETIME",
}

TRANSFORMS = {
georgesittas marked this conversation as resolved.
Show resolved Hide resolved
**MySQL.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayAgg: rename_func("COLLECT_LIST"),
exp.Coalesce: rename_func("NVL"),
exp.CurrentTimestamp: lambda *_: "NOW()",
exp.DateTrunc: lambda self, e: self.func(
"DATE_TRUNC", e.this, "'" + e.text("unit") + "'"
),
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.RegexpLike: rename_func("REGEXP"),
georgesittas marked this conversation as resolved.
Show resolved Hide resolved
exp.RegexpSplit: rename_func("SPLIT_BY_STRING"),
exp.SetAgg: rename_func("COLLECT_SET"),
exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Split: rename_func("SPLIT_BY_STRING"),
exp.TimeStrToDate: rename_func("TO_DATE"),
exp.ToChar: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", # Only for day level
exp.TsOrDsToDate: lambda self, e: self.func("TO_DATE", e.this),
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be removed, already exists in MySQL.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok,thanks

exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimestampTrunc: lambda self, e: self.func(
"DATE_TRUNC", e.this, "'" + e.text("unit") + "'"
),
exp.UnixToStr: lambda self, e: self.func(
"FROM_UNIXTIME", e.this, _time_format(self, e)
),
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.Map: rename_func("ARRAY_MAP"),
}
34 changes: 34 additions & 0 deletions tests/dialects/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def test_cast(self):
"snowflake": "CAST(a AS TEXT)",
"spark": "CAST(a AS STRING)",
"starrocks": "CAST(a AS STRING)",
"doris": "CAST(a AS STRING)",
},
)
self.validate_all(
Expand Down Expand Up @@ -169,6 +170,7 @@ def test_cast(self):
"snowflake": "CAST(a AS TEXT)",
"spark": "CAST(a AS STRING)",
"starrocks": "CAST(a AS STRING)",
"doris": "CAST(a AS STRING)",
},
)
self.validate_all(
Expand All @@ -186,6 +188,7 @@ def test_cast(self):
"snowflake": "CAST(a AS VARCHAR)",
"spark": "CAST(a AS STRING)",
"starrocks": "CAST(a AS VARCHAR)",
"doris": "CAST(a AS VARCHAR)",
},
)
self.validate_all(
Expand All @@ -203,6 +206,7 @@ def test_cast(self):
"snowflake": "CAST(a AS VARCHAR(3))",
"spark": "CAST(a AS VARCHAR(3))",
"starrocks": "CAST(a AS VARCHAR(3))",
"doris": "CAST(a AS VARCHAR(3))",
},
)
self.validate_all(
Expand All @@ -221,6 +225,7 @@ def test_cast(self):
"spark": "CAST(a AS SMALLINT)",
"sqlite": "CAST(a AS INTEGER)",
"starrocks": "CAST(a AS SMALLINT)",
"doris": "CAST(a AS SMALLINT)",
},
)
self.validate_all(
Expand All @@ -234,6 +239,7 @@ def test_cast(self):
"drill": "CAST(a AS DOUBLE)",
"postgres": "CAST(a AS DOUBLE PRECISION)",
"redshift": "CAST(a AS DOUBLE PRECISION)",
"doris": "CAST(a AS DOUBLE)",
},
)

Expand Down Expand Up @@ -267,13 +273,15 @@ def test_cast(self):
write={
"starrocks": "CAST(a AS DATETIME)",
"redshift": "CAST(a AS TIMESTAMP)",
"doris": "CAST(a AS DATETIME)",
},
)
self.validate_all(
"CAST(a AS TIMESTAMPTZ)",
write={
"starrocks": "CAST(a AS DATETIME)",
"redshift": "CAST(a AS TIMESTAMPTZ)",
"doris": "CAST(a AS DATETIME)",
},
)
self.validate_all("CAST(a AS TINYINT)", write={"oracle": "CAST(a AS NUMBER)"})
Expand Down Expand Up @@ -408,6 +416,7 @@ def test_time(self):
"hive": "UNIX_TIMESTAMP('2020-01-01', 'yyyy-mm-dd')",
"presto": "TO_UNIXTIME(DATE_PARSE('2020-01-01', '%Y-%i-%d'))",
"starrocks": "UNIX_TIMESTAMP('2020-01-01', '%Y-%i-%d')",
"doris": "UNIX_TIMESTAMP('2020-01-01', '%Y-%M-%d')",
},
)
self.validate_all(
Expand All @@ -418,6 +427,7 @@ def test_time(self):
"hive": "TO_DATE('2020-01-01')",
"presto": "CAST('2020-01-01' AS TIMESTAMP)",
"starrocks": "TO_DATE('2020-01-01')",
"doris": "TO_DATE('2020-01-01')",
},
)
self.validate_all(
Expand All @@ -428,6 +438,7 @@ def test_time(self):
"hive": "CAST('2020-01-01' AS TIMESTAMP)",
"presto": "CAST('2020-01-01' AS TIMESTAMP)",
"sqlite": "'2020-01-01'",
"doris": "CAST('2020-01-01' AS DATETIME)",
},
)
self.validate_all(
Expand All @@ -437,6 +448,7 @@ def test_time(self):
"hive": "UNIX_TIMESTAMP('2020-01-01')",
"mysql": "UNIX_TIMESTAMP('2020-01-01')",
"presto": "TO_UNIXTIME(DATE_PARSE('2020-01-01', '%Y-%m-%d %T'))",
"doris": "UNIX_TIMESTAMP('2020-01-01')",
},
)
self.validate_all(
Expand All @@ -449,6 +461,7 @@ def test_time(self):
"postgres": "TO_CHAR(x, 'YYYY-MM-DD')",
"presto": "DATE_FORMAT(x, '%Y-%m-%d')",
"redshift": "TO_CHAR(x, 'YYYY-MM-DD')",
"doris": "DATE_FORMAT(x, '%Y-%m-%d')",
},
)
self.validate_all(
Expand All @@ -459,6 +472,7 @@ def test_time(self):
"hive": "CAST(x AS STRING)",
"presto": "CAST(x AS VARCHAR)",
"redshift": "CAST(x AS VARCHAR(MAX))",
"doris": "CAST(x AS STRING)",
},
)
self.validate_all(
Expand All @@ -468,6 +482,7 @@ def test_time(self):
"duckdb": "EPOCH(x)",
"hive": "UNIX_TIMESTAMP(x)",
"presto": "TO_UNIXTIME(x)",
"doris": "UNIX_TIMESTAMP(x)",
},
)
self.validate_all(
Expand All @@ -476,6 +491,7 @@ def test_time(self):
"duckdb": "SUBSTRING(CAST(x AS TEXT), 1, 10)",
"hive": "SUBSTRING(CAST(x AS STRING), 1, 10)",
"presto": "SUBSTRING(CAST(x AS VARCHAR), 1, 10)",
"doris": "SUBSTRING(CAST(x AS STRING), 1, 10)",
},
)
self.validate_all(
Expand All @@ -487,6 +503,7 @@ def test_time(self):
"postgres": "CAST(x AS DATE)",
"presto": "CAST(CAST(x AS TIMESTAMP) AS DATE)",
"snowflake": "CAST(x AS DATE)",
"doris": "TO_DATE(x)",
},
)
self.validate_all(
Expand All @@ -505,6 +522,7 @@ def test_time(self):
"hive": "FROM_UNIXTIME(x, y)",
"presto": "DATE_FORMAT(FROM_UNIXTIME(x), y)",
"starrocks": "FROM_UNIXTIME(x, y)",
"doris": "FROM_UNIXTIME(x, y)",
},
)
self.validate_all(
Expand All @@ -516,6 +534,7 @@ def test_time(self):
"postgres": "TO_TIMESTAMP(x)",
"presto": "FROM_UNIXTIME(x)",
"starrocks": "FROM_UNIXTIME(x)",
"doris": "FROM_UNIXTIME(x)",
},
)
self.validate_all(
Expand Down Expand Up @@ -582,6 +601,7 @@ def test_time(self):
"sqlite": "DATE(x, '1 DAY')",
"starrocks": "DATE_ADD(x, INTERVAL 1 DAY)",
"tsql": "DATEADD(DAY, 1, x)",
"doris": "DATE_ADD(x, INTERVAL 1 DAY)",
},
)
self.validate_all(
Expand All @@ -595,6 +615,7 @@ def test_time(self):
"presto": "DATE_ADD('day', 1, x)",
"spark": "DATE_ADD(x, 1)",
"starrocks": "DATE_ADD(x, INTERVAL 1 DAY)",
"doris": "DATE_ADD(x, INTERVAL 1 DAY)",
},
)
self.validate_all(
Expand All @@ -612,6 +633,7 @@ def test_time(self):
"snowflake": "DATE_TRUNC('day', x)",
"starrocks": "DATE_TRUNC('day', x)",
"spark": "TRUNC(x, 'day')",
"doris": "DATE_TRUNC(x, 'day')",
},
)
self.validate_all(
Expand All @@ -624,6 +646,7 @@ def test_time(self):
"snowflake": "DATE_TRUNC('day', x)",
"starrocks": "DATE_TRUNC('day', x)",
"spark": "DATE_TRUNC('day', x)",
"doris": "DATE_TRUNC('day', x)",
},
)
self.validate_all(
Expand Down Expand Up @@ -684,6 +707,7 @@ def test_time(self):
"snowflake": "DATE_TRUNC('year', x)",
"starrocks": "DATE_TRUNC('year', x)",
"spark": "TRUNC(x, 'year')",
"doris": "DATE_TRUNC(x, 'year')",
},
)
self.validate_all(
Expand All @@ -698,6 +722,7 @@ def test_time(self):
write={
"bigquery": "TIMESTAMP_TRUNC(x, year)",
"spark": "DATE_TRUNC('year', x)",
"doris": "DATE_TRUNC(x, 'year')",
},
)
self.validate_all(
Expand All @@ -719,6 +744,7 @@ def test_time(self):
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS DATE)",
"presto": "CAST(DATE_PARSE(x, '%Y-%m-%dT%T') AS DATE)",
"spark": "TO_DATE(x, 'yyyy-MM-ddTHH:mm:ss')",
"doris": "STR_TO_DATE(x, '%Y-%m-%dT%H:%M:%S')",
},
)
self.validate_all(
Expand All @@ -730,6 +756,7 @@ def test_time(self):
"hive": "CAST(x AS DATE)",
"presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)",
"spark": "TO_DATE(x)",
"doris": "STR_TO_DATE(x, '%Y-%m-%d')",
},
)
self.validate_all(
Expand Down Expand Up @@ -784,6 +811,7 @@ def test_time(self):
"mysql": "CAST('2022-01-01' AS TIMESTAMP)",
"starrocks": "CAST('2022-01-01' AS DATETIME)",
"hive": "CAST('2022-01-01' AS TIMESTAMP)",
"doris": "CAST('2022-01-01' AS DATETIME)",
},
)
self.validate_all(
Expand All @@ -792,6 +820,7 @@ def test_time(self):
"mysql": "TIMESTAMP('2022-01-01')",
"starrocks": "TIMESTAMP('2022-01-01')",
"hive": "TIMESTAMP('2022-01-01')",
"doris": "TIMESTAMP('2022-01-01')",
},
)

Expand All @@ -807,6 +836,7 @@ def test_time(self):
"mysql",
"presto",
"starrocks",
"doris",
)
},
write={
Expand All @@ -820,6 +850,7 @@ def test_time(self):
"hive",
"spark",
"starrocks",
"doris",
)
},
)
Expand Down Expand Up @@ -886,13 +917,15 @@ def test_json(self):
"postgres": "x->'y'",
"presto": "JSON_EXTRACT(x, 'y')",
"starrocks": "x -> 'y'",
"doris": "x -> 'y'",
},
write={
"mysql": "JSON_EXTRACT(x, 'y')",
"oracle": "JSON_EXTRACT(x, 'y')",
"postgres": "x -> 'y'",
"presto": "JSON_EXTRACT(x, 'y')",
"starrocks": "x -> 'y'",
"doris": "x -> 'y'",
},
)
self.validate_all(
Expand Down Expand Up @@ -1115,6 +1148,7 @@ def test_operators(self):
"sqlite": "LOWER(x) LIKE '%y'",
"starrocks": "LOWER(x) LIKE '%y'",
"trino": "LOWER(x) LIKE '%y'",
"doris": "LOWER(x) LIKE '%y'",
},
)
self.validate_all(
Expand Down
Loading