Skip to content

Commit

Permalink
Fix!: preserve ascending order of sorting when present (#2256)
Browse files Browse the repository at this point in the history
* Feat(generator): enable overridding of default sorting order in ORDER BY

* Fixup

* Refactor in order to preserve ASC if present

* Convert ORDER BY a ASC to ORDER BY a in optimizer's canonicalize
  • Loading branch information
georgesittas authored Sep 18, 2023
1 parent f76ebb2 commit 4badd91
Show file tree
Hide file tree
Showing 18 changed files with 65 additions and 57 deletions.
2 changes: 1 addition & 1 deletion sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1917,7 +1917,7 @@ class Sort(Order):


class Ordered(Expression):
arg_types = {"this": True, "desc": True, "nulls_first": True}
arg_types = {"this": True, "desc": False, "nulls_first": True}


class Property(Expression):
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1665,7 +1665,7 @@ def ordered_sql(self, expression: exp.Ordered) -> str:
nulls_are_small = self.NULL_ORDERING == "nulls_are_small"
nulls_are_last = self.NULL_ORDERING == "nulls_are_last"

sort_order = " DESC" if desc else ""
sort_order = " DESC" if desc else (" ASC" if desc is False else "")
nulls_sort_change = ""
if nulls_first and (
(asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last
Expand Down
9 changes: 9 additions & 0 deletions sqlglot/optimizer/canonicalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
expression = coerce_type(expression)
expression = remove_redundant_casts(expression)
expression = ensure_bool_predicates(expression)
expression = remove_ascending_order(expression)

return expression

Expand Down Expand Up @@ -63,6 +64,14 @@ def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression:
return expression


def remove_ascending_order(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False:
# Convert ORDER BY a ASC to ORDER BY a
expression.set("desc", None)

return expression


def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
for a, b in itertools.permutations([a, b]):
if (
Expand Down
10 changes: 5 additions & 5 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2944,20 +2944,20 @@ def _parse_sort(self, exp_class: t.Type[E], token: TokenType) -> t.Optional[E]:

def _parse_ordered(self) -> exp.Ordered:
this = self._parse_conjunction()
self._match(TokenType.ASC)

is_desc = self._match(TokenType.DESC)
asc = self._match(TokenType.ASC)
desc = self._match(TokenType.DESC) or (asc and False)

is_nulls_first = self._match_text_seq("NULLS", "FIRST")
is_nulls_last = self._match_text_seq("NULLS", "LAST")
desc = is_desc or False
asc = not desc

nulls_first = is_nulls_first or False
explicitly_null_ordered = is_nulls_first or is_nulls_last

if (
not explicitly_null_ordered
and (
(asc and self.NULL_ORDERING == "nulls_are_small")
(not desc and self.NULL_ORDERING == "nulls_are_small")
or (desc and self.NULL_ORDERING != "nulls_are_small")
)
and self.NULL_ORDERING != "nulls_are_last"
Expand Down
6 changes: 3 additions & 3 deletions tests/dataframe/unit/test_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,16 @@ def test_isin(self):
self.assertEqual("cola IN (1, 2, 3)", F.col("cola").isin(1, 2, 3).sql())

def test_asc(self):
self.assertEqual("cola", F.col("cola").asc().sql())
self.assertEqual("cola ASC", F.col("cola").asc().sql())

def test_desc(self):
self.assertEqual("cola DESC", F.col("cola").desc().sql())

def test_asc_nulls_first(self):
self.assertEqual("cola", F.col("cola").asc_nulls_first().sql())
self.assertEqual("cola ASC", F.col("cola").asc_nulls_first().sql())

def test_asc_nulls_last(self):
self.assertEqual("cola NULLS LAST", F.col("cola").asc_nulls_last().sql())
self.assertEqual("cola ASC NULLS LAST", F.col("cola").asc_nulls_last().sql())

def test_desc_nulls_first(self):
self.assertEqual("cola DESC NULLS FIRST", F.col("cola").desc_nulls_first().sql())
Expand Down
8 changes: 4 additions & 4 deletions tests/dataframe/unit/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,18 +335,18 @@ def test_bitwise_not(self):
def test_asc_nulls_first(self):
col_str = SF.asc_nulls_first("cola")
self.assertIsInstance(col_str.expression, exp.Ordered)
self.assertEqual("cola", col_str.sql())
self.assertEqual("cola ASC", col_str.sql())
col = SF.asc_nulls_first(SF.col("cola"))
self.assertIsInstance(col.expression, exp.Ordered)
self.assertEqual("cola", col.sql())
self.assertEqual("cola ASC", col.sql())

def test_asc_nulls_last(self):
col_str = SF.asc_nulls_last("cola")
self.assertIsInstance(col_str.expression, exp.Ordered)
self.assertEqual("cola NULLS LAST", col_str.sql())
self.assertEqual("cola ASC NULLS LAST", col_str.sql())
col = SF.asc_nulls_last(SF.col("cola"))
self.assertIsInstance(col.expression, exp.Ordered)
self.assertEqual("cola NULLS LAST", col.sql())
self.assertEqual("cola ASC NULLS LAST", col.sql())

def test_desc_nulls_first(self):
col_str = SF.desc_nulls_first("cola")
Expand Down
6 changes: 3 additions & 3 deletions tests/dialects/test_clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ def test_clickhouse(self):
self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
write={
"clickhouse": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname",
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
"clickhouse": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname",
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname NULLS LAST",
},
)
self.validate_all(
Expand Down Expand Up @@ -216,7 +216,7 @@ def test_clickhouse(self):
""",
write={
"clickhouse": "SELECT loyalty, count() FROM hits LEFT SEMI JOIN users USING (UserID)"
+ " GROUP BY loyalty ORDER BY loyalty"
" GROUP BY loyalty ORDER BY loyalty ASC"
},
)
self.validate_identity("SELECT s, arr FROM arrays_test ARRAY JOIN arr")
Expand Down
12 changes: 6 additions & 6 deletions tests/dialects/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,12 +964,12 @@ def test_order_by(self):
self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
write={
"bigquery": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
"duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST",
"oracle": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
"presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST",
"hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
"bigquery": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
"duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname NULLS FIRST",
"oracle": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
"presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname NULLS FIRST",
"hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
},
)

Expand Down
5 changes: 3 additions & 2 deletions tests/dialects/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def test_duckdb(self):
},
)

self.validate_identity("SELECT i FROM RANGE(5) AS _(i) ORDER BY i ASC")
self.validate_identity("[x.STRING_SPLIT(' ')[1] FOR x IN ['1', '2', 3] IF x.CONTAINS('1')]")
self.validate_identity("INSERT INTO x BY NAME SELECT 1 AS y")
self.validate_identity("SELECT 1 AS x UNION ALL BY NAME SELECT 2 AS x")
Expand Down Expand Up @@ -384,8 +385,8 @@ def test_duckdb(self):
self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
write={
"": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
"duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname",
"": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname NULLS LAST",
"duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname",
},
)
self.validate_all(
Expand Down
8 changes: 4 additions & 4 deletions tests/dialects/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,10 @@ def test_order_by(self):
self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
write={
"duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST",
"presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST",
"hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
"duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname NULLS FIRST",
"presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname NULLS FIRST",
"hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
},
)

Expand Down
8 changes: 4 additions & 4 deletions tests/dialects/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,10 +362,10 @@ def test_postgres(self):
self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
write={
"postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname",
"presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname",
"hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
"postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname ASC, lname",
"presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname",
"hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname NULLS LAST",
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname NULLS LAST",
},
)
self.validate_all(
Expand Down
4 changes: 2 additions & 2 deletions tests/dialects/test_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,8 @@ def test_ddl(self):
self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
write={
"presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname",
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
"presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname",
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname NULLS LAST",
},
)

Expand Down
12 changes: 6 additions & 6 deletions tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,12 +385,12 @@ def test_snowflake(self):
self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
write={
"duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname",
"postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname",
"presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname",
"hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
"snowflake": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname",
"duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname",
"postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname ASC, lname",
"presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname",
"hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname NULLS LAST",
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname NULLS LAST",
"snowflake": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname ASC, lname",
},
)
self.validate_all(
Expand Down
14 changes: 7 additions & 7 deletions tests/dialects/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,13 +461,13 @@ def test_spark(self):
self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
write={
"clickhouse": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST",
"duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST",
"postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname NULLS FIRST",
"presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST",
"hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
"snowflake": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname NULLS FIRST",
"clickhouse": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname NULLS FIRST",
"duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname NULLS FIRST",
"postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname ASC, lname NULLS FIRST",
"presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname NULLS FIRST",
"hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
"snowflake": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname ASC, lname NULLS FIRST",
},
)
self.validate_all(
Expand Down
4 changes: 2 additions & 2 deletions tests/dialects/test_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def test_sqlite(self):
self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
write={
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
"sqlite": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname",
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
"sqlite": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
},
)
self.validate_all("x", read={"snowflake": "LEAST(x)"})
Expand Down
8 changes: 4 additions & 4 deletions tests/dialects/test_tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_tsql(self):
WITH (PAD_INDEX = ON, STATISTICS_NORECOMPUTE = OFF) ON [PRIMARY]
) ON [PRIMARY]
""",
'CREATE TABLE x ("zip_cd" VARCHAR(5) NULL NOT FOR REPLICATION CONSTRAINT "pk_mytable" PRIMARY KEY CLUSTERED ("zip_cd_mkey") WITH (PAD_INDEX=ON, STATISTICS_NORECOMPUTE=OFF) ON "PRIMARY") ON "PRIMARY"',
'CREATE TABLE x ("zip_cd" VARCHAR(5) NULL NOT FOR REPLICATION CONSTRAINT "pk_mytable" PRIMARY KEY CLUSTERED ("zip_cd_mkey" ASC) WITH (PAD_INDEX=ON, STATISTICS_NORECOMPUTE=OFF) ON "PRIMARY") ON "PRIMARY"',
)

self.validate_identity(
Expand Down Expand Up @@ -123,10 +123,10 @@ def test_tsql(self):
self.validate_all(
"STRING_AGG(x, '|') WITHIN GROUP (ORDER BY z ASC)",
write={
"tsql": "STRING_AGG(x, '|') WITHIN GROUP (ORDER BY z)",
"mysql": "GROUP_CONCAT(x ORDER BY z SEPARATOR '|')",
"tsql": "STRING_AGG(x, '|') WITHIN GROUP (ORDER BY z ASC)",
"mysql": "GROUP_CONCAT(x ORDER BY z ASC SEPARATOR '|')",
"sqlite": "GROUP_CONCAT(x, '|')",
"postgres": "STRING_AGG(x, '|' ORDER BY z NULLS FIRST)",
"postgres": "STRING_AGG(x, '|' ORDER BY z ASC NULLS FIRST)",
},
)
self.validate_all(
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/identity.sql
Original file line number Diff line number Diff line change
Expand Up @@ -860,3 +860,4 @@ SELECT * FROM (tbl1 CROSS JOIN (SELECT * FROM tbl2) AS t1)
/* comment */ CREATE TABLE foo AS SELECT 1
SELECT next, transform, if
SELECT "any", "case", "if", "next"
SELECT x FROM y ORDER BY x ASC
3 changes: 0 additions & 3 deletions tests/test_transpile.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,6 @@ def test_alias(self):
with self.assertRaises(ParseError):
self.validate(f"SELECT x {key}", "")

def test_asc(self):
self.validate("SELECT x FROM y ORDER BY x ASC", "SELECT x FROM y ORDER BY x")

def test_unary(self):
self.validate("+++1", "1")
self.validate("+-1", "-1")
Expand Down

0 comments on commit 4badd91

Please sign in to comment.