Skip to content
This repository has been archived by the owner on Jul 19, 2024. It is now read-only.

Commit

Permalink
Pass cluster pool when creating executions (flyteorg#1208)
Browse files Browse the repository at this point in the history
* Pass cluster pool when creating executions

Signed-off-by: Future Outlier <[email protected]>
  • Loading branch information
iaroslav-ciupin authored and Future Outlier committed Oct 3, 2023
1 parent dd24867 commit 63b98a7
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 2 deletions.
10 changes: 10 additions & 0 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,15 @@ class RunLevelParams(PyFlyteParams):
"if `from-server` option is used",
)
)
cluster_pool: str = make_field(
click.Option(
param_decls=["--cluster-pool", "cluster_pool"],
required=False,
type=str,
default="",
help="Assign newly created execution to a given cluster pool",
)
)
computed_params: RunLevelComputedParams = field(default_factory=RunLevelComputedParams)
_remote: typing.Optional[FlyteRemote] = None

Expand Down Expand Up @@ -427,6 +436,7 @@ def run_remote(
overwrite_cache=run_level_params.overwrite_cache,
envs=run_level_params.envvars,
tags=run_level_params.tags,
cluster_pool=run_level_params.cluster_pool,
)

console_url = remote.generate_console_url(execution)
Expand Down
5 changes: 3 additions & 2 deletions flytekit/clis/sdk_in_container/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,17 @@ def serve(_: click.Context, port, worker, timeout):

async def _start_grpc_server(port: int, worker: int, timeout: int):
click.secho("Starting up the server to expose the prometheus metrics...", fg="blue")
from flytekit.extend.backend.agent_service import AsyncAgentService

try:
from prometheus_client import start_http_server

from flytekit.extend.backend.agent_service import AsyncAgentService

start_http_server(9090)
except ImportError as e:
click.secho(f"Failed to start the prometheus server with error {e}", fg="red")
click.secho("Starting the agent service...", fg="blue")
server = aio.server(futures.ThreadPoolExecutor(max_workers=worker))

add_AsyncAgentServiceServicer_to_server(AsyncAgentService(), server)

server.add_insecure_port(f"[::]:{port}")
Expand Down
42 changes: 42 additions & 0 deletions flytekit/models/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional

import flyteidl
import flyteidl.admin.cluster_assignment_pb2 as _cluster_assignment_pb2
import flyteidl.admin.execution_pb2 as _execution_pb2
import flyteidl.admin.node_execution_pb2 as _node_execution_pb2
import flyteidl.admin.task_execution_pb2 as _task_execution_pb2
Expand Down Expand Up @@ -179,6 +180,7 @@ def __init__(
overwrite_cache: Optional[bool] = None,
envs: Optional[_common_models.Envs] = None,
tags: Optional[typing.List[str]] = None,
cluster_assignment: Optional[ClusterAssignment] = None,
):
"""
:param flytekit.models.core.identifier.Identifier launch_plan: Launch plan unique identifier to execute
Expand Down Expand Up @@ -210,6 +212,7 @@ def __init__(
self._overwrite_cache = overwrite_cache
self._envs = envs
self._tags = tags
self._cluster_assignment = cluster_assignment

@property
def launch_plan(self):
Expand Down Expand Up @@ -288,6 +291,10 @@ def envs(self) -> Optional[_common_models.Envs]:
def tags(self) -> Optional[typing.List[str]]:
return self._tags

@property
def cluster_assignment(self) -> Optional[ClusterAssignment]:
return self._cluster_assignment

def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.execution_pb2.ExecutionSpec
Expand All @@ -308,6 +315,7 @@ def to_flyte_idl(self):
overwrite_cache=self.overwrite_cache,
envs=self.envs.to_flyte_idl() if self.envs else None,
tags=self.tags,
cluster_assignment=self._cluster_assignment.to_flyte_idl() if self._cluster_assignment else None,
)

@classmethod
Expand All @@ -334,8 +342,42 @@ def from_flyte_idl(cls, p):
overwrite_cache=p.overwrite_cache,
envs=_common_models.Envs.from_flyte_idl(p.envs) if p.HasField("envs") else None,
tags=p.tags,
cluster_assignment=ClusterAssignment.from_flyte_idl(p.cluster_assignment)
if p.HasField("cluster_assignment")
else None,
)


class ClusterAssignment(_common_models.FlyteIdlEntity):
def __init__(self, cluster_pool=None):
"""
:param Text cluster_pool:
"""
self._cluster_pool = cluster_pool

@property
def cluster_pool(self):
"""
:rtype: Text
"""
return self._cluster_pool

def to_flyte_idl(self):
"""
:rtype: flyteidl.admin._cluster_assignment_pb2.ClusterAssignment
"""
return _cluster_assignment_pb2.ClusterAssignment(
cluster_pool_name=self._cluster_pool,
)

@classmethod
def from_flyte_idl(cls, p):
"""
:param flyteidl.admin._cluster_assignment_pb2.ClusterAssignment p:
:rtype: flyteidl.admin.ClusterAssignment
"""
return cls(cluster_pool=p.cluster_pool_name)


class LiteralMapBlob(_common_models.FlyteIdlEntity):
def __init__(self, values=None, uri=None):
Expand Down
24 changes: 24 additions & 0 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from flytekit.models.core.identifier import Identifier, ResourceType, SignalIdentifier, WorkflowExecutionIdentifier
from flytekit.models.core.workflow import NodeMetadata
from flytekit.models.execution import (
ClusterAssignment,
ExecutionMetadata,
ExecutionSpec,
NodeExecutionGetDataResponse,
Expand Down Expand Up @@ -962,6 +963,7 @@ def _execute(
overwrite_cache: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""Common method for execution across all entities.
Expand All @@ -978,6 +980,7 @@ def _execute(
be available, overwriting the stored data once execution finishes successfully.
:param envs: Environment variables to set for the execution.
:param tags: Tags to set for the execution.
:param cluster_pool: Specify cluster pool on which newly created execution should be placed.
:returns: :class:`~flytekit.remote.workflow_execution.FlyteWorkflowExecution`
"""
if execution_name is not None and execution_name_prefix is not None:
Expand Down Expand Up @@ -1047,6 +1050,7 @@ def _execute(
security_context=options.security_context,
envs=common_models.Envs(envs) if envs else None,
tags=tags,
cluster_assignment=ClusterAssignment(cluster_pool=cluster_pool) if cluster_pool else None,
),
literal_inputs,
)
Expand Down Expand Up @@ -1106,6 +1110,7 @@ def execute(
overwrite_cache: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""
Execute a task, workflow, or launchplan, either something that's been declared locally, or a fetched entity.
Expand Down Expand Up @@ -1144,6 +1149,7 @@ def execute(
be available, overwriting the stored data once execution finishes successfully.
:param envs: Environment variables to be set for the execution.
:param tags: Tags to be set for the execution.
:param cluster_pool: Specify cluster pool on which newly created execution should be placed.
.. note:
Expand All @@ -1166,6 +1172,7 @@ def execute(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)
if isinstance(entity, FlyteWorkflow):
return self.execute_remote_wf(
Expand All @@ -1181,6 +1188,7 @@ def execute(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)
if isinstance(entity, PythonTask):
return self.execute_local_task(
Expand All @@ -1197,6 +1205,7 @@ def execute(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)
if isinstance(entity, WorkflowBase):
return self.execute_local_workflow(
Expand All @@ -1214,6 +1223,7 @@ def execute(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)
if isinstance(entity, LaunchPlan):
return self.execute_local_launch_plan(
Expand All @@ -1229,6 +1239,7 @@ def execute(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)
raise NotImplementedError(f"entity type {type(entity)} not recognized for execution")

Expand All @@ -1249,6 +1260,7 @@ def execute_remote_task_lp(
overwrite_cache: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""Execute a FlyteTask, or FlyteLaunchplan.
Expand All @@ -1267,6 +1279,7 @@ def execute_remote_task_lp(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)

def execute_remote_wf(
Expand All @@ -1283,6 +1296,7 @@ def execute_remote_wf(
overwrite_cache: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""Execute a FlyteWorkflow.
Expand All @@ -1302,6 +1316,7 @@ def execute_remote_wf(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)

# Flytekit Entities
Expand All @@ -1322,6 +1337,7 @@ def execute_local_task(
overwrite_cache: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""
Execute a @task-decorated function or TaskTemplate task.
Expand All @@ -1338,6 +1354,7 @@ def execute_local_task(
:param overwrite_cache: If True, will overwrite the cache.
:param envs: Environment variables to set for the execution.
:param tags: Tags to set for the execution.
:param cluster_pool: Specify cluster pool on which newly created execution should be placed.
:return: FlyteWorkflowExecution object.
"""
resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version)
Expand Down Expand Up @@ -1368,6 +1385,7 @@ def execute_local_task(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)

def execute_local_workflow(
Expand All @@ -1386,6 +1404,7 @@ def execute_local_workflow(
overwrite_cache: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""
Execute an @workflow decorated function.
Expand All @@ -1402,6 +1421,7 @@ def execute_local_workflow(
:param overwrite_cache:
:param envs:
:param tags:
:param cluster_pool:
:return:
"""
resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version)
Expand Down Expand Up @@ -1450,6 +1470,7 @@ def execute_local_workflow(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)

def execute_local_launch_plan(
Expand All @@ -1466,6 +1487,7 @@ def execute_local_launch_plan(
overwrite_cache: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""
Expand All @@ -1480,6 +1502,7 @@ def execute_local_launch_plan(
:param overwrite_cache: If True, will overwrite the cache.
:param envs: Environment variables to be passed into the execution.
:param tags: Tags to be passed into the execution.
:param cluster_pool: Specify cluster pool on which newly created execution should be placed.
:return: FlyteWorkflowExecution object
"""
try:
Expand Down Expand Up @@ -1509,6 +1532,7 @@ def execute_local_launch_plan(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)

###################################
Expand Down
2 changes: 2 additions & 0 deletions tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,13 @@ def test_execute_python_task(flyteclient, flyte_workflows_register, flyte_remote
overwrite_cache=True,
envs={"foo": "bar"},
tags=["flyte"],
cluster_pool="gpu",
)
assert execution.outputs["t1_int_output"] == 12
assert execution.outputs["c"] == "world"
assert execution.spec.envs == {"foo": "bar"}
assert execution.spec.tags == ["flyte"]
assert execution.spec.cluster_assignment.cluster_pool == "gpu"


def test_execute_python_workflow_and_launch_plan(flyteclient, flyte_workflows_register, flyte_remote_env):
Expand Down

0 comments on commit 63b98a7

Please sign in to comment.