diff --git a/.gitignore b/.gitignore index e4adb09ff4f98..d0e0fbe7d099c 100644 --- a/.gitignore +++ b/.gitignore @@ -76,6 +76,7 @@ py/selenium/webdriver/remote/isDisplayed.js py/docs/build/ py/build/ py/LICENSE +py/pytestdebug.log selenium.egg-info/ third_party/java/jetty/jetty-repacked.jar *.user diff --git a/py/conftest.py b/py/conftest.py index 3c433f9baae11..6177a899005a1 100644 --- a/py/conftest.py +++ b/py/conftest.py @@ -79,6 +79,14 @@ def pytest_addoption(parser): dest="use_lan_ip", help="Whether to start test server with lan ip instead of localhost", ) + parser.addoption( + "--bidi", + action="store", + dest="bidi", + metavar="BIDI", + default=True, + help="Whether to enable BiDi support", + ) def pytest_ignore_collect(path, config): @@ -166,6 +174,7 @@ def get_options(driver_class, config): browser_path = config.option.binary browser_args = config.option.args headless = bool(config.option.headless) + bidi = bool(config.option.bidi) options = None if browser_path or browser_args: @@ -187,6 +196,14 @@ def get_options(driver_class, config): options.add_argument("--headless=new") if driver_class == "Firefox": options.add_argument("-headless") + + if bidi: + if not options: + options = getattr(webdriver, f"{driver_class}Options")() + + if driver_class == "Chrome" or driver_class == "Edge" or driver_class == "Firefox": + options.web_socket_url = True + return options diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py new file mode 100644 index 0000000000000..88a26b6437ca2 --- /dev/null +++ b/py/selenium/webdriver/common/bidi/script.py @@ -0,0 +1,111 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import typing +from dataclasses import dataclass + +from .session import session_subscribe +from .session import session_unsubscribe + + +class Script: + def __init__(self, conn): + self.conn = conn + self.log_entry_subscribed = False + + def add_console_message_handler(self, handler): + self._subscribe_to_log_entries() + return self.conn.add_callback(LogEntryAdded, self._handle_log_entry("console", handler)) + + def add_javascript_error_handler(self, handler): + self._subscribe_to_log_entries() + return self.conn.add_callback(LogEntryAdded, self._handle_log_entry("javascript", handler)) + + def remove_console_message_handler(self, id): + self.conn.remove_callback(LogEntryAdded, id) + self._unsubscribe_from_log_entries() + + remove_javascript_error_handler = remove_console_message_handler + + def _subscribe_to_log_entries(self): + if not self.log_entry_subscribed: + self.conn.execute(session_subscribe(LogEntryAdded.event_class)) + self.log_entry_subscribed = True + + def _unsubscribe_from_log_entries(self): + if self.log_entry_subscribed and LogEntryAdded.event_class not in self.conn.callbacks: + self.conn.execute(session_unsubscribe(LogEntryAdded.event_class)) + self.log_entry_subscribed = False + + def _handle_log_entry(self, type, handler): + def _handle_log_entry(log_entry): + if log_entry.type_ == type: + handler(log_entry) + + return _handle_log_entry + + +class LogEntryAdded: + event_class = "log.entryAdded" + + @classmethod + def from_json(cls, json): + print(json) + if json["type"] == "console": + return ConsoleLogEntry.from_json(json) + elif json["type"] == "javascript": + return JavaScriptLogEntry.from_json(json) + + +@dataclass +class ConsoleLogEntry: + level: str + text: str + timestamp: str + method: str + args: typing.List[dict] + type_: str + + @classmethod + def from_json(cls, json): + return cls( + level=json["level"], + text=json["text"], + timestamp=json["timestamp"], + method=json["method"], + args=json["args"], + type_=json["type"], + ) + + +@dataclass +class JavaScriptLogEntry: + level: str + text: str + timestamp: str + stacktrace: dict + type_: str + + @classmethod + def from_json(cls, json): + return cls( + level=json["level"], + text=json["text"], + timestamp=json["timestamp"], + stacktrace=json["stackTrace"], + type_=json["type"], + ) diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py new file mode 100644 index 0000000000000..9c334498e5952 --- /dev/null +++ b/py/selenium/webdriver/common/bidi/session.py @@ -0,0 +1,42 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +def session_subscribe(*events, browsing_contexts=[]): + cmd_dict = { + "method": "session.subscribe", + "params": { + "events": events, + }, + } + if browsing_contexts: + cmd_dict["params"]["browsingContexts"] = browsing_contexts + _ = yield cmd_dict + return None + + +def session_unsubscribe(*events, browsing_contexts=[]): + cmd_dict = { + "method": "session.unsubscribe", + "params": { + "events": events, + }, + } + if browsing_contexts: + cmd_dict["params"]["browsingContexts"] = browsing_contexts + _ = yield cmd_dict + return None diff --git a/py/selenium/webdriver/common/options.py b/py/selenium/webdriver/common/options.py index 2f1579115dcce..0066acab76c6d 100644 --- a/py/selenium/webdriver/common/options.py +++ b/py/selenium/webdriver/common/options.py @@ -44,7 +44,13 @@ def __init__(self, name): self.name = name def __get__(self, obj, cls): - if self.name in ("acceptInsecureCerts", "strictFileInteractability", "setWindowRect", "se:downloadsEnabled"): + if self.name in ( + "acceptInsecureCerts", + "strictFileInteractability", + "setWindowRect", + "se:downloadsEnabled", + "webSocketUrl", + ): return obj._caps.get(self.name, False) return obj._caps.get(self.name) @@ -361,6 +367,28 @@ class BaseOptions(metaclass=ABCMeta): - `None` """ + web_socket_url = _BaseOptionsDescriptor("webSocketUrl") + """Gets and Sets WebSocket URL. + + Usage + ----- + - Get + - `self.web_socket_url` + - Set + - `self.web_socket_url` = `value` + + Parameters + ---------- + `value`: `bool` + + Returns + ------- + - Get + - `bool` + - Set + - `None` + """ + def __init__(self) -> None: super().__init__() self._caps = self.default_capabilities diff --git a/py/selenium/webdriver/remote/webdriver.py b/py/selenium/webdriver/remote/webdriver.py index 9e05d5165e8e8..41c4645bdc686 100644 --- a/py/selenium/webdriver/remote/webdriver.py +++ b/py/selenium/webdriver/remote/webdriver.py @@ -41,6 +41,7 @@ from selenium.common.exceptions import NoSuchCookieException from selenium.common.exceptions import NoSuchElementException from selenium.common.exceptions import WebDriverException +from selenium.webdriver.common.bidi.script import Script from selenium.webdriver.common.by import By from selenium.webdriver.common.options import BaseOptions from selenium.webdriver.common.print_page_options import PrintOptions @@ -209,7 +210,9 @@ def __init__( self._authenticator_id = None self.start_client() self.start_session(capabilities) + self._websocket_connection = None + self._script = None def __repr__(self): return f'<{type(self).__module__}.{type(self).__name__} (session="{self.session_id}")>' @@ -1067,6 +1070,24 @@ async def bidi_connection(self): async with conn.open_session(target_id) as session: yield BidiConnection(session, cdp, devtools) + @property + def script(self): + if not self._websocket_connection: + self._start_bidi() + + if not self._script: + self._script = Script(self._websocket_connection) + + return self._script + + def _start_bidi(self): + if self.caps.get("webSocketUrl"): + ws_url = self.caps.get("webSocketUrl") + else: + raise WebDriverException("Unable to find url to connect to from capabilities") + + self._websocket_connection = WebSocketConnection(ws_url) + def _get_cdp_details(self): import json diff --git a/py/selenium/webdriver/remote/websocket_connection.py b/py/selenium/webdriver/remote/websocket_connection.py index 0eb842bb8d8f8..ee0e6ba6d26e6 100644 --- a/py/selenium/webdriver/remote/websocket_connection.py +++ b/py/selenium/webdriver/remote/websocket_connection.py @@ -22,7 +22,7 @@ from websocket import WebSocketApp -logger = logging.getLogger("websocket") +logger = logging.getLogger(__name__) class WebSocketConnection: @@ -32,11 +32,11 @@ class WebSocketConnection: _max_log_message_size = 9999 def __init__(self, url): + self.callbacks = {} self.session_id = None self.url = url self._id = 0 - self._callbacks = {} self._messages = {} self._started = False @@ -57,17 +57,38 @@ def execute(self, command): payload["sessionId"] = self.session_id data = json.dumps(payload) - logger.debug(f"WebSocket -> {data}"[: self._max_log_message_size]) + logger.debug(f"-> {data}"[: self._max_log_message_size]) self._ws.send(data) self._wait_until(lambda: self._id in self._messages) - result = self._messages.pop(self._id)["result"] - return self._deserialize_result(result, command) + response = self._messages.pop(self._id) - def on(self, event, callback): - if event not in self._callbacks: - self._callbacks[event.event_class] = [] - self._callbacks[event.event_class].append(lambda params: callback(event.from_json(params))) + if "error" in response: + raise Exception(response["error"]) + else: + result = response["result"] + return self._deserialize_result(result, command) + + def add_callback(self, event, callback): + event_name = event.event_class + if event_name not in self.callbacks: + self.callbacks[event_name] = [] + + def _callback(params): + callback(event.from_json(params)) + + self.callbacks[event_name].append(_callback) + return id(_callback) + + on = add_callback + + def remove_callback(self, event, callback_id): + event_name = event.event_class + if event_name in self.callbacks: + for callback in self.callbacks[event_name]: + if id(callback) == callback_id: + self.callbacks[event_name].remove(callback) + return def _serialize_command(self, command): return next(command) @@ -87,7 +108,7 @@ def on_message(ws, message): self._process_message(message) def on_error(ws, error): - logger.debug(f"WebSocket error: {error}") + logger.debug(f"error: {error}") ws.close() def run_socket(): @@ -102,14 +123,14 @@ def run_socket(): def _process_message(self, message): message = json.loads(message) - logger.debug(f"WebSocket <- {message}"[: self._max_log_message_size]) + logger.debug(f"<- {message}"[: self._max_log_message_size]) if "id" in message: self._messages[message["id"]] = message if "method" in message: params = message["params"] - for callback in self._callbacks.get(message["method"], []): + for callback in self.callbacks.get(message["method"], []): callback(params) def _wait_until(self, condition): diff --git a/py/test/selenium/webdriver/common/bidi_script_tests.py b/py/test/selenium/webdriver/common/bidi_script_tests.py new file mode 100644 index 0000000000000..f031cb0cd917d --- /dev/null +++ b/py/test/selenium/webdriver/common/bidi_script_tests.py @@ -0,0 +1,72 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +from selenium.webdriver.common.by import By +from selenium.webdriver.support.ui import WebDriverWait + + +@pytest.mark.xfail_safari +def test_logs_console_messages(driver, pages): + log_entries = [] + pages.load("bidi/LogEntryAdded.html") + driver.script.add_console_message_handler(log_entries.append) + + driver.find_element(By.ID, "jsException").click() + driver.find_element(By.ID, "consoleLog").click() + + WebDriverWait(driver, 5).until(lambda _: log_entries) + + log_entry = log_entries[0] + assert log_entry.level == "info" + assert log_entry.method == "log" + assert log_entry.text == "Hello, world!" + assert log_entry.type_ == "console" + + +@pytest.mark.xfail_safari +def test_logs_multiple_console_messages(driver, pages): + log_entries = [] + pages.load("bidi/LogEntryAdded.html") + driver.script.add_console_message_handler(log_entries.append) + driver.script.add_console_message_handler(log_entries.append) + + driver.find_element(By.ID, "jsException").click() + driver.find_element(By.ID, "consoleLog").click() + + WebDriverWait(driver, 5).until(lambda _: len(log_entries) > 1) + assert len(log_entries) == 2 + + +@pytest.mark.xfail_safari +def test_removes_console_message_handler(driver, pages): + log_entries1 = [] + log_entries2 = [] + pages.load("bidi/LogEntryAdded.html") + + id = driver.script.add_console_message_handler(log_entries1.append) + driver.script.add_console_message_handler(log_entries2.append) + + driver.find_element(By.ID, "consoleLog").click() + WebDriverWait(driver, 5).until(lambda _: len(log_entries1) and len(log_entries2)) + + # import pdb; pdb.set_trace() + driver.script.remove_console_message_handler(id) + driver.find_element(By.ID, "consoleLog").click() + + WebDriverWait(driver, 5).until(lambda _: len(log_entries2) == 2) + assert len(log_entries1) == 1 diff --git a/rb/spec/integration/selenium/webdriver/bidi/script_spec.rb b/rb/spec/integration/selenium/webdriver/bidi/script_spec.rb index 34fef2e314393..eb660f9b4d763 100644 --- a/rb/spec/integration/selenium/webdriver/bidi/script_spec.rb +++ b/rb/spec/integration/selenium/webdriver/bidi/script_spec.rb @@ -65,7 +65,7 @@ module WebDriver expect(log_entries.size).to eq(2) end - it 'logs removes console message handler' do + it 'removes console message handler' do driver.navigate.to url_for('bidi/logEntryAdded.html') log_entries = []