diff --git a/CHANGELOG b/CHANGELOG index 954a2329..eb517fec 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -2,3 +2,4 @@ [das-atom-db#216] Removed cursor from api of get_matched*() and get_incoming_links() [HOTFIX] Fixes after DAS api change in atom-db/#216 [das/#106] Change patterns and templates indexes to store handles only rather than handles+targets +[#333] Move handshake validation to the client-side. The server will now only send the expected library versions. diff --git a/hyperon_das/utils.py b/hyperon_das/utils.py index 0027f3a2..8a975c60 100644 --- a/hyperon_das/utils.py +++ b/hyperon_das/utils.py @@ -1,3 +1,4 @@ +import re import pickle from dataclasses import dataclass from http import HTTPStatus # noqa: F401 @@ -55,7 +56,9 @@ def __lt__(self, other) -> bool: return self.hashcode < other.hashcode def __repr__(self) -> str: - return str([tuple([label, self.mapping[label]]) for label in sorted(self.labels)]) + return str( + [tuple([label, self.mapping[label]]) for label in sorted(self.labels)] + ) def __str__(self) -> str: return self.__repr__() @@ -86,7 +89,7 @@ def assign( if label in self.labels: return self.mapping[label] == value else: - if parameters and parameters['no_overload'] and value in self.values: + if parameters and parameters["no_overload"] and value in self.values: return False self.labels.add(label) self.values.add(value) @@ -116,17 +119,17 @@ class QueryAnswer: assignment: Optional[Assignment] = None def _recursive_get_handle_set(self, atom, handle_set): - handle_set.add(atom['handle']) - targets = atom.get('targets', None) + handle_set.add(atom["handle"]) + targets = atom.get("targets", None) if targets is not None: for target_atom in targets: self._recursive_get_handle_set(target_atom, handle_set) def _recursive_get_handle_count(self, atom, handle_count): - key = atom['handle'] + key = atom["handle"] total = handle_count.get(key, 0) + 1 handle_count[key] = total - targets = atom.get('targets', None) + targets = atom.get("targets", None) if targets is not None: for target_atom in targets: self._recursive_get_handle_count(target_atom, handle_count) @@ -146,7 +149,7 @@ def get_handle_count(self): def get_package_version(package_name: str) -> str: package_module = import_module(package_name) - return getattr(package_module, '__version__', None) + return getattr(package_module, "__version__", None) def serialize(payload: Any) -> bytes: @@ -160,9 +163,9 @@ def deserialize(payload: bytes) -> Any: @retry(attempts=5, timeout_seconds=120) def connect_to_server(host: str, port: int) -> Tuple[int, str]: """Connect to the server and return the status connection and the url server""" - port = port or '8081' - openfaas_uri = f'http://{host}:{port}/function/query-engine' - aws_lambda_uri = f'http://{host}/prod/query-engine' + port = port or "8081" + openfaas_uri = f"http://{host}:{port}/function/query-engine" + aws_lambda_uri = f"http://{host}/prod/query-engine" for uri in [openfaas_uri, aws_lambda_uri]: status_code, message = check_server_connection(uri) @@ -175,50 +178,97 @@ def connect_to_server(host: str, port: int) -> Tuple[int, str]: def check_server_connection(url: str) -> Tuple[int, str]: - logger().debug(f'connecting to remote Das {url}') - error_msg = None + logger().debug(f"Connecting to remote DAS {url}") + try: das_version = get_package_version('hyperon_das') atom_db_version = get_package_version('hyperon_das_atomdb') with sessions.Session() as session: payload = { - 'action': 'handshake', - 'input': { - 'das_version': das_version, - 'atomdb_version': atom_db_version, - }, + "action": "handshake", + "input": {}, } - response = session.request( - method='POST', + response = session.post( url=url, data=serialize(payload), - headers={'Content-Type': 'application/octet-stream'}, + headers={"Content-Type": "application/octet-stream"}, timeout=10, ) - if response.status_code == HTTPStatus.CONFLICT: - try: - response = deserialize(response.content) - remote_das_version = response.get('das').get('version') - remote_atomdb_version = response.get('atom_db').get('version') - except JSONDecodeError as e: - raise Exception(str(e)) - logger().error( - f"Package version conflict error when connecting to remote DAS. Local DAS: 'das: {das_version} - atom_db: {atom_db_version}' -- Remote DAS: 'das: {remote_das_version} - atom_db: {remote_atomdb_version}'" + + response.raise_for_status() + + remote_data = deserialize(response.content) + remote_das_version = remote_data.get("das", {}).get("version") + remote_atomdb_version = remote_data.get("atom_db", {}).get("version") + + if not remote_das_version or not remote_atomdb_version: + raise ValueError("Invalid response from server, missing version info.") + + is_atomdb_compatible = compare_minor_versions( + remote_atomdb_version, + atom_db_version, + ) + is_das_compatible = compare_minor_versions( + remote_das_version, + das_version, + ) + + if not is_atomdb_compatible or not is_das_compatible: + local_versions = f"das: {das_version}, atom_db: {atom_db_version}" + remote_versions = ( + f"das: {remote_das_version}, atom_db: {remote_atomdb_version}" ) - raise Exception( - f"Local DAS version 'das: {das_version} and atom_db: {atom_db_version}', DAS server is expecting 'das: {remote_das_version} and atom_db: {remote_atomdb_version}'" + error_message = ( + f"Version mismatch. Local: {local_versions}. " + f"Remote: {remote_versions}." ) - if response.status_code == HTTPStatus.OK: - return response.status_code, "Successful connection" - else: - try: - error_msg = deserialize(response.content).get('error') - response.raise_for_status() - except pickle.UnpicklingError: - raise Exception("Error unpickling objects in peer's response") + logger().error(error_message) + raise Exception(error_message) + + return response.status_code, "Successful connection" + + except pickle.UnpicklingError: + logger().error("Failed to unpickle response from server.") + return 500, "Error unpickling objects in server response" except (ConnectionError, Timeout, HTTPError, RequestException) as e: - msg = f"{error_msg} - {str(e)}" if error_msg else str(e) - return 400, msg + logger().error(f"Connection error: {str(e)}") + return 400, f"Connection failed: {str(e)}" except Exception as e: + logger().error(f"Unexpected error: {str(e)}") return 500, str(e) + + +def get_version_components(version_string: str) -> Union[Tuple[int, int, int], None]: + pattern = r"^(\d+)\.(\d+)\.(\d+)$" + match = re.match(pattern, version_string) + + if match: + return tuple(map(int, match.groups())) + + return None + + +def compare_versions(version1: str, version2: str, component_index: int) -> Union[int, None]: + components1 = get_version_components(version1) + components2 = get_version_components(version2) + + if components1 is None or components2 is None: + return None + + for i in range(component_index + 1): + if components1[i] != components2[i]: + return False + + return True + +def compare_major_versions(version1: str, version2: str) -> Union[int, None]: + return compare_versions(version1, version2, 0) + + +def compare_minor_versions(version1: str, version2: str) -> Union[int, None]: + return compare_versions(version1, version2, 1) + + +def compare_patch_versions(version1: str, version2: str) -> Union[int, None]: + return compare_versions(version1, version2, 2) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 04b21a7e..f8c0153f 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,7 +1,14 @@ import pytest from hyperon_das.exceptions import InvalidAssignment -from hyperon_das.utils import Assignment, QueryAnswer +from hyperon_das.utils import ( + Assignment, + QueryAnswer, + compare_major_versions, + compare_minor_versions, + compare_patch_versions, + get_version_components, +) def _build_assignment(mappings): @@ -155,54 +162,109 @@ def _check_handle_set(self, atom, handles, count): def test_get_handle_stats(self): self._check_handle_set(None, set([]), []) - self._check_handle_set({'handle': 'h1'}, ['h1'], [1]) + self._check_handle_set({"handle": "h1"}, ["h1"], [1]) self._check_handle_set( { - 'handle': 'h1', - 'targets': [ - {'handle': 'h2'}, - {'handle': 'h3'}, + "handle": "h1", + "targets": [ + {"handle": "h2"}, + {"handle": "h3"}, ], }, - ['h1', 'h2', 'h3'], + ["h1", "h2", "h3"], [1, 1, 1], ) self._check_handle_set( { - 'handle': 'h1', - 'targets': [ - {'handle': 'h2'}, - {'handle': 'h1'}, + "handle": "h1", + "targets": [ + {"handle": "h2"}, + {"handle": "h1"}, ], }, - ['h1', 'h2'], + ["h1", "h2"], [2, 1], ) self._check_handle_set( { - 'handle': 'h1', - 'targets': [ - {'handle': 'h2'}, + "handle": "h1", + "targets": [ + {"handle": "h2"}, { - 'handle': 'h2', - 'targets': [ - {'handle': 'h4'}, - {'handle': 'h1'}, + "handle": "h2", + "targets": [ + {"handle": "h4"}, + {"handle": "h1"}, ], }, { - 'handle': 'h5', - 'targets': [ - {'handle': 'h1'}, - {'handle': 'h6'}, + "handle": "h5", + "targets": [ + {"handle": "h1"}, + {"handle": "h6"}, ], }, - {'handle': 'h3'}, + {"handle": "h3"}, ], }, - ['h1', 'h2', 'h3', 'h4', 'h5', 'h6'], + ["h1", "h2", "h3", "h4", "h5", "h6"], [3, 2, 1, 1, 1, 1], ) + + +@pytest.mark.parametrize( + "version_string, expected", + [ + ("1.2.3", (1, 2, 3)), + ("10.20.30", (10, 20, 30)), + ("0.0.0", (0, 0, 0)), + ("invalid", None), + ("1.2", None), + ("1.2.3.4", None), + ], +) +def test_get_version_components(version_string, expected): + assert get_version_components(version_string) == expected + + +@pytest.mark.parametrize( + "version1, version2, expected", + [ + ("1.8.0", "1.8.1", True), + ("1.8.0", "1.9.0", False), + ("1.8.0", "2.8.0", False), + ("1.8.0", "1.8.0", True), + ("1.8", "1.8.0", None), + ], +) +def test_compare_minor_versions(version1, version2, expected): + assert compare_minor_versions(version1, version2) == expected + + +@pytest.mark.parametrize( + "version1, version2, expected", + [ + ("1.8.0", "1.8.1", False), + ("1.8.0", "1.8.0", True), + ("2.8.0", "1.8.0", False), + ("1.8", "1.8.0", None), + ], +) +def test_compare_patch_versions(version1, version2, expected): + assert compare_patch_versions(version1, version2) == expected + + +@pytest.mark.parametrize( + "version1, version2, expected", + [ + ("1.8.0", "1.8.1", True), + ("1.8.0", "2.8.0", False), + ("1.8.0", "1.7.0", True), + ("invalid", "1.8.0", None), + ], +) +def test_compare_major_versions(version1, version2, expected): + assert compare_major_versions(version1, version2) == expected