diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 5f890ac74b..620578cd10 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1221,6 +1221,7 @@ class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind): # this: True -> ALWAYS, this: False -> BY DEFAULT arg_types = { "this": False, + "expression": False, "on_null": False, "start": False, "increment": False, diff --git a/sqlglot/generator.py b/sqlglot/generator.py index bb94a51fc5..42fce8bb4c 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -595,14 +595,20 @@ def generatedasidentitycolumnconstraint_sql( maxvalue = f" MAXVALUE {maxvalue}" if maxvalue else "" cycle = expression.args.get("cycle") cycle_sql = "" + if cycle is not None: cycle_sql = f"{' NO' if not cycle else ''} CYCLE" cycle_sql = cycle_sql.strip() if not start and not increment else cycle_sql + sequence_opts = "" if start or increment or cycle_sql: sequence_opts = f"{start}{increment}{minvalue}{maxvalue}{cycle_sql}" sequence_opts = f" ({sequence_opts.strip()})" - return f"GENERATED{this}AS IDENTITY{sequence_opts}" + + expr = self.sql(expression, "expression") + expr = f"({expr})" if expr else "IDENTITY" + + return f"GENERATED{this}AS {expr}{sequence_opts}" def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str: return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL" diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 4ffefcbc7d..f41e8355b3 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -3302,7 +3302,9 @@ def _parse_generated_as_identity(self) -> exp.Expression: self._match_text_seq("ALWAYS") this = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True) - self._match_text_seq("AS", "IDENTITY") + self._match(TokenType.ALIAS) + identity = self._match_text_seq("IDENTITY") + if self._match(TokenType.L_PAREN): if self._match_text_seq("START", "WITH"): this.set("start", self._parse_bitwise()) @@ -3318,6 +3320,9 @@ def _parse_generated_as_identity(self) -> exp.Expression: elif self._match_text_seq("NO", "CYCLE"): this.set("cycle", False) + if not identity: + this.set("expression", self._parse_bitwise()) + self._match_r_paren() return this diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py index 46191086f1..8239dec386 100644 --- a/tests/dialects/test_databricks.py +++ b/tests/dialects/test_databricks.py @@ -9,6 +9,14 @@ def test_databricks(self): self.validate_identity("CREATE FUNCTION a.b(x INT) RETURNS INT RETURN x + 1") self.validate_identity("CREATE FUNCTION a AS b") self.validate_identity("SELECT ${x} FROM ${y} WHERE ${z} > 1") + self.validate_identity("CREATE TABLE foo (x DATE GENERATED ALWAYS AS (CAST(y AS DATE)))") + + self.validate_all( + "CREATE TABLE foo (x INT GENERATED ALWAYS AS (YEAR(y)))", + write={ + "databricks": "CREATE TABLE foo (x INT GENERATED ALWAYS AS (YEAR(TO_DATE(y))))", + }, + ) # https://docs.databricks.com/sql/language-manual/functions/colonsign.html def test_json(self):