Skip to content

Commit

Permalink
Feat(parser): improved comment parsing (#1956)
Browse files Browse the repository at this point in the history
* improved comment parsing

improved comment parsing for INSERT, UPDATE, and DELETE

added tests with and w/o CTEs

* make mypy happy

make mypy happy

* Modified comment generation for Delete, Insert, Update

Added indentity examples:

/* comment1 */ INSERT INTO x /* comment2 */ VALUES (1, 2, 3)
/* comment1 */ UPDATE tbl /* comment2 */ SET x = 2 WHERE x < 2
/* comment1 */ DELETE FROM x /* comment2 */ WHERE y > 1
  • Loading branch information
mpf82 authored Jul 25, 2023
1 parent cbd5099 commit 59847f5
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 1 deletion.
3 changes: 3 additions & 0 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,12 @@ class Generator:

# Expressions whose comments are separated from them for better formatting
WITH_SEPARATED_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = (
exp.Delete,
exp.Drop,
exp.From,
exp.Insert,
exp.Select,
exp.Update,
exp.Where,
exp.With,
)
Expand Down
7 changes: 7 additions & 0 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1692,6 +1692,7 @@ def _parse_describe(self) -> exp.Describe:
return self.expression(exp.Describe, this=this, kind=kind)

def _parse_insert(self) -> exp.Insert:
comments = ensure_list(self._prev_comments)
overwrite = self._match(TokenType.OVERWRITE)
ignore = self._match(TokenType.IGNORE)
local = self._match_text_seq("LOCAL")
Expand All @@ -1709,13 +1710,15 @@ def _parse_insert(self) -> exp.Insert:
alternative = self._match_texts(self.INSERT_ALTERNATIVES) and self._prev.text

self._match(TokenType.INTO)
comments += ensure_list(self._prev_comments)
self._match(TokenType.TABLE)
this = self._parse_table(schema=True)

returning = self._parse_returning()

return self.expression(
exp.Insert,
comments=comments,
this=this,
exists=self._parse_exists(),
partition=self._parse_partition(),
Expand Down Expand Up @@ -1840,13 +1843,15 @@ def _parse_delete(self) -> exp.Delete:
# This handles MySQL's "Multiple-Table Syntax"
# https://dev.mysql.com/doc/refman/8.0/en/delete.html
tables = None
comments = self._prev_comments
if not self._match(TokenType.FROM, advance=False):
tables = self._parse_csv(self._parse_table) or None

returning = self._parse_returning()

return self.expression(
exp.Delete,
comments=comments,
tables=tables,
this=self._match(TokenType.FROM) and self._parse_table(joins=True),
using=self._match(TokenType.USING) and self._parse_table(joins=True),
Expand All @@ -1856,11 +1861,13 @@ def _parse_delete(self) -> exp.Delete:
)

def _parse_update(self) -> exp.Update:
comments = self._prev_comments
this = self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS)
expressions = self._match(TokenType.SET) and self._parse_csv(self._parse_equality)
returning = self._parse_returning()
return self.expression(
exp.Update,
comments=comments,
**{ # type: ignore
"this": this,
"expressions": expressions,
Expand Down
3 changes: 3 additions & 0 deletions tests/fixtures/identity.sql
Original file line number Diff line number Diff line change
Expand Up @@ -848,3 +848,6 @@ SELECT * FROM current_date
SELECT * FROM schema.current_date
SELECT /*+ SOME_HINT(foo) */ 1
SELECT * FROM (tbl1 CROSS JOIN (SELECT * FROM tbl2) AS t1)
/* comment1 */ INSERT INTO x /* comment2 */ VALUES (1, 2, 3)
/* comment1 */ UPDATE tbl /* comment2 */ SET x = 2 WHERE x < 2
/* comment1 */ DELETE FROM x /* comment2 */ WHERE y > 1
116 changes: 115 additions & 1 deletion tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def test_var(self):
self.assertIsInstance(parse_one("INTERVAL '1' DAY").args["unit"], exp.Var)
self.assertEqual(parse_one("SELECT @JOIN, @'foo'").sql(), "SELECT @JOIN, @'foo'")

def test_comments(self):
def test_comments_select(self):
expression = parse_one(
"""
--comment1.1
Expand All @@ -277,6 +277,120 @@ def test_comments(self):
self.assertEqual(expression.expressions[4].comments, [""])
self.assertEqual(expression.expressions[5].comments, [" space"])

def test_comments_select_cte(self):
expression = parse_one(
"""
/*comment1.1*/
/*comment1.2*/
WITH a AS (SELECT 1)
SELECT /*comment2*/
a.*
FROM /*comment3*/
a
"""
)

self.assertEqual(expression.comments, ["comment2"])
self.assertEqual(expression.args.get("from").comments, ["comment3"])
self.assertEqual(expression.args.get("with").comments, ["comment1.1", "comment1.2"])

def test_comments_insert(self):
expression = parse_one(
"""
--comment1.1
--comment1.2
INSERT INTO /*comment1.3*/
x /*comment2*/
VALUES /*comment3*/
(1, 'a', 2.0)
"""
)

self.assertEqual(expression.comments, ["comment1.1", "comment1.2", "comment1.3"])
self.assertEqual(expression.this.comments, ["comment2"])

def test_comments_insert_cte(self):
expression = parse_one(
"""
/*comment1.1*/
/*comment1.2*/
WITH a AS (SELECT 1)
INSERT INTO /*comment2*/
b /*comment3*/
SELECT * FROM a
"""
)

self.assertEqual(expression.comments, ["comment2"])
self.assertEqual(expression.this.comments, ["comment3"])
self.assertEqual(expression.args.get("with").comments, ["comment1.1", "comment1.2"])

def test_comments_update(self):
expression = parse_one(
"""
--comment1.1
--comment1.2
UPDATE /*comment1.3*/
tbl /*comment2*/
SET /*comment3*/
x = 2
WHERE /*comment4*/
x <> 2
"""
)

self.assertEqual(expression.comments, ["comment1.1", "comment1.2", "comment1.3"])
self.assertEqual(expression.this.comments, ["comment2"])
self.assertEqual(expression.args.get("where").comments, ["comment4"])

def test_comments_update_cte(self):
expression = parse_one(
"""
/*comment1.1*/
/*comment1.2*/
WITH a AS (SELECT * FROM b)
UPDATE /*comment2*/
a /*comment3*/
SET col = 1
"""
)

self.assertEqual(expression.comments, ["comment2"])
self.assertEqual(expression.this.comments, ["comment3"])
self.assertEqual(expression.args.get("with").comments, ["comment1.1", "comment1.2"])

def test_comments_delete(self):
expression = parse_one(
"""
--comment1.1
--comment1.2
DELETE /*comment1.3*/
FROM /*comment2*/
x /*comment3*/
WHERE /*comment4*/
y > 1
"""
)

self.assertEqual(expression.comments, ["comment1.1", "comment1.2", "comment1.3"])
self.assertEqual(expression.this.comments, ["comment3"])
self.assertEqual(expression.args.get("where").comments, ["comment4"])

def test_comments_delete_cte(self):
expression = parse_one(
"""
/*comment1.1*/
/*comment1.2*/
WITH a AS (SELECT * FROM b)
--comment2
DELETE FROM a /*comment3*/
"""
)

self.assertEqual(expression.comments, ["comment2"])
self.assertEqual(expression.this.comments, ["comment3"])
self.assertEqual(expression.args.get("with").comments, ["comment1.1", "comment1.2"])

def test_type_literals(self):
self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)"))
self.assertEqual(parse_one("int.5"), parse_one("CAST(0.5 AS INT)"))
Expand Down

0 comments on commit 59847f5

Please sign in to comment.