Skip to content

Commit

Permalink
Refactor(optimizer): make the type annotator more dry (#1777)
Browse files Browse the repository at this point in the history
* Refactor(optimizer): make the type annotator more dry, add DATE function

* Test fixup

* Minor fixup

* Add another test
  • Loading branch information
georgesittas authored Jun 15, 2023
1 parent 0a1362b commit 1dbed85
Show file tree
Hide file tree
Showing 7 changed files with 272 additions and 269 deletions.
2 changes: 1 addition & 1 deletion sqlglot/dialects/spark2.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,13 @@ class Parser(Hive.Parser):
"WEEKOFYEAR": lambda args: exp.WeekOfYear(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")),
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1),
unit=exp.var(seq_get(args, 0)),
),
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
"BOOLEAN": _parse_as_cast("boolean"),
"DATE": _parse_as_cast("date"),
"DOUBLE": _parse_as_cast("double"),
"FLOAT": _parse_as_cast("float"),
"INT": _parse_as_cast("int"),
Expand Down
5 changes: 5 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4095,6 +4095,11 @@ class DateToDi(Func):
pass


class Date(Func):
arg_types = {"expressions": True}
is_var_len_args = True


class Day(Func):
pass

Expand Down
518 changes: 253 additions & 265 deletions sqlglot/optimizer/annotate_types.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion sqlglot/optimizer/qualify_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

def qualify_columns(
expression: exp.Expression,
schema: dict | Schema,
schema: t.Dict | Schema,
expand_alias_refs: bool = True,
infer_schema: t.Optional[bool] = None,
) -> exp.Expression:
Expand Down
5 changes: 5 additions & 0 deletions tests/dialects/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ class TestBigQuery(Validator):
dialect = "bigquery"

def test_bigquery(self):
self.validate_identity("DATE(2016, 12, 25)")
self.validate_identity("DATE(CAST('2016-12-25 23:59:59' AS DATETIME))")
self.validate_identity("SELECT foo IN UNNEST(bar) AS bla")
self.validate_identity("SELECT * FROM x-0.a")
self.validate_identity("SELECT * FROM pivot CROSS JOIN foo")
Expand All @@ -28,6 +30,9 @@ def test_bigquery(self):
self.validate_identity("SELECT * FROM q UNPIVOT(values FOR quarter IN (b, c))")
self.validate_identity("""CREATE TABLE x (a STRUCT<values ARRAY<INT64>>)""")
self.validate_identity("""CREATE TABLE x (a STRUCT<b STRING OPTIONS (description='b')>)""")
self.validate_identity(
"DATE(CAST('2016-12-25 05:30:00+07' AS DATETIME), 'America/Los_Angeles')"
)
self.validate_identity(
"""CREATE TABLE x (a STRING OPTIONS (description='x')) OPTIONS (table_expiration_days=1)"""
)
Expand Down
1 change: 1 addition & 0 deletions tests/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ def test_functions(self):
self.assertIsInstance(parse_one("HLL(a)"), exp.Hll)
self.assertIsInstance(parse_one("ARRAY(time, foo)"), exp.Array)
self.assertIsInstance(parse_one("STANDARD_HASH('hello', 'sha256')"), exp.StandardHash)
self.assertIsInstance(parse_one("DATE(foo)"), exp.Date)

def test_column(self):
column = parse_one("a.b.c.d")
Expand Down
8 changes: 6 additions & 2 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,9 +553,10 @@ def test_cte_column_annotation(self):

def test_function_annotation(self):
schema = {"x": {"cola": "VARCHAR", "colb": "CHAR"}}
sql = "SELECT x.cola || TRIM(x.colb) AS col FROM x AS x"
sql = "SELECT x.cola || TRIM(x.colb) AS col, DATE(x.colb) FROM x AS x"

concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
expression = annotate_types(parse_one(sql), schema=schema)
concat_expr_alias = expression.expressions[0]
self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.VARCHAR)

concat_expr = concat_expr_alias.this
Expand All @@ -564,6 +565,9 @@ def test_function_annotation(self):
self.assertEqual(concat_expr.right.type.this, exp.DataType.Type.VARCHAR) # TRIM(x.colb)
self.assertEqual(concat_expr.right.this.type.this, exp.DataType.Type.CHAR) # x.colb

date_expr = expression.expressions[1]
self.assertEqual(date_expr.type.this, exp.DataType.Type.DATE)

sql = "SELECT CASE WHEN 1=1 THEN x.cola ELSE x.colb END AS col FROM x AS x"

case_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
Expand Down

0 comments on commit 1dbed85

Please sign in to comment.