Skip to content

Commit

Permalink
Chore: cleanup types and add sort by alias for hive
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Aug 19, 2023
1 parent edb9a96 commit a20794a
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 40 deletions.
6 changes: 2 additions & 4 deletions sqlglot/dialects/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def _parse_function(

def _parse_func_params(
self, this: t.Optional[exp.Func] = None
) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
) -> t.Optional[t.List[exp.Expression]]:
if self._match_pair(TokenType.R_PAREN, TokenType.L_PAREN):
return self._parse_csv(self._parse_lambda)

Expand All @@ -267,9 +267,7 @@ def _parse_quantile(self) -> exp.Quantile:
return self.expression(exp.Quantile, this=params[0], quantile=this)
return self.expression(exp.Quantile, this=this, quantile=exp.Literal.number(0.5))

def _parse_wrapped_id_vars(
self, optional: bool = False
) -> t.List[t.Optional[exp.Expression]]:
def _parse_wrapped_id_vars(self, optional: bool = False) -> t.List[exp.Expression]:
return super()._parse_wrapped_id_vars(optional=True)

def _parse_primary_key(
Expand Down
11 changes: 10 additions & 1 deletion sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ class Tokenizer(tokens.Tokenizer):
class Parser(parser.Parser):
LOG_DEFAULTS_TO_LN = True
STRICT_CAST = False
PARTITION_BY_TOKENS = {*parser.Parser.PARTITION_BY_TOKENS, TokenType.DISTRIBUTE_BY}

FUNCTIONS = {
**parser.Parser.FUNCTIONS,
Expand Down Expand Up @@ -351,6 +350,16 @@ def _parse_types(

return this

def _parse_partition_and_order(
self,
) -> t.Tuple[t.List[exp.Expression], t.Optional[exp.Expression]]:
return (
self._parse_csv(self._parse_conjunction)
if self._match_set({TokenType.PARTITION_BY, TokenType.DISTRIBUTE_BY})
else [],
super()._parse_order(skip_order_token=self._match(TokenType.SORT_BY)),
)

class Generator(generator.Generator):
LIMIT_FETCH = "LIMIT"
TABLESAMPLE_WITH_METHOD = False
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ class Parser(parser.Parser):

CONCAT_NULL_OUTPUTS_STRING = True

def _parse_projections(self) -> t.List[t.Optional[exp.Expression]]:
def _parse_projections(self) -> t.List[exp.Expression]:
"""
T-SQL supports the syntax alias = expression in the SELECT's projection list,
so we transform all parsed Selects to convert their EQ projections into Aliases.
Expand Down
21 changes: 10 additions & 11 deletions sqlglot/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,9 @@ def subclasses(

def apply_index_offset(
this: exp.Expression,
expressions: t.List[t.Optional[E]],
expressions: t.List[E],
offset: int,
) -> t.List[t.Optional[E]]:
) -> t.List[E]:
"""
Applies an offset to a given integer literal expression.
Expand Down Expand Up @@ -170,15 +170,14 @@ def apply_index_offset(
):
return expressions

if expression:
if not expression.type:
annotate_types(expression)
if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
logger.warning("Applying array index offset (%s)", offset)
expression = simplify(
exp.Add(this=expression.copy(), expression=exp.Literal.number(offset))
)
return [expression]
if not expression.type:
annotate_types(expression)
if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
logger.warning("Applying array index offset (%s)", offset)
expression = simplify(
exp.Add(this=expression.copy(), expression=exp.Literal.number(offset))
)
return [expression]

return expressions

Expand Down
49 changes: 27 additions & 22 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,6 @@ class Parser(metaclass=_Parser):
WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS}
WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER}
WINDOW_SIDES = {"FOLLOWING", "PRECEDING"}
PARTITION_BY_TOKENS = {TokenType.PARTITION_BY}

ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY}

Expand Down Expand Up @@ -1417,7 +1416,7 @@ def _parse_volatile_property(self) -> exp.VolatileProperty | exp.StabilityProper

def _parse_with_property(
self,
) -> t.Optional[exp.Expression] | t.List[t.Optional[exp.Expression]]:
) -> t.Optional[exp.Expression] | t.List[exp.Expression]:
if self._match(TokenType.L_PAREN, advance=False):
return self._parse_wrapped_csv(self._parse_property)

Expand Down Expand Up @@ -1630,8 +1629,8 @@ def _parse_locking(self) -> exp.LockingProperty:
override=override,
)

def _parse_partition_by(self) -> t.List[t.Optional[exp.Expression]]:
if self._match_set(self.PARTITION_BY_TOKENS):
def _parse_partition_by(self) -> t.List[exp.Expression]:
if self._match(TokenType.PARTITION_BY):
return self._parse_csv(self._parse_conjunction)
return []

Expand Down Expand Up @@ -1956,7 +1955,7 @@ def _parse_value(self) -> exp.Tuple:
# https://prestodb.io/docs/current/sql/values.html
return self.expression(exp.Tuple, expressions=[self._parse_conjunction()])

def _parse_projections(self) -> t.List[t.Optional[exp.Expression]]:
def _parse_projections(self) -> t.List[exp.Expression]:
return self._parse_expressions()

def _parse_select(
Expand Down Expand Up @@ -2768,7 +2767,7 @@ def _parse_group(self, skip_group_by_token: bool = False) -> t.Optional[exp.Grou

return self.expression(exp.Group, **elements) # type: ignore

def _parse_grouping_sets(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
def _parse_grouping_sets(self) -> t.Optional[t.List[exp.Expression]]:
if not self._match(TokenType.GROUPING_SETS):
return None

Expand Down Expand Up @@ -3159,7 +3158,7 @@ def _parse_types(
maybe_func = True

this: t.Optional[exp.Expression] = None
values: t.Optional[t.List[t.Optional[exp.Expression]]] = None
values: t.Optional[t.List[exp.Expression]] = None

if nested and self._match(TokenType.LT):
if is_struct:
Expand Down Expand Up @@ -3450,7 +3449,9 @@ def _parse_lambda(self, alias: bool = False) -> t.Optional[exp.Expression]:
index = self._index

if self._match(TokenType.L_PAREN):
expressions = self._parse_csv(self._parse_id_var)
expressions = t.cast(
t.List[t.Optional[exp.Expression]], self._parse_csv(self._parse_id_var)
)

if not self._match(TokenType.R_PAREN):
self._retreat(index)
Expand Down Expand Up @@ -3737,7 +3738,7 @@ def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Exp
bracket_kind = self._prev.token_type

if self._match(TokenType.COLON):
expressions: t.List[t.Optional[exp.Expression]] = [
expressions: t.List[exp.Expression] = [
self.expression(exp.Slice, expression=self._parse_conjunction())
]
else:
Expand Down Expand Up @@ -3916,7 +3917,7 @@ def _parse_string_agg(self) -> exp.Expression:
if self._match(TokenType.COMMA):
args.extend(self._parse_csv(self._parse_conjunction))
else:
args = self._parse_csv(self._parse_conjunction)
args = self._parse_csv(self._parse_conjunction) # type: ignore

index = self._index
if not self._match(TokenType.R_PAREN) and args:
Expand Down Expand Up @@ -4124,7 +4125,7 @@ def _parse_substring(self) -> exp.Substring:
# Postgres supports the form: substring(string [from int] [for int])
# https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6

args = self._parse_csv(self._parse_bitwise)
args = t.cast(t.List[t.Optional[exp.Expression]], self._parse_csv(self._parse_bitwise))

if self._match(TokenType.FROM):
args.append(self._parse_bitwise())
Expand Down Expand Up @@ -4157,7 +4158,7 @@ def _parse_trim(self) -> exp.Trim:
exp.Trim, this=this, position=position, expression=expression, collation=collation
)

def _parse_window_clause(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
def _parse_window_clause(self) -> t.Optional[t.List[exp.Expression]]:
return self._match(TokenType.WINDOW) and self._parse_csv(self._parse_named_window)

def _parse_named_window(self) -> t.Optional[exp.Expression]:
Expand Down Expand Up @@ -4224,8 +4225,7 @@ def _parse_window(
if self._match_text_seq("LAST"):
first = False

partition = self._parse_partition_by()
order = self._parse_order()
partition, order = self._parse_partition_and_order()
kind = self._match_set((TokenType.ROWS, TokenType.RANGE)) and self._prev.text

if kind:
Expand Down Expand Up @@ -4264,6 +4264,11 @@ def _parse_window(

return window

def _parse_partition_and_order(
self,
) -> t.Tuple[t.List[exp.Expression], t.Optional[exp.Expression]]:
return self._parse_partition_by(), self._parse_order()

def _parse_window_spec(self) -> t.Dict[str, t.Optional[str | exp.Expression]]:
self._match(TokenType.BETWEEN)

Expand Down Expand Up @@ -4385,14 +4390,14 @@ def _parse_placeholder(self) -> t.Optional[exp.Expression]:
self._advance(-1)
return None

def _parse_except(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
def _parse_except(self) -> t.Optional[t.List[exp.Expression]]:
if not self._match(TokenType.EXCEPT):
return None
if self._match(TokenType.L_PAREN, advance=False):
return self._parse_wrapped_csv(self._parse_column)
return self._parse_csv(self._parse_column)

def _parse_replace(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
def _parse_replace(self) -> t.Optional[t.List[exp.Expression]]:
if not self._match(TokenType.REPLACE):
return None
if self._match(TokenType.L_PAREN, advance=False):
Expand All @@ -4401,7 +4406,7 @@ def _parse_replace(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:

def _parse_csv(
self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA
) -> t.List[t.Optional[exp.Expression]]:
) -> t.List[exp.Expression]:
parse_result = parse_method()
items = [parse_result] if parse_result is not None else []

Expand All @@ -4428,12 +4433,12 @@ def _parse_tokens(

return this

def _parse_wrapped_id_vars(self, optional: bool = False) -> t.List[t.Optional[exp.Expression]]:
def _parse_wrapped_id_vars(self, optional: bool = False) -> t.List[exp.Expression]:
return self._parse_wrapped_csv(self._parse_id_var, optional=optional)

def _parse_wrapped_csv(
self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA, optional: bool = False
) -> t.List[t.Optional[exp.Expression]]:
) -> t.List[exp.Expression]:
return self._parse_wrapped(
lambda: self._parse_csv(parse_method, sep=sep), optional=optional
)
Expand All @@ -4447,7 +4452,7 @@ def _parse_wrapped(self, parse_method: t.Callable, optional: bool = False) -> t.
self._match_r_paren()
return parse_result

def _parse_expressions(self) -> t.List[t.Optional[exp.Expression]]:
def _parse_expressions(self) -> t.List[exp.Expression]:
return self._parse_csv(self._parse_expression)

def _parse_select_or_expression(self, alias: bool = False) -> t.Optional[exp.Expression]:
Expand Down Expand Up @@ -4557,7 +4562,7 @@ def _parse_add_constraint(self) -> exp.AddConstraint:

return self.expression(exp.AddConstraint, this=this, expression=expression)

def _parse_alter_table_add(self) -> t.List[t.Optional[exp.Expression]]:
def _parse_alter_table_add(self) -> t.List[exp.Expression]:
index = self._index - 1

if self._match_set(self.ADD_CONSTRAINT_TOKENS):
Expand All @@ -4584,7 +4589,7 @@ def _parse_alter_table_alter(self) -> exp.AlterColumn:
using=self._match(TokenType.USING) and self._parse_conjunction(),
)

def _parse_alter_table_drop(self) -> t.List[t.Optional[exp.Expression]]:
def _parse_alter_table_drop(self) -> t.List[exp.Expression]:
index = self._index - 1

partition_exists = self._parse_exists()
Expand Down
3 changes: 2 additions & 1 deletion tests/dialects/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,8 @@ def test_order_by(self):

def test_hive(self):
self.validate_identity(
"SELECT ROW() OVER (DISTRIBUTE BY x)", "SELECT ROW() OVER (PARTITION BY x)"
"SELECT ROW() OVER (DISTRIBUTE BY x SORT BY y)",
"SELECT ROW() OVER (PARTITION BY x ORDER BY y)",
)
self.validate_identity("SELECT transform")
self.validate_identity("SELECT * FROM test DISTRIBUTE BY y SORT BY x DESC ORDER BY l")
Expand Down

0 comments on commit a20794a

Please sign in to comment.