diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 286097c668e..2d7938c67c2 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -24,6 +24,8 @@ from dataclasses import dataclass from typing import Any, Coroutine, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union, cast +from flyteidl.core import tasks_pb2 + from flytekit.configuration import SerializationSettings from flytekit.core.context_manager import ( ExecutionParameters, @@ -344,6 +346,12 @@ def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str] """ return None + def get_extended_resources(self, settings: SerializationSettings) -> Optional[tasks_pb2.ExtendedResources]: + """ + Returns the extended resources to allocate to the task on hosted Flyte. + """ + return None + def local_execution_mode(self) -> ExecutionState.Mode: """ """ return ExecutionState.Mode.LOCAL_TASK_EXECUTION diff --git a/flytekit/core/node.py b/flytekit/core/node.py index 1038c005219..2957abe0dfb 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -4,6 +4,8 @@ import typing from typing import Any, List +from flyteidl.core import tasks_pb2 + from flytekit.core.resources import Resources, convert_resources_to_resource_model from flytekit.core.utils import _dnsify from flytekit.loggers import logger @@ -62,6 +64,7 @@ def __init__( self._aliases: _workflow_model.Alias = None self._outputs = None self._resources: typing.Optional[_resources_model] = None + self._extended_resources: typing.Optional[tasks_pb2.ExtendedResources] = None def runs_before(self, other: Node): """ @@ -172,6 +175,11 @@ def with_overrides(self, *args, **kwargs): assert_not_promise(v, "container_image") self.flyte_entity._container_image = v + if "accelerator" in kwargs: + v = kwargs["accelerator"] + assert_not_promise(v, "accelerator") + self._extended_resources = tasks_pb2.ExtendedResources(gpu_accelerator=v.to_flyte_idl()) + return self diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 5335410a79f..1ad1de0216f 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -5,6 +5,8 @@ from abc import ABC from typing import Callable, Dict, List, Optional, TypeVar, Union +from flyteidl.core import tasks_pb2 + from flytekit.configuration import ImageConfig, SerializationSettings from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin from flytekit.core.context_manager import FlyteContextManager @@ -13,6 +15,7 @@ from flytekit.core.tracked_abc import FlyteTrackedABC from flytekit.core.tracker import TrackedInstance, extract_task_module from flytekit.core.utils import _get_container_definition, _serialize_pod_spec, timeit +from flytekit.extras.accelerators import BaseAccelerator from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec from flytekit.loggers import logger from flytekit.models import task as _task_model @@ -44,6 +47,7 @@ def __init__( secret_requests: Optional[List[Secret]] = None, pod_template: Optional[PodTemplate] = None, pod_template_name: Optional[str] = None, + accelerator: Optional[BaseAccelerator] = None, **kwargs, ): """ @@ -70,6 +74,7 @@ def __init__( - `AWS Parameter store `__ :param pod_template: Custom PodTemplate for this task. :param pod_template_name: The name of the existing PodTemplate resource which will be used in this task. + :param accelerator: The accelerator to use for this task. """ sec_ctx = None if secret_requests: @@ -110,6 +115,7 @@ def __init__( self._get_command_fn = self.get_default_command self.pod_template = pod_template + self.accelerator = accelerator @property def task_resolver(self) -> TaskResolverMixin: @@ -219,6 +225,15 @@ def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str] return {} return {_PRIMARY_CONTAINER_NAME_FIELD: self.pod_template.primary_container_name} + def get_extended_resources(self, settings: SerializationSettings) -> Optional[tasks_pb2.ExtendedResources]: + """ + Returns the extended resources to allocate to the task on hosted Flyte. + """ + if self.accelerator is None: + return None + + return tasks_pb2.ExtendedResources(gpu_accelerator=self.accelerator.to_flyte_idl()) + class DefaultTaskResolver(TrackedInstance, TaskResolverMixin): """ diff --git a/flytekit/core/task.py b/flytekit/core/task.py index ce16e9634d5..547abd41fa1 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -8,6 +8,7 @@ from flytekit.core.python_function_task import PythonFunctionTask from flytekit.core.reference_entity import ReferenceEntity, TaskReference from flytekit.core.resources import Resources +from flytekit.extras.accelerators import BaseAccelerator from flytekit.image_spec.image_spec import ImageSpec from flytekit.models.documentation import Documentation from flytekit.models.security import Secret @@ -102,6 +103,7 @@ def task( enable_deck: Optional[bool] = ..., pod_template: Optional["PodTemplate"] = ..., pod_template_name: Optional[str] = ..., + accelerator: Optional[BaseAccelerator] = ..., ) -> Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]]: ... @@ -129,6 +131,7 @@ def task( enable_deck: Optional[bool] = ..., pod_template: Optional["PodTemplate"] = ..., pod_template_name: Optional[str] = ..., + accelerator: Optional[BaseAccelerator] = ..., ) -> Union[PythonFunctionTask[T], Callable[..., FuncOut]]: ... @@ -155,6 +158,7 @@ def task( enable_deck: Optional[bool] = None, pod_template: Optional["PodTemplate"] = None, pod_template_name: Optional[str] = None, + accelerator: Optional[BaseAccelerator] = None, ) -> Union[Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]], PythonFunctionTask[T], Callable[..., FuncOut]]: """ This is the core decorator to use for any task type in flytekit. @@ -248,6 +252,7 @@ def foo2(): :param docs: Documentation about this task :param pod_template: Custom PodTemplate for this task. :param pod_template_name: The name of the existing PodTemplate resource which will be used in this task. + :param accelerator: The accelerator to use for this task. """ def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]: @@ -277,6 +282,7 @@ def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]: docs=docs, pod_template=pod_template, pod_template_name=pod_template_name, + accelerator=accelerator, ) update_wrapper(task_instance, fn) return task_instance diff --git a/flytekit/extras/accelerators.py b/flytekit/extras/accelerators.py new file mode 100644 index 00000000000..3615f32bdbd --- /dev/null +++ b/flytekit/extras/accelerators.py @@ -0,0 +1,90 @@ +import abc +import copy +from typing import ClassVar, Generic, Optional, Type, TypeVar + +from flyteidl.core import tasks_pb2 + +T = TypeVar("T") +MIG = TypeVar("MIG", bound="MultiInstanceGPUAccelerator") + + +class BaseAccelerator(abc.ABC, Generic[T]): + @abc.abstractmethod + def to_flyte_idl(self) -> T: + ... + + +class GPUAccelerator(BaseAccelerator): + def __init__(self, device: str) -> None: + self._device = device + + def to_flyte_idl(self) -> tasks_pb2.GPUAccelerator: + return tasks_pb2.GPUAccelerator(device=self._device) + + +A10G = GPUAccelerator("nvidia-a10g") +L4 = GPUAccelerator("nvidia-l4-vws") +K80 = GPUAccelerator("nvidia-tesla-k80") +M60 = GPUAccelerator("nvidia-tesla-m60") +P4 = GPUAccelerator("nvidia-tesla-p4") +P100 = GPUAccelerator("nvidia-tesla-p100") +T4 = GPUAccelerator("nvidia-tesla-t4") +V100 = GPUAccelerator("nvidia-tesla-v100") + + +class MultiInstanceGPUAccelerator(BaseAccelerator): + device: ClassVar[str] + _partition_size: Optional[str] + + @property + def unpartitioned(self: MIG) -> MIG: + instance = copy.deepcopy(self) + instance._partition_size = None + return instance + + @classmethod + def partitioned(cls: Type[MIG], partition_size: str) -> MIG: + instance = cls() + instance._partition_size = partition_size + return instance + + def to_flyte_idl(self) -> tasks_pb2.GPUAccelerator: + msg = tasks_pb2.GPUAccelerator(device=self.device) + if not hasattr(self, "_partition_size"): + return msg + + if self._partition_size is None: + msg.unpartitioned = True + else: + msg.partition_size = self._partition_size + return msg + + +class _A100_Base(MultiInstanceGPUAccelerator): + device = "nvidia-tesla-a100" + + +class _A100(_A100_Base): + partition_1g_5gb = _A100_Base.partitioned("1g.5gb") + partition_2g_10gb = _A100_Base.partitioned("2g.10gb") + partition_3g_20gb = _A100_Base.partitioned("3g.20gb") + partition_4g_20gb = _A100_Base.partitioned("4g.20gb") + partition_7g_40gb = _A100_Base.partitioned("7g.40gb") + + +A100 = _A100() + + +class _A100_80GB_Base(MultiInstanceGPUAccelerator): + device = "nvidia-a100-80gb" + + +class _A100_80GB(_A100_80GB_Base): + partition_1g_10gb = _A100_80GB_Base.partitioned("1g.10gb") + partition_2g_20gb = _A100_80GB_Base.partitioned("2g.20gb") + partition_3g_40gb = _A100_80GB_Base.partitioned("3g.40gb") + partition_4g_40gb = _A100_80GB_Base.partitioned("4g.40gb") + partition_7g_80gb = _A100_80GB_Base.partitioned("7g.80gb") + + +A100_80GB = _A100_80GB() diff --git a/flytekit/models/core/workflow.py b/flytekit/models/core/workflow.py index e60038c0f6d..efd18babcda 100644 --- a/flytekit/models/core/workflow.py +++ b/flytekit/models/core/workflow.py @@ -1,6 +1,7 @@ import datetime import typing +from flyteidl.core import tasks_pb2 from flyteidl.core import workflow_pb2 as _core_workflow from flytekit.models import common as _common @@ -562,24 +563,33 @@ def from_flyte_idl(cls, pb2_object): class TaskNodeOverrides(_common.FlyteIdlEntity): - def __init__(self, resources: typing.Optional[Resources] = None): + def __init__( + self, resources: typing.Optional[Resources], extended_resources: typing.Optional[tasks_pb2.ExtendedResources] + ): self._resources = resources + self._extended_resources = extended_resources @property def resources(self) -> Resources: return self._resources + @property + def extended_resources(self) -> tasks_pb2.ExtendedResources: + return self._extended_resources + def to_flyte_idl(self): return _core_workflow.TaskNodeOverrides( resources=self.resources.to_flyte_idl() if self.resources is not None else None, + extended_resources=self.extended_resources, ) @classmethod def from_flyte_idl(cls, pb2_object): resources = Resources.from_flyte_idl(pb2_object.resources) + extended_resources = pb2_object.extended_resources if pb2_object.HasField("extended_resources") else None if bool(resources.requests) or bool(resources.limits): - return cls(resources=resources) - return cls(resources=None) + return cls(resources=resources, extended_resources=extended_resources) + return cls(resources=None, extended_resources=extended_resources) class TaskNode(_common.FlyteIdlEntity): diff --git a/flytekit/models/task.py b/flytekit/models/task.py index f7f1d710c9f..48a8abfde17 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -336,6 +336,7 @@ def __init__( config=None, k8s_pod=None, sql=None, + extended_resources=None, ): """ A task template represents the full set of information necessary to perform a unit of work in the Flyte system. @@ -359,6 +360,7 @@ def __init__( in tandem with the custom. :param K8sPod k8s_pod: Alternative to the container used to execute this task. :param Sql sql: This is used to execute query in FlytePropeller instead of running container or k8s_pod. + :param flyteidl.core.tasks_pb2.ExtendedResources extended_resources: The extended resources to allocate to the task. """ if ( (container is not None and k8s_pod is not None) @@ -377,6 +379,7 @@ def __init__( self._security_context = security_context self._k8s_pod = k8s_pod self._sql = sql + self._extended_resources = extended_resources @property def id(self): @@ -451,6 +454,14 @@ def k8s_pod(self): def sql(self): return self._sql + @property + def extended_resources(self): + """ + If not None, the extended resources to allocate to the task. + :rtype: flyteidl.core.tasks_pb2.ExtendedResources + """ + return self._extended_resources + def to_flyte_idl(self): """ :rtype: flyteidl.core.tasks_pb2.TaskTemplate @@ -464,6 +475,7 @@ def to_flyte_idl(self): container=self.container.to_flyte_idl() if self.container else None, task_type_version=self.task_type_version, security_context=self.security_context.to_flyte_idl() if self.security_context else None, + extended_resources=self.extended_resources, config={k: v for k, v in self.config.items()} if self.config is not None else None, k8s_pod=self.k8s_pod.to_flyte_idl() if self.k8s_pod else None, sql=self.sql.to_flyte_idl() if self.sql else None, @@ -487,6 +499,7 @@ def from_flyte_idl(cls, pb2_object): security_context=_sec.SecurityContext.from_flyte_idl(pb2_object.security_context) if pb2_object.security_context and pb2_object.security_context.ByteSize() > 0 else None, + extended_resources=pb2_object.extended_resources if pb2_object.HasField("extended_resources") else None, config={k: v for k, v in pb2_object.config.items()} if pb2_object.config is not None else None, k8s_pod=K8sPod.from_flyte_idl(pb2_object.k8s_pod) if pb2_object.HasField("k8s_pod") else None, sql=Sql.from_flyte_idl(pb2_object.sql) if pb2_object.HasField("sql") else None, diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 87ccd2f5346..cfe43544f36 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -214,6 +214,7 @@ def get_serializable_task( config=entity.get_config(settings), k8s_pod=pod, sql=entity.get_sql(settings), + extended_resources=entity.get_extended_resources(settings), ) if settings.should_fast_serialize() and isinstance(entity, PythonAutoContainerTask): entity.reset_command_fn() @@ -440,7 +441,8 @@ def get_serializable_node( upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], task_node=workflow_model.TaskNode( - reference_id=task_spec.template.id, overrides=TaskNodeOverrides(resources=entity._resources) + reference_id=task_spec.template.id, + overrides=TaskNodeOverrides(resources=entity._resources, extended_resources=entity._extended_resources), ), ) if entity._aliases: @@ -516,7 +518,8 @@ def get_serializable_node( upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], task_node=workflow_model.TaskNode( - reference_id=entity.flyte_entity.id, overrides=TaskNodeOverrides(resources=entity._resources) + reference_id=entity.flyte_entity.id, + overrides=TaskNodeOverrides(resources=entity._resources, extended_resources=entity._extended_resources), ), ) elif isinstance(entity.flyte_entity, FlyteWorkflow): @@ -565,7 +568,7 @@ def get_serializable_array_node( task_spec = get_serializable(entity_mapping, settings, entity, options) task_node = workflow_model.TaskNode( reference_id=task_spec.template.id, - overrides=TaskNodeOverrides(resources=node._resources), + overrides=TaskNodeOverrides(resources=node._resources, extended_resources=node._extended_resources), ) node = workflow_model.Node( id=entity.name, diff --git a/tests/flytekit/common/parameterizers.py b/tests/flytekit/common/parameterizers.py index d5b07fe4202..96c30b69b4e 100644 --- a/tests/flytekit/common/parameterizers.py +++ b/tests/flytekit/common/parameterizers.py @@ -1,6 +1,9 @@ from datetime import timedelta from itertools import product +from flyteidl.core import tasks_pb2 + +from flytekit.extras.accelerators import A100, T4 from flytekit.models import interface, literals, security, task, types from flytekit.models.core import identifier from flytekit.models.core import types as _core_types @@ -136,7 +139,6 @@ ) ] - LIST_OF_TASK_TEMPLATES = [ task.TaskTemplate( identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), @@ -250,3 +252,19 @@ LIST_OF_SECURITY_CONTEXT = [ security.SecurityContext(run_as=r, secrets=s, tokens=None) for r in LIST_RUN_AS for s in LIST_OF_SECRETS ] + [None] + +LIST_OF_ACCELERATORS = [ + None, + T4, + A100, + A100.unpartitioned, + A100.partition_1g_5gb, +] + +LIST_OF_EXTENDED_RESOURCES = [ + None, + *[ + tasks_pb2.ExtendedResources(gpu_accelerator=None if accelerator is None else accelerator.to_flyte_idl()) + for accelerator in LIST_OF_ACCELERATORS + ], +] diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index 81621ef3fc4..cb790f3c2e1 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -13,6 +13,7 @@ from flytekit.core.task import task from flytekit.core.workflow import workflow from flytekit.exceptions.user import FlyteAssertion +from flytekit.extras.accelerators import A100, T4 from flytekit.models import literals as _literal_models from flytekit.models.task import Resources as _resources_models from flytekit.tools.translator import get_serializable @@ -465,3 +466,29 @@ def wf() -> str: return "hi" assert wf.nodes[0].flyte_entity.container_image == "hello/world" + + +def test_override_accelerator(): + @task(accelerator=T4) + def bar() -> str: + return "hello" + + @workflow + def my_wf() -> str: + return bar().with_overrides(accelerator=A100.partition_1g_5gb) + + serialization_settings = flytekit.configuration.SerializationSettings( + project="test_proj", + domain="test_domain", + version="abc", + image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), + env={}, + ) + wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) + assert len(wf_spec.template.nodes) == 1 + assert wf_spec.template.nodes[0].task_node.overrides is not None + assert wf_spec.template.nodes[0].task_node.overrides.extended_resources is not None + accelerator = wf_spec.template.nodes[0].task_node.overrides.extended_resources.gpu_accelerator + assert accelerator.device == "nvidia-tesla-a100" + assert accelerator.partition_size == "1g.5gb" + assert not accelerator.HasField("unpartitioned") diff --git a/tests/flytekit/unit/extras/test_accelerators.py b/tests/flytekit/unit/extras/test_accelerators.py new file mode 100644 index 00000000000..a62dff7af79 --- /dev/null +++ b/tests/flytekit/unit/extras/test_accelerators.py @@ -0,0 +1,64 @@ +from collections import OrderedDict + +from flytekit.configuration import Image, ImageConfig, SerializationSettings +from flytekit.core.task import task +from flytekit.extras.accelerators import A100, T4 +from flytekit.tools.translator import get_serializable + +serialization_settings = SerializationSettings( + project="proj", + domain="dom", + version="123", + image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")), + env={}, +) + + +class TestAccelerators: + def test_gpu_accelerator(self): + @task(accelerator=T4) + def needs_t4(a: int): + pass + + ts = get_serializable(OrderedDict(), serialization_settings, needs_t4).to_flyte_idl() + gpu_accelerator = ts.template.extended_resources.gpu_accelerator + assert gpu_accelerator is not None + assert gpu_accelerator.device == "nvidia-tesla-t4" + assert not gpu_accelerator.HasField("unpartitioned") + assert not gpu_accelerator.HasField("partition_size") + + def test_mig(self): + @task(accelerator=A100) + def needs_a100(a: int): + pass + + ts = get_serializable(OrderedDict(), serialization_settings, needs_a100).to_flyte_idl() + gpu_accelerator = ts.template.extended_resources.gpu_accelerator + assert gpu_accelerator is not None + assert gpu_accelerator.device == "nvidia-tesla-a100" + assert not gpu_accelerator.HasField("unpartitioned") + assert not gpu_accelerator.HasField("partition_size") + + def test_mig_unpartitioned(self): + @task(accelerator=A100.unpartitioned) + def needs_unpartitioned_a100(a: int): + pass + + ts = get_serializable(OrderedDict(), serialization_settings, needs_unpartitioned_a100).to_flyte_idl() + gpu_accelerator = ts.template.extended_resources.gpu_accelerator + assert gpu_accelerator is not None + assert gpu_accelerator.device == "nvidia-tesla-a100" + assert gpu_accelerator.unpartitioned + assert not gpu_accelerator.HasField("partition_size") + + def test_mig_partitioned(self): + @task(accelerator=A100.partition_1g_5gb) + def needs_partitioned_a100(a: int): + pass + + ts = get_serializable(OrderedDict(), serialization_settings, needs_partitioned_a100).to_flyte_idl() + gpu_accelerator = ts.template.extended_resources.gpu_accelerator + assert gpu_accelerator is not None + assert gpu_accelerator.device == "nvidia-tesla-a100" + assert gpu_accelerator.partition_size == "1g.5gb" + assert not gpu_accelerator.HasField("unpartitioned") diff --git a/tests/flytekit/unit/models/core/test_workflow.py b/tests/flytekit/unit/models/core/test_workflow.py index 6775d589403..cd36381ea01 100644 --- a/tests/flytekit/unit/models/core/test_workflow.py +++ b/tests/flytekit/unit/models/core/test_workflow.py @@ -1,5 +1,8 @@ from datetime import timedelta +from flyteidl.core import tasks_pb2 + +from flytekit.extras.accelerators import T4 from flytekit.models import interface as _interface from flytekit.models import literals as _literals from flytekit.models import types as _types @@ -300,10 +303,12 @@ def test_task_node_overrides(): Resources( requests=[Resources.ResourceEntry(Resources.ResourceName.CPU, "1")], limits=[Resources.ResourceEntry(Resources.ResourceName.CPU, "2")], - ) + ), + tasks_pb2.ExtendedResources(gpu_accelerator=T4.to_flyte_idl()), ) assert overrides.resources.requests == [Resources.ResourceEntry(Resources.ResourceName.CPU, "1")] assert overrides.resources.limits == [Resources.ResourceEntry(Resources.ResourceName.CPU, "2")] + assert overrides.extended_resources.gpu_accelerator == T4.to_flyte_idl() obj = _workflow.TaskNodeOverrides.from_flyte_idl(overrides.to_flyte_idl()) assert overrides == obj @@ -316,12 +321,14 @@ def test_task_node_with_overrides(): Resources( requests=[Resources.ResourceEntry(Resources.ResourceName.CPU, "1")], limits=[Resources.ResourceEntry(Resources.ResourceName.CPU, "2")], - ) + ), + tasks_pb2.ExtendedResources(gpu_accelerator=T4.to_flyte_idl()), ), ) assert task_node.overrides.resources.requests == [Resources.ResourceEntry(Resources.ResourceName.CPU, "1")] assert task_node.overrides.resources.limits == [Resources.ResourceEntry(Resources.ResourceName.CPU, "2")] + assert task_node.overrides.extended_resources.gpu_accelerator == T4.to_flyte_idl() obj = _workflow.TaskNode.from_flyte_idl(task_node.to_flyte_idl()) assert task_node == obj diff --git a/tests/flytekit/unit/models/test_tasks.py b/tests/flytekit/unit/models/test_tasks.py index a979a39b661..b4158c38521 100644 --- a/tests/flytekit/unit/models/test_tasks.py +++ b/tests/flytekit/unit/models/test_tasks.py @@ -2,12 +2,13 @@ from itertools import product import pytest -from flyteidl.core.tasks_pb2 import TaskMetadata +from flyteidl.core.tasks_pb2 import ExtendedResources, TaskMetadata from google.protobuf import text_format import flytekit.models.interface as interface_models import flytekit.models.literals as literal_models from flytekit import Description, Documentation, SourceCode +from flytekit.extras.accelerators import T4 from flytekit.models import literals, task, types from flytekit.models.core import identifier from tests.flytekit.common import parameterizers @@ -108,6 +109,7 @@ def test_task_template(in_tuple): {"d": "e"}, ), config={"a": "b"}, + extended_resources=ExtendedResources(gpu_accelerator=T4.to_flyte_idl()), ) assert obj.id.resource_type == identifier.ResourceType.TASK assert obj.id.project == "project" @@ -124,6 +126,9 @@ def test_task_template(in_tuple): task.TaskTemplate.from_flyte_idl(obj.to_flyte_idl()).to_flyte_idl() ) assert obj.config == {"a": "b"} + assert obj.extended_resources.gpu_accelerator.device == "nvidia-tesla-t4" + assert not obj.extended_resources.gpu_accelerator.HasField("unpartitioned") + assert not obj.extended_resources.gpu_accelerator.HasField("partition_size") def test_task_spec(): @@ -166,6 +171,7 @@ def test_task_spec(): {"d": "e"}, ), config={"a": "b"}, + extended_resources=ExtendedResources(gpu_accelerator=T4.to_flyte_idl()), ) short_description = "short" @@ -212,6 +218,7 @@ def test_task_template_k8s_pod_target(): metadata=task.K8sObjectMetadata(labels={"label": "foo"}, annotations={"anno": "bar"}), pod_spec={"str": "val", "int": 1}, ), + extended_resources=ExtendedResources(gpu_accelerator=T4.to_flyte_idl()), ) assert obj.id.resource_type == identifier.ResourceType.TASK assert obj.id.project == "project" @@ -226,6 +233,9 @@ def test_task_template_k8s_pod_target(): task.TaskTemplate.from_flyte_idl(obj.to_flyte_idl()).to_flyte_idl() ) assert obj.config == {"a": "b"} + assert obj.extended_resources.gpu_accelerator.device == "nvidia-tesla-t4" + assert not obj.extended_resources.gpu_accelerator.HasField("unpartitioned") + assert not obj.extended_resources.gpu_accelerator.HasField("partition_size") @pytest.mark.parametrize("sec_ctx", parameterizers.LIST_OF_SECURITY_CONTEXT) @@ -254,6 +264,28 @@ def test_task_template_security_context(sec_ctx): assert task.TaskTemplate.from_flyte_idl(obj.to_flyte_idl()).security_context == expected +@pytest.mark.parametrize("extended_resources", parameterizers.LIST_OF_EXTENDED_RESOURCES) +def test_task_template_extended_resources(extended_resources): + obj = task.TaskTemplate( + identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), + "python", + parameterizers.LIST_OF_TASK_METADATA[0], + parameterizers.LIST_OF_INTERFACES[0], + {"a": 1, "b": {"c": 2, "d": 3}}, + container=task.Container( + "my_image", + ["this", "is", "a", "cmd"], + ["this", "is", "an", "arg"], + parameterizers.LIST_OF_RESOURCES[0], + {"a": "b"}, + {"d": "e"}, + ), + extended_resources=extended_resources, + ) + assert obj.extended_resources == extended_resources + assert task.TaskTemplate.from_flyte_idl(obj.to_flyte_idl()).extended_resources == extended_resources + + @pytest.mark.parametrize("task_closure", parameterizers.LIST_OF_TASK_CLOSURES) def test_task(task_closure): obj = task.Task( diff --git a/tests/flytekit/unit/models/test_workflow_closure.py b/tests/flytekit/unit/models/test_workflow_closure.py index 2b5b06696b6..64e5a577139 100644 --- a/tests/flytekit/unit/models/test_workflow_closure.py +++ b/tests/flytekit/unit/models/test_workflow_closure.py @@ -1,5 +1,8 @@ from datetime import timedelta +from flyteidl.core import tasks_pb2 + +from flytekit.extras.accelerators import T4 from flytekit.models import interface as _interface from flytekit.models import literals as _literals from flytekit.models import task as _task @@ -58,6 +61,7 @@ def test_workflow_closure(): {}, {}, ), + extended_resources=tasks_pb2.ExtendedResources(gpu_accelerator=T4.to_flyte_idl()), ) task_node = _workflow.TaskNode(task.id)