Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass cluster pool when creating executions #1208

Merged
merged 8 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,15 @@ class RunLevelParams(PyFlyteParams):
"if `from-server` option is used",
)
)
cluster_pool: str = make_field(
click.Option(
param_decls=["--cluster-pool", "cluster_pool"],
required=False,
type=str,
default="",
help="Assign newly created execution to a given cluster pool",
)
)
computed_params: RunLevelComputedParams = field(default_factory=RunLevelComputedParams)
_remote: typing.Optional[FlyteRemote] = None

Expand Down Expand Up @@ -427,6 +436,7 @@ def run_remote(
overwrite_cache=run_level_params.overwrite_cache,
envs=run_level_params.envvars,
tags=run_level_params.tags,
cluster_pool=run_level_params.cluster_pool,
)

console_url = remote.generate_console_url(execution)
Expand Down
4 changes: 2 additions & 2 deletions flytekit/clis/sdk_in_container/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from flyteidl.service.agent_pb2_grpc import add_AsyncAgentServiceServicer_to_server
from grpc import aio

from flytekit.extend.backend.agent_service import AsyncAgentService

_serve_help = """Start a grpc server for the agent service."""


Expand Down Expand Up @@ -52,6 +50,8 @@ async def _start_grpc_server(port: int, worker: int, timeout: int):
click.secho(f"Failed to start the prometheus server with error {e}", fg="red")
click.secho("Starting the agent service...", fg="blue")
server = aio.server(futures.ThreadPoolExecutor(max_workers=worker))
from flytekit.extend.backend.agent_service import AsyncAgentService

add_AsyncAgentServiceServicer_to_server(AsyncAgentService(), server)

server.add_insecure_port(f"[::]:{port}")
Expand Down
31 changes: 31 additions & 0 deletions flytekit/models/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional

import flyteidl
import flyteidl.admin.cluster_assignment_pb2 as _cluster_assignment_pb2
import flyteidl.admin.execution_pb2 as _execution_pb2
import flyteidl.admin.node_execution_pb2 as _node_execution_pb2
import flyteidl.admin.task_execution_pb2 as _task_execution_pb2
Expand Down Expand Up @@ -179,6 +180,7 @@ def __init__(
overwrite_cache: Optional[bool] = None,
envs: Optional[_common_models.Envs] = None,
tags: Optional[typing.List[str]] = None,
cluster_assignment: Optional[ClusterAssignment] = None,
):
"""
:param flytekit.models.core.identifier.Identifier launch_plan: Launch plan unique identifier to execute
Expand Down Expand Up @@ -210,6 +212,7 @@ def __init__(
self._overwrite_cache = overwrite_cache
self._envs = envs
self._tags = tags
self._cluster_assignment = cluster_assignment

@property
def launch_plan(self):
Expand Down Expand Up @@ -288,6 +291,10 @@ def envs(self) -> Optional[_common_models.Envs]:
def tags(self) -> Optional[typing.List[str]]:
return self._tags

@property
def cluster_assignment(self) -> Optional[ClusterAssignment]:
return self._cluster_assignment

def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.execution_pb2.ExecutionSpec
Expand All @@ -308,6 +315,7 @@ def to_flyte_idl(self):
overwrite_cache=self.overwrite_cache,
envs=self.envs.to_flyte_idl() if self.envs else None,
tags=self.tags,
cluster_assignment=self._cluster_assignment.to_flyte_idl() if self._cluster_assignment else None,
iaroslav-ciupin marked this conversation as resolved.
Show resolved Hide resolved
)

@classmethod
Expand Down Expand Up @@ -337,6 +345,29 @@ def from_flyte_idl(cls, p):
)


class ClusterAssignment(_common_models.FlyteIdlEntity):
def __init__(self, cluster_pool=None):
"""
:param Text cluster_pool:
"""
self._cluster_pool = cluster_pool

@property
def cluster_pool(self):
"""
:rtype: Text
"""
return self._cluster_pool

def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.ClusterAssignment
"""
return _cluster_assignment_pb2.ClusterAssignment(
cluster_pool_name=self.cluster_pool,
)


class LiteralMapBlob(_common_models.FlyteIdlEntity):
def __init__(self, values=None, uri=None):
"""
Expand Down
24 changes: 24 additions & 0 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from flytekit.models.core.identifier import Identifier, ResourceType, SignalIdentifier, WorkflowExecutionIdentifier
from flytekit.models.core.workflow import NodeMetadata
from flytekit.models.execution import (
ClusterAssignment,
ExecutionMetadata,
ExecutionSpec,
NodeExecutionGetDataResponse,
Expand Down Expand Up @@ -962,6 +963,7 @@ def _execute(
overwrite_cache: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""Common method for execution across all entities.

Expand All @@ -978,6 +980,7 @@ def _execute(
be available, overwriting the stored data once execution finishes successfully.
:param envs: Environment variables to set for the execution.
:param tags: Tags to set for the execution.
:param cluster_pool: Specify cluster pool on which newly created execution should be placed.
:returns: :class:`~flytekit.remote.workflow_execution.FlyteWorkflowExecution`
"""
if execution_name is not None and execution_name_prefix is not None:
Expand Down Expand Up @@ -1047,6 +1050,7 @@ def _execute(
security_context=options.security_context,
envs=common_models.Envs(envs) if envs else None,
tags=tags,
cluster_assignment=ClusterAssignment(cluster_pool=cluster_pool) if cluster_pool else None,
),
literal_inputs,
)
Expand Down Expand Up @@ -1106,6 +1110,7 @@ def execute(
overwrite_cache: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""
Execute a task, workflow, or launchplan, either something that's been declared locally, or a fetched entity.
Expand Down Expand Up @@ -1144,6 +1149,7 @@ def execute(
be available, overwriting the stored data once execution finishes successfully.
:param envs: Environment variables to be set for the execution.
:param tags: Tags to be set for the execution.
:param cluster_pool: Specify cluster pool on which newly created execution should be placed.

.. note:

Expand All @@ -1166,6 +1172,7 @@ def execute(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)
if isinstance(entity, FlyteWorkflow):
return self.execute_remote_wf(
Expand All @@ -1181,6 +1188,7 @@ def execute(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)
if isinstance(entity, PythonTask):
return self.execute_local_task(
Expand All @@ -1197,6 +1205,7 @@ def execute(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)
if isinstance(entity, WorkflowBase):
return self.execute_local_workflow(
Expand All @@ -1214,6 +1223,7 @@ def execute(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)
if isinstance(entity, LaunchPlan):
return self.execute_local_launch_plan(
Expand All @@ -1229,6 +1239,7 @@ def execute(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)
raise NotImplementedError(f"entity type {type(entity)} not recognized for execution")

Expand All @@ -1249,6 +1260,7 @@ def execute_remote_task_lp(
overwrite_cache: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""Execute a FlyteTask, or FlyteLaunchplan.

Expand All @@ -1267,6 +1279,7 @@ def execute_remote_task_lp(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)

def execute_remote_wf(
Expand All @@ -1283,6 +1296,7 @@ def execute_remote_wf(
overwrite_cache: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""Execute a FlyteWorkflow.

Expand All @@ -1302,6 +1316,7 @@ def execute_remote_wf(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)

# Flytekit Entities
Expand All @@ -1322,6 +1337,7 @@ def execute_local_task(
overwrite_cache: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""
Execute a @task-decorated function or TaskTemplate task.
Expand All @@ -1338,6 +1354,7 @@ def execute_local_task(
:param overwrite_cache: If True, will overwrite the cache.
:param envs: Environment variables to set for the execution.
:param tags: Tags to set for the execution.
:param cluster_pool: Specify cluster pool on which newly created execution should be placed.
:return: FlyteWorkflowExecution object.
"""
resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version)
Expand Down Expand Up @@ -1368,6 +1385,7 @@ def execute_local_task(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)

def execute_local_workflow(
Expand All @@ -1386,6 +1404,7 @@ def execute_local_workflow(
overwrite_cache: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""
Execute an @workflow decorated function.
Expand All @@ -1402,6 +1421,7 @@ def execute_local_workflow(
:param overwrite_cache:
:param envs:
:param tags:
:param cluster_pool:
:return:
"""
resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version)
Expand Down Expand Up @@ -1450,6 +1470,7 @@ def execute_local_workflow(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)

def execute_local_launch_plan(
Expand All @@ -1466,6 +1487,7 @@ def execute_local_launch_plan(
overwrite_cache: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""

Expand All @@ -1480,6 +1502,7 @@ def execute_local_launch_plan(
:param overwrite_cache: If True, will overwrite the cache.
:param envs: Environment variables to be passed into the execution.
:param tags: Tags to be passed into the execution.
:param cluster_pool: Specify cluster pool on which newly created execution should be placed.
:return: FlyteWorkflowExecution object
"""
try:
Expand Down Expand Up @@ -1509,6 +1532,7 @@ def execute_local_launch_plan(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)

###################################
Expand Down
2 changes: 2 additions & 0 deletions tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,13 @@ def test_execute_python_task(flyteclient, flyte_workflows_register, flyte_remote
overwrite_cache=True,
envs={"foo": "bar"},
tags=["flyte"],
cluster_pool="gpu",
)
assert execution.outputs["t1_int_output"] == 12
assert execution.outputs["c"] == "world"
assert execution.spec.envs == {"foo": "bar"}
assert execution.spec.tags == ["flyte"]
assert execution.spec.cluster_assignment.cluster_pool == "gpu"


def test_execute_python_workflow_and_launch_plan(flyteclient, flyte_workflows_register, flyte_remote_env):
Expand Down