From e2e7f15b719f480c4d2a3aea028c55f2dc3f0b75 Mon Sep 17 00:00:00 2001 From: ibriquem Date: Tue, 2 Jun 2020 15:38:41 +0200 Subject: [PATCH] Make dataclasses/attrs comparison recursive, fixes #4675 --- changelog/4675.bugfix.rst | 1 + src/_pytest/assertion/util.py | 48 ++++++----- .../test_compare_recursive_dataclasses.py | 34 ++++++++ testing/test_assertion.py | 80 +++++++++++++++++++ 4 files changed, 142 insertions(+), 21 deletions(-) create mode 100644 changelog/4675.bugfix.rst create mode 100644 testing/example_scripts/dataclasses/test_compare_recursive_dataclasses.py diff --git a/changelog/4675.bugfix.rst b/changelog/4675.bugfix.rst new file mode 100644 index 00000000000..9f857622f08 --- /dev/null +++ b/changelog/4675.bugfix.rst @@ -0,0 +1 @@ +Make dataclasses/attrs comparison recursive. diff --git a/src/_pytest/assertion/util.py b/src/_pytest/assertion/util.py index 7d525aa4c42..c2f0431d479 100644 --- a/src/_pytest/assertion/util.py +++ b/src/_pytest/assertion/util.py @@ -148,26 +148,7 @@ def assertrepr_compare(config, op: str, left: Any, right: Any) -> Optional[List[ explanation = None try: if op == "==": - if istext(left) and istext(right): - explanation = _diff_text(left, right, verbose) - else: - if issequence(left) and issequence(right): - explanation = _compare_eq_sequence(left, right, verbose) - elif isset(left) and isset(right): - explanation = _compare_eq_set(left, right, verbose) - elif isdict(left) and isdict(right): - explanation = _compare_eq_dict(left, right, verbose) - elif type(left) == type(right) and (isdatacls(left) or isattrs(left)): - type_fn = (isdatacls, isattrs) - explanation = _compare_eq_cls(left, right, verbose, type_fn) - elif verbose > 0: - explanation = _compare_eq_verbose(left, right) - if isiterable(left) and isiterable(right): - expl = _compare_eq_iterable(left, right, verbose) - if explanation is not None: - explanation.extend(expl) - else: - explanation = expl + explanation = _compare_eq_any(left, right, verbose) elif op == "not in": if istext(left) and istext(right): explanation = _notin_text(left, right, verbose) @@ -187,6 +168,28 @@ def assertrepr_compare(config, op: str, left: Any, right: Any) -> Optional[List[ return [summary] + explanation +def _compare_eq_any(left: Any, right: Any, verbose: int = 0) -> List[str]: + explanation = [] # type: List[str] + if istext(left) and istext(right): + explanation = _diff_text(left, right, verbose) + else: + if issequence(left) and issequence(right): + explanation = _compare_eq_sequence(left, right, verbose) + elif isset(left) and isset(right): + explanation = _compare_eq_set(left, right, verbose) + elif isdict(left) and isdict(right): + explanation = _compare_eq_dict(left, right, verbose) + elif type(left) == type(right) and (isdatacls(left) or isattrs(left)): + type_fn = (isdatacls, isattrs) + explanation = _compare_eq_cls(left, right, verbose, type_fn) + elif verbose > 0: + explanation = _compare_eq_verbose(left, right) + if isiterable(left) and isiterable(right): + expl = _compare_eq_iterable(left, right, verbose) + explanation.extend(expl) + return explanation + + def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]: """Return the explanation for the diff between text. @@ -439,7 +442,10 @@ def _compare_eq_cls( explanation += ["Differing attributes:"] for field in diff: explanation += [ - ("%s: %r != %r") % (field, getattr(left, field), getattr(right, field)) + ("%s: %r != %r") % (field, getattr(left, field), getattr(right, field)), + "", + "Drill down into differing attribute %s:" % field, + *_compare_eq_any(getattr(left, field), getattr(right, field), verbose), ] return explanation diff --git a/testing/example_scripts/dataclasses/test_compare_recursive_dataclasses.py b/testing/example_scripts/dataclasses/test_compare_recursive_dataclasses.py new file mode 100644 index 00000000000..98385379ead --- /dev/null +++ b/testing/example_scripts/dataclasses/test_compare_recursive_dataclasses.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass +from dataclasses import field + + +@dataclass +class SimpleDataObject: + field_a: int = field() + field_b: int = field() + + +@dataclass +class ComplexDataObject2: + field_a: SimpleDataObject = field() + field_b: SimpleDataObject = field() + + +@dataclass +class ComplexDataObject: + field_a: SimpleDataObject = field() + field_b: ComplexDataObject2 = field() + + +def test_recursive_dataclasses(): + + left = ComplexDataObject( + SimpleDataObject(1, "b"), + ComplexDataObject2(SimpleDataObject(1, "b"), SimpleDataObject(2, "c"),), + ) + right = ComplexDataObject( + SimpleDataObject(1, "b"), + ComplexDataObject2(SimpleDataObject(1, "b"), SimpleDataObject(3, "c"),), + ) + + assert left == right diff --git a/testing/test_assertion.py b/testing/test_assertion.py index f28876edcc7..4b1df89c93f 100644 --- a/testing/test_assertion.py +++ b/testing/test_assertion.py @@ -781,6 +781,48 @@ def test_dataclasses(self, testdir): "*Omitting 1 identical items, use -vv to show*", "*Differing attributes:*", "*field_b: 'b' != 'c'*", + "*- c*", + "*+ b*", + ] + ) + + @pytest.mark.skipif(sys.version_info < (3, 7), reason="Dataclasses in Python3.7+") + def test_recursive_dataclasses(self, testdir): + p = testdir.copy_example("dataclasses/test_compare_recursive_dataclasses.py") + result = testdir.runpytest(p) + result.assert_outcomes(failed=1, passed=0) + result.stdout.fnmatch_lines( + [ + "*Omitting 1 identical items, use -vv to show*", + "*Differing attributes:*", + "*field_b: ComplexDataObject2(*SimpleDataObject(field_a=2, field_b='c')) != ComplexDataObject2(*SimpleDataObject(field_a=3, field_b='c'))*", # noqa + "*Drill down into differing attribute field_b:*", + "*Omitting 1 identical items, use -vv to show*", + "*Differing attributes:*", + "*Full output truncated*", + ] + ) + + @pytest.mark.skipif(sys.version_info < (3, 7), reason="Dataclasses in Python3.7+") + def test_recursive_dataclasses_verbose(self, testdir): + p = testdir.copy_example("dataclasses/test_compare_recursive_dataclasses.py") + result = testdir.runpytest(p, "-vv") + result.assert_outcomes(failed=1, passed=0) + result.stdout.fnmatch_lines( + [ + "*Matching attributes:*", + "*['field_a']*", + "*Differing attributes:*", + "*field_b: ComplexDataObject2(*SimpleDataObject(field_a=2, field_b='c')) != ComplexDataObject2(*SimpleDataObject(field_a=3, field_b='c'))*", # noqa + "*Matching attributes:*", + "*['field_a']*", + "*Differing attributes:*", + "*field_b: SimpleDataObject(field_a=2, field_b='c') " + "!= SimpleDataObject(field_a=3, field_b='c')*", + "*Matching attributes:*", + "*['field_b']*", + "*Differing attributes:*", + "*field_a: 2 != 3", ] ) @@ -832,6 +874,44 @@ class SimpleDataObject: for line in lines[1:]: assert "field_a" not in line + def test_attrs_recursive(self) -> None: + @attr.s + class OtherDataObject: + field_c = attr.ib() + field_d = attr.ib() + + @attr.s + class SimpleDataObject: + field_a = attr.ib() + field_b = attr.ib() + + left = SimpleDataObject(OtherDataObject(1, "a"), "b") + right = SimpleDataObject(OtherDataObject(1, "b"), "b") + + lines = callequal(left, right) + assert "Matching attributes" not in lines + for line in lines[1:]: + assert "field_b:" not in line + assert "field_c:" not in line + + def test_attrs_recursive_verbose(self) -> None: + @attr.s + class OtherDataObject: + field_c = attr.ib() + field_d = attr.ib() + + @attr.s + class SimpleDataObject: + field_a = attr.ib() + field_b = attr.ib() + + left = SimpleDataObject(OtherDataObject(1, "a"), "b") + right = SimpleDataObject(OtherDataObject(1, "b"), "b") + + lines = callequal(left, right) + assert "field_d: 'a' != 'b'" in lines + print("\n".join(lines)) + def test_attrs_verbose(self) -> None: @attr.s class SimpleDataObject: