Skip to content

Commit

Permalink
Merge pull request #47 from alex-mirkin/allow-disable-duplicate-rows-…
Browse files Browse the repository at this point in the history
…support

allow to disable duplicate rows support with assume-unique-key
  • Loading branch information
erezsh authored Sep 23, 2024
2 parents a0047e8 + 3855d20 commit 31ea069
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 17 deletions.
4 changes: 3 additions & 1 deletion reladiff/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -
@click.option(
"--assume-unique-key",
is_flag=True,
help="Skip validating the uniqueness of the key column during joindiff, which is costly in non-cloud dbs.",
help="Skip validating the uniqueness of the key column during joindiff, which is costly in non-cloud dbs."
"Also, disables support for duplicate rows in hashdiff, offering a small performance gain.",
)
@click.option(
"--skip-sort-results",
Expand Down Expand Up @@ -367,6 +368,7 @@ def _main(
max_threadpool_size=threads and threads * 2,
allow_empty_tables=allow_empty_tables,
skip_sort_results=skip_sort_results,
duplicate_rows_support=not assume_unique_key,
)

table_names = table1, table2
Expand Down
28 changes: 15 additions & 13 deletions reladiff/hashdiff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Iterator
from operator import attrgetter
from collections import Counter
from itertools import repeat
from itertools import chain

from dataclasses import dataclass, field

Expand All @@ -27,17 +27,17 @@
logger = logging.getLogger("hashdiff_tables")


def diff_sets(a: list, b: list, skip_sort_results: bool) -> Iterator:
c = Counter(b)
c.subtract(a)
x = c.items() if skip_sort_results else sorted(c.items(), key=lambda i: i[0]) # sort by key
for k, count in x:
if count < 0:
sign = "-"
count = -count
else:
sign = "+"
yield from repeat((sign, k), count)
def diff_sets(a: list, b: list, skip_sort_results: bool, duplicate_rows_support: bool) -> Iterator:
if duplicate_rows_support:
c = Counter(b)
c.subtract(a)
diff = (("+", k) if count > 0 else ("-", k) for k, count in c.items() for _ in range(abs(count)))
else:
sa = set(a)
sb = set(b)
diff = chain((("-", x) for x in sa - sb), (("+", x) for x in sb - sa))

return diff if skip_sort_results else sorted(diff, key=lambda i: i[1]) # sort by key


@dataclass(frozen=True)
Expand All @@ -58,11 +58,13 @@ class HashDiffer(TableDiffer):
There may be many pools, so number of actual threads can be a lot higher.
skip_sort_results (bool): Skip sorting the hashdiff output by key for better performance.
Entries with the same key but different column values may not appear adjacent in the output.
duplicate_rows_support (bool): If ``True``, the algorithm will support duplicate rows in the tables.
"""

bisection_factor: int = DEFAULT_BISECTION_FACTOR
bisection_threshold: Number = DEFAULT_BISECTION_THRESHOLD # Accepts inf for tests
skip_sort_results: bool = False
duplicate_rows_support: bool = True

stats: dict = field(default_factory=dict)

Expand Down Expand Up @@ -204,7 +206,7 @@ def _bisect_and_diff_segments(
# This saves time, as bisection speed is limited by ping and query performance.
if max_rows < self.bisection_threshold or max_space_size < self.bisection_factor * 2:
rows1, rows2 = self._threaded_call("get_values", [table1, table2])
diff = list(diff_sets(rows1, rows2, self.skip_sort_results))
diff = list(diff_sets(rows1, rows2, self.skip_sort_results, self.duplicate_rows_support))

info_tree.info.set_diff(diff)
info_tree.info.rowcounts = {1: len(rows1), 2: len(rows2)}
Expand Down
16 changes: 13 additions & 3 deletions tests/test_diff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,7 @@ def setUp(self):
(6, "ABCDE")
]

self.expected_output = [
self.expected_output_with_dups = [
('+', ("6", 'ABCDE')),
('+', ("4", 'ABCDE')),
('+', ("4", 'ABCDE')),
Expand All @@ -850,6 +850,12 @@ def setUp(self):
('-', ("12", 'ABCDE')),
]

self.expected_output_no_dups = [
('+', ("4", 'ABCDE')),
('+', ("4", 'ABCDEF')),
('-', ("12", 'ABCDE')),
]

self.connection.query([self.src_table.insert_rows(a), self.dst_table.insert_rows(b), commit])

self.a = table_segment(
Expand All @@ -862,9 +868,13 @@ def setUp(self):
def test_duplicates2(self):
"""If there are duplicates in data, we want to return them as well"""

differ = HashDiffer(bisection_factor=2, bisection_threshold=4)
differ = HashDiffer(bisection_factor=2, bisection_threshold=4, duplicate_rows_support=True)
diff = list(differ.diff_tables(self.a, self.b))
self.assertEqual(sorted(diff), sorted(self.expected_output_with_dups))

differ = HashDiffer(bisection_factor=2, bisection_threshold=4, duplicate_rows_support=False)
diff = list(differ.diff_tables(self.a, self.b))
self.assertEqual(sorted(diff), sorted(self.expected_output))
self.assertEqual(sorted(diff), sorted(self.expected_output_no_dups))


class TestSkipSortResults(DiffTestCase):
Expand Down

0 comments on commit 31ea069

Please sign in to comment.