diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 40ba88e6e8..410a2daf49 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -265,9 +265,12 @@ class Generator: # Expressions whose comments are separated from them for better formatting WITH_SEPARATED_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = ( + exp.Delete, exp.Drop, exp.From, + exp.Insert, exp.Select, + exp.Update, exp.Where, exp.With, ) diff --git a/sqlglot/parser.py b/sqlglot/parser.py index fcb54d1b4b..fe41c020fb 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -1692,6 +1692,7 @@ def _parse_describe(self) -> exp.Describe: return self.expression(exp.Describe, this=this, kind=kind) def _parse_insert(self) -> exp.Insert: + comments = ensure_list(self._prev_comments) overwrite = self._match(TokenType.OVERWRITE) ignore = self._match(TokenType.IGNORE) local = self._match_text_seq("LOCAL") @@ -1709,6 +1710,7 @@ def _parse_insert(self) -> exp.Insert: alternative = self._match_texts(self.INSERT_ALTERNATIVES) and self._prev.text self._match(TokenType.INTO) + comments += ensure_list(self._prev_comments) self._match(TokenType.TABLE) this = self._parse_table(schema=True) @@ -1716,6 +1718,7 @@ def _parse_insert(self) -> exp.Insert: return self.expression( exp.Insert, + comments=comments, this=this, exists=self._parse_exists(), partition=self._parse_partition(), @@ -1840,6 +1843,7 @@ def _parse_delete(self) -> exp.Delete: # This handles MySQL's "Multiple-Table Syntax" # https://dev.mysql.com/doc/refman/8.0/en/delete.html tables = None + comments = self._prev_comments if not self._match(TokenType.FROM, advance=False): tables = self._parse_csv(self._parse_table) or None @@ -1847,6 +1851,7 @@ def _parse_delete(self) -> exp.Delete: return self.expression( exp.Delete, + comments=comments, tables=tables, this=self._match(TokenType.FROM) and self._parse_table(joins=True), using=self._match(TokenType.USING) and self._parse_table(joins=True), @@ -1856,11 +1861,13 @@ def _parse_delete(self) -> exp.Delete: ) def _parse_update(self) -> exp.Update: + comments = self._prev_comments this = self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS) expressions = self._match(TokenType.SET) and self._parse_csv(self._parse_equality) returning = self._parse_returning() return self.expression( exp.Update, + comments=comments, **{ # type: ignore "this": this, "expressions": expressions, diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index b460c1598f..f0fc01b054 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -848,3 +848,6 @@ SELECT * FROM current_date SELECT * FROM schema.current_date SELECT /*+ SOME_HINT(foo) */ 1 SELECT * FROM (tbl1 CROSS JOIN (SELECT * FROM tbl2) AS t1) +/* comment1 */ INSERT INTO x /* comment2 */ VALUES (1, 2, 3) +/* comment1 */ UPDATE tbl /* comment2 */ SET x = 2 WHERE x < 2 +/* comment1 */ DELETE FROM x /* comment2 */ WHERE y > 1 diff --git a/tests/test_parser.py b/tests/test_parser.py index 07686af979..235e26f502 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -253,7 +253,7 @@ def test_var(self): self.assertIsInstance(parse_one("INTERVAL '1' DAY").args["unit"], exp.Var) self.assertEqual(parse_one("SELECT @JOIN, @'foo'").sql(), "SELECT @JOIN, @'foo'") - def test_comments(self): + def test_comments_select(self): expression = parse_one( """ --comment1.1 @@ -277,6 +277,120 @@ def test_comments(self): self.assertEqual(expression.expressions[4].comments, [""]) self.assertEqual(expression.expressions[5].comments, [" space"]) + def test_comments_select_cte(self): + expression = parse_one( + """ + /*comment1.1*/ + /*comment1.2*/ + WITH a AS (SELECT 1) + SELECT /*comment2*/ + a.* + FROM /*comment3*/ + a + """ + ) + + self.assertEqual(expression.comments, ["comment2"]) + self.assertEqual(expression.args.get("from").comments, ["comment3"]) + self.assertEqual(expression.args.get("with").comments, ["comment1.1", "comment1.2"]) + + def test_comments_insert(self): + expression = parse_one( + """ + --comment1.1 + --comment1.2 + INSERT INTO /*comment1.3*/ + x /*comment2*/ + VALUES /*comment3*/ + (1, 'a', 2.0) + """ + ) + + self.assertEqual(expression.comments, ["comment1.1", "comment1.2", "comment1.3"]) + self.assertEqual(expression.this.comments, ["comment2"]) + + def test_comments_insert_cte(self): + expression = parse_one( + """ + /*comment1.1*/ + /*comment1.2*/ + WITH a AS (SELECT 1) + INSERT INTO /*comment2*/ + b /*comment3*/ + SELECT * FROM a + """ + ) + + self.assertEqual(expression.comments, ["comment2"]) + self.assertEqual(expression.this.comments, ["comment3"]) + self.assertEqual(expression.args.get("with").comments, ["comment1.1", "comment1.2"]) + + def test_comments_update(self): + expression = parse_one( + """ + --comment1.1 + --comment1.2 + UPDATE /*comment1.3*/ + tbl /*comment2*/ + SET /*comment3*/ + x = 2 + WHERE /*comment4*/ + x <> 2 + """ + ) + + self.assertEqual(expression.comments, ["comment1.1", "comment1.2", "comment1.3"]) + self.assertEqual(expression.this.comments, ["comment2"]) + self.assertEqual(expression.args.get("where").comments, ["comment4"]) + + def test_comments_update_cte(self): + expression = parse_one( + """ + /*comment1.1*/ + /*comment1.2*/ + WITH a AS (SELECT * FROM b) + UPDATE /*comment2*/ + a /*comment3*/ + SET col = 1 + """ + ) + + self.assertEqual(expression.comments, ["comment2"]) + self.assertEqual(expression.this.comments, ["comment3"]) + self.assertEqual(expression.args.get("with").comments, ["comment1.1", "comment1.2"]) + + def test_comments_delete(self): + expression = parse_one( + """ + --comment1.1 + --comment1.2 + DELETE /*comment1.3*/ + FROM /*comment2*/ + x /*comment3*/ + WHERE /*comment4*/ + y > 1 + """ + ) + + self.assertEqual(expression.comments, ["comment1.1", "comment1.2", "comment1.3"]) + self.assertEqual(expression.this.comments, ["comment3"]) + self.assertEqual(expression.args.get("where").comments, ["comment4"]) + + def test_comments_delete_cte(self): + expression = parse_one( + """ + /*comment1.1*/ + /*comment1.2*/ + WITH a AS (SELECT * FROM b) + --comment2 + DELETE FROM a /*comment3*/ + """ + ) + + self.assertEqual(expression.comments, ["comment2"]) + self.assertEqual(expression.this.comments, ["comment3"]) + self.assertEqual(expression.args.get("with").comments, ["comment1.1", "comment1.2"]) + def test_type_literals(self): self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)")) self.assertEqual(parse_one("int.5"), parse_one("CAST(0.5 AS INT)"))