Skip to content

Commit

Permalink
Fix: tsql hashbytes closes #1508
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Apr 30, 2023
1 parent 49cc9bf commit 2dcbc7f
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 6 deletions.
7 changes: 3 additions & 4 deletions sqlglot/dataframe/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,18 +699,17 @@ def crc32(col: ColumnOrName) -> Column:

def md5(col: ColumnOrName) -> Column:
column = col if isinstance(col, Column) else lit(col)
return Column.invoke_anonymous_function(column, "MD5")
return Column.invoke_expression_over_column(column, expression.MD5)


def sha1(col: ColumnOrName) -> Column:
column = col if isinstance(col, Column) else lit(col)
return Column.invoke_anonymous_function(column, "SHA1")
return Column.invoke_expression_over_column(column, expression.SHA)


def sha2(col: ColumnOrName, numBits: int) -> Column:
column = col if isinstance(col, Column) else lit(col)
num_bits = lit(numBits)
return Column.invoke_anonymous_function(column, "SHA2", num_bits)
return Column.invoke_expression_over_column(column, expression.SHA2, length=lit(numBits))


def hash(*cols: ColumnOrName) -> Column:
Expand Down
23 changes: 23 additions & 0 deletions sqlglot/dialects/tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,23 @@ def _parse_eomonth(args):
return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit))


def _parse_hashbytes(args):
kind, data = args
kind = kind.name.upper() if kind.is_string else ""

if kind == "MD5":
args.pop(0)
return exp.MD5(this=data)
if kind in ("SHA", "SHA1"):
args.pop(0)
return exp.SHA(this=data)
if kind == "SHA2_256":
return exp.SHA2(this=data, length=exp.Literal.number(256))
if kind == "SHA2_512":
return exp.SHA2(this=data, length=exp.Literal.number(512))
return exp.func("HASHBYTES", *args)


def generate_date_delta_with_unit_sql(self, e):
func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF"
return self.func(func, e.text("unit"), e.expression, e.this)
Expand Down Expand Up @@ -288,6 +305,7 @@ class Parser(parser.Parser):
"EOMONTH": _parse_eomonth,
"FORMAT": _parse_format,
"GETDATE": exp.CurrentTimestamp.from_arg_list,
"HASHBYTES": _parse_hashbytes,
"IIF": exp.If.from_arg_list,
"ISNULL": exp.Coalesce.from_arg_list,
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
Expand Down Expand Up @@ -450,7 +468,12 @@ class Generator(generator.Generator):
exp.TimeToStr: _format_sql,
exp.GroupConcat: _string_agg_sql,
exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
exp.Min: min_or_least,
exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this),
exp.SHA2: lambda self, e: self.func(
"HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this
),
}

TRANSFORMS.pop(exp.ReturnsProperty)
Expand Down
13 changes: 13 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3779,6 +3779,10 @@ class Max(AggFunc):
is_var_len_args = True


class MD5(Func):
_sql_names = ["MD5"]


class Min(AggFunc):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
Expand Down Expand Up @@ -3885,6 +3889,15 @@ class SetAgg(AggFunc):
pass


class SHA(Func):
_sql_names = ["SHA", "SHA1"]


class SHA2(Func):
_sql_names = ["SHA2"]
arg_types = {"this": True, "length": False}


class SortArray(Func):
arg_types = {"this": True, "asc": False}

Expand Down
4 changes: 2 additions & 2 deletions tests/dataframe/unit/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,9 +960,9 @@ def test_md5(self):

def test_sha1(self):
col_str = SF.sha1("Spark")
self.assertEqual("SHA1('Spark')", col_str.sql())
self.assertEqual("SHA('Spark')", col_str.sql())
col = SF.sha1(SF.col("cola"))
self.assertEqual("SHA1(cola)", col.sql())
self.assertEqual("SHA(cola)", col.sql())

def test_sha2(self):
col_str = SF.sha2("Spark", 256)
Expand Down
48 changes: 48 additions & 0 deletions tests/dialects/test_tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,54 @@ def test_tsql(self):
"postgres": "STRING_AGG(x, '|')",
},
)
self.validate_all(
"SELECT CAST([a].[b] AS SMALLINT) FROM foo",
write={
"tsql": 'SELECT CAST("a"."b" AS SMALLINT) FROM foo',
"spark": "SELECT CAST(`a`.`b` AS SHORT) FROM foo",
},
)
self.validate_all(
"HASHBYTES('SHA1', x)",
read={
"spark": "SHA(x)",
},
write={
"tsql": "HASHBYTES('SHA1', x)",
"spark": "SHA(x)",
},
)
self.validate_all(
"HASHBYTES('SHA2_256', x)",
read={
"spark": "SHA2(x, 256)",
},
write={
"tsql": "HASHBYTES('SHA2_256', x)",
"spark": "SHA2(x, 256)",
},
)
self.validate_all(
"HASHBYTES('SHA2_512', x)",
read={
"spark": "SHA2(x, 512)",
},
write={
"tsql": "HASHBYTES('SHA2_512', x)",
"spark": "SHA2(x, 512)",
},
)
self.validate_all(
"HASHBYTES('MD5', 'x')",
read={
"spark": "MD5('x')",
},
write={
"tsql": "HASHBYTES('MD5', 'x')",
"spark": "MD5('x')",
},
)
self.validate_identity("HASHBYTES('MD2', 'x')")

def test_types(self):
self.validate_identity("CAST(x AS XML)")
Expand Down

0 comments on commit 2dcbc7f

Please sign in to comment.