From d19f8c65ba7c46fbc3632ba06c9f7ccefa10b4ba Mon Sep 17 00:00:00 2001 From: ibriquem Date: Fri, 28 Feb 2020 12:32:19 +0100 Subject: [PATCH] Make dataclasses/attrs comparison recursive, fixes #4675 --- changelog/4675.bugfix.rst | 1 + src/_pytest/assertion/util.py | 46 ++++++++++--------- .../test_compare_recursive_dataclasses.py | 22 +++++++++ testing/test_assertion.py | 18 ++++++++ 4 files changed, 66 insertions(+), 21 deletions(-) create mode 100644 changelog/4675.bugfix.rst create mode 100644 testing/example_scripts/dataclasses/test_compare_recursive_dataclasses.py diff --git a/changelog/4675.bugfix.rst b/changelog/4675.bugfix.rst new file mode 100644 index 0000000000..9f857622f0 --- /dev/null +++ b/changelog/4675.bugfix.rst @@ -0,0 +1 @@ +Make dataclasses/attrs comparison recursive. diff --git a/src/_pytest/assertion/util.py b/src/_pytest/assertion/util.py index 7d525aa4c4..c704591bdd 100644 --- a/src/_pytest/assertion/util.py +++ b/src/_pytest/assertion/util.py @@ -148,26 +148,7 @@ def assertrepr_compare(config, op: str, left: Any, right: Any) -> Optional[List[ explanation = None try: if op == "==": - if istext(left) and istext(right): - explanation = _diff_text(left, right, verbose) - else: - if issequence(left) and issequence(right): - explanation = _compare_eq_sequence(left, right, verbose) - elif isset(left) and isset(right): - explanation = _compare_eq_set(left, right, verbose) - elif isdict(left) and isdict(right): - explanation = _compare_eq_dict(left, right, verbose) - elif type(left) == type(right) and (isdatacls(left) or isattrs(left)): - type_fn = (isdatacls, isattrs) - explanation = _compare_eq_cls(left, right, verbose, type_fn) - elif verbose > 0: - explanation = _compare_eq_verbose(left, right) - if isiterable(left) and isiterable(right): - expl = _compare_eq_iterable(left, right, verbose) - if explanation is not None: - explanation.extend(expl) - else: - explanation = expl + explanation = _compare_eq_any(left, right, verbose) elif op == "not in": if istext(left) and istext(right): explanation = _notin_text(left, right, verbose) @@ -187,6 +168,28 @@ def assertrepr_compare(config, op: str, left: Any, right: Any) -> Optional[List[ return [summary] + explanation +def _compare_eq_any(left: Any, right: Any, verbose: int = 0) -> List[str]: + explanation = [] # type: List[str] + if istext(left) and istext(right): + explanation = _diff_text(left, right, verbose) + else: + if issequence(left) and issequence(right): + explanation = _compare_eq_sequence(left, right, verbose) + elif isset(left) and isset(right): + explanation = _compare_eq_set(left, right, verbose) + elif isdict(left) and isdict(right): + explanation = _compare_eq_dict(left, right, verbose) + elif type(left) == type(right) and (isdatacls(left) or isattrs(left)): + type_fn = (isdatacls, isattrs) + explanation = _compare_eq_cls(left, right, verbose, type_fn) + elif verbose > 0: + explanation = _compare_eq_verbose(left, right) + if isiterable(left) and isiterable(right): + expl = _compare_eq_iterable(left, right, verbose) + explanation.extend(expl) + return explanation + + def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]: """Return the explanation for the diff between text. @@ -439,7 +442,8 @@ def _compare_eq_cls( explanation += ["Differing attributes:"] for field in diff: explanation += [ - ("%s: %r != %r") % (field, getattr(left, field), getattr(right, field)) + ("%s: %r != %r") % (field, getattr(left, field), getattr(right, field)), + *_compare_eq_any(getattr(left, field), getattr(right, field), verbose), ] return explanation diff --git a/testing/example_scripts/dataclasses/test_compare_recursive_dataclasses.py b/testing/example_scripts/dataclasses/test_compare_recursive_dataclasses.py new file mode 100644 index 0000000000..7fa32b62f0 --- /dev/null +++ b/testing/example_scripts/dataclasses/test_compare_recursive_dataclasses.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass +from dataclasses import field + + +@dataclass +class SimpleDataObject: + field_a: int = field() + field_b: int = field() + + +@dataclass +class ComplexDataObject: + field_a: SimpleDataObject = field() + field_b: SimpleDataObject = field() + + +def test_recursive_dataclasses(): + + left = ComplexDataObject(SimpleDataObject(1, "b"), SimpleDataObject(2, "c"),) + right = ComplexDataObject(SimpleDataObject(1, "b"), SimpleDataObject(3, "c"),) + + assert left == right diff --git a/testing/test_assertion.py b/testing/test_assertion.py index b12b3b119b..3b47db2d21 100644 --- a/testing/test_assertion.py +++ b/testing/test_assertion.py @@ -752,6 +752,24 @@ def test_dataclasses(self, testdir): "*Omitting 1 identical items, use -vv to show*", "*Differing attributes:*", "*field_b: 'b' != 'c'*", + "*- c*", + "*+ b*", + ] + ) + + @pytest.mark.skipif(sys.version_info < (3, 7), reason="Dataclasses in Python3.7+") + def test_recursive_dataclasses(self, testdir): + p = testdir.copy_example("dataclasses/test_compare_recursive_dataclasses.py") + result = testdir.runpytest(p) + result.assert_outcomes(failed=1, passed=0) + result.stdout.fnmatch_lines( + [ + "*Omitting 1 identical items, use -vv to show*", + "*Differing attributes:*", + "*field_b: SimpleDataObject(field_a=2, field_b='c') != SimpleDataObject(field_a=3, field_b='c')*", + "*Omitting 1 identical items, use -vv to show*", + "*Differing attributes:*", + "*field_a: 2 != 3*", ] )