Skip to content

Commit

Permalink
Fix: need to differentiate between peek and curr tokenizers
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed May 8, 2023
1 parent 4744742 commit 23cf246
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 14 deletions.
43 changes: 29 additions & 14 deletions sqlglot/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ def _chars(self, size: int) -> str:
return self.sql[start:end]
return ""

def _advance(self, i: int = 1, alnum=False) -> None:
def _advance(self, i: int = 1, alnum: str = "") -> None:
if self.WHITE_SPACE.get(self._char) is TokenType.BREAK:
self._col = 1
self._line += 1
Expand All @@ -858,19 +858,33 @@ def _advance(self, i: int = 1, alnum=False) -> None:
_col = self._col
_current = self._current
_end = self._end
_peek = self._peek

while _peek.isalnum():
_col += 1
_current += 1
_end = _current >= self.size
_peek = "" if _end else self.sql[_current]
if alnum == "curr":
_char = self._char

while not _end and _char.isalnum():
_char = self.sql[_current]
_col += 1
_current += 1
_end = _current >= self.size

self._char = _char
self._peek = "" if _end else self.sql[_current]
else:
_peek = self._peek

while _peek.isalnum():
_col += 1
_current += 1
_end = _current >= self.size
_peek = "" if _end else self.sql[_current]

self._char = self.sql[_current - 1]
self._peek = _peek

self._col = _col
self._current = _current
self._end = _end
self._peek = _peek
self._char = self.sql[_current - 1]

@property
def _text(self) -> str:
Expand Down Expand Up @@ -978,13 +992,13 @@ def _scan_comment(self, comment_start: str) -> bool:

comment_end_size = len(comment_end)
while not self._end and self._chars(comment_end_size) != comment_end:
self._advance(alnum=True)
self._advance(alnum="peek")

self._comments.append(self._text[comment_start_size : -comment_end_size + 1])
self._advance(comment_end_size - 1)
else:
while not self._end and not self.WHITE_SPACE.get(self._peek) is TokenType.BREAK:
self._advance(alnum=True)
self._advance(alnum="peek")
self._comments.append(self._text[comment_start_size:])

# Leading comment is attached to the succeeding token, whilst trailing comment to the preceding.
Expand Down Expand Up @@ -1065,7 +1079,7 @@ def _extract_value(self) -> str:
while True:
char = self._peek.strip()
if char and char not in self.SINGLE_TOKENS:
self._advance(alnum=True)
self._advance(alnum="peek")
else:
break

Expand Down Expand Up @@ -1123,9 +1137,10 @@ def _scan_var(self) -> None:
while True:
char = self._peek.strip()
if char and (char in self.VAR_SINGLE_TOKENS or char not in self.SINGLE_TOKENS):
self._advance(alnum=True)
self._advance(alnum="peek")
else:
break

self._add(
TokenType.VAR
if self.tokens and self.tokens[-1].token_type == TokenType.PARAMETER
Expand Down Expand Up @@ -1158,7 +1173,7 @@ def _extract_string(self, delimiter: str, escapes=None) -> str:
raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}")

current = self._current - 1
self._advance(alnum=True)
self._advance(alnum="curr")
text += self.sql[current : self._current - 1]

return text
1 change: 1 addition & 0 deletions tests/dialects/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def test_hexadecimal_literal(self):
self.validate_all("SELECT X'1A'", write={"mysql": "SELECT x'1A'"})
self.validate_all("SELECT 0xz", write={"mysql": "SELECT `0xz`"})
self.validate_all("SELECT 0xCC", write=write_CC)
self.validate_all("SELECT 0xCC ", write=write_CC)
self.validate_all("SELECT x'CC'", write=write_CC)
self.validate_all("SELECT 0x0000CC", write=write_CC_with_leading_zeros)
self.validate_all("SELECT x'0000CC'", write=write_CC_with_leading_zeros)
Expand Down
11 changes: 11 additions & 0 deletions tests/test_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,14 @@ def test_jinja(self):
(TokenType.SEMICOLON, ";"),
],
)

tokens = tokenizer.tokenize("""'{{ var('x') }}'""")
tokens = [(token.token_type, token.text) for token in tokens]
self.assertEqual(
tokens,
[
(TokenType.STRING, "{{ var("),
(TokenType.VAR, "x"),
(TokenType.STRING, ") }}"),
],
)

0 comments on commit 23cf246

Please sign in to comment.