Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Boris committed Sep 25, 2023
1 parent 3e64749 commit 4047256
Show file tree
Hide file tree
Showing 15 changed files with 323 additions and 296 deletions.
32 changes: 31 additions & 1 deletion src/clients/KubernetesClients.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from typing import Optional

from cacheout import LRUCache
from fastapi.requests import Request
from kubernetes import config
from kubernetes.client import CoreV1Api, AppsV1Api, NetworkingV1Api
from kubernetes.client import CoreV1Api, AppsV1Api, NetworkingV1Api, V1Deployment

from src.configs.settings import Settings

Expand Down Expand Up @@ -66,3 +67,32 @@ def __init__(
self.network_client = k8s_network_client
self.service_status_cache = LRUCache(ttl=10)
self.all_service_status_cache = LRUCache(ttl=10)


def get_k8s_core_client(request: Request) -> CoreV1Api:
return request.app.state.k8s_clients.core_client


def get_k8s_app_client(request: Request) -> AppsV1Api:
return request.app.state.k8s_clients.app_client


def get_k8s_networking_client(request: Request) -> NetworkingV1Api:
return request.app.state.k8s_clients.network_client


def get_k8s_service_status_cache(request: Request) -> LRUCache:
return request.app.state.k8s_clients.service_status_cache


def get_k8s_all_service_status_cache(request: Request) -> LRUCache:
return request.app.state.k8s_clients.all_service_status_cache


def check_service_status_cache(request: Request, label_selector_text) -> V1Deployment:
cache = get_k8s_service_status_cache(request)
return cache.get(label_selector_text, None)


def populate_service_status_cache(request: Request, label_selector_text, data: list):
get_k8s_service_status_cache(request).set(label_selector_text, data)
56 changes: 15 additions & 41 deletions src/dependencies/k8_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import time
from typing import Optional, List

from cacheout import LRUCache
from fastapi import Request
from kubernetes import client
from kubernetes.client import (
Expand All @@ -14,48 +13,23 @@
V1IngressSpec,
V1IngressRule,
ApiException,
CoreV1Api,
AppsV1Api,
NetworkingV1Api,
V1HTTPIngressPath,
V1IngressBackend,
V1Deployment,
V1HTTPIngressRuleValue,
V1Toleration,
)

from src.clients.KubernetesClients import (
get_k8s_core_client,
get_k8s_app_client,
get_k8s_networking_client,
get_k8s_all_service_status_cache,
check_service_status_cache,
populate_service_status_cache,
)
from src.configs.settings import get_settings


def get_k8s_core_client(request: Request) -> CoreV1Api:
return request.app.state.k8s_clients.core_client


def get_k8s_app_client(request: Request) -> AppsV1Api:
return request.app.state.k8s_clients.app_client


def get_k8s_networking_client(request: Request) -> NetworkingV1Api:
return request.app.state.k8s_clients.network_client


def _get_k8s_service_status_cache(request: Request) -> LRUCache:
return request.app.state.k8s_clients.service_status_cache


def _get_k8s_all_service_status_cache(request: Request) -> LRUCache:
return request.app.state.k8s_clients.all_service_status_cache


def check_service_status_cache(request: Request, label_selector_text) -> V1Deployment:
cache = _get_k8s_service_status_cache(request)
return cache.get(label_selector_text, None)


def populate_service_status_cache(request: Request, label_selector_text, data: list):
_get_k8s_service_status_cache(request).set(label_selector_text, data)


def get_pods_in_namespace(
k8s_client: client.CoreV1Api,
field_selector=None,
Expand Down Expand Up @@ -88,7 +62,7 @@ def v1_volume_mount_factory(mounts):
return volumes, volume_mounts


def _sanitize_deployment_name(module_name, module_git_commit_hash):
def sanitize_deployment_name(module_name, module_git_commit_hash):
"""
Create a deployment name based on the module name and git commit hash.
Adhere to Kubernetes API naming rules and create valid DNS labels.
Expand All @@ -108,7 +82,7 @@ def _sanitize_deployment_name(module_name, module_git_commit_hash):

def create_clusterip_service(request, module_name, module_git_commit_hash, labels) -> client.V1Service:
core_v1_api = get_k8s_core_client(request)
deployment_name, service_name = _sanitize_deployment_name(module_name, module_git_commit_hash)
deployment_name, service_name = sanitize_deployment_name(module_name, module_git_commit_hash)

# Define the service
service = V1Service(
Expand Down Expand Up @@ -189,15 +163,15 @@ def _update_ingress_with_retries(request, new_path, namespace, retries=3):
def update_ingress_to_point_to_service(request: Request, module_name: str, git_commit_hash: str):
settings = request.app.state.settings
namespace = settings.namespace
deployment_name, service_name = _sanitize_deployment_name(module_name, git_commit_hash)
deployment_name, service_name = sanitize_deployment_name(module_name, git_commit_hash)
# Need to sync this with Status methods
path = f"/{settings.external_ds_url.split('/')[-1]}/{module_name}.{git_commit_hash}(/|$)(.*)"
new_path = V1HTTPIngressPath(path=path, path_type="ImplementationSpecific", backend=V1IngressBackend(service={"name": service_name, "port": {"number": 5000}}))
_update_ingress_with_retries(request=request, new_path=new_path, namespace=namespace)


def create_and_launch_deployment(request, module_name, module_git_commit_hash, image, labels, annotations, env, mounts) -> client.V1LabelSelector:
deployment_name, service_name = _sanitize_deployment_name(module_name, module_git_commit_hash)
deployment_name, service_name = sanitize_deployment_name(module_name, module_git_commit_hash)
namespace = request.app.state.settings.namespace

annotations["k8s_deployment_name"] = deployment_name
Expand Down Expand Up @@ -264,7 +238,7 @@ def get_k8s_deployments(request, label_selector="us.kbase.dynamicservice=true")
:return: A list of deployments
"""

cache = _get_k8s_all_service_status_cache(request)
cache = get_k8s_all_service_status_cache(request)
cached_deployments = cache.get(label_selector, None)
if cached_deployments is not None:
return cached_deployments
Expand All @@ -278,7 +252,7 @@ def get_k8s_deployments(request, label_selector="us.kbase.dynamicservice=true")


def delete_deployment(request, module_name, module_git_commit_hash) -> str:
deployment_name, _ = _sanitize_deployment_name(module_name, module_git_commit_hash)
deployment_name, _ = sanitize_deployment_name(module_name, module_git_commit_hash)
namespace = request.app.state.settings.namespace
get_k8s_app_client(request).delete_namespaced_deployment(name=deployment_name, namespace=namespace)
return deployment_name
Expand All @@ -292,7 +266,7 @@ def scale_replicas(request, module_name, module_git_commit_hash, replicas: int)


def get_logs_for_first_pod_in_deployment(request, module_name, module_git_commit_hash):
deployment_name, _ = _sanitize_deployment_name(module_name, module_git_commit_hash)
deployment_name, _ = sanitize_deployment_name(module_name, module_git_commit_hash)
namespace = request.app.state.settings.namespace
label_selector_text = f"us.kbase.module.module_name={module_name.lower()}," + f"us.kbase.module.git_commit_hash={module_git_commit_hash}"

Expand Down
2 changes: 0 additions & 2 deletions src/dependencies/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ def get_volume_mounts(request, module_name, module_version):

def _setup_metadata(module_name, requested_module_version, git_commit_hash, version, git_url) -> Tuple[Dict, Dict]:
"""
TESTED=TRUE
Convenience method to set up the labels and annotations for a deployment.
:param module_name: Module name that comes from the web request
Expand Down
2 changes: 1 addition & 1 deletion src/dependencies/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_service_log(request: Request, module_name: str, module_version: str) ->
return [{"instance_id": pod_name, "log": logs}]


def get_service_log_web_socket(request: Request, module_name: str, module_version: str) -> List[dict]:
def get_service_log_web_socket(request: Request, module_name: str, module_version: str) -> List[dict]: # pragma: no cover
"""
Get logs for a service. This isn't used anywhere but can require a dependency on rancher if implemented.
Expand Down
4 changes: 4 additions & 0 deletions src/fastapi_routes/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,12 @@ def json_rpc(request: Request, body: bytes = Depends(get_body)) -> Response | HT
else:
request.state.user_auth_roles = user_auth_roles

print(request, params, jrpc_id)
valid_response = request_function(request, params, jrpc_id) # type:JSONRPCResponse
print("RESPONSE IS", valid_response)
converted_response = jsonable_encoder(valid_response)
print("CONVERTED RESPONSE IS", converted_response)
if "error" in converted_response:
print("HERE YOU GO")
return JSONResponse(content=converted_response, status_code=500)
return JSONResponse(content=converted_response, status_code=200)
14 changes: 13 additions & 1 deletion src/rpc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ def validate_rpc_request(body):
params = json_data.get("params", [])
jrpc_id = json_data.get("id", 0)

if not isinstance(method, str) or not isinstance(params, list):
if not isinstance(method, str) and not isinstance(params, list):
raise ServerError(message=f"`method` must be a valid SW1 method string. Params must be a dictionary. {json_data}", code=-32600, name="Invalid Request")
print(type(method), type(params), type(jrpc_id))
return method, params, jrpc_id


Expand Down Expand Up @@ -83,6 +84,16 @@ def handle_rpc_request(
method_name = action.__name__
try:
params = params[0]
if not isinstance(params, dict):
return JSONRPCResponse(
id=jrpc_id,
error=ErrorResponse(
message=f"Invalid params for ServiceWizard.{method_name}",
code=-32602,
name="Invalid params",
error=f"Params must be a dictionary. Got {type(params)}",
),
)
except IndexError:
return no_params_passed(method=method_name, jrpc_id=jrpc_id)

Expand All @@ -93,6 +104,7 @@ def handle_rpc_request(

try:
result = action(request, module_name, module_version)
print("ABOUT TO RETURN RESULT", result)
return JSONRPCResponse(id=jrpc_id, result=[result])
except ServerError as e:
traceback_str = traceback.format_exc()
Expand Down
2 changes: 1 addition & 1 deletion src/rpc/error_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,4 @@ def token_validation_failed(jrpc_id):


def json_rpc_response_to_exception(content: JSONRPCResponse, status_code=500):
return JSONResponse(content=content.dict(), status_code=status_code)
return JSONResponse(content=content.model_dump(), status_code=status_code)
31 changes: 18 additions & 13 deletions src/rpc/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Any
from typing import Any, Optional, Union

from pydantic import BaseModel

Expand All @@ -10,22 +10,27 @@ class ErrorResponse(BaseModel):
error: str = None




class JSONRPCResponse(BaseModel):
version: str = "1.0"
id: Optional[int | str]
error: Optional[ErrorResponse]
id: Optional[Union[int, str]] = 0
error: Optional[ErrorResponse] = None
result: Any = None

def dict(self, *args, **kwargs):
response_dict = super().dict(*args, **kwargs)
if self.result is None:
response_dict.pop("result", None)
def model_dump(self, *args, **kwargs) -> dict[str, Any]:
# Default behavior for the serialization
serialized_data = super().model_dump(*args, **kwargs)

# Custom logic to exclude fields based on their values
if serialized_data.get("result") is None:
serialized_data.pop("result", None)

if self.error is None:
response_dict.pop("error", None)
response_dict.pop("version", None)
if serialized_data.get("error") is None:
serialized_data.pop("error", None)
serialized_data.pop("version", None)

if self.id is None:
response_dict.pop("id", None)
if serialized_data.get("id") is None:
serialized_data.pop("id", None)

return response_dict
return serialized_data
10 changes: 4 additions & 6 deletions src/rpc/unauthenticated_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@ def start(request: Request, params: list[dict], jrpc_id: str) -> JSONRPCResponse
return handle_rpc_request(request, params, jrpc_id, start_deployment)


def status(request: Request, params: list[dict], jrpc_id: str) -> JSONRPCResponse:
if not params:
params = [{}]
def status(request: Request, params: list[dict], jrpc_id: str) -> JSONRPCResponse: # noqa F811
params = [{}]
return handle_rpc_request(request, params, jrpc_id, get_status)


def version(request: Request, params: list[dict], jrpc_id: str) -> JSONRPCResponse:
if not params:
params = [{}]
def version(request: Request, params: list[dict], jrpc_id: str) -> JSONRPCResponse: # noqa F811
params = [{}]
return handle_rpc_request(request, params, jrpc_id, get_version)
49 changes: 10 additions & 39 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,21 @@
import os
from glob import glob

import pytest
from dotenv import load_dotenv


def _as_module(fixture_path: str) -> str:
return fixture_path.replace("/", ".").replace("\\", ".").replace(".py", "")

def pytest_collectreport(report):
print("CONFTEST loaded")

@pytest.fixture(autouse=True)
def load_environment():
# Ensure that the environment variables are loaded before running the tests
load_dotenv()
load_dotenv(os.environ.get("DOTENV_FILE_LOCATION", ".env"))


@pytest.fixture(autouse=True)
def generate_kubeconfig():
# Generate a kubeconfig file for testing
# Overwrite kubeconfig
os.environ["KUBECONFIG"] = "test_kubeconfig_file"
kubeconfig_path = os.environ["KUBECONFIG"]

kubeconfig_content = """\
apiVersion: v1
kind: Config
current-context: test-context
clusters:
- name: test-cluster
cluster:
server: https://test-api-server
insecure-skip-tls-verify: true
contexts:
- name: test-context
context:
cluster: test-cluster
user: test-user
users:
- name: test-user
user:
exec:
command: echo
apiVersion: client.authentication.k8s.io/v1alpha1
args:
- "access_token"
"""

with open(kubeconfig_path, "w") as kubeconfig_file:
kubeconfig_file.write(kubeconfig_content.strip())

yield

# Clean up the generated kubeconfig file after the tests
os.remove(kubeconfig_path)

pytest_plugins = [_as_module(fixture) for fixture in glob("test/src/fixtures/[!_]*.py")]
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

def get_running_deployment(deployment_name) -> DynamicServiceStatus:
module_info = _create_sample_module_info()
deployment = _create_sample_deployment(deployment_name=deployment_name, ready_replicas=1)
deployment = create_sample_deployment(deployment_name=deployment_name, ready_replicas=1, replicas=1, available_replicas=1, unavailable_replicas=0)
deployment_status = _create_deployment_status(module_info, deployment)
return deployment_status


def get_stopped_deployment(deployment_name) -> DynamicServiceStatus:
module_info = _create_sample_module_info()
deployment = _create_sample_deployment(deployment_name=deployment_name, ready_replicas=0, available_replicas=0)
deployment = create_sample_deployment(deployment_name=deployment_name, ready_replicas=0, available_replicas=0, unavailable_replicas=1, replicas=0)
deployment_status = _create_deployment_status(module_info, deployment)
return deployment_status

Expand All @@ -34,7 +34,7 @@ def _create_deployment_status(module_info, deployment) -> DynamicServiceStatus:
)


def _create_sample_deployment(deployment_name, replicas=1, ready_replicas=1, available_replicas=1, unavailable_replicas=0):
def create_sample_deployment(deployment_name, replicas, ready_replicas, available_replicas, unavailable_replicas):
deployment_status = V1DeploymentStatus(
updated_replicas=replicas, ready_replicas=ready_replicas, available_replicas=available_replicas, unavailable_replicas=unavailable_replicas
)
Expand Down
Loading

0 comments on commit 4047256

Please sign in to comment.