Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Print the thread name in the live TUI #562

Merged
merged 5 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions news/562.feature.2.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Capture the name attribute of Python `threading.Thread` objects.
1 change: 1 addition & 0 deletions news/562.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Include the thread name in the live TUI
21 changes: 14 additions & 7 deletions src/memray/_memray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ from ._destination import FileDestination
from ._destination import SocketDestination
from ._metadata import Metadata
from ._stats import Stats
from ._thread_name_interceptor import ThreadNameInterceptor


def set_log_level(int level):
Expand Down Expand Up @@ -306,9 +307,7 @@ cdef class AllocationRecord:
if self.tid == -1:
return "merged thread"
assert self._reader.get() != NULL, "Cannot get thread name without reader."
cdef object name = self._reader.get().getThreadName(self.tid)
thread_id = hex(self.tid)
return f"{thread_id} ({name})" if name else f"{thread_id}"
return self._reader.get().getThreadName(self.tid)

def stack_trace(self, max_stacks=None):
cache_key = ("python", max_stacks)
Expand Down Expand Up @@ -441,9 +440,7 @@ cdef class TemporalAllocationRecord:
@property
def thread_name(self):
assert self._reader.get() != NULL, "Cannot get thread name without reader."
cdef object name = self._reader.get().getThreadName(self.tid)
thread_id = hex(self.tid)
return f"{thread_id} ({name})" if name else f"{thread_id}"
return self._reader.get().getThreadName(self.tid)

def stack_trace(self, max_stacks=None):
cache_key = ("python", max_stacks)
Expand Down Expand Up @@ -681,7 +678,6 @@ cdef class Tracker:

@cython.profile(False)
def __enter__(self):

if NativeTracker.getTracker() != NULL:
raise RuntimeError("No more than one Tracker instance can be active at the same time")

Expand All @@ -690,6 +686,14 @@ cdef class Tracker:
raise RuntimeError("Attempting to use stale output handle")
writer = move(self._writer)

for attr in ("_name", "_ident"):
assert not hasattr(threading.Thread, attr)
setattr(
threading.Thread,
attr,
ThreadNameInterceptor(attr, NativeTracker.registerThreadNameById),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good old friend :P

)

self._previous_profile_func = sys.getprofile()
self._previous_thread_profile_func = threading._profile_hook
threading.setprofile(start_thread_trace)
Expand All @@ -712,6 +716,9 @@ cdef class Tracker:
sys.setprofile(self._previous_profile_func)
threading.setprofile(self._previous_thread_profile_func)

for attr in ("_name", "_ident"):
delattr(threading.Thread, attr)


def start_thread_trace(frame, event, arg):
if event in {"call", "c_call"}:
Expand Down
27 changes: 27 additions & 0 deletions src/memray/_memray/tracking_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,7 @@ Tracker::trackAllocationImpl(
hooks::Allocator func,
const std::optional<NativeTrace>& trace)
{
registerCachedThreadName();
PythonStackTracker::get().emitPendingPushesAndPops();

if (d_unwind_native_frames) {
Expand Down Expand Up @@ -871,6 +872,7 @@ Tracker::trackAllocationImpl(
void
Tracker::trackDeallocationImpl(void* ptr, size_t size, hooks::Allocator func)
{
registerCachedThreadName();
AllocationRecord record{reinterpret_cast<uintptr_t>(ptr), size, func};
if (!d_writer->writeThreadSpecificRecord(thread_id(), record)) {
std::cerr << "Failed to write output, deactivating tracking" << std::endl;
Expand Down Expand Up @@ -963,12 +965,37 @@ void
Tracker::registerThreadNameImpl(const char* name)
{
RecursionGuard guard;
dropCachedThreadName();
if (!d_writer->writeThreadSpecificRecord(thread_id(), ThreadRecord{name})) {
std::cerr << "memray: Failed to write output, deactivating tracking" << std::endl;
deactivate();
}
}

void
Tracker::registerCachedThreadName()
{
if (d_cached_thread_names.empty()) {
return;
}

auto it = d_cached_thread_names.find((uint64_t)(pthread_self()));
pablogsal marked this conversation as resolved.
Show resolved Hide resolved
if (it != d_cached_thread_names.end()) {
auto& name = it->second;
if (!d_writer->writeThreadSpecificRecord(thread_id(), ThreadRecord{name.c_str()})) {
std::cerr << "memray: Failed to write output, deactivating tracking" << std::endl;
deactivate();
}
d_cached_thread_names.erase(it);
}
}

void
Tracker::dropCachedThreadName()
{
d_cached_thread_names.erase((uint64_t)(pthread_self()));
pablogsal marked this conversation as resolved.
Show resolved Hide resolved
}

frame_id_t
Tracker::registerFrame(const RawFrame& frame)
{
Expand Down
24 changes: 24 additions & 0 deletions src/memray/_memray/tracking_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,27 @@ class Tracker
}
}

inline static void registerThreadNameById(uint64_t thread, const char* name)
{
if (RecursionGuard::isActive || !Tracker::isActive()) {
return;
}
RecursionGuard guard;

std::unique_lock<std::mutex> lock(*s_mutex);
Tracker* tracker = getTracker();
if (tracker) {
if (thread == (uint64_t)(pthread_self())) {
pablogsal marked this conversation as resolved.
Show resolved Hide resolved
tracker->registerThreadNameImpl(name);
} else {
// We've got a different thread's name, but don't know what id
// has been assigned to that thread (if any!). Set this update
// aside to be handled later, from that thread.
tracker->d_cached_thread_names.emplace(thread, name);
}
}
}

// RawFrame stack interface
bool pushFrame(const RawFrame& frame);
bool popFrames(uint32_t count);
Expand Down Expand Up @@ -359,6 +380,7 @@ class Tracker
const bool d_trace_python_allocators;
linker::SymbolPatcher d_patcher;
std::unique_ptr<BackgroundThread> d_background_thread;
std::unordered_map<uint64_t, std::string> d_cached_thread_names;

// Methods
static size_t computeMainTidSkip();
Expand All @@ -373,6 +395,8 @@ class Tracker
void invalidate_module_cache_impl();
void updateModuleCacheImpl();
void registerThreadNameImpl(const char* name);
void registerCachedThreadName();
void dropCachedThreadName();
void registerPymallocHooks() const noexcept;
void unregisterPymallocHooks() const noexcept;

Expand Down
4 changes: 4 additions & 0 deletions src/memray/_memray/tracking_api.pxd
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from _memray.record_writer cimport RecordWriter
from libc.stdint cimport uint64_t
from libcpp cimport bool
from libcpp.memory cimport unique_ptr
from libcpp.string cimport string
Expand Down Expand Up @@ -31,3 +32,6 @@ cdef extern from "tracking_api.h" namespace "memray::tracking_api":

@staticmethod
void handleGreenletSwitch(object, object) except+

@staticmethod
void registerThreadNameById(uint64_t, const char*) except+
23 changes: 23 additions & 0 deletions src/memray/_thread_name_interceptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import threading
from typing import Callable


class ThreadNameInterceptor:
"""Record the name of each threading.Thread for Memray's reports.

The name can be set either before or after the thread is started, and from
either the same thread or a different thread. Whenever an assignment to
either `Thread._name` or `Thread._ident` is performed and the other has
already been set, we call a callback with the thread's ident and name.
"""

def __init__(self, attr: str, callback: Callable[[int, str], None]) -> None:
self._attr = attr
self._callback = callback

def __set__(self, instance: threading.Thread, value: object) -> None:
instance.__dict__[self._attr] = value
ident = instance.__dict__.get("_ident")
name = instance.__dict__.get("_name")
if ident is not None and name is not None:
self._callback(ident, name)
14 changes: 14 additions & 0 deletions src/memray/reporters/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Union

from memray._memray import AllocationRecord
from memray._memray import TemporalAllocationRecord


def format_thread_name(
record: Union[AllocationRecord, TemporalAllocationRecord]
) -> str:
if record.tid == -1:
return "merged thread"
name = record.thread_name
thread_id = hex(record.tid)
return f"{thread_id} ({name})" if name else f"{thread_id}"
7 changes: 4 additions & 3 deletions src/memray/reporters/flamegraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from memray import Metadata
from memray._memray import Interval
from memray._memray import TemporalAllocationRecord
from memray.reporters.common import format_thread_name
from memray.reporters.frame_tools import StackFrame
from memray.reporters.frame_tools import is_cpython_internal
from memray.reporters.frame_tools import is_frame_from_import_system
Expand Down Expand Up @@ -263,21 +264,21 @@ def _from_any_snapshot(

unique_threads: Set[str] = set()
for record in allocations:
unique_threads.add(record.thread_name)
unique_threads.add(format_thread_name(record))

record_data: RecordData
if temporal:
assert isinstance(record, TemporalAllocationRecord)
record_data = {
"thread_name": record.thread_name,
"thread_name": format_thread_name(record),
"intervals": record.intervals,
"size": None,
"n_allocations": None,
}
else:
assert not isinstance(record, TemporalAllocationRecord)
record_data = {
"thread_name": record.thread_name,
"thread_name": format_thread_name(record),
"intervals": None,
"size": record.size,
"n_allocations": record.n_allocations,
Expand Down
3 changes: 2 additions & 1 deletion src/memray/reporters/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from memray import AllocatorType
from memray import MemorySnapshot
from memray import Metadata
from memray.reporters.common import format_thread_name
from memray.reporters.templates import render_report


Expand Down Expand Up @@ -47,7 +48,7 @@ def from_snapshot(
allocator = AllocatorType(record.allocator)
result.append(
{
"tid": record.thread_name,
"tid": format_thread_name(record),
"size": record.size,
"allocator": allocator.name.lower(),
"n_allocations": record.n_allocations,
Expand Down
3 changes: 2 additions & 1 deletion src/memray/reporters/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from memray import AllocatorType
from memray import MemorySnapshot
from memray import Metadata
from memray.reporters.common import format_thread_name

Location = Tuple[str, str]

Expand Down Expand Up @@ -117,7 +118,7 @@ def render_as_csv(
record.n_allocations,
record.size,
record.tid,
record.thread_name,
format_thread_name(record),
"|".join(f"{func};{mod};{line}" for func, mod, line in stack_trace),
]
)
3 changes: 2 additions & 1 deletion src/memray/reporters/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from memray.reporters._textual_hacks import Bindings
from memray.reporters._textual_hacks import redraw_footer
from memray.reporters._textual_hacks import update_key_description
from memray.reporters.common import format_thread_name
from memray.reporters.frame_tools import is_cpython_internal
from memray.reporters.frame_tools import is_frame_from_import_system
from memray.reporters.frame_tools import is_frame_interesting
Expand Down Expand Up @@ -476,7 +477,7 @@ def from_snapshot(
current_frame = current_frame.children[stack_frame]
current_frame.value += size
current_frame.n_allocations += record.n_allocations
current_frame.thread_id = record.thread_name
current_frame.thread_id = format_thread_name(record)

if index > MAX_STACKS:
break
Expand Down
10 changes: 7 additions & 3 deletions src/memray/reporters/tui.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def __init__(self, pid: Optional[int], cmd_line: Optional[str], native: bool):
self.pid = pid
self.cmd_line = cmd_line
self.native = native
self._seen_threads: Set[int] = set()
self._name_by_tid: Dict[int, str] = {}
self._max_memory_seen = 0
self._merge_threads = True
super().__init__()
Expand Down Expand Up @@ -575,6 +575,9 @@ def _populate_header_thread_labels(self, thread_idx: int) -> None:
else:
tid_label = f"[b]TID[/]: {hex(self.current_thread)}"
thread_label = f"[b]Thread[/] {thread_idx + 1} of {len(self.threads)}"
thread_name = self._name_by_tid.get(self.current_thread)
if thread_name:
thread_label += f" ({thread_name})"

self.query_one("#tid", Label).update(tid_label)
self.query_one("#thread", Label).update(thread_label)
Expand Down Expand Up @@ -662,8 +665,9 @@ def display_snapshot(self) -> None:
if self.paused:
return

new_tids = {record.tid for record in snapshot.records} - self._seen_threads
self._seen_threads.update(new_tids)
name_by_tid = {record.tid: record.thread_name for record in snapshot.records}
new_tids = name_by_tid.keys() - self._name_by_tid.keys()
self._name_by_tid.update(name_by_tid)

if new_tids:
threads = self.threads
Expand Down
52 changes: 51 additions & 1 deletion tests/integration/test_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,54 @@ def allocating_function():
assert len(vallocs) == 1
(valloc,) = vallocs
assert valloc.size == 1234
assert "my thread name" in valloc.thread_name
assert "my thread name" == valloc.thread_name


def test_setting_python_thread_name(tmpdir):
# GIVEN
output = Path(tmpdir) / "test.bin"
allocator = MemoryAllocator()
name_set_inside_thread = threading.Event()
name_set_outside_thread = threading.Event()
prctl_rc = -1

def allocating_function():
allocator.valloc(1234)
allocator.free()

threading.current_thread().name = "set inside thread"
allocator.valloc(1234)
allocator.free()

name_set_inside_thread.set()
name_set_outside_thread.wait()
allocator.valloc(1234)
allocator.free()

nonlocal prctl_rc
prctl_rc = set_thread_name("set by prctl")
allocator.valloc(1234)
allocator.free()

# WHEN
with Tracker(output):
t = threading.Thread(target=allocating_function, name="set before start")
t.start()
name_set_inside_thread.wait()
t.name = "set outside running thread"
name_set_outside_thread.set()
t.join()

# THEN
expected_names = [
"set before start",
"set inside thread",
"set outside running thread",
"set by prctl" if prctl_rc == 0 else "set outside running thread",
]
names = [
rec.thread_name
for rec in FileReader(output).get_allocation_records()
if rec.allocator == AllocatorType.VALLOC
]
assert names == expected_names
Loading
Loading