diff --git a/py/BUILD.bazel b/py/BUILD.bazel index f521e51befce0..466bbb67cae55 100644 --- a/py/BUILD.bazel +++ b/py/BUILD.bazel @@ -179,6 +179,7 @@ py_library( requirement("trio_websocket"), requirement("urllib3"), requirement("certifi"), + requirement("websocket-client"), ], ) @@ -263,6 +264,7 @@ py_wheel( "trio-websocket~=0.9", "certifi>=2021.10.8", "typing_extensions>=4.9.0", + "websocket-client>=1.8.0", ], strip_path_prefixes = [ "py/", diff --git a/py/generate.py b/py/generate.py index 5d6ea7cb4b0ef..026e80673b40f 100644 --- a/py/generate.py +++ b/py/generate.py @@ -77,6 +77,7 @@ def event_class(method): ''' A decorator that registers a class as an event class. ''' def decorate(cls): _event_parsers[method] = cls + cls.event_class = method return cls return decorate diff --git a/py/requirements.txt b/py/requirements.txt index 64b2ee18ca655..b4b8092407fa8 100644 --- a/py/requirements.txt +++ b/py/requirements.txt @@ -32,5 +32,6 @@ trio-websocket==0.9.2 twine==4.0.2 typing_extensions==4.9.0 urllib3[socks]==2.0.7 +websocket-client==1.8.0 wsproto==1.2.0 zipp==3.17.0 diff --git a/py/requirements_lock.txt b/py/requirements_lock.txt index b4746763676c5..037a8bda20bd5 100644 --- a/py/requirements_lock.txt +++ b/py/requirements_lock.txt @@ -521,6 +521,10 @@ urllib3[socks]==2.0.7 \ # -r py/requirements.txt # requests # twine +websocket-client==1.8.0 \ + --hash=sha256:17b44cc997f5c498e809b22cdf2d9c7a9e71c02c8cc2b6c56e7c2d1239bfa526 \ + --hash=sha256:3239df9f44da632f96012472805d40a23281a991027ce11d2f45a6f24ac4c3da + # via -r py/requirements.txt wsproto==1.2.0 \ --hash=sha256:ad565f26ecb92588a3e43bc3d96164de84cd9902482b130d0ddbaa9664a85065 \ --hash=sha256:b9acddd652b585d75b20477888c56642fdade28bdfd3579aa24a4d2c037dd736 diff --git a/py/selenium/webdriver/remote/webdriver.py b/py/selenium/webdriver/remote/webdriver.py index ccb6b0b5e1830..9e05d5165e8e8 100644 --- a/py/selenium/webdriver/remote/webdriver.py +++ b/py/selenium/webdriver/remote/webdriver.py @@ -63,8 +63,10 @@ from .shadowroot import ShadowRoot from .switch_to import SwitchTo from .webelement import WebElement +from .websocket_connection import WebSocketConnection cdp = None +devtools = None def import_cdp(): @@ -207,6 +209,7 @@ def __init__( self._authenticator_id = None self.start_client() self.start_session(capabilities) + self._websocket_connection = None def __repr__(self): return f'<{type(self).__module__}.{type(self).__name__} (session="{self.session_id}")>' @@ -1018,6 +1021,32 @@ def get_log(self, log_type): """ return self.execute(Command.GET_LOG, {"type": log_type})["value"] + def start_devtools(self): + global devtools + if self._websocket_connection: + return devtools, self._websocket_connection + else: + global cdp + import_cdp() + + if not devtools: + if self.caps.get("se:cdp"): + ws_url = self.caps.get("se:cdp") + version = self.caps.get("se:cdpVersion").split(".")[0] + else: + version, ws_url = self._get_cdp_details() + + if not ws_url: + raise WebDriverException("Unable to find url to connect to from capabilities") + + devtools = cdp.import_devtools(version) + self._websocket_connection = WebSocketConnection(ws_url) + targets = self._websocket_connection.execute(devtools.target.get_targets()) + target_id = targets[0].target_id + session = self._websocket_connection.execute(devtools.target.attach_to_target(target_id, True)) + self._websocket_connection.session_id = session + return devtools, self._websocket_connection + @asynccontextmanager async def bidi_connection(self): global cdp diff --git a/py/selenium/webdriver/remote/websocket_connection.py b/py/selenium/webdriver/remote/websocket_connection.py new file mode 100644 index 0000000000000..0eb842bb8d8f8 --- /dev/null +++ b/py/selenium/webdriver/remote/websocket_connection.py @@ -0,0 +1,125 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import json +import logging +from ssl import CERT_NONE +from threading import Thread +from time import sleep + +from websocket import WebSocketApp + +logger = logging.getLogger("websocket") + + +class WebSocketConnection: + _response_wait_timeout = 30 + _response_wait_interval = 0.1 + + _max_log_message_size = 9999 + + def __init__(self, url): + self.session_id = None + self.url = url + + self._id = 0 + self._callbacks = {} + self._messages = {} + self._started = False + + self._start_ws() + self._wait_until(lambda: self._started) + + def close(self): + self._ws_thread.join(timeout=self._response_wait_timeout) + self._ws.close() + self._started = False + self._ws = None + + def execute(self, command): + self._id += 1 + payload = self._serialize_command(command) + payload["id"] = self._id + if self.session_id: + payload["sessionId"] = self.session_id + + data = json.dumps(payload) + logger.debug(f"WebSocket -> {data}"[: self._max_log_message_size]) + self._ws.send(data) + + self._wait_until(lambda: self._id in self._messages) + result = self._messages.pop(self._id)["result"] + return self._deserialize_result(result, command) + + def on(self, event, callback): + if event not in self._callbacks: + self._callbacks[event.event_class] = [] + self._callbacks[event.event_class].append(lambda params: callback(event.from_json(params))) + + def _serialize_command(self, command): + return next(command) + + def _deserialize_result(self, result, command): + try: + _ = command.send(result) + raise Exception("The command's generator function did not exit when expected!") + except StopIteration as exit: + return exit.value + + def _start_ws(self): + def on_open(ws): + self._started = True + + def on_message(ws, message): + self._process_message(message) + + def on_error(ws, error): + logger.debug(f"WebSocket error: {error}") + ws.close() + + def run_socket(): + if self.url.startswith("wss://"): + self._ws.run_forever(sslopt={"cert_reqs": CERT_NONE}, suppress_origin=True) + else: + self._ws.run_forever(suppress_origin=True) + + self._ws = WebSocketApp(self.url, on_open=on_open, on_message=on_message, on_error=on_error) + self._ws_thread = Thread(target=run_socket) + self._ws_thread.start() + + def _process_message(self, message): + message = json.loads(message) + logger.debug(f"WebSocket <- {message}"[: self._max_log_message_size]) + + if "id" in message: + self._messages[message["id"]] = message + + if "method" in message: + params = message["params"] + for callback in self._callbacks.get(message["method"], []): + callback(params) + + def _wait_until(self, condition): + timeout = self._response_wait_timeout + interval = self._response_wait_interval + + while timeout > 0: + result = condition() + if result: + return result + else: + timeout -= interval + sleep(interval) diff --git a/py/test/selenium/webdriver/common/devtools_tests.py b/py/test/selenium/webdriver/common/devtools_tests.py new file mode 100644 index 0000000000000..b4b3bcbce361a --- /dev/null +++ b/py/test/selenium/webdriver/common/devtools_tests.py @@ -0,0 +1,36 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +from selenium.webdriver.support.ui import WebDriverWait + + +@pytest.mark.xfail_safari +def test_check_console_messages(driver, pages): + devtools, connection = driver.start_devtools() + console_api_calls = [] + + connection.execute(devtools.runtime.enable()) + connection.on(devtools.runtime.ConsoleAPICalled, console_api_calls.append) + driver.execute_script("console.log('I love cheese')") + driver.execute_script("console.error('I love bread')") + WebDriverWait(driver, 5).until(lambda _: len(console_api_calls) == 2) + + assert console_api_calls[0].type_ == "log" + assert console_api_calls[0].args[0].value == "I love cheese" + assert console_api_calls[1].type_ == "error" + assert console_api_calls[1].args[0].value == "I love bread"