diff --git a/py/generate.py b/py/generate.py index 5d6ea7cb4b0ef..026e80673b40f 100644 --- a/py/generate.py +++ b/py/generate.py @@ -77,6 +77,7 @@ def event_class(method): ''' A decorator that registers a class as an event class. ''' def decorate(cls): _event_parsers[method] = cls + cls.event_class = method return cls return decorate diff --git a/py/requirements.txt b/py/requirements.txt index bc91501adabf5..5f706ccb5eeec 100644 --- a/py/requirements.txt +++ b/py/requirements.txt @@ -32,5 +32,6 @@ trio-websocket==0.9.2 twine==4.0.2 typing_extensions==4.9.0 urllib3[socks]==2.0.7 +websocket-client==1.8.0 wsproto==1.2.0 zipp==3.17.0 diff --git a/py/selenium/webdriver/remote/webdriver.py b/py/selenium/webdriver/remote/webdriver.py index c1fa5110898b7..afde33c29d0c9 100644 --- a/py/selenium/webdriver/remote/webdriver.py +++ b/py/selenium/webdriver/remote/webdriver.py @@ -62,8 +62,10 @@ from .shadowroot import ShadowRoot from .switch_to import SwitchTo from .webelement import WebElement +from .websocket_connection import WebSocketConnection cdp = None +devtools = None def import_cdp(): @@ -206,6 +208,7 @@ def __init__( self._authenticator_id = None self.start_client() self.start_session(capabilities) + self._websocket_connection = None def __repr__(self): return f'<{type(self).__module__}.{type(self).__name__} (session="{self.session_id}")>' @@ -1017,6 +1020,32 @@ def get_log(self, log_type): """ return self.execute(Command.GET_LOG, {"type": log_type})["value"] + def start_devtools(self): + global devtools + if self._websocket_connection: + return devtools, self._websocket_connection + else: + global cdp + import_cdp() + + if not devtools: + if self.caps.get("se:cdp"): + ws_url = self.caps.get("se:cdp") + version = self.caps.get("se:cdpVersion").split(".")[0] + else: + version, ws_url = self._get_cdp_details() + + if not ws_url: + raise WebDriverException("Unable to find url to connect to from capabilities") + + devtools = cdp.import_devtools(version) + self._websocket_connection = WebSocketConnection(ws_url) + targets = self._websocket_connection.execute(devtools.target.get_targets()) + target_id = targets[0].target_id + session = self._websocket_connection.execute(devtools.target.attach_to_target(target_id, True)) + self._websocket_connection.session_id = session + return devtools, self._websocket_connection + @asynccontextmanager async def bidi_connection(self): global cdp diff --git a/py/selenium/webdriver/remote/websocket_connection.py b/py/selenium/webdriver/remote/websocket_connection.py new file mode 100644 index 0000000000000..1c3b54c44fe52 --- /dev/null +++ b/py/selenium/webdriver/remote/websocket_connection.py @@ -0,0 +1,108 @@ +import json +import logging +from ssl import CERT_NONE +from threading import Thread +from time import sleep + +from websocket import WebSocketApp + +logger = logging.getLogger("websocket") + +class WebSocketConnection: + _response_wait_timeout = 30 + _response_wait_interval = 0.1 + + _max_log_message_size = 9999 + + def __init__(self, url): + self.session_id = None + self.url = url + + self._id = 0 + self._callbacks = {} + self._messages = {} + self._started = False + + self._start_ws() + self._wait_until(lambda: self._started) + + def close(self): + self._ws_thread.join(timeout=_response_wait_timeout) + self._ws.close() + self._started = False + self._ws = None + + def execute(self, command): + self._id += 1 + payload = self._serialize_command(command) + payload["id"] = self._id + if self.session_id: + payload["sessionId"] = self.session_id + + data = json.dumps(payload) + logger.debug(f"WebSocket -> {data}"[:self._max_log_message_size]) + self._ws.send(data) + + self._wait_until(lambda: self._id in self._messages) + result = self._messages.pop(self._id)["result"] + return self._deserialize_result(result, command) + + def on(self, event, callback): + if event not in self._callbacks: + self._callbacks[event.event_class] = [] + self._callbacks[event.event_class].append(lambda params: callback(event.from_json(params))) + + def _serialize_command(self, command): + return next(command) + + def _deserialize_result(self, result, command): + try: + _ = command.send(result) + raise InternalError("The command's generator function did not exit when expected!") + except StopIteration as exit: + return exit.value + + def _start_ws(self): + def on_open(ws): + self._started = True + + def on_message(ws, message): + self._process_message(message) + + def on_error(ws, error): + logger.debug(f"WebSocket error: {error}") + ws.close() + + def run_socket(): + if self.url.startswith("wss://"): + self._ws.run_forever(sslopt={"cert_reqs": CERT_NONE}, suppress_origin=True) + else: + self._ws.run_forever(suppress_origin=True) + + self._ws = WebSocketApp(self.url, on_open=on_open, on_message=on_message, on_error=on_error) + self._ws_thread = Thread(target=run_socket) + self._ws_thread.start() + + def _process_message(self, message): + message = json.loads(message) + logger.debug(f"WebSocket <- {message}"[:self._max_log_message_size]) + + if 'id' in message: + self._messages[message['id']] = message + + if 'method' in message: + params = message['params'] + for callback in self._callbacks.get(message['method'], []): + callback(params) + + def _wait_until(self, condition): + timeout = self._response_wait_timeout + interval = self._response_wait_interval + + while timeout > 0: + result = condition() + if result: + return result + else: + timeout -= interval + sleep(interval) diff --git a/py/test/selenium/webdriver/common/devtools_tests.py b/py/test/selenium/webdriver/common/devtools_tests.py new file mode 100644 index 0000000000000..9d789aad6cc8f --- /dev/null +++ b/py/test/selenium/webdriver/common/devtools_tests.py @@ -0,0 +1,39 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +from selenium.webdriver.common.by import By +from selenium.webdriver.common.log import Log +from selenium.webdriver.support import expected_conditions as EC +from selenium.webdriver.support.ui import WebDriverWait + + +@pytest.mark.xfail_safari +def test_check_console_messages(driver, pages): + devtools, connection = driver.start_devtools() + console_api_calls = [] + + connection.execute(devtools.runtime.enable()) + connection.on(devtools.runtime.ConsoleAPICalled, console_api_calls.append) + driver.execute_script("console.log('I love cheese')") + driver.execute_script("console.error('I love bread')") + WebDriverWait(driver, 5).until(lambda _: len(console_api_calls) == 2) + + assert console_api_calls[0].type_ == "log" + assert console_api_calls[0].args[0].value == "I love cheese" + assert console_api_calls[1].type_ == "error" + assert console_api_calls[1].args[0].value == "I love bread"