Skip to content

Commit

Permalink
code: stop storing weakref to ExceptionInfo on Traceback and Tracebac…
Browse files Browse the repository at this point in the history
…kEntry

TracebackEntry needs the excinfo for the `__tracebackhide__ = callback`
functionality, where `callback` accepts the excinfo.

Currently it achieves this by storing a weakref to the excinfo which
created it. I think this is not great, mixing layers and bloating the
objects.

Instead, have `ishidden` (and transitively, `Traceback.filter()`) take
the excinfo as a parameter.
  • Loading branch information
bluetech committed Apr 28, 2023
1 parent 11965d1 commit cc23ec9
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 41 deletions.
53 changes: 29 additions & 24 deletions src/_pytest/_code/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from weakref import ref

import pluggy

Expand All @@ -52,7 +51,6 @@
if TYPE_CHECKING:
from typing_extensions import Literal
from typing_extensions import SupportsIndex
from weakref import ReferenceType

_TracebackStyle = Literal["long", "short", "line", "no", "native", "value", "auto"]

Expand Down Expand Up @@ -194,15 +192,13 @@ def getargs(self, var: bool = False):
class TracebackEntry:
"""A single entry in a Traceback."""

__slots__ = ("_rawentry", "_excinfo", "_repr_style")
__slots__ = ("_rawentry", "_repr_style")

def __init__(
self,
rawentry: TracebackType,
excinfo: Optional["ReferenceType[ExceptionInfo[BaseException]]"] = None,
) -> None:
self._rawentry = rawentry
self._excinfo = excinfo
self._repr_style: Optional['Literal["short", "long"]'] = None

@property
Expand Down Expand Up @@ -272,7 +268,7 @@ def getsource(

source = property(getsource)

def ishidden(self) -> bool:
def ishidden(self, excinfo: Optional["ExceptionInfo[BaseException]"]) -> bool:
"""Return True if the current frame has a var __tracebackhide__
resolving to True.
Expand All @@ -296,7 +292,7 @@ def ishidden(self) -> bool:
else:
break
if tbh and callable(tbh):
return tbh(None if self._excinfo is None else self._excinfo())
return tbh(excinfo)
return tbh

def __str__(self) -> str:
Expand Down Expand Up @@ -329,16 +325,14 @@ class Traceback(List[TracebackEntry]):
def __init__(
self,
tb: Union[TracebackType, Iterable[TracebackEntry]],
excinfo: Optional["ReferenceType[ExceptionInfo[BaseException]]"] = None,
) -> None:
"""Initialize from given python traceback object and ExceptionInfo."""
self._excinfo = excinfo
if isinstance(tb, TracebackType):

def f(cur: TracebackType) -> Iterable[TracebackEntry]:
cur_: Optional[TracebackType] = cur
while cur_ is not None:
yield TracebackEntry(cur_, excinfo=excinfo)
yield TracebackEntry(cur_)
cur_ = cur_.tb_next

super().__init__(f(tb))
Expand Down Expand Up @@ -378,7 +372,7 @@ def cut(
continue
if firstlineno is not None and x.frame.code.firstlineno != firstlineno:
continue
return Traceback(x._rawentry, self._excinfo)
return Traceback(x._rawentry)
return self

@overload
Expand All @@ -398,25 +392,36 @@ def __getitem__(
return super().__getitem__(key)

def filter(
self, fn: Callable[[TracebackEntry], bool] = lambda x: not x.ishidden()
self,
# TODO(py38): change to positional only.
_excinfo_or_fn: Union[
"ExceptionInfo[BaseException]",
Callable[[TracebackEntry], bool],
],
) -> "Traceback":
"""Return a Traceback instance with certain items removed
"""Return a Traceback instance with certain items removed.
fn is a function that gets a single argument, a TracebackEntry
instance, and should return True when the item should be added
to the Traceback, False when not.
If the filter is an `ExceptionInfo`, removes all the ``TracebackEntry``s
which are hidden (see ishidden() above).
By default this removes all the TracebackEntries which are hidden
(see ishidden() above).
Otherwise, the filter is a function that gets a single argument, a
``TracebackEntry`` instance, and should return True when the item should
be added to the ``Traceback``, False when not.
"""
return Traceback(filter(fn, self), self._excinfo)
if isinstance(_excinfo_or_fn, ExceptionInfo):
fn = lambda x: not x.ishidden(_excinfo_or_fn) # noqa: E731
else:
fn = _excinfo_or_fn
return Traceback(filter(fn, self))

def getcrashentry(self) -> Optional[TracebackEntry]:
def getcrashentry(
self, excinfo: Optional["ExceptionInfo[BaseException]"]
) -> Optional[TracebackEntry]:
"""Return last non-hidden traceback entry that lead to the exception of
a traceback, or None if all hidden."""
for i in range(-1, -len(self) - 1, -1):
entry = self[i]
if not entry.ishidden():
if not entry.ishidden(excinfo):
return entry
return None

Expand Down Expand Up @@ -583,7 +588,7 @@ def typename(self) -> str:
def traceback(self) -> Traceback:
"""The traceback."""
if self._traceback is None:
self._traceback = Traceback(self.tb, excinfo=ref(self))
self._traceback = Traceback(self.tb)
return self._traceback

@traceback.setter
Expand Down Expand Up @@ -624,7 +629,7 @@ def errisinstance(

def _getreprcrash(self) -> Optional["ReprFileLocation"]:
exconly = self.exconly(tryshort=True)
entry = self.traceback.getcrashentry()
entry = self.traceback.getcrashentry(self)
if entry is None:
return None
path, lineno = entry.frame.code.raw.co_filename, entry.lineno
Expand Down Expand Up @@ -882,7 +887,7 @@ def _makepath(self, path: Union[Path, str]) -> str:
def repr_traceback(self, excinfo: ExceptionInfo[BaseException]) -> "ReprTraceback":
traceback = excinfo.traceback
if self.tbfilter:
traceback = traceback.filter()
traceback = traceback.filter(excinfo)

if isinstance(excinfo.value, RecursionError):
traceback, extraline = self._truncate_recursive_traceback(traceback)
Expand Down
2 changes: 1 addition & 1 deletion src/_pytest/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def _prunetraceback(self, excinfo: ExceptionInfo[BaseException]) -> None:
ntraceback = traceback.cut(path=self.path)
if ntraceback == traceback:
ntraceback = ntraceback.cut(excludepath=tracebackcutdir)
excinfo.traceback = ntraceback.filter()
excinfo.traceback = ntraceback.filter(excinfo)


def _check_initialpaths_for_relpath(session: "Session", path: Path) -> Optional[str]:
Expand Down
2 changes: 1 addition & 1 deletion src/_pytest/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -1814,7 +1814,7 @@ def _prunetraceback(self, excinfo: ExceptionInfo[BaseException]) -> None:
if not ntraceback:
ntraceback = traceback

excinfo.traceback = ntraceback.filter()
excinfo.traceback = ntraceback.filter(excinfo)
# issue364: mark all but first and last frames to
# only show a single-line message for each frame.
if self.config.getoption("tbstyle", "auto") == "auto":
Expand Down
2 changes: 1 addition & 1 deletion src/_pytest/unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def _prunetraceback(
) -> None:
super()._prunetraceback(excinfo)
traceback = excinfo.traceback.filter(
lambda x: not x.frame.f_globals.get("__unittest")
lambda x: not x.frame.f_globals.get("__unittest"),
)
if traceback:
excinfo.traceback = traceback
Expand Down
24 changes: 12 additions & 12 deletions testing/code/test_excinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def test_traceback_cut_excludepath(self, pytester: Pytester) -> None:

def test_traceback_filter(self):
traceback = self.excinfo.traceback
ntraceback = traceback.filter()
ntraceback = traceback.filter(self.excinfo)
assert len(ntraceback) == len(traceback) - 1

@pytest.mark.parametrize(
Expand Down Expand Up @@ -217,7 +217,7 @@ def h():

excinfo = pytest.raises(ValueError, h)
traceback = excinfo.traceback
ntraceback = traceback.filter()
ntraceback = traceback.filter(excinfo)
print(f"old: {traceback!r}")
print(f"new: {ntraceback!r}")

Expand Down Expand Up @@ -307,7 +307,7 @@ def f():

excinfo = pytest.raises(ValueError, f)
tb = excinfo.traceback
entry = tb.getcrashentry()
entry = tb.getcrashentry(excinfo)
assert entry is not None
co = _pytest._code.Code.from_function(h)
assert entry.frame.code.path == co.path
Expand All @@ -324,7 +324,7 @@ def f():
g()

excinfo = pytest.raises(ValueError, f)
assert excinfo.traceback.getcrashentry() is None
assert excinfo.traceback.getcrashentry(excinfo) is None


def test_excinfo_exconly():
Expand Down Expand Up @@ -626,7 +626,7 @@ def func1():
"""
)
excinfo = pytest.raises(ValueError, mod.func1)
excinfo.traceback = excinfo.traceback.filter()
excinfo.traceback = excinfo.traceback.filter(excinfo)
p = FormattedExcinfo()
reprtb = p.repr_traceback_entry(excinfo.traceback[-1])

Expand Down Expand Up @@ -659,7 +659,7 @@ def func1(m, x, y, z):
"""
)
excinfo = pytest.raises(ValueError, mod.func1, "m" * 90, 5, 13, "z" * 120)
excinfo.traceback = excinfo.traceback.filter()
excinfo.traceback = excinfo.traceback.filter(excinfo)
entry = excinfo.traceback[-1]
p = FormattedExcinfo(funcargs=True)
reprfuncargs = p.repr_args(entry)
Expand All @@ -686,7 +686,7 @@ def func1(x, *y, **z):
"""
)
excinfo = pytest.raises(ValueError, mod.func1, "a", "b", c="d")
excinfo.traceback = excinfo.traceback.filter()
excinfo.traceback = excinfo.traceback.filter(excinfo)
entry = excinfo.traceback[-1]
p = FormattedExcinfo(funcargs=True)
reprfuncargs = p.repr_args(entry)
Expand Down Expand Up @@ -960,7 +960,7 @@ def f():
"""
)
excinfo = pytest.raises(ValueError, mod.f)
excinfo.traceback = excinfo.traceback.filter()
excinfo.traceback = excinfo.traceback.filter(excinfo)
repr = excinfo.getrepr()
repr.toterminal(tw_mock)
assert tw_mock.lines[0] == ""
Expand Down Expand Up @@ -994,7 +994,7 @@ def f():
)
excinfo = pytest.raises(ValueError, mod.f)
tmp_path.joinpath("mod.py").unlink()
excinfo.traceback = excinfo.traceback.filter()
excinfo.traceback = excinfo.traceback.filter(excinfo)
repr = excinfo.getrepr()
repr.toterminal(tw_mock)
assert tw_mock.lines[0] == ""
Expand Down Expand Up @@ -1026,7 +1026,7 @@ def f():
)
excinfo = pytest.raises(ValueError, mod.f)
tmp_path.joinpath("mod.py").write_text("asdf")
excinfo.traceback = excinfo.traceback.filter()
excinfo.traceback = excinfo.traceback.filter(excinfo)
repr = excinfo.getrepr()
repr.toterminal(tw_mock)
assert tw_mock.lines[0] == ""
Expand Down Expand Up @@ -1123,7 +1123,7 @@ def i():
"""
)
excinfo = pytest.raises(ValueError, mod.f)
excinfo.traceback = excinfo.traceback.filter()
excinfo.traceback = excinfo.traceback.filter(excinfo)
excinfo.traceback[1].set_repr_style("short")
excinfo.traceback[2].set_repr_style("short")
r = excinfo.getrepr(style="long")
Expand Down Expand Up @@ -1391,7 +1391,7 @@ def f():
with pytest.raises(TypeError) as excinfo:
mod.f()
# previously crashed with `AttributeError: list has no attribute get`
excinfo.traceback.filter()
excinfo.traceback.filter(excinfo)


@pytest.mark.parametrize("style", ["short", "long"])
Expand Down
4 changes: 2 additions & 2 deletions testing/python/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,9 +1003,9 @@ def test_skip_simple(self):
with pytest.raises(pytest.skip.Exception) as excinfo:
pytest.skip("xxx")
assert excinfo.traceback[-1].frame.code.name == "skip"
assert excinfo.traceback[-1].ishidden()
assert excinfo.traceback[-1].ishidden(excinfo)
assert excinfo.traceback[-2].frame.code.name == "test_skip_simple"
assert not excinfo.traceback[-2].ishidden()
assert not excinfo.traceback[-2].ishidden(excinfo)

def test_traceback_argsetup(self, pytester: Pytester) -> None:
pytester.makeconftest(
Expand Down

0 comments on commit cc23ec9

Please sign in to comment.