Skip to content

Commit

Permalink
Fix: improve comment handling for several expressions (#2017)
Browse files Browse the repository at this point in the history
* Fix: improve comment handling for several expressions

* More improvements
  • Loading branch information
georgesittas authored Aug 9, 2023
1 parent 95ec5b6 commit baab165
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 3 deletions.
2 changes: 2 additions & 0 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,10 +271,12 @@ class Generator:

# Expressions whose comments are separated from them for better formatting
WITH_SEPARATED_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = (
exp.Create,
exp.Delete,
exp.Drop,
exp.From,
exp.Insert,
exp.Join,
exp.Select,
exp.Update,
exp.Where,
Expand Down
9 changes: 7 additions & 2 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,6 +1165,8 @@ def _parse_exists(self, not_: bool = False) -> t.Optional[bool]:
def _parse_create(self) -> exp.Create | exp.Command:
# Note: this can't be None because we've matched a statement parser
start = self._prev
comments = self._prev_comments

replace = start.text.upper() == "REPLACE" or self._match_pair(
TokenType.OR, TokenType.REPLACE
)
Expand Down Expand Up @@ -1273,6 +1275,7 @@ def extend_props(temp_props: t.Optional[exp.Properties]) -> None:

return self.expression(
exp.Create,
comments=comments,
this=this,
kind=create_token.text,
replace=replace,
Expand Down Expand Up @@ -2338,7 +2341,8 @@ def _parse_join(

kwargs["this"].set("joins", joins)

return self.expression(exp.Join, **kwargs)
comments = [c for token in (method, side, kind) if token for c in token.comments]
return self.expression(exp.Join, comments=comments, **kwargs)

def _parse_index(
self,
Expand Down Expand Up @@ -3738,6 +3742,7 @@ def _parse_case(self) -> t.Optional[exp.Expression]:
ifs = []
default = None

comments = self._prev_comments
expression = self._parse_conjunction()

while self._match(TokenType.WHEN):
Expand All @@ -3753,7 +3758,7 @@ def _parse_case(self) -> t.Optional[exp.Expression]:
self.raise_error("Expected END after CASE", self._prev)

return self._parse_window(
self.expression(exp.Case, this=expression, ifs=ifs, default=default)
self.expression(exp.Case, comments=comments, this=expression, ifs=ifs, default=default)
)

def _parse_if(self) -> t.Optional[exp.Expression]:
Expand Down
5 changes: 5 additions & 0 deletions sqlglot/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,11 @@ def peek(self, i: int = 0) -> str:

def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
self._prev_token_line = self._line

if self._comments and token_type == TokenType.SEMICOLON and self.tokens:
self.tokens[-1].comments.extend(self._comments)
self._comments = []

self.tokens.append(
Token(
token_type,
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/identity.sql
Original file line number Diff line number Diff line change
Expand Up @@ -853,4 +853,5 @@ SELECT * FROM (tbl1 CROSS JOIN (SELECT * FROM tbl2) AS t1)
/* comment1 */ INSERT INTO x /* comment2 */ VALUES (1, 2, 3)
/* comment1 */ UPDATE tbl /* comment2 */ SET x = 2 WHERE x < 2
/* comment1 */ DELETE FROM x /* comment2 */ WHERE y > 1
/* comment */ CREATE TABLE foo AS SELECT 1
SELECT next, transform, if
1 change: 1 addition & 0 deletions tests/test_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def test_comment_attachment(self):
("foo /*comment 1*/ /*comment 2*/", ["comment 1", "comment 2"]),
("foo\n-- comment", [" comment"]),
("1 /*/2 */", ["/2 "]),
("1\n/*comment*/;", ["comment"]),
]

for sql, comment in sql_comment:
Expand Down
56 changes: 55 additions & 1 deletion tests/test_transpile.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,19 @@ def test_space(self):
self.validate("SELECT 3>=3", "SELECT 3 >= 3")

def test_comments(self):
self.validate("SELECT\n foo\n/* comments */\n;", "SELECT foo /* comments */")
self.validate(
"SELECT * FROM a INNER /* comments */ JOIN b",
"SELECT * FROM a /* comments */ INNER JOIN b",
)
self.validate(
"SELECT * FROM a LEFT /* comment 1 */ OUTER /* comment 2 */ JOIN b",
"SELECT * FROM a /* comment 1 */ /* comment 2 */ LEFT OUTER JOIN b",
)
self.validate(
"SELECT CASE /* test */ WHEN a THEN b ELSE c END",
"SELECT CASE WHEN a THEN b ELSE c END /* test */",
)
self.validate("SELECT 1 /*/2 */", "SELECT 1 /* /2 */")
self.validate("SELECT */*comment*/", "SELECT * /* comment */")
self.validate(
Expand Down Expand Up @@ -308,6 +321,7 @@ def test_comments(self):
)
self.validate(
"""
-- comment4
CREATE TABLE db.tba AS
SELECT a, b, c
FROM tb_01
Expand All @@ -316,8 +330,10 @@ def test_comments(self):
a = 1 AND b = 2 --comment6
-- and c = 1
-- comment7
;
""",
"""CREATE TABLE db.tba AS
"""/* comment4 */
CREATE TABLE db.tba AS
SELECT
a,
b,
Expand All @@ -329,6 +345,44 @@ def test_comments(self):
/* comment7 */""",
pretty=True,
)
self.validate(
"""
SELECT
-- This is testing comments
col,
-- 2nd testing comments
CASE WHEN a THEN b ELSE c END as d
FROM t
""",
"""SELECT
col, /* This is testing comments */
CASE WHEN a THEN b ELSE c END /* 2nd testing comments */ AS d
FROM t""",
pretty=True,
)
self.validate(
"""
SELECT * FROM a
-- comments
INNER JOIN b
""",
"""SELECT
*
FROM a
/* comments */
INNER JOIN b""",
pretty=True,
)
self.validate(
"SELECT * FROM a LEFT /* comment 1 */ OUTER /* comment 2 */ JOIN b",
"""SELECT
*
FROM a
/* comment 1 */
/* comment 2 */
LEFT OUTER JOIN b""",
pretty=True,
)

def test_types(self):
self.validate("INT 1", "CAST(1 AS INT)")
Expand Down

0 comments on commit baab165

Please sign in to comment.