diff --git a/doc-requirements.txt b/doc-requirements.txt index 19f20af9fc..57286bfb07 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -204,7 +204,7 @@ flask==2.2.3 # via mlflow flatbuffers==23.1.21 # via tensorflow -flyteidl==1.3.7 +flyteidl==1.3.12 # via flytekit fonttools==4.38.0 # via matplotlib diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index 677142736c..d51f71d837 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -4,13 +4,15 @@ from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask, TaskMetadata from flytekit.core.interface import Interface +from flytekit.core.pod_template import PodTemplate from flytekit.core.resources import Resources, ResourceSpec -from flytekit.core.utils import _get_container_definition +from flytekit.core.utils import _get_container_definition, _serialize_pod_spec from flytekit.models import task as _task_model from flytekit.models.security import Secret, SecurityContext +_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name" + -# TODO: do we need pod_template here? Seems that it is a raw container not running in pods class ContainerTask(PythonTask): """ This is an intermediate class that represents Flyte Tasks that run a container at execution time. This is the vast @@ -47,6 +49,8 @@ def __init__( metadata_format: MetadataFormat = MetadataFormat.JSON, io_strategy: Optional[IOStrategy] = None, secret_requests: Optional[List[Secret]] = None, + pod_template: Optional[PodTemplate] = None, + pod_template_name: Optional[str] = None, **kwargs, ): sec_ctx = None @@ -55,6 +59,11 @@ def __init__( if not isinstance(s, Secret): raise AssertionError(f"Secret {s} should be of type flytekit.Secret, received {type(s)}") sec_ctx = SecurityContext(secrets=secret_requests) + + # pod_template_name overwrites the metadata.pod_template_name + metadata = metadata or TaskMetadata() + metadata.pod_template_name = pod_template_name + super().__init__( task_type="raw-container", name=name, @@ -74,6 +83,7 @@ def __init__( self._resources = ResourceSpec( requests=requests if requests else Resources(), limits=limits if limits else Resources() ) + self.pod_template = pod_template @property def resources(self) -> ResourceSpec: @@ -91,19 +101,29 @@ def execute(self, **kwargs) -> Any: return None def get_container(self, settings: SerializationSettings) -> _task_model.Container: + # if pod_template is specified, return None here but in get_k8s_pod, return pod_template merged with container + if self.pod_template is not None: + return None + + return self._get_container(settings) + + def _get_data_loading_config(self) -> _task_model.DataLoadingConfig: + return _task_model.DataLoadingConfig( + input_path=self._input_data_dir, + output_path=self._output_data_dir, + format=self._md_format.value, + enabled=True, + io_strategy=self._io_strategy.value if self._io_strategy else None, + ) + + def _get_container(self, settings: SerializationSettings) -> _task_model.Container: env = settings.env or {} env = {**env, **self.environment} if self.environment else env return _get_container_definition( image=self._image, command=self._cmd, args=self._args, - data_loading_config=_task_model.DataLoadingConfig( - input_path=self._input_data_dir, - output_path=self._output_data_dir, - format=self._md_format.value, - enabled=True, - io_strategy=self._io_strategy.value if self._io_strategy else None, - ), + data_loading_config=self._get_data_loading_config(), environment=env, storage_request=self.resources.requests.storage, ephemeral_storage_request=self.resources.requests.ephemeral_storage, @@ -116,3 +136,20 @@ def get_container(self, settings: SerializationSettings) -> _task_model.Containe gpu_limit=self.resources.limits.gpu, memory_limit=self.resources.limits.mem, ) + + def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod: + if self.pod_template is None: + return None + return _task_model.K8sPod( + pod_spec=_serialize_pod_spec(self.pod_template, self._get_container(settings)), + metadata=_task_model.K8sObjectMetadata( + labels=self.pod_template.labels, + annotations=self.pod_template.annotations, + ), + data_config=self._get_data_loading_config(), + ) + + def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]]: + if self.pod_template is None: + return {} + return {_PRIMARY_CONTAINER_NAME_FIELD: self.pod_template.primary_container_name} diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 113f94a998..774825f347 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -3,11 +3,7 @@ import importlib import re from abc import ABC -from typing import Any, Callable, Dict, List, Optional, TypeVar, cast - -from flyteidl.core import tasks_pb2 as _core_task -from kubernetes.client import ApiClient -from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements +from typing import Callable, Dict, List, Optional, TypeVar from flytekit.configuration import ImageConfig, SerializationSettings from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin @@ -16,7 +12,7 @@ from flytekit.core.resources import Resources, ResourceSpec from flytekit.core.tracked_abc import FlyteTrackedABC from flytekit.core.tracker import TrackedInstance, extract_task_module -from flytekit.core.utils import _get_container_definition +from flytekit.core.utils import _get_container_definition, _serialize_pod_spec from flytekit.loggers import logger from flytekit.models import task as _task_model from flytekit.models.security import Secret, SecurityContext @@ -25,10 +21,6 @@ _PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name" -def _sanitize_resource_name(resource: _task_model.Resources.ResourceEntry) -> str: - return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-") - - class PythonAutoContainerTask(PythonTask[T], ABC, metaclass=FlyteTrackedABC): """ A Python AutoContainer task should be used as the base for all extensions that want the user's code to be in the @@ -206,52 +198,11 @@ def _get_container(self, settings: SerializationSettings) -> _task_model.Contain memory_limit=self.resources.limits.mem, ) - def _serialize_pod_spec(self, settings: SerializationSettings) -> Dict[str, Any]: - containers = cast(PodTemplate, self.pod_template).pod_spec.containers - primary_exists = False - - for container in containers: - if container.name == cast(PodTemplate, self.pod_template).primary_container_name: - primary_exists = True - break - - if not primary_exists: - # insert a placeholder primary container if it is not defined in the pod spec. - containers.append(V1Container(name=cast(PodTemplate, self.pod_template).primary_container_name)) - final_containers = [] - for container in containers: - # In the case of the primary container, we overwrite specific container attributes - # with the default values used in the regular Python task. - # The attributes include: image, command, args, resource, and env (env is unioned) - if container.name == cast(PodTemplate, self.pod_template).primary_container_name: - sdk_default_container = self._get_container(settings) - container.image = sdk_default_container.image - # clear existing commands - container.command = sdk_default_container.command - # also clear existing args - container.args = sdk_default_container.args - limits, requests = {}, {} - for resource in sdk_default_container.resources.limits: - limits[_sanitize_resource_name(resource)] = resource.value - for resource in sdk_default_container.resources.requests: - requests[_sanitize_resource_name(resource)] = resource.value - resource_requirements = V1ResourceRequirements(limits=limits, requests=requests) - if len(limits) > 0 or len(requests) > 0: - # Important! Only copy over resource requirements if they are non-empty. - container.resources = resource_requirements - container.env = [V1EnvVar(name=key, value=val) for key, val in sdk_default_container.env.items()] + ( - container.env or [] - ) - final_containers.append(container) - cast(PodTemplate, self.pod_template).pod_spec.containers = final_containers - - return ApiClient().sanitize_for_serialization(cast(PodTemplate, self.pod_template).pod_spec) - def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod: if self.pod_template is None: return None return _task_model.K8sPod( - pod_spec=self._serialize_pod_spec(settings), + pod_spec=_serialize_pod_spec(self.pod_template, self._get_container(settings)), metadata=_task_model.K8sObjectMetadata( labels=self.pod_template.labels, annotations=self.pod_template.annotations, diff --git a/flytekit/core/utils.py b/flytekit/core/utils.py index ee2c841465..24ce4d07d8 100644 --- a/flytekit/core/utils.py +++ b/flytekit/core/utils.py @@ -4,9 +4,15 @@ import time as _time from hashlib import sha224 as _sha224 from pathlib import Path -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional, cast +from flyteidl.core import tasks_pb2 as _core_task +from kubernetes.client import ApiClient +from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements + +from flytekit.core.pod_template import PodTemplate from flytekit.loggers import logger +from flytekit.models import task as _task_model from flytekit.models import task as task_models @@ -125,6 +131,51 @@ def _get_container_definition( ) +def _sanitize_resource_name(resource: _task_model.Resources.ResourceEntry) -> str: + return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-") + + +def _serialize_pod_spec(pod_template: PodTemplate, primary_container: _task_model.Container) -> Dict[str, Any]: + containers = cast(PodTemplate, pod_template).pod_spec.containers + primary_exists = False + + for container in containers: + if container.name == cast(PodTemplate, pod_template).primary_container_name: + primary_exists = True + break + + if not primary_exists: + # insert a placeholder primary container if it is not defined in the pod spec. + containers.append(V1Container(name=cast(PodTemplate, pod_template).primary_container_name)) + final_containers = [] + for container in containers: + # In the case of the primary container, we overwrite specific container attributes + # with the values given to ContainerTask. + # The attributes include: image, command, args, resource, and env (env is unioned) + if container.name == cast(PodTemplate, pod_template).primary_container_name: + container.image = primary_container.image + container.command = primary_container.command + container.args = primary_container.args + + limits, requests = {}, {} + for resource in primary_container.resources.limits: + limits[_sanitize_resource_name(resource)] = resource.value + for resource in primary_container.resources.requests: + requests[_sanitize_resource_name(resource)] = resource.value + resource_requirements = V1ResourceRequirements(limits=limits, requests=requests) + if len(limits) > 0 or len(requests) > 0: + # Important! Only copy over resource requirements if they are non-empty. + container.resources = resource_requirements + if primary_container.env is not None: + container.env = [V1EnvVar(name=key, value=val) for key, val in primary_container.env.items()] + ( + container.env or [] + ) + final_containers.append(container) + cast(PodTemplate, pod_template).pod_spec.containers = final_containers + + return ApiClient().sanitize_for_serialization(cast(PodTemplate, pod_template).pod_spec) + + def load_proto_from_file(pb2_type, path): with open(path, "rb") as reader: out = pb2_type() diff --git a/flytekit/models/task.py b/flytekit/models/task.py index fc79c87a2d..f7f1d710c9 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -868,12 +868,18 @@ def from_flyte_idl(cls, pb2_object: _core_task.K8sObjectMetadata): class K8sPod(_common.FlyteIdlEntity): - def __init__(self, metadata: K8sObjectMetadata = None, pod_spec: typing.Dict[str, typing.Any] = None): + def __init__( + self, + metadata: K8sObjectMetadata = None, + pod_spec: typing.Dict[str, typing.Any] = None, + data_config: typing.Optional[DataLoadingConfig] = None, + ): """ This defines a kubernetes pod target. It will build the pod target during task execution """ self._metadata = metadata self._pod_spec = pod_spec + self._data_config = data_config @property def metadata(self) -> K8sObjectMetadata: @@ -883,10 +889,15 @@ def metadata(self) -> K8sObjectMetadata: def pod_spec(self) -> typing.Dict[str, typing.Any]: return self._pod_spec + @property + def data_config(self) -> typing.Optional[DataLoadingConfig]: + return self._data_config + def to_flyte_idl(self) -> _core_task.K8sPod: return _core_task.K8sPod( metadata=self._metadata.to_flyte_idl(), pod_spec=_json_format.Parse(_json.dumps(self.pod_spec), _struct.Struct()) if self.pod_spec else None, + data_config=self.data_config.to_flyte_idl() if self.data_config else None, ) @classmethod @@ -894,6 +905,9 @@ def from_flyte_idl(cls, pb2_object: _core_task.K8sPod): return cls( metadata=K8sObjectMetadata.from_flyte_idl(pb2_object.metadata), pod_spec=_json_format.MessageToDict(pb2_object.pod_spec) if pb2_object.HasField("pod_spec") else None, + data_config=DataLoadingConfig.from_flyte_idl(pb2_object.data_config) + if pb2_object.HasField("data_config") + else None, ) diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 5ec249fa4b..8b30fc4d36 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -9,6 +9,7 @@ from flytekit.core import constants as _common_constants from flytekit.core.base_task import PythonTask from flytekit.core.condition import BranchNode +from flytekit.core.container_task import ContainerTask from flytekit.core.gate import Gate from flytekit.core.launch_plan import LaunchPlan, ReferenceLaunchPlan from flytekit.core.map_task import MapPythonTask @@ -189,7 +190,7 @@ def get_serializable_task( # If the pod spec is not None, we have to get it again, because the one we retrieved above will be incorrect. # The reason we have to call get_k8s_pod again, instead of just modifying the command in this file, is because # the pod spec is a K8s library object, and we shouldn't be messing around with it in this file. - elif pod: + elif pod and not isinstance(entity, ContainerTask): if isinstance(entity, MapPythonTask): entity.set_command_prefix(get_command_prefix_for_fast_execute(settings)) pod = entity.get_k8s_pod(settings) diff --git a/setup.py b/setup.py index 18ffb75187..26eca35e99 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ }, install_requires=[ "googleapis-common-protos>=1.57", - "flyteidl>=1.3.5,<1.4.0", + "flyteidl>=1.3.12,<1.4.0", "wheel>=0.30.0,<1.0.0", "pandas>=1.0.0,<2.0.0", "pyarrow>=4.0.0,<11.0.0", diff --git a/tests/flytekit/unit/core/test_container_task.py b/tests/flytekit/unit/core/test_container_task.py new file mode 100644 index 0000000000..599061d403 --- /dev/null +++ b/tests/flytekit/unit/core/test_container_task.py @@ -0,0 +1,80 @@ +from kubernetes.client.models import ( + V1Affinity, + V1NodeAffinity, + V1NodeSelectorRequirement, + V1NodeSelectorTerm, + V1PodSpec, + V1PreferredSchedulingTerm, + V1Toleration, +) + +from flytekit import kwtypes +from flytekit.configuration import Image, ImageConfig, SerializationSettings +from flytekit.core.container_task import ContainerTask +from flytekit.core.pod_template import PodTemplate +from flytekit.tools.translator import get_serializable_task + + +def test_pod_template(): + ps = V1PodSpec( + containers=[], tolerations=[V1Toleration(effect="NoSchedule", key="nvidia.com/gpu", operator="Exists")] + ) + ps.runtime_class_name = "nvidia" + nsr = V1NodeSelectorRequirement(key="nvidia.com/gpu.memory", operator="Gt", values=["10000"]) + pref_sched = V1PreferredSchedulingTerm(preference=V1NodeSelectorTerm(match_expressions=[nsr]), weight=1) + ps.affinity = V1Affinity( + node_affinity=V1NodeAffinity(preferred_during_scheduling_ignored_during_execution=[pref_sched]) + ) + pt = PodTemplate(pod_spec=ps, labels={"somelabel": "foobar"}) + + image = "ghcr.io/flyteorg/rawcontainers-shell:v2" + cmd = [ + "./calculate-ellipse-area.sh", + "{{.inputs.a}}", + "{{.inputs.b}}", + "/var/outputs", + ] + ct = ContainerTask( + name="ellipse-area-metadata-shell", + input_data_dir="/var/inputs", + output_data_dir="/var/outputs", + inputs=kwtypes(a=float, b=float), + outputs=kwtypes(area=float, metadata=str), + image=image, + command=cmd, + pod_template=pt, + pod_template_name="my-base-template", + ) + + assert ct.metadata.pod_template_name == "my-base-template" + + default_image = Image(name="default", fqn="docker.io/xyz", tag="some-git-hash") + default_image_config = ImageConfig(default_image=default_image) + default_serialization_settings = SerializationSettings( + project="p", domain="d", version="v", image_config=default_image_config + ) + + container = ct.get_container(default_serialization_settings) + assert container is None + + k8s_pod = ct.get_k8s_pod(default_serialization_settings) + assert k8s_pod.metadata.labels == {"somelabel": "foobar"} + + primary_container = k8s_pod.pod_spec["containers"][0] + + assert primary_container["image"] == image + assert primary_container["command"] == cmd + + ################# + # Test Serialization + ################# + ts = get_serializable_task(default_serialization_settings, ct) + assert ts.template.metadata.pod_template_name == "my-base-template" + assert ts.template.container is None + assert ts.template.k8s_pod is not None + serialized_pod_spec = ts.template.k8s_pod.pod_spec + assert serialized_pod_spec["affinity"]["nodeAffinity"] is not None + assert serialized_pod_spec["tolerations"] == [ + {"effect": "NoSchedule", "key": "nvidia.com/gpu", "operator": "Exists"} + ] + assert serialized_pod_spec["runtimeClassName"] == "nvidia"