Skip to content

Commit

Permalink
Fix: convert JSONArrayContains to a Func expression
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas committed Jun 30, 2023
1 parent 58e1683 commit 0197119
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 9 deletions.
3 changes: 3 additions & 0 deletions sqlglot/dialects/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,9 @@ class Generator(generator.Generator):

LIMIT_FETCH = "LIMIT"

def jsonarraycontains_sql(self, expression: exp.JSONArrayContains) -> str:
return f"{self.sql(expression, 'this')} MEMBER OF({self.sql(expression, 'expression')})"

def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
"""(U)BIGINT is not allowed in a CAST expression, so we use (UN)SIGNED instead."""
if expression.to.this == exp.DataType.Type.BIGINT:
Expand Down
10 changes: 5 additions & 5 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3624,11 +3624,6 @@ class Is(Binary, Predicate):
pass


# https://dev.mysql.com/doc/refman/8.0/en/json-search-functions.html#operator_member-of
class JSONArrayContains(Binary, Predicate):
pass


class Kwarg(Binary):
"""Kwarg in special functions like func(kwarg => y)."""

Expand Down Expand Up @@ -4252,6 +4247,11 @@ class JSONFormat(Func):
_sql_names = ["JSON_FORMAT"]


# https://dev.mysql.com/doc/refman/8.0/en/json-search-functions.html#operator_member-of
class JSONArrayContains(Binary, Predicate, Func):
_sql_names = ["JSON_ARRAY_CONTAINS"]


class Least(Func):
arg_types = {"expressions": False}
is_var_len_args = True
Expand Down
3 changes: 0 additions & 3 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2440,9 +2440,6 @@ def anyvalue_sql(self, expression: exp.AnyValue) -> str:

return self.func("ANY_VALUE", this)

def jsonarraycontains_sql(self, expression: exp.JSONArrayContains) -> str:
return f"{self.sql(expression, 'this')} MEMBER OF({self.sql(expression, 'expression')})"


def cached_generator(
cache: t.Optional[t.Dict[int, str]] = None
Expand Down
8 changes: 7 additions & 1 deletion tests/dialects/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def test_identity(self):
self.validate_identity("SELECT JSON_ARRAY(4, 5) MEMBER OF('[[3,4],[4,5]]')")
self.validate_identity("SELECT CAST('[4,5]' AS JSON) MEMBER OF('[[3,4],[4,5]]')")
self.validate_identity("""SELECT 'ab' MEMBER OF('[23, "abc", 17, "ab", 10]')""")
self.validate_identity("""SELECT 17 MEMBER OF('[23, "abc", 17, "ab", 10]')""")
self.validate_identity("CAST(x AS ENUM('a', 'b'))")
self.validate_identity("CAST(x AS SET('a', 'b'))")
self.validate_identity("SELECT CURRENT_TIMESTAMP(6)")
Expand Down Expand Up @@ -419,6 +418,13 @@ def test_mysql(self):
self.validate_all("CAST(x AS SIGNED INTEGER)", write={"mysql": "CAST(x AS SIGNED)"})
self.validate_all("CAST(x AS UNSIGNED)", write={"mysql": "CAST(x AS UNSIGNED)"})
self.validate_all("CAST(x AS UNSIGNED INTEGER)", write={"mysql": "CAST(x AS UNSIGNED)"})
self.validate_all(
"""SELECT 17 MEMBER OF('[23, "abc", 17, "ab", 10]')""",
write={
"": """SELECT JSON_ARRAY_CONTAINS(17, '[23, "abc", 17, "ab", 10]')""",
"mysql": """SELECT 17 MEMBER OF('[23, "abc", 17, "ab", 10]')""",
},
)
self.validate_all(
"SELECT DATE_ADD('2023-06-23 12:00:00', INTERVAL 2 * 2 MONTH) FROM foo",
write={
Expand Down

0 comments on commit 0197119

Please sign in to comment.