diff --git a/CHANGELOG.md b/CHANGELOG.md index f0ea866..d493aa4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- A `ResilientWatcher` utility class to reconnect Kubernetes client streams on `ProtocolErrors` - [#107](https://github.com/PrefectHQ/prefect-kubernetes/pull/107) + ### Changed ### Deprecated diff --git a/prefect_kubernetes/events.py b/prefect_kubernetes/events.py index 3f22ed7..18ff4e9 100644 --- a/prefect_kubernetes/events.py +++ b/prefect_kubernetes/events.py @@ -1,10 +1,13 @@ import atexit +import logging import threading from typing import TYPE_CHECKING, Dict, List, Optional from prefect.events import Event, RelatedResource, emit_event from prefect.utilities.importtools import lazy_import +from prefect_kubernetes.utilities import ResilientStreamWatcher + if TYPE_CHECKING: import kubernetes import kubernetes.client @@ -38,11 +41,13 @@ def __init__( worker_resource: Dict[str, str], related_resources: List[RelatedResource], timeout_seconds: int, + logger: Optional[logging.Logger] = None, ): self._client = client self._job_name = job_name self._namespace = namespace self._timeout_seconds = timeout_seconds + self._logger = logger # All events emitted by this replicator have the pod itself as the # resource. The `worker_resource` is what the worker uses when it's @@ -52,7 +57,7 @@ def __init__( worker_related_resource = RelatedResource(__root__=worker_resource) self._related_resources = related_resources + [worker_related_resource] - self._watch = kubernetes.watch.Watch() + self._watch = ResilientStreamWatcher(logger=self._logger) self._thread = threading.Thread(target=self._replicate_pod_events) self._state = "READY" @@ -90,7 +95,7 @@ def _replicate_pod_events(self): try: core_client = kubernetes.client.CoreV1Api(api_client=self._client) - for event in self._watch.stream( + for event in self._watch.api_object_stream( func=core_client.list_namespaced_pod, namespace=self._namespace, label_selector=f"job-name={self._job_name}", diff --git a/prefect_kubernetes/pods.py b/prefect_kubernetes/pods.py index ab8999b..a33a981 100644 --- a/prefect_kubernetes/pods.py +++ b/prefect_kubernetes/pods.py @@ -2,11 +2,11 @@ from typing import Any, Callable, Dict, Optional, Union from kubernetes.client.models import V1DeleteOptions, V1Pod, V1PodList -from kubernetes.watch import Watch from prefect import task from prefect.utilities.asyncutils import run_sync_in_worker_thread from prefect_kubernetes.credentials import KubernetesCredentials +from prefect_kubernetes.utilities import ResilientStreamWatcher @task @@ -45,7 +45,6 @@ def kubernetes_orchestrator(): ``` """ with kubernetes_credentials.get_client("core") as core_v1_client: - return await run_sync_in_worker_thread( core_v1_client.create_namespaced_pod, namespace=namespace, @@ -93,7 +92,6 @@ def kubernetes_orchestrator(): ``` """ with kubernetes_credentials.get_client("core") as core_v1_client: - return await run_sync_in_worker_thread( core_v1_client.delete_namespaced_pod, pod_name, @@ -135,7 +133,6 @@ def kubernetes_orchestrator(): ``` """ with kubernetes_credentials.get_client("core") as core_v1_client: - return await run_sync_in_worker_thread( core_v1_client.list_namespaced_pod, namespace=namespace, **kube_kwargs ) @@ -180,7 +177,6 @@ def kubernetes_orchestrator(): ``` """ with kubernetes_credentials.get_client("core") as core_v1_client: - return await run_sync_in_worker_thread( core_v1_client.patch_namespaced_pod, name=pod_name, @@ -224,7 +220,6 @@ def kubernetes_orchestrator(): ``` """ with kubernetes_credentials.get_client("core") as core_v1_client: - return await run_sync_in_worker_thread( core_v1_client.read_namespaced_pod, name=pod_name, @@ -281,11 +276,11 @@ def kubernetes_orchestrator(): ``` """ with kubernetes_credentials.get_client("core") as core_v1_client: - if print_func is not None: # should no longer need to manually refresh on ApiException.status == 410 # as of https://github.com/kubernetes-client/python-base/pull/133 - for log_line in Watch().stream( + watcher = ResilientStreamWatcher() + for log_line in watcher.stream( core_v1_client.read_namespaced_pod_log, name=pod_name, namespace=namespace, @@ -341,7 +336,6 @@ def kubernetes_orchestrator(): ``` """ with kubernetes_credentials.get_client("core") as core_v1_client: - return await run_sync_in_worker_thread( core_v1_client.replace_namespaced_pod, body=new_pod, diff --git a/prefect_kubernetes/utilities.py b/prefect_kubernetes/utilities.py index 18ee4d6..d5c0381 100644 --- a/prefect_kubernetes/utilities.py +++ b/prefect_kubernetes/utilities.py @@ -1,7 +1,11 @@ """ Utilities for working with the Python Kubernetes API. """ +import logging +import time from pathlib import Path -from typing import Optional, TypeVar, Union +from typing import Callable, List, Optional, Set, Type, TypeVar, Union +import urllib3 +from kubernetes import watch from kubernetes.client import models as k8s_models from prefect.infrastructure.kubernetes import KubernetesJob, KubernetesManifest from slugify import slugify @@ -13,6 +17,24 @@ V1KubernetesModel = TypeVar("V1KubernetesModel") +class _CappedSet(set): + """ + A set with a bounded size. + """ + + def __init__(self, maxsize): + super().__init__() + self.maxsize = maxsize + + def add(self, value): + """ + Add to the set and maintain its max size. + """ + if len(self) >= self.maxsize: + self.pop() + super().add(value) + + def convert_manifest_to_model( manifest: Union[Path, str, KubernetesManifest], v1_model_name: str ) -> V1KubernetesModel: @@ -180,3 +202,123 @@ def _slugify_label_value(value: str, max_length: int = 63) -> str: # Kubernetes to throw the validation error return slug + + +class ResilientStreamWatcher: + """ + A wrapper class around kuberenetes.watch.Watch that will reconnect on + certain exceptions. + """ + + DEFAULT_RECONNECT_EXCEPTIONS = (urllib3.exceptions.ProtocolError,) + + def __init__( + self, + logger: Optional[logging.Logger] = None, + max_cache_size: int = 50000, + reconnect_exceptions: Optional[List[Type[Exception]]] = None, + ) -> None: + """ + A utility class for managing streams of Kuberenetes API objects and logs + + Attributes: + logger: A logger which will be used interally to log errors + max_cache_size: The maximum number of API objects to track in an + internal cache to help deduplicate results on stream reconnects + reconnect_exceptions: A list of exceptions that will cause the stream + to reconnect. + """ + + self.max_cache_size = max_cache_size + self.logger = logger + self.watch = watch.Watch() + + reconnect_exceptions = ( + reconnect_exceptions + if reconnect_exceptions is not None + else self.DEFAULT_RECONNECT_EXCEPTIONS + ) + self.reconnect_exceptions = tuple(reconnect_exceptions) + + def stream(self, func: Callable, *args, cache: Optional[Set] = None, **kwargs): + """ + A method for streaming API objects or logs from a Kubernetes + client function. This method will reconnect the stream on certain + configurable exceptions and deduplicate results on reconnects if + streaming API objects and a cache is provided. + + Note that client functions that produce a stream will + restart a stream from the beginning of the log's history on reconnect. + If a cache is not provided, it is possible for duplicate entries to be yielded. + + Args: + func: A Kubernetes client function to call which produces a stream + of logs + *args: Positional arguments to pass to `func` + cache: A keyward argument that provides a way to deduplicate + results on reconnects and bound + **kwargs: Keyword arguments to pass to `func` + + Returns: + An iterator of log + """ + keep_streaming = True + while keep_streaming: + try: + for event in self.watch.stream(func, *args, **kwargs): + # check that we want to and can track this object + if ( + cache is not None + and isinstance(event, dict) + and "object" in event + ): + uid = event["object"].metadata.uid + if uid not in cache: + cache.add(uid) + yield event + else: + yield event + else: + # Case: we've finished iterating + keep_streaming = False + except self.reconnect_exceptions: + # Case: We've hit an exception we're willing to retry on + if self.logger: + self.logger.error("Unable to connect, retrying...", exc_info=True) + time.sleep(1) + except Exception: + # Case: We hit an exception we're unwilling to retry on + if self.logger: + self.logger.exception( + f"Unexpected error while streaming {func.__name__}" + ) + keep_streaming = False + self.stop() + raise + + self.stop() + + def api_object_stream(self, func: Callable, *args, **kwargs): + """ + Create a cache to maintain a record of API objects that have been + seen. This is useful because `stream` will reconnect a stream on + `self.reconnect_exceptions` and on reconnect it will restart streaming all + objects. This cache prevents the same object from being yielded twice. + + Args: + func: A Kubernetes client function to call which produces a stream of API o + bjects + *args: Positional arguments to pass to `func` + **kwargs: Keyword arguments to pass to `func` + + Returns: + An iterator of API objects + """ + cache = _CappedSet(self.max_cache_size) + yield from self.stream(func, *args, cache=cache, **kwargs) + + def stop(self): + """ + Shut down the internal Watch object. + """ + self.watch.stop() diff --git a/prefect_kubernetes/worker.py b/prefect_kubernetes/worker.py index a3b0def..ec37d77 100644 --- a/prefect_kubernetes/worker.py +++ b/prefect_kubernetes/worker.py @@ -144,6 +144,7 @@ from prefect_kubernetes.events import KubernetesEventsReplicator from prefect_kubernetes.utilities import ( + ResilientStreamWatcher, _slugify_label_key, _slugify_label_value, _slugify_name, @@ -575,7 +576,6 @@ async def run( task_status.started(pid) # Monitor the job until completion - events_replicator = KubernetesEventsReplicator( client=client, job_name=job.metadata.name, @@ -585,6 +585,7 @@ async def run( configuration=configuration ), timeout_seconds=configuration.pod_watch_timeout_seconds, + logger=logger, ) with events_replicator: @@ -910,15 +911,16 @@ def _watch_job( if configuration.stream_output: with self._get_core_client(client) as core_client: - logs = core_client.read_namespaced_pod_log( - pod.metadata.name, - configuration.namespace, - follow=True, - _preload_content=False, - container="prefect-job", - ) + watch = ResilientStreamWatcher(logger=logger) try: - for log in logs.stream(): + for log in watch.stream( + core_client.read_namespaced_pod_log, + pod.metadata.name, + configuration.namespace, + follow=True, + _preload_content=False, + container="prefect-job", + ): print(log.decode().rstrip()) # Check if we have passed the deadline and should stop streaming @@ -928,7 +930,6 @@ def _watch_job( ) if deadline and remaining_time <= 0: break - except Exception: logger.warning( ( @@ -957,7 +958,7 @@ def _watch_job( ) return -1 - watch = kubernetes.watch.Watch() + watch = ResilientStreamWatcher(logger=logger) # The kubernetes library will disable retries if the timeout kwarg is # present regardless of the value so we do not pass it unless given # https://github.com/kubernetes-client/python/blob/84f5fea2a3e4b161917aa597bf5e5a1d95e24f5a/kubernetes/base/watch/watch.py#LL160 @@ -965,7 +966,7 @@ def _watch_job( {"timeout_seconds": remaining_time} if deadline else {} ) - for event in watch.stream( + for event in watch.api_object_stream( func=batch_client.list_namespaced_job, field_selector=f"metadata.name={job_name}", namespace=configuration.namespace, @@ -1065,12 +1066,12 @@ def _get_job_pod( """Get the first running pod for a job.""" from kubernetes.client.models import V1Pod - watch = kubernetes.watch.Watch() + watch = ResilientStreamWatcher(logger=logger) logger.debug(f"Job {job_name!r}: Starting watch for pod start...") last_phase = None last_pod_name: Optional[str] = None with self._get_core_client(client) as core_client: - for event in watch.stream( + for event in watch.api_object_stream( func=core_client.list_namespaced_pod, namespace=configuration.namespace, label_selector=f"job-name={job_name}", diff --git a/tests/conftest.py b/tests/conftest.py index f86a269..72bb7ca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,8 +4,14 @@ import pytest import yaml -from kubernetes.client import AppsV1Api, BatchV1Api, CoreV1Api, CustomObjectsApi, models -from kubernetes.client.exceptions import ApiException +from kubernetes.client import ( + ApiException, + AppsV1Api, + BatchV1Api, + CoreV1Api, + CustomObjectsApi, + models, +) from prefect.blocks.kubernetes import KubernetesClusterConfig from prefect.settings import PREFECT_LOGGING_TO_API_ENABLED, temporary_settings from prefect.testing.utilities import prefect_test_harness @@ -180,7 +186,7 @@ def mock_delete_namespaced_job(monkeypatch): @pytest.fixture def mock_stream_timeout(monkeypatch): monkeypatch.setattr( - "kubernetes.watch.Watch.stream", + "prefect_kubernetes.utilities.watch.Watch.stream", MagicMock(side_effect=ApiException(status=408)), ) diff --git a/tests/test_events_replicator.py b/tests/test_events_replicator.py index 96ea3d2..b34fccd 100644 --- a/tests/test_events_replicator.py +++ b/tests/test_events_replicator.py @@ -9,6 +9,7 @@ from prefect.utilities.importtools import lazy_import from prefect_kubernetes.events import EVICTED_REASONS, KubernetesEventsReplicator +from prefect_kubernetes.utilities import ResilientStreamWatcher kubernetes = lazy_import("kubernetes") @@ -169,8 +170,8 @@ def test_lifecycle(replicator): def test_replicate_successful_pod_events(replicator, successful_pod_stream): - mock_watch = MagicMock(spec=kubernetes.watch.Watch) - mock_watch.stream.return_value = successful_pod_stream + mock_watch = MagicMock(spec=ResilientStreamWatcher) + mock_watch.api_object_stream.return_value = successful_pod_stream event_count = 0 @@ -257,12 +258,12 @@ def event(*args, **kwargs): ), ] ) - mock_watch.stop.assert_called_once_with() + # mock_watch.stop.assert_called_once_with() def test_replicate_failed_pod_events(replicator, failed_pod_stream): - mock_watch = MagicMock(spec=kubernetes.watch.Watch) - mock_watch.stream.return_value = failed_pod_stream + mock_watch = MagicMock(spec=ResilientStreamWatcher) + mock_watch.api_object_stream.return_value = failed_pod_stream event_count = 0 @@ -353,8 +354,8 @@ def event(*args, **kwargs): def test_replicate_evicted_pod_events(replicator, evicted_pod_stream): - mock_watch = MagicMock(spec=kubernetes.watch.Watch) - mock_watch.stream.return_value = evicted_pod_stream + mock_watch = MagicMock(spec=ResilientStreamWatcher) + mock_watch.api_object_stream.return_value = evicted_pod_stream event_count = 0 diff --git a/tests/test_flows.py b/tests/test_flows.py index 0b5b84d..e0f1d3f 100644 --- a/tests/test_flows.py +++ b/tests/test_flows.py @@ -37,7 +37,6 @@ async def test_run_namespaced_job_successful( mock_list_namespaced_pod, read_pod_logs, ): - await run_namespaced_job(kubernetes_job=valid_kubernetes_job_block) assert mock_create_namespaced_job.call_count == 1 @@ -84,7 +83,6 @@ async def test_run_namespaced_job_unsuccessful( mock_list_namespaced_pod, read_pod_logs, ): - successful_job_status.status.failed = 1 successful_job_status.status.succeeded = None mock_read_namespaced_job_status.return_value = successful_job_status diff --git a/tests/test_utilities.py b/tests/test_utilities.py index 4bc69cd..ed09ab2 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -1,8 +1,18 @@ +import logging +import uuid +from typing import Type +from unittest import mock + +import kubernetes import pytest +import urllib3 from kubernetes.client import models as k8s_models from prefect.infrastructure.kubernetes import KubernetesJob -from prefect_kubernetes.utilities import convert_manifest_to_model +from prefect_kubernetes.utilities import ( + ResilientStreamWatcher, + convert_manifest_to_model, +) base_path = "tests/sample_k8s_resources" @@ -200,3 +210,110 @@ def test_bad_model_type_raises(v1_model_name): match="`v1_model` must be the name of a valid Kubernetes client model.", ): convert_manifest_to_model(sample_deployment_manifest, v1_model_name) + + +def test_resilient_streaming_retries_on_configured_errors(caplog): + watcher = ResilientStreamWatcher(logger=logging.getLogger("test")) + + with mock.patch.object( + watcher.watch, + "stream", + side_effect=[ + watcher.reconnect_exceptions[0], + watcher.reconnect_exceptions[0], + ["random_success"], + ], + ) as mocked_stream: + for log in watcher.api_object_stream(str): + assert log == "random_success" + + assert mocked_stream.call_count == 3 + assert "Unable to connect, retrying..." in caplog.text + + +@pytest.mark.parametrize( + "exc", [Exception, TypeError, ValueError, urllib3.exceptions.ProtocolError] +) +def test_resilient_streaming_raises_on_unconfigured_errors( + exc: Type[Exception], caplog +): + watcher = ResilientStreamWatcher( + logger=logging.getLogger("test"), reconnect_exceptions=[] + ) + + with mock.patch.object(watcher.watch, "stream", side_effect=[exc]) as mocked_stream: + with pytest.raises(exc): + for _ in watcher.api_object_stream(str): + pass + + assert mocked_stream.call_count == 1 + assert "Unexpected error" in caplog.text + assert exc.__name__ in caplog.text + + +def _create_api_objects_mocks(n: int = 3): + objects = [] + for _ in range(n): + o = mock.MagicMock(spec=kubernetes.client.V1Pod) + o.metadata = mock.PropertyMock() + o.metadata.uid = uuid.uuid4() + objects.append(o) + return objects + + +def test_resilient_streaming_deduplicates_api_objects_on_reconnects(): + watcher = ResilientStreamWatcher(logger=logging.getLogger("test")) + + object_pool = _create_api_objects_mocks() + thrown_exceptions = 0 + + def my_stream(*args, **kwargs): + """ + Simulate a stream that throws exceptions after yielding the first + object before yielding the rest of the objects. + """ + for o in object_pool: + yield {"object": o} + + nonlocal thrown_exceptions + if thrown_exceptions < 3: + thrown_exceptions += 1 + raise watcher.reconnect_exceptions[0] + + watcher.watch.stream = my_stream + results = [obj for obj in watcher.api_object_stream(str)] + + assert len(object_pool) == len(results) + + +def test_resilient_streaming_pulls_all_logs_on_reconnects(): + watcher = ResilientStreamWatcher(logger=logging.getLogger("test")) + + logs = ["log1", "log2", "log3", "log4"] + thrown_exceptions = 0 + + def my_stream(*args, **kwargs): + """ + Simulate a stream that throws exceptions after yielding the first + object before yielding the rest of the objects. + """ + for log in logs: + yield log + + nonlocal thrown_exceptions + if thrown_exceptions < 3: + thrown_exceptions += 1 + raise watcher.reconnect_exceptions[0] + + watcher.watch.stream = my_stream + results = [obj for obj in watcher.stream(str)] + + assert results == [ + "log1", # Before first exception + "log1", # Before second exception + "log1", # Before third exception + "log1", # No more exceptions from here onward + "log2", + "log3", + "log4", + ] diff --git a/tests/test_worker.py b/tests/test_worker.py index 708e6a6..dbf5dd6 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -67,7 +67,11 @@ def mock_watch(monkeypatch): mock = MagicMock() - monkeypatch.setattr("kubernetes.watch.Watch", MagicMock(return_value=mock)) + monkeypatch.setattr( + "prefect_kubernetes.worker.ResilientStreamWatcher", + MagicMock(return_value=mock), + raising=True, + ) return mock @@ -201,7 +205,7 @@ def enable_store_api_key_in_secret(monkeypatch): stream_output=True, ), lambda flow_run, deployment, flow: KubernetesWorkerJobConfiguration( - command="python -m prefect.engine", + command="prefect flow-run execute", env={ **get_current_settings().to_environment_variables(exclude_unset=True), "PREFECT__FLOW_RUN_ID": str(flow_run.id), @@ -259,7 +263,11 @@ def enable_store_api_key_in_secret(monkeypatch): }, ], "image": get_prefect_image_name(), - "args": ["python", "-m", "prefect.engine"], + "args": [ + "prefect", + "flow-run", + "execute", + ], } ], } @@ -478,7 +486,7 @@ def enable_store_api_key_in_secret(monkeypatch): stream_output=True, ), lambda flow_run, deployment, flow: KubernetesWorkerJobConfiguration( - command="python -m prefect.engine", + command="prefect flow-run execute", env={ **get_current_settings().to_environment_variables(exclude_unset=True), "PREFECT__FLOW_RUN_ID": str(flow_run.id), @@ -545,7 +553,11 @@ def enable_store_api_key_in_secret(monkeypatch): }, ], "image": get_prefect_image_name(), - "args": ["python", "-m", "prefect.engine"], + "args": [ + "prefect", + "flow-run", + "execute", + ], } ], } @@ -1211,7 +1223,7 @@ async def test_user_can_supply_a_sidecar_container_and_volume(self, flow_run): # the prefect-job container is still populated assert pod["containers"][0]["name"] == "prefect-job" - assert pod["containers"][0]["args"] == ["python", "-m", "prefect.engine"] + assert pod["containers"][0]["args"] == ["prefect", "flow-run", "execute"] assert pod["containers"][1] == { "name": "my-sidecar", @@ -1239,7 +1251,9 @@ async def test_creates_job_by_building_a_manifest( mock_core_client, mock_watch, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod + mock_watch.api_object_stream.return_value = ( + _mock_pods_stream_that_returns_running_pod() + ) default_configuration.prepare_for_flow_run(flow_run) expected_manifest = default_configuration.job_manifest @@ -1359,7 +1373,7 @@ async def test_job_name_creates_valid_name( job_name, clean_name, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod + mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod default_configuration.name = job_name default_configuration.prepare_for_flow_run(flow_run) async with KubernetesWorker(work_pool_name="test") as k8s_worker: @@ -1377,7 +1391,7 @@ async def test_uses_image_variable( mock_watch, mock_batch_client, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod + mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), {"image": "foo"} @@ -1398,7 +1412,7 @@ async def test_can_store_api_key_in_secret( mock_batch_client, enable_store_api_key_in_secret, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod + mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod mock_core_client.read_namespaced_secret.side_effect = ApiException(status=404) configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( @@ -1452,7 +1466,7 @@ async def test_store_api_key_in_existing_secret( mock_batch_client, enable_store_api_key_in_secret, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod + mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), {"image": "foo"} @@ -1654,7 +1668,7 @@ async def test_allows_image_setting_from_manifest( mock_watch, mock_batch_client, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod + mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod default_configuration.job_manifest["spec"]["template"]["spec"]["containers"][0][ "image" @@ -1676,7 +1690,7 @@ async def test_uses_labels_setting( mock_watch, mock_batch_client, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod + mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), @@ -1793,7 +1807,7 @@ async def test_sanitizes_user_label_keys( given, expected, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod + mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), {"labels": {given: "foo"}}, @@ -1841,7 +1855,7 @@ async def test_sanitizes_user_label_values( given, expected, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod + mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), @@ -1864,7 +1878,7 @@ async def test_uses_namespace_setting( mock_watch, mock_batch_client, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod + mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), {"namespace": "foo"}, @@ -1886,7 +1900,7 @@ async def test_allows_namespace_setting_from_manifest( mock_watch, mock_batch_client, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod + mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod default_configuration.job_manifest["metadata"]["namespace"] = "test" default_configuration.prepare_for_flow_run(flow_run) @@ -1906,7 +1920,7 @@ async def test_uses_service_account_name_setting( mock_watch, mock_batch_client, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod + mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), {"service_account_name": "foo"}, @@ -1927,7 +1941,7 @@ async def test_uses_finished_job_ttl_setting( mock_watch, mock_batch_client, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod + mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), {"finished_job_ttl": 123}, @@ -1948,7 +1962,7 @@ async def test_uses_specified_image_pull_policy( mock_watch, mock_batch_client, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod + mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), {"image_pull_policy": "IfNotPresent"}, @@ -1970,7 +1984,7 @@ async def test_defaults_to_incluster_config( mock_cluster_config, mock_batch_client, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod + mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod async with KubernetesWorker(work_pool_name="test") as k8s_worker: await k8s_worker.run(flow_run, default_configuration) @@ -1987,7 +2001,7 @@ async def test_uses_cluster_config_if_not_in_cluster( mock_cluster_config, mock_batch_client, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod + mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod mock_cluster_config.load_incluster_config.side_effect = ConfigException() async with KubernetesWorker(work_pool_name="test") as k8s_worker: await k8s_worker.run(flow_run, default_configuration) @@ -2004,7 +2018,9 @@ async def test_allows_configurable_timeouts_for_pod_and_job_watches( default_configuration: KubernetesWorkerJobConfiguration, flow_run, ): - mock_watch.stream = Mock(side_effect=_mock_pods_stream_that_returns_running_pod) + mock_watch.api_object_stream = Mock( + side_effect=_mock_pods_stream_that_returns_running_pod + ) # The job should not be completed to start mock_batch_client.read_namespaced_job.return_value.status.completion_time = None @@ -2031,7 +2047,7 @@ async def test_allows_configurable_timeouts_for_pod_and_job_watches( async with KubernetesWorker(work_pool_name="test") as k8s_worker: await k8s_worker.run(flow_run, default_configuration) - mock_watch.stream.assert_has_calls( + mock_watch.api_object_stream.assert_has_calls( [ mock.call( func=mock_core_client.list_namespaced_pod, @@ -2053,7 +2069,7 @@ async def test_excludes_timeout_from_job_watches_when_null( mock_batch_client, job_timeout, ): - mock_watch.stream = mock.Mock( + mock_watch.api_object_stream = mock.Mock( side_effect=_mock_pods_stream_that_returns_running_pod ) # The job should not be completed to start @@ -2064,7 +2080,7 @@ async def test_excludes_timeout_from_job_watches_when_null( async with KubernetesWorker(work_pool_name="test") as k8s_worker: await k8s_worker.run(flow_run, default_configuration) - mock_watch.stream.assert_has_calls( + mock_watch.api_object_stream.assert_has_calls( [ mock.call( func=mock_core_client.list_namespaced_pod, @@ -2089,7 +2105,7 @@ async def test_watches_the_right_namespace( mock_watch, mock_batch_client, ): - mock_watch.stream = mock.Mock( + mock_watch.api_object_stream = mock.Mock( side_effect=_mock_pods_stream_that_returns_running_pod ) # The job should not be completed to start @@ -2100,7 +2116,7 @@ async def test_watches_the_right_namespace( async with KubernetesWorker(work_pool_name="test") as k8s_worker: await k8s_worker.run(flow_run, default_configuration) - mock_watch.stream.assert_has_calls( + mock_watch.api_object_stream.assert_has_calls( [ mock.call( func=mock_core_client.list_namespaced_pod, @@ -2125,14 +2141,11 @@ async def test_streaming_pod_logs_timeout_warns( mock_batch_client, caplog, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod + mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod # The job should not be completed to start mock_batch_client.read_namespaced_job.return_value.status.completion_time = None - mock_logs = MagicMock() - mock_logs.stream = MagicMock(side_effect=RuntimeError("something went wrong")) - - mock_core_client.read_namespaced_pod_log = MagicMock(return_value=mock_logs) + mock_watch.stream = MagicMock(side_effect=RuntimeError("something went wrong")) async with KubernetesWorker(work_pool_name="test") as k8s_worker: with caplog.at_level("WARNING"): @@ -2165,7 +2178,7 @@ def mock_stream(*args, **kwargs): sleep(0.5) yield {"object": job} - mock_watch.stream.side_effect = mock_stream + mock_watch.api_object_stream.side_effect = mock_stream default_configuration.pod_watch_timeout_seconds = 42 default_configuration.job_watch_timeout_seconds = 0 @@ -2206,7 +2219,7 @@ def mock_log_stream(*args, **kwargs): return MagicMock() mock_core_client.read_namespaced_pod_log.side_effect = mock_log_stream - mock_watch.stream.side_effect = mock_stream + mock_watch.api_object_stream.side_effect = mock_stream default_configuration.job_watch_timeout_seconds = 1000 async with KubernetesWorker(work_pool_name="test") as k8s_worker: @@ -2214,7 +2227,7 @@ def mock_log_stream(*args, **kwargs): assert result.status_code == 1 - mock_watch.stream.assert_has_calls( + mock_watch.api_object_stream.assert_has_calls( [ mock.call( func=mock_core_client.list_namespaced_pod, @@ -2256,7 +2269,7 @@ def mock_stream(*args, **kwargs): # Yield the job then return exiting the stream # After restarting the watch a few times, we'll report completion job.status.completion_time = ( - None if mock_watch.stream.call_count < 3 else True + None if mock_watch.api_object_stream.call_count < 3 else True ) yield {"object": job} @@ -2265,8 +2278,8 @@ def mock_log_stream(*args, **kwargs): sleep(0.25) yield f"test {i}".encode() - mock_core_client.read_namespaced_pod_log.return_value.stream = mock_log_stream - mock_watch.stream.side_effect = mock_stream + mock_watch.stream = mock_log_stream + mock_watch.api_object_stream.side_effect = mock_stream default_configuration.job_watch_timeout_seconds = 1 @@ -2276,7 +2289,7 @@ def mock_log_stream(*args, **kwargs): # The job should timeout assert result.status_code == -1 - mock_watch.stream.assert_has_calls( + mock_watch.api_object_stream.assert_has_calls( [ mock.call( func=mock_core_client.list_namespaced_pod, @@ -2320,8 +2333,8 @@ def mock_log_stream(*args, **kwargs): sleep(0.25) yield f"test {i}".encode() - mock_core_client.read_namespaced_pod_log.return_value.stream = mock_log_stream - mock_watch.stream.side_effect = mock_stream + mock_watch.stream = mock_log_stream + mock_watch.api_object_stream.side_effect = mock_stream default_configuration.job_watch_timeout_seconds = 1 async with KubernetesWorker(work_pool_name="test") as k8s_worker: @@ -2330,7 +2343,7 @@ def mock_log_stream(*args, **kwargs): # The job should not timeout assert result.status_code == 1 - mock_watch.stream.assert_has_calls( + mock_watch.api_object_stream.assert_has_calls( [ mock.call( func=mock_core_client.list_namespaced_pod, @@ -2382,14 +2395,14 @@ def mock_stream(*args, **kwargs): job.spec.backoff_limit = 6 yield {"object": job, "type": "ADDED"} - mock_watch.stream.side_effect = mock_stream + mock_watch.api_object_stream.side_effect = mock_stream default_configuration.job_watch_timeout_seconds = 40 async with KubernetesWorker(work_pool_name="test") as k8s_worker: result = await k8s_worker.run(flow_run, default_configuration) assert result.status_code == -1 - mock_watch.stream.assert_has_calls( + mock_watch.api_object_stream.assert_has_calls( [ mock.call( func=mock_core_client.list_namespaced_pod, @@ -2457,7 +2470,7 @@ def mock_stream(*args, **kwargs): job.status.failed = i yield {"object": job, "type": "ADDED"} - mock_watch.stream.side_effect = mock_stream + mock_watch.api_object_stream.side_effect = mock_stream async with KubernetesWorker(work_pool_name="test") as k8s_worker: result = await k8s_worker.run(flow_run, default_configuration) @@ -2492,7 +2505,7 @@ def mock_stream(*args, **kwargs): job.status.failed = i yield {"object": job, "type": "ADDED"} - mock_watch.stream.side_effect = mock_stream + mock_watch.api_object_stream.side_effect = mock_stream async with KubernetesWorker(work_pool_name="test") as k8s_worker: result = await k8s_worker.run(flow_run, default_configuration) @@ -2537,7 +2550,7 @@ def mock_stream(*args, **kwargs): job.status.failed = i yield {"object": job, "type": "ADDED"} - mock_watch.stream.side_effect = mock_stream + mock_watch.api_object_stream.side_effect = mock_stream async with KubernetesWorker(work_pool_name="test") as k8s_worker: result = await k8s_worker.run(flow_run, default_configuration)