diff --git a/README.md b/README.md index da3fec700..ce00e9d60 100644 --- a/README.md +++ b/README.md @@ -35,9 +35,8 @@ with the power of Large Multimodal Modals (LMMs) by: - Recording screenshots and associated user input - Aggregating and visualizing user input and recordings for development - Converting screenshots and user input into tokenized format -- Generating synthetic input via transformer model completions -- Generating task trees by analyzing recordings (work-in-progress) -- Replaying synthetic input to complete tasks (work-in-progress) +- Generating and replaying synthetic input via transformer model completions +- Generating process graphs by analyzing recording logs (work-in-progress) The goal is similar to that of [Robotic Process Automation](https://en.wikipedia.org/wiki/Robotic_process_automation), @@ -165,37 +164,6 @@ pointing the cursor and left or right clicking, as described in this [open issue](https://github.com/OpenAdaptAI/OpenAdapt/issues/145) -### Capturing Browser Events - -To capture (record) browser events in Chrome, follow these steps: - -1. Go to: [Chrome Extension Page](chrome://extensions/) - -2. Enable `Developer mode` (located at the top right): - -![image](https://github.com/OpenAdaptAI/OpenAdapt/assets/65433817/c97eb9fb-05d6-465d-85b3-332694556272) - -3. Click `Load unpacked` (located at the top left). - -![image](https://github.com/OpenAdaptAI/OpenAdapt/assets/65433817/00c8adf5-074a-4655-b132-fd87644007fc) - -4. Select the `chrome_extension` directory: - -![image](https://github.com/OpenAdaptAI/OpenAdapt/assets/65433817/71610ed3-f8d4-431a-9a22-d901127b7b0c) - -5. You should see the following confirmation, indicating that the extension is loaded: - -![image](https://github.com/OpenAdaptAI/OpenAdapt/assets/65433817/7ee19da9-37e0-448f-b9ab-08ef99110e85) - -6. Set the flag to `true` if it is currently `false`: - -![image](https://github.com/user-attachments/assets/8eba24a3-7c68-4deb-8fbe-9d03cece1482) - -7. Start recording. Once recording begins, navigate to the Chrome browser, browse some pages, and perform a few clicks. Then, stop the recording and let it complete successfully. - -8. After recording, check the `openadapt.db` table `browser_event`. It should contain all your browser activity logs. You can verify the data's correctness using the `sqlite3` CLI or an extension like `SQLite Viewer` in VS Code to open `data/openadapt.db`. - - ### Visualize Quickly visualize the latest recording you created by running the following command: @@ -243,6 +211,7 @@ Other replay strategies include: - [`StatefulReplayStrategy`](https://github.com/OpenAdaptAI/OpenAdapt/blob/main/openadapt/strategies/stateful.py): Early proof-of-concept which uses the OpenAI GPT-4 API with prompts constructed via OS-level window data. - (*)[`VisualReplayStrategy`](https://github.com/OpenAdaptAI/OpenAdapt/blob/main/openadapt/strategies/visual.py): Uses [Fast Segment Anything Model (FastSAM)](https://github.com/CASIA-IVA-Lab/FastSAM) to segment active window. - (*)[`VanillaReplayStrategy`](https://github.com/OpenAdaptAI/OpenAdapt/blob/main/openadapt/strategies/vanilla.py): Assumes the model is capable of directly reasoning on states and actions accurately. With future frontier models, we hope that this script will suddenly work a lot better. +- (*)[`VisualBrowserReplayStrategy`](https://github.com/OpenAdaptAI/OpenAdapt/blob/main/openadapt/strategies/visual_browser.py): Like VisualReplayStrategy but generates segments from the visible DOM read by the browser extension. The (*) prefix indicates strategies which accept an "instructions" parameter that is used to modify the recording, e.g.: @@ -253,6 +222,22 @@ python -m openadapt.replay VanillaReplayStrategy --instructions "calculate 9-8" See https://github.com/OpenAdaptAI/OpenAdapt/tree/main/openadapt/strategies for a complete list. More ReplayStrategies coming soon! (see [Contributing](#Contributing)). +### Browser integration + +To record browser events in Google Chrome (required by the `BrowserReplayStrategy`), follow these steps: + +1. Go to your Chrome extensions page by entering [chrome://extensions](chrome://extensions/) in your address bar. + +2. Enable `Developer mode` (located at the top right). + +3. Click `Load unpacked` (located at the top left). + +4. Select the `chrome_extension` directory in the OpenAdapt repo. + +5. Make sure the Chrome extension is enabled (the switch to the right of the OpenAdapt extension widget is turned on). + +6. Set the `RECORD_BROWSER_EVENTS` flag to `true` in `openadapt/data/config.json`. + ## Features ### State-of-the-art GUI understanding via [Segment Anything in High Quality](https://github.com/SysCV/sam-hq): @@ -306,13 +291,6 @@ We're looking forward to your contributions. Let's build the future 🚀 ## Contributing -### Notable Works-in-progress (incomplete, see https://github.com/OpenAdaptAI/OpenAdapt/pulls and https://github.com/OpenAdaptAI/OpenAdapt/issues/ for more) - -- [Video Recording Hardware Acceleration](https://github.com/OpenAdaptAI/OpenAdapt/issues/570) (help wanted) -- [Audio Narration](https://github.com/OpenAdaptAI/OpenAdapt/pull/346) (help wanted) -- [Chrome Extension](https://github.com/OpenAdaptAI/OpenAdapt/pull/364) (help wanted) -- [Gemini Vision](https://github.com/OpenAdaptAI/OpenAdapt/issues/551) (help wanted) - ### Replay Problem Statement Our goal is to automate the task described and demonstrated in a `Recording`. diff --git a/chrome_extension/background.js b/chrome_extension/background.js index a747b8669..24e6203fb 100644 --- a/chrome_extension/background.js +++ b/chrome_extension/background.js @@ -1,33 +1,28 @@ /** * @file background.js - * @description Creates a new background script that listens for messages from the content script - * and sends them to a WebSocket server. -*/ + * @description Background script that maintains the current mode and communicates with content scripts. + */ let socket; +let currentMode = null; // Maintain the current mode here let timeOffset = 0; // Global variable to store the time offset -/* - * TODO: - * Ideally we read `WS_SERVER_PORT`, `WS_SERVER_ADDRESS` and - * `RECONNECT_TIMEOUT_INTERVAL` from config.py, - * or it gets passed in somehow. -*/ +/* + * Note: these need to match the corresponding values in config[.defaults].json + */ let RECONNECT_TIMEOUT_INTERVAL = 1000; // ms let WS_SERVER_PORT = 8765; let WS_SERVER_ADDRESS = "localhost"; let WS_SERVER_URL = "ws://" + WS_SERVER_ADDRESS + ":" + WS_SERVER_PORT; - function socketSend(socket, message) { console.log({ message }); socket.send(JSON.stringify(message)); } - /* * Function to connect to the WebSocket server. -*/ + */ function connectWebSocket() { socket = new WebSocket(WS_SERVER_URL); @@ -38,11 +33,34 @@ function connectWebSocket() { socket.onmessage = function(event) { console.log("Message from server:", event.data); const message = JSON.parse(event.data); + + // Handle mode messages + if (message.type === 'SET_MODE') { + currentMode = message.mode; // Update the current mode + console.log(`Mode set to: ${currentMode}`); + + // Send the mode to all active tabs + chrome.tabs.query( + { + active: true, + }, + function(tabs) { + tabs.forEach(function(tab) { + chrome.tabs.sendMessage(tab.id, message, function(response) { + if (chrome.runtime.lastError) { + console.error("Error sending message to content script in tab " + tab.id, chrome.runtime.lastError.message); + } else { + console.log("Message sent to content script in tab " + tab.id, response); + } + }); + }); + } + ); + } }; socket.onclose = function(event) { console.log("WebSocket connection closed", event); - // Reconnect after 5 seconds if the connection is lost setTimeout(connectWebSocket, RECONNECT_TIMEOUT_INTERVAL); }; @@ -66,3 +84,32 @@ chrome.runtime.onMessage.addListener((message, sender, sendResponse) => { sendResponse({ status: "WebSocket connection not open" }); } }); + +/* Listen for tab activation */ +chrome.tabs.onActivated.addListener((activeInfo) => { + // Send current mode to the newly active tab if it's not null + if (currentMode) { + const message = { type: 'SET_MODE', mode: currentMode }; + chrome.tabs.sendMessage(activeInfo.tabId, message, function(response) { + if (chrome.runtime.lastError) { + console.error("Error sending message to content script in tab " + activeInfo.tabId, chrome.runtime.lastError.message); + } else { + console.log("Message sent to content script in tab " + activeInfo.tabId, response); + } + }); + } +}); + +/* Listen for tab updates to handle new pages or reloading */ +chrome.tabs.onUpdated.addListener((tabId, changeInfo, tab) => { + if (changeInfo.status === 'complete' && currentMode) { + const message = { type: 'SET_MODE', mode: currentMode }; + chrome.tabs.sendMessage(tabId, message, function(response) { + if (chrome.runtime.lastError) { + console.error("Error sending message to content script in tab " + tabId, chrome.runtime.lastError.message); + } else { + console.log("Message sent to content script in tab " + tabId, response); + } + }); + } +}); diff --git a/chrome_extension/content.js b/chrome_extension/content.js index a08daabb8..79c95c42e 100644 --- a/chrome_extension/content.js +++ b/chrome_extension/content.js @@ -1,4 +1,151 @@ const DEBUG = true; + +if (!DEBUG) { + console.debug = function() {}; +} + +let currentMode = "idle"; // Default mode is 'idle' +let recordListenersAttached = false; // Track if record listeners are currently attached +let replayObserversAttached = false; // Track if replay observers are currently attached + +function setMode(mode) { + currentMode = mode; + console.log(`Mode set to: ${currentMode}`); + + // Attach or detach listeners based on mode + if (currentMode === 'record') { + if (!recordListenersAttached) attachRecordListeners(); + if (replayObserversAttached) disconnectReplayObservers(); // Detach replay observers if needed + } else if (currentMode === 'replay') { + debounceSendVisibleHTML('setmode'); + if (!replayObserversAttached) attachReplayObservers(); + if (recordListenersAttached) detachRecordListeners(); // Detach record listeners if needed + } else if (currentMode === 'idle') { + if (recordListenersAttached) detachRecordListeners(); + if (replayObserversAttached) disconnectReplayObservers(); + } +} + +// Listen for messages from the background script or Python process +chrome.runtime.onMessage.addListener((message, sender, sendResponse) => { + console.log("Received message:", message); + if (message.type === 'SET_MODE') { + setMode(message.mode); + } +}); + +// Attach event listeners for recording mode +function attachRecordListeners() { + if (!recordListenersAttached) { + attachUserEventListeners(); + attachInstrumentationEventListeners(); + recordListenersAttached = true; + } +} + +/** + * Attach event listeners for user-generated events, with specific capturing behavior. + */ +function attachUserEventListeners() { + const eventTargetMap = { + 'click': document.body, + 'keydown': document.body, + 'keyup': document.body, + 'mousemove': document.body, + 'scroll': document, + }; + + const eventDebounceDelayMap = { + 'click': 0, // No debounce + 'keydown': 0, // No debounce + 'keyup': 0, // No debounce + 'mousemove': 100, // 100ms debounce + 'scroll': 100, // 100ms debounce + }; + + const lastEventTimeMap = new Map(); + + // Attach event listeners + Object.entries(eventTargetMap).forEach(([eventType, target]) => { + target.addEventListener(eventType, (event) => { + const debounceDelay = eventDebounceDelayMap[eventType]; + const lastEventTime = lastEventTimeMap.get(eventType) || 0; + const now = Date.now(); + + if (now - lastEventTime >= debounceDelay) { + console.log({ eventType }); + handleUserEvent(event); + lastEventTimeMap.set(eventType, now); + } + }, true); + }); +} + +// Attach instrumentation event listeners +function attachInstrumentationEventListeners() { + const eventsToCapture = [ + 'mousedown', + 'mouseup', + 'mousemove', + ]; + eventsToCapture.forEach(eventType => { + document.body.addEventListener(eventType, trackMouseEvent, true); + }); +} + + +// Detach all event listeners for recording mode +function detachRecordListeners() { + const eventsToCapture = [ + 'click', 'keydown', 'keyup', 'mousedown', 'mouseup', 'mousemove', + ]; + + eventsToCapture.forEach(eventType => { + document.body.removeEventListener(eventType, handleUserEvent, true); + document.body.removeEventListener(eventType, trackMouseEvent, true); + }); + + recordListenersAttached = false; +} + +// Attach observers for replay mode +function attachReplayObservers() { + if (!replayObserversAttached) { + setupIntersectionObserver(); + setupMutationObserver(); + setupScrollAndResizeListeners(); + replayObserversAttached = true; + } +} + +// Disconnect observers for replay mode +function disconnectReplayObservers() { + if (window.intersectionObserverInstance) { + window.intersectionObserverInstance.disconnect(); + } + if (window.mutationObserverInstance) { + window.mutationObserverInstance.disconnect(); + } + window.removeEventListener('scroll', handleScrollEvent, { passive: true }); + window.removeEventListener('resize', handleResizeEvent, { passive: true }); + + replayObserversAttached = false; +} + +// Handle scroll events +function handleScrollEvent(event) { + debounceSendVisibleHTML(event.type); +} + +// Handle resize events +function handleResizeEvent(event) { + debounceSendVisibleHTML(event.type); +} + +/* + * Record + */ + const RETURN_FULL_DOCUMENT = false; const MAX_COORDS = 3; const SET_SCREEN_COORDS = false; @@ -123,19 +270,25 @@ function sendMessageToBackgroundScript(message) { } function generateElementIdAndBbox(element) { + console.debug(`[generateElementIdAndBbox] Processing element: ${element.tagName}`); + // ignore invisible elements if (!isVisible(element)) { + console.debug(`[generateElementIdAndBbox] Element is not visible: ${element.tagName}`); return; } // set id if (!elementIdMap.has(element)) { const newId = `elem-${elementIdCounter++}`; + console.debug(`[generateElementIdAndBbox] Generated new ID: ${newId} for element: ${element.tagName}`); elementIdMap.set(element, newId); idToElementMap.set(newId, element); // Reverse mapping element.setAttribute('data-id', newId); } + // TODO: store bounding boxes in a map instead of in DOM attributes + // set client bbox let { top, left, bottom, right } = element.getBoundingClientRect(); let bboxClient = `${top},${left},${bottom},${right}`; @@ -143,6 +296,7 @@ function generateElementIdAndBbox(element) { // set screen bbox if (SET_SCREEN_COORDS) { + // XXX TODO: support in replay mode, or remove altogether ({ top, left, bottom, right } = getScreenCoordinates(element)); if (top == null) { // not enough data points to get screen coordinates @@ -214,17 +368,17 @@ function cleanDomTree(node) { } } -function getVisibleHtmlString() { +function getVisibleHTMLString() { const startTime = performance.now(); // Step 1: Instrument the live DOM with data-id and data-bbox attributes instrumentLiveDomWithBbox(); if (RETURN_FULL_DOCUMENT) { - const visibleHtmlDuration = performance.now() - startTime; - console.log({ visibleHtmlDuration }); - const visibleHtmlString = document.body.outerHTML; - return { visibleHtmlString, visibleHtmlDuration }; + const visibleHTMLDuration = performance.now() - startTime; + console.log({ visibleHTMLDuration }); + const visibleHTMLString = document.body.outerHTML; + return { visibleHTMLString, visibleHTMLDuration }; } // Step 2: Clone the body @@ -234,12 +388,12 @@ function getVisibleHtmlString() { cleanDomTree(clonedBody); // Step 4: Serialize the modified clone to a string - const visibleHtmlString = clonedBody.outerHTML; + const visibleHTMLString = clonedBody.outerHTML; - const visibleHtmlDuration = performance.now() - startTime; - console.log({ visibleHtmlDuration }); + const visibleHTMLDuration = performance.now() - startTime; + console.debug({ visibleHTMLDuration }); - return { visibleHtmlString, visibleHtmlDuration }; + return { visibleHTMLString, visibleHTMLDuration }; } /** @@ -277,20 +431,43 @@ function validateCoordinates(event, eventTarget, attrType, coordX, coordY) { } } -function handleUserGeneratedEvent(event) { - const eventTarget = event.target; +let lastScrollPosition = { x: window.scrollX, y: window.scrollY }; + +function handleUserEvent(event) { + let eventTarget = event.target; + + // Fallback to eventTarget.activeElement if event.target is not an HTMLElement + if (!(eventTarget instanceof HTMLElement)) { + console.warn(`Event target is not an HTMLElement: ${eventTarget}, using eventTarget.activeElement instead.`); + eventTarget = eventTarget.activeElement; + } + const eventTargetId = generateElementIdAndBbox(eventTarget); const timestamp = Date.now() / 1000; // Convert to Python-compatible seconds - const { visibleHtmlString, visibleHtmlDuration } = getVisibleHtmlString(); + const { visibleHTMLString, visibleHTMLDuration } = getVisibleHTMLString(); + + // Calculate scroll displacement + const currentScrollX = window.scrollX; + const currentScrollY = window.scrollY; + const scrollDeltaX = currentScrollX - lastScrollPosition.x; + const scrollDeltaY = currentScrollY - lastScrollPosition.y; + + // Update last scroll position + lastScrollPosition = { x: currentScrollX, y: currentScrollY }; + + // Retrieve the last recorded mouse coordinates from coordMappings + const lastMouseClientX = coordMappings.x.client[coordMappings.x.client.length - 1] || -1; + const lastMouseClientY = coordMappings.y.client[coordMappings.y.client.length - 1] || -1; const eventData = { type: 'USER_EVENT', eventType: event.type, targetId: eventTargetId, timestamp: timestamp, - visibleHtmlString, - visibleHtmlDuration, + visibleHTMLString, + visibleHTMLDuration, + devicePixelRatio, }; if (event instanceof KeyboardEvent) { @@ -307,38 +484,196 @@ function handleUserGeneratedEvent(event) { if (SET_SCREEN_COORDS) { validateCoordinates(event, eventTarget, 'screen', 'screenX', 'screenY'); } + } else if (event.type == 'scroll') { + eventData.scrollDeltaX = scrollDeltaX; + // negative to match pynput + eventData.scrollDeltaY = -scrollDeltaY; + // Use last known mouse coordinates for scroll events + eventData.clientX = lastMouseClientX; + eventData.clientY = lastMouseClientY; + console.log(JSON.stringify(coordMappings)); + console.log("scroll", { eventData }); } + sendMessageToBackgroundScript(eventData); } -// Attach event listeners for user-generated events -function attachUserEventListeners() { - const eventsToCapture = [ - 'click', - // input events are triggered after the DOM change is written, so we can't use them - // (since the resulting HTML would not look as the DOM was at the time the - // user took the action, i.e. immediately before) - //'input', - 'keydown', - 'keyup', - ]; - eventsToCapture.forEach(eventType => { - document.body.addEventListener(eventType, handleUserGeneratedEvent, true); +/* + * Replay + */ + +let debounceTimeoutId = null; // Timeout ID for debouncing +const DEBOUNCE_DELAY = 10; + +function setupIntersectionObserver() { + const observer = new IntersectionObserver(handleIntersection, { + root: null, // Use the viewport as the root + threshold: 0 // Consider an element visible if any part of it is in view }); + + document.querySelectorAll('*').forEach(element => observer.observe(element)); } -function attachInstrumentationEventListeners() { - const eventsToCapture = [ - 'mousedown', - 'mouseup', - 'mousemove', - ]; - eventsToCapture.forEach(eventType => { - document.body.addEventListener(eventType, trackMouseEvent, true); +function handleIntersection(entries) { + let shouldSendUpdate = false; + entries.forEach(entry => { + if (entry.isIntersecting) { + shouldSendUpdate = true; + } }); + if (shouldSendUpdate) { + debounceSendVisibleHTML('intersection'); + } +} + +function setupMutationObserver() { + const observer = new MutationObserver(handleMutations); + observer.observe(document.body, { + childList: true, + // XXX this results in continuous DOM_EVENT messages on some websites (e.g. ChatGPT) + subtree: true, + attributes: true + }); +} + +function handleMutations(mutationsList) { + const startTime = performance.now(); // Capture start time for the instrumentation + console.debug(`[handleMutations] Start handling ${mutationsList.length} mutations at ${startTime}`); + + let shouldSendUpdate = false; + + for (const mutation of mutationsList) { + console.debug(`[handleMutations] Mutation type: ${mutation.type}, target: ${mutation.target.tagName}`); + for (const node of mutation.addedNodes) { + if (node.nodeType === Node.ELEMENT_NODE) { + console.debug(`[handleMutations] Added node: ${node.tagName}`); + + // Uncommenting this freezes some websites (e.g. ChatGPT). + // It should not be necessary to call this here since it is also called in + // getVisibleHTMLString. + //generateElementIdAndBbox(node); // Generate a new ID and bbox for the added node + + if (isVisible(node)) { + shouldSendUpdate = true; + break; // Exit the loop early + } + } + } + if (shouldSendUpdate) break; // Exit outer loop if update is needed + + for (const node of mutation.removedNodes) { + console.log(`[handleMutations] Removed node: ${node.tagName}`); + if (node.nodeType === Node.ELEMENT_NODE && idToElementMap.has(node.getAttribute('data-id'))) { + shouldSendUpdate = true; + break; // Exit the loop early + } + } + if (shouldSendUpdate) break; // Exit outer loop if update is needed + } + + const endTime = performance.now(); + console.debug(`[handleMutations] Finished handling mutations. Duration: ${endTime - startTime}ms`); + + if (shouldSendUpdate) { + debounceSendVisibleHTML('mutation'); + } } -// Initial setup -attachUserEventListeners(); -attachInstrumentationEventListeners(); +function debounceSendVisibleHTML(eventType) { + // Clear the previous timeout, if any + if (debounceTimeoutId) { + clearTimeout(debounceTimeoutId); + } + + console.debug(`[debounceSendVisibleHTML] Debouncing visible HTML send for event: ${eventType}`); + // Set a new timeout + debounceTimeoutId = setTimeout(() => { + sendVisibleHTML(eventType); + }, DEBOUNCE_DELAY); +} + +function sendVisibleHTML(eventType) { + console.debug(`Handling DOM event: ${eventType}`); + const timestamp = Date.now() / 1000; // Convert to Python-compatible seconds + + const { visibleHTMLString, visibleHTMLDuration } = getVisibleHTMLString(); + + const eventData = { + type: 'DOM_EVENT', + eventType: eventType, + timestamp: timestamp, + visibleHTMLString, + visibleHTMLDuration, + }; + + sendMessageToBackgroundScript(eventData); +} + +function setupScrollAndResizeListeners() { + window.addEventListener('scroll', handleScrollEvent, { passive: true }); + window.addEventListener('resize', handleResizeEvent, { passive: true }); +} + +/* Debugging */ + +const DEBUG_DRAW = false; // Flag for drawing bounding boxes + +// Start continuous drawing if DEBUG_DRAW is enabled +if (DEBUG_DRAW) { + startDrawingBoundingBoxes(); +} + +/** + * Start continuously drawing bounding boxes for visible elements. + */ +function startDrawingBoundingBoxes() { + function drawBoundingBoxesLoop() { + // Clean up existing bounding boxes before drawing new ones + cleanUpBoundingBoxes(); + + // Query all visible elements and draw their bounding boxes + document.querySelectorAll('*').forEach(element => { + if (isVisible(element)) { + drawBoundingBoxForElement(element); + } + }); + + // Use requestAnimationFrame for continuous updates without performance impact + requestAnimationFrame(drawBoundingBoxesLoop); + } + + // Kick off the loop + drawBoundingBoxesLoop(); +} + +/** + * Draw a bounding box for the given element. + * Uses client coordinates. + * @param {HTMLElement} element - The DOM element to draw the bounding box for. + */ +function drawBoundingBoxForElement(element) { + const { top, left, bottom, right } = element.getBoundingClientRect(); + + // Create and style the overlay to represent the bounding box + let bboxOverlay = document.createElement('div'); + bboxOverlay.style.position = 'absolute'; + bboxOverlay.style.border = '2px solid red'; + bboxOverlay.style.top = `${top + window.scrollY}px`; // Adjust for scrolling + bboxOverlay.style.left = `${left + window.scrollX}px`; // Adjust for scrolling + bboxOverlay.style.width = `${right - left}px`; + bboxOverlay.style.height = `${bottom - top}px`; + bboxOverlay.style.pointerEvents = 'none'; // Prevent interference with normal element interactions + bboxOverlay.style.zIndex = '9999'; // Ensure it's drawn on top + bboxOverlay.setAttribute('data-debug-bbox', element.getAttribute('data-id') || ''); + + // Append the overlay to the body + document.body.appendChild(bboxOverlay); +} + +/** + * Clean up all existing bounding boxes to prevent overlapping or lingering overlays. + */ +function cleanUpBoundingBoxes() { + document.querySelectorAll('[data-debug-bbox]').forEach(overlay => overlay.remove()); +} diff --git a/openadapt/browser.py b/openadapt/browser.py index aa8f1cd6b..f55e734f3 100644 --- a/openadapt/browser.py +++ b/openadapt/browser.py @@ -1,23 +1,34 @@ """Utilities for working with BrowserEvents.""" from statistics import mean, median, stdev +import json from bs4 import BeautifulSoup -from copy import deepcopy from dtaidistance import dtw, dtw_ndim -from loguru import logger from sqlalchemy.orm import Session as SaSession from tqdm import tqdm import numpy as np +import websockets.sync.server from openadapt import models, utils +from openadapt.custom_logger import logger from openadapt.db import crud # action to browser -MOUSE_BUTTON_MAPPING = {"left": 0, "right": 2, "middle": 1} +MOUSE_BUTTON_MAPPING = { + "left": 0, + "right": 2, + "middle": 1, +} # action to browser -EVENT_TYPE_MAPPING = {"click": "click", "press": "keydown", "release": "keyup"} +EVENT_TYPE_MAPPING = { + "click": "click", + "press": "keydown", + "release": "keyup", + "move": "mousemove", + "scroll": "scroll", +} SPATIAL = True @@ -79,95 +90,107 @@ ] -def add_screen_tlbr(browser_events: list[models.BrowserEvent]) -> None: +def set_browser_mode( + mode: str, websocket: websockets.sync.server.ServerConnection +) -> None: + """Send a message to the browser extension to set the mode.""" + logger.info(f"{type(websocket)=}") + VALID_MODES = ("idle", "record", "replay") + assert mode in VALID_MODES, f"{mode=} not in {VALID_MODES=}" + message = json.dumps({"type": "SET_MODE", "mode": mode}) + logger.info(f"sending {message=}") + websocket.send(message) + + +def add_screen_tlbr( + browser_events: list[models.BrowserEvent], target_element_only: bool = False +) -> None: """Computes and adds the 'data-tlbr-screen' attribute for each element. Uses coordMappings provided by JavaScript events. If 'data-tlbr-screen' already - exists, compute the values again and assert equality. Reuse the most recent valid - mappings if none exist for the current event by iterating over the events in - reverse order. + exists, compute the values again and assert equality. Reuse the closest valid + mappings (before or after) if none exist for the current event. Args: browser_events (list[models.BrowserEvent]): list of browser events to process. + target_element_only (bool): if True, only process the target element. If False, + process all elements with the 'data-tlbr-client' property. """ - # Initialize variables to store the most recent valid mappings - latest_valid_x_mappings = None - latest_valid_y_mappings = None - - # Iterate over the events in reverse order - for event in reversed(browser_events): - message = event.message - - event_type = message.get("eventType") - if event_type != "click": - continue - - visible_html_string = message.get("visibleHtmlString") - if not visible_html_string: - logger.warning("No visible HTML data available for event.") - continue + n = len(browser_events) + prev_valid = [None] * n + next_valid = [None] * n - # Parse the visible HTML using BeautifulSoup - soup = BeautifulSoup(visible_html_string, "html.parser") - - # Fetch the target element using its data-id - target_id = message.get("targetId") - target_element = soup.find(attrs={"data-id": target_id}) - - if not target_element: - logger.warning(f"No target element found for targetId: {target_id}") - continue - - # Extract coordMappings from the message + def update_valid_mappings(index: int) -> tuple[dict, dict] | None: + """Helper to check if the event at the given index has valid coordMappings.""" + message = browser_events[index].message coord_mappings = message.get("coordMappings", {}) x_mappings = coord_mappings.get("x", {}) y_mappings = coord_mappings.get("y", {}) - # Check if there are sufficient data points; if not, reuse latest valid mappings - if ( + if are_mappings_valid(x_mappings, y_mappings): + return (x_mappings, y_mappings) + + return None + + def are_mappings_valid(x_mappings: dict, y_mappings: dict) -> bool: + """Check if the mappings contain sufficient data points.""" + return ( "client" in x_mappings and len(x_mappings["client"]) >= 2 + and "client" in y_mappings and len(y_mappings["client"]) >= 2 - ): - # Update the latest valid mappings - latest_valid_x_mappings = x_mappings - latest_valid_y_mappings = y_mappings - else: - # Reuse the most recent valid mappings - if latest_valid_x_mappings is None or latest_valid_y_mappings is None: - logger.warning( - f"No valid coordinate mappings available for element: {target_id}" - ) - continue # No valid mappings available, skip this event - - x_mappings = latest_valid_x_mappings - y_mappings = latest_valid_y_mappings - - # Compute the scale and offset for both x and y axes - sx_scale, sx_offset = fit_linear_transformation( - x_mappings["client"], x_mappings["screen"] - ) - sy_scale, sy_offset = fit_linear_transformation( - y_mappings["client"], y_mappings["screen"] ) - # Only process the screen coordinates + # Forward pass: Track the closest previous valid mapping + last_valid = None + for i in range(n): + last_valid = update_valid_mappings(i) or last_valid + prev_valid[i] = last_valid + + # Reverse pass: Track the closest next valid mapping + last_valid = None + for i in range(n - 1, -1, -1): + last_valid = update_valid_mappings(i) or last_valid + next_valid[i] = last_valid + + # Forward pass: set clientX/clientY on scroll events that don't have them + prev_client_x = None + prev_client_y = None + for event in browser_events: + if "clientX" not in event.message: + continue + client_x = event.message.get("clientX") + client_y = event.message.get("clientY") + if client_x == -1: + assert prev_client_x is not None + logger.info(f"updating {client_x=} to {prev_client_x=}") + event.message["clientX"] = prev_client_x + else: + prev_client_x = client_x + if client_y == -1: + assert prev_client_y is not None + logger.info(f"updating {client_y=} to {prev_client_y=}") + event.message["clientY"] = prev_client_y + else: + prev_client_y = client_y + + def process_element( + element: BeautifulSoup, + sx_scale: float, + sx_offset: float, + sy_scale: float, + sy_offset: float, + ) -> None: + """Helper to compute and update 'data-tlbr-screen' attribute for an element.""" tlbr_attr = "data-tlbr-screen" - try: - # Get existing screen coordinates if present - existing_screen_coords = ( - target_element[tlbr_attr] if tlbr_attr in target_element.attrs else None - ) - except KeyError: - existing_screen_coords = None + existing_screen_coords = element.get(tlbr_attr, None) - # Compute screen coordinates - client_coords = target_element.get("data-tlbr-client") + client_coords = element.get("data-tlbr-client") if not client_coords: logger.warning( - f"Missing client coordinates for element with id: {target_id}" + f"Missing client coordinates for element with id: {element.get('id')}" ) - continue + return # Extract client coordinates client_top, client_left, client_bottom, client_right = map( @@ -182,7 +205,7 @@ def add_screen_tlbr(browser_events: list[models.BrowserEvent]) -> None: # New computed screen coordinates new_screen_coords = f"{screen_top},{screen_left},{screen_bottom},{screen_right}" - logger.info(f"{client_coords=} {existing_screen_coords=} {new_screen_coords=}") + logger.trace(f"{client_coords=} {existing_screen_coords=} {new_screen_coords=}") # Check for existing data-tlbr-screen attribute if existing_screen_coords: @@ -192,10 +215,85 @@ def add_screen_tlbr(browser_events: list[models.BrowserEvent]) -> None: ) # Update the attribute with the new value - target_element["data-tlbr-screen"] = new_screen_coords + element["data-tlbr-screen"] = new_screen_coords + + def process_event( + event: models.BrowserEvent, + sx_scale: float, + sx_offset: float, + sy_scale: float, + sy_offset: float, + ) -> None: + """Helper to process a single browser event.""" + try: + soup, target_element = event.parse() + except AssertionError as exc: + logger.warning(exc) + return + + if target_element_only: + process_element(target_element, sx_scale, sx_offset, sy_scale, sy_offset) + else: + elements_to_process = soup.find_all(attrs={"data-tlbr-client": True}) + for element in elements_to_process: + process_element(element, sx_scale, sx_offset, sy_scale, sy_offset) - # Write the updated element back to the message - message["visibleHtmlString"] = str(soup) + # Compute and assign screen coordinates for scroll events + if event.message.get("eventType") == "scroll": + client_x = event.message["clientX"] + client_y = event.message["clientY"] + + screen_x = sx_scale * client_x + sx_offset + screen_y = sy_scale * client_y + sy_offset + + logger.info(f"scroll {client_x=} {client_y=} {screen_x=} {screen_y=}") + + # Assign screen coordinates to the event message + event.message["screenX"] = screen_x + event.message["screenY"] = screen_y + + # Write the updated elements back to the message + event.message["visibleHTMLString"] = str(soup) + + # Process each event, choosing the closest valid mappings + for idx, event in enumerate(browser_events): + # Extract coordMappings from the message + message = event.message + coord_mappings = message.get("coordMappings", {}) + x_mappings = coord_mappings.get("x", {}) + y_mappings = coord_mappings.get("y", {}) + + # Check if there are sufficient data points; if not, use closest valid mappings + if not are_mappings_valid(x_mappings, y_mappings): + # Determine which of prev_valid or next_valid to use + closest_mappings = ( + prev_valid[idx] if prev_valid[idx] is not None else next_valid[idx] + ) + + if closest_mappings is None: + # this means that no mappings are available anywhere, which means + # we can't handle browser events at all + logger.error( + "No valid coordinate mappings available for element in event at" + f" index {idx}" + ) + import ipdb + + ipdb.set_trace() + continue # No valid mappings available, skip this event + + x_mappings, y_mappings = closest_mappings + + # Compute the scale and offset for both x and y axes + sx_scale, sx_offset = fit_linear_transformation( + x_mappings["client"], x_mappings["screen"] + ) + sy_scale, sy_offset = fit_linear_transformation( + y_mappings["client"], y_mappings["screen"] + ) + + # Process the event + process_event(event, sx_scale, sx_offset, sy_scale, sy_offset) logger.info("Finished processing all browser events for screen coordinates.") @@ -235,7 +333,7 @@ def identify_and_log_smallest_clicked_element( Args: browser_event: The browser event containing the click details. """ - visible_html_string = browser_event.message.get("visibleHtmlString") + visible_html_string = browser_event.message.get("visibleHTMLString") message_id = browser_event.message.get("id") logger.info("*" * 10) logger.info(f"{message_id=}") @@ -246,8 +344,7 @@ def identify_and_log_smallest_clicked_element( logger.warning("No visible HTML data available for click event.") return - # Parse the visible HTML using BeautifulSoup - soup = BeautifulSoup(visible_html_string, "html.parser") + soup = utils.parse_html(visible_html_string, "html.parser") target_element = soup.find(attrs={"data-id": target_id}) target_area = None if not target_element: @@ -329,17 +426,21 @@ def is_action_event( Args: event: The action event to check. - action_name: The type of action (e.g., "click", "press", "release"). + action_name: The action name (eg "click", "press", "release", "move", "scroll"). key_or_button: The key or button associated with the action. Returns: bool: True if the event matches the action name and key/button, False otherwise. """ - if action_name == "click": + if action_name == "move": + return event.name == action_name + elif action_name == "click": return event.name == action_name and event.mouse_button_name == key_or_button elif action_name in {"press", "release"}: raw_action_text = event._text(name_prefix="", name_suffix="") return event.name == action_name and raw_action_text == key_or_button + elif action_name == "scroll": + return event.name == action_name else: return False @@ -353,13 +454,15 @@ def is_browser_event( Args: event: The browser event to check. - action_name (str): The type of action (e.g., "click", "press", "release"). - key_or_button (str): The key or button associated with the action. + action_name: The action name (eg "click", "press", "release", "move", "scroll"). + key_or_button: The key or button associated with the action. Returns: bool: True if the event matches the action name and key/button, False otherwise. """ - if action_name == "click": + if action_name == "move": + return event.message.get("eventType") == EVENT_TYPE_MAPPING[action_name] + elif action_name == "click": return ( event.message.get("eventType") == EVENT_TYPE_MAPPING[action_name] and event.message.get("button") == MOUSE_BUTTON_MAPPING[key_or_button] @@ -369,6 +472,8 @@ def is_browser_event( event.message.get("eventType") == EVENT_TYPE_MAPPING[action_name] and event.message.get("key").lower() == key_or_button ) + elif action_name == "scroll": + return event.message.get("eventType") == EVENT_TYPE_MAPPING[action_name] else: return False @@ -410,7 +515,18 @@ def align_events( if spatial: # Prepare sequences for multidimensional DTW action_sequence = np.array( - [[e.timestamp, e.mouse_x or 0.0, e.mouse_y or 0.0] for e in action_events], + [ + [ + # TODO: refactor ActionEvent timestamps to be immutable; + # add playback_timestamp to implement current timestamp behavior + e.timestamp, + e.mouse_x or 0.0, + e.mouse_y or 0.0, + e.mouse_dx or 0.0, + e.mouse_dy or 0.0, + ] + for e in action_events + ], dtype=np.double, ) @@ -420,6 +536,8 @@ def align_events( e.timestamp, e.message.get("screenX", 0.0), e.message.get("screenY", 0.0), + e.message.get("scrollDeltaX", 0.0), + e.message.get("scrollDeltaY", 0.0), ] for e in browser_events ], @@ -446,7 +564,9 @@ def evaluate_alignment( action_events: list, browser_events: list, spatial: bool = SPATIAL, -) -> tuple[int, list[float], list[float], list[float], list[float]]: +) -> tuple[ + int, list[float], list[float], list[float], list[float], list[float], list[float] +]: """Evaluate the alignment between action events and browser events. Args: @@ -465,6 +585,8 @@ def evaluate_alignment( - list[float]: Differences in local timestamps for matched events. - list[float]: Differences in mouse X positions for matched events. - list[float]: Differences in mouse Y positions for matched events. + - list[float]: Differences in mouse dX positions for matched events. + - list[float]: Differences in mouse dY positions for matched events. """ match_count = 0 mismatch_count = 0 @@ -479,11 +601,14 @@ def evaluate_alignment( mouse_y_differences = ( [] ) # To store differences in mouse Y positions for matching events + mouse_dx_differences = [] + mouse_dy_differences = [] logger.info(f"Alignment for {event_type} Events") for i, j in filtered_path: action_event = action_events[i] - browser_event = deepcopy(browser_events[j]) + # browser_event = deepcopy(browser_events[j]) + browser_event = browser_events[j] action_event_type = action_event.name.lower() browser_event_type = browser_event.message["eventType"].lower() @@ -501,7 +626,30 @@ def evaluate_alignment( local_time_differences.append(local_time_difference) # Compute differences between mouse positions - if action_event.mouse_x is not None: + if browser_event.message.get("scrollDeltaX") or browser_event.message.get( + "scrollDeltaY" + ): + # TODO XXX: accumulate differences before comparing to account for + # different sampling rates + mouse_dx_difference = ( + action_event.mouse_dx - browser_event.message["scrollDeltaX"] + ) + mouse_dy_difference = ( + action_event.mouse_dy - browser_event.message["scrollDeltaY"] + ) + if mouse_dx_difference > 1: + logger.warning( + f"{mouse_dx_difference=} {action_event.mouse_dx=}" + f" {browser_event.message['scrollDeltaX']=}" + ) + if mouse_dy_difference > 1: + logger.warning( + f"{mouse_dy_difference=} {action_event.mouse_dy=}" + f" {browser_event.message['scrollDeltaY']=}" + ) + mouse_dx_differences.append(mouse_dx_difference) + mouse_dy_differences.append(mouse_dy_difference) + elif action_event.mouse_x is not None: mouse_x_difference = ( action_event.mouse_x - browser_event.message["screenX"] ) @@ -573,6 +721,8 @@ def evaluate_alignment( local_time_differences, mouse_x_differences, mouse_y_differences, + mouse_dx_differences, + mouse_dy_differences, ) @@ -640,12 +790,15 @@ def assign_browser_events( ] add_screen_tlbr(browser_events) + session.add_all(browser_events) # Define event pairs dynamically for mouse events event_pairs = [ ("Left Click", "click", "left"), ("Right Click", "click", "right"), ("Middle Click", "click", "middle"), + ("Mouse Move", "move", ""), + ("Scroll", "scroll", ""), ] # Add keyboard events dynamically @@ -660,6 +813,8 @@ def assign_browser_events( total_local_time_differences = [] total_mouse_x_differences = [] total_mouse_y_differences = [] + total_mouse_dx_differences = [] + total_mouse_dy_differences = [] # Initialize additional statistics event_stats = { @@ -683,6 +838,8 @@ def assign_browser_events( browser_events, ) ) + # if action_name == "move": + # import ipdb; ipdb.set_trace() if action_filtered_events or browser_filtered_events: logger.info( @@ -698,8 +855,10 @@ def assign_browser_events( for i, j in filtered_path: action_event = action_filtered_events[i] browser_event = browser_filtered_events[j] + action_event.browser_event_timestamp = browser_event.timestamp action_event.browser_event_id = browser_event.id + action_event.browser_event = browser_event logger.info( f"assigning {action_event.timestamp=} ==>" f" {browser_event.timestamp=}" @@ -714,6 +873,8 @@ def assign_browser_events( local_time_differences, mouse_x_differences, mouse_y_differences, + mouse_dx_differences, + mouse_dy_differences, ) = evaluate_alignment( filtered_path, event_type, @@ -727,6 +888,8 @@ def assign_browser_events( total_local_time_differences += local_time_differences total_mouse_x_differences += mouse_x_differences total_mouse_y_differences += mouse_y_differences + total_mouse_dx_differences += mouse_dx_differences + total_mouse_dy_differences += mouse_dy_differences event_stats["match_count"] += len(filtered_path) event_stats["mismatch_count"] += errors @@ -795,6 +958,8 @@ def assign_browser_events( f" {event_stats['mouse_position_stats'][axis]['stddev']:.4f}" ) + # TODO XXX: Calculate and log statistics for mouse scroll position differences + event_stats["unmatched_browser_events"] = len( [ e diff --git a/openadapt/common.py b/openadapt/common.py index e9debeb31..953477704 100644 --- a/openadapt/common.py +++ b/openadapt/common.py @@ -1,9 +1,13 @@ """This module defines common constants used in OpenAdapt.""" -RAW_MOUSE_EVENTS = ( +RAW_PRECISE_MOUSE_EVENTS = ( "move", "click", - "scroll", +) +# location of cursor doesn't matter as much when scrolling compared to moving/clicking +RAW_IMPRECISE_MOUSE_EVENTS = ("scroll",) +RAW_MOUSE_EVENTS = tuple( + list(RAW_PRECISE_MOUSE_EVENTS) + list(RAW_IMPRECISE_MOUSE_EVENTS) ) FUSED_MOUSE_EVENTS = ( "singleclick", @@ -11,6 +15,7 @@ ) MOUSE_EVENTS = tuple(list(RAW_MOUSE_EVENTS) + list(FUSED_MOUSE_EVENTS)) MOUSE_CLICK_EVENTS = (event for event in MOUSE_EVENTS if event.endswith("click")) +PRECISE_MOUSE_EVENTS = tuple(list(RAW_PRECISE_MOUSE_EVENTS) + list(FUSED_MOUSE_EVENTS)) RAW_KEY_EVENTS = ( "press", diff --git a/openadapt/db/crud.py b/openadapt/db/crud.py index b0a9ff12d..3ebf8a152 100644 --- a/openadapt/db/crud.py +++ b/openadapt/db/crud.py @@ -423,6 +423,7 @@ def get_action_events( .options( joinedload(ActionEvent.recording), joinedload(ActionEvent.screenshot), + joinedload(ActionEvent.browser_event), subqueryload(ActionEvent.window_event).joinedload( WindowEvent.action_events ), diff --git a/openadapt/drivers/openai.py b/openadapt/drivers/openai.py index f81e23747..8fcb5a4e3 100644 --- a/openadapt/drivers/openai.py +++ b/openadapt/drivers/openai.py @@ -14,6 +14,12 @@ from openadapt import cache, utils from openadapt.config import config from openadapt.custom_logger import logger +from tokencost import ( + calculate_prompt_cost, + calculate_completion_cost, + count_message_tokens, +) + MODEL_NAME = [ "gpt-4-vision-preview", @@ -23,7 +29,11 @@ # TODO XXX: per model MAX_TOKENS = 4096 # TODO XXX undocumented -MAX_IMAGES = None +MAX_IMAGES = 90 + +# Track total tokens and cost +total_tokens_used = 0 +total_cost = 0.0 def create_payload( @@ -134,6 +144,68 @@ def get_response( return result +def calculate_tokens_and_cost( + result: dict, payload: dict +) -> tuple[int, float, int, float, float]: + """Calculate the tokens and costs for the API call. + + Args: + result: The response from the OpenAI API. + payload: The original request payload. + + Returns: + A tuple containing: + - Prompt tokens used + - Prompt cost + - Completion tokens used + - Completion cost + - Total cost + """ # noqa: D202 + + # Extract text content from payload messages + def extract_text(messages: list[dict]) -> list[dict]: + extracted = [] + for message in messages: + role = message["role"] + content_list = message["content"] + text_content = "".join( + [item["text"] for item in content_list if item["type"] == "text"] + ) + extracted.append({"role": role, "content": text_content}) + return extracted + + # Extracted messages for token counting + prompt_messages = extract_text(payload["messages"]) + completion_text = result["choices"][0]["message"]["content"] + + # Calculate tokens and costs + prompt_tokens = count_message_tokens(prompt_messages, model=payload["model"]) + completion_tokens = count_message_tokens( + [{"role": "assistant", "content": completion_text}], model=payload["model"] + ) + + prompt_cost = calculate_prompt_cost(prompt_messages, model=payload["model"]) + completion_cost = calculate_completion_cost(completion_text, model=payload["model"]) + + total_cost = prompt_cost + completion_cost + return prompt_tokens, prompt_cost, completion_tokens, completion_cost, total_cost + + +def update_total_usage( + prompt_tokens: int, completion_tokens: int, api_call_cost: float +) -> None: + """Update the total token and cost counters. + + Args: + prompt_tokens: The number of tokens used in the prompt. + completion_tokens: The number of tokens used in the completion. + api_call_cost: The total cost of the API call. + """ + global total_tokens_used, total_cost + total_tokens_used += prompt_tokens + completion_tokens + total_cost += float(api_call_cost) # Convert Decimal to float + + def get_completion(payload: dict, dev_mode: bool = False) -> str: """Sends a request to the OpenAI API and returns the first message. @@ -153,10 +225,25 @@ def get_completion(payload: dict, dev_mode: bool = False) -> str: import ipdb ipdb.set_trace() - # TODO: handle more errors else: raise exc logger.info(f"result=\n{pformat(result)}") + + # Calculate tokens and cost + prompt_tokens, prompt_cost, completion_tokens, completion_cost, api_call_cost = ( + calculate_tokens_and_cost(result, payload) + ) + + # Update total usage + update_total_usage(prompt_tokens, completion_tokens, api_call_cost) + + # Log the results + logger.info( + f"API Call Tokens: {prompt_tokens + completion_tokens}, Total Tokens Used:" + f" {total_tokens_used}" + ) + logger.info(f"API Call Cost: ${api_call_cost:.6f}, Total Cost: ${total_cost:.6f}") + choices = result["choices"] choice = choices[0] message = choice["message"] diff --git a/openadapt/events.py b/openadapt/events.py index 5866a3656..596c1da62 100644 --- a/openadapt/events.py +++ b/openadapt/events.py @@ -81,7 +81,7 @@ def get_events( f"{num_process_iters=} " f"{num_action_events=} " f"{num_window_events=} " - f"{num_screenshots=}" + f"{num_screenshots=} " f"{num_browser_events=}" ) ( @@ -90,6 +90,7 @@ def get_events( screenshots, browser_events, ) = merge_events( + db, action_events, window_events, screenshots, @@ -180,7 +181,7 @@ def make_parent_event( children = extra.get("children", []) browser_events = [child.browser_event for child in children if child.browser_event] if browser_events: - assert len(browser_events) <= 1, len(browser_events) + # TODO: store additional browser events as children on first browser event? browser_event = browser_events[0] event_dict["browser_event"] = browser_event @@ -347,13 +348,23 @@ def is_target_event(event: models.ActionEvent, state: dict[str, Any]) -> bool: def get_merged_events( to_merge: list[models.ActionEvent], state: dict[str, Any] ) -> list[models.ActionEvent]: - state["dt"] += to_merge[-1].timestamp - to_merge[0].timestamp - mouse_dx = sum(event.mouse_dx for event in to_merge) - mouse_dy = sum(event.mouse_dy for event in to_merge) - merged_event = to_merge[-1] - merged_event.timestamp -= state["dt"] - merged_event.mouse_dx = mouse_dx - merged_event.mouse_dy = mouse_dy + total_mouse_dx = sum(event.mouse_dx for event in to_merge) + total_mouse_dy = sum(event.mouse_dy for event in to_merge) + first_child = to_merge[0] + last_child = to_merge[-1] + merged_event = make_parent_event( + first_child, + { + "name": "scroll", + "mouse_x": first_child.mouse_x, + "mouse_y": first_child.mouse_y, + "mouse_dx": total_mouse_dx, + "mouse_dy": total_mouse_dy, + "timestamp": first_child.timestamp - state["dt"], + "children": to_merge, + }, + ) + state["dt"] += last_child.timestamp - first_child.timestamp return [merged_event] return merge_consecutive_action_events( @@ -818,7 +829,54 @@ def discard_unused_events( return referred_events +def filter_invalid_window_events( + db: crud.SaSession, + action_events: list[models.ActionEvent], + min_width: int = 100, + min_height: int = 100, +) -> list[models.WindowEvent]: + """Filter out invalid window events by updating action events. + + Update the associated window_event_timestamp and window_event_id + in action events to the previous valid window event in the sequence + if the current window event is invalid. Return a list of valid window events. + + Args: + action_events (list[models.ActionEvent]): The list of action events. + min_width (int): Minimum allowable width for a valid window event. + Default is 100. + min_height (int): Minimum allowable height for a valid window event. + Default is 100. + + Returns: + list[models.WindowEvent]: A list of valid window events. + """ + + def is_valid_window_event(event: models.WindowEvent) -> bool: + return event and event.width >= min_width and event.height >= min_height + + prev_valid_window = None + valid_window_events = [] + + for action in action_events: + if is_valid_window_event(action.window_event): + prev_valid_window = action.window_event + valid_window_events.append(action.window_event) + else: + assert prev_valid_window is not None, "No previous valid window event found" + action.window_event_id = prev_valid_window.id + action.window_event_timestamp = prev_valid_window.timestamp + action.window_event = prev_valid_window + + db.add_all(action_events) + logger.info( + "Completed filtering and updating invalid window events in action events." + ) + return valid_window_events + + def merge_events( + db: crud.SaSession, action_events: list[models.ActionEvent], window_events: list[models.WindowEvent], screenshots: list[models.Screenshot], @@ -893,6 +951,10 @@ def merge_events( action_events, "browser_event_timestamp", ) + + # TODO: prevent invalid window events from being triggered to begin with + window_events = filter_invalid_window_events(db, action_events) + num_action_events_ = len(action_events) num_window_events_ = len(window_events) num_screenshots_ = len(screenshots) diff --git a/openadapt/models.py b/openadapt/models.py index 1df82c45e..131652f13 100644 --- a/openadapt/models.py +++ b/openadapt/models.py @@ -3,11 +3,12 @@ from collections import OrderedDict from copy import deepcopy from itertools import zip_longest -from typing import Any, Type +from typing import Any, Type, Union import copy import io import sys +from bs4 import BeautifulSoup from oa_pynput import keyboard from PIL import Image, ImageChops import numpy as np @@ -193,6 +194,7 @@ def __init__(self, **kwargs: dict) -> None: for key, value in properties.items(): setattr(self, key, value) + # TODO: rename "available" to "target" @property def available_segment_descriptions(self) -> list[str]: """Gets the available segment descriptions.""" @@ -314,6 +316,9 @@ def text(self, value: str) -> None: """Validate the text property. Useful for ActionModel(**action_dict).""" if not value == self.text: logger.warning(f"{value=} did not match {self.text=}") + # if self.text: + # import ipdb; ipdb.set_trace() + # foo = 1 @property def canonical_text(self) -> str: @@ -372,6 +377,8 @@ def from_dict( ) -> "ActionEvent": """Get an ActionEvent from a dict. + See tests.openadapt.test_models::test_action_from_dict for behavior details. + Args: action_dict (dict): A dictionary representing the action. handle_separator_variations (bool): Whether to attempt to handle variations @@ -382,40 +389,40 @@ def from_dict( (ActionEvent) The ActionEvent. """ sep = config.ACTION_TEXT_SEP - name_prefix = config.ACTION_TEXT_NAME_PREFIX - name_suffix = config.ACTION_TEXT_NAME_SUFFIX children = [] - release_events = [] if "text" in action_dict: - # Splitting actions based on whether they are special keys or characters - if action_dict["text"].startswith(name_prefix) and action_dict[ - "text" - ].endswith(name_suffix): - # handle multiple key separators - # (each key separator must start and end with a prefix and suffix) + name_prefix = config.ACTION_TEXT_NAME_PREFIX + name_suffix = config.ACTION_TEXT_NAME_SUFFIX + text = action_dict["text"] + + # Check if the text contains named keys (starting with the name prefix) + # TODO: support sequences of the form -a- + contains_named_keys = text.startswith(name_prefix) and text.endswith( + name_suffix + ) + + if contains_named_keys: + # Handle named keys, potentially with separator variations + release_events = [] default_sep = "".join([name_suffix, sep, name_prefix]) - variation_seps = ["".join([name_suffix, name_prefix])] key_seps = [default_sep] if handle_separator_variations: + variation_seps = ["".join([name_suffix, name_prefix])] key_seps += variation_seps prefix_len = len(name_prefix) suffix_len = len(name_suffix) key_names = utils.split_by_separators( - action_dict.get("text", "")[prefix_len:-suffix_len], + text[prefix_len:-suffix_len], key_seps, ) canonical_key_names = utils.split_by_separators( action_dict.get("canonical_text", "")[prefix_len:-suffix_len], key_seps, ) - logger.info(f"{key_names=}") - logger.info(f"{canonical_key_names=}") # Process each key name and canonical key name found - children = [] - release_events = [] for key_name, canonical_key_name in zip_longest( key_names, canonical_key_names, @@ -424,19 +431,22 @@ def from_dict( key_name, canonical_key_name ) children.append(press) - release_events.append( - release - ) # Collect release events to append in reverse order later - + release_events.append(release) + children += release_events[::-1] else: - # Handling regular character sequences - sep_len = len(sep) - for key_char in action_dict["text"][:: sep_len + 1]: - # Press and release each character one after another - press, release = cls._create_key_events(key_char=key_char) + # Handle mixed sequences of named keys and regular characters + split_text = text.split(sep) + for part in split_text: + if part.startswith(name_prefix) and part.endswith(name_suffix): + # It's a named key + key_name = part[len(name_prefix) : -len(name_suffix)] + press, release = cls._create_key_events(key_name=key_name) + else: + # It's a character + press, release = cls._create_key_events(key_char=part) children.append(press) children.append(release) - children += release_events[::-1] + rval = ActionEvent(**action_dict, children=children) return rval @@ -482,6 +492,10 @@ def to_prompt_dict(self) -> dict[str, Any]: Returns: dictionary containing relevant properties from the ActionEvent. """ + if self.active_browser_element: + import ipdb + + ipdb.set_trace() action_dict = deepcopy( { key: val @@ -497,12 +511,44 @@ def to_prompt_dict(self) -> dict[str, Any]: for key in ("mouse_x", "mouse_y", "mouse_dx", "mouse_dy"): if key in action_dict: del action_dict[key] + # TODO XXX: add target_segment_description? + + # Manually add properties to the dictionary if self.available_segment_descriptions: action_dict["available_segment_descriptions"] = ( self.available_segment_descriptions ) + if self.active_browser_element: + action_dict["active_browser_element"] = str(self.active_browser_element) + if self.available_browser_elements: + # TODO XXX: available browser_elements contains raw HTML. We need to + # prompt to convert into descriptions. + action_dict["available_browser_elements"] = str( + self.available_browser_elements + ) + + if self.active_browser_element: + import ipdb + + ipdb.set_trace() return action_dict + @property + def next_event(self) -> Union["ActionEvent", None]: + """Get the next ActionEvent chronologically in the same recording. + + Returns: + ActionEvent | None: The next ActionEvent, or None if this is the last event. + """ + if not self.recording or not self.recording.action_events: + return None + + current_index = self.recording.action_events.index(self) + if current_index < len(self.recording.action_events) - 1: + return self.recording.action_events[current_index + 1] + + return None + class WindowEvent(db.Base): """Class representing a window event in the database.""" @@ -649,10 +695,10 @@ def __str__(self) -> str: # Create a copy of the message to avoid modifying the original message_copy = copy.deepcopy(self.message) - # Truncate the visibleHtmlString in the copied message if it exists - if "visibleHtmlString" in message_copy: - message_copy["visibleHtmlString"] = utils.truncate_html( - message_copy["visibleHtmlString"], max_len=100 + # Truncate the visibleHTMLString in the copied message if it exists + if "visibleHTMLString" in message_copy: + message_copy["visibleHTMLString"] = utils.truncate_html( + message_copy["visibleHTMLString"], max_len=100 ) # Get all attributes except 'message' @@ -668,6 +714,39 @@ def __str__(self) -> str: # Return the complete representation including the truncated message return f"BrowserEvent({base_repr}, message={message_copy})" + def parse(self) -> tuple[BeautifulSoup, BeautifulSoup | None]: + """Parses the visible HTML and optionally extracts the target element. + + This method processes the browser event to parse the visible HTML and, + if the event has a targetId, extracts the target HTML element. + + Returns: + A tuple containing: + - BeautifulSoup: The parsed soup of the visible HTML. + - BeautifulSoup | None: The target HTML element if the event type is + "click"; otherwise, None. + + Raises: + AssertionError: If the necessary data is missing. + """ + message = self.message + + visible_html_string = message.get("visibleHTMLString") + assert visible_html_string, "Cannot parse without visibleHTMLstring" + + # Parse the visible HTML using BeautifulSoup + soup = utils.parse_html(visible_html_string) + + target_element = None + + # Fetch the target element using its data-id + target_id = message.get("targetId") + if target_id: + target_element = soup.find(attrs={"data-id": target_id}) + assert target_element, f"No target element found for targetId: {target_id}" + + return soup, target_element + # # TODO: implement # @classmethod # def get_active_browser_event( diff --git a/openadapt/plotting.py b/openadapt/plotting.py index f01df0d4f..03c6a5b0c 100644 --- a/openadapt/plotting.py +++ b/openadapt/plotting.py @@ -261,11 +261,14 @@ def display_event( width_ratio, height_ratio = utils.get_scale_ratios(action_event) # dim area outside window event - x0 = window_event.left * width_ratio - y0 = window_event.top * height_ratio - x1 = x0 + window_event.width * width_ratio - y1 = y0 + window_event.height * height_ratio - image = draw_rectangle(x0, y0, x1, y1, image, outline_width=5) + if not window_event: + logger.error(f"{window_event=}") + else: + x0 = window_event.left * width_ratio + y0 = window_event.top * height_ratio + x1 = x0 + window_event.width * width_ratio + y1 = y0 + window_event.height * height_ratio + image = draw_rectangle(x0, y0, x1, y1, image, outline_width=5) # display diff bbox if diff: diff --git a/openadapt/record.py b/openadapt/record.py index 27eb9e578..eef25c7c8 100644 --- a/openadapt/record.py +++ b/openadapt/record.py @@ -24,6 +24,7 @@ from pympler import tracker import av +from openadapt.browser import set_browser_mode from openadapt.build_utils import redirect_stdout_stderr from openadapt.custom_logger import logger from openadapt.models import Recording @@ -1192,24 +1193,27 @@ def read_browser_events( """ utils.set_start_time(recording.timestamp) + # set the browser mode + set_browser_mode("record", websocket) + logger.info("Starting Reading Browser Events ...") while not terminate_processing.is_set(): - for message in websocket: - if not message: - continue - - timestamp = utils.get_timestamp() - - data = json.loads(message) - - event_q.put( - Event( - timestamp, - "browser", - {"message": data}, - ) + try: + message = websocket.recv(0.01) + except TimeoutError: + continue + timestamp = utils.get_timestamp() + data = json.loads(message) + event_q.put( + Event( + timestamp, + "browser", + {"message": data}, ) + ) + + set_browser_mode("idle", websocket) @logger.catch diff --git a/openadapt/strategies/__init__.py b/openadapt/strategies/__init__.py index fecc7c056..dca387d11 100644 --- a/openadapt/strategies/__init__.py +++ b/openadapt/strategies/__init__.py @@ -5,6 +5,7 @@ # flake8: noqa from openadapt.strategies.base import BaseReplayStrategy +from openadapt.strategies.visual_browser import VisualBrowserReplayStrategy # disabled because importing is expensive # from openadapt.strategies.demo import DemoReplayStrategy diff --git a/openadapt/strategies/base.py b/openadapt/strategies/base.py index ea7fb68c5..de114331b 100644 --- a/openadapt/strategies/base.py +++ b/openadapt/strategies/base.py @@ -10,7 +10,7 @@ from openadapt import adapters, models, playback, utils from openadapt.custom_logger import logger -CHECK_ACTION_COMPLETE = True +CHECK_ACTION_COMPLETE = False MAX_FRAME_TIMES = 1000 @@ -21,12 +21,14 @@ def __init__( self, recording: models.Recording, max_frame_times: int = MAX_FRAME_TIMES, + include_a11y_data: bool = True, ) -> None: """Initialize the BaseReplayStrategy. Args: recording (models.Recording): The recording to replay. max_frame_times (int): The maximum number of frame times to track. + include_a11y_data (bool): Whether to include accessibility data. """ self.recording = recording self.max_frame_times = max_frame_times @@ -34,6 +36,7 @@ def __init__( self.screenshots = [] self.window_events = [] self.frame_times = [] + self.include_a11y_data = include_a11y_data @abstractmethod def get_next_action_event( @@ -67,7 +70,10 @@ def run(self) -> None: continue self.screenshots.append(screenshot) - window_event = models.WindowEvent.get_active_window_event() + window_event = models.WindowEvent.get_active_window_event( + # TODO: rename + include_window_data=self.include_a11y_data, + ) self.window_events.append(window_event) try: action_event = self.get_next_action_event( @@ -121,6 +127,16 @@ def log_fps(self) -> None: self.frame_times.pop(0) +# TODO XXX handle failure mode: +""" +expected_state='After pressing -, I would expect to see the application +switcher overlay, showing a row of applications that can be cycled through.' +is_complete=False +""" + + +# e.g. include next window state in prompt +# e.g. elaborate specifically for cmd-tab (i.e. that next application should be visible) def prompt_is_action_complete( current_screenshot: models.Screenshot, played_actions: list[models.ActionEvent], diff --git a/openadapt/strategies/visual_browser.py b/openadapt/strategies/visual_browser.py new file mode 100644 index 000000000..dfb5ee046 --- /dev/null +++ b/openadapt/strategies/visual_browser.py @@ -0,0 +1,671 @@ +"""Like visual.py but using instrumented DOM to generate segments instead of FastSAM.""" + +# TODO XXX: fix caching + +from dataclasses import dataclass +from pprint import pformat +import time + +from bs4 import BeautifulSoup +from PIL import Image, ImageDraw +import numpy as np + +from openadapt import adapters, common, models, plotting, strategies, utils, vision +from openadapt.custom_logger import logger + +DEBUG = True +DEBUG_REPLAY = False +SEGMENTATIONS = [] # TODO: store to db +MIN_SCREENSHOT_SSIM = 0.9 # threshold for considering screenshots structurally similar +MIN_SEGMENT_SSIM = 0.95 # threshold for considering segments structurally similar +MIN_SEGMENT_SIZE_SIM = 0 # threshold for considering segment sizes similar +SKIP_MOVE_BEFORE_CLICK = True # workaround for bug in events.remove_move_before_click + + +@dataclass +class Segmentation: + """A data class to encapsulate segmentation data of images. + + Attributes: + image: The original image used to generate segments. + marked_image: The marked image (for Set-of-Mark prompting). + masked_images: A list of PIL Image objects that have been masked based on + segmentation. + descriptions: Descriptions of each segmented region, correlating with each + image in `masked_images`. + bounding_boxes: A list of dictionaries containing bounding box + coordinates for each segmented region. Each dictionary should have the + keys "top", "left", "height", and "width" with float values indicating + the position and size of the box. + centroids: A list of tuples, each containing the x and y coordinates of the + centroid of each segmented region. + """ + + image: Image.Image + marked_image: Image.Image + masked_images: list[Image.Image] + descriptions: list[str] + bounding_boxes: list[dict[str, float]] # "top", "left", "height", "width" + centroids: list[tuple[float, float]] + + +def add_active_segment_descriptions(action_events: list[models.ActionEvent]) -> None: + """Set the ActionEvent.active_segment_description where appropriate. + + Args: + action_events: list of ActionEvents to modify in-place. + """ + for action in action_events: + # TODO: handle terminal event + if action.name in common.MOUSE_EVENTS: + window_segmentation = get_window_segmentation(action) + if not window_segmentation: + logger.warning(f"{window_segmentation=}") + continue + active_segment_idx = get_active_segment(action, window_segmentation) + if not active_segment_idx: + logger.warning(f"{active_segment_idx=}") + active_segment_description = "(None)" + # XXX TODO handle + logger.error(f"{active_segment_idx=}") + # import ipdb; ipdb.set_trace() + else: + active_segment_description = window_segmentation.descriptions[ + active_segment_idx + ] + action.active_segment_description = active_segment_description + action.available_segment_descriptions = window_segmentation.descriptions + + +@utils.retry_with_exceptions() +def apply_replay_instructions( + action_events: list[models.ActionEvent], + replay_instructions: str, + exceptions: list[Exception], +) -> None: + """Modify the given ActionEvents according to the given replay instructions. + + Args: + action_events: list of action events to be modified in place. + replay_instructions: instructions for how action events should be modified. + exceptions: list of exceptions that were produced attempting to run this + function. + """ + action_dicts = [action.to_prompt_dict() for action in action_events] + actions_dict = {"actions": action_dicts} + system_prompt = utils.render_template_from_file( + "prompts/system.j2", + ) + prompt = utils.render_template_from_file( + "prompts/apply_replay_instructions--browser.j2", + actions=actions_dict, + replay_instructions=replay_instructions, + exceptions=exceptions, + ) + prompt_adapter = adapters.get_default_prompt_adapter() + content = prompt_adapter.prompt( + prompt, + system_prompt=system_prompt, + ) + content_dict = utils.parse_code_snippet(content) + try: + action_dicts = content_dict["actions"] + except TypeError as exc: + logger.warning(exc) + # sometimes OpenAI returns a list of dicts directly, so let it slide + action_dicts = content_dict + modified_actions = [] + for action_dict in action_dicts: + action = models.ActionEvent.from_dict(action_dict) + modified_actions.append(action) + return modified_actions + + +class VisualBrowserReplayStrategy( + strategies.base.BaseReplayStrategy, +): + """ReplayStrategy using Large Multimodal Model and replay instructions.""" + + def __init__( + self, + recording: models.Recording, + instructions: str, + ) -> None: + """Initialize the VisualReplayStrategy. + + Args: + recording (models.Recording): The recording object. + instructions (str): Natural language instructions for how recording + should be replayed. + """ + super().__init__(recording) + self.recording_action_idx = 0 + self.action_history = [] + add_active_segment_descriptions(recording.processed_action_events) + self.modified_actions = apply_replay_instructions( + recording.processed_action_events, + instructions, + ) + # TODO: make this less of a hack + global DEBUG + DEBUG = DEBUG_REPLAY + + def get_next_action_event( + self, + active_screenshot: models.Screenshot, + active_window: models.WindowEvent, + ) -> models.ActionEvent: + """Get the next ActionEvent for replay. + + Since we have already modified the actions, this function just determines + the appropriate coordinates for the modified actions (where appropriate). + + Args: + active_screenshot (models.Screenshot): The active screenshot object. + active_window (models.WindowEvent): The active window event object. + + Returns: + models.ActionEvent: The next ActionEvent for replay. + """ + logger.debug(f"{self.recording_action_idx=}") + if self.recording_action_idx >= len(self.modified_actions): + raise StopIteration() + + # TODO: hack + time.sleep(1) + active_window = models.WindowEvent.get_active_window_event() + + active_screenshot = models.Screenshot.take_screenshot() + logger.info(f"{active_window=}") + + modified_reference_action = self.modified_actions[self.recording_action_idx] + self.recording_action_idx += 1 + + # TODO XXX: how to handle scroll position? + if modified_reference_action.name in common.PRECISE_MOUSE_EVENTS: + modified_reference_action.screenshot = active_screenshot + modified_reference_action.window_event = active_window + modified_reference_action.recording = self.recording + exceptions = [] + while True: + active_window_segmentation = get_window_segmentation( + modified_reference_action, + exceptions=exceptions, + ) + try: + target_segment_idx = active_window_segmentation.descriptions.index( + modified_reference_action.active_segment_description + ) + except ValueError as exc: + logger.warning(f"{exc=}") + exceptions.append(exc) + else: + break + target_centroid = active_window_segmentation.centroids[target_segment_idx] + # = scale_ratio * + width_ratio, height_ratio = utils.get_scale_ratios( + modified_reference_action + ) + target_mouse_x = target_centroid[0] / width_ratio + active_window.left + target_mouse_y = target_centroid[1] / height_ratio + active_window.top + modified_reference_action.mouse_x = target_mouse_x + modified_reference_action.mouse_y = target_mouse_y + self.action_history.append(modified_reference_action) + return modified_reference_action + + def __del__(self) -> None: + """Log the action history.""" + action_history_dicts = [ + action.to_prompt_dict() for action in self.action_history + ] + logger.info(f"action_history=\n{pformat(action_history_dicts)}") + + +def get_active_segment( + action: models.ActionEvent, + window_segmentation: Segmentation, + debug: bool = DEBUG, +) -> int: + """Get the index of the bounding box containing the action's mouse coordinates. + + Adjust for the scaling of the cropped window and the action coordinates. + Optionally visualize segments and mouse position. + + Args: + action: the ActionEvent + window_segmentation: the Segmentation + debug: whether to display images for debugging + + Returns: + index of active segment in Segmentation + """ + # Obtain the scale ratios + width_ratio, height_ratio = utils.get_scale_ratios(action) + logger.info(f"{width_ratio=} {height_ratio=}") + + # Adjust action coordinates to be relative to the cropped window's top-left corner + adjusted_mouse_x = (action.mouse_x - action.window_event.left) * width_ratio + adjusted_mouse_y = (action.mouse_y - action.window_event.top) * height_ratio + logger.info(f"{action.mouse_x=} {action.window_event.left=} {adjusted_mouse_x=}") + logger.info(f"{action.mouse_y=} {action.window_event.top=} {adjusted_mouse_y=}") + + active_index = None + + if debug: + # Create an empty image with enough space to display all bounding boxes + width = int( + max( + box["left"] + box["width"] for box in window_segmentation.bounding_boxes + ) + ) + height = int( + max( + box["top"] + box["height"] for box in window_segmentation.bounding_boxes + ) + ) + image = Image.new("RGB", (width, height), "white") + draw = ImageDraw.Draw(image) + + for index, box in enumerate(window_segmentation.bounding_boxes): + box_left = box["left"] + box_top = box["top"] + box_right = box["left"] + box["width"] + box_bottom = box["top"] + box["height"] + + if debug: + # Draw each bounding box as a rectangle + draw.rectangle( + [box_left, box_top, box_right, box_bottom], outline="red", width=1 + ) + + # Check if the adjusted action's coordinates are within the bounding box + if ( + box_left <= adjusted_mouse_x < box_right + and box_top <= adjusted_mouse_y < box_bottom + ): + active_index = index + + if debug: + # Draw the adjusted mouse position + draw.ellipse( + [ + adjusted_mouse_x - 5, + adjusted_mouse_y - 5, + adjusted_mouse_x + 5, + adjusted_mouse_y + 5, + ], + fill="blue", + ) + # Display the image without blocking + image.show() + + if active_index is None: + # XXX TODO handle + logger.error(f"{active_index=}") + # import ipdb; ipdb.set_trace() + + return active_index + + +def find_similar_image_segmentation( + image: Image.Image, + min_ssim: float = MIN_SCREENSHOT_SSIM, +) -> tuple[Segmentation, np.ndarray] | tuple[None, None]: + """Identify a similar image in the cache based on the SSIM comparison. + + This function iterates through a global list of image segmentations, + comparing each against a given image using the SSIM index calculated by + get_image_similarity. + It logs and updates the best match found above a specified SSIM threshold. + + Args: + image (Image.Image): The image to compare against the cache. + min_ssim (float): The minimum SSIM threshold for considering a match. + + Returns: + tuple[Segmentation, np.ndarray] | tuple[None, None]: The best matching + segmentation and its difference image if a match is found; + otherwise, None for both. + """ + similar_segmentation = None + similar_segmentation_diff = None + + for segmentation in SEGMENTATIONS: + similarity_index, ssim_image = vision.get_image_similarity( + image, + segmentation.image, + ) + if similarity_index > min_ssim: + logger.info(f"{similarity_index=}") + min_ssim = similarity_index + similar_segmentation = segmentation + similar_segmentation_diff = ssim_image + + return similar_segmentation, similar_segmentation_diff + + +def get_window_segmentation( + action_event: models.ActionEvent, + exceptions: list[Exception] | None = None, + # TODO: document or remove + return_similar_segmentation: bool = False, + handle_similar_image_groups: bool = False, +) -> Segmentation: + """Segments the active window from the action event's screenshot. + + Args: + action_event: action event containing the screenshot data. + exceptions: list of exceptions previously raised, added to prompt. + handle_similar_image_groups (bool): Whether to distinguish between similar + image groups. Work-in-progress. + + Returns: + Segmentation object containing detailed segmentation information. + """ + screenshot = action_event.screenshot + # XXX TODO in visual.py we use screenshot.cropped_image, but to use that here + # we need to modify data-tlbr-screen + original_image = screenshot.image + if DEBUG: + original_image.show() + + if return_similar_segmentation and not exceptions: + similar_segmentation, similar_segmentation_diff = ( + find_similar_image_segmentation(original_image) + ) + if similar_segmentation: + # TODO XXX: create copy of similar_segmentation, but overwrite with segments + # of regions of new image where segments of similar_segmentation overlap + # non-zero regions of similar_segmentation_diff + return similar_segmentation + + if action_event.browser_event: + refined_masks, element_labels = get_dom_masks(action_event) + else: + # TODO XXX: get segments from A11Y, fallback to segmentation + + # XXX HACK: skip if this event is "move" and next is "click" + # TODO: consolidate with events.remove_move_before_click (currently disabled) + # the following was implemented because enabelling remove_move_before_click + # had no effect on order in visualize.py + if SKIP_MOVE_BEFORE_CLICK: + if ( + action_event.name == "move" + and action_event.next_event + and action_event.next_event.name == "click" + ): + logger.info("Skipping 'move' event followed by 'click'") + return None + + segmentation_adapter = adapters.get_default_segmentation_adapter() + segmented_image = segmentation_adapter.fetch_segmented_image(original_image) + if DEBUG: + segmented_image.show() + + masks = vision.get_masks_from_segmented_image(segmented_image) + if DEBUG: + plotting.display_binary_images_grid(masks) + + refined_masks = vision.refine_masks(masks) + if DEBUG: + plotting.display_binary_images_grid(refined_masks) + + masked_images = vision.extract_masked_images(original_image, refined_masks) + if action_event.browser_event and DEBUG: + plotting.display_images_table_with_titles(masked_images, element_labels) + + if handle_similar_image_groups: + similar_idx_groups, ungrouped_idxs, _, _ = vision.get_similar_image_idxs( + masked_images, + MIN_SEGMENT_SSIM, + MIN_SEGMENT_SIZE_SIM, + ) + # TODO XXX: handle similar image groups + raise ValueError("Currently unsupported.") + + descriptions = prompt_for_descriptions( + original_image, + masked_images, + action_event.active_segment_description, + exceptions, + ) + bounding_boxes, centroids = vision.calculate_bounding_boxes(refined_masks) + assert len(bounding_boxes) == len(descriptions) == len(centroids), ( + len(bounding_boxes), + len(descriptions), + len(centroids), + ) + marked_image = plotting.get_marked_image( + original_image, + refined_masks, # masks, + ) + segmentation = Segmentation( + original_image, + marked_image, + masked_images, + descriptions, + bounding_boxes, + centroids, + ) + if DEBUG: + plotting.display_images_table_with_titles(masked_images, descriptions) + + SEGMENTATIONS.append(segmentation) + return segmentation + + +MIN_ELEMENT_AREA_PIXELS = 10 +MIN_EXTENT = 0.1 +DEBUG_DISPLAY_MASKS = False +DEBUG_DISPLAY_LABELLED_SCREENSHOT = True + + +def get_dom_masks( + action_event: models.ActionEvent, + min_element_area_pixels: int = MIN_ELEMENT_AREA_PIXELS, + min_extent: float = MIN_EXTENT, + display_masks: bool = DEBUG_DISPLAY_MASKS, + display_screenshot_with_labels: bool = DEBUG_DISPLAY_LABELLED_SCREENSHOT, +) -> tuple[list[np.ndarray], list[str]]: + """Returns a list of binary masks for DOM elements in the given ActionEvent. + + Also returns a list of corresponding strings containing the element's data-id and + its top, left, bottom, right coordinates, scaled according to the screen. + + Args: + action_event (models.ActionEvent): The ActionEvent to extract DOM masks from. + min_element_area_pixels (int, optional): Minimum area of an element in pixels. + Defaults to MIN_ELEMENT_AREA_PIXELS. + min_extent (float, optional): Minimum extent of an element. + Defaults to MIN_EXTENT. + display_masks (bool, optional): Whether to display the masks. + display_screenshot_with_labels (bool, optional): Whether to display screenshot + with bounding boxes and labels. + + Returns: + tuple[list[np.ndarray], list[str]]: A tuple containing a list of binary masks + and a list of strings with the data-id and coordinates for each mask. + """ + # Get scale ratios for x and y + width_ratio, height_ratio = utils.get_scale_ratios(action_event) + + browser_event = action_event.browser_event + assert browser_event, action_event + soup, target_element = browser_event.parse() + elements = soup.find_all(attrs={"data-tlbr-screen": True}) + elements.sort(key=lambda el: calculate_area(el)) + + masks = [] + element_info = [] + + # If we want to display the screenshot with labels, create a drawable version of + # the screenshot + if display_screenshot_with_labels: + screenshot_with_labels = action_event.screenshot.image.copy() + draw_screenshot = ImageDraw.Draw(screenshot_with_labels) + + for element in elements: + try: + area = calculate_area(element) + if area < min_element_area_pixels: + logger.info(f"skipping {area=} < {min_element_area_pixels=}") + continue + + # Remove child masks from mask + child_area = 0 + for child in element.find_all(attrs={"data-tlbr-screen": True}): + child_area += calculate_area(child) + adjusted_area = max(0, area - child_area) + extent = adjusted_area / area if area > 0 else 0 + logger.info(f"{extent=}") + if extent < min_extent: + logger.info(f"<{min_extent=}, skipping") + continue + + # Create a binary mask for the element + mask_img = Image.new("L", action_event.screenshot.image.size, color=0) + draw = ImageDraw.Draw(mask_img) + + # Get the element's top, left, bottom, right in window coordinates + top, left, bottom, right = get_tlbr(element) + + # Apply scale ratios to convert to image space + top_scaled = top * height_ratio + left_scaled = left * width_ratio + bottom_scaled = bottom * height_ratio + right_scaled = right * width_ratio + + draw.rectangle( + [(left_scaled, top_scaled), (right_scaled, bottom_scaled)], fill=255 + ) + + # Convert the mask to a numpy array + mask = np.array(mask_img, dtype=np.uint8) / 255 + masks.append(mask) + + # Collect element data-id and scaled coordinates + data_id = element.get("data-id", "unknown") + element_info.append( + f"data-id: {data_id}, tlbr: ({top_scaled}, {left_scaled}," + f" {bottom_scaled}, {right_scaled})" + ) + + # If display_screenshot_with_labels is True, draw the bounding boxes and + # labels + if display_screenshot_with_labels: + draw_screenshot.rectangle( + [(left_scaled, top_scaled), (right_scaled, bottom_scaled)], + outline="red", + width=2, + ) + draw_screenshot.text( + (left_scaled, top_scaled), + f"{data_id}: ({top_scaled}, {left_scaled}, {bottom_scaled}," + f" {right_scaled})", + fill="yellow", + ) + + if display_masks: + logger.debug(f"Displaying mask for {element=}") + mask_img.show() # Display the mask using PIL.Image.imshow() + + except (ValueError, KeyError) as exc: + logger.warning(f"Failed to process {element=}: {exc}") + + # If display_screenshot_with_labels is True, show the screenshot with the drawn + # labels + if display_screenshot_with_labels: + logger.debug("Displaying screenshot with bounding boxes and labels") + screenshot_with_labels.show() + + return masks, element_info + + +def get_tlbr(element: BeautifulSoup, attr: str = "data-tlbr-screen") -> list[int]: + """Get bounding box tuple.""" + top, left, bottom, right = [float(val) for val in element[attr].split(",")] + return top, left, bottom, right + + +def calculate_area(element: BeautifulSoup) -> int: + """Calculate area of an element.""" + top, left, bottom, right = get_tlbr(element) + return (right - left) * (bottom - top) + + +def prompt_for_descriptions( + original_image: Image.Image, + masked_images: list[Image.Image], + active_segment_description: str | None, + exceptions: list[Exception] | None = None, +) -> list[str]: + """Generates descriptions for given image segments using a prompt adapter. + + Args: + original_image: The original image. + masked_images: List of masked images. + active_segment_description: Description of the active segment. + exceptions: List of exceptions previously raised, added to prompts. + + Returns: + list of descriptions for each masked image. + """ + # TODO: move inside adapters.prompt + for driver in adapters.prompt.DRIVER_ORDER: + # off by one to account for original image + if driver.MAX_IMAGES and (len(masked_images) + 1 > driver.MAX_IMAGES): + masked_images_batches = utils.split_list( + masked_images, + driver.MAX_IMAGES - 1, + ) + descriptions = [] + for masked_images_batch in masked_images_batches: + descriptions_batch = prompt_for_descriptions( + original_image, + masked_images_batch, + active_segment_description, + exceptions, + ) + descriptions += descriptions_batch + return descriptions + + images = [original_image] + masked_images + system_prompt = utils.render_template_from_file( + "prompts/system.j2", + ) + logger.info(f"system_prompt=\n{system_prompt}") + num_segments = len(masked_images) + prompt = utils.render_template_from_file( + "prompts/description.j2", + active_segment_description=active_segment_description, + num_segments=num_segments, + exceptions=exceptions, + ).strip() + logger.info(f"prompt=\n{prompt}") + logger.info(f"{len(images)=}") + descriptions_json = driver.prompt( + prompt, + system_prompt, + images, + ) + descriptions = utils.parse_code_snippet(descriptions_json)["descriptions"] + logger.info(f"{descriptions=}") + try: + assert len(descriptions) == len(masked_images), ( + len(descriptions), + len(masked_images), + ) + except Exception as exc: + exceptions = exceptions or [] + exceptions.append(exc) + logger.info(f"exceptions=\n{pformat(exceptions)}") + return prompt_for_descriptions( + original_image, + masked_images, + active_segment_description, + exceptions, + ) + + # remove indexes + descriptions = [desc for idx, desc in descriptions] + return descriptions diff --git a/openadapt/utils.py b/openadapt/utils.py index 82279c0d4..c57763694 100644 --- a/openadapt/utils.py +++ b/openadapt/utils.py @@ -16,6 +16,7 @@ import threading import time +from bs4 import BeautifulSoup from jinja2 import Environment, FileSystemLoader from PIL import Image, ImageEnhance from posthog import Posthog @@ -992,6 +993,49 @@ def truncate_html(html_str: str, max_len: int) -> str: return html_str +def parse_html(html: str, parser: str = "html.parser") -> BeautifulSoup: + """Parse the visible HTML using BeautifulSoup.""" + soup = BeautifulSoup(html, parser) + return soup + + +def get_html_prompt(html: str, convert_to_markdown: bool = False) -> str: + """Convert an HTML string to a processed version suitable for LLM prompts. + + Args: + html: The input HTML string. + convert_to_markdown: If True, converts the HTML to Markdown. Defaults to False. + + Returns: + A string with preserved semantic structure and interactable elements. + If convert_to_markdown is True, the string is in Markdown format. + """ + # Parse HTML with BeautifulSoup + soup = BeautifulSoup(html, "html.parser") + + # Remove non-interactive and unnecessary elements + for tag in soup(["style", "script", "noscript", "meta", "head", "iframe"]): + tag.decompose() + + assert not convert_to_markdown, "poetry add html2text" + if convert_to_markdown: + # XXX TODO: + import html2text + + # Initialize html2text converter + converter = html2text.HTML2Text() + converter.ignore_links = False # Keep all links + converter.ignore_images = False # Keep all images + converter.body_width = 0 # Preserve original width without wrapping + + # Convert the cleaned HTML to Markdown + markdown = converter.handle(str(soup)) + return markdown + + # Return processed HTML as a string if Markdown conversion is not required + return str(soup) + + class WrapStdout: """Class to be used a target for multiprocessing.Process.""" diff --git a/openadapt/window/_macos.py b/openadapt/window/_macos.py index 3b9fd7625..25a9d0428 100644 --- a/openadapt/window/_macos.py +++ b/openadapt/window/_macos.py @@ -3,6 +3,7 @@ import pickle import plistlib import re +import time import AppKit import ApplicationServices @@ -117,17 +118,37 @@ def get_window_data(window_meta: dict) -> dict: def dump_state( element: Union[AppKit.NSArray, list, AppKit.NSDictionary, dict, Any], - elements: set = None, -) -> Union[dict, list]: + elements: set | None = None, + max_depth: int = 10, + current_depth: int = 0, + timeout: float | None = None, + start_time: float | None = None, +) -> Union[dict, list, None]: """Dump the state of the given element and its descendants. Args: element: The element to dump the state for. elements (set): Set to track elements to prevent circular traversal. + max_depth (int): Maximum depth for recursion. + current_depth (int): Current depth in the recursion. + timeout (float): Maximum time in seconds for the dump_state operation. + start_time (float): Start time of the dump_state operation. Returns: - dict or list: State of element and descendants as dict or list + dict or list or None: State of element and descendants as dict or list, + or None if max depth reached """ + if timeout is not None and start_time is None: + start_time = time.time() + + if current_depth >= max_depth: + return None + + if timeout is not None and start_time is not None: + if time.time() - start_time > timeout: + logger.warning("dump_state timed out") + return None + elements = elements or set() if element in elements: return @@ -136,14 +157,18 @@ def dump_state( if isinstance(element, AppKit.NSArray) or isinstance(element, list): state = [] for child in element: - _state = dump_state(child, elements) + _state = dump_state( + child, elements, max_depth, current_depth + 1, timeout, start_time + ) if _state: state.append(_state) return state elif isinstance(element, AppKit.NSDictionary) or isinstance(element, dict): state = {} for k, v in element.items(): - _state = dump_state(v, elements) + _state = dump_state( + v, elements, max_depth, current_depth + 1, timeout, start_time + ) if _state: state[k] = _state return state @@ -179,7 +204,14 @@ def dump_state( ): continue - _state = dump_state(attr_val, elements) + _state = dump_state( + attr_val, + elements, + max_depth, + current_depth + 1, + timeout, + start_time, + ) if _state: state[attr_name] = _state return state diff --git a/poetry.lock b/poetry.lock index e83b8d488..fe840de78 100644 --- a/poetry.lock +++ b/poetry.lock @@ -7298,6 +7298,53 @@ numpy = "*" [package.extras] all = ["defusedxml", "fsspec", "imagecodecs (>=2023.8.12)", "lxml", "matplotlib", "zarr"] +[[package]] +name = "tiktoken" +version = "0.8.0" +description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models" +optional = false +python-versions = ">=3.9" +files = [ + {file = "tiktoken-0.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b07e33283463089c81ef1467180e3e00ab00d46c2c4bbcef0acab5f771d6695e"}, + {file = "tiktoken-0.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9269348cb650726f44dd3bbb3f9110ac19a8dcc8f54949ad3ef652ca22a38e21"}, + {file = "tiktoken-0.8.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25e13f37bc4ef2d012731e93e0fef21dc3b7aea5bb9009618de9a4026844e560"}, + {file = "tiktoken-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f13d13c981511331eac0d01a59b5df7c0d4060a8be1e378672822213da51e0a2"}, + {file = "tiktoken-0.8.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:6b2ddbc79a22621ce8b1166afa9f9a888a664a579350dc7c09346a3b5de837d9"}, + {file = "tiktoken-0.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:d8c2d0e5ba6453a290b86cd65fc51fedf247e1ba170191715b049dac1f628005"}, + {file = "tiktoken-0.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d622d8011e6d6f239297efa42a2657043aaed06c4f68833550cac9e9bc723ef1"}, + {file = "tiktoken-0.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2efaf6199717b4485031b4d6edb94075e4d79177a172f38dd934d911b588d54a"}, + {file = "tiktoken-0.8.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5637e425ce1fc49cf716d88df3092048359a4b3bbb7da762840426e937ada06d"}, + {file = "tiktoken-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fb0e352d1dbe15aba082883058b3cce9e48d33101bdaac1eccf66424feb5b47"}, + {file = "tiktoken-0.8.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:56edfefe896c8f10aba372ab5706b9e3558e78db39dd497c940b47bf228bc419"}, + {file = "tiktoken-0.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:326624128590def898775b722ccc327e90b073714227175ea8febbc920ac0a99"}, + {file = "tiktoken-0.8.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:881839cfeae051b3628d9823b2e56b5cc93a9e2efb435f4cf15f17dc45f21586"}, + {file = "tiktoken-0.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fe9399bdc3f29d428f16a2f86c3c8ec20be3eac5f53693ce4980371c3245729b"}, + {file = "tiktoken-0.8.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9a58deb7075d5b69237a3ff4bb51a726670419db6ea62bdcd8bd80c78497d7ab"}, + {file = "tiktoken-0.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2908c0d043a7d03ebd80347266b0e58440bdef5564f84f4d29fb235b5df3b04"}, + {file = "tiktoken-0.8.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:294440d21a2a51e12d4238e68a5972095534fe9878be57d905c476017bff99fc"}, + {file = "tiktoken-0.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:d8f3192733ac4d77977432947d563d7e1b310b96497acd3c196c9bddb36ed9db"}, + {file = "tiktoken-0.8.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:02be1666096aff7da6cbd7cdaa8e7917bfed3467cd64b38b1f112e96d3b06a24"}, + {file = "tiktoken-0.8.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c94ff53c5c74b535b2cbf431d907fc13c678bbd009ee633a2aca269a04389f9a"}, + {file = "tiktoken-0.8.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b231f5e8982c245ee3065cd84a4712d64692348bc609d84467c57b4b72dcbc5"}, + {file = "tiktoken-0.8.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4177faa809bd55f699e88c96d9bb4635d22e3f59d635ba6fd9ffedf7150b9953"}, + {file = "tiktoken-0.8.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5376b6f8dc4753cd81ead935c5f518fa0fbe7e133d9e25f648d8c4dabdd4bad7"}, + {file = "tiktoken-0.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:18228d624807d66c87acd8f25fc135665617cab220671eb65b50f5d70fa51f69"}, + {file = "tiktoken-0.8.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7e17807445f0cf1f25771c9d86496bd8b5c376f7419912519699f3cc4dc5c12e"}, + {file = "tiktoken-0.8.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:886f80bd339578bbdba6ed6d0567a0d5c6cfe198d9e587ba6c447654c65b8edc"}, + {file = "tiktoken-0.8.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6adc8323016d7758d6de7313527f755b0fc6c72985b7d9291be5d96d73ecd1e1"}, + {file = "tiktoken-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b591fb2b30d6a72121a80be24ec7a0e9eb51c5500ddc7e4c2496516dd5e3816b"}, + {file = "tiktoken-0.8.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:845287b9798e476b4d762c3ebda5102be87ca26e5d2c9854002825d60cdb815d"}, + {file = "tiktoken-0.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:1473cfe584252dc3fa62adceb5b1c763c1874e04511b197da4e6de51d6ce5a02"}, + {file = "tiktoken-0.8.0.tar.gz", hash = "sha256:9ccbb2740f24542534369c5635cfd9b2b3c2490754a78ac8831d99f89f94eeb2"}, +] + +[package.dependencies] +regex = ">=2022.1.18" +requests = ">=2.26.0" + +[package.extras] +blobfile = ["blobfile (>=2)"] + [[package]] name = "tldextract" version = "5.1.2" @@ -7319,6 +7366,25 @@ requests-file = ">=1.4" release = ["build", "twine"] testing = ["black", "mypy", "pytest", "pytest-gitignore", "pytest-mock", "responses", "ruff", "syrupy", "tox", "types-filelock", "types-requests"] +[[package]] +name = "tokencost" +version = "0.1.12" +description = "To calculate token and translated USD cost of string and message calls to OpenAI, for example when used by AI agents" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tokencost-0.1.12-py3-none-any.whl", hash = "sha256:cf08dabfa970eb56cc69fed45e405095dfdc3ada261d3e6ad944aadd2bac157d"}, + {file = "tokencost-0.1.12.tar.gz", hash = "sha256:555a2812dd1039923ef7ecc057d54eee898acdee004830ad845552196a5b2b59"}, +] + +[package.dependencies] +aiohttp = ">=3.9.3" +tiktoken = ">=0.7.0" + +[package.extras] +dev = ["coverage[toml] (>=7.4.0)", "flake8 (>=3.1.0)", "pytest (>=7.4.4)"] +llama-index = ["llama-index (>=0.10.23)"] + [[package]] name = "tokenizers" version = "0.13.3" @@ -8361,4 +8427,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"] [metadata] lock-version = "2.0" python-versions = "3.10.x" -content-hash = "906630b6f2aa9fa40caa1d957967fea3c7432cb8dcc83762c22a520b0848fafe" +content-hash = "b4ab66bd022968943af70b0680e43a786de9b1d7e5133f6cde873f679eaf59d7" diff --git a/pyproject.toml b/pyproject.toml index 7f857f51a..6e130c199 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,6 +109,7 @@ cython = "^3.0.10" av = "^12.3.0" beautifulsoup4 = "^4.12.3" dtaidistance = "^2.3.12" +tokencost = "^0.1.12" [tool.pytest.ini_options] filterwarnings = [ # suppress warnings starting from "setuptools>=67.3" diff --git a/tests/openadapt/test_events.py b/tests/openadapt/test_events.py index 293e5e5f1..0a86ed73f 100644 --- a/tests/openadapt/test_events.py +++ b/tests/openadapt/test_events.py @@ -215,17 +215,30 @@ def make_click_event( ) -def make_scroll_event(dy: int = 0, dx: int = 0) -> ActionEvent: +def make_scroll_event( + get_children: Optional[Callable[[], list[ActionEvent]]] = None, + dy: int = 0, + dx: int = 0, +) -> ActionEvent: """Create a scroll event with the given attributes. Args: + get_children (Callable[[], list[ActionEvent]]): Function that returns + the list of children events. dy (int, optional): Vertical scroll amount. Defaults to 0. dx (int, optional): Horizontal scroll amount. Defaults to 0. Returns: ActionEvent: An instance of the ActionEvent class representing the scroll event. """ - return make_action_event({"name": "scroll", "mouse_dx": dx, "mouse_dy": dy}) + return make_action_event( + { + "name": "scroll", + "mouse_dx": dx, + "mouse_dy": dy, + }, + get_pre_children=get_children, + ) def make_click_events( @@ -467,9 +480,23 @@ def test_merge_consecutive_mouse_scroll_events() -> None: expected_events = rows2dicts( [ make_move_event(), - make_scroll_event(dx=2), + make_scroll_event( + lambda: [ + make_scroll_event(dx=2), + make_scroll_event(dx=1), + make_scroll_event(dx=-1), + ], + dx=2, + ), make_move_event(), - make_scroll_event(dx=1, dy=1), + make_scroll_event( + lambda: [ + make_scroll_event(dy=1), + make_scroll_event(dx=1), + ], + dx=1, + dy=1, + ), ] ) logger.info(f"expected_events=\n{pformat(expected_events)}") @@ -633,7 +660,7 @@ def test_merge_consecutive_keyboard_events() -> None: make_press_event("h"), make_release_event("g"), make_release_event("h"), - make_scroll_event(1), + make_scroll_event(dx=1), ] logger.info(f"raw_events=\n{pformat(rows2dicts(raw_events))}") reset_timestamp() @@ -660,7 +687,7 @@ def test_merge_consecutive_keyboard_events() -> None: make_release_event("h"), ] ), - make_scroll_event(1), + make_scroll_event(dx=1), ] ) logger.info(f"expected_events=\n{pformat(expected_events)}") diff --git a/tests/openadapt/test_models.py b/tests/openadapt/test_models.py index 8319e7458..e8260bf97 100644 --- a/tests/openadapt/test_models.py +++ b/tests/openadapt/test_models.py @@ -5,22 +5,47 @@ def test_action_from_dict() -> None: """Test ActionEvent.from_dict().""" - texts = [ - # standard - "--", - # mal-formed - "", - # mixed - "-", - "-", - ] + input_variations_by_expected_output = { + # all named keys + "--": [ + # standard + "--", + # mal-formed + "", + # mixed + "-", + "-", + ], + # TODO: support malformed configurations below + # all char keys + "a-b-c": [ + # standard + "a-b-c", + # malformed + # "abc", + # mixed + # "a-bc", + # "ab-c", + ], + # mixed named and char + "-t": [ + # standard + "-t", + # malformed + # "t", + ], + } - for text in texts: - action_dict = { - "name": "type", - "text": text, - "canonical_text": text, - } - print(f"{text=}") - action_event = models.ActionEvent.from_dict(action_dict) - assert action_event.text == "--", action_event + for ( + expected_output, + input_variations, + ) in input_variations_by_expected_output.items(): + for input_variation in input_variations: + action_dict = { + "name": "type", + "text": input_variation, + "canonical_text": input_variation, + } + print(f"{input_variation=}") + action_event = models.ActionEvent.from_dict(action_dict) + assert action_event.text == expected_output, action_event