Skip to content

Commit

Permalink
Fix!: multi threading issues with simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed May 19, 2023
1 parent f5adb87 commit 8c9e5ec
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 21 deletions.
9 changes: 9 additions & 0 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2313,3 +2313,12 @@ def tochar_sql(self, expression: exp.ToChar) -> str:
self.unsupported("Format argument unsupported for TO_CHAR/TO_VARCHAR function")

return self.sql(exp.cast(expression.this, "text"))


def cached_generator(
cache: t.Optional[t.Dict[int, str]] = None
) -> t.Callable[[exp.Expression], str]:
"""Returns a cached generator."""
cache = {} if cache is None else cache
generator = Generator(normalize=True, identify="safe")
return lambda e: generator.generate(e, cache)
28 changes: 14 additions & 14 deletions sqlglot/optimizer/normalize.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

import logging
import typing as t

from sqlglot import exp
from sqlglot.errors import OptimizeError
from sqlglot.generator import cached_generator
from sqlglot.helper import while_changing
from sqlglot.optimizer.simplify import flatten, uniq_sort

Expand All @@ -28,7 +28,7 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
Returns:
sqlglot.Expression: normalized expression
"""
cache: t.Dict[int, str] = {}
generate = cached_generator()

for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
if isinstance(node, exp.Connector):
Expand All @@ -47,7 +47,7 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
original = node.copy()
try:
node = node.replace(
while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache))
while_changing(node, lambda e: distributive_law(e, dnf, max_distance, generate))
)
except OptimizeError as e:
logger.info(e)
Expand Down Expand Up @@ -111,7 +111,7 @@ def _predicate_lengths(expression, dnf):
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)


def distributive_law(expression, dnf, max_distance, cache=None):
def distributive_law(expression, dnf, max_distance, generate):
"""
x OR (y AND z) -> (x OR y) AND (x OR z)
(x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
Expand All @@ -124,7 +124,7 @@ def distributive_law(expression, dnf, max_distance, cache=None):
if distance > max_distance:
raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")

exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, cache))
exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, generate))
to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)

if isinstance(expression, from_exp):
Expand All @@ -135,30 +135,30 @@ def distributive_law(expression, dnf, max_distance, cache=None):

if isinstance(a, to_exp) and isinstance(b, to_exp):
if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
return _distribute(a, b, from_func, to_func, cache)
return _distribute(b, a, from_func, to_func, cache)
return _distribute(a, b, from_func, to_func, generate)
return _distribute(b, a, from_func, to_func, generate)
if isinstance(a, to_exp):
return _distribute(b, a, from_func, to_func, cache)
return _distribute(b, a, from_func, to_func, generate)
if isinstance(b, to_exp):
return _distribute(a, b, from_func, to_func, cache)
return _distribute(a, b, from_func, to_func, generate)

return expression


def _distribute(a, b, from_func, to_func, cache):
def _distribute(a, b, from_func, to_func, generate):
if isinstance(a, exp.Connector):
exp.replace_children(
a,
lambda c: to_func(
uniq_sort(flatten(from_func(c, b.left)), cache),
uniq_sort(flatten(from_func(c, b.right)), cache),
uniq_sort(flatten(from_func(c, b.left)), generate),
uniq_sort(flatten(from_func(c, b.right)), generate),
copy=False,
),
)
else:
a = to_func(
uniq_sort(flatten(from_func(a, b.left)), cache),
uniq_sort(flatten(from_func(a, b.right)), cache),
uniq_sort(flatten(from_func(a, b.left)), generate),
uniq_sort(flatten(from_func(a, b.right)), generate),
copy=False,
)

Expand Down
12 changes: 5 additions & 7 deletions sqlglot/optimizer/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
from decimal import Decimal

from sqlglot import exp
from sqlglot.generator import Generator
from sqlglot.generator import cached_generator
from sqlglot.helper import first, while_changing

GENERATOR = Generator(normalize=True, identify="safe")


def simplify(expression):
"""
Expand All @@ -27,12 +25,12 @@ def simplify(expression):
sqlglot.Expression: simplified expression
"""

cache = {}
generate = cached_generator()

def _simplify(expression, root=True):
node = expression
node = rewrite_between(node)
node = uniq_sort(node, cache, root)
node = uniq_sort(node, generate, root)
node = absorb_and_eliminate(node, root)
exp.replace_children(node, lambda e: _simplify(e, False))
node = simplify_not(node)
Expand Down Expand Up @@ -247,7 +245,7 @@ def remove_compliments(expression, root=True):
return expression


def uniq_sort(expression, cache=None, root=True):
def uniq_sort(expression, generate, root=True):
"""
Uniq and sort a connector.
Expand All @@ -256,7 +254,7 @@ def uniq_sort(expression, cache=None, root=True):
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
flattened = tuple(expression.flatten())
deduped = {GENERATOR.generate(e, cache): e for e in flattened}
deduped = {generate(e): e for e in flattened}
arr = tuple(deduped.items())

# check if the operands are already sorted, if not sort them
Expand Down

0 comments on commit 8c9e5ec

Please sign in to comment.