Skip to content

Commit

Permalink
Ignore inapplicable FlytePropeller limits
Browse files Browse the repository at this point in the history
Signed-off-by: Edwin Yu <[email protected]>
  • Loading branch information
edwinyyyu committed Sep 12, 2023
1 parent 6fce44e commit ee55506
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 43 deletions.
9 changes: 5 additions & 4 deletions plugins/flytekit-mmcloud/flytekitplugins/mmcloud/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import grpc
from flyteidl.admin.agent_pb2 import CreateTaskResponse, DeleteTaskResponse, GetTaskResponse, Resource
from flytekitplugins.mmcloud.utils import async_check_output, flyte_to_mmcloud_resources, mmcloud_status_to_flyte_state
from flytekitplugins.mmcloud.utils import async_check_output, mmcloud_status_to_flyte_state

from flytekit import current_context
from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry
Expand Down Expand Up @@ -73,12 +73,13 @@ async def async_create(
*self._response_format,
]

container = task_template.container

min_cpu, min_mem, max_cpu, max_mem = flyte_to_mmcloud_resources(container.resources)
# We do not use container.resources because FlytePropeller will impose limits that should not apply to MMCloud
min_cpu, min_mem, max_cpu, max_mem = task_template.custom["resources"]
submit_command.extend(["--cpu", f"{min_cpu}:{max_cpu}"] if max_cpu else ["--cpu", f"{min_cpu}"])
submit_command.extend(["--mem", f"{min_mem}:{max_mem}"] if max_mem else ["--mem", f"{min_mem}"])

container = task_template.container

image = container.image
submit_command.extend(["--image", image])

Expand Down
13 changes: 11 additions & 2 deletions plugins/flytekit-mmcloud/flytekitplugins/mmcloud/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from dataclasses import dataclass
from typing import Any, Optional, Union

from flytekitplugins.mmcloud.utils import flyte_to_mmcloud_resources
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Struct

from flytekit.configuration import DefaultImages, SerializationSettings
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.core.resources import Resources
from flytekit.extend import TaskPlugins
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin
from flytekit.image_spec.image_spec import ImageSpec
Expand All @@ -30,7 +32,9 @@ def __init__(
self,
task_config: Optional[MMCloudConfig],
task_function: Callable,
container_image: Optional[Union[str, ImageSpec]],
container_image: Optional[Union[str, ImageSpec]] = None,
requests: Optional[Resources] = None,
limits: Optional[Resources] = None,
**kwargs,
):
super().__init__(
Expand All @@ -41,6 +45,8 @@ def __init__(
**kwargs,
)

self._mmcloud_resources = flyte_to_mmcloud_resources(requests=requests, limits=limits)

def execute(self, **kwargs) -> Any:
# FLOAT_JOB_ID should always and only be defined on a Memory Machine Cloud worker node
if os.getenv("FLOAT_JOB_ID"):
Expand All @@ -54,7 +60,10 @@ def get_custom(self, settings: SerializationSettings) -> dict[str, Any]:
"""
Return plugin-specific data as a serializable dictionary.
"""
config = {"submit_extra": self.task_config.submit_extra}
config = {
"submit_extra": self.task_config.submit_extra,
"resources": [str(resource) if resource else None for resource in self._mmcloud_resources],
}
s = Struct()
s.update(config)
return json_format.MessageToDict(s)
Expand Down
32 changes: 7 additions & 25 deletions plugins/flytekit-mmcloud/flytekitplugins/mmcloud/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from flyteidl.admin.agent_pb2 import PERMANENT_FAILURE, RETRYABLE_FAILURE, RUNNING, SUCCEEDED, State
from kubernetes.utils.quantity import parse_quantity

from flytekit.models.task import Resources
from flytekit.core.resources import Resources

MMCLOUD_STATUS_TO_FLYTE_STATE = {
"Submitted": RUNNING,
Expand Down Expand Up @@ -39,35 +39,17 @@ def mmcloud_status_to_flyte_state(status: str) -> State:
return MMCLOUD_STATUS_TO_FLYTE_STATE[status]


def flyte_to_mmcloud_resources(resources: Resources) -> tuple[int, int, int, int]:
def flyte_to_mmcloud_resources(requests: Resources, limits: Resources) -> tuple[int, int, int, int]:
"""
Map Flyte (K8s) resources to MMCloud resources.
"""
requests = resources.requests
limits = resources.limits

B_IN_GIB = 1073741824

req_cpu = None
req_mem = None
lim_cpu = None
lim_mem = None

for request in requests:
if request.name == Resources.ResourceName.CPU:
# MMCloud does not support cpu under 1
req_cpu = max(Decimal(1), parse_quantity(request.value))
elif request.name == Resources.ResourceName.MEMORY:
# MMCloud does not support mem under 1Gi
req_mem = max(Decimal(B_IN_GIB), parse_quantity(request.value))

for limit in limits:
if limit.name == Resources.ResourceName.CPU:
# MMCloud does not support cpu under 1
lim_cpu = max(Decimal(1), parse_quantity(limit.value))
elif limit.name == Resources.ResourceName.MEMORY:
# MMCloud does not support mem under 1Gi
lim_mem = max(Decimal(B_IN_GIB), parse_quantity(limit.value))
# MMCloud does not support cpu under 1 or mem under 1Gi
req_cpu = max(Decimal(1), parse_quantity(requests.cpu)) if requests and requests.cpu else None
req_mem = max(Decimal(B_IN_GIB), parse_quantity(requests.mem)) if requests and requests.mem else None
lim_cpu = max(Decimal(1), parse_quantity(limits.cpu)) if limits and limits.cpu else None
lim_mem = max(Decimal(B_IN_GIB), parse_quantity(limits.mem)) if limits and limits.mem else None

# Convert Decimal to int
# Round up so that resource demands are met
Expand Down
18 changes: 6 additions & 12 deletions plugins/flytekit-mmcloud/tests/test_mmcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from flytekit import Resources, task
from flytekit.configuration import DefaultImages, ImageConfig, SerializationSettings
from flytekit.core.resources import convert_resources_to_resource_model
from flytekit.extend import get_serializable
from flytekit.extend.backend.base_agent import AgentRegistry

Expand All @@ -22,7 +21,7 @@
def test_mmcloud_task():
task_config = MMCloudConfig(submit_extra="--migratePolicy [enable=true]")
requests = Resources(cpu="2", mem="4Gi")
limits = Resources(cpu="4", mem="16Gi")
limits = Resources(cpu="4")
container_image = DefaultImages.default_image()
environment = {"KEY": "value"}

Expand All @@ -45,8 +44,7 @@ def say_hello(name: str) -> str:
template = task_spec.template
container = template.container

assert template.custom == {"submit_extra": "--migratePolicy [enable=true]"}
assert container.resources == convert_resources_to_resource_model(requests=requests, limits=limits)
assert template.custom == {"submit_extra": "--migratePolicy [enable=true]", "resources": ["2", "4", "4", None]}
assert container.image == container_image
assert container.env == environment

Expand Down Expand Up @@ -78,10 +76,8 @@ def test_flyte_to_mmcloud_resources():

for (req_cpu, req_mem, lim_cpu, lim_mem), (min_cpu, min_mem, max_cpu, max_mem) in success_cases.items():
resources = flyte_to_mmcloud_resources(
convert_resources_to_resource_model(
requests=Resources(cpu=req_cpu, mem=req_mem),
limits=Resources(cpu=lim_cpu, mem=lim_mem),
)
requests=Resources(cpu=req_cpu, mem=req_mem),
limits=Resources(cpu=lim_cpu, mem=lim_mem),
)
assert resources == (min_cpu, min_mem, max_cpu, max_mem)

Expand All @@ -93,10 +89,8 @@ def test_flyte_to_mmcloud_resources():
for (req_cpu, req_mem, lim_cpu, lim_mem) in error_cases:
with pytest.raises(ValueError):
flyte_to_mmcloud_resources(
convert_resources_to_resource_model(
requests=Resources(cpu=req_cpu, mem=req_mem),
limits=Resources(cpu=lim_cpu, mem=lim_mem),
)
requests=Resources(cpu=req_cpu, mem=req_mem),
limits=Resources(cpu=lim_cpu, mem=lim_mem),
)


Expand Down

0 comments on commit ee55506

Please sign in to comment.