Skip to content

Commit

Permalink
fix: conflicting ranges (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
barakalon authored Jul 20, 2024
1 parent 2a18d5b commit afdad41
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 55 deletions.
2 changes: 1 addition & 1 deletion odex/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class Literal(Condition):
value: Any

def __str__(self) -> str:
return str(self.value)
return repr(self.value)


@dataclass
Expand Down
15 changes: 9 additions & 6 deletions odex/index.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import abstractmethod
from typing import Generic, TypeVar, Set, Any, Optional, Iterable, List, cast, Dict, Type, Callable
from typing_extensions import Protocol
from typing_extensions import Protocol, runtime_checkable

from sortedcontainers import SortedDict # type: ignore

Expand Down Expand Up @@ -37,6 +37,9 @@ def match(self, condition: BinOp, operand: Condition) -> Optional[Plan]:
`IndexLoop` plan if it can.
"""


@runtime_checkable
class SupportsLookup(Protocol[T]):
@abstractmethod
def lookup(self, value: Any) -> Set[T]:
"""
Expand All @@ -48,6 +51,9 @@ def lookup(self, value: Any) -> Set[T]:
Result set
"""


@runtime_checkable
class SupportsRange(Protocol[T]):
@abstractmethod
def range(self, rng: Range) -> Set[T]:
"""
Expand All @@ -60,7 +66,7 @@ def range(self, rng: Range) -> Set[T]:
"""


class HashIndex(Generic[T], Index[T]):
class HashIndex(Generic[T], Index[T], SupportsLookup[T]):
"""
Hash table index.
Expand Down Expand Up @@ -95,9 +101,6 @@ def remove(self, objs: Set[T], ctx: Context[T]) -> None:
def lookup(self, value: Any) -> Set[T]:
return self.idx.get(value) or set()

def range(self, rng: Range) -> Set[T]:
raise ValueError(f"{self.__class__.__name__} does not support range queries")

def match(self, condition: BinOp, operand: Condition) -> Optional[Plan]:
if isinstance(condition, Eq) and isinstance(operand, Literal):
return IndexLookup(index=self, value=operand.value)
Expand All @@ -121,7 +124,7 @@ def __str__(self) -> str:
return f"{self.__class__.__name__}({self.attr})"


class SortedDictIndex(Generic[T], HashIndex[T]):
class SortedDictIndex(Generic[T], HashIndex[T], SupportsRange[T]):
"""
Same as `HashIndex`, except this uses a `sortedcontainers.SortedDict` as the index
and supports range queries.
Expand Down
70 changes: 42 additions & 28 deletions odex/optimize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Callable, Sequence, Dict, TYPE_CHECKING, Any, List
from collections import defaultdict
from typing import Callable, Sequence, Dict, TYPE_CHECKING, List, Union as UnionType
from typing_extensions import Protocol

from odex.condition import and_, BinOp, Attribute
Expand All @@ -13,6 +14,7 @@
IndexRange,
IndexLookup,
Bound,
Empty,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -90,40 +92,52 @@ class CombineRanges(TransformerRule):

def transform(self, plan: Plan, ctx: Context) -> Plan:
if isinstance(plan, Intersect):
ranges: Dict[Index, Range] = {}
# Group the plans by ones that support ranges and by index
plans_by_index: Dict[Index, List[UnionType[IndexLookup, IndexRange]]] = defaultdict(
list
)
others = []

for i in plan.inputs:
if isinstance(i, IndexLookup):
rng: Range[Any] = Range(
left=Bound(i.value, True),
right=Bound(i.value, True),
)
existing = ranges.get(i.index)
ranges[i.index] = existing.combine(rng) if existing else rng
elif isinstance(i, IndexRange):
existing = ranges.get(i.index)
ranges[i.index] = existing.combine(i.range) if existing else i.range
if isinstance(i, (IndexLookup, IndexRange)):
plans_by_index[i.index].append(i)
else:
others.append(i)

inputs: List[Plan] = []
for index, rng in ranges.items():
if (
isinstance(rng.left, Bound)
and isinstance(rng.right, Bound)
and rng.left.value == rng.right.value
and rng.left.inclusive
and rng.right.inclusive
):
inputs.append(IndexLookup(index=index, value=rng.left.value))
else:
inputs.append(
IndexRange(
index=index,
range=rng,
)

for index, plans in plans_by_index.items():
if len(plans) == 1:
inputs.append(plans[0])
continue

ranges = [
# Treat a lookup as a range
Range(
left=Bound(i.value, True),
right=Bound(i.value, True),
)
if isinstance(i, IndexLookup)
else i.range
for i in plans
]

new_range = ranges[0]
for rng in ranges[1:]:
combined = new_range.combine(rng)

# None means there is a range that always evaluates to False
if combined is None:
return Empty()
else:
new_range = combined

inputs.append(
IndexRange(
index=index,
range=new_range,
)
)

inputs.extend(others)

if len(inputs) == 1:
Expand Down
35 changes: 30 additions & 5 deletions odex/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Generic,
TypeVar,
NamedTuple,
Optional,
)
from typing_extensions import Protocol

Expand Down Expand Up @@ -41,10 +42,18 @@ class Comparable(Protocol):
def __lt__(self: "C", other: "C") -> bool:
pass

@abstractmethod
def __le__(self: "C", other: "C") -> bool:
pass

@abstractmethod
def __gt__(self: "C", other: "C") -> bool:
pass

@abstractmethod
def __ge__(self: "C", other: "C") -> bool:
pass


C = TypeVar("C", bound=Comparable)

Expand Down Expand Up @@ -94,6 +103,11 @@ def transform(self, transformer: Transformer) -> "Plan":
return transformer(self)


class Empty(Plan):
def to_s(self, depth: int = 0) -> str:
return "Empty"


@dataclass
class ScanFilter(Plan):
"""Return all objects in the collection, filtering with `condition`"""
Expand Down Expand Up @@ -149,10 +163,19 @@ class Range(Generic[C]):
left: OptionalBound = UNSET
right: OptionalBound = UNSET

def combine(self, other: "Range[C]") -> "Range[C]":
def combine(self, other: "Range[C]") -> "Optional[Range[C]]":
left = self._combine_bounds(self.left, other.left, lambda a, b: a > b)
right = self._combine_bounds(self.right, other.right, lambda a, b: a < b)

# Check for an invalid range
if isinstance(left, Bound) and isinstance(right, Bound):
if left.inclusive and right.inclusive:
if left.value > right.value:
return None
else:
if left.value >= right.value:
return None

return Range(
left=left,
right=right,
Expand Down Expand Up @@ -184,11 +207,13 @@ class IndexRange(Plan):
def to_s(self, depth=0):
if self.range.left is UNSET:
assert isinstance(self.range.right, Bound)
return f"IndexRange: {self.index} {self.range.right.symbol()} {self.range.right.value}"
return f"IndexRange: {self.index} {self.range.right.symbol()} {repr(self.range.right.value)}"
if self.range.right is UNSET:
assert isinstance(self.range.left, Bound)
return f"IndexRange: {self.range.left.value} {self.range.left.symbol()} {self.index}"
return f"IndexRange: {self.range.left.value} {self.range.left.symbol()} {self.index} {self.range.right.symbol()} {self.range.right.value}"
return (
f"IndexRange: {repr(self.range.left.value)} {self.range.left.symbol()} {self.index}"
)
return f"IndexRange: {repr(self.range.left.value)} {self.range.left.symbol()} {self.index} {self.range.right.symbol()} {repr(self.range.right.value)}"

def __deepcopy__(self, memodict):
return IndexRange(
Expand All @@ -205,7 +230,7 @@ class IndexLookup(Plan):
value: Any

def to_s(self, depth=0):
return f"IndexLookup: {self.index} = {self.value}"
return f"IndexLookup: {self.index} = {repr(self.value)}"

def __deepcopy__(self, memodict):
return IndexLookup(index=self.index, value=deepcopy(self.value))
Expand Down
13 changes: 12 additions & 1 deletion odex/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,17 @@
from odex.index import Index, InvertedIndex, SortedDictIndex, HashIndex
from odex.optimize import Chain, Rule
from odex.parse import Parser
from odex.plan import Plan, Union, Intersect, ScanFilter, Filter, Planner, IndexLookup, IndexRange
from odex.plan import (
Plan,
Union,
Intersect,
ScanFilter,
Filter,
Planner,
IndexLookup,
IndexRange,
Empty,
)
from odex import condition as cond
from odex.condition import BinOp, UnaryOp, Attribute, Literal, Condition
from odex.utils import intersect
Expand Down Expand Up @@ -122,6 +132,7 @@ def __init__(
Intersect: lambda plan: intersect(*(self.execute(i) for i in plan.inputs)), # type: ignore
IndexLookup: lambda plan: plan.index.lookup(plan.value), # type: ignore
IndexRange: lambda plan: plan.index.range(plan.range), # type: ignore
Empty: lambda plan: set(), # type: ignore
}

def match_binop(op: Callable[[Any, Any], Any]) -> Callable[[BinOp, T], Any]:
Expand Down
60 changes: 46 additions & 14 deletions tests/fixtures/e2e.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ setups:
- ScanFilter: 1 > a
- ScanFilter: 3 <= a
optimized_plan: |-
IndexRange: 3 <= SortedDictIndex(a) < 1
Empty
result: []
- title: Combining ranges leads to =
condition: a > 1 AND a >= 3 AND a <= 3
Expand All @@ -129,9 +129,18 @@ setups:
- ScanFilter: a >= 3
- ScanFilter: a <= 3
optimized_plan: |-
IndexLookup: SortedDictIndex(a) = 3
IndexRange: 3 <= SortedDictIndex(a) <= 3
result:
- 2
- title: Conflicting equalities
condition: a = 1 AND a = 2
plan: |-
Intersect
- ScanFilter: a = 1
- ScanFilter: a = 2
optimized_plan: |-
Empty
result: []
- objects:
- a: 1
b: 2
Expand Down Expand Up @@ -211,29 +220,52 @@ setups:
optimized_plan: |-
IndexRange: SortedDictIndex(a) <= 2
result:
- 0
- 1
- 3
- 4
- 0
- 1
- 3
- 4
- title: Bisect left (>)
condition: a > 1
plan: |-
ScanFilter: a > 1
optimized_plan: |-
IndexRange: 1 < SortedDictIndex(a)
result:
- 1
- 2
- 4
- 5
- 1
- 2
- 4
- 5
- title: Bisect right (>=)
condition: a >= 2
plan: |-
ScanFilter: a >= 2
optimized_plan: |-
IndexRange: 2 <= SortedDictIndex(a)
result:
- 1
- 2
- 4
- 5
- 1
- 2
- 4
- 5
- objects:
- a: foo
- a: bar
- a: baz
indexes: [a]
tests:
- title: String equality
condition: a = 'foo'
plan: |-
ScanFilter: a = 'foo'
optimized_plan: |-
IndexLookup: HashIndex(a) = 'foo'
result:
- 0
- title: Conflicting equalities, string
condition: a = 'foo' AND a = 'bar'
plan: |-
Intersect
- ScanFilter: a = 'foo'
- ScanFilter: a = 'bar'
optimized_plan: |-
Empty
result: []

0 comments on commit afdad41

Please sign in to comment.