Skip to content

Commit

Permalink
fix: allow multiple connections to work with the runner (#31)
Browse files Browse the repository at this point in the history
- fix: allow multiple websocket connections in parallel (custom
threading / multiprocessing is initialized before patching for use with
gevent)
- feat: send events (during execution) to all connected (listening)
websocket connections
- Queries (placeholder_query) is still using a request-response pattern
and does not multicast to all connections
- green thread pool size for connections is limited to 8 for now

---------

Co-authored-by: megalinter-bot <[email protected]>
Co-authored-by: Lars Reimann <[email protected]>
  • Loading branch information
3 people authored Dec 16, 2023
1 parent 054cca4 commit 64685a3
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 23 deletions.
16 changes: 15 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ gevent = "^23.9.1"
[tool.poetry.dev-dependencies]
pytest = "^7.4.3"
pytest-cov = "^4.1.0"
pytest-timeout = "^2.2.0"

[tool.poetry.group.docs.dependencies]
mkdocs = "^1.4.3"
Expand Down
38 changes: 25 additions & 13 deletions src/safeds_runner/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,18 +87,20 @@ def ws_main(ws: simple_websocket.Server, pipeline_manager: PipelineManager) -> N
Manager used to execute pipelines on, and retrieve placeholders from
"""
logging.debug("Request to WSRunProgram")
pipeline_manager.set_new_websocket_target(ws)
pipeline_manager.connect(ws)
while True:
# This would be a JSON message
received_message: str = ws.receive()
if received_message is None:
logging.debug("Received EOF, closing connection")
pipeline_manager.disconnect(ws)
ws.close()
return
logging.debug("Received Message: %s", received_message)
received_object, error_detail, error_short = parse_validate_message(received_message)
if received_object is None:
logging.error(error_detail)
pipeline_manager.disconnect(ws)
ws.close(message=error_short)
return
match received_object.type:
Expand All @@ -110,16 +112,19 @@ def ws_main(ws: simple_websocket.Server, pipeline_manager: PipelineManager) -> N
program_data, invalid_message = messages.validate_program_message_data(received_object.data)
if program_data is None:
logging.error("Invalid message data specified in: %s (%s)", received_message, invalid_message)
pipeline_manager.disconnect(ws)
ws.close(None, invalid_message)
return
# This should only be called from the extension as it is a security risk
pipeline_manager.execute_pipeline(program_data, received_object.id)
case "placeholder_query":
# For this query, a response can be directly sent to the requesting connection
placeholder_query_data, invalid_message = messages.validate_placeholder_query_message_data(
received_object.data,
)
if placeholder_query_data is None:
logging.error("Invalid message data specified in: %s (%s)", received_message, invalid_message)
pipeline_manager.disconnect(ws)
ws.close(None, invalid_message)
return
placeholder_type, placeholder_value = pipeline_manager.get_placeholder(
Expand All @@ -129,8 +134,8 @@ def ws_main(ws: simple_websocket.Server, pipeline_manager: PipelineManager) -> N
# send back a value message
if placeholder_type is not None:
try:
send_websocket_message(
ws,
broadcast_message(
[ws],
Message(
message_type_placeholder_value,
received_object.id,
Expand All @@ -139,8 +144,8 @@ def ws_main(ws: simple_websocket.Server, pipeline_manager: PipelineManager) -> N
)
except TypeError as _encoding_error:
# if the value can't be encoded send back that the value exists but is not displayable
send_websocket_message(
ws,
broadcast_message(
[ws],
Message(
message_type_placeholder_value,
received_object.id,
Expand All @@ -150,8 +155,8 @@ def ws_main(ws: simple_websocket.Server, pipeline_manager: PipelineManager) -> N
else:
# Send back empty type / value, to communicate that no placeholder exists (yet)
# Use name from query to allow linking a response to a request on the peer
send_websocket_message(
ws,
broadcast_message(
[ws],
Message(
message_type_placeholder_value,
received_object.id,
Expand All @@ -163,18 +168,20 @@ def ws_main(ws: simple_websocket.Server, pipeline_manager: PipelineManager) -> N
logging.warning("Invalid message type: %s", received_object.type)


def send_websocket_message(connection: simple_websocket.Server, message: Message) -> None:
def broadcast_message(connections: list[simple_websocket.Server], message: Message) -> None:
"""
Send any message to the VS Code extension.
Send any message to all the provided connections (to the VS Code extension).
Parameters
----------
connection : simple_websocket.Server
Websocket connection.
connections : list[simple_websocket.Server]
List of Websocket connections that should receive the message.
message : Message
Object that will be sent.
"""
connection.send(json.dumps(message.to_dict(), cls=SafeDsEncoder))
message_encoded = json.dumps(message.to_dict(), cls=SafeDsEncoder)
for connection in connections:
connection.send(message_encoded)


def start_server(port: int) -> None: # pragma: no cover
Expand All @@ -186,8 +193,13 @@ def start_server(port: int) -> None: # pragma: no cover
builtins.print = functools.partial(print, flush=True) # type: ignore[assignment]

logging.getLogger().setLevel(logging.DEBUG)
# Startup early, so our multiprocessing setup works
app_pipeline_manager.startup()
from gevent.monkey import patch_all
from gevent.pywsgi import WSGIServer

# Patch WebSockets to work in parallel
patch_all()
logging.info("Starting Safe-DS Runner on port %s", str(port))
# Only bind to host=127.0.0.1. Connections from other devices should not be accepted
WSGIServer(("127.0.0.1", port), app).serve_forever()
WSGIServer(("127.0.0.1", port), app, spawn=8).serve_forever()
29 changes: 21 additions & 8 deletions src/safeds_runner/server/pipeline_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class PipelineManager:
def __init__(self) -> None:
"""Create a new PipelineManager object, which is lazily started, when needed."""
self._placeholder_map: dict = {}
self._websocket_target: simple_websocket.Server | None = None
self._websocket_target: list[simple_websocket.Server] = []

@cached_property
def _multiprocessing_manager(self) -> SyncManager:
Expand All @@ -56,7 +56,7 @@ def _messages_queue_thread(self) -> threading.Thread:
daemon=True,
)

def _startup(self) -> None:
def startup(self) -> None:
"""
Prepare the runner for running Safe-DS pipelines.
Expand All @@ -80,21 +80,34 @@ def _handle_queue_messages(self) -> None:
try:
while self._messages_queue is not None:
message = self._messages_queue.get()
if self._websocket_target is not None:
self._websocket_target.send(json.dumps(message.to_dict()))
message_encoded = json.dumps(message.to_dict())
# only send messages to the same connection once
for connection in set(self._websocket_target):
connection.send(message_encoded)
except BaseException as error: # noqa: BLE001 # pragma: no cover
logging.warning("Message queue terminated: %s", error.__repr__()) # pragma: no cover

def set_new_websocket_target(self, websocket_connection: simple_websocket.Server) -> None:
def connect(self, websocket_connection: simple_websocket.Server) -> None:
"""
Change the websocket connection to relay messages to, which are occurring during pipeline execution.
Add a websocket connection to relay event messages to, which are occurring during pipeline execution.
Parameters
----------
websocket_connection : simple_websocket.Server
New websocket connection.
"""
self._websocket_target = websocket_connection
self._websocket_target.append(websocket_connection)

def disconnect(self, websocket_connection: simple_websocket.Server) -> None:
"""
Remove a websocket target connection to no longer receive messages.
Parameters
----------
websocket_connection : simple_websocket.Server
Websocket connection to be removed.
"""
self._websocket_target.remove(websocket_connection)

def execute_pipeline(
self,
Expand All @@ -111,7 +124,7 @@ def execute_pipeline(
execution_id : str
Unique ID to identify this execution.
"""
self._startup()
self.startup()
if execution_id not in self._placeholder_map:
self._placeholder_map[execution_id] = self._multiprocessing_manager.dict()
process = PipelineProcess(
Expand Down
52 changes: 51 additions & 1 deletion tests/safeds_runner/server/test_websocket_mock.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import json
import logging
import multiprocessing
import os
import sys
import threading
import time

import pytest
import safeds_runner.server.main
import simple_websocket
from safeds_runner.server.main import app_pipeline_manager, ws_main
from safeds_runner.server.messages import (
Message,
Expand Down Expand Up @@ -170,6 +174,7 @@ def get_next_received_message(self) -> str:
)
def test_should_fail_message_validation(websocket_message: str, exception_message: str) -> None:
mock_connection = MockWebsocketConnection([websocket_message])
app_pipeline_manager.connect(mock_connection)
ws_main(mock_connection, app_pipeline_manager)
assert str(mock_connection.close_message) == exception_message

Expand Down Expand Up @@ -212,6 +217,7 @@ def test_should_execute_pipeline_return_exception(
expected_response_runtime_error: Message,
) -> None:
mock_connection = MockWebsocketConnection(messages)
app_pipeline_manager.connect(mock_connection)
ws_main(mock_connection, app_pipeline_manager)
mock_connection.wait_for_messages(1)
exception_message = Message.from_dict(json.loads(mock_connection.get_next_received_message()))
Expand Down Expand Up @@ -300,6 +306,7 @@ def test_should_execute_pipeline_return_valid_placeholder(
) -> None:
# Initial execution
mock_connection = MockWebsocketConnection(initial_messages)
app_pipeline_manager.connect(mock_connection)
ws_main(mock_connection, app_pipeline_manager)
# Wait for at least enough messages to successfully execute pipeline
mock_connection.wait_for_messages(initial_execution_message_wait)
Expand Down Expand Up @@ -374,6 +381,7 @@ def test_should_execute_pipeline_return_valid_placeholder(
)
def test_should_successfully_execute_simple_flow(messages: list[str], expected_response: Message) -> None:
mock_connection = MockWebsocketConnection(messages)
app_pipeline_manager.connect(mock_connection)
ws_main(mock_connection, app_pipeline_manager)
mock_connection.wait_for_messages(1)
query_result_invalid = Message.from_dict(json.loads(mock_connection.get_next_received_message()))
Expand Down Expand Up @@ -408,4 +416,46 @@ def helper_should_shut_itself_down_run_in_subprocess(sub_messages: list[str]) ->
ws_main(mock_connection, PipelineManager())


helper_should_shut_itself_down_run_in_subprocess.__test__ = False # type: ignore[attr-defined]
@pytest.mark.timeout(45)
def test_should_accept_at_least_2_parallel_connections_in_subprocess() -> None:
port = 6000
server_output_pipes_stderr_r, server_output_pipes_stderr_w = multiprocessing.Pipe()
process = multiprocessing.Process(
target=helper_should_accept_at_least_2_parallel_connections_in_subprocess_server,
args=(port, server_output_pipes_stderr_w),
)
process.start()
while process.is_alive():
if not server_output_pipes_stderr_r.poll(0.1):
continue
process_line = str(server_output_pipes_stderr_r.recv()).strip()
# Wait for first line of log
if process_line.startswith("INFO:root:Starting Safe-DS Runner"):
break
connected = False
client1 = None
for _i in range(10):
try:
client1 = simple_websocket.Client.connect(f"ws://127.0.0.1:{port}/WSMain")
client2 = simple_websocket.Client.connect(f"ws://127.0.0.1:{port}/WSMain")
connected = client1.connected and client2.connected
break
except ConnectionRefusedError as e:
logging.warning("Connection refused: %s", e)
connected = False
time.sleep(0.5)
if client1 is not None and client1.connected:
client1.send('{"id": "", "type": "shutdown", "data": ""}')
process.join(5)
if process.is_alive():
process.kill()
assert connected


def helper_should_accept_at_least_2_parallel_connections_in_subprocess_server(
port: int,
pipe: multiprocessing.connection.Connection,
) -> None:
sys.stderr.write = lambda value: pipe.send(value) # type: ignore[method-assign, assignment]
sys.stdout.write = lambda value: pipe.send(value) # type: ignore[method-assign, assignment]
safeds_runner.server.main.start_server(port)

0 comments on commit 64685a3

Please sign in to comment.