Skip to content

Commit

Permalink
Fix\!: convert left and right closes #1733
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Jun 6, 2023
1 parent e058513 commit 6ad00ca
Show file tree
Hide file tree
Showing 9 changed files with 48 additions and 21 deletions.
19 changes: 19 additions & 0 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,25 @@ def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> s
)


def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
expression = expression.copy()
return self.sql(
exp.Substring(
this=expression.this, start=exp.Literal.number(1), length=expression.expression
)
)


def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
expression = expression.copy()
return self.sql(
exp.Substring(
this=expression.this,
start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
)
)


def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"

Expand Down
4 changes: 4 additions & 0 deletions sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
create_with_partitions_sql,
format_time_lambda,
if_sql,
left_to_substring_sql,
locate_to_strposition,
max_or_greatest,
min_or_least,
Expand All @@ -17,6 +18,7 @@
no_safe_divide_sql,
no_trycast_sql,
rename_func,
right_to_substring_sql,
strposition_to_locate_sql,
struct_extract_sql,
timestrtotime_sql,
Expand Down Expand Up @@ -356,6 +358,7 @@ class Generator(generator.Generator):
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
exp.JSONFormat: _json_format_sql,
exp.Left: left_to_substring_sql,
exp.Map: var_map_sql,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
Expand All @@ -365,6 +368,7 @@ class Generator(generator.Generator):
exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"),
exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"),
exp.RegexpSplit: rename_func("SPLIT"),
exp.Right: right_to_substring_sql,
exp.SafeDivide: no_safe_divide_sql,
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
exp.SetAgg: rename_func("COLLECT_SET"),
Expand Down
3 changes: 0 additions & 3 deletions sqlglot/dialects/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,6 @@ class Parser(parser.Parser):
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"),
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
"INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)),
"LEFT": lambda args: exp.Substring(
this=seq_get(args, 0), start=exp.Literal.number(1), length=seq_get(args, 1)
),
"LOCATE": locate_to_strposition,
"STR_TO_DATE": _str_to_date,
}
Expand Down
4 changes: 4 additions & 0 deletions sqlglot/dialects/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
date_trunc_to_time,
format_time_lambda,
if_sql,
left_to_substring_sql,
no_ilike_sql,
no_pivot_sql,
no_safe_divide_sql,
rename_func,
right_to_substring_sql,
struct_extract_sql,
timestamptrunc_sql,
timestrtotime_sql,
Expand Down Expand Up @@ -293,11 +295,13 @@ class Generator(generator.Generator):
exp.ILike: no_ilike_sql,
exp.Initcap: _initcap_sql,
exp.Lateral: _explode_to_unnest_sql,
exp.Left: left_to_substring_sql,
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Pivot: no_pivot_sql,
exp.Quantile: _quantile_sql,
exp.Right: right_to_substring_sql,
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.Select: transforms.preprocess(
Expand Down
15 changes: 2 additions & 13 deletions sqlglot/dialects/spark2.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,6 @@ class Parser(Hive.Parser):
**Hive.Parser.FUNCTIONS,
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
"LEFT": lambda args: exp.Substring(
this=seq_get(args, 0),
start=exp.Literal.number(1),
length=seq_get(args, 1),
),
"SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
this=seq_get(args, 0),
expression=seq_get(args, 1),
Expand All @@ -123,14 +118,6 @@ class Parser(Hive.Parser):
this=seq_get(args, 0),
expression=seq_get(args, 1),
),
"RIGHT": lambda args: exp.Substring(
this=seq_get(args, 0),
start=exp.Sub(
this=exp.Length(this=seq_get(args, 0)),
expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)),
),
length=seq_get(args, 1),
),
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
"IIF": exp.If.from_arg_list,
"AGGREGATE": exp.Reduce.from_arg_list,
Expand Down Expand Up @@ -240,6 +227,8 @@ class Generator(Hive.Generator):
TRANSFORMS.pop(exp.ArrayJoin)
TRANSFORMS.pop(exp.ArraySort)
TRANSFORMS.pop(exp.ILike)
TRANSFORMS.pop(exp.Left)
TRANSFORMS.pop(exp.Right)

WRAP_DERIVED_VALUES = False
CREATE_FUNCTION_RETURN_AS = False
Expand Down
8 changes: 8 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4158,6 +4158,14 @@ class Least(Func):
is_var_len_args = True


class Left(Func):
arg_types = {"this": True, "expression": True}


class Right(Func):
arg_types = {"this": True, "expression": True}


class Length(Func):
pass

Expand Down
2 changes: 1 addition & 1 deletion tests/dialects/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_types(self):
)

def test_canonical_functions(self):
self.validate_identity("SELECT LEFT('str', 2)", "SELECT SUBSTRING('str', 1, 2)")
self.validate_identity("SELECT LEFT('str', 2)", "SELECT LEFT('str', 2)")
self.validate_identity("SELECT INSTR('str', 'substr')", "SELECT LOCATE('substr', 'str')")
self.validate_identity("SELECT UCASE('foo')", "SELECT UPPER('foo')")
self.validate_identity("SELECT LCASE('foo')", "SELECT LOWER('foo')")
Expand Down
6 changes: 6 additions & 0 deletions tests/dialects/test_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,12 @@ def test_presto(self):
self.validate_all("INTERVAL '1 day'", write={"trino": "INTERVAL '1' day"})
self.validate_all("(5 * INTERVAL '7' day)", read={"": "INTERVAL '5' week"})
self.validate_all("(5 * INTERVAL '7' day)", read={"": "INTERVAL '5' WEEKS"})
self.validate_all(
"SELECT SUBSTRING(a, 1, 3), SUBSTRING(a, LENGTH(a) - (3 - 1))",
read={
"redshift": "SELECT LEFT(a, 3), RIGHT(a, 3)",
},
)
self.validate_all(
"WITH RECURSIVE t(n) AS (SELECT 1 AS n UNION ALL SELECT n + 1 AS n FROM t WHERE n < 4) SELECT SUM(n) FROM t",
read={
Expand Down
8 changes: 4 additions & 4 deletions tests/dialects/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,10 +385,10 @@ def test_spark(self):
self.validate_all(
"SELECT LEFT(x, 2), RIGHT(x, 2)",
write={
"duckdb": "SELECT SUBSTRING(x, 1, 2), SUBSTRING(x, LENGTH(x) - 2 + 1, 2)",
"presto": "SELECT SUBSTRING(x, 1, 2), SUBSTRING(x, LENGTH(x) - 2 + 1, 2)",
"hive": "SELECT SUBSTRING(x, 1, 2), SUBSTRING(x, LENGTH(x) - 2 + 1, 2)",
"spark": "SELECT SUBSTRING(x, 1, 2), SUBSTRING(x, LENGTH(x) - 2 + 1, 2)",
"duckdb": "SELECT LEFT(x, 2), RIGHT(x, 2)",
"presto": "SELECT SUBSTRING(x, 1, 2), SUBSTRING(x, LENGTH(x) - (2 - 1))",
"hive": "SELECT SUBSTRING(x, 1, 2), SUBSTRING(x, LENGTH(x) - (2 - 1))",
"spark": "SELECT LEFT(x, 2), RIGHT(x, 2)",
},
)
self.validate_all(
Expand Down

0 comments on commit 6ad00ca

Please sign in to comment.