Skip to content

Commit

Permalink
[#333] Review the way our libraries are checked in client/server conn…
Browse files Browse the repository at this point in the history
…ection (#350)

* das-query-engine-333: Move handshake validation to the client-side.

* das-query-engine-333: Add tests for version comparison utilities

* das-query-engine-333: Update the http status

Co-authored-by: Pedro Borges Costa <[email protected]>

* das-query-engine-333: Update CHANGELOG file

---------

Co-authored-by: Pedro Borges Costa <[email protected]>
  • Loading branch information
levisingularity and Pedrobc89 authored Oct 2, 2024
1 parent e82d506 commit ea93212
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 66 deletions.
1 change: 1 addition & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
[das-atom-db#216] Removed cursor from api of get_matched*() and get_incoming_links()
[HOTFIX] Fixes after DAS api change in atom-db/#216
[das/#106] Change patterns and templates indexes to store handles only rather than handles+targets
[#333] Move handshake validation to the client-side. The server will now only send the expected library versions.
132 changes: 91 additions & 41 deletions hyperon_das/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import pickle
from dataclasses import dataclass
from http import HTTPStatus # noqa: F401
Expand Down Expand Up @@ -55,7 +56,9 @@ def __lt__(self, other) -> bool:
return self.hashcode < other.hashcode

def __repr__(self) -> str:
return str([tuple([label, self.mapping[label]]) for label in sorted(self.labels)])
return str(
[tuple([label, self.mapping[label]]) for label in sorted(self.labels)]
)

def __str__(self) -> str:
return self.__repr__()
Expand Down Expand Up @@ -86,7 +89,7 @@ def assign(
if label in self.labels:
return self.mapping[label] == value
else:
if parameters and parameters['no_overload'] and value in self.values:
if parameters and parameters["no_overload"] and value in self.values:
return False
self.labels.add(label)
self.values.add(value)
Expand Down Expand Up @@ -116,17 +119,17 @@ class QueryAnswer:
assignment: Optional[Assignment] = None

def _recursive_get_handle_set(self, atom, handle_set):
handle_set.add(atom['handle'])
targets = atom.get('targets', None)
handle_set.add(atom["handle"])
targets = atom.get("targets", None)
if targets is not None:
for target_atom in targets:
self._recursive_get_handle_set(target_atom, handle_set)

def _recursive_get_handle_count(self, atom, handle_count):
key = atom['handle']
key = atom["handle"]
total = handle_count.get(key, 0) + 1
handle_count[key] = total
targets = atom.get('targets', None)
targets = atom.get("targets", None)
if targets is not None:
for target_atom in targets:
self._recursive_get_handle_count(target_atom, handle_count)
Expand All @@ -146,7 +149,7 @@ def get_handle_count(self):

def get_package_version(package_name: str) -> str:
package_module = import_module(package_name)
return getattr(package_module, '__version__', None)
return getattr(package_module, "__version__", None)


def serialize(payload: Any) -> bytes:
Expand All @@ -160,9 +163,9 @@ def deserialize(payload: bytes) -> Any:
@retry(attempts=5, timeout_seconds=120)
def connect_to_server(host: str, port: int) -> Tuple[int, str]:
"""Connect to the server and return the status connection and the url server"""
port = port or '8081'
openfaas_uri = f'http://{host}:{port}/function/query-engine'
aws_lambda_uri = f'http://{host}/prod/query-engine'
port = port or "8081"
openfaas_uri = f"http://{host}:{port}/function/query-engine"
aws_lambda_uri = f"http://{host}/prod/query-engine"

for uri in [openfaas_uri, aws_lambda_uri]:
status_code, message = check_server_connection(uri)
Expand All @@ -175,50 +178,97 @@ def connect_to_server(host: str, port: int) -> Tuple[int, str]:


def check_server_connection(url: str) -> Tuple[int, str]:
logger().debug(f'connecting to remote Das {url}')
error_msg = None
logger().debug(f"Connecting to remote DAS {url}")

try:
das_version = get_package_version('hyperon_das')
atom_db_version = get_package_version('hyperon_das_atomdb')

with sessions.Session() as session:
payload = {
'action': 'handshake',
'input': {
'das_version': das_version,
'atomdb_version': atom_db_version,
},
"action": "handshake",
"input": {},
}
response = session.request(
method='POST',
response = session.post(
url=url,
data=serialize(payload),
headers={'Content-Type': 'application/octet-stream'},
headers={"Content-Type": "application/octet-stream"},
timeout=10,
)
if response.status_code == HTTPStatus.CONFLICT:
try:
response = deserialize(response.content)
remote_das_version = response.get('das').get('version')
remote_atomdb_version = response.get('atom_db').get('version')
except JSONDecodeError as e:
raise Exception(str(e))
logger().error(
f"Package version conflict error when connecting to remote DAS. Local DAS: 'das: {das_version} - atom_db: {atom_db_version}' -- Remote DAS: 'das: {remote_das_version} - atom_db: {remote_atomdb_version}'"

response.raise_for_status()

remote_data = deserialize(response.content)
remote_das_version = remote_data.get("das", {}).get("version")
remote_atomdb_version = remote_data.get("atom_db", {}).get("version")

if not remote_das_version or not remote_atomdb_version:
raise ValueError("Invalid response from server, missing version info.")

is_atomdb_compatible = compare_minor_versions(
remote_atomdb_version,
atom_db_version,
)
is_das_compatible = compare_minor_versions(
remote_das_version,
das_version,
)

if not is_atomdb_compatible or not is_das_compatible:
local_versions = f"das: {das_version}, atom_db: {atom_db_version}"
remote_versions = (
f"das: {remote_das_version}, atom_db: {remote_atomdb_version}"
)
raise Exception(
f"Local DAS version 'das: {das_version} and atom_db: {atom_db_version}', DAS server is expecting 'das: {remote_das_version} and atom_db: {remote_atomdb_version}'"
error_message = (
f"Version mismatch. Local: {local_versions}. "
f"Remote: {remote_versions}."
)
if response.status_code == HTTPStatus.OK:
return response.status_code, "Successful connection"
else:
try:
error_msg = deserialize(response.content).get('error')
response.raise_for_status()
except pickle.UnpicklingError:
raise Exception("Error unpickling objects in peer's response")
logger().error(error_message)
raise Exception(error_message)

return response.status_code, "Successful connection"

except pickle.UnpicklingError:
logger().error("Failed to unpickle response from server.")
return 500, "Error unpickling objects in server response"
except (ConnectionError, Timeout, HTTPError, RequestException) as e:
msg = f"{error_msg} - {str(e)}" if error_msg else str(e)
return 400, msg
logger().error(f"Connection error: {str(e)}")
return 400, f"Connection failed: {str(e)}"
except Exception as e:
logger().error(f"Unexpected error: {str(e)}")
return 500, str(e)


def get_version_components(version_string: str) -> Union[Tuple[int, int, int], None]:
pattern = r"^(\d+)\.(\d+)\.(\d+)$"
match = re.match(pattern, version_string)

if match:
return tuple(map(int, match.groups()))

return None


def compare_versions(version1: str, version2: str, component_index: int) -> Union[int, None]:
components1 = get_version_components(version1)
components2 = get_version_components(version2)

if components1 is None or components2 is None:
return None

for i in range(component_index + 1):
if components1[i] != components2[i]:
return False

return True

def compare_major_versions(version1: str, version2: str) -> Union[int, None]:
return compare_versions(version1, version2, 0)


def compare_minor_versions(version1: str, version2: str) -> Union[int, None]:
return compare_versions(version1, version2, 1)


def compare_patch_versions(version1: str, version2: str) -> Union[int, None]:
return compare_versions(version1, version2, 2)
112 changes: 87 additions & 25 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import pytest

from hyperon_das.exceptions import InvalidAssignment
from hyperon_das.utils import Assignment, QueryAnswer
from hyperon_das.utils import (
Assignment,
QueryAnswer,
compare_major_versions,
compare_minor_versions,
compare_patch_versions,
get_version_components,
)


def _build_assignment(mappings):
Expand Down Expand Up @@ -155,54 +162,109 @@ def _check_handle_set(self, atom, handles, count):
def test_get_handle_stats(self):
self._check_handle_set(None, set([]), [])

self._check_handle_set({'handle': 'h1'}, ['h1'], [1])
self._check_handle_set({"handle": "h1"}, ["h1"], [1])

self._check_handle_set(
{
'handle': 'h1',
'targets': [
{'handle': 'h2'},
{'handle': 'h3'},
"handle": "h1",
"targets": [
{"handle": "h2"},
{"handle": "h3"},
],
},
['h1', 'h2', 'h3'],
["h1", "h2", "h3"],
[1, 1, 1],
)

self._check_handle_set(
{
'handle': 'h1',
'targets': [
{'handle': 'h2'},
{'handle': 'h1'},
"handle": "h1",
"targets": [
{"handle": "h2"},
{"handle": "h1"},
],
},
['h1', 'h2'],
["h1", "h2"],
[2, 1],
)

self._check_handle_set(
{
'handle': 'h1',
'targets': [
{'handle': 'h2'},
"handle": "h1",
"targets": [
{"handle": "h2"},
{
'handle': 'h2',
'targets': [
{'handle': 'h4'},
{'handle': 'h1'},
"handle": "h2",
"targets": [
{"handle": "h4"},
{"handle": "h1"},
],
},
{
'handle': 'h5',
'targets': [
{'handle': 'h1'},
{'handle': 'h6'},
"handle": "h5",
"targets": [
{"handle": "h1"},
{"handle": "h6"},
],
},
{'handle': 'h3'},
{"handle": "h3"},
],
},
['h1', 'h2', 'h3', 'h4', 'h5', 'h6'],
["h1", "h2", "h3", "h4", "h5", "h6"],
[3, 2, 1, 1, 1, 1],
)


@pytest.mark.parametrize(
"version_string, expected",
[
("1.2.3", (1, 2, 3)),
("10.20.30", (10, 20, 30)),
("0.0.0", (0, 0, 0)),
("invalid", None),
("1.2", None),
("1.2.3.4", None),
],
)
def test_get_version_components(version_string, expected):
assert get_version_components(version_string) == expected


@pytest.mark.parametrize(
"version1, version2, expected",
[
("1.8.0", "1.8.1", True),
("1.8.0", "1.9.0", False),
("1.8.0", "2.8.0", False),
("1.8.0", "1.8.0", True),
("1.8", "1.8.0", None),
],
)
def test_compare_minor_versions(version1, version2, expected):
assert compare_minor_versions(version1, version2) == expected


@pytest.mark.parametrize(
"version1, version2, expected",
[
("1.8.0", "1.8.1", False),
("1.8.0", "1.8.0", True),
("2.8.0", "1.8.0", False),
("1.8", "1.8.0", None),
],
)
def test_compare_patch_versions(version1, version2, expected):
assert compare_patch_versions(version1, version2) == expected


@pytest.mark.parametrize(
"version1, version2, expected",
[
("1.8.0", "1.8.1", True),
("1.8.0", "2.8.0", False),
("1.8.0", "1.7.0", True),
("invalid", "1.8.0", None),
],
)
def test_compare_major_versions(version1, version2, expected):
assert compare_major_versions(version1, version2) == expected

0 comments on commit ea93212

Please sign in to comment.