diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 0ee09dc981..326ce4913d 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -243,6 +243,15 @@ class RunLevelParams(PyFlyteParams): "if `from-server` option is used", ) ) + cluster_pool: str = make_field( + click.Option( + param_decls=["--cluster-pool", "cluster_pool"], + required=False, + type=str, + default="", + help="Assign newly created execution to a given cluster pool", + ) + ) computed_params: RunLevelComputedParams = field(default_factory=RunLevelComputedParams) _remote: typing.Optional[FlyteRemote] = None @@ -427,6 +436,7 @@ def run_remote( overwrite_cache=run_level_params.overwrite_cache, envs=run_level_params.envvars, tags=run_level_params.tags, + cluster_pool=run_level_params.cluster_pool, ) console_url = remote.generate_console_url(execution) diff --git a/flytekit/clis/sdk_in_container/serve.py b/flytekit/clis/sdk_in_container/serve.py index 53f02b6481..145dc90212 100644 --- a/flytekit/clis/sdk_in_container/serve.py +++ b/flytekit/clis/sdk_in_container/serve.py @@ -42,16 +42,17 @@ def serve(_: click.Context, port, worker, timeout): async def _start_grpc_server(port: int, worker: int, timeout: int): click.secho("Starting up the server to expose the prometheus metrics...", fg="blue") + from flytekit.extend.backend.agent_service import AsyncAgentService + try: from prometheus_client import start_http_server - from flytekit.extend.backend.agent_service import AsyncAgentService - start_http_server(9090) except ImportError as e: click.secho(f"Failed to start the prometheus server with error {e}", fg="red") click.secho("Starting the agent service...", fg="blue") server = aio.server(futures.ThreadPoolExecutor(max_workers=worker)) + add_AsyncAgentServiceServicer_to_server(AsyncAgentService(), server) server.add_insecure_port(f"[::]:{port}") diff --git a/flytekit/models/execution.py b/flytekit/models/execution.py index 468adb8884..b76c01c967 100644 --- a/flytekit/models/execution.py +++ b/flytekit/models/execution.py @@ -5,6 +5,7 @@ from typing import Optional import flyteidl +import flyteidl.admin.cluster_assignment_pb2 as _cluster_assignment_pb2 import flyteidl.admin.execution_pb2 as _execution_pb2 import flyteidl.admin.node_execution_pb2 as _node_execution_pb2 import flyteidl.admin.task_execution_pb2 as _task_execution_pb2 @@ -179,6 +180,7 @@ def __init__( overwrite_cache: Optional[bool] = None, envs: Optional[_common_models.Envs] = None, tags: Optional[typing.List[str]] = None, + cluster_assignment: Optional[ClusterAssignment] = None, ): """ :param flytekit.models.core.identifier.Identifier launch_plan: Launch plan unique identifier to execute @@ -210,6 +212,7 @@ def __init__( self._overwrite_cache = overwrite_cache self._envs = envs self._tags = tags + self._cluster_assignment = cluster_assignment @property def launch_plan(self): @@ -288,6 +291,10 @@ def envs(self) -> Optional[_common_models.Envs]: def tags(self) -> Optional[typing.List[str]]: return self._tags + @property + def cluster_assignment(self) -> Optional[ClusterAssignment]: + return self._cluster_assignment + def to_flyte_idl(self): """ :rtype: flyteidl.admin.execution_pb2.ExecutionSpec @@ -308,6 +315,7 @@ def to_flyte_idl(self): overwrite_cache=self.overwrite_cache, envs=self.envs.to_flyte_idl() if self.envs else None, tags=self.tags, + cluster_assignment=self._cluster_assignment.to_flyte_idl() if self._cluster_assignment else None, ) @classmethod @@ -334,8 +342,42 @@ def from_flyte_idl(cls, p): overwrite_cache=p.overwrite_cache, envs=_common_models.Envs.from_flyte_idl(p.envs) if p.HasField("envs") else None, tags=p.tags, + cluster_assignment=ClusterAssignment.from_flyte_idl(p.cluster_assignment) + if p.HasField("cluster_assignment") + else None, + ) + + +class ClusterAssignment(_common_models.FlyteIdlEntity): + def __init__(self, cluster_pool=None): + """ + :param Text cluster_pool: + """ + self._cluster_pool = cluster_pool + + @property + def cluster_pool(self): + """ + :rtype: Text + """ + return self._cluster_pool + + def to_flyte_idl(self): + """ + :rtype: flyteidl.admin._cluster_assignment_pb2.ClusterAssignment + """ + return _cluster_assignment_pb2.ClusterAssignment( + cluster_pool_name=self._cluster_pool, ) + @classmethod + def from_flyte_idl(cls, p): + """ + :param flyteidl.admin._cluster_assignment_pb2.ClusterAssignment p: + :rtype: flyteidl.admin.ClusterAssignment + """ + return cls(cluster_pool=p.cluster_pool_name) + class LiteralMapBlob(_common_models.FlyteIdlEntity): def __init__(self, values=None, uri=None): diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 96ce18f562..78fa3271e7 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -54,6 +54,7 @@ from flytekit.models.core.identifier import Identifier, ResourceType, SignalIdentifier, WorkflowExecutionIdentifier from flytekit.models.core.workflow import NodeMetadata from flytekit.models.execution import ( + ClusterAssignment, ExecutionMetadata, ExecutionSpec, NodeExecutionGetDataResponse, @@ -962,6 +963,7 @@ def _execute( overwrite_cache: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, + cluster_pool: typing.Optional[str] = None, ) -> FlyteWorkflowExecution: """Common method for execution across all entities. @@ -978,6 +980,7 @@ def _execute( be available, overwriting the stored data once execution finishes successfully. :param envs: Environment variables to set for the execution. :param tags: Tags to set for the execution. + :param cluster_pool: Specify cluster pool on which newly created execution should be placed. :returns: :class:`~flytekit.remote.workflow_execution.FlyteWorkflowExecution` """ if execution_name is not None and execution_name_prefix is not None: @@ -1047,6 +1050,7 @@ def _execute( security_context=options.security_context, envs=common_models.Envs(envs) if envs else None, tags=tags, + cluster_assignment=ClusterAssignment(cluster_pool=cluster_pool) if cluster_pool else None, ), literal_inputs, ) @@ -1106,6 +1110,7 @@ def execute( overwrite_cache: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, + cluster_pool: typing.Optional[str] = None, ) -> FlyteWorkflowExecution: """ Execute a task, workflow, or launchplan, either something that's been declared locally, or a fetched entity. @@ -1144,6 +1149,7 @@ def execute( be available, overwriting the stored data once execution finishes successfully. :param envs: Environment variables to be set for the execution. :param tags: Tags to be set for the execution. + :param cluster_pool: Specify cluster pool on which newly created execution should be placed. .. note: @@ -1166,6 +1172,7 @@ def execute( overwrite_cache=overwrite_cache, envs=envs, tags=tags, + cluster_pool=cluster_pool, ) if isinstance(entity, FlyteWorkflow): return self.execute_remote_wf( @@ -1181,6 +1188,7 @@ def execute( overwrite_cache=overwrite_cache, envs=envs, tags=tags, + cluster_pool=cluster_pool, ) if isinstance(entity, PythonTask): return self.execute_local_task( @@ -1197,6 +1205,7 @@ def execute( overwrite_cache=overwrite_cache, envs=envs, tags=tags, + cluster_pool=cluster_pool, ) if isinstance(entity, WorkflowBase): return self.execute_local_workflow( @@ -1214,6 +1223,7 @@ def execute( overwrite_cache=overwrite_cache, envs=envs, tags=tags, + cluster_pool=cluster_pool, ) if isinstance(entity, LaunchPlan): return self.execute_local_launch_plan( @@ -1229,6 +1239,7 @@ def execute( overwrite_cache=overwrite_cache, envs=envs, tags=tags, + cluster_pool=cluster_pool, ) raise NotImplementedError(f"entity type {type(entity)} not recognized for execution") @@ -1249,6 +1260,7 @@ def execute_remote_task_lp( overwrite_cache: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, + cluster_pool: typing.Optional[str] = None, ) -> FlyteWorkflowExecution: """Execute a FlyteTask, or FlyteLaunchplan. @@ -1267,6 +1279,7 @@ def execute_remote_task_lp( overwrite_cache=overwrite_cache, envs=envs, tags=tags, + cluster_pool=cluster_pool, ) def execute_remote_wf( @@ -1283,6 +1296,7 @@ def execute_remote_wf( overwrite_cache: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, + cluster_pool: typing.Optional[str] = None, ) -> FlyteWorkflowExecution: """Execute a FlyteWorkflow. @@ -1302,6 +1316,7 @@ def execute_remote_wf( overwrite_cache=overwrite_cache, envs=envs, tags=tags, + cluster_pool=cluster_pool, ) # Flytekit Entities @@ -1322,6 +1337,7 @@ def execute_local_task( overwrite_cache: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, + cluster_pool: typing.Optional[str] = None, ) -> FlyteWorkflowExecution: """ Execute a @task-decorated function or TaskTemplate task. @@ -1338,6 +1354,7 @@ def execute_local_task( :param overwrite_cache: If True, will overwrite the cache. :param envs: Environment variables to set for the execution. :param tags: Tags to set for the execution. + :param cluster_pool: Specify cluster pool on which newly created execution should be placed. :return: FlyteWorkflowExecution object. """ resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version) @@ -1368,6 +1385,7 @@ def execute_local_task( overwrite_cache=overwrite_cache, envs=envs, tags=tags, + cluster_pool=cluster_pool, ) def execute_local_workflow( @@ -1386,6 +1404,7 @@ def execute_local_workflow( overwrite_cache: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, + cluster_pool: typing.Optional[str] = None, ) -> FlyteWorkflowExecution: """ Execute an @workflow decorated function. @@ -1402,6 +1421,7 @@ def execute_local_workflow( :param overwrite_cache: :param envs: :param tags: + :param cluster_pool: :return: """ resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version) @@ -1450,6 +1470,7 @@ def execute_local_workflow( overwrite_cache=overwrite_cache, envs=envs, tags=tags, + cluster_pool=cluster_pool, ) def execute_local_launch_plan( @@ -1466,6 +1487,7 @@ def execute_local_launch_plan( overwrite_cache: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, tags: typing.Optional[typing.List[str]] = None, + cluster_pool: typing.Optional[str] = None, ) -> FlyteWorkflowExecution: """ @@ -1480,6 +1502,7 @@ def execute_local_launch_plan( :param overwrite_cache: If True, will overwrite the cache. :param envs: Environment variables to be passed into the execution. :param tags: Tags to be passed into the execution. + :param cluster_pool: Specify cluster pool on which newly created execution should be placed. :return: FlyteWorkflowExecution object """ try: @@ -1509,6 +1532,7 @@ def execute_local_launch_plan( overwrite_cache=overwrite_cache, envs=envs, tags=tags, + cluster_pool=cluster_pool, ) ################################### diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 72b3cb9192..84f9746d2f 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -180,11 +180,13 @@ def test_execute_python_task(flyteclient, flyte_workflows_register, flyte_remote overwrite_cache=True, envs={"foo": "bar"}, tags=["flyte"], + cluster_pool="gpu", ) assert execution.outputs["t1_int_output"] == 12 assert execution.outputs["c"] == "world" assert execution.spec.envs == {"foo": "bar"} assert execution.spec.tags == ["flyte"] + assert execution.spec.cluster_assignment.cluster_pool == "gpu" def test_execute_python_workflow_and_launch_plan(flyteclient, flyte_workflows_register, flyte_remote_env):