Skip to content

Commit

Permalink
feat(record): memory profiling
Browse files Browse the repository at this point in the history
* tracemalloc

* pympler

* todo

* changed position of tracemalloc stats collection

* updated requirements.txt

* memory leak fix and cleanup

* removed todo

* changed printing to logging

* alphabetical order

* changes to tracemalloc usage

* plot memory usage

* memory writer terminates with performance writer

* add MemoryStat table to database

* remove todo

* switch from writing/reading memory using file to saving/retrieving from database

* add memory legend to performance plot

* prevent error from child processes terminating

* style changes

* moved PLOT_PERFORMANCE to config.py

* only display memory legend if there is memory data

* moved memory logging into function

* removed unnecessary call to row2dicts

* rename memory_usage to memory_usage_bytes

* replaced alembic revision

* remove start_time_deltas; minor refactor

* fix indent

---------

Co-authored-by: Krish Patel <[email protected]>
Co-authored-by: Richard Abrich <[email protected]>
Co-authored-by: Richard Abrich <[email protected]>
  • Loading branch information
4 people authored Jul 2, 2023
1 parent 2bb8814 commit ef0d5eb
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 56 deletions.
35 changes: 35 additions & 0 deletions alembic/versions/607d1380b5ae_add_memorystat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""add MemoryStat
Revision ID: 607d1380b5ae
Revises: 104d4a614d95
Create Date: 2023-06-28 11:54:36.749072
"""
from alembic import op
import sqlalchemy as sa
from openadapt.models import ForceFloat


# revision identifiers, used by Alembic.
revision = '607d1380b5ae'
down_revision = '104d4a614d95'
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('memory_stat',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('recording_timestamp', sa.Integer(), nullable=True),
sa.Column('memory_usage_bytes', ForceFloat(precision=10, scale=2, asdecimal=False), nullable=True),
sa.Column('timestamp', ForceFloat(precision=10, scale=2, asdecimal=False), nullable=True),
sa.PrimaryKeyConstraint('id', name=op.f('pk_memory_stat'))
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('memory_stat')
# ### end Alembic commands ###
1 change: 1 addition & 0 deletions openadapt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
"key_vk",
"children",
],
"PLOT_PERFORMANCE": True,
}

# each string in STOP_STRS should only contain strings that don't contain special characters
Expand Down
30 changes: 30 additions & 0 deletions openadapt/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Recording,
WindowEvent,
PerformanceStat,
MemoryStat
)
from openadapt.config import STOP_SEQUENCES

Expand All @@ -18,6 +19,8 @@
screenshots = []
window_events = []
performance_stats = []
memory_stats = []



def _insert(event_data, table, buffer=None):
Expand Down Expand Up @@ -100,6 +103,33 @@ def get_perf_stats(recording_timestamp):
)


def insert_memory_stat(recording_timestamp, memory_usage_bytes, timestamp):
"""
Insert memory stat into db
"""

memory_stat = {
"recording_timestamp": recording_timestamp,
"memory_usage_bytes": memory_usage_bytes,
"timestamp": timestamp,
}
_insert(memory_stat, MemoryStat, memory_stats)


def get_memory_stats(recording_timestamp):
"""
return memory stats for a given recording
"""

return (
db
.query(MemoryStat)
.filter(MemoryStat.recording_timestamp == recording_timestamp)
.order_by(MemoryStat.timestamp)
.all()
)


def insert_recording(recording_data):
db_obj = Recording(**recording_data)
db.add(db_obj)
Expand Down
13 changes: 11 additions & 2 deletions openadapt/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,11 +269,11 @@ def take_screenshot(cls):
sct_img = utils.take_screenshot()
screenshot = Screenshot(sct_img=sct_img)
return screenshot

def crop_active_window(self, action_event):
window_event = action_event.window_event
width_ratio, height_ratio = utils.get_scale_ratios(action_event)

x0 = window_event.left * width_ratio
y0 = window_event.top * height_ratio
x1 = x0 + window_event.width * width_ratio
Expand Down Expand Up @@ -314,3 +314,12 @@ class PerformanceStat(db.Base):
start_time = sa.Column(sa.Integer)
end_time = sa.Column(sa.Integer)
window_id = sa.Column(sa.String)


class MemoryStat(db.Base):
__tablename__ = "memory_stat"

id = sa.Column(sa.Integer, primary_key=True)
recording_timestamp = sa.Column(sa.Integer)
memory_usage_bytes = sa.Column(ForceFloat)
timestamp = sa.Column(ForceFloat)
101 changes: 87 additions & 14 deletions openadapt/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""

from collections import namedtuple
from functools import partial
from functools import partial, wraps
from typing import Any, Callable, Dict
import multiprocessing
import os
Expand All @@ -16,16 +16,18 @@
import sys
import threading
import time
import tracemalloc

from loguru import logger
from pympler import tracker
from pynput import keyboard, mouse
import fire
import mss.tools
import psutil

from openadapt import config, crud, utils, window

import functools

Event = namedtuple("Event", ("timestamp", "type", "data"))

EVENT_TYPES = ("screen", "action", "window")
LOG_LEVEL = "INFO"
Expand All @@ -34,16 +36,39 @@
"action": True,
"window": True,
}
PLOT_PERFORMANCE = False

Event = namedtuple("Event", ("timestamp", "type", "data"))

global sequence_detected # Flag to indicate if a stop sequence is detected
PLOT_PERFORMANCE = config.PLOT_PERFORMANCE
NUM_MEMORY_STATS_TO_LOG = 3
STOP_SEQUENCES = config.STOP_SEQUENCES

stop_sequence_detected = False
performance_snapshots = []
tracker = tracker.SummaryTracker()
tracemalloc.start()
utils.configure_logging(logger, LOG_LEVEL)


def collect_stats():
performance_snapshots.append(tracemalloc.take_snapshot())


def log_memory_usage():
assert len(performance_snapshots) == 2, performance_snapshots
first_snapshot, last_snapshot = performance_snapshots
stats = last_snapshot.compare_to(first_snapshot, "lineno")

for stat in stats[:NUM_MEMORY_STATS_TO_LOG]:
new_KiB = stat.size_diff / 1024
total_KiB = stat.size / 1024
new_blocks = stat.count_diff
total_blocks = stat.count
source = stat.traceback.format()[0].strip()
logger.info(f"{source=}")
logger.info(f"\t{new_KiB=} {total_KiB=} {new_blocks=} {total_blocks=}")

trace_str = "\n".join(list(tracker.format_diff()))
logger.info(f"trace_str=\n{trace_str}")


def args_to_str(*args):
return ", ".join(map(str, args))

Expand All @@ -54,7 +79,7 @@ def kwargs_to_str(**kwargs):

def trace(logger):
def decorator(func):
@functools.wraps(func)
@wraps(func)
def wrapper_logging(*args, **kwargs):
func_name = func.__qualname__
func_args = args_to_str(*args)
Expand Down Expand Up @@ -160,6 +185,7 @@ def process_events(
prev_saved_window_timestamp = prev_window_event.timestamp
else:
raise Exception(f"unhandled {event.type=}")
del prev_event
prev_event = event
logger.info("done")

Expand Down Expand Up @@ -470,6 +496,41 @@ def performance_stats_writer(
logger.info("performance stats writer done")


def memory_writer(
recording_timestamp: float, terminate_event: multiprocessing.Event, record_pid: int
):
utils.configure_logging(logger, LOG_LEVEL)
utils.set_start_time(recording_timestamp)
logger.info("Memory writer starting")
signal.signal(signal.SIGINT, signal.SIG_IGN)
process = psutil.Process(record_pid)

while not terminate_event.is_set():
memory_usage_bytes = 0

memory_info = process.memory_info()
rss = memory_info.rss # Resident Set Size: non-swapped physical memory
memory_usage_bytes += rss

for child in process.children(recursive=True):
# after ctrl+c, children may terminate before the next line
try:
child_memory_info = child.memory_info()
except psutil.NoSuchProcess:
continue
child_rss = child_memory_info.rss
rss += child_rss

timestamp = utils.get_timestamp()

crud.insert_memory_stat(
recording_timestamp,
rss,
timestamp,
)
logger.info("Memory writer done")


@trace(logger)
def create_recording(
task_description: str,
Expand Down Expand Up @@ -521,7 +582,7 @@ def on_press(event_q, key, injected):

# stop sequence code
nonlocal stop_sequence_indices
global sequence_detected
global stop_sequence_detected
canonical_key_name = getattr(canonical_key, "name", None)

for i in range(0, len(STOP_SEQUENCES)):
Expand All @@ -547,7 +608,7 @@ def on_press(event_q, key, injected):
# Check if the entire sequence has been entered correctly
if stop_sequence_indices[i] == len(stop_sequence):
logger.info("Stop sequence entered! Stopping recording now.")
sequence_detected = True # Set global flag to end recording
stop_sequence_detected = True

def on_release(event_q, key, injected):
canonical_key = keyboard_listener.canonical(key)
Expand Down Expand Up @@ -694,19 +755,30 @@ def record(
)
perf_stat_writer.start()

if PLOT_PERFORMANCE:
record_pid = os.getpid()
mem_plotter = multiprocessing.Process(
target=memory_writer,
args=(recording_timestamp, terminate_perf_event, record_pid),
)
mem_plotter.start()

# TODO: discard events until everything is ready

global sequence_detected
sequence_detected = False
collect_stats()
global stop_sequence_detected

try:
while not sequence_detected:
while not stop_sequence_detected:
time.sleep(1)

terminate_event.set()
except KeyboardInterrupt:
terminate_event.set()

collect_stats()
log_memory_usage()

logger.info(f"joining...")
keyboard_event_reader.join()
mouse_event_reader.join()
Expand All @@ -720,6 +792,7 @@ def record(
terminate_perf_event.set()

if PLOT_PERFORMANCE:
mem_plotter.join()
utils.plot_performance(recording_timestamp)

logger.info(f"saved {recording_timestamp=}")
Expand Down
Loading

0 comments on commit ef0d5eb

Please sign in to comment.