Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Extended Resources] GPU Accelerators #1843

Merged
merged 71 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
e029c2a
pip through to container
wild-endeavor Aug 16, 2023
550c3ff
move around
wild-endeavor Aug 16, 2023
dfdb38e
add asserts
wild-endeavor Aug 16, 2023
8bd3446
delete bad line
wild-endeavor Aug 16, 2023
f12e10c
switch to abc and add support for gpu unpartitioned
jeevb Sep 1, 2023
94b6cb8
Add Azure-specific headers when uploading to blob storage (#1784)
devictr Aug 17, 2023
7024010
Add async delete function in base_agent (#1800)
Future-Outlier Aug 19, 2023
34e0c68
Add support for execution name prefixes (#1803)
troychiu Aug 21, 2023
e636025
Remove ref in output (#1794)
wild-endeavor Aug 21, 2023
993df0d
Inherit directly from DataClassJsonMixin instead of using @dataclass_…
ringohoffman Aug 21, 2023
f322aef
Async file sensor (#1790)
pingsutw Aug 23, 2023
3cc6326
Eager workflows to support async workflows (#1579)
cosmicBboy Aug 25, 2023
99d2ea8
Enable SecretsManager.get to load and return bytes (#1798)
ysysys3074 Aug 25, 2023
68f87d9
Batch upload flyte directory (#1806)
pingsutw Aug 26, 2023
ca33b5f
Better error messaging for overrides (#1807)
kumare3 Aug 28, 2023
f48e4e9
Run remote Launchplan from `pyflyte run` (#1785)
kumare3 Aug 29, 2023
26f7de0
Add is none function (#1757)
pingsutw Aug 29, 2023
88a108f
Dynamic workflow should not throw nested task warning (#1812)
Aug 31, 2023
112f740
Add a manual image building GH action (#1816)
wild-endeavor Sep 1, 2023
40a789f
catch abfs protocol in data_persistence.py/get_filesystem and set ano…
fiedlerNr9 Sep 1, 2023
9a599bf
None doesnt work
jeevb Sep 1, 2023
18be31a
unpartitioned selector
jeevb Sep 1, 2023
015e24f
Fix list of annotated structured dataset (#1817)
wild-endeavor Sep 1, 2023
d167977
Support the flytectl config.yaml admin.clientSecretEnvVar option in f…
chaohengstudent Sep 6, 2023
1b6a027
Async agent delete function for while loop case (#1802)
Future-Outlier Sep 7, 2023
9e0a91a
refactor
jeevb Sep 12, 2023
810a5cf
fix docs warnings (#1827)
samhita-alla Sep 11, 2023
305864d
Fix extract_task_module (#1829)
pingsutw Sep 11, 2023
c8fc69d
Feat: Add type support for pydantic BaseModels (#1660)
ArthurBook Sep 11, 2023
e70ac1e
add test for unspecified mig
jeevb Sep 12, 2023
5201d1a
add support for overriding accelerator
jeevb Sep 12, 2023
b8fc677
cleanup
jeevb Sep 12, 2023
fbe8bc5
move from core to extras
jeevb Sep 12, 2023
ad49582
fixes
jeevb Sep 12, 2023
cefa3d3
fixes
jeevb Sep 12, 2023
ab1e0d6
fixes
jeevb Sep 12, 2023
e98517b
cleanup
jeevb Sep 13, 2023
68f423d
Make FlyteRemote slightly more copy/pastable (#1830)
katrogan Sep 12, 2023
2f681e9
Pyflyte meta inputs (#1823)
kumare3 Sep 12, 2023
62255f5
Use mashumaro to serialize/deserialize dataclass (#1735)
hhcs9527 Sep 12, 2023
a1af299
Databricks Agent (#1797)
Future-Outlier Sep 12, 2023
f0fc698
Prometheus metrics (#1815)
pingsutw Sep 13, 2023
c4ade35
Pyflyte register optionally activates schedule (#1832)
kumare3 Sep 14, 2023
3690e41
Remove versions 3.9 and 3.10 (#1831)
wild-endeavor Sep 14, 2023
92d4340
Snowflake agent (#1799)
hhcs9527 Sep 15, 2023
8a7a092
Update agent metric name (#1835)
pingsutw Sep 15, 2023
3e61111
MemVerge MMCloud Agent (#1821)
edwinyyyu Sep 15, 2023
948ae71
Add download badges in readme (#1836)
pingsutw Sep 18, 2023
f63faaf
Eager local entrypoint and support for offloaded types (#1833)
cosmicBboy Sep 18, 2023
92a7fd6
update requirements and add snowflake agent to api reference (#1838)
samhita-alla Sep 19, 2023
2c1f729
Fix: Make sure decks created in elastic task workers are transferred …
fg91 Sep 19, 2023
99abcb4
add accept grpc (#1841)
wild-endeavor Sep 20, 2023
3220a3e
Feat: Enable `flytekit` to authenticate with proxy in front of FlyteA…
fg91 Sep 20, 2023
5c81d17
bump flyteidl
jeevb Sep 20, 2023
8691dcb
Merge branch 'master' into gpu-selector
jeevb Sep 20, 2023
6f112a0
make requirements
jeevb Sep 21, 2023
92acdae
fix failing tests
jeevb Sep 21, 2023
662d1ee
move gpu accelerator to flyteidl.core.Resources
jeevb Sep 23, 2023
ab0c555
Use ResourceExtensions for extended resources
jeevb Oct 2, 2023
c38b76b
cleanup
jeevb Oct 2, 2023
c370262
Switch to using ExtendedResources in TaskTemplate
jeevb Oct 3, 2023
e239b33
Merge remote-tracking branch 'origin/master' into gpu-selector
jeevb Oct 6, 2023
2f342b8
cleanups
jeevb Oct 6, 2023
b95ace8
Merge branch 'master' into gpu-selector
jeevb Oct 23, 2023
41cf59c
update flyteidl
jeevb Oct 24, 2023
ee42a67
Replace _core_task imports with tasks_pb2
jeevb Oct 26, 2023
3c59c67
less verbose definitions
jeevb Oct 27, 2023
ae1f44d
Attempt at less confusing syntax
jeevb Oct 27, 2023
1d19dce
Streamline UX
jeevb Oct 31, 2023
0d35c5f
Run make fmt
jeevb Oct 31, 2023
6ead99c
Merge branch 'master' into gpu-selector
wild-endeavor Nov 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from dataclasses import dataclass
from typing import Any, Coroutine, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union, cast

from flyteidl.core import tasks_pb2

from flytekit.configuration import SerializationSettings
from flytekit.core.context_manager import (
ExecutionParameters,
Expand Down Expand Up @@ -344,6 +346,12 @@
"""
return None

def get_extended_resources(self, settings: SerializationSettings) -> Optional[tasks_pb2.ExtendedResources]:
"""
Returns the extended resources to allocate to the task on hosted Flyte.
"""
return None

Check warning on line 353 in flytekit/core/base_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/base_task.py#L353

Added line #L353 was not covered by tests

def local_execution_mode(self) -> ExecutionState.Mode:
""" """
return ExecutionState.Mode.LOCAL_TASK_EXECUTION
Expand Down
8 changes: 8 additions & 0 deletions flytekit/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import typing
from typing import Any, List

from flyteidl.core import tasks_pb2

from flytekit.core.resources import Resources, convert_resources_to_resource_model
from flytekit.core.utils import _dnsify
from flytekit.loggers import logger
Expand Down Expand Up @@ -62,6 +64,7 @@
self._aliases: _workflow_model.Alias = None
self._outputs = None
self._resources: typing.Optional[_resources_model] = None
self._extended_resources: typing.Optional[tasks_pb2.ExtendedResources] = None

def runs_before(self, other: Node):
"""
Expand Down Expand Up @@ -172,6 +175,11 @@
assert_not_promise(v, "container_image")
self.flyte_entity._container_image = v

if "accelerator" in kwargs:
v = kwargs["accelerator"]
assert_not_promise(v, "accelerator")
self._extended_resources = tasks_pb2.ExtendedResources(gpu_accelerator=v.to_flyte_idl())

Check warning on line 181 in flytekit/core/node.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/node.py#L179-L181

Added lines #L179 - L181 were not covered by tests

return self


Expand Down
15 changes: 15 additions & 0 deletions flytekit/core/python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from abc import ABC
from typing import Callable, Dict, List, Optional, TypeVar, Union

from flyteidl.core import tasks_pb2

from flytekit.configuration import ImageConfig, SerializationSettings
from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin
from flytekit.core.context_manager import FlyteContextManager
Expand All @@ -13,6 +15,7 @@
from flytekit.core.tracked_abc import FlyteTrackedABC
from flytekit.core.tracker import TrackedInstance, extract_task_module
from flytekit.core.utils import _get_container_definition, _serialize_pod_spec, timeit
from flytekit.extras.accelerators import BaseAccelerator
from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec
from flytekit.loggers import logger
from flytekit.models import task as _task_model
Expand Down Expand Up @@ -44,6 +47,7 @@
secret_requests: Optional[List[Secret]] = None,
pod_template: Optional[PodTemplate] = None,
pod_template_name: Optional[str] = None,
accelerator: Optional[BaseAccelerator] = None,
**kwargs,
):
"""
Expand All @@ -70,6 +74,7 @@
- `AWS Parameter store <https://docs.aws.amazon.com/systems-manager/latest/userguide/systems-manager-parameter-store.html>`__
:param pod_template: Custom PodTemplate for this task.
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
:param accelerator: The accelerator to use for this task.
"""
sec_ctx = None
if secret_requests:
Expand Down Expand Up @@ -110,6 +115,7 @@
self._get_command_fn = self.get_default_command

self.pod_template = pod_template
self.accelerator = accelerator

@property
def task_resolver(self) -> TaskResolverMixin:
Expand Down Expand Up @@ -219,6 +225,15 @@
return {}
return {_PRIMARY_CONTAINER_NAME_FIELD: self.pod_template.primary_container_name}

def get_extended_resources(self, settings: SerializationSettings) -> Optional[tasks_pb2.ExtendedResources]:
"""
Returns the extended resources to allocate to the task on hosted Flyte.
"""
if self.accelerator is None:
return None

Check warning on line 233 in flytekit/core/python_auto_container.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/python_auto_container.py#L233

Added line #L233 was not covered by tests

return tasks_pb2.ExtendedResources(gpu_accelerator=self.accelerator.to_flyte_idl())

Check warning on line 235 in flytekit/core/python_auto_container.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/python_auto_container.py#L235

Added line #L235 was not covered by tests


class DefaultTaskResolver(TrackedInstance, TaskResolverMixin):
"""
Expand Down
6 changes: 6 additions & 0 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.core.reference_entity import ReferenceEntity, TaskReference
from flytekit.core.resources import Resources
from flytekit.extras.accelerators import BaseAccelerator
from flytekit.image_spec.image_spec import ImageSpec
from flytekit.models.documentation import Documentation
from flytekit.models.security import Secret
Expand Down Expand Up @@ -102,6 +103,7 @@ def task(
enable_deck: Optional[bool] = ...,
pod_template: Optional["PodTemplate"] = ...,
pod_template_name: Optional[str] = ...,
accelerator: Optional[BaseAccelerator] = ...,
) -> Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]]:
...

Expand Down Expand Up @@ -129,6 +131,7 @@ def task(
enable_deck: Optional[bool] = ...,
pod_template: Optional["PodTemplate"] = ...,
pod_template_name: Optional[str] = ...,
accelerator: Optional[BaseAccelerator] = ...,
) -> Union[PythonFunctionTask[T], Callable[..., FuncOut]]:
...

Expand All @@ -155,6 +158,7 @@ def task(
enable_deck: Optional[bool] = None,
pod_template: Optional["PodTemplate"] = None,
pod_template_name: Optional[str] = None,
accelerator: Optional[BaseAccelerator] = None,
) -> Union[Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]], PythonFunctionTask[T], Callable[..., FuncOut]]:
"""
This is the core decorator to use for any task type in flytekit.
Expand Down Expand Up @@ -248,6 +252,7 @@ def foo2():
:param docs: Documentation about this task
:param pod_template: Custom PodTemplate for this task.
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
:param accelerator: The accelerator to use for this task.
"""

def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]:
Expand Down Expand Up @@ -277,6 +282,7 @@ def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]:
docs=docs,
pod_template=pod_template,
pod_template_name=pod_template_name,
accelerator=accelerator,
)
update_wrapper(task_instance, fn)
return task_instance
Expand Down
90 changes: 90 additions & 0 deletions flytekit/extras/accelerators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import abc
import copy
from typing import ClassVar, Generic, Optional, Type, TypeVar

from flyteidl.core import tasks_pb2

T = TypeVar("T")
MIG = TypeVar("MIG", bound="MultiInstanceGPUAccelerator")


class BaseAccelerator(abc.ABC, Generic[T]):
@abc.abstractmethod
def to_flyte_idl(self) -> T:
...

Check warning on line 14 in flytekit/extras/accelerators.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/accelerators.py#L14

Added line #L14 was not covered by tests


class GPUAccelerator(BaseAccelerator):
def __init__(self, device: str) -> None:
self._device = device

def to_flyte_idl(self) -> tasks_pb2.GPUAccelerator:
return tasks_pb2.GPUAccelerator(device=self._device)

Check warning on line 22 in flytekit/extras/accelerators.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/accelerators.py#L22

Added line #L22 was not covered by tests


A10G = GPUAccelerator("nvidia-a10g")
L4 = GPUAccelerator("nvidia-l4-vws")
K80 = GPUAccelerator("nvidia-tesla-k80")
M60 = GPUAccelerator("nvidia-tesla-m60")
P4 = GPUAccelerator("nvidia-tesla-p4")
P100 = GPUAccelerator("nvidia-tesla-p100")
T4 = GPUAccelerator("nvidia-tesla-t4")
V100 = GPUAccelerator("nvidia-tesla-v100")


class MultiInstanceGPUAccelerator(BaseAccelerator):
device: ClassVar[str]
_partition_size: Optional[str]

@property
def unpartitioned(self: MIG) -> MIG:
instance = copy.deepcopy(self)
instance._partition_size = None
return instance

Check warning on line 43 in flytekit/extras/accelerators.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/accelerators.py#L41-L43

Added lines #L41 - L43 were not covered by tests

@classmethod
def partitioned(cls: Type[MIG], partition_size: str) -> MIG:
instance = cls()
instance._partition_size = partition_size
return instance

def to_flyte_idl(self) -> tasks_pb2.GPUAccelerator:
msg = tasks_pb2.GPUAccelerator(device=self.device)

Check warning on line 52 in flytekit/extras/accelerators.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/accelerators.py#L52

Added line #L52 was not covered by tests
if not hasattr(self, "_partition_size"):
return msg

Check warning on line 54 in flytekit/extras/accelerators.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/accelerators.py#L54

Added line #L54 was not covered by tests

if self._partition_size is None:
msg.unpartitioned = True

Check warning on line 57 in flytekit/extras/accelerators.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/accelerators.py#L57

Added line #L57 was not covered by tests
else:
msg.partition_size = self._partition_size
return msg

Check warning on line 60 in flytekit/extras/accelerators.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/accelerators.py#L59-L60

Added lines #L59 - L60 were not covered by tests


class _A100_Base(MultiInstanceGPUAccelerator):
device = "nvidia-tesla-a100"


class _A100(_A100_Base):
partition_1g_5gb = _A100_Base.partitioned("1g.5gb")
partition_2g_10gb = _A100_Base.partitioned("2g.10gb")
partition_3g_20gb = _A100_Base.partitioned("3g.20gb")
partition_4g_20gb = _A100_Base.partitioned("4g.20gb")
partition_7g_40gb = _A100_Base.partitioned("7g.40gb")


A100 = _A100()


class _A100_80GB_Base(MultiInstanceGPUAccelerator):
device = "nvidia-a100-80gb"


class _A100_80GB(_A100_80GB_Base):
partition_1g_10gb = _A100_80GB_Base.partitioned("1g.10gb")
partition_2g_20gb = _A100_80GB_Base.partitioned("2g.20gb")
partition_3g_40gb = _A100_80GB_Base.partitioned("3g.40gb")
partition_4g_40gb = _A100_80GB_Base.partitioned("4g.40gb")
partition_7g_80gb = _A100_80GB_Base.partitioned("7g.80gb")


A100_80GB = _A100_80GB()
16 changes: 13 additions & 3 deletions flytekit/models/core/workflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import typing

from flyteidl.core import tasks_pb2
from flyteidl.core import workflow_pb2 as _core_workflow

from flytekit.models import common as _common
Expand Down Expand Up @@ -562,24 +563,33 @@


class TaskNodeOverrides(_common.FlyteIdlEntity):
def __init__(self, resources: typing.Optional[Resources] = None):
def __init__(
self, resources: typing.Optional[Resources], extended_resources: typing.Optional[tasks_pb2.ExtendedResources]
):
self._resources = resources
self._extended_resources = extended_resources

Check warning on line 570 in flytekit/models/core/workflow.py

View check run for this annotation

Codecov / codecov/patch

flytekit/models/core/workflow.py#L570

Added line #L570 was not covered by tests

@property
def resources(self) -> Resources:
return self._resources

@property
def extended_resources(self) -> tasks_pb2.ExtendedResources:
return self._extended_resources

Check warning on line 578 in flytekit/models/core/workflow.py

View check run for this annotation

Codecov / codecov/patch

flytekit/models/core/workflow.py#L578

Added line #L578 was not covered by tests

def to_flyte_idl(self):
return _core_workflow.TaskNodeOverrides(
resources=self.resources.to_flyte_idl() if self.resources is not None else None,
extended_resources=self.extended_resources,
)

@classmethod
def from_flyte_idl(cls, pb2_object):
resources = Resources.from_flyte_idl(pb2_object.resources)
extended_resources = pb2_object.extended_resources if pb2_object.HasField("extended_resources") else None

Check warning on line 589 in flytekit/models/core/workflow.py

View check run for this annotation

Codecov / codecov/patch

flytekit/models/core/workflow.py#L589

Added line #L589 was not covered by tests
if bool(resources.requests) or bool(resources.limits):
return cls(resources=resources)
return cls(resources=None)
return cls(resources=resources, extended_resources=extended_resources)
return cls(resources=None, extended_resources=extended_resources)

Check warning on line 592 in flytekit/models/core/workflow.py

View check run for this annotation

Codecov / codecov/patch

flytekit/models/core/workflow.py#L591-L592

Added lines #L591 - L592 were not covered by tests


class TaskNode(_common.FlyteIdlEntity):
Expand Down
13 changes: 13 additions & 0 deletions flytekit/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@
config=None,
k8s_pod=None,
sql=None,
extended_resources=None,
):
"""
A task template represents the full set of information necessary to perform a unit of work in the Flyte system.
Expand All @@ -359,6 +360,7 @@
in tandem with the custom.
:param K8sPod k8s_pod: Alternative to the container used to execute this task.
:param Sql sql: This is used to execute query in FlytePropeller instead of running container or k8s_pod.
:param flyteidl.core.tasks_pb2.ExtendedResources extended_resources: The extended resources to allocate to the task.
"""
if (
(container is not None and k8s_pod is not None)
Expand All @@ -377,6 +379,7 @@
self._security_context = security_context
self._k8s_pod = k8s_pod
self._sql = sql
self._extended_resources = extended_resources

Check warning on line 382 in flytekit/models/task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/models/task.py#L382

Added line #L382 was not covered by tests

@property
def id(self):
Expand Down Expand Up @@ -451,6 +454,14 @@
def sql(self):
return self._sql

@property
def extended_resources(self):
"""
If not None, the extended resources to allocate to the task.
:rtype: flyteidl.core.tasks_pb2.ExtendedResources
"""
return self._extended_resources

Check warning on line 463 in flytekit/models/task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/models/task.py#L463

Added line #L463 was not covered by tests

def to_flyte_idl(self):
"""
:rtype: flyteidl.core.tasks_pb2.TaskTemplate
Expand All @@ -464,6 +475,7 @@
container=self.container.to_flyte_idl() if self.container else None,
task_type_version=self.task_type_version,
security_context=self.security_context.to_flyte_idl() if self.security_context else None,
extended_resources=self.extended_resources,
config={k: v for k, v in self.config.items()} if self.config is not None else None,
k8s_pod=self.k8s_pod.to_flyte_idl() if self.k8s_pod else None,
sql=self.sql.to_flyte_idl() if self.sql else None,
Expand All @@ -487,6 +499,7 @@
security_context=_sec.SecurityContext.from_flyte_idl(pb2_object.security_context)
if pb2_object.security_context and pb2_object.security_context.ByteSize() > 0
else None,
extended_resources=pb2_object.extended_resources if pb2_object.HasField("extended_resources") else None,
config={k: v for k, v in pb2_object.config.items()} if pb2_object.config is not None else None,
k8s_pod=K8sPod.from_flyte_idl(pb2_object.k8s_pod) if pb2_object.HasField("k8s_pod") else None,
sql=Sql.from_flyte_idl(pb2_object.sql) if pb2_object.HasField("sql") else None,
Expand Down
9 changes: 6 additions & 3 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def get_serializable_task(
config=entity.get_config(settings),
k8s_pod=pod,
sql=entity.get_sql(settings),
extended_resources=entity.get_extended_resources(settings),
)
if settings.should_fast_serialize() and isinstance(entity, PythonAutoContainerTask):
entity.reset_command_fn()
Expand Down Expand Up @@ -440,7 +441,8 @@ def get_serializable_node(
upstream_node_ids=[n.id for n in upstream_nodes],
output_aliases=[],
task_node=workflow_model.TaskNode(
reference_id=task_spec.template.id, overrides=TaskNodeOverrides(resources=entity._resources)
reference_id=task_spec.template.id,
overrides=TaskNodeOverrides(resources=entity._resources, extended_resources=entity._extended_resources),
),
)
if entity._aliases:
Expand Down Expand Up @@ -516,7 +518,8 @@ def get_serializable_node(
upstream_node_ids=[n.id for n in upstream_nodes],
output_aliases=[],
task_node=workflow_model.TaskNode(
reference_id=entity.flyte_entity.id, overrides=TaskNodeOverrides(resources=entity._resources)
reference_id=entity.flyte_entity.id,
overrides=TaskNodeOverrides(resources=entity._resources, extended_resources=entity._extended_resources),
),
)
elif isinstance(entity.flyte_entity, FlyteWorkflow):
Expand Down Expand Up @@ -565,7 +568,7 @@ def get_serializable_array_node(
task_spec = get_serializable(entity_mapping, settings, entity, options)
task_node = workflow_model.TaskNode(
reference_id=task_spec.template.id,
overrides=TaskNodeOverrides(resources=node._resources),
overrides=TaskNodeOverrides(resources=node._resources, extended_resources=node._extended_resources),
)
node = workflow_model.Node(
id=entity.name,
Expand Down
Loading
Loading