Skip to content

Commit

Permalink
Feat(hive): add support for the query TRANSFORM clause (#1935)
Browse files Browse the repository at this point in the history
* Feat(hive): add support for the query TRANSFORM clause

* Add exp.Transform

* Get rid of the _retreat call in _parse_query_transform

* Rename function parser
  • Loading branch information
georgesittas authored Jul 20, 2023
1 parent 4b7e9f1 commit 1d2b5e0
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 8 deletions.
4 changes: 3 additions & 1 deletion sqlglot/dataframe/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,7 +1200,9 @@ def transform(
f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
) -> Column:
f_expression = _get_lambda_from_func(f)
return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression))
return Column.invoke_expression_over_column(
col, expression.Transform, expression=Column(f_expression)
)


def exists(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column:
Expand Down
45 changes: 44 additions & 1 deletion sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,13 +272,52 @@ class Parser(parser.Parser):
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
}

FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
"TRANSFORM": lambda self: self._parse_transform(),
}

PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS,
"WITH SERDEPROPERTIES": lambda self: exp.SerdeProperties(
expressions=self._parse_wrapped_csv(self._parse_property)
),
}

def _parse_transform(self) -> exp.Transform | exp.QueryTransform:
args = self._parse_csv(self._parse_lambda)
self._match_r_paren()

row_format_before = self._parse_row_format(match_row=True)

record_writer = None
if self._match_text_seq("RECORDWRITER"):
record_writer = self._parse_string()

if not self._match(TokenType.USING):
return exp.Transform.from_arg_list(args)

command_script = self._parse_string()

self._match(TokenType.ALIAS)
schema = self._parse_schema()

row_format_after = self._parse_row_format(match_row=True)
record_reader = None
if self._match_text_seq("RECORDREADER"):
record_reader = self._parse_string()

return self.expression(
exp.QueryTransform,
expressions=args,
command_script=command_script,
schema=schema,
row_format_before=row_format_before,
record_writer=record_writer,
row_format_after=row_format_after,
record_reader=record_reader,
)

def _parse_types(
self, check_func: bool = False, schema: bool = False
) -> t.Optional[exp.Expression]:
Expand Down Expand Up @@ -400,7 +439,6 @@ class Generator(generator.Generator):
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}",
exp.RowFormatSerdeProperty: lambda self, e: f"ROW FORMAT SERDE {self.sql(e, 'this')}",
exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
exp.LastDateOfMonth: rename_func("LAST_DAY"),
Expand All @@ -414,6 +452,11 @@ class Generator(generator.Generator):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}

def rowformatserdeproperty_sql(self, expression: exp.RowFormatSerdeProperty) -> str:
serde_props = self.sql(expression, "serde_properties")
serde_props = f" {serde_props}" if serde_props else ""
return f"ROW FORMAT SERDE {self.sql(expression, 'this')}{serde_props}"

def arrayagg_sql(self, expression: exp.ArrayAgg) -> str:
return self.func(
"COLLECT_LIST",
Expand Down
4 changes: 2 additions & 2 deletions sqlglot/dialects/spark2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import typing as t

from sqlglot import exp, parser, transforms
from sqlglot import exp, transforms
from sqlglot.dialects.dialect import (
create_with_partitions_sql,
format_time_lambda,
Expand Down Expand Up @@ -142,7 +142,7 @@ class Parser(Hive.Parser):
}

FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
**Hive.Parser.FUNCTION_PARSERS,
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
Expand Down
20 changes: 19 additions & 1 deletion sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2031,7 +2031,20 @@ class RowFormatDelimitedProperty(Property):


class RowFormatSerdeProperty(Property):
arg_types = {"this": True}
arg_types = {"this": True, "serde_properties": False}


# https://spark.apache.org/docs/3.1.2/sql-ref-syntax-qry-select-transform.html
class QueryTransform(Expression):
arg_types = {
"expressions": True,
"command_script": True,
"schema": False,
"row_format_before": False,
"record_writer": False,
"row_format_after": False,
"record_reader": False,
}


class SchemaCommentProperty(Property):
Expand Down Expand Up @@ -3876,6 +3889,11 @@ class Abs(Func):
pass


# https://spark.apache.org/docs/latest/api/sql/index.html#transform
class Transform(Func):
arg_types = {"this": True, "expression": True}


class Anonymous(Func):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
Expand Down
15 changes: 15 additions & 0 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2508,6 +2508,21 @@ def anyvalue_sql(self, expression: exp.AnyValue) -> str:

return self.func("ANY_VALUE", this)

def querytransform_sql(self, expression: exp.QueryTransform) -> str:
transform = self.func("TRANSFORM", *expression.expressions)
row_format_before = self.sql(expression, "row_format_before")
row_format_before = f" {row_format_before}" if row_format_before else ""
record_writer = self.sql(expression, "record_writer")
record_writer = f" RECORDWRITER {record_writer}" if record_writer else ""
using = f" USING {self.sql(expression, 'command_script')}"
schema = self.sql(expression, "schema")
schema = f" AS {schema}" if schema else ""
row_format_after = self.sql(expression, "row_format_after")
row_format_after = f" {row_format_after}" if row_format_after else ""
record_reader = self.sql(expression, "record_reader")
record_reader = f" RECORDREADER {record_reader}" if record_reader else ""
return f"{transform}{row_format_before}{record_writer}{using}{schema}{row_format_after}{record_reader}"


def cached_generator(
cache: t.Optional[t.Dict[int, str]] = None
Expand Down
16 changes: 13 additions & 3 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ class Parser(metaclass=_Parser):

FUNCTIONS_WITH_ALIASED_ARGS = {"STRUCT"}

FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
FUNCTION_PARSERS = {
"ANY_VALUE": lambda self: self._parse_any_value(),
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
"CONCAT": lambda self: self._parse_concat(),
Expand Down Expand Up @@ -1784,7 +1784,17 @@ def _parse_row_format(
return None

if self._match_text_seq("SERDE"):
return self.expression(exp.RowFormatSerdeProperty, this=self._parse_string())
this = self._parse_string()

serde_properties = None
if self._match(TokenType.SERDE_PROPERTIES):
serde_properties = self.expression(
exp.SerdeProperties, expressions=self._parse_wrapped_csv(self._parse_property)
)

return self.expression(
exp.RowFormatSerdeProperty, this=this, serde_properties=serde_properties
)

self._match_text_seq("DELIMITED")

Expand Down Expand Up @@ -3331,7 +3341,7 @@ def _parse_function(
else:
this = self.expression(exp.Anonymous, this=this, expressions=args)

self._match_r_paren(this)
self._match(TokenType.R_PAREN, expression=this)
return self._parse_window(this)

def _parse_function_parameter(self) -> t.Optional[exp.Expression]:
Expand Down
20 changes: 20 additions & 0 deletions tests/dialects/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ def test_hint(self, logger):
)

def test_spark(self):
self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), x -> x + 1)")
self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), (x, i) -> x + i)")
self.validate_identity("REFRESH table a.b.c")
self.validate_identity("INTERVAL -86 days")
self.validate_identity("SELECT UNIX_TIMESTAMP()")
Expand Down Expand Up @@ -502,3 +504,21 @@ def test_current_user(self):
"CURRENT_USER()",
write={"spark": "CURRENT_USER()"},
)

def test_transform_query(self):
self.validate_identity("SELECT TRANSFORM(x) USING 'x' AS (x INT) FROM t")
self.validate_identity(
"SELECT TRANSFORM(zip_code, name, age) USING 'cat' AS (a, b, c) FROM person WHERE zip_code > 94511"
)
self.validate_identity(
"SELECT TRANSFORM(zip_code, name, age) USING 'cat' AS (a STRING, b STRING, c STRING) FROM person WHERE zip_code > 94511"
)
self.validate_identity(
"SELECT TRANSFORM(name, age) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\n' NULL DEFINED AS 'NULL' USING 'cat' AS (name_age STRING) ROW FORMAT DELIMITED FIELDS TERMINATED BY '@' LINES TERMINATED BY '\n' NULL DEFINED AS 'NULL' FROM person"
)
self.validate_identity(
"SELECT TRANSFORM(zip_code, name, age) ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' WITH SERDEPROPERTIES ('field.delim'='\t') USING 'cat' AS (a STRING, b STRING, c STRING) ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' WITH SERDEPROPERTIES ('field.delim'='\t') FROM person WHERE zip_code > 94511"
)
self.validate_identity(
"SELECT TRANSFORM(zip_code, name, age) USING 'cat' FROM person WHERE zip_code > 94500"
)
1 change: 1 addition & 0 deletions tests/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,7 @@ def test_functions(self):
self.assertIsInstance(parse_one("HEX(foo)"), exp.Hex)
self.assertIsInstance(parse_one("TO_HEX(foo)", read="bigquery"), exp.Hex)
self.assertIsInstance(parse_one("TO_HEX(MD5(foo))", read="bigquery"), exp.MD5)
self.assertIsInstance(parse_one("TRANSFORM(a, b)", read="spark"), exp.Transform)

def test_column(self):
column = parse_one("a.b.c.d")
Expand Down

0 comments on commit 1d2b5e0

Please sign in to comment.