Skip to content

Commit

Permalink
feat: add progress bar in record.py and visualize.py
Browse files Browse the repository at this point in the history
* run `poetry lock --no-update`

* add alive-progress via poetry and in code

* add progress bar in visualization

* add a check for MAX_EVENT = None

* update the title for the Progress bAr
 (better for USer pov)

* update the requirement.txt

* ran ` black --line-length 80 <file>`
on record.py and visualize.py

* remove all progress bar from record

* add tqdm progress bar in recrod.py

* add tqdm for visualiztion

* remove alive-progress

* consistent tqdm api

--add dynamic_cols: to enable adjustments when window is resized

Order:
--total
-description
--unit
--Optional[bar_format]
--colour
--dynamic_ncols

* Update requirements.txt

Co-authored-by: Aaron <[email protected]>

* Address comemnt:
#318 (comment)

* remove incorrect indent

* remove rows

* try to fix distorted table in html

* add custom queue class

* lint --line-length 80

* fix `NotImplementedError` for MacOs
-- using custom MyQueue class

* rename custom -> thirdparty_customization

* rename to something useful

* address comments

* rename dir to customized_imports

* rename to extensions
#318 (comment)

---------

Co-authored-by: Aaron <[email protected]>
  • Loading branch information
KrishPatel13 and 0dm authored Jul 3, 2023
1 parent d15f683 commit 3e12fd4
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 83 deletions.
99 changes: 99 additions & 0 deletions openadapt/extensions/synchronized_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""
Module for customizing multiprocessing.Queue to avoid NotImplementedError in MacOS
"""


from multiprocessing.queues import Queue
import multiprocessing

# Credit: https://gist.github.com/FanchenBao/d8577599c46eab1238a81857bb7277c9

# The following implementation of custom SynchronizedQueue to avoid NotImplementedError
# when calling queue.qsize() in MacOS X comes almost entirely from this github
# discussion: https://github.com/keras-team/autokeras/issues/368
# Necessary modification is made to make the code compatible with Python3.


class SharedCounter(object):
""" A synchronized shared counter.
The locking done by multiprocessing.Value ensures that only a single
process or thread may read or write the in-memory ctypes object. However,
in order to do n += 1, Python performs a read followed by a write, so a
second process may read the old value before the new one is written by the
first process. The solution is to use a multiprocessing.Lock to guarantee
the atomicity of the modifications to Value.
This class comes almost entirely from Eli Bendersky's blog:
http://eli.thegreenplace.net/2012/01/04/
shared-counter-with-pythons-multiprocessing/
"""

def __init__(self, n=0):
self.count = multiprocessing.Value('i', n)

def increment(self, n=1):
""" Increment the counter by n (default = 1) """
with self.count.get_lock():
self.count.value += n

@property
def value(self):
""" Return the value of the counter """
return self.count.value


class SynchronizedQueue(Queue):
""" A portable implementation of multiprocessing.Queue.
Because of multithreading / multiprocessing semantics, Queue.qsize() may
raise the NotImplementedError exception on Unix platforms like Mac OS X
where sem_getvalue() is not implemented. This subclass addresses this
problem by using a synchronized shared counter (initialized to zero) and
increasing / decreasing its value every time the put() and get() methods
are called, respectively. This not only prevents NotImplementedError from
being raised, but also allows us to implement a reliable version of both
qsize() and empty().
Note the implementation of __getstate__ and __setstate__ which help to
serialize SynchronizedQueue when it is passed between processes. If these functions
are not defined, SynchronizedQueue cannot be serialized, which will lead to the error
of "AttributeError: 'SynchronizedQueue' object has no attribute 'size'".
See the answer provided here: https://stackoverflow.com/a/65513291/9723036
For documentation of using __getstate__ and __setstate__
to serialize objects, refer to here:
https://docs.python.org/3/library/pickle.html#pickling-class-instances
"""

def __init__(self):
super().__init__(ctx=multiprocessing.get_context())
self.size = SharedCounter(0)

def __getstate__(self):
"""Help to make SynchronizedQueue instance serializable.
Note that we record the parent class state, which is the state of the
actual queue, and the size of the queue, which is the state of SynchronizedQueue.
self.size is a SharedCounter instance. It is itself serializable.
"""
return {
'parent_state': super().__getstate__(),
'size': self.size,
}

def __setstate__(self, state):
super().__setstate__(state['parent_state'])
self.size = state['size']

def put(self, *args, **kwargs):
super().put(*args, **kwargs)
self.size.increment(1)

def get(self, *args, **kwargs):
item = super().get(*args, **kwargs)
self.size.increment(-1)
return item

def qsize(self):
""" Reliable implementation of multiprocessing.Queue.qsize() """
return self.size.value

def empty(self):
""" Reliable implementation of multiprocessing.Queue.empty() """
return not self.qsize()
87 changes: 67 additions & 20 deletions openadapt/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
from loguru import logger
from pympler import tracker
from pynput import keyboard, mouse
from tqdm import tqdm
import fire
import mss.tools
import psutil

from openadapt import config, crud, utils, window
from openadapt.extensions import synchronized_queue as sq

Event = namedtuple("Event", ("timestamp", "type", "data"))

Expand Down Expand Up @@ -86,7 +88,9 @@ def wrapper_logging(*args, **kwargs):
func_kwargs = kwargs_to_str(**kwargs)

if func_kwargs != "":
logger.info(f" -> Enter: {func_name}({func_args}, {func_kwargs})")
logger.info(
f" -> Enter: {func_name}({func_args}, {func_kwargs})"
)
else:
logger.info(f" -> Enter: {func_name}({func_args})")

Expand All @@ -110,10 +114,10 @@ def process_event(event, write_q, write_fn, recording_timestamp, perf_q):
@trace(logger)
def process_events(
event_q: queue.Queue,
screen_write_q: multiprocessing.Queue,
action_write_q: multiprocessing.Queue,
window_write_q: multiprocessing.Queue,
perf_q: multiprocessing.Queue,
screen_write_q: sq.SynchronizedQueue,
action_write_q: sq.SynchronizedQueue,
window_write_q: sq.SynchronizedQueue,
perf_q: sq.SynchronizedQueue,
recording_timestamp: float,
terminate_event: multiprocessing.Event,
):
Expand Down Expand Up @@ -193,7 +197,7 @@ def process_events(
def write_action_event(
recording_timestamp: float,
event: Event,
perf_q: multiprocessing.Queue,
perf_q: sq.SynchronizedQueue,
):
"""
Write an action event to the database and update the performance queue.
Expand All @@ -212,7 +216,7 @@ def write_action_event(
def write_screen_event(
recording_timestamp: float,
event: Event,
perf_q: multiprocessing.Queue,
perf_q: sq.SynchronizedQueue,
):
"""
Write a screen event to the database and update the performance queue.
Expand All @@ -234,7 +238,7 @@ def write_screen_event(
def write_window_event(
recording_timestamp: float,
event: Event,
perf_q: multiprocessing.Queue,
perf_q: sq.SynchronizedQueue,
):
"""
Write a window event to the database and update the performance queue.
Expand All @@ -254,10 +258,11 @@ def write_window_event(
def write_events(
event_type: str,
write_fn: Callable,
write_q: multiprocessing.Queue,
perf_q: multiprocessing.Queue,
write_q: sq.SynchronizedQueue,
perf_q: sq.SynchronizedQueue,
recording_timestamp: float,
terminate_event: multiprocessing.Event,
term_pipe: multiprocessing.Pipe,
):
"""
Write events of a specific type to the db using the provided write function.
Expand All @@ -269,20 +274,48 @@ def write_events(
perf_q: A queue for collecting performance data.
recording_timestamp: The timestamp of the recording.
terminate_event: An event to signal the termination of the process.
term_pipe: A pipe for communicating \
the number of events left to be written.
"""

utils.configure_logging(logger, LOG_LEVEL)
utils.set_start_time(recording_timestamp)
logger.info(f"{event_type=} starting")
signal.signal(signal.SIGINT, signal.SIG_IGN)
while not terminate_event.is_set() or not write_q.empty():

num_left = 0
progress = None
while (
not terminate_event.is_set() or
not write_q.empty()
):
if term_pipe.poll():
num_left = term_pipe.recv()
if num_left != 0 and progress is None:
progress = tqdm(
total=num_left,
desc="Writing to Database",
unit="event",
colour="green",
dynamic_ncols=True,
)
if (
terminate_event.is_set() and
num_left != 0 and
progress is not None
):
progress.update()
try:
event = write_q.get_nowait()
except queue.Empty:
continue
assert event.type == event_type, (event_type, event)
write_fn(recording_timestamp, event, perf_q)
logger.debug(f"{event_type=} written")

if progress is not None:
progress.close()

logger.info(f"{event_type=} done")


Expand Down Expand Up @@ -375,15 +408,18 @@ def handle_key(
"vk",
]
attrs = {
f"key_{attr_name}": getattr(key, attr_name, None) for attr_name in attr_names
f"key_{attr_name}": getattr(key, attr_name, None)
for attr_name in attr_names
}
logger.debug(f"{attrs=}")
canonical_attrs = {
f"canonical_key_{attr_name}": getattr(canonical_key, attr_name, None)
for attr_name in attr_names
}
logger.debug(f"{canonical_attrs=}")
trigger_action_event(event_q, {"name": event_name, **attrs, **canonical_attrs})
trigger_action_event(
event_q, {"name": event_name, **attrs, **canonical_attrs}
)


def read_screen_events(
Expand Down Expand Up @@ -463,7 +499,7 @@ def read_window_events(

@trace(logger)
def performance_stats_writer(
perf_q: multiprocessing.Queue,
perf_q: sq.SynchronizedQueue,
recording_timestamp: float,
terminate_event: multiprocessing.Event,
):
Expand Down Expand Up @@ -660,13 +696,17 @@ def record(
recording_timestamp = recording.timestamp

event_q = queue.Queue()
screen_write_q = multiprocessing.Queue()
action_write_q = multiprocessing.Queue()
window_write_q = multiprocessing.Queue()
screen_write_q = sq.SynchronizedQueue()
action_write_q = sq.SynchronizedQueue()
window_write_q = sq.SynchronizedQueue()
# TODO: save write times to DB; display performance plot in visualize.py
perf_q = multiprocessing.Queue()
perf_q = sq.SynchronizedQueue()
terminate_event = multiprocessing.Event()


term_pipe_parent_window, term_pipe_child_window = multiprocessing.Pipe()
term_pipe_parent_screen, term_pipe_child_screen = multiprocessing.Pipe()
term_pipe_parent_action, term_pipe_child_action = multiprocessing.Pipe()

window_event_reader = threading.Thread(
target=read_window_events,
args=(event_q, terminate_event, recording_timestamp),
Expand Down Expand Up @@ -714,6 +754,7 @@ def record(
perf_q,
recording_timestamp,
terminate_event,
term_pipe_child_screen,
),
)
screen_event_writer.start()
Expand All @@ -727,6 +768,7 @@ def record(
perf_q,
recording_timestamp,
terminate_event,
term_pipe_child_action,
),
)
action_event_writer.start()
Expand All @@ -740,6 +782,7 @@ def record(
perf_q,
recording_timestamp,
terminate_event,
term_pipe_child_window,
),
)
window_event_writer.start()
Expand Down Expand Up @@ -776,9 +819,14 @@ def record(
except KeyboardInterrupt:
terminate_event.set()


collect_stats()
log_memory_usage()

term_pipe_parent_window.send(window_write_q.qsize())
term_pipe_parent_action.send(action_write_q.qsize())
term_pipe_parent_screen.send(screen_write_q.qsize())

logger.info(f"joining...")
keyboard_event_reader.join()
mouse_event_reader.join()
Expand All @@ -788,7 +836,6 @@ def record(
screen_event_writer.join()
action_event_writer.join()
window_event_writer.join()

terminate_perf_event.set()

if PLOT_PERFORMANCE:
Expand Down
1 change: 1 addition & 0 deletions openadapt/scripts/scrub.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def scrub_mp4(
unit="frame",
bar_format=progress_bar_format,
colour="green",
dynamic_ncols=True,
)
progress_interval = 0.1 # Print progress every 10% of frames
progress_threshold = math.floor(frame_count * progress_interval)
Expand Down
Loading

0 comments on commit 3e12fd4

Please sign in to comment.