From 5eefe0fc9d1281c17313ddb08d5efcbdea7d8141 Mon Sep 17 00:00:00 2001 From: WinPlay02 Date: Mon, 20 Nov 2023 02:14:33 +0100 Subject: [PATCH] feat: execute pipelines in new subprocess and send messages back to main process using a queue --- src/safeds_runner/server/main.py | 151 ++++++++------------- src/safeds_runner/server/messages.py | 56 ++++++++ src/safeds_runner/server/module_manager.py | 120 +++++++++++++--- 3 files changed, 215 insertions(+), 112 deletions(-) create mode 100644 src/safeds_runner/server/messages.py diff --git a/src/safeds_runner/server/main.py b/src/safeds_runner/server/main.py index 47b9228..7d50d30 100644 --- a/src/safeds_runner/server/main.py +++ b/src/safeds_runner/server/main.py @@ -2,15 +2,15 @@ import json import logging -import typing -from typing import Any, Optional +from typing import Any -import stack_data -from flask import Flask, request +from flask import Flask from flask_cors import CORS from flask_sock import Sock -from safeds_runner.server.module_manager import execute_pipeline +from safeds_runner.server import messages +from safeds_runner.server.module_manager import execute_pipeline, get_placeholder, set_new_websocket_target, \ + start_message_queue_handling, setup_multiprocessing app = Flask(__name__) # Websocket Configuration @@ -19,38 +19,18 @@ # Allow access from VSCode extension CORS(app, resources={r"/*": {"origins": "vscode-webview://*"}}) +""" +Args should contain every source file that was generated +code: ["" => ["" => "", ...], ...] +main: {"package": , "module": , "pipeline": } +:return: Tuple: (Result String, HTTP Code) +""" -## HTTP Route -@app.route("/PostRunProgram", methods=["POST"]) -def post_run_program(): - """ - Args should contain every source file that was generated - code: ["" => ["" => "", ...], ...] - main: {"package": , "module": , "pipeline": } - :return: Tuple: (Result String, HTTP Code) - """ - logging.debug(f"{request.path}: {request.form}") - if not request.is_json: - return "Body is not JSON", 400 - request_data = request.get_json() - # Validate - valid, invalid_message = validate_message(request_data) - if not valid: - return invalid_message, 400 - code = request_data["code"] - main = request_data["main"] - # Execute - # Dynamically define Safe-DS Modules only in our runtime scope - # TODO forward memoization map here - context_globals = {} - # This should only be called from the extension as it is a security risk - result = execute_pipeline(code, main['package'], main['module'], main['pipeline'], context_globals) - return json.dumps(result), 200 - - -@sock.route("/WSRunProgram") + +@sock.route("/WSMain") def ws_run_program(ws): logging.debug(f"Request to WSRunProgram") + set_new_websocket_target(ws) while True: # This would be a JSON message received_message: str = ws.receive() @@ -65,76 +45,62 @@ def ws_run_program(ws): logging.warn(f"No message type specified in: {received_message}") ws.close(None) return - match received_object["type"]: + if "id" not in received_object: + logging.warn(f"No message id specified in: {received_message}") + ws.close(None) + return + if "data" not in received_object: + logging.warn(f"No message data specified in: {received_message}") + ws.close(None) + return + if not isinstance(received_object["type"], str): + logging.warn(f"Message type is not a string: {received_message}") + ws.close(None) + return + if not isinstance(received_object["id"], str): + logging.warn(f"Message id is not a string: {received_message}") + ws.close(None) + return + request_data = received_object["data"] + message_type = received_object["type"] + execution_id = received_object["id"] + match message_type: case "program": - if "data" not in received_object: - logging.warn(f"No message data specified in: {received_message}") - ws.close(None) - return - request_data = received_object["data"] - valid, invalid_message = validate_message(request_data) + valid, invalid_message = messages.validate_program_message(request_data) if not valid: logging.warn(f"Invalid message data specified in: {received_message} ({invalid_message})") ws.close(None) return code = request_data["code"] main = request_data["main"] - # Execute - # Dynamically define Safe-DS Modules only in our runtime scope - # TODO forward memoization map here - context_globals = {"connection": ws, "send_value": send_value} # This should only be called from the extension as it is a security risk - try: - execute_pipeline(code, main['package'], main['module'], main['pipeline'], context_globals) - send_message(ws, "progress", "done") - except BaseException as error: - send_message(ws, "runtime_error", - {"message": error.__str__(), "backtrace": get_backtrace_info(error)}) - - -def get_backtrace_info(error: BaseException) -> list[dict[str, typing.Any]]: - backtrace_list = [] - for frame in stack_data.core.FrameInfo.stack_data(error.__traceback__): - backtrace_list.append({"file": frame.filename, "line": str(frame.lineno)}) - return backtrace_list + execute_pipeline(code, main['package'], main['module'], main['pipeline'], execution_id) + case "placeholder_query": + valid, invalid_message = messages.validate_placeholder_query_message(request_data) + if not valid: + logging.warn(f"Invalid message data specified in: {received_message} ({invalid_message})") + ws.close(None) + return + placeholder_type, placeholder_value = get_placeholder(execution_id, request_data) + if placeholder_type is not None: + send_websocket_value(ws, request_data, placeholder_type, placeholder_value) + else: + # Send back empty type / value, to communicate that no placeholder exists (yet) + send_websocket_value(ws, request_data, "", "") + case _: + if message_type not in messages.message_types: + logging.warn(f"Invalid message type {message_type}") -def send_value(connection, name: str, var_type: str, value: str): - send_message(connection, "value", {"name": name, "type": var_type, "value": value}) +def send_websocket_value(connection, name: str, var_type: str, value: str): + send_websocket_message(connection, "value", {"name": name, "type": var_type, "value": value}) -def send_message(connection, msg_type: str, msg_data): +def send_websocket_message(connection, msg_type: str, msg_data): message = {"type": msg_type, "data": msg_data} connection.send(json.dumps(message)) -def validate_message(message: dict[str, Any]) -> (bool, Optional[str]): - if "code" not in message: - return False, "No 'code' parameter given" - if "main" not in message: - return False, "No 'main' parameter given" - if "package" not in message["main"] or "module" not in message["main"] or "pipeline" not in message["main"]: - return False, "Invalid 'main' parameter given" - if len(message["main"]) != 3: - return False, "Invalid 'main' parameter given" - main: dict[str, str] = message["main"] - if not isinstance(message["code"], dict): - return False, "Invalid 'code' parameter given" - code: dict = message["code"] - for key in code.keys(): - if not isinstance(key, str): - return False, "Invalid 'code' parameter given" - if not isinstance(code[key], dict): - return False, "Invalid 'code' parameter given" - next_dict: dict = code[key] - for next_key in next_dict.keys(): - if not isinstance(next_key, str): - return False, "Invalid 'code' parameter given" - if not isinstance(next_dict[next_key], str): - return False, "Invalid 'code' parameter given" - return True, None - - if __name__ == "__main__": # Allow prints to be unbuffered by default import functools @@ -143,14 +109,13 @@ def validate_message(message: dict[str, Any]) -> (bool, Optional[str]): builtins.print = functools.partial(print, flush=True) logging.getLogger().setLevel(logging.DEBUG) - from gevent import monkey - - monkey.patch_all() from gevent.pywsgi import WSGIServer parser = argparse.ArgumentParser(description="Start Safe-DS Runner on a specific port.") parser.add_argument('--port', type=int, default=5000, help='Port on which to run the python server.') args = parser.parse_args() + setup_multiprocessing() + start_message_queue_handling() logging.info(f"Starting Safe-DS Runner on port {args.port}") - # TODO Maybe only bind to host=127.0.0.1? Connections from other devices would then not be accepted - WSGIServer(('0.0.0.0', args.port), app).serve_forever() + # Only bind to host=127.0.0.1. Connections from other devices should not be accepted + WSGIServer(('127.0.0.1', args.port), app).serve_forever() diff --git a/src/safeds_runner/server/messages.py b/src/safeds_runner/server/messages.py new file mode 100644 index 0000000..338d4a7 --- /dev/null +++ b/src/safeds_runner/server/messages.py @@ -0,0 +1,56 @@ +import typing + +message_types = ["program", "placeholder_query", "placeholder", "placeholder_value", "runtime_error", + "runtime_progress"] + + +def create_placeholder_description(name: str, placeholder_type: str) -> dict[str, typing.Any]: + return {"name": name, "type": placeholder_type} + + +def create_placeholder_value(name: str, placeholder_type: str, value: str) -> dict[str, typing.Any]: + return {"name": name, "type": placeholder_type, "value": value} + + +def create_runtime_error_description(message: str, backtrace: list[dict[str, typing.Any]]) -> dict[str, typing.Any]: + return {"message": message, "backtrace": backtrace} + + +def create_runtime_progress_done() -> str: + return "done" + + +def validate_program_message(message_data: dict[str, typing.Any] | str) -> (bool, typing.Optional[str]): + if not isinstance(message_data, dict): + return False, "Message data is not a JSON object" + if "code" not in message_data: + return False, "No 'code' parameter given" + if "main" not in message_data: + return False, "No 'main' parameter given" + if "package" not in message_data["main"] or "module" not in message_data["main"] or "pipeline" not in message_data[ + "main"]: + return False, "Invalid 'main' parameter given" + if len(message_data["main"]) != 3: + return False, "Invalid 'main' parameter given" + main: dict[str, str] = message_data["main"] + if not isinstance(message_data["code"], dict): + return False, "Invalid 'code' parameter given" + code: dict = message_data["code"] + for key in code.keys(): + if not isinstance(key, str): + return False, "Invalid 'code' parameter given" + if not isinstance(code[key], dict): + return False, "Invalid 'code' parameter given" + next_dict: dict = code[key] + for next_key in next_dict.keys(): + if not isinstance(next_key, str): + return False, "Invalid 'code' parameter given" + if not isinstance(next_dict[next_key], str): + return False, "Invalid 'code' parameter given" + return True, None + + +def validate_placeholder_query_message(message_data: dict[str, typing.Any] | str) -> (bool, typing.Optional[str]): + if not isinstance(message_data, str): + return False, "Message data is not a string" + return True, None diff --git a/src/safeds_runner/server/module_manager.py b/src/safeds_runner/server/module_manager.py index db2cf62..9dd7ea6 100644 --- a/src/safeds_runner/server/module_manager.py +++ b/src/safeds_runner/server/module_manager.py @@ -1,5 +1,7 @@ import importlib.abc -import typing +import multiprocessing +import threading +import queue from abc import ABC from importlib.machinery import ModuleSpec import sys @@ -7,9 +9,22 @@ import types import runpy import logging +import typing +import json import stack_data +multiprocessing_manager = None +placeholder_map = None +messages_queue: queue.Queue | None = None + + +def setup_multiprocessing(): + global multiprocessing_manager, placeholder_map, messages_queue + multiprocessing_manager = multiprocessing.Manager() + placeholder_map = multiprocessing_manager.dict() + messages_queue = multiprocessing_manager.Queue() + class InMemoryLoader(importlib.abc.SourceLoader, ABC): def __init__(self, code_bytes: bytes, filename: str): @@ -74,21 +89,88 @@ def detach(self): sys.meta_path.remove(self) -def _execute_pipeline(code: dict[str, dict[str, str]], sdspackage: str, sdsmodule: str, sdspipeline: str): - pipeline_finder = InMemoryFinder(code) - pipeline_finder.attach() - main_module = f"gen_{sdsmodule}_{sdspipeline}" - try: - runpy.run_module(main_module, run_name="__main__") # TODO Is the Safe-DS-Package relevant here? - except BaseException: - raise # This should keep the backtrace - finally: - pipeline_finder.detach() - - -def execute_pipeline(code: dict[str, dict[str, str]], sdspackage: str, sdsmodule: str, sdspipeline: str, - context_globals: dict): - logging.info(f"Executing {sdspackage}.{sdsmodule}.{sdspipeline}...") - exec('_execute_pipeline(code, sdspackage, sdsmodule, sdspipeline)', context_globals, - {"code": code, "sdspackage": sdspackage, "sdsmodule": sdsmodule, "sdspipeline": sdspipeline, - "_execute_pipeline": _execute_pipeline, "runpy": runpy}) +class PipelineProcess: + def __init__(self, code: dict[str, dict[str, str]], sdspackage: str, sdsmodule: str, sdspipeline: str, + execution_id: str, messages_queue: queue.Queue): + self.code = code + self.sdspackage = sdspackage + self.sdsmodule = sdsmodule + self.sdspipeline = sdspipeline + self.id = execution_id + self.messages_queue = messages_queue + self.process = multiprocessing.Process(target=self._execute, daemon=True) + + def _send_message(self, message_type: str, value: dict[typing.Any, typing.Any] | str) -> None: + global messages_queue + self.messages_queue.put({"type": message_type, "id": self.id, "data": value}) + + def _send_exception(self, exception: BaseException): + backtrace = get_backtrace_info(exception) + self._send_message("runtime_error", {"message": exception.__str__(), "backtrace": backtrace}) + + def _execute(self): + logging.info(f"Executing {self.sdspackage}.{self.sdsmodule}.{self.sdspipeline}...") + pipeline_finder = InMemoryFinder(self.code) + pipeline_finder.attach() + main_module = f"gen_{self.sdsmodule}_{self.sdspipeline}" + try: + runpy.run_module(main_module, run_name="__main__") # TODO Is the Safe-DS-Package relevant here? + self._send_message("progress", "done") + except BaseException as error: + self._send_exception(error) + finally: + pipeline_finder.detach() + + def execute(self): + self.process.start() + + +def get_backtrace_info(error: BaseException) -> list[dict[str, typing.Any]]: + backtrace_list = [] + for frame in stack_data.core.FrameInfo.stack_data(error.__traceback__): + backtrace_list.append({"file": frame.filename, "line": str(frame.lineno)}) + return backtrace_list + + +def execute_pipeline(code: dict[str, dict[str, str]], sdspackage: str, sdsmodule: str, sdspipeline: str, exec_id: str): + global messages_queue + process = PipelineProcess(code, sdspackage, sdsmodule, sdspipeline, exec_id, messages_queue) + process.execute() + + +def get_placeholder(exec_id: str, placeholder_name: str) -> (str | None, typing.Any): + if exec_id not in placeholder_map: + return None, None + if placeholder_name not in placeholder_map[exec_id]: + return None, None + # TODO type + return "anytype", placeholder_map[exec_id][placeholder_name] + + +def save_placeholder(exec_id: str, placeholder_name: str, value: typing.Any) -> None: + if exec_id not in placeholder_map: + placeholder_map[exec_id] = {} + placeholder_map[exec_id][placeholder_name] = value + + +websocket_target = None +messages_queue_thread = None + + +def handle_queue_messages(): + global websocket_target + while True: + message = messages_queue.get() + if websocket_target is not None: + websocket_target.send(json.dumps(message)) + + +def start_message_queue_handling(): + global messages_queue_thread + messages_queue_thread = threading.Thread(target=handle_queue_messages, daemon=True) + messages_queue_thread.start() + + +def set_new_websocket_target(ws): + global websocket_target + websocket_target = ws