Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#333] Review the way our libraries are checked in client/server connection #350

Merged
merged 4 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
[das-atom-db#216] Removed cursor from api of get_matched*() and get_incoming_links()
[HOTFIX] Fixes after DAS api change in atom-db/#216
[das/#106] Change patterns and templates indexes to store handles only rather than handles+targets
[#333] Move handshake validation to the client-side. The server will now only send the expected library versions.
132 changes: 91 additions & 41 deletions hyperon_das/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import pickle
from dataclasses import dataclass
from http import HTTPStatus # noqa: F401
Expand Down Expand Up @@ -55,7 +56,9 @@ def __lt__(self, other) -> bool:
return self.hashcode < other.hashcode

def __repr__(self) -> str:
return str([tuple([label, self.mapping[label]]) for label in sorted(self.labels)])
return str(
[tuple([label, self.mapping[label]]) for label in sorted(self.labels)]
)

def __str__(self) -> str:
return self.__repr__()
Expand Down Expand Up @@ -86,7 +89,7 @@ def assign(
if label in self.labels:
return self.mapping[label] == value
else:
if parameters and parameters['no_overload'] and value in self.values:
if parameters and parameters["no_overload"] and value in self.values:
return False
self.labels.add(label)
self.values.add(value)
Expand Down Expand Up @@ -116,17 +119,17 @@ class QueryAnswer:
assignment: Optional[Assignment] = None

def _recursive_get_handle_set(self, atom, handle_set):
handle_set.add(atom['handle'])
targets = atom.get('targets', None)
handle_set.add(atom["handle"])
targets = atom.get("targets", None)
if targets is not None:
for target_atom in targets:
self._recursive_get_handle_set(target_atom, handle_set)

def _recursive_get_handle_count(self, atom, handle_count):
key = atom['handle']
key = atom["handle"]
total = handle_count.get(key, 0) + 1
handle_count[key] = total
targets = atom.get('targets', None)
targets = atom.get("targets", None)
if targets is not None:
for target_atom in targets:
self._recursive_get_handle_count(target_atom, handle_count)
Expand All @@ -146,7 +149,7 @@ def get_handle_count(self):

def get_package_version(package_name: str) -> str:
package_module = import_module(package_name)
return getattr(package_module, '__version__', None)
return getattr(package_module, "__version__", None)


def serialize(payload: Any) -> bytes:
Expand All @@ -160,9 +163,9 @@ def deserialize(payload: bytes) -> Any:
@retry(attempts=5, timeout_seconds=120)
def connect_to_server(host: str, port: int) -> Tuple[int, str]:
"""Connect to the server and return the status connection and the url server"""
port = port or '8081'
openfaas_uri = f'http://{host}:{port}/function/query-engine'
aws_lambda_uri = f'http://{host}/prod/query-engine'
port = port or "8081"
openfaas_uri = f"http://{host}:{port}/function/query-engine"
aws_lambda_uri = f"http://{host}/prod/query-engine"

for uri in [openfaas_uri, aws_lambda_uri]:
status_code, message = check_server_connection(uri)
Expand All @@ -175,50 +178,97 @@ def connect_to_server(host: str, port: int) -> Tuple[int, str]:


def check_server_connection(url: str) -> Tuple[int, str]:
logger().debug(f'connecting to remote Das {url}')
error_msg = None
logger().debug(f"Connecting to remote DAS {url}")

try:
das_version = get_package_version('hyperon_das')
atom_db_version = get_package_version('hyperon_das_atomdb')

with sessions.Session() as session:
payload = {
'action': 'handshake',
'input': {
'das_version': das_version,
'atomdb_version': atom_db_version,
},
"action": "handshake",
"input": {},
}
response = session.request(
method='POST',
response = session.post(
url=url,
data=serialize(payload),
headers={'Content-Type': 'application/octet-stream'},
headers={"Content-Type": "application/octet-stream"},
timeout=10,
)
if response.status_code == HTTPStatus.CONFLICT:
try:
response = deserialize(response.content)
remote_das_version = response.get('das').get('version')
remote_atomdb_version = response.get('atom_db').get('version')
except JSONDecodeError as e:
raise Exception(str(e))
logger().error(
f"Package version conflict error when connecting to remote DAS. Local DAS: 'das: {das_version} - atom_db: {atom_db_version}' -- Remote DAS: 'das: {remote_das_version} - atom_db: {remote_atomdb_version}'"

response.raise_for_status()

remote_data = deserialize(response.content)
remote_das_version = remote_data.get("das", {}).get("version")
remote_atomdb_version = remote_data.get("atom_db", {}).get("version")

if not remote_das_version or not remote_atomdb_version:
raise ValueError("Invalid response from server, missing version info.")

is_atomdb_compatible = compare_minor_versions(
remote_atomdb_version,
atom_db_version,
)
is_das_compatible = compare_minor_versions(
remote_das_version,
das_version,
)

if not is_atomdb_compatible or not is_das_compatible:
local_versions = f"das: {das_version}, atom_db: {atom_db_version}"
remote_versions = (
f"das: {remote_das_version}, atom_db: {remote_atomdb_version}"
)
raise Exception(
f"Local DAS version 'das: {das_version} and atom_db: {atom_db_version}', DAS server is expecting 'das: {remote_das_version} and atom_db: {remote_atomdb_version}'"
error_message = (
f"Version mismatch. Local: {local_versions}. "
f"Remote: {remote_versions}."
)
if response.status_code == HTTPStatus.OK:
return response.status_code, "Successful connection"
else:
try:
error_msg = deserialize(response.content).get('error')
response.raise_for_status()
except pickle.UnpicklingError:
raise Exception("Error unpickling objects in peer's response")
logger().error(error_message)
raise Exception(error_message)

return response.status_code, "Successful connection"

except pickle.UnpicklingError:
logger().error("Failed to unpickle response from server.")
return 500, "Error unpickling objects in server response"
except (ConnectionError, Timeout, HTTPError, RequestException) as e:
msg = f"{error_msg} - {str(e)}" if error_msg else str(e)
return 400, msg
logger().error(f"Connection error: {str(e)}")
return 400, f"Connection failed: {str(e)}"
except Exception as e:
logger().error(f"Unexpected error: {str(e)}")
return 500, str(e)


def get_version_components(version_string: str) -> Union[Tuple[int, int, int], None]:
pattern = r"^(\d+)\.(\d+)\.(\d+)$"
match = re.match(pattern, version_string)

if match:
return tuple(map(int, match.groups()))

return None


def compare_versions(version1: str, version2: str, component_index: int) -> Union[int, None]:
components1 = get_version_components(version1)
components2 = get_version_components(version2)

if components1 is None or components2 is None:
return None

for i in range(component_index + 1):
if components1[i] != components2[i]:
return False

return True

def compare_major_versions(version1: str, version2: str) -> Union[int, None]:
return compare_versions(version1, version2, 0)


def compare_minor_versions(version1: str, version2: str) -> Union[int, None]:
return compare_versions(version1, version2, 1)


def compare_patch_versions(version1: str, version2: str) -> Union[int, None]:
return compare_versions(version1, version2, 2)
112 changes: 87 additions & 25 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import pytest

from hyperon_das.exceptions import InvalidAssignment
from hyperon_das.utils import Assignment, QueryAnswer
from hyperon_das.utils import (
Assignment,
QueryAnswer,
compare_major_versions,
compare_minor_versions,
compare_patch_versions,
get_version_components,
)


def _build_assignment(mappings):
Expand Down Expand Up @@ -155,54 +162,109 @@ def _check_handle_set(self, atom, handles, count):
def test_get_handle_stats(self):
self._check_handle_set(None, set([]), [])

self._check_handle_set({'handle': 'h1'}, ['h1'], [1])
self._check_handle_set({"handle": "h1"}, ["h1"], [1])

self._check_handle_set(
{
'handle': 'h1',
'targets': [
{'handle': 'h2'},
{'handle': 'h3'},
"handle": "h1",
"targets": [
{"handle": "h2"},
{"handle": "h3"},
],
},
['h1', 'h2', 'h3'],
["h1", "h2", "h3"],
[1, 1, 1],
)

self._check_handle_set(
{
'handle': 'h1',
'targets': [
{'handle': 'h2'},
{'handle': 'h1'},
"handle": "h1",
"targets": [
{"handle": "h2"},
{"handle": "h1"},
],
},
['h1', 'h2'],
["h1", "h2"],
[2, 1],
)

self._check_handle_set(
{
'handle': 'h1',
'targets': [
{'handle': 'h2'},
"handle": "h1",
"targets": [
{"handle": "h2"},
{
'handle': 'h2',
'targets': [
{'handle': 'h4'},
{'handle': 'h1'},
"handle": "h2",
"targets": [
{"handle": "h4"},
{"handle": "h1"},
],
},
{
'handle': 'h5',
'targets': [
{'handle': 'h1'},
{'handle': 'h6'},
"handle": "h5",
"targets": [
{"handle": "h1"},
{"handle": "h6"},
],
},
{'handle': 'h3'},
{"handle": "h3"},
],
},
['h1', 'h2', 'h3', 'h4', 'h5', 'h6'],
["h1", "h2", "h3", "h4", "h5", "h6"],
[3, 2, 1, 1, 1, 1],
)


@pytest.mark.parametrize(
"version_string, expected",
[
("1.2.3", (1, 2, 3)),
("10.20.30", (10, 20, 30)),
("0.0.0", (0, 0, 0)),
("invalid", None),
("1.2", None),
("1.2.3.4", None),
],
)
def test_get_version_components(version_string, expected):
assert get_version_components(version_string) == expected


@pytest.mark.parametrize(
"version1, version2, expected",
[
("1.8.0", "1.8.1", True),
("1.8.0", "1.9.0", False),
("1.8.0", "2.8.0", False),
("1.8.0", "1.8.0", True),
("1.8", "1.8.0", None),
],
)
def test_compare_minor_versions(version1, version2, expected):
assert compare_minor_versions(version1, version2) == expected


@pytest.mark.parametrize(
"version1, version2, expected",
[
("1.8.0", "1.8.1", False),
("1.8.0", "1.8.0", True),
("2.8.0", "1.8.0", False),
("1.8", "1.8.0", None),
],
)
def test_compare_patch_versions(version1, version2, expected):
assert compare_patch_versions(version1, version2) == expected


@pytest.mark.parametrize(
"version1, version2, expected",
[
("1.8.0", "1.8.1", True),
("1.8.0", "2.8.0", False),
("1.8.0", "1.7.0", True),
("invalid", "1.8.0", None),
],
)
def test_compare_major_versions(version1, version2, expected):
assert compare_major_versions(version1, version2) == expected
Loading