Skip to content

Commit

Permalink
Refactor!: hashable args is now more efficient and identifiers no lon…
Browse files Browse the repository at this point in the history
…ger accomodate case insensitivity because that is dialect specific
  • Loading branch information
tobymao committed Jul 3, 2023
1 parent df4448d commit f747260
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 43 deletions.
16 changes: 8 additions & 8 deletions benchmarks/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

import numpy as np

import sqlfluff
import moz_sql_parser
import sqloxide
import sqlparse
#import sqlfluff
#import moz_sql_parser
#import sqloxide
#import sqlparse
import sqltree

import sqlglot
Expand Down Expand Up @@ -199,11 +199,11 @@ def diff(row, column):

libs = [
"sqlglot",
"sqlfluff",
#"sqlfluff",
"sqltree",
"sqlparse",
"moz_sql_parser",
"sqloxide",
#"sqlparse",
#"moz_sql_parser",
#"sqloxide",
]
table = []

Expand Down
25 changes: 24 additions & 1 deletion sqlglot/dialects/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,29 @@ def _to_timestamp(args: t.List) -> exp.Expression:
return format_time_lambda(exp.StrToTime, "postgres")(args)


def _remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
"""Remove table refs from columns in when statements."""
if isinstance(expression, exp.Merge):
alias = expression.this.args.get("alias")

normalize = lambda identifier: Postgres.normalize_identifier(identifier).name

targets = {normalize(expression.this.this)}

if alias:
targets.add(normalize(alias.this))

for when in expression.expressions:
when.transform(
lambda node: exp.column(node.name)
if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
else node,
copy=False,
)

return expression


class Postgres(Dialect):
INDEX_OFFSET = 1
NULL_ORDERING = "nulls_are_large"
Expand Down Expand Up @@ -352,7 +375,7 @@ class Generator(generator.Generator):
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
exp.ArrayContained: lambda self, e: self.binary(e, "<@"),
exp.Merge: transforms.preprocess([transforms.remove_target_from_merge]),
exp.Merge: transforms.preprocess([_remove_target_from_merge]),
exp.Pivot: no_pivot_sql,
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
Expand Down
15 changes: 5 additions & 10 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,10 @@ def __eq__(self, other) -> bool:

@property
def hashable_args(self) -> t.Any:
args = (self.args.get(k) for k in self.arg_types)

return tuple(
(tuple(_norm_arg(a) for a in arg) if arg else None)
if type(arg) is list
else (_norm_arg(arg) if arg is not None and arg is not False else None)
for arg in args
return frozenset(
(k, tuple(_norm_arg(a) for a in v) if type(v) is list else _norm_arg(v))
for k, v in self.args.items()
if not (v is None or v is False or (type(v) is list and not v))
)

def __hash__(self) -> int:
Expand Down Expand Up @@ -1490,9 +1487,7 @@ def quoted(self) -> bool:

@property
def hashable_args(self) -> t.Any:
if self.quoted and any(char.isupper() for char in self.this):
return (self.this, self.quoted)
return self.this.lower()
return (self.this, self.quoted)

@property
def output_name(self) -> str:
Expand Down
19 changes: 0 additions & 19 deletions sqlglot/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,25 +215,6 @@ def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
return expression


def remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
"""Remove table refs from columns in when statements."""
if isinstance(expression, exp.Merge):
alias = expression.this.args.get("alias")
targets = {expression.this.this}
if alias:
targets.add(alias.this)

for when in expression.expressions:
when.transform(
lambda node: exp.column(node.name)
if isinstance(node, exp.Column) and node.args.get("table") in targets
else node,
copy=False,
)

return expression


def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
if (
isinstance(expression, exp.WithinGroup)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ def test_depth(self):
self.assertEqual(parse_one("x(1)").find(exp.Literal).depth, 1)

def test_eq(self):
self.assertEqual(exp.to_identifier("a"), exp.to_identifier("A"))
self.assertNotEqual(exp.to_identifier("a"), exp.to_identifier("A"))

self.assertEqual(
exp.Column(table=exp.to_identifier("b"), this=exp.to_identifier("b")),
exp.Column(this=exp.to_identifier("b"), table=exp.to_identifier("b")),
)

self.assertEqual(exp.to_identifier("a", quoted=True), exp.to_identifier("A"))
self.assertNotEqual(exp.to_identifier("a", quoted=True), exp.to_identifier("A"))
self.assertNotEqual(exp.to_identifier("A", quoted=True), exp.to_identifier("A"))
self.assertNotEqual(
exp.to_identifier("A", quoted=True), exp.to_identifier("a", quoted=True)
Expand All @@ -31,9 +31,9 @@ def test_eq(self):
self.assertNotEqual(parse_one("'1'"), parse_one("1"))
self.assertEqual(parse_one("`a`", read="hive"), parse_one('"a"'))
self.assertEqual(parse_one("`a`", read="hive"), parse_one('"a" '))
self.assertEqual(parse_one("`a`.b", read="hive"), parse_one('"a"."b"'))
self.assertEqual(parse_one("`a`.`b`", read="hive"), parse_one('"a"."b"'))
self.assertEqual(parse_one("select a, b+1"), parse_one("SELECT a, b + 1"))
self.assertEqual(parse_one("`a`.`b`.`c`", read="hive"), parse_one("a.b.c"))
self.assertNotEqual(parse_one("`a`.`b`.`c`", read="hive"), parse_one("a.b.c"))
self.assertNotEqual(parse_one("a.b.c.d", read="hive"), parse_one("a.b.c"))
self.assertEqual(parse_one("a.b.c.d", read="hive"), parse_one("a.b.c.d"))
self.assertEqual(parse_one("a + b * c - 1.0"), parse_one("a+b*c-1.0"))
Expand Down Expand Up @@ -338,7 +338,7 @@ def test_hash(self):
{
parse_one("select a.b"),
parse_one("1+2"),
parse_one('"a".b'),
parse_one('"a"."b"'),
parse_one("a.b.c.d"),
},
{
Expand Down

0 comments on commit f747260

Please sign in to comment.