From 4047256f3dd9398ba6b106a912a40f48ac42ed83 Mon Sep 17 00:00:00 2001 From: Boris Date: Mon, 25 Sep 2023 16:32:36 -0500 Subject: [PATCH] update tests --- src/clients/KubernetesClients.py | 32 ++++- src/dependencies/k8_wrapper.py | 56 +++------ src/dependencies/lifecycle.py | 2 - src/dependencies/logs.py | 2 +- src/fastapi_routes/rpc.py | 4 + src/rpc/common.py | 14 ++- src/rpc/error_responses.py | 2 +- src/rpc/models.py | 31 +++-- src/rpc/unauthenticated_handlers.py | 10 +- test/conftest.py | 49 ++------ ...t_lifecycle_helpers.py => test_helpers.py} | 6 +- test/src/dependencies/test_k8_wrapper.py | 113 +++++++++++++++--- test/src/dependencies/test_lifecycle.py | 108 +++++++---------- test/src/dependencies/test_lifecycle2.py | 108 ----------------- test/src/fixtures/fixtures.py | 82 +++++++++++++ 15 files changed, 323 insertions(+), 296 deletions(-) rename test/src/dependencies/{test_lifecycle_helpers.py => test_helpers.py} (87%) delete mode 100644 test/src/dependencies/test_lifecycle2.py create mode 100644 test/src/fixtures/fixtures.py diff --git a/src/clients/KubernetesClients.py b/src/clients/KubernetesClients.py index d0cd289..2281e68 100644 --- a/src/clients/KubernetesClients.py +++ b/src/clients/KubernetesClients.py @@ -2,8 +2,9 @@ from typing import Optional from cacheout import LRUCache +from fastapi.requests import Request from kubernetes import config -from kubernetes.client import CoreV1Api, AppsV1Api, NetworkingV1Api +from kubernetes.client import CoreV1Api, AppsV1Api, NetworkingV1Api, V1Deployment from src.configs.settings import Settings @@ -66,3 +67,32 @@ def __init__( self.network_client = k8s_network_client self.service_status_cache = LRUCache(ttl=10) self.all_service_status_cache = LRUCache(ttl=10) + + +def get_k8s_core_client(request: Request) -> CoreV1Api: + return request.app.state.k8s_clients.core_client + + +def get_k8s_app_client(request: Request) -> AppsV1Api: + return request.app.state.k8s_clients.app_client + + +def get_k8s_networking_client(request: Request) -> NetworkingV1Api: + return request.app.state.k8s_clients.network_client + + +def get_k8s_service_status_cache(request: Request) -> LRUCache: + return request.app.state.k8s_clients.service_status_cache + + +def get_k8s_all_service_status_cache(request: Request) -> LRUCache: + return request.app.state.k8s_clients.all_service_status_cache + + +def check_service_status_cache(request: Request, label_selector_text) -> V1Deployment: + cache = get_k8s_service_status_cache(request) + return cache.get(label_selector_text, None) + + +def populate_service_status_cache(request: Request, label_selector_text, data: list): + get_k8s_service_status_cache(request).set(label_selector_text, data) diff --git a/src/dependencies/k8_wrapper.py b/src/dependencies/k8_wrapper.py index de2cb83..e2d5cde 100644 --- a/src/dependencies/k8_wrapper.py +++ b/src/dependencies/k8_wrapper.py @@ -3,7 +3,6 @@ import time from typing import Optional, List -from cacheout import LRUCache from fastapi import Request from kubernetes import client from kubernetes.client import ( @@ -14,48 +13,23 @@ V1IngressSpec, V1IngressRule, ApiException, - CoreV1Api, - AppsV1Api, - NetworkingV1Api, V1HTTPIngressPath, V1IngressBackend, - V1Deployment, V1HTTPIngressRuleValue, V1Toleration, ) +from src.clients.KubernetesClients import ( + get_k8s_core_client, + get_k8s_app_client, + get_k8s_networking_client, + get_k8s_all_service_status_cache, + check_service_status_cache, + populate_service_status_cache, +) from src.configs.settings import get_settings -def get_k8s_core_client(request: Request) -> CoreV1Api: - return request.app.state.k8s_clients.core_client - - -def get_k8s_app_client(request: Request) -> AppsV1Api: - return request.app.state.k8s_clients.app_client - - -def get_k8s_networking_client(request: Request) -> NetworkingV1Api: - return request.app.state.k8s_clients.network_client - - -def _get_k8s_service_status_cache(request: Request) -> LRUCache: - return request.app.state.k8s_clients.service_status_cache - - -def _get_k8s_all_service_status_cache(request: Request) -> LRUCache: - return request.app.state.k8s_clients.all_service_status_cache - - -def check_service_status_cache(request: Request, label_selector_text) -> V1Deployment: - cache = _get_k8s_service_status_cache(request) - return cache.get(label_selector_text, None) - - -def populate_service_status_cache(request: Request, label_selector_text, data: list): - _get_k8s_service_status_cache(request).set(label_selector_text, data) - - def get_pods_in_namespace( k8s_client: client.CoreV1Api, field_selector=None, @@ -88,7 +62,7 @@ def v1_volume_mount_factory(mounts): return volumes, volume_mounts -def _sanitize_deployment_name(module_name, module_git_commit_hash): +def sanitize_deployment_name(module_name, module_git_commit_hash): """ Create a deployment name based on the module name and git commit hash. Adhere to Kubernetes API naming rules and create valid DNS labels. @@ -108,7 +82,7 @@ def _sanitize_deployment_name(module_name, module_git_commit_hash): def create_clusterip_service(request, module_name, module_git_commit_hash, labels) -> client.V1Service: core_v1_api = get_k8s_core_client(request) - deployment_name, service_name = _sanitize_deployment_name(module_name, module_git_commit_hash) + deployment_name, service_name = sanitize_deployment_name(module_name, module_git_commit_hash) # Define the service service = V1Service( @@ -189,7 +163,7 @@ def _update_ingress_with_retries(request, new_path, namespace, retries=3): def update_ingress_to_point_to_service(request: Request, module_name: str, git_commit_hash: str): settings = request.app.state.settings namespace = settings.namespace - deployment_name, service_name = _sanitize_deployment_name(module_name, git_commit_hash) + deployment_name, service_name = sanitize_deployment_name(module_name, git_commit_hash) # Need to sync this with Status methods path = f"/{settings.external_ds_url.split('/')[-1]}/{module_name}.{git_commit_hash}(/|$)(.*)" new_path = V1HTTPIngressPath(path=path, path_type="ImplementationSpecific", backend=V1IngressBackend(service={"name": service_name, "port": {"number": 5000}})) @@ -197,7 +171,7 @@ def update_ingress_to_point_to_service(request: Request, module_name: str, git_c def create_and_launch_deployment(request, module_name, module_git_commit_hash, image, labels, annotations, env, mounts) -> client.V1LabelSelector: - deployment_name, service_name = _sanitize_deployment_name(module_name, module_git_commit_hash) + deployment_name, service_name = sanitize_deployment_name(module_name, module_git_commit_hash) namespace = request.app.state.settings.namespace annotations["k8s_deployment_name"] = deployment_name @@ -264,7 +238,7 @@ def get_k8s_deployments(request, label_selector="us.kbase.dynamicservice=true") :return: A list of deployments """ - cache = _get_k8s_all_service_status_cache(request) + cache = get_k8s_all_service_status_cache(request) cached_deployments = cache.get(label_selector, None) if cached_deployments is not None: return cached_deployments @@ -278,7 +252,7 @@ def get_k8s_deployments(request, label_selector="us.kbase.dynamicservice=true") def delete_deployment(request, module_name, module_git_commit_hash) -> str: - deployment_name, _ = _sanitize_deployment_name(module_name, module_git_commit_hash) + deployment_name, _ = sanitize_deployment_name(module_name, module_git_commit_hash) namespace = request.app.state.settings.namespace get_k8s_app_client(request).delete_namespaced_deployment(name=deployment_name, namespace=namespace) return deployment_name @@ -292,7 +266,7 @@ def scale_replicas(request, module_name, module_git_commit_hash, replicas: int) def get_logs_for_first_pod_in_deployment(request, module_name, module_git_commit_hash): - deployment_name, _ = _sanitize_deployment_name(module_name, module_git_commit_hash) + deployment_name, _ = sanitize_deployment_name(module_name, module_git_commit_hash) namespace = request.app.state.settings.namespace label_selector_text = f"us.kbase.module.module_name={module_name.lower()}," + f"us.kbase.module.git_commit_hash={module_git_commit_hash}" diff --git a/src/dependencies/lifecycle.py b/src/dependencies/lifecycle.py index d70ee73..757032e 100644 --- a/src/dependencies/lifecycle.py +++ b/src/dependencies/lifecycle.py @@ -72,8 +72,6 @@ def get_volume_mounts(request, module_name, module_version): def _setup_metadata(module_name, requested_module_version, git_commit_hash, version, git_url) -> Tuple[Dict, Dict]: """ - TESTED=TRUE - Convenience method to set up the labels and annotations for a deployment. :param module_name: Module name that comes from the web request diff --git a/src/dependencies/logs.py b/src/dependencies/logs.py index d8862d6..84d2ca3 100644 --- a/src/dependencies/logs.py +++ b/src/dependencies/logs.py @@ -29,7 +29,7 @@ def get_service_log(request: Request, module_name: str, module_version: str) -> return [{"instance_id": pod_name, "log": logs}] -def get_service_log_web_socket(request: Request, module_name: str, module_version: str) -> List[dict]: +def get_service_log_web_socket(request: Request, module_name: str, module_version: str) -> List[dict]: # pragma: no cover """ Get logs for a service. This isn't used anywhere but can require a dependency on rancher if implemented. diff --git a/src/fastapi_routes/rpc.py b/src/fastapi_routes/rpc.py index 2838645..22231d5 100644 --- a/src/fastapi_routes/rpc.py +++ b/src/fastapi_routes/rpc.py @@ -55,8 +55,12 @@ def json_rpc(request: Request, body: bytes = Depends(get_body)) -> Response | HT else: request.state.user_auth_roles = user_auth_roles + print(request, params, jrpc_id) valid_response = request_function(request, params, jrpc_id) # type:JSONRPCResponse + print("RESPONSE IS", valid_response) converted_response = jsonable_encoder(valid_response) + print("CONVERTED RESPONSE IS", converted_response) if "error" in converted_response: + print("HERE YOU GO") return JSONResponse(content=converted_response, status_code=500) return JSONResponse(content=converted_response, status_code=200) diff --git a/src/rpc/common.py b/src/rpc/common.py index afee039..880bb19 100644 --- a/src/rpc/common.py +++ b/src/rpc/common.py @@ -34,8 +34,9 @@ def validate_rpc_request(body): params = json_data.get("params", []) jrpc_id = json_data.get("id", 0) - if not isinstance(method, str) or not isinstance(params, list): + if not isinstance(method, str) and not isinstance(params, list): raise ServerError(message=f"`method` must be a valid SW1 method string. Params must be a dictionary. {json_data}", code=-32600, name="Invalid Request") + print(type(method), type(params), type(jrpc_id)) return method, params, jrpc_id @@ -83,6 +84,16 @@ def handle_rpc_request( method_name = action.__name__ try: params = params[0] + if not isinstance(params, dict): + return JSONRPCResponse( + id=jrpc_id, + error=ErrorResponse( + message=f"Invalid params for ServiceWizard.{method_name}", + code=-32602, + name="Invalid params", + error=f"Params must be a dictionary. Got {type(params)}", + ), + ) except IndexError: return no_params_passed(method=method_name, jrpc_id=jrpc_id) @@ -93,6 +104,7 @@ def handle_rpc_request( try: result = action(request, module_name, module_version) + print("ABOUT TO RETURN RESULT", result) return JSONRPCResponse(id=jrpc_id, result=[result]) except ServerError as e: traceback_str = traceback.format_exc() diff --git a/src/rpc/error_responses.py b/src/rpc/error_responses.py index 876d671..51171d1 100644 --- a/src/rpc/error_responses.py +++ b/src/rpc/error_responses.py @@ -68,4 +68,4 @@ def token_validation_failed(jrpc_id): def json_rpc_response_to_exception(content: JSONRPCResponse, status_code=500): - return JSONResponse(content=content.dict(), status_code=status_code) + return JSONResponse(content=content.model_dump(), status_code=status_code) diff --git a/src/rpc/models.py b/src/rpc/models.py index 9ca6d00..8003987 100644 --- a/src/rpc/models.py +++ b/src/rpc/models.py @@ -1,4 +1,4 @@ -from typing import Optional, Any +from typing import Any, Optional, Union from pydantic import BaseModel @@ -10,22 +10,27 @@ class ErrorResponse(BaseModel): error: str = None + + class JSONRPCResponse(BaseModel): version: str = "1.0" - id: Optional[int | str] - error: Optional[ErrorResponse] + id: Optional[Union[int, str]] = 0 + error: Optional[ErrorResponse] = None result: Any = None - def dict(self, *args, **kwargs): - response_dict = super().dict(*args, **kwargs) - if self.result is None: - response_dict.pop("result", None) + def model_dump(self, *args, **kwargs) -> dict[str, Any]: + # Default behavior for the serialization + serialized_data = super().model_dump(*args, **kwargs) + + # Custom logic to exclude fields based on their values + if serialized_data.get("result") is None: + serialized_data.pop("result", None) - if self.error is None: - response_dict.pop("error", None) - response_dict.pop("version", None) + if serialized_data.get("error") is None: + serialized_data.pop("error", None) + serialized_data.pop("version", None) - if self.id is None: - response_dict.pop("id", None) + if serialized_data.get("id") is None: + serialized_data.pop("id", None) - return response_dict + return serialized_data diff --git a/src/rpc/unauthenticated_handlers.py b/src/rpc/unauthenticated_handlers.py index 39d762a..ef7bfeb 100644 --- a/src/rpc/unauthenticated_handlers.py +++ b/src/rpc/unauthenticated_handlers.py @@ -18,13 +18,11 @@ def start(request: Request, params: list[dict], jrpc_id: str) -> JSONRPCResponse return handle_rpc_request(request, params, jrpc_id, start_deployment) -def status(request: Request, params: list[dict], jrpc_id: str) -> JSONRPCResponse: - if not params: - params = [{}] +def status(request: Request, params: list[dict], jrpc_id: str) -> JSONRPCResponse: # noqa F811 + params = [{}] return handle_rpc_request(request, params, jrpc_id, get_status) -def version(request: Request, params: list[dict], jrpc_id: str) -> JSONRPCResponse: - if not params: - params = [{}] +def version(request: Request, params: list[dict], jrpc_id: str) -> JSONRPCResponse: # noqa F811 + params = [{}] return handle_rpc_request(request, params, jrpc_id, get_version) diff --git a/test/conftest.py b/test/conftest.py index 4be0ff6..eedcc91 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,50 +1,21 @@ import os +from glob import glob import pytest from dotenv import load_dotenv +def _as_module(fixture_path: str) -> str: + return fixture_path.replace("/", ".").replace("\\", ".").replace(".py", "") + +def pytest_collectreport(report): + print("CONFTEST loaded") + @pytest.fixture(autouse=True) def load_environment(): # Ensure that the environment variables are loaded before running the tests - load_dotenv() + load_dotenv(os.environ.get("DOTENV_FILE_LOCATION", ".env")) -@pytest.fixture(autouse=True) -def generate_kubeconfig(): - # Generate a kubeconfig file for testing - # Overwrite kubeconfig - os.environ["KUBECONFIG"] = "test_kubeconfig_file" - kubeconfig_path = os.environ["KUBECONFIG"] - - kubeconfig_content = """\ -apiVersion: v1 -kind: Config -current-context: test-context -clusters: -- name: test-cluster - cluster: - server: https://test-api-server - insecure-skip-tls-verify: true -contexts: -- name: test-context - context: - cluster: test-cluster - user: test-user -users: -- name: test-user - user: - exec: - command: echo - apiVersion: client.authentication.k8s.io/v1alpha1 - args: - - "access_token" -""" - - with open(kubeconfig_path, "w") as kubeconfig_file: - kubeconfig_file.write(kubeconfig_content.strip()) - - yield - - # Clean up the generated kubeconfig file after the tests - os.remove(kubeconfig_path) + +pytest_plugins = [_as_module(fixture) for fixture in glob("test/src/fixtures/[!_]*.py")] diff --git a/test/src/dependencies/test_lifecycle_helpers.py b/test/src/dependencies/test_helpers.py similarity index 87% rename from test/src/dependencies/test_lifecycle_helpers.py rename to test/src/dependencies/test_helpers.py index 19f8a57..80eec63 100644 --- a/test/src/dependencies/test_lifecycle_helpers.py +++ b/test/src/dependencies/test_helpers.py @@ -6,14 +6,14 @@ def get_running_deployment(deployment_name) -> DynamicServiceStatus: module_info = _create_sample_module_info() - deployment = _create_sample_deployment(deployment_name=deployment_name, ready_replicas=1) + deployment = create_sample_deployment(deployment_name=deployment_name, ready_replicas=1, replicas=1, available_replicas=1, unavailable_replicas=0) deployment_status = _create_deployment_status(module_info, deployment) return deployment_status def get_stopped_deployment(deployment_name) -> DynamicServiceStatus: module_info = _create_sample_module_info() - deployment = _create_sample_deployment(deployment_name=deployment_name, ready_replicas=0, available_replicas=0) + deployment = create_sample_deployment(deployment_name=deployment_name, ready_replicas=0, available_replicas=0, unavailable_replicas=1, replicas=0) deployment_status = _create_deployment_status(module_info, deployment) return deployment_status @@ -34,7 +34,7 @@ def _create_deployment_status(module_info, deployment) -> DynamicServiceStatus: ) -def _create_sample_deployment(deployment_name, replicas=1, ready_replicas=1, available_replicas=1, unavailable_replicas=0): +def create_sample_deployment(deployment_name, replicas, ready_replicas, available_replicas, unavailable_replicas): deployment_status = V1DeploymentStatus( updated_replicas=replicas, ready_replicas=ready_replicas, available_replicas=available_replicas, unavailable_replicas=unavailable_replicas ) diff --git a/test/src/dependencies/test_k8_wrapper.py b/test/src/dependencies/test_k8_wrapper.py index afe42ef..0c1db46 100644 --- a/test/src/dependencies/test_k8_wrapper.py +++ b/test/src/dependencies/test_k8_wrapper.py @@ -5,7 +5,7 @@ from src.dependencies.k8_wrapper import ( create_clusterip_service, - _sanitize_deployment_name, + sanitize_deployment_name, update_ingress_to_point_to_service, create_and_launch_deployment, query_k8s_deployment_status, @@ -13,6 +13,7 @@ delete_deployment, scale_replicas, get_logs_for_first_pod_in_deployment, + get_k8s_deployment_status_from_label, ) # Sample Data @@ -58,22 +59,109 @@ def test_create_and_launch_deployment(mock_get_k8s_app_client): assert isinstance(result, client.V1LabelSelector) -@patch("src.dependencies.k8_wrapper._get_k8s_service_status_cache") -@patch("src.dependencies.k8_wrapper.get_k8s_app_client") -def test_query_k8s_deployment_status(mock_get_k8s_app_client, mock_get_k8s_service_status_cache): +@patch("src.dependencies.k8_wrapper.check_service_status_cache") +@patch("src.dependencies.k8_wrapper._get_deployment_status") +def test_query_k8s_deployment_status( mock__get_deployment_status, mock_check_service_status_cache, mock_request,): + + module_info = mock_request.app.state.mock_module_info + module_name = module_info["module_name"] + module_git_commit_hash = module_info["git_commit_hash"] + result = query_k8s_deployment_status(mock_request, module_name, module_git_commit_hash) + assert mock__get_deployment_status.call_count == 1 + assert mock__get_deployment_status.called_with(mock_request, module_name, module_git_commit_hash) + + + + assert result == module_info + # ls = label_selector: client.V1LabelSelector + selector = client.V1LabelSelector(match_labels={"app": "d-test-module-1234567"}) + result = get_k8s_deployment_status_from_label(Mock(), selector) + assert result == "deployment1" + # Assert that _get_deployment_status was called with the deployment + """ + def query_k8s_deployment_status(request, module_name, module_git_commit_hash) -> client.V1Deployment: + label_selector_text = f"us.kbase.module.module_name={module_name.lower()}," + f"us.kbase.module.git_commit_hash={module_git_commit_hash}" + return _get_deployment_status(request, label_selector_text) + + + def get_k8s_deployment_status_from_label(request, label_selector: client.V1LabelSelector) -> client.V1Deployment: + label_selector_text = ",".join([f"{key}={value}" for key, value in label_selector.match_labels.items()]) + return _get_deployment_status(request, label_selector_text) + + """ + assert mock__get_deployment_status.call_count == 2 + # Mock the service status cache to return None - mock_cache = Mock() - mock_cache.get.return_value = None - mock_get_k8s_service_status_cache.return_value = mock_cache + + # mock_check_service_status_cache.return_value = "deployment1" + # mock__get_deployment_status.return_value = LRUCache(ttl=10) # Mock the Kubernetes client to return the deployment - mock_get_k8s_app_client.return_value.list_namespaced_deployment.return_value.items = ["deployment1"] - result = query_k8s_deployment_status(Mock(), sample_module_name, sample_git_commit_hash) + # mock_get_k8s_app_client.return_value.list_namespaced_deployment.return_value.items = ["deployment1"] + + +@patch("src.dependencies.k8_wrapper._get_deployment_status") +@patch("src.dependencies.k8_wrapper.check_service_status_cache") +def test_query_k8s_deployment_status2(mock_check_service_status_cache, mock__get_deployment_status): + # Mock the service status cache to return "deployment1" + mock_check_service_status_cache.return_value = "deployment1" + + # Mock the _get_deployment_status to return "deployment1" + mock__get_deployment_status.return_value = "deployment1" + + # Call the function and assert its result + request_obj = Mock() + result = query_k8s_deployment_status(request_obj, sample_module_name, sample_git_commit_hash) + assert result == "deployment1" + + expected_label1 = f"us.kbase.module.module_name={sample_module_name.lower()},us.kbase.module.git_commit_hash={sample_git_commit_hash}" + mock__get_deployment_status.assert_called_with(request_obj, expected_label1) + + # Create a label selector + selector = client.V1LabelSelector(match_labels={"app": "d-test-module-1234567"}) + result = get_k8s_deployment_status_from_label(request_obj, selector) + assert result == "deployment1" + + expected_label2 = "app=d-test-module-1234567" + mock__get_deployment_status.assert_called_with(request_obj, expected_label2) + + # Assert that _get_deployment_status was called twice + assert mock__get_deployment_status.call_count == 2 + +@patch("src.dependencies.k8_wrapper._get_deployment_status") +@patch("src.dependencies.k8_wrapper.check_service_status_cache") +def test_combined_query_k8s_deployment_status(mock_check_service_status_cache, mock__get_deployment_status): + # Mock the service status cache to return "deployment1" + mock_check_service_status_cache.return_value = "deployment1" + + # Mock the _get_deployment_status to return "deployment1" + mock__get_deployment_status.return_value = "deployment1" + + # Call the function and assert its result + request_obj = Mock() + result = query_k8s_deployment_status(request_obj, sample_module_name, sample_git_commit_hash) + assert result == "deployment1" + + expected_label1 = f"us.kbase.module.module_name={sample_module_name.lower()},us.kbase.module.git_commit_hash={sample_git_commit_hash}" + mock__get_deployment_status.assert_called_with(request_obj, expected_label1) + + # Create a label selector + selector = client.V1LabelSelector(match_labels={"app": "d-test-module-1234567"}) + result = get_k8s_deployment_status_from_label(request_obj, selector) assert result == "deployment1" + expected_label2 = "app=d-test-module-1234567" + mock__get_deployment_status.assert_called_with(request_obj, expected_label2) + + # Assert that _get_deployment_status was called twice + assert mock__get_deployment_status.call_count == 2 + + -@patch("src.dependencies.k8_wrapper._get_k8s_all_service_status_cache") + + +@patch("src.dependencies.k8_wrapper.get_k8s_all_service_status_cache") @patch("src.dependencies.k8_wrapper.get_k8s_app_client") def test_get_k8s_deployments(mock_get_k8s_app_client, mock_get_k8s_all_service_status_cache): # Mock the all service status cache to return None @@ -139,10 +227,7 @@ def test_get_logs_for_first_pod_in_deployment(mock_get_k8s_core_client): ) def test_sanitize_deployment_name(module_name, git_commit_hash, expected_deployment_name): # When we sanitize the deployment name - deployment_name, _ = _sanitize_deployment_name(module_name, git_commit_hash) + deployment_name, _ = sanitize_deployment_name(module_name, git_commit_hash) # Then the deployment name should match the expected format assert deployment_name == expected_deployment_name assert len(deployment_name) <= 63 - - -# THESE TESTS SUCK, AND SHOULD BE INTEGRATION TESTS INSTEAD diff --git a/test/src/dependencies/test_lifecycle.py b/test/src/dependencies/test_lifecycle.py index c8db20a..45de509 100644 --- a/test/src/dependencies/test_lifecycle.py +++ b/test/src/dependencies/test_lifecycle.py @@ -1,41 +1,16 @@ import logging import re -from unittest.mock import Mock, patch +from unittest.mock import patch import pytest -from fastapi import Request, HTTPException +from fastapi import HTTPException from kubernetes.client import ApiException -from src.clients.CachedCatalogClient import CachedCatalogClient +from src.clients.baseclient import ServerError from src.configs.settings import get_settings from src.dependencies import lifecycle -from test.src.dependencies import test_lifecycle_helpers as tlh - - -def _get_mock_request(): - request = Mock(spec=Request) - request.app.state.settings = get_settings() - - mock_module_info = { - "git_commit_hash": "test_hash", - "version": "test_version", - "git_url": "https://github.com/test/repo", - "module_name": "test_module", - "release_tags": ["test_tag"], - "owners": ["test_owner"], - "docker_img_name": "test_img_name", - } - - request.app.state.catalog_client = Mock(spec=CachedCatalogClient) - request.app.state.catalog_client.get_combined_module_info.return_value = mock_module_info - request.app.state.catalog_client.list_service_volume_mounts.return_value = [] - request.app.state.catalog_client.get_secure_params.return_value = [{"param_name": "test_secure_param_name", "param_value": "test_secure_param_value"}] - return request - - -@pytest.fixture -def mock_request(): - return _get_mock_request() +from src.models.models import DynamicServiceStatus, ServiceStatus +from test.src.dependencies import test_helpers as tlh def test_simple_get_volume_mounts(mock_request): @@ -99,12 +74,13 @@ def test_start_deployment( _create_cluster_ip_service_helper_mock, get_service_status_with_retries_mock, scale_replicas_mock, + mock_request, ): # Test Deployment Does Not Already exist, no need to scale replicas _create_and_launch_deployment_helper_mock.return_value = False _setup_metadata_mock.return_value = {}, {} get_service_status_with_retries_mock.return_value = tlh.get_stopped_deployment("tester") - mock_request = _get_mock_request() + rv = lifecycle.start_deployment(request=mock_request, module_name="test_module", module_version="dev") scale_replicas_mock.assert_not_called() assert rv == tlh.get_stopped_deployment("tester") @@ -204,38 +180,38 @@ def test_create_and_launch_deployment_helper(mock_logging_warning, mock_update_i assert mock_logging_warning.call_count == 1 -# -# @patch("src.dependencies.lifecycle.get_service_status_with_retries") -# @patch("src.dependencies.lifecycle._update_ingress_for_service_helper") -# @patch("src.dependencies.lifecycle._create_cluster_ip_service_helper") -# @patch("src.dependencies.lifecycle._create_and_launch_deployment_helper") -# @patch("src.dependencies.lifecycle.get_env") -# @patch("src.dependencies.lifecycle.get_volume_mounts") -# @patch("src.dependencies.lifecycle._setup_metadata") -# def test_start_deployment_existing( -# mock_setup_metadata, mock_get_volume_mounts, mock_get_env, -# mock_create_and_launch, mock_create_cluster_ip, mock_update_ingress, -# mock_get_status, mock_request): -# # Arrange -# module_name = "test_module" -# module_version = "1.0" -# -# mock_setup_metadata.return_value = ({}, {}) -# mock_get_volume_mounts.return_value = [] -# mock_get_env.return_value = {} -# mock_create_and_launch.return_value = False -# mock_create_cluster_ip.return_value = None -# mock_update_ingress.return_value = None -# mock_get_status.return_value = tlh.get_running_deployment(deployment_name="test_existing_deployment") -# -# # Act -# result = start_deployment(mock_request, module_name, module_version) -# -# # Assert -# assert isinstance(result, DynamicServiceStatus) -# -# -# -# -# -# +@patch("src.dependencies.lifecycle.scale_replicas") +def test_stop_deployment(mock_scale_replicas, mock_request): + mock_request.state.user_auth_roles.is_admin_or_owner.return_value = False + with pytest.raises(ServerError) as e: + lifecycle.stop_deployment(request=mock_request, module_name="test_module", module_version="test_version") + assert mock_request.state.user_auth_roles.is_admin_or_owner.call_count == 1 + assert e.value.code == -32000 + assert e.value.message == "Only admins or module owners can stop dynamic services" + + mock_request.state.user_auth_roles.is_admin_or_owner.return_value = True + + deployment = tlh.create_sample_deployment(deployment_name="test_deployment_name", replicas=0, ready_replicas=0, available_replicas=0, unavailable_replicas=0) + + mock_scale_replicas.return_value = deployment + + rv = lifecycle.stop_deployment(request=mock_request, module_name="test_module", module_version="test_version") + + dds = DynamicServiceStatus( + git_commit_hash="test_hash", + status=ServiceStatus.STOPPED, + version="test_version", + hash="test_hash", + release_tags=["test_tag"], + url="https://ci.kbase.us/dynamic_services/test_module.test_hash", + module_name="test_module", + health=ServiceStatus.STOPPED, + up=0, + deployment_name="test_deployment_name", + replicas=0, + updated_replicas=0, + ready_replicas=0, + available_replicas=0, + unavailable_replicas=0, + ) + assert rv == dds diff --git a/test/src/dependencies/test_lifecycle2.py b/test/src/dependencies/test_lifecycle2.py deleted file mode 100644 index b7bb3b2..0000000 --- a/test/src/dependencies/test_lifecycle2.py +++ /dev/null @@ -1,108 +0,0 @@ -# -# from unittest.mock import patch -# -# import pytest -# -# from src.dependencies.lifecycle import get_env, get_volume_mounts, _setup_metadata, _create_and_launch_deployment_helper, \ -# _create_cluster_ip_service_helper, _update_ingress_for_service_helper, start_deployment, stop_deployment -# -# - -# -# def test_get_volume_mounts_success(): -# # Mocking the response from the KBase Catalog to simulate valid volume mount data -# mock_response = [ -# { -# 'volume_name': 'test_volume', -# 'mount_path': '/test/path', -# 'read_only': False -# } -# ] -# with patch('src.dependencies.lifecycle.get_catalog_volume_mounts', return_value=mock_response): -# mounts = get_volume_mounts('test_module') -# assert len(mounts) == 1 -# assert mounts[0].name == 'test_volume' -# assert mounts[0].mount_path == '/test/path' -# assert mounts[0].read_only == False -# -# def test_get_volume_mounts_invalid_data(): -# # Mocking an invalid response from the KBase Catalog -# mock_response = [ -# { -# 'invalid_key': 'test_volume', -# 'mount_path': '/test/path' -# } -# ] -# with patch('src.dependencies.lifecycle.get_catalog_volume_mounts', return_value=mock_response): -# with pytest.raises(Exception, match="Invalid data from KBase Catalog"): -# get_volume_mounts('test_module') -# -# -# -# def test_setup_metadata_success(): -# labels, annotations = _setup_metadata('test_module', 'test_version', 'test_commit') -# assert labels['module-name'] == 'test_module' -# assert labels['module-version'] == 'test_version' -# assert labels['module-git-commit'] == 'test_commit' -# assert annotations['module-name'] == 'test_module' -# assert annotations['module-version'] == 'test_version' -# assert annotations['module-git-commit'] == 'test_commit' -# -# def test_setup_metadata_missing_params(): -# labels, annotations = _setup_metadata('test_module', None, None) -# assert labels['module-name'] == 'test_module' -# assert 'module-version' not in labels -# assert 'module-git-commit' not in labels -# assert annotations['module-name'] == 'test_module' -# assert 'module-version' not in annotations -# assert 'module-git-commit' not in annotations -# -# -# -# def test_create_and_launch_deployment_helper_success(): -# # Mocking the necessary Kubernetes interactions -# with patch('src.dependencies.lifecycle.some_k8s_function', return_value=True): -# result = _create_and_launch_deployment_helper('test_params') -# assert result is True -# -# def test_create_cluster_ip_service_helper_success(): -# # Mocking the necessary Kubernetes interactions -# with patch('src.dependencies.lifecycle.some_k8s_function', return_value=True): -# result = _create_cluster_ip_service_helper('test_params') -# assert result is True -# -# def test_update_ingress_for_service_helper_success(): -# # Mocking the necessary Kubernetes interactions -# with patch('src.dependencies.lifecycle.some_k8s_function', return_value=True): -# result = _update_ingress_for_service_helper('test_params') -# assert result is True -# -# # Additional tests can be added for failure scenarios, exceptions, etc. -# -# -# -# def test_start_deployment_success(): -# # Mocking the necessary functions and interactions for a successful deployment start -# with patch('src.dependencies.lifecycle.some_k8s_function', return_value=True), patch('src.dependencies.lifecycle.some_other_function', return_value=True): -# result = start_deployment('test_module', 'test_version', 1) -# # Assertions to check successful deployment start -# -# def test_start_deployment_service_exists(): -# # Mocking the scenario where the service already exists -# with patch('src.dependencies.lifecycle.some_k8s_function', return_value=True): -# with pytest.raises(Exception, match="Service already exists"): -# start_deployment('test_module', 'test_version', 1) -# -# def test_stop_deployment_success(): -# # Mocking the necessary functions and interactions for a successful deployment stop -# with patch('src.dependencies.lifecycle.some_k8s_function', return_value=True), patch('src.dependencies.lifecycle.some_other_function', return_value=True): -# result = stop_deployment('test_module', 'test_version') -# # Assertions to check successful deployment stop -# -# def test_stop_deployment_no_rights(): -# # Mocking the scenario where the user doesn't have the necessary rights -# with patch('src.dependencies.lifecycle.some_k8s_function', return_value=False): -# with pytest.raises(Exception, match="User does not have rights"): -# stop_deployment('test_module', 'test_version') -# -# # Additional tests can be added for other scenarios, exceptions, etc. diff --git a/test/src/fixtures/fixtures.py b/test/src/fixtures/fixtures.py new file mode 100644 index 0000000..4c3266c --- /dev/null +++ b/test/src/fixtures/fixtures.py @@ -0,0 +1,82 @@ +import os +from unittest.mock import Mock + +import pytest +from dotenv import load_dotenv +from fastapi import Request + +from src.clients.CachedCatalogClient import CachedCatalogClient +from src.clients.KubernetesClients import K8sClients +from src.configs.settings import get_settings + + +@pytest.fixture(autouse=True) +def mock_request(): + return get_example_mock_request() + + + + +@pytest.fixture(autouse=True) +def generate_kubeconfig(): + # Generate a kubeconfig file for testing + # Overwrite kubeconfig + os.environ["KUBECONFIG"] = "test_kubeconfig_file" + kubeconfig_path = os.environ["KUBECONFIG"] + + kubeconfig_content = """\ +apiVersion: v1 +kind: Config +current-context: test-context +clusters: +- name: test-cluster + cluster: + server: https://test-api-server + insecure-skip-tls-verify: true +contexts: +- name: test-context + context: + cluster: test-cluster + user: test-user +users: +- name: test-user + user: + exec: + command: echo + apiVersion: client.authentication.k8s.io/v1alpha1 + args: + - "access_token" +""" + + with open(kubeconfig_path, "w") as kubeconfig_file: + kubeconfig_file.write(kubeconfig_content.strip()) + + yield + + # Clean up the generated kubeconfig file after the tests + os.remove(kubeconfig_path) + + +def get_example_mock_request(): + request = Mock(spec=Request) + request.app.state.settings = get_settings() + + mock_module_info = { + "git_commit_hash": "test_hash", + "version": "test_version", + "git_url": "https://github.com/test/repo", + "module_name": "test_module", + "release_tags": ["test_tag"], + "owners": ["test_owner"], + "docker_img_name": "test_img_name", + } + + request.app.state.catalog_client = Mock(spec=CachedCatalogClient) + request.app.state.catalog_client.get_combined_module_info.return_value = mock_module_info + request.app.state.catalog_client.list_service_volume_mounts.return_value = [] + request.app.state.catalog_client.get_secure_params.return_value = [{"param_name": "test_secure_param_name", "param_value": "test_secure_param_value"}] + request.app.state.k8s_clients = Mock(spec=K8sClients) + request.app.state.mock_module_info = mock_module_info + + + return request