Skip to content

Commit

Permalink
Fix: conditionally quote identifiers that start with a digit (#1729)
Browse files Browse the repository at this point in the history
* Fix: conditionally quote identifiers that start with a digit

* Fixups

* PR feedback

* Comment fixup

* Bring the tokenizer cache back

* Move setting logic to metaclass
  • Loading branch information
georgesittas authored Jun 6, 2023
1 parent 1eb338a commit cad14bd
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 10 deletions.
6 changes: 6 additions & 0 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,18 @@ def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[
klass.byte_start, klass.byte_end = get_start_end(TokenType.BYTE_STRING)
klass.raw_start, klass.raw_end = get_start_end(TokenType.RAW_STRING)

klass.tokenizer_class.identifiers_can_start_with_digit = (
klass.identifiers_can_start_with_digit
)

return klass


class Dialect(metaclass=_Dialect):
index_offset = 0
unnest_column_only = False
alias_post_tablesample = False
identifiers_can_start_with_digit = False
normalize_functions: t.Optional[str] = "upper"
null_ordering = "nulls_are_small"

Expand Down Expand Up @@ -231,6 +236,7 @@ def generator(self, **opts) -> Generator:
"time_trie": self.inverse_time_trie,
"unnest_column_only": self.unnest_column_only,
"alias_post_tablesample": self.alias_post_tablesample,
"identifiers_can_start_with_digit": self.identifiers_can_start_with_digit,
"normalize_functions": self.normalize_functions,
"null_ordering": self.null_ordering,
**opts,
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str

class Hive(Dialect):
alias_post_tablesample = True
identifiers_can_start_with_digit = True

time_mapping = {
"y": "%Y",
Expand Down Expand Up @@ -190,7 +191,6 @@ class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ["`"]
STRING_ESCAPES = ["\\"]
ENCODE = "utf-8"
IDENTIFIER_CAN_START_WITH_DIGIT = True

KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
Expand Down
6 changes: 6 additions & 0 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class Generator:
Default: "upper"
alias_post_tablesample (bool): if the table alias comes after tablesample
Default: False
identifiers_can_start_with_digit (bool): if an unquoted identifier can start with digit
Default: False
unsupported_level (ErrorLevel): determines the generator's behavior when it encounters
unsupported expressions. Default ErrorLevel.WARN.
null_ordering (str): Indicates the default null ordering method to use if not explicitly set.
Expand Down Expand Up @@ -266,6 +268,7 @@ class Generator:
"index_offset",
"unnest_column_only",
"alias_post_tablesample",
"identifiers_can_start_with_digit",
"normalize_functions",
"unsupported_level",
"unsupported_messages",
Expand Down Expand Up @@ -306,6 +309,7 @@ def __init__(
index_offset=0,
unnest_column_only=False,
alias_post_tablesample=False,
identifiers_can_start_with_digit=False,
normalize_functions="upper",
unsupported_level=ErrorLevel.WARN,
null_ordering=None,
Expand Down Expand Up @@ -339,6 +343,7 @@ def __init__(
self.index_offset = index_offset
self.unnest_column_only = unnest_column_only
self.alias_post_tablesample = alias_post_tablesample
self.identifiers_can_start_with_digit = identifiers_can_start_with_digit
self.normalize_functions = normalize_functions
self.unsupported_level = unsupported_level
self.unsupported_messages = []
Expand Down Expand Up @@ -896,6 +901,7 @@ def identifier_sql(self, expression: exp.Identifier) -> str:
expression.quoted
or should_identify(text, self.identify)
or lower in self.RESERVED_KEYWORDS
or (not self.identifiers_can_start_with_digit and text[:1].isdigit())
):
text = f"{self.identifier_start}{text}{self.identifier_end}"
return text
Expand Down
9 changes: 3 additions & 6 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3060,11 +3060,7 @@ def _parse_column_ops(self, this: exp.Expression) -> exp.Expression:
else exp.Literal.string(value)
)
else:
field = (
self._parse_star()
or self._parse_function(anonymous=True)
or self._parse_id_var()
)
field = self._parse_field(anonymous_func=True)

if isinstance(field, exp.Func):
# bigquery allows function calls like x.y.count(...)
Expand Down Expand Up @@ -3135,10 +3131,11 @@ def _parse_field(
self,
any_token: bool = False,
tokens: t.Optional[t.Collection[TokenType]] = None,
anonymous_func: bool = False,
) -> t.Optional[exp.Expression]:
return (
self._parse_primary()
or self._parse_function()
or self._parse_function(anonymous=anonymous_func)
or self._parse_id_var(any_token=any_token, tokens=tokens)
)

Expand Down
5 changes: 2 additions & 3 deletions sqlglot/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,8 +735,6 @@ class Tokenizer(metaclass=_Tokenizer):
COMMENTS = ["--", ("/*", "*/"), ("{#", "#}")]
KEYWORD_TRIE: t.Dict = {} # autofilled

IDENTIFIER_CAN_START_WITH_DIGIT = False

__slots__ = (
"sql",
"size",
Expand All @@ -750,6 +748,7 @@ class Tokenizer(metaclass=_Tokenizer):
"_end",
"_peek",
"_prev_token_line",
"identifiers_can_start_with_digit",
)

def __init__(self) -> None:
Expand Down Expand Up @@ -1010,7 +1009,7 @@ def _scan_number(self) -> None:
self._add(TokenType.NUMBER, number_text)
self._add(TokenType.DCOLON, "::")
return self._add(token_type, literal)
elif self.IDENTIFIER_CAN_START_WITH_DIGIT:
elif self.identifiers_can_start_with_digit: # type: ignore
return self._add(TokenType.VAR)

self._add(TokenType.NUMBER, number_text)
Expand Down
1 change: 1 addition & 0 deletions tests/dialects/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ def test_hive(self):
"SELECT 1_a AS a FROM test_table",
write={
"spark": "SELECT 1_a AS a FROM test_table",
"trino": 'SELECT "1_a" AS a FROM test_table',
},
)
self.validate_all(
Expand Down

0 comments on commit cad14bd

Please sign in to comment.