Skip to content

Commit

Permalink
Feat: add RegexpReplace expression (#1925)
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas authored Jul 17, 2023
1 parent aaee594 commit 3456bbf
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 8 deletions.
12 changes: 10 additions & 2 deletions sqlglot/dataframe/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,8 +875,16 @@ def regexp_extract(str: ColumnOrName, pattern: str, idx: t.Optional[int] = None)
)


def regexp_replace(str: ColumnOrName, pattern: str, replacement: str) -> Column:
return Column.invoke_anonymous_function(str, "REGEXP_REPLACE", lit(pattern), lit(replacement))
def regexp_replace(
str: ColumnOrName, pattern: str, replacement: str, position: t.Optional[int] = None
) -> Column:
return Column.invoke_expression_over_column(
str,
expression.RegexpReplace,
expression=lit(pattern),
replacement=lit(replacement),
position=position,
)


def initcap(col: ColumnOrName) -> Column:
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
min_or_least,
no_ilike_sql,
parse_date_delta_with_interval,
regexp_replace_sql,
rename_func,
timestrtotime_sql,
ts_or_ds_to_date_sql,
Expand Down Expand Up @@ -415,6 +416,7 @@ class Generator(generator.Generator):
e.args.get("position"),
e.args.get("occurrence"),
),
exp.RegexpReplace: regexp_replace_sql,
exp.RegexpLike: rename_func("REGEXP_CONTAINS"),
exp.ReturnsProperty: _returnsproperty_sql,
exp.Select: transforms.preprocess(
Expand Down
10 changes: 10 additions & 0 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,16 @@ def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
)


def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
if bad_args:
self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")

return self.func(
"REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"]
)


def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
names = []
for agg in aggregations:
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
no_safe_divide_sql,
pivot_column_names,
regexp_extract_sql,
regexp_replace_sql,
rename_func,
str_position_sql,
str_to_time_sql,
Expand Down Expand Up @@ -219,6 +220,7 @@ class Generator(generator.Generator):
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.Properties: no_properties_sql,
exp.RegexpExtract: regexp_extract_sql,
exp.RegexpReplace: regexp_replace_sql,
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"),
exp.SafeDivide: no_safe_divide_sql,
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
no_safe_divide_sql,
no_trycast_sql,
regexp_extract_sql,
regexp_replace_sql,
rename_func,
right_to_substring_sql,
strposition_to_locate_sql,
Expand Down Expand Up @@ -369,6 +370,7 @@ class Generator(generator.Generator):
exp.Quantile: rename_func("PERCENTILE"),
exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"),
exp.RegexpExtract: regexp_extract_sql,
exp.RegexpReplace: regexp_replace_sql,
exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"),
exp.RegexpSplit: rename_func("SPLIT"),
exp.Right: right_to_substring_sql,
Expand Down
22 changes: 16 additions & 6 deletions sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _check_int(s: str) -> bool:


# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
def _snowflake_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]:
def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]:
if len(args) == 2:
first_arg, second_arg = args
if second_arg.is_string:
Expand Down Expand Up @@ -137,21 +137,21 @@ def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]:


# https://docs.snowflake.com/en/sql-reference/functions/div0
def _div0_to_if(args: t.List) -> exp.Expression:
def _div0_to_if(args: t.List) -> exp.If:
cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0))
true = exp.Literal.number(0)
false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1))
return exp.If(this=cond, true=true, false=false)


# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
def _zeroifnull_to_if(args: t.List) -> exp.Expression:
def _zeroifnull_to_if(args: t.List) -> exp.If:
cond = exp.Is(this=seq_get(args, 0), expression=exp.Null())
return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0))


# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
def _nullifzero_to_if(args: t.List) -> exp.Expression:
def _nullifzero_to_if(args: t.List) -> exp.If:
cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0))
return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))

Expand All @@ -164,12 +164,21 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
return self.datatype_sql(expression)


def _parse_convert_timezone(args: t.List) -> exp.Expression:
def _parse_convert_timezone(args: t.List) -> t.Union[exp.Anonymous, exp.AtTimeZone]:
if len(args) == 3:
return exp.Anonymous(this="CONVERT_TIMEZONE", expressions=args)
return exp.AtTimeZone(this=seq_get(args, 1), zone=seq_get(args, 0))


def _parse_regexp_replace(args: t.List) -> exp.RegexpReplace:
regexp_replace = exp.RegexpReplace.from_arg_list(args)

if not regexp_replace.args.get("replacement"):
regexp_replace.set("replacement", exp.Literal.string(""))

return regexp_replace


class Snowflake(Dialect):
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
RESOLVES_IDENTIFIERS_AS_UPPERCASE = True
Expand Down Expand Up @@ -223,13 +232,14 @@ class Parser(parser.Parser):
"IFF": exp.If.from_arg_list,
"NULLIFZERO": _nullifzero_to_if,
"OBJECT_CONSTRUCT": _parse_object_construct,
"REGEXP_REPLACE": _parse_regexp_replace,
"REGEXP_SUBSTR": exp.RegexpExtract.from_arg_list,
"RLIKE": exp.RegexpLike.from_arg_list,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"TIMEDIFF": _parse_datediff,
"TIMESTAMPDIFF": _parse_datediff,
"TO_ARRAY": exp.Array.from_arg_list,
"TO_TIMESTAMP": _snowflake_to_timestamp,
"TO_TIMESTAMP": _parse_to_timestamp,
"TO_VARCHAR": exp.ToChar.from_arg_list,
"ZEROIFNULL": _zeroifnull_to_if,
}
Expand Down
7 changes: 7 additions & 0 deletions sqlglot/dialects/spark2.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,13 @@ class Generator(Hive.Generator):
exp.Map: _map_sql,
exp.Pivot: transforms.preprocess([_unqualify_pivot_columns]),
exp.Reduce: rename_func("AGGREGATE"),
exp.RegexpReplace: lambda self, e: self.func(
"REGEXP_REPLACE",
e.this,
e.expression,
e.args["replacement"],
e.args.get("position"),
),
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimestampTrunc: lambda self, e: self.func(
Expand Down
11 changes: 11 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4431,6 +4431,17 @@ class RegexpExtract(Func):
}


class RegexpReplace(Func):
arg_types = {
"this": True,
"expression": True,
"replacement": True,
"position": False,
"occurrence": False,
"parameters": False,
}


class RegexpLike(Func):
arg_types = {"this": True, "expression": True, "flag": False}

Expand Down
52 changes: 52 additions & 0 deletions tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,58 @@ def test_regexp_substr(self, logger):
},
)

@mock.patch("sqlglot.generator.logger")
def test_regexp_replace(self, logger):
self.validate_all(
"REGEXP_REPLACE(subject, pattern)",
write={
"bigquery": "REGEXP_REPLACE(subject, pattern, '')",
"duckdb": "REGEXP_REPLACE(subject, pattern, '')",
"hive": "REGEXP_REPLACE(subject, pattern, '')",
"snowflake": "REGEXP_REPLACE(subject, pattern, '')",
"spark": "REGEXP_REPLACE(subject, pattern, '')",
},
)
self.validate_all(
"REGEXP_REPLACE(subject, pattern, replacement)",
read={
"bigquery": "REGEXP_REPLACE(subject, pattern, replacement)",
"duckdb": "REGEXP_REPLACE(subject, pattern, replacement)",
"hive": "REGEXP_REPLACE(subject, pattern, replacement)",
"spark": "REGEXP_REPLACE(subject, pattern, replacement)",
},
write={
"bigquery": "REGEXP_REPLACE(subject, pattern, replacement)",
"duckdb": "REGEXP_REPLACE(subject, pattern, replacement)",
"hive": "REGEXP_REPLACE(subject, pattern, replacement)",
"snowflake": "REGEXP_REPLACE(subject, pattern, replacement)",
"spark": "REGEXP_REPLACE(subject, pattern, replacement)",
},
)
self.validate_all(
"REGEXP_REPLACE(subject, pattern, replacement, position)",
read={
"spark": "REGEXP_REPLACE(subject, pattern, replacement, position)",
},
write={
"bigquery": "REGEXP_REPLACE(subject, pattern, replacement)",
"duckdb": "REGEXP_REPLACE(subject, pattern, replacement)",
"hive": "REGEXP_REPLACE(subject, pattern, replacement)",
"snowflake": "REGEXP_REPLACE(subject, pattern, replacement, position)",
"spark": "REGEXP_REPLACE(subject, pattern, replacement, position)",
},
)
self.validate_all(
"REGEXP_REPLACE(subject, pattern, replacement, position, occurrence, parameters)",
write={
"bigquery": "REGEXP_REPLACE(subject, pattern, replacement)",
"duckdb": "REGEXP_REPLACE(subject, pattern, replacement)",
"hive": "REGEXP_REPLACE(subject, pattern, replacement)",
"snowflake": "REGEXP_REPLACE(subject, pattern, replacement, position, occurrence, parameters)",
"spark": "REGEXP_REPLACE(subject, pattern, replacement, position)",
},
)

def test_match_recognize(self):
for row in (
"ONE ROW PER MATCH",
Expand Down

0 comments on commit 3456bbf

Please sign in to comment.