Skip to content

Commit

Permalink
add pod_template and pod_template_name arguments for ContainerTask (#…
Browse files Browse the repository at this point in the history
…1515)

* add pod_template and pod_template_name arguments for ContainerTask

Signed-off-by: Felix Ruess <[email protected]>

* factor out _serialize_pod_spec into separate util function

Signed-off-by: Felix Ruess <[email protected]>

* model file changes, couple other changes

Signed-off-by: Yee Hing Tong <[email protected]>

* minor cleanup

Signed-off-by: Felix Ruess <[email protected]>

* add unit test for container_task pod_template

Signed-off-by: Felix Ruess <[email protected]>

* bump min version of flyteidl to 1.3.12

for pod template data config support

Signed-off-by: Felix Ruess <[email protected]>

* require flyteidl==1.3.12 in doc-requirements.txt

Signed-off-by: Felix Ruess <[email protected]>

---------

Signed-off-by: Felix Ruess <[email protected]>
Signed-off-by: Yee Hing Tong <[email protected]>
Co-authored-by: Yee Hing Tong <[email protected]>
  • Loading branch information
flixr and wild-endeavor authored Mar 24, 2023
1 parent de28789 commit 36fe581
Show file tree
Hide file tree
Showing 8 changed files with 200 additions and 66 deletions.
2 changes: 1 addition & 1 deletion doc-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ flask==2.2.3
# via mlflow
flatbuffers==23.1.21
# via tensorflow
flyteidl==1.3.7
flyteidl==1.3.12
# via flytekit
fonttools==4.38.0
# via matplotlib
Expand Down
55 changes: 46 additions & 9 deletions flytekit/core/container_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
from flytekit.configuration import SerializationSettings
from flytekit.core.base_task import PythonTask, TaskMetadata
from flytekit.core.interface import Interface
from flytekit.core.pod_template import PodTemplate
from flytekit.core.resources import Resources, ResourceSpec
from flytekit.core.utils import _get_container_definition
from flytekit.core.utils import _get_container_definition, _serialize_pod_spec
from flytekit.models import task as _task_model
from flytekit.models.security import Secret, SecurityContext

_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"


# TODO: do we need pod_template here? Seems that it is a raw container not running in pods
class ContainerTask(PythonTask):
"""
This is an intermediate class that represents Flyte Tasks that run a container at execution time. This is the vast
Expand Down Expand Up @@ -47,6 +49,8 @@ def __init__(
metadata_format: MetadataFormat = MetadataFormat.JSON,
io_strategy: Optional[IOStrategy] = None,
secret_requests: Optional[List[Secret]] = None,
pod_template: Optional[PodTemplate] = None,
pod_template_name: Optional[str] = None,
**kwargs,
):
sec_ctx = None
Expand All @@ -55,6 +59,11 @@ def __init__(
if not isinstance(s, Secret):
raise AssertionError(f"Secret {s} should be of type flytekit.Secret, received {type(s)}")
sec_ctx = SecurityContext(secrets=secret_requests)

# pod_template_name overwrites the metadata.pod_template_name
metadata = metadata or TaskMetadata()
metadata.pod_template_name = pod_template_name

super().__init__(
task_type="raw-container",
name=name,
Expand All @@ -74,6 +83,7 @@ def __init__(
self._resources = ResourceSpec(
requests=requests if requests else Resources(), limits=limits if limits else Resources()
)
self.pod_template = pod_template

@property
def resources(self) -> ResourceSpec:
Expand All @@ -91,19 +101,29 @@ def execute(self, **kwargs) -> Any:
return None

def get_container(self, settings: SerializationSettings) -> _task_model.Container:
# if pod_template is specified, return None here but in get_k8s_pod, return pod_template merged with container
if self.pod_template is not None:
return None

return self._get_container(settings)

def _get_data_loading_config(self) -> _task_model.DataLoadingConfig:
return _task_model.DataLoadingConfig(
input_path=self._input_data_dir,
output_path=self._output_data_dir,
format=self._md_format.value,
enabled=True,
io_strategy=self._io_strategy.value if self._io_strategy else None,
)

def _get_container(self, settings: SerializationSettings) -> _task_model.Container:
env = settings.env or {}
env = {**env, **self.environment} if self.environment else env
return _get_container_definition(
image=self._image,
command=self._cmd,
args=self._args,
data_loading_config=_task_model.DataLoadingConfig(
input_path=self._input_data_dir,
output_path=self._output_data_dir,
format=self._md_format.value,
enabled=True,
io_strategy=self._io_strategy.value if self._io_strategy else None,
),
data_loading_config=self._get_data_loading_config(),
environment=env,
storage_request=self.resources.requests.storage,
ephemeral_storage_request=self.resources.requests.ephemeral_storage,
Expand All @@ -116,3 +136,20 @@ def get_container(self, settings: SerializationSettings) -> _task_model.Containe
gpu_limit=self.resources.limits.gpu,
memory_limit=self.resources.limits.mem,
)

def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod:
if self.pod_template is None:
return None
return _task_model.K8sPod(
pod_spec=_serialize_pod_spec(self.pod_template, self._get_container(settings)),
metadata=_task_model.K8sObjectMetadata(
labels=self.pod_template.labels,
annotations=self.pod_template.annotations,
),
data_config=self._get_data_loading_config(),
)

def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]]:
if self.pod_template is None:
return {}
return {_PRIMARY_CONTAINER_NAME_FIELD: self.pod_template.primary_container_name}
55 changes: 3 additions & 52 deletions flytekit/core/python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@
import importlib
import re
from abc import ABC
from typing import Any, Callable, Dict, List, Optional, TypeVar, cast

from flyteidl.core import tasks_pb2 as _core_task
from kubernetes.client import ApiClient
from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements
from typing import Callable, Dict, List, Optional, TypeVar

from flytekit.configuration import ImageConfig, SerializationSettings
from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin
Expand All @@ -16,7 +12,7 @@
from flytekit.core.resources import Resources, ResourceSpec
from flytekit.core.tracked_abc import FlyteTrackedABC
from flytekit.core.tracker import TrackedInstance, extract_task_module
from flytekit.core.utils import _get_container_definition
from flytekit.core.utils import _get_container_definition, _serialize_pod_spec
from flytekit.loggers import logger
from flytekit.models import task as _task_model
from flytekit.models.security import Secret, SecurityContext
Expand All @@ -25,10 +21,6 @@
_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"


def _sanitize_resource_name(resource: _task_model.Resources.ResourceEntry) -> str:
return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-")


class PythonAutoContainerTask(PythonTask[T], ABC, metaclass=FlyteTrackedABC):
"""
A Python AutoContainer task should be used as the base for all extensions that want the user's code to be in the
Expand Down Expand Up @@ -206,52 +198,11 @@ def _get_container(self, settings: SerializationSettings) -> _task_model.Contain
memory_limit=self.resources.limits.mem,
)

def _serialize_pod_spec(self, settings: SerializationSettings) -> Dict[str, Any]:
containers = cast(PodTemplate, self.pod_template).pod_spec.containers
primary_exists = False

for container in containers:
if container.name == cast(PodTemplate, self.pod_template).primary_container_name:
primary_exists = True
break

if not primary_exists:
# insert a placeholder primary container if it is not defined in the pod spec.
containers.append(V1Container(name=cast(PodTemplate, self.pod_template).primary_container_name))
final_containers = []
for container in containers:
# In the case of the primary container, we overwrite specific container attributes
# with the default values used in the regular Python task.
# The attributes include: image, command, args, resource, and env (env is unioned)
if container.name == cast(PodTemplate, self.pod_template).primary_container_name:
sdk_default_container = self._get_container(settings)
container.image = sdk_default_container.image
# clear existing commands
container.command = sdk_default_container.command
# also clear existing args
container.args = sdk_default_container.args
limits, requests = {}, {}
for resource in sdk_default_container.resources.limits:
limits[_sanitize_resource_name(resource)] = resource.value
for resource in sdk_default_container.resources.requests:
requests[_sanitize_resource_name(resource)] = resource.value
resource_requirements = V1ResourceRequirements(limits=limits, requests=requests)
if len(limits) > 0 or len(requests) > 0:
# Important! Only copy over resource requirements if they are non-empty.
container.resources = resource_requirements
container.env = [V1EnvVar(name=key, value=val) for key, val in sdk_default_container.env.items()] + (
container.env or []
)
final_containers.append(container)
cast(PodTemplate, self.pod_template).pod_spec.containers = final_containers

return ApiClient().sanitize_for_serialization(cast(PodTemplate, self.pod_template).pod_spec)

def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod:
if self.pod_template is None:
return None
return _task_model.K8sPod(
pod_spec=self._serialize_pod_spec(settings),
pod_spec=_serialize_pod_spec(self.pod_template, self._get_container(settings)),
metadata=_task_model.K8sObjectMetadata(
labels=self.pod_template.labels,
annotations=self.pod_template.annotations,
Expand Down
53 changes: 52 additions & 1 deletion flytekit/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,15 @@
import time as _time
from hashlib import sha224 as _sha224
from pathlib import Path
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional, cast

from flyteidl.core import tasks_pb2 as _core_task
from kubernetes.client import ApiClient
from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements

from flytekit.core.pod_template import PodTemplate
from flytekit.loggers import logger
from flytekit.models import task as _task_model
from flytekit.models import task as task_models


Expand Down Expand Up @@ -125,6 +131,51 @@ def _get_container_definition(
)


def _sanitize_resource_name(resource: _task_model.Resources.ResourceEntry) -> str:
return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-")


def _serialize_pod_spec(pod_template: PodTemplate, primary_container: _task_model.Container) -> Dict[str, Any]:
containers = cast(PodTemplate, pod_template).pod_spec.containers
primary_exists = False

for container in containers:
if container.name == cast(PodTemplate, pod_template).primary_container_name:
primary_exists = True
break

if not primary_exists:
# insert a placeholder primary container if it is not defined in the pod spec.
containers.append(V1Container(name=cast(PodTemplate, pod_template).primary_container_name))
final_containers = []
for container in containers:
# In the case of the primary container, we overwrite specific container attributes
# with the values given to ContainerTask.
# The attributes include: image, command, args, resource, and env (env is unioned)
if container.name == cast(PodTemplate, pod_template).primary_container_name:
container.image = primary_container.image
container.command = primary_container.command
container.args = primary_container.args

limits, requests = {}, {}
for resource in primary_container.resources.limits:
limits[_sanitize_resource_name(resource)] = resource.value
for resource in primary_container.resources.requests:
requests[_sanitize_resource_name(resource)] = resource.value
resource_requirements = V1ResourceRequirements(limits=limits, requests=requests)
if len(limits) > 0 or len(requests) > 0:
# Important! Only copy over resource requirements if they are non-empty.
container.resources = resource_requirements
if primary_container.env is not None:
container.env = [V1EnvVar(name=key, value=val) for key, val in primary_container.env.items()] + (
container.env or []
)
final_containers.append(container)
cast(PodTemplate, pod_template).pod_spec.containers = final_containers

return ApiClient().sanitize_for_serialization(cast(PodTemplate, pod_template).pod_spec)


def load_proto_from_file(pb2_type, path):
with open(path, "rb") as reader:
out = pb2_type()
Expand Down
16 changes: 15 additions & 1 deletion flytekit/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,12 +868,18 @@ def from_flyte_idl(cls, pb2_object: _core_task.K8sObjectMetadata):


class K8sPod(_common.FlyteIdlEntity):
def __init__(self, metadata: K8sObjectMetadata = None, pod_spec: typing.Dict[str, typing.Any] = None):
def __init__(
self,
metadata: K8sObjectMetadata = None,
pod_spec: typing.Dict[str, typing.Any] = None,
data_config: typing.Optional[DataLoadingConfig] = None,
):
"""
This defines a kubernetes pod target. It will build the pod target during task execution
"""
self._metadata = metadata
self._pod_spec = pod_spec
self._data_config = data_config

@property
def metadata(self) -> K8sObjectMetadata:
Expand All @@ -883,17 +889,25 @@ def metadata(self) -> K8sObjectMetadata:
def pod_spec(self) -> typing.Dict[str, typing.Any]:
return self._pod_spec

@property
def data_config(self) -> typing.Optional[DataLoadingConfig]:
return self._data_config

def to_flyte_idl(self) -> _core_task.K8sPod:
return _core_task.K8sPod(
metadata=self._metadata.to_flyte_idl(),
pod_spec=_json_format.Parse(_json.dumps(self.pod_spec), _struct.Struct()) if self.pod_spec else None,
data_config=self.data_config.to_flyte_idl() if self.data_config else None,
)

@classmethod
def from_flyte_idl(cls, pb2_object: _core_task.K8sPod):
return cls(
metadata=K8sObjectMetadata.from_flyte_idl(pb2_object.metadata),
pod_spec=_json_format.MessageToDict(pb2_object.pod_spec) if pb2_object.HasField("pod_spec") else None,
data_config=DataLoadingConfig.from_flyte_idl(pb2_object.data_config)
if pb2_object.HasField("data_config")
else None,
)


Expand Down
3 changes: 2 additions & 1 deletion flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from flytekit.core import constants as _common_constants
from flytekit.core.base_task import PythonTask
from flytekit.core.condition import BranchNode
from flytekit.core.container_task import ContainerTask
from flytekit.core.gate import Gate
from flytekit.core.launch_plan import LaunchPlan, ReferenceLaunchPlan
from flytekit.core.map_task import MapPythonTask
Expand Down Expand Up @@ -189,7 +190,7 @@ def get_serializable_task(
# If the pod spec is not None, we have to get it again, because the one we retrieved above will be incorrect.
# The reason we have to call get_k8s_pod again, instead of just modifying the command in this file, is because
# the pod spec is a K8s library object, and we shouldn't be messing around with it in this file.
elif pod:
elif pod and not isinstance(entity, ContainerTask):
if isinstance(entity, MapPythonTask):
entity.set_command_prefix(get_command_prefix_for_fast_execute(settings))
pod = entity.get_k8s_pod(settings)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
},
install_requires=[
"googleapis-common-protos>=1.57",
"flyteidl>=1.3.5,<1.4.0",
"flyteidl>=1.3.12,<1.4.0",
"wheel>=0.30.0,<1.0.0",
"pandas>=1.0.0,<2.0.0",
"pyarrow>=4.0.0,<11.0.0",
Expand Down
Loading

0 comments on commit 36fe581

Please sign in to comment.