From 3efca1bcb573f57d0764d0c490cf13dcfd50222a Mon Sep 17 00:00:00 2001 From: Alex Rodionov Date: Sun, 19 May 2024 09:56:15 -0700 Subject: [PATCH 1/3] [py] Add low-level sync API to use DevTools --- py/BUILD.bazel | 2 + py/generate.py | 1 + py/requirements.txt | 1 + py/requirements_lock.txt | 4 + py/selenium/webdriver/remote/webdriver.py | 29 ++++ .../webdriver/remote/websocket_connection.py | 125 ++++++++++++++++++ .../webdriver/common/devtools_tests.py | 39 ++++++ 7 files changed, 201 insertions(+) create mode 100644 py/selenium/webdriver/remote/websocket_connection.py create mode 100644 py/test/selenium/webdriver/common/devtools_tests.py diff --git a/py/BUILD.bazel b/py/BUILD.bazel index f521e51befce0..466bbb67cae55 100644 --- a/py/BUILD.bazel +++ b/py/BUILD.bazel @@ -179,6 +179,7 @@ py_library( requirement("trio_websocket"), requirement("urllib3"), requirement("certifi"), + requirement("websocket-client"), ], ) @@ -263,6 +264,7 @@ py_wheel( "trio-websocket~=0.9", "certifi>=2021.10.8", "typing_extensions>=4.9.0", + "websocket-client>=1.8.0", ], strip_path_prefixes = [ "py/", diff --git a/py/generate.py b/py/generate.py index 5d6ea7cb4b0ef..026e80673b40f 100644 --- a/py/generate.py +++ b/py/generate.py @@ -77,6 +77,7 @@ def event_class(method): ''' A decorator that registers a class as an event class. ''' def decorate(cls): _event_parsers[method] = cls + cls.event_class = method return cls return decorate diff --git a/py/requirements.txt b/py/requirements.txt index bc91501adabf5..5f706ccb5eeec 100644 --- a/py/requirements.txt +++ b/py/requirements.txt @@ -32,5 +32,6 @@ trio-websocket==0.9.2 twine==4.0.2 typing_extensions==4.9.0 urllib3[socks]==2.0.7 +websocket-client==1.8.0 wsproto==1.2.0 zipp==3.17.0 diff --git a/py/requirements_lock.txt b/py/requirements_lock.txt index b4746763676c5..037a8bda20bd5 100644 --- a/py/requirements_lock.txt +++ b/py/requirements_lock.txt @@ -521,6 +521,10 @@ urllib3[socks]==2.0.7 \ # -r py/requirements.txt # requests # twine +websocket-client==1.8.0 \ + --hash=sha256:17b44cc997f5c498e809b22cdf2d9c7a9e71c02c8cc2b6c56e7c2d1239bfa526 \ + --hash=sha256:3239df9f44da632f96012472805d40a23281a991027ce11d2f45a6f24ac4c3da + # via -r py/requirements.txt wsproto==1.2.0 \ --hash=sha256:ad565f26ecb92588a3e43bc3d96164de84cd9902482b130d0ddbaa9664a85065 \ --hash=sha256:b9acddd652b585d75b20477888c56642fdade28bdfd3579aa24a4d2c037dd736 diff --git a/py/selenium/webdriver/remote/webdriver.py b/py/selenium/webdriver/remote/webdriver.py index c1fa5110898b7..afde33c29d0c9 100644 --- a/py/selenium/webdriver/remote/webdriver.py +++ b/py/selenium/webdriver/remote/webdriver.py @@ -62,8 +62,10 @@ from .shadowroot import ShadowRoot from .switch_to import SwitchTo from .webelement import WebElement +from .websocket_connection import WebSocketConnection cdp = None +devtools = None def import_cdp(): @@ -206,6 +208,7 @@ def __init__( self._authenticator_id = None self.start_client() self.start_session(capabilities) + self._websocket_connection = None def __repr__(self): return f'<{type(self).__module__}.{type(self).__name__} (session="{self.session_id}")>' @@ -1017,6 +1020,32 @@ def get_log(self, log_type): """ return self.execute(Command.GET_LOG, {"type": log_type})["value"] + def start_devtools(self): + global devtools + if self._websocket_connection: + return devtools, self._websocket_connection + else: + global cdp + import_cdp() + + if not devtools: + if self.caps.get("se:cdp"): + ws_url = self.caps.get("se:cdp") + version = self.caps.get("se:cdpVersion").split(".")[0] + else: + version, ws_url = self._get_cdp_details() + + if not ws_url: + raise WebDriverException("Unable to find url to connect to from capabilities") + + devtools = cdp.import_devtools(version) + self._websocket_connection = WebSocketConnection(ws_url) + targets = self._websocket_connection.execute(devtools.target.get_targets()) + target_id = targets[0].target_id + session = self._websocket_connection.execute(devtools.target.attach_to_target(target_id, True)) + self._websocket_connection.session_id = session + return devtools, self._websocket_connection + @asynccontextmanager async def bidi_connection(self): global cdp diff --git a/py/selenium/webdriver/remote/websocket_connection.py b/py/selenium/webdriver/remote/websocket_connection.py new file mode 100644 index 0000000000000..5925fdeb00226 --- /dev/null +++ b/py/selenium/webdriver/remote/websocket_connection.py @@ -0,0 +1,125 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import json +import logging +from ssl import CERT_NONE +from threading import Thread +from time import sleep + +from websocket import WebSocketApp + +logger = logging.getLogger("websocket") + + +class WebSocketConnection: + _response_wait_timeout = 30 + _response_wait_interval = 0.1 + + _max_log_message_size = 9999 + + def __init__(self, url): + self.session_id = None + self.url = url + + self._id = 0 + self._callbacks = {} + self._messages = {} + self._started = False + + self._start_ws() + self._wait_until(lambda: self._started) + + def close(self): + self._ws_thread.join(timeout=_response_wait_timeout) + self._ws.close() + self._started = False + self._ws = None + + def execute(self, command): + self._id += 1 + payload = self._serialize_command(command) + payload["id"] = self._id + if self.session_id: + payload["sessionId"] = self.session_id + + data = json.dumps(payload) + logger.debug(f"WebSocket -> {data}"[: self._max_log_message_size]) + self._ws.send(data) + + self._wait_until(lambda: self._id in self._messages) + result = self._messages.pop(self._id)["result"] + return self._deserialize_result(result, command) + + def on(self, event, callback): + if event not in self._callbacks: + self._callbacks[event.event_class] = [] + self._callbacks[event.event_class].append(lambda params: callback(event.from_json(params))) + + def _serialize_command(self, command): + return next(command) + + def _deserialize_result(self, result, command): + try: + _ = command.send(result) + raise InternalError("The command's generator function did not exit when expected!") + except StopIteration as exit: + return exit.value + + def _start_ws(self): + def on_open(ws): + self._started = True + + def on_message(ws, message): + self._process_message(message) + + def on_error(ws, error): + logger.debug(f"WebSocket error: {error}") + ws.close() + + def run_socket(): + if self.url.startswith("wss://"): + self._ws.run_forever(sslopt={"cert_reqs": CERT_NONE}, suppress_origin=True) + else: + self._ws.run_forever(suppress_origin=True) + + self._ws = WebSocketApp(self.url, on_open=on_open, on_message=on_message, on_error=on_error) + self._ws_thread = Thread(target=run_socket) + self._ws_thread.start() + + def _process_message(self, message): + message = json.loads(message) + logger.debug(f"WebSocket <- {message}"[: self._max_log_message_size]) + + if "id" in message: + self._messages[message["id"]] = message + + if "method" in message: + params = message["params"] + for callback in self._callbacks.get(message["method"], []): + callback(params) + + def _wait_until(self, condition): + timeout = self._response_wait_timeout + interval = self._response_wait_interval + + while timeout > 0: + result = condition() + if result: + return result + else: + timeout -= interval + sleep(interval) diff --git a/py/test/selenium/webdriver/common/devtools_tests.py b/py/test/selenium/webdriver/common/devtools_tests.py new file mode 100644 index 0000000000000..9d789aad6cc8f --- /dev/null +++ b/py/test/selenium/webdriver/common/devtools_tests.py @@ -0,0 +1,39 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +from selenium.webdriver.common.by import By +from selenium.webdriver.common.log import Log +from selenium.webdriver.support import expected_conditions as EC +from selenium.webdriver.support.ui import WebDriverWait + + +@pytest.mark.xfail_safari +def test_check_console_messages(driver, pages): + devtools, connection = driver.start_devtools() + console_api_calls = [] + + connection.execute(devtools.runtime.enable()) + connection.on(devtools.runtime.ConsoleAPICalled, console_api_calls.append) + driver.execute_script("console.log('I love cheese')") + driver.execute_script("console.error('I love bread')") + WebDriverWait(driver, 5).until(lambda _: len(console_api_calls) == 2) + + assert console_api_calls[0].type_ == "log" + assert console_api_calls[0].args[0].value == "I love cheese" + assert console_api_calls[1].type_ == "error" + assert console_api_calls[1].args[0].value == "I love bread" From de85fbaf25a508f953d75cbc01e257293921d738 Mon Sep 17 00:00:00 2001 From: Alex Rodionov Date: Thu, 6 Jun 2024 18:03:45 -0700 Subject: [PATCH 2/3] [py] Fix linter --- .../webdriver/remote/websocket_connection.py | 18 +++++++++++++----- .../webdriver/common/devtools_tests.py | 3 --- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/py/selenium/webdriver/remote/websocket_connection.py b/py/selenium/webdriver/remote/websocket_connection.py index 5925fdeb00226..8c0cbb7919f1f 100644 --- a/py/selenium/webdriver/remote/websocket_connection.py +++ b/py/selenium/webdriver/remote/websocket_connection.py @@ -44,7 +44,7 @@ def __init__(self, url): self._wait_until(lambda: self._started) def close(self): - self._ws_thread.join(timeout=_response_wait_timeout) + self._ws_thread.join(timeout=self._response_wait_timeout) self._ws.close() self._started = False self._ws = None @@ -67,7 +67,9 @@ def execute(self, command): def on(self, event, callback): if event not in self._callbacks: self._callbacks[event.event_class] = [] - self._callbacks[event.event_class].append(lambda params: callback(event.from_json(params))) + self._callbacks[event.event_class].append( + lambda params: callback(event.from_json(params)) + ) def _serialize_command(self, command): return next(command) @@ -75,7 +77,9 @@ def _serialize_command(self, command): def _deserialize_result(self, result, command): try: _ = command.send(result) - raise InternalError("The command's generator function did not exit when expected!") + raise Exception( + "The command's generator function did not exit when expected!" + ) except StopIteration as exit: return exit.value @@ -92,11 +96,15 @@ def on_error(ws, error): def run_socket(): if self.url.startswith("wss://"): - self._ws.run_forever(sslopt={"cert_reqs": CERT_NONE}, suppress_origin=True) + self._ws.run_forever( + sslopt={"cert_reqs": CERT_NONE}, suppress_origin=True + ) else: self._ws.run_forever(suppress_origin=True) - self._ws = WebSocketApp(self.url, on_open=on_open, on_message=on_message, on_error=on_error) + self._ws = WebSocketApp( + self.url, on_open=on_open, on_message=on_message, on_error=on_error + ) self._ws_thread = Thread(target=run_socket) self._ws_thread.start() diff --git a/py/test/selenium/webdriver/common/devtools_tests.py b/py/test/selenium/webdriver/common/devtools_tests.py index 9d789aad6cc8f..b4b3bcbce361a 100644 --- a/py/test/selenium/webdriver/common/devtools_tests.py +++ b/py/test/selenium/webdriver/common/devtools_tests.py @@ -16,9 +16,6 @@ # under the License. import pytest -from selenium.webdriver.common.by import By -from selenium.webdriver.common.log import Log -from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support.ui import WebDriverWait From a4b8fa4fbc523deec943c82fa1a736185c60bb9e Mon Sep 17 00:00:00 2001 From: Alex Rodionov Date: Thu, 6 Jun 2024 20:39:24 -0700 Subject: [PATCH 3/3] [py] fix lint --- .../webdriver/remote/websocket_connection.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/py/selenium/webdriver/remote/websocket_connection.py b/py/selenium/webdriver/remote/websocket_connection.py index 8c0cbb7919f1f..0eb842bb8d8f8 100644 --- a/py/selenium/webdriver/remote/websocket_connection.py +++ b/py/selenium/webdriver/remote/websocket_connection.py @@ -67,9 +67,7 @@ def execute(self, command): def on(self, event, callback): if event not in self._callbacks: self._callbacks[event.event_class] = [] - self._callbacks[event.event_class].append( - lambda params: callback(event.from_json(params)) - ) + self._callbacks[event.event_class].append(lambda params: callback(event.from_json(params))) def _serialize_command(self, command): return next(command) @@ -77,9 +75,7 @@ def _serialize_command(self, command): def _deserialize_result(self, result, command): try: _ = command.send(result) - raise Exception( - "The command's generator function did not exit when expected!" - ) + raise Exception("The command's generator function did not exit when expected!") except StopIteration as exit: return exit.value @@ -96,15 +92,11 @@ def on_error(ws, error): def run_socket(): if self.url.startswith("wss://"): - self._ws.run_forever( - sslopt={"cert_reqs": CERT_NONE}, suppress_origin=True - ) + self._ws.run_forever(sslopt={"cert_reqs": CERT_NONE}, suppress_origin=True) else: self._ws.run_forever(suppress_origin=True) - self._ws = WebSocketApp( - self.url, on_open=on_open, on_message=on_message, on_error=on_error - ) + self._ws = WebSocketApp(self.url, on_open=on_open, on_message=on_message, on_error=on_error) self._ws_thread = Thread(target=run_socket) self._ws_thread.start()