From fbe18cc1e0ddc5b2a55b411d810574e1ce6caaa4 Mon Sep 17 00:00:00 2001 From: Felix Ruess Date: Wed, 22 Feb 2023 14:50:38 +0100 Subject: [PATCH] add pod_template and pod_template_name arguments for ContainerTask Signed-off-by: Felix Ruess --- flytekit/core/container_task.py | 84 ++++++++++++++++++++++++++++++++- 1 file changed, 82 insertions(+), 2 deletions(-) diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index 677142736c..5054e3a362 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -1,16 +1,26 @@ from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type, cast + +from flyteidl.core import tasks_pb2 as _core_task +from kubernetes.client import ApiClient +from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask, TaskMetadata from flytekit.core.interface import Interface +from flytekit.core.pod_template import PodTemplate from flytekit.core.resources import Resources, ResourceSpec from flytekit.core.utils import _get_container_definition from flytekit.models import task as _task_model from flytekit.models.security import Secret, SecurityContext +_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name" + + +def _sanitize_resource_name(resource: _task_model.Resources.ResourceEntry) -> str: + return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-") + -# TODO: do we need pod_template here? Seems that it is a raw container not running in pods class ContainerTask(PythonTask): """ This is an intermediate class that represents Flyte Tasks that run a container at execution time. This is the vast @@ -47,6 +57,8 @@ def __init__( metadata_format: MetadataFormat = MetadataFormat.JSON, io_strategy: Optional[IOStrategy] = None, secret_requests: Optional[List[Secret]] = None, + pod_template: Optional[PodTemplate] = None, + pod_template_name: Optional[str] = None, **kwargs, ): sec_ctx = None @@ -55,6 +67,11 @@ def __init__( if not isinstance(s, Secret): raise AssertionError(f"Secret {s} should be of type flytekit.Secret, received {type(s)}") sec_ctx = SecurityContext(secrets=secret_requests) + + # pod_template_name overwrites the metadata.pod_template_name + kwargs["metadata"] = kwargs["metadata"] if "metadata" in kwargs else TaskMetadata() + kwargs["metadata"].pod_template_name = pod_template_name + super().__init__( task_type="raw-container", name=name, @@ -74,6 +91,7 @@ def __init__( self._resources = ResourceSpec( requests=requests if requests else Resources(), limits=limits if limits else Resources() ) + self.pod_template = pod_template @property def resources(self) -> ResourceSpec: @@ -91,6 +109,13 @@ def execute(self, **kwargs) -> Any: return None def get_container(self, settings: SerializationSettings) -> _task_model.Container: + # if pod_template is specified, return None here but in get_k8s_pod, return pod_template merged with container + if self.pod_template is not None: + return None + + return self._get_container(settings) + + def _get_container(self, settings: SerializationSettings) -> _task_model.Container: env = settings.env or {} env = {**env, **self.environment} if self.environment else env return _get_container_definition( @@ -116,3 +141,58 @@ def get_container(self, settings: SerializationSettings) -> _task_model.Containe gpu_limit=self.resources.limits.gpu, memory_limit=self.resources.limits.mem, ) + + def _serialize_pod_spec(self, settings: SerializationSettings) -> Dict[str, Any]: + containers = cast(PodTemplate, self.pod_template).pod_spec.containers + primary_exists = False + + for container in containers: + if container.name == cast(PodTemplate, self.pod_template).primary_container_name: + primary_exists = True + break + + if not primary_exists: + # insert a placeholder primary container if it is not defined in the pod spec. + containers.append(V1Container(name=cast(PodTemplate, self.pod_template).primary_container_name)) + final_containers = [] + for container in containers: + # In the case of the primary container, we overwrite specific container attributes + # with the default values used in the regular Python task. + # The attributes include: image, command, args, resource, and env (env is unioned) + if container.name == cast(PodTemplate, self.pod_template).primary_container_name: + container.image = self._image + container.command = self._cmd + container.args = self._args + + limits, requests = {}, {} + for resource in self.resources.limits: + limits[_sanitize_resource_name(resource)] = resource.value + for resource in self.resources.requests: + requests[_sanitize_resource_name(resource)] = resource.value + resource_requirements = V1ResourceRequirements(limits=limits, requests=requests) + if len(limits) > 0 or len(requests) > 0: + # Important! Only copy over resource requirements if they are non-empty. + container.resources = resource_requirements + env = settings.env or {} + env = {**env, **self.environment} if self.environment else env + container.env = [V1EnvVar(name=key, value=val) for key, val in env.items()] + (container.env or []) + final_containers.append(container) + cast(PodTemplate, self.pod_template).pod_spec.containers = final_containers + + return ApiClient().sanitize_for_serialization(cast(PodTemplate, self.pod_template).pod_spec) + + def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod: + if self.pod_template is None: + return None + return _task_model.K8sPod( + pod_spec=self._serialize_pod_spec(settings), + metadata=_task_model.K8sObjectMetadata( + labels=self.pod_template.labels, + annotations=self.pod_template.annotations, + ), + ) + + def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]]: + if self.pod_template is None: + return {} + return {_PRIMARY_CONTAINER_NAME_FIELD: self.pod_template.primary_container_name}