diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 16235a68ec..d08685f06d 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -326,7 +326,7 @@ class AsyncAgentExecutorMixin: def execute(self: PythonTask, **kwargs) -> LiteralMap: ctx = FlyteContext.current_context() - ss = ctx.serialization_settings or SerializationSettings(ImageConfig()) + ss = ctx.serialization_settings or SerializationSettings(ImageConfig.auto_default_image()) output_prefix = ctx.file_access.get_random_remote_directory() self.resource_meta = None diff --git a/flytekit/extend/backend/utils.py b/flytekit/extend/backend/utils.py index 4dcdf3174a..8aa7952134 100644 --- a/flytekit/extend/backend/utils.py +++ b/flytekit/extend/backend/utils.py @@ -20,9 +20,9 @@ def convert_to_flyte_phase(state: str) -> TaskExecution.Phase: Convert the state from the agent to the phase in flyte. """ state = state.lower() - if state in ["failed", "timeout", "timedout", "canceled", "skipped", "internal_error"]: + if state in ["failed", "timeout", "timedout", "canceled", "cancelled", "skipped", "internal_error"]: return TaskExecution.FAILED - elif state in ["done", "succeeded", "success"]: + elif state in ["done", "succeeded", "success", "completed"]: return TaskExecution.SUCCEEDED elif state in ["running", "terminating"]: return TaskExecution.RUNNING diff --git a/flytekit/extras/tasks/shell.py b/flytekit/extras/tasks/shell.py index 32ae33fcc7..296666c85e 100644 --- a/flytekit/extras/tasks/shell.py +++ b/flytekit/extras/tasks/shell.py @@ -250,7 +250,10 @@ def __init__( if task_config is not None: fully_qualified_class_name = task_config.__module__ + "." + task_config.__class__.__name__ - if not fully_qualified_class_name == "flytekitplugins.pod.task.Pod": + if fully_qualified_class_name not in [ + "flytekitplugins.pod.task.Pod", + "flytekitplugins.slurm.script.task.Slurm", + ]: raise ValueError("TaskConfig can either be empty - indicating simple container task or a PodConfig.") # Each instance of NotebookTask instantiates an underlying task with a dummy function that will only be used @@ -259,11 +262,14 @@ def __init__( # errors. # This seem like a hack. We should use a plugin_class that doesn't require a fake-function to make work. plugin_class = TaskPlugins.find_pythontask_plugin(type(task_config)) - self._config_task_instance = plugin_class(task_config=task_config, task_function=_dummy_task_func) - # Rename the internal task so that there are no conflicts at serialization time. Technically these internal - # tasks should not be serialized at all, but we don't currently have a mechanism for skipping Flyte entities - # at serialization time. - self._config_task_instance._name = f"_bash.{name}" + if plugin_class.__name__ in ["SlurmShellTask"]: + self._config_task_instance = None + else: + self._config_task_instance = plugin_class(task_config=task_config, task_function=_dummy_task_func) + # Rename the internal task so that there are no conflicts at serialization time. Technically these internal + # tasks should not be serialized at all, but we don't currently have a mechanism for skipping Flyte entities + # at serialization time. + self._config_task_instance._name = f"_bash.{name}" self._script = script self._script_file = script_file self._debug = debug @@ -275,7 +281,9 @@ def __init__( super().__init__( name, task_config, - task_type=self._config_task_instance.task_type, + task_type=kwargs.pop("task_type") + if self._config_task_instance is None + else self._config_task_instance.task_type, interface=Interface(inputs=inputs, outputs=outputs), **kwargs, ) @@ -309,7 +317,10 @@ def script_file(self) -> typing.Optional[os.PathLike]: return self._script_file def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: - return self._config_task_instance.pre_execute(user_params) + if self._config_task_instance is None: + return user_params + else: + return self._config_task_instance.pre_execute(user_params) def execute(self, **kwargs) -> typing.Any: """ @@ -367,7 +378,10 @@ def execute(self, **kwargs) -> typing.Any: return None def post_execute(self, user_params: ExecutionParameters, rval: typing.Any) -> typing.Any: - return self._config_task_instance.post_execute(user_params, rval) + if self._config_task_instance is None: + return rval + else: + return self._config_task_instance.post_execute(user_params, rval) class RawShellTask(ShellTask): diff --git a/plugins/flytekit-slurm/README.md b/plugins/flytekit-slurm/README.md new file mode 100644 index 0000000000..af6596cf28 --- /dev/null +++ b/plugins/flytekit-slurm/README.md @@ -0,0 +1,5 @@ +# Flytekit Slurm Plugin + +The Slurm agent is designed to integrate Flyte workflows with Slurm-managed high-performance computing (HPC) clusters, enabling users to leverage Slurm's capability of compute resource allocation, scheduling, and monitoring. + +This [guide](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md) provides a concise overview of the design philosophy behind the Slurm agent and explains how to set up a local environment for testing the agent. diff --git a/plugins/flytekit-slurm/assets/basic_arch.png b/plugins/flytekit-slurm/assets/basic_arch.png new file mode 100644 index 0000000000..b1ee5d4771 Binary files /dev/null and b/plugins/flytekit-slurm/assets/basic_arch.png differ diff --git a/plugins/flytekit-slurm/assets/flyte_client.png b/plugins/flytekit-slurm/assets/flyte_client.png new file mode 100644 index 0000000000..454769bce5 Binary files /dev/null and b/plugins/flytekit-slurm/assets/flyte_client.png differ diff --git a/plugins/flytekit-slurm/assets/overview_v2.png b/plugins/flytekit-slurm/assets/overview_v2.png new file mode 100644 index 0000000000..c47caa1304 Binary files /dev/null and b/plugins/flytekit-slurm/assets/overview_v2.png differ diff --git a/plugins/flytekit-slurm/assets/remote_tiny_slurm_cluster.png b/plugins/flytekit-slurm/assets/remote_tiny_slurm_cluster.png new file mode 100644 index 0000000000..276b93f304 Binary files /dev/null and b/plugins/flytekit-slurm/assets/remote_tiny_slurm_cluster.png differ diff --git a/plugins/flytekit-slurm/assets/slurm_basic_result.png b/plugins/flytekit-slurm/assets/slurm_basic_result.png new file mode 100644 index 0000000000..4b15aeea51 Binary files /dev/null and b/plugins/flytekit-slurm/assets/slurm_basic_result.png differ diff --git a/plugins/flytekit-slurm/demo.md b/plugins/flytekit-slurm/demo.md new file mode 100644 index 0000000000..170632b19b --- /dev/null +++ b/plugins/flytekit-slurm/demo.md @@ -0,0 +1,264 @@ +# Slurm Agent Demo + +> Note: This document is still a work in progress, focusing on demonstrating the initial implementation. It will be updated and refined frequently until a stable version is ready. + +In this guide, we will briefly introduce how to setup an environment to test Slurm agent locally without running the backend service (e.g., flyte agent gRPC server). It covers both basic and advanced use cases: the basic use case involves executing a shell script directly, while the advanced use case enables running user-defined functions on a Slurm cluster. + +## Table of Content +* [Overview](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#overview) +* [Setup a Local Test Environment](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#setup-a-local-test-environment) + * [Flyte Client (Localhost)](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#flyte-client-localhost) + * [Remote Tiny Slurm Cluster](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#remote-tiny-slurm-cluster) + * [SSH Configuration](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#ssh-configuration) + * [(Optional) Setup Amazon S3 Bucket](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#optional-setup-amazon-s3-bucket) +* [Rich Use Cases](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#rich-use-cases) + * [`SlurmTask`](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#slurmtask) + * [`SlurmShellTask`](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#slurmshelltask) + * [`SlurmFunctionTask`](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#slurmfunctiontask) + +## Overview +Slurm agent on the highest level has three core methods to interact with a Slurm cluster: +1. `create`: Use `srun` or `sbatch` to run a job on a Slurm cluster +2. `get`: Use `scontrol show job ` to monitor the Slurm job state +3. `delete`: Use `scancel ` to cancel the Slurm job (this method is still under test) + +In the simplest form, Slurm agent supports directly running a batch script using `sbatch` on a Slurm cluster as shown below: + +![](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/assets/basic_arch.png) + +## Setup a Local Test Environment +Without running the backend service, we can setup an environment to test Slurm agent locally. The setup consists of two main components: a client (localhost) and a remote tiny Slurm cluster. Then, we need to configure SSH connection to facilitate communication between the two, which relies on `asyncssh`. Additionally, an S3-compatible object storage is needed for advanced use cases and we choose [Amazon S3](https://us-west-2.console.aws.amazon.com/s3/get-started?region=us-west-2&bucketType=general) for demonstration here. +> Note: A persistence layer (such as S3-compatible object storage) becomes essential as scenarios grow more complex, especially when integrating heterogeneous task types into a workflow in the future. + +### Flyte Client (Localhost) +1. Setup a local Flyte cluster following this [official guide](https://docs.flyte.org/en/latest/community/contribute/contribute_code.html#how-to-setup-dev-environment-for-flytekit) +2. Build a virtual environment (e.g., [poetry](https://python-poetry.org/), [conda](https://docs.conda.io/en/latest/)) and activate it +3. Clone Flytekit [repo](https://github.com/flyteorg/flytekit), checkout the Slurm agent [PR](https://github.com/flyteorg/flytekit/pull/3005/), and install Flytekit +``` +git clone https://github.com/flyteorg/flytekit.git +gh pr checkout 3005 +make setup && pip install -e . +``` +4. Install Flytekit Slurm agent +``` +cd plugins/flytekit-slurm/ +pip install -e . +``` + +### Remote Tiny Slurm Cluster +To simplify the setup process, we follow this [guide](https://github.com/JiangJiaWei1103/Slurm-101) to configure a single-host Slurm cluster, covering `slurmctld` (the central management daemon) and `slurmd` (the compute node daemon). + +After building a Slurm cluster, we need to install Flytekit and Slurm agent, just as what we did in the previous [section](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#flyte-client-localhost). +1. Build a virtual environment and activate it (we take `poetry` as an example): +``` +poetry new demo-env + +# For running a subshell with the virtual environment activated +poetry self add poetry-plugin-shell + +# Activate the virtual environment +poetry shell +``` +2. Clone Flytekit [repo](https://github.com/flyteorg/flytekit), checkout the Slurm agent [PR](https://github.com/flyteorg/flytekit/pull/3005/), and install Flytekit +``` +git clone https://github.com/flyteorg/flytekit.git +gh pr checkout 3005 +make setup && pip install -e . +``` +3. Install Flytekit Slurm agent +``` +cd plugins/flytekit-slurm/ +pip install -e . +``` + +### SSH Configuration +To facilitate communication between the Flyte client and the remote Slurm cluster, we setup SSH on the Flyte client side as follows: +1. Create a new authentication key pair +``` +ssh-keygen -t rsa -b 4096 +``` +2. Copy the public key into the remote Slurm cluster +``` +ssh-copy-id @ +``` +3. Enable key-based authentication +``` +# ~/.ssh/config +Host + HostName + Port + User + IdentityFile +``` +Then, run a sanity check to make sure we can connect to the Slurm cluster: +``` +ssh +``` +Simple and elegant! + +### (Optional) Setup Amazon S3 Bucket +For those interested in advanced use cases, in which user-defined functions are sent and executed on the Slurm cluster, an S3-compitable object storage becomes a necessary component. Following summarizes the setup process: +1. Click "Create bucket" button (marked in yellow) to create a bucket on this [page](https://us-west-2.console.aws.amazon.com/s3/get-started?region=us-west-2&bucketType=general) + * Give the cluster an unique name and leave other settings as default +2. Click the user on the top right corner and go to "Security credentials" +3. Create an access key and save it +4. Configure AWS access on **both** machines +``` +# ~/.aws/config +[default] +region= + +# ~/.aws/credentials +[default] +aws_access_key_id= +aws_secret_access_key= +``` + +Now, both machines have access to the Amazon S3 bucket. Perfect! + + +## Rich Use Cases +In this section, we will demonstrate three supported use cases, ranging from basic to advanced. + +### `SlurmTask` +In the simplest use case, we specify the path to the batch script that is already available on the cluster. + +Suppose we have a batch script as follows: +``` +#!/bin/bash + +echo "Hello AWS slurm, run a Flyte SlurmTask!" >> ./echo_aws.txt +``` + +We use the following python script to test Slurm agent on the [client](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#flyte-client-localhost): +```python +import os + +from flytekit import workflow +from flytekitplugins.slurm import SlurmRemoteScript, SlurmTask + + +echo_job = SlurmTask( + name="", + task_config=SlurmRemoteScript( + slurm_host="", + batch_script_path="", + sbatch_conf={ + "partition": "debug", + "job-name": "tiny-slurm", + } + ) +) + + +@workflow +def wf() -> None: + echo_job() + + +if __name__ == "__main__": + from flytekit.clis.sdk_in_container import pyflyte + from click.testing import CliRunner + + runner = CliRunner() + path = os.path.realpath(__file__) + + print(f">>> LOCAL EXEC <<<") + result = runner.invoke(pyflyte.main, ["run", path, "wf"]) + print(result.output) +``` + +### `SlurmShellTask` +`SlurmShellTask` offers users the flexibility to define the content of shell scripts. Below is an example of creating a task that executes a Python script already present on the Slurm cluster: +```python +import os + +from flytekit import workflow +from flytekitplugins.slurm import Slurm, SlurmShellTask + + +shell_task = SlurmShellTask( + name="test-shell", + script="""#!/bin/bash +# We can define sbatch options here, but using sbatch_conf can be more neat +echo "Run a Flyte SlurmShellTask...\n" + +# Run a python script on Slurm +# Activate the virtual env first if any +python3 +""", + task_config=Slurm( + slurm_host="", + sbatch_conf={ + "partition": "debug", + "job-name": "tiny-slurm", + } + ), +) + + +@workflow +def wf() -> None: + shell_task() + + +if __name__ == "__main__": + from flytekit.clis.sdk_in_container import pyflyte + from click.testing import CliRunner + + runner = CliRunner() + path = os.path.realpath(__file__) + + print(f">>> LOCAL EXEC <<<") + result = runner.invoke(pyflyte.main, ["run", path, "wf"]) + print(result.output) +``` + +### `SlurmFunctionTask` +In the most advanced use case, `SlurmFunctionTask` allows users to define custom Python functions that are sent to and executed on the Slurm cluster. Following figure demonstrates the process of running a `SlurmFunctionTask`: + +![](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/assets/overview_v2.png) + +```python +import os + +from flytekit import task, workflow +from flytekitplugins.slurm import SlurmFunction + + +@task( + task_config=SlurmFunction( + slurm_host="", + sbatch_conf={ + "partition": "debug", + "job-name": "tiny-slurm", + } + ) +) +def plus_one(x: int) -> int: + return x + 1 + + +@task +def greet(year: int) -> str: + return f"Hello {year}!!!" + + +@workflow +def wf(x: int) -> str: + x = plus_one(x=x) + msg = greet(year=x) + return msg + + +if __name__ == "__main__": + from flytekit.clis.sdk_in_container import pyflyte + from click.testing import CliRunner + + runner = CliRunner() + path = os.path.realpath(__file__) + + print(f">>> LOCAL EXEC <<<") + result = runner.invoke(pyflyte.main, ["run", "--raw-output-data-prefix", "", path, "wf", "--x", 2024]) + print(result.output) +``` diff --git a/plugins/flytekit-slurm/flytekitplugins/slurm/__init__.py b/plugins/flytekit-slurm/flytekitplugins/slurm/__init__.py new file mode 100644 index 0000000000..75dc5ea9ff --- /dev/null +++ b/plugins/flytekit-slurm/flytekitplugins/slurm/__init__.py @@ -0,0 +1,4 @@ +from .function.agent import SlurmFunctionAgent +from .function.task import SlurmFunction, SlurmFunctionTask +from .script.agent import SlurmScriptAgent +from .script.task import Slurm, SlurmRemoteScript, SlurmShellTask, SlurmTask diff --git a/plugins/flytekit-slurm/flytekitplugins/slurm/function/agent.py b/plugins/flytekit-slurm/flytekitplugins/slurm/function/agent.py new file mode 100644 index 0000000000..0f9f0119c2 --- /dev/null +++ b/plugins/flytekit-slurm/flytekitplugins/slurm/function/agent.py @@ -0,0 +1,115 @@ +from dataclasses import dataclass +from typing import Dict, Optional + +import asyncssh +from asyncssh import SSHClientConnection + +from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta +from flytekit.extend.backend.utils import convert_to_flyte_phase +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + + +@dataclass +class SlurmJobMetadata(ResourceMeta): + """Slurm job metadata. + + Args: + job_id: Slurm job id. + """ + + job_id: str + slurm_host: str + + +class SlurmFunctionAgent(AsyncAgentBase): + name = "Slurm Function Agent" + + # SSH connection pool for multi-host environment + _conn: Optional[SSHClientConnection] = None + + def __init__(self) -> None: + super(SlurmFunctionAgent, self).__init__(task_type_name="slurm_fn", metadata_type=SlurmJobMetadata) + + async def create( + self, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + **kwargs, + ) -> SlurmJobMetadata: + # Retrieve task config + slurm_host = task_template.custom["slurm_host"] + srun_conf = task_template.custom["srun_conf"] + + # Construct srun command for Slurm cluster + cmd = _get_srun_cmd(srun_conf=srun_conf, entrypoint=" ".join(task_template.container.args)) + + # Run Slurm job + if self._conn is None: + await self._connect(slurm_host) + res = await self._conn.run(cmd, check=True) + + # Direct return for sbatch + # job_id = res.stdout.split()[-1] + # Use echo trick for srun + job_id = res.stdout.strip() + + return SlurmJobMetadata(job_id=job_id, slurm_host=slurm_host) + + async def get(self, resource_meta: SlurmJobMetadata, **kwargs) -> Resource: + await self._connect(resource_meta.slurm_host) + res = await self._conn.run(f"scontrol show job {resource_meta.job_id}", check=True) + + # Determine the current flyte phase from Slurm job state + job_state = "running" + for o in res.stdout.split(" "): + if "JobState" in o: + job_state = o.split("=")[1].strip().lower() + cur_phase = convert_to_flyte_phase(job_state) + + return Resource(phase=cur_phase) + + async def delete(self, resource_meta: SlurmJobMetadata, **kwargs) -> None: + await self._connect(resource_meta.slurm_host) + _ = await self._conn.run(f"scancel {resource_meta.job_id}", check=True) + + async def _connect(self, slurm_host: str) -> None: + """Make an SSH client connection.""" + self._conn = await asyncssh.connect(host=slurm_host) + + +def _get_srun_cmd(srun_conf: Dict[str, str], entrypoint: str) -> str: + """Construct Slurm srun command. + + Flyte entrypoint, pyflyte-execute, is run within a bash shell process. + + Args: + srun_conf: Options of srun command. + entrypoint: Flyte entrypoint. + + Returns: + cmd: Slurm srun command. + """ + # Setup srun options + cmd = ["srun"] + for opt, val in srun_conf.items(): + cmd.extend([f"--{opt}", str(val)]) + + cmd.extend(["bash", "-c"]) + cmd = " ".join(cmd) + + cmd += f""" '# Activate the pre-built virtual env + . /home/ubuntu/.cache/pypoetry/virtualenvs/demo-poetry-RLi6T71_-py3.12/bin/activate; + + # Run entrypoints in a subshell with virtual env activated, + # including pyflyte-fast-execute and pyflyte-execute + {entrypoint}; + + # A trick to show Slurm job id on stdout + echo $SLURM_JOB_ID;' + """ + + return cmd + + +AgentRegistry.register(SlurmFunctionAgent()) diff --git a/plugins/flytekit-slurm/flytekitplugins/slurm/function/task.py b/plugins/flytekit-slurm/flytekitplugins/slurm/function/task.py new file mode 100644 index 0000000000..ac55109d75 --- /dev/null +++ b/plugins/flytekit-slurm/flytekitplugins/slurm/function/task.py @@ -0,0 +1,72 @@ +""" +Slurm task. +""" + +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional, Union + +from flytekit import FlyteContextManager, PythonFunctionTask +from flytekit.configuration import SerializationSettings +from flytekit.extend import TaskPlugins +from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin +from flytekit.image_spec import ImageSpec + + +@dataclass +class SlurmFunction(object): + """Configure Slurm settings. Note that we focus on srun command now. + + Compared with spark, please refer to https://api-docs.databricks.com/python/pyspark/latest/api/pyspark.SparkContext.html. + + Args: + slurm_host: Slurm host name. We assume there's no default Slurm host now. + srun_conf: Options of srun command. + """ + + slurm_host: str + srun_conf: Optional[Dict[str, str]] = None + + def __post_init__(self): + if self.srun_conf is None: + self.srun_conf = {} + + +class SlurmFunctionTask(AsyncAgentExecutorMixin, PythonFunctionTask[SlurmFunction]): + """ + Actual Plugin that transforms the local python code for execution within a slurm context... + """ + + _TASK_TYPE = "slurm_fn" + + def __init__( + self, + task_config: SlurmFunction, + task_function: Callable, + container_image: Optional[Union[str, ImageSpec]] = None, + **kwargs, + ): + super(SlurmFunctionTask, self).__init__( + task_config=task_config, + task_type=self._TASK_TYPE, + task_function=task_function, + container_image=container_image, + **kwargs, + ) + + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + return { + "slurm_host": self.task_config.slurm_host, + "srun_conf": self.task_config.srun_conf, + } + + def execute(self, **kwargs) -> Any: + ctx = FlyteContextManager.current_context() + if ctx.execution_state and ctx.execution_state.is_local_execution(): + # Mimic the propeller's behavior in local agent test + return AsyncAgentExecutorMixin.execute(self, **kwargs) + else: + # Execute the task with a direct python function call + return PythonFunctionTask.execute(self, **kwargs) + + +TaskPlugins.register_pythontask_plugin(SlurmFunction, SlurmFunctionTask) diff --git a/plugins/flytekit-slurm/flytekitplugins/slurm/script/agent.py b/plugins/flytekit-slurm/flytekitplugins/slurm/script/agent.py new file mode 100644 index 0000000000..7591cc6b5a --- /dev/null +++ b/plugins/flytekit-slurm/flytekitplugins/slurm/script/agent.py @@ -0,0 +1,136 @@ +import tempfile +from dataclasses import dataclass +from typing import Dict, List, Optional + +import asyncssh +from asyncssh import SSHClientConnection + +from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta +from flytekit.extend.backend.utils import convert_to_flyte_phase +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + + +@dataclass +class SlurmJobMetadata(ResourceMeta): + """Slurm job metadata. + + Args: + job_id: Slurm job id. + """ + + job_id: str + slurm_host: str + + +class SlurmScriptAgent(AsyncAgentBase): + name = "Slurm Script Agent" + + # SSH connection pool for multi-host environment + # _ssh_clients: Dict[str, SSHClientConnection] + _conn: Optional[SSHClientConnection] = None + + # Tmp remote path of the batch script + REMOTE_PATH = "/tmp/echo_shell.slurm" + + # Dummy script content + DUMMY_SCRIPT = "#!/bin/bash" + + def __init__(self) -> None: + super(SlurmScriptAgent, self).__init__(task_type_name="slurm", metadata_type=SlurmJobMetadata) + + async def create( + self, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + **kwargs, + ) -> SlurmJobMetadata: + # Retrieve task config + slurm_host = task_template.custom["slurm_host"] + batch_script_args = task_template.custom["batch_script_args"] + sbatch_conf = task_template.custom["sbatch_conf"] + + # Construct sbatch command for Slurm cluster + upload_script = False + if "script" in task_template.custom: + script = task_template.custom["script"] + assert script != self.DUMMY_SCRIPT, "Please write the user-defined batch script content." + + batch_script_path = self.REMOTE_PATH + upload_script = True + else: + # Assume the batch script is already on Slurm + batch_script_path = task_template.custom["batch_script_path"] + cmd = _get_sbatch_cmd( + sbatch_conf=sbatch_conf, batch_script_path=batch_script_path, batch_script_args=batch_script_args + ) + + # Run Slurm job + if self._conn is None: + await self._connect(slurm_host) + if upload_script: + with tempfile.NamedTemporaryFile("w") as f: + f.write(script) + f.flush() + async with self._conn.start_sftp_client() as sftp: + await sftp.put(f.name, self.REMOTE_PATH) + res = await self._conn.run(cmd, check=True) + + # Retrieve Slurm job id + job_id = res.stdout.split()[-1] + + return SlurmJobMetadata(job_id=job_id, slurm_host=slurm_host) + + async def get(self, resource_meta: SlurmJobMetadata, **kwargs) -> Resource: + await self._connect(resource_meta.slurm_host) + res = await self._conn.run(f"scontrol show job {resource_meta.job_id}", check=True) + + # Determine the current flyte phase from Slurm job state + job_state = "running" + for o in res.stdout.split(" "): + if "JobState" in o: + job_state = o.split("=")[1].strip().lower() + cur_phase = convert_to_flyte_phase(job_state) + + return Resource(phase=cur_phase) + + async def delete(self, resource_meta: SlurmJobMetadata, **kwargs) -> None: + await self._connect(resource_meta.slurm_host) + _ = await self._conn.run(f"scancel {resource_meta.job_id}", check=True) + + async def _connect(self, slurm_host: str) -> None: + """Make an SSH client connection.""" + self._conn = await asyncssh.connect(host=slurm_host) + + +def _get_sbatch_cmd(sbatch_conf: Dict[str, str], batch_script_path: str, batch_script_args: List[str] = None) -> str: + """Construct Slurm sbatch command. + + We assume all main scripts and dependencies are on Slurm cluster. + + Args: + sbatch_conf: Options of srun command. + batch_script_path: Absolute path of the batch script on Slurm cluster. + batch_script_args: Additional args for the batch script on Slurm cluster. + + Returns: + cmd: Slurm sbatch command. + """ + # Setup sbatch options + cmd = ["sbatch"] + for opt, val in sbatch_conf.items(): + cmd.extend([f"--{opt}", str(val)]) + + # Assign the batch script to run + cmd.append(batch_script_path) + + # Add args if present + if batch_script_args: + for arg in batch_script_args: + cmd.append(arg) + + cmd = " ".join(cmd) + return cmd + + +AgentRegistry.register(SlurmScriptAgent()) diff --git a/plugins/flytekit-slurm/flytekitplugins/slurm/script/task.py b/plugins/flytekit-slurm/flytekitplugins/slurm/script/task.py new file mode 100644 index 0000000000..0892dfd900 --- /dev/null +++ b/plugins/flytekit-slurm/flytekitplugins/slurm/script/task.py @@ -0,0 +1,102 @@ +""" +Slurm task. +""" + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from flytekit.configuration import SerializationSettings +from flytekit.core.base_task import PythonTask +from flytekit.core.interface import Interface +from flytekit.extend import TaskPlugins +from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin +from flytekit.extras.tasks.shell import ShellTask + + +@dataclass +class Slurm(object): + """Configure Slurm settings. Note that we focus on sbatch command now. + + Compared with spark, please refer to https://api-docs.databricks.com/python/pyspark/latest/api/pyspark.SparkContext.html. + + Args: + slurm_host: Slurm host name. We assume there's no default Slurm host now. + sbatch_conf: Options of sbatch command. For available options, please refer to + https://slurm.schedmd.com/sbatch.html. + batch_script_args: Additional args for the batch script on Slurm cluster. + """ + + slurm_host: str + sbatch_conf: Optional[Dict[str, str]] = None + batch_script_args: Optional[List[str]] = None + + def __post_init__(self): + if self.sbatch_conf is None: + self.sbatch_conf = {} + + +# See https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses +@dataclass(kw_only=True) +class SlurmRemoteScript(Slurm): + """Encounter collision if Slurm is shared btw SlurmTask and SlurmShellTask.""" + + batch_script_path: str + + +class SlurmTask(AsyncAgentExecutorMixin, PythonTask[SlurmRemoteScript]): + _TASK_TYPE = "slurm" + + def __init__( + self, + name: str, + task_config: SlurmRemoteScript, + **kwargs, + ): + super(SlurmTask, self).__init__( + task_type=self._TASK_TYPE, + name=name, + task_config=task_config, + # Dummy interface, will support this after discussion + interface=Interface(None, None), + **kwargs, + ) + + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + return { + "slurm_host": self.task_config.slurm_host, + "batch_script_path": self.task_config.batch_script_path, + "batch_script_args": self.task_config.batch_script_args, + "sbatch_conf": self.task_config.sbatch_conf, + } + + +class SlurmShellTask(AsyncAgentExecutorMixin, ShellTask[Slurm]): + _TASK_TYPE = "slurm" + + def __init__( + self, + name: str, + task_config: Slurm, + script: Optional[str] = None, + **kwargs, + ): + super(SlurmShellTask, self).__init__( + name, + task_config=task_config, + task_type=self._TASK_TYPE, + script=script, + **kwargs, + ) + + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + return { + "slurm_host": self.task_config.slurm_host, + "batch_script_args": self.task_config.batch_script_args, + "sbatch_conf": self.task_config.sbatch_conf, + # User-defined script content + "script": self._script, + } + + +TaskPlugins.register_pythontask_plugin(SlurmRemoteScript, SlurmTask) +TaskPlugins.register_pythontask_plugin(Slurm, SlurmShellTask) diff --git a/plugins/flytekit-slurm/setup.py b/plugins/flytekit-slurm/setup.py new file mode 100644 index 0000000000..2c338db47e --- /dev/null +++ b/plugins/flytekit-slurm/setup.py @@ -0,0 +1,40 @@ +from setuptools import setup + +PLUGIN_NAME = "slurm" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>1.13.8", "flyteidl>=1.11.0b1", "asyncssh"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="This package holds the Slurm plugins for flytekit", + namespace_packages=["flytekitplugins"], + packages=[ + f"flytekitplugins.{PLUGIN_NAME}", + f"flytekitplugins.{PLUGIN_NAME}.script", + f"flytekitplugins.{PLUGIN_NAME}.function", + ], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.9", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, +) diff --git a/plugins/flytekit-slurm/tests/__init__.py b/plugins/flytekit-slurm/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-slurm/tests/test_slurm.py b/plugins/flytekit-slurm/tests/test_slurm.py new file mode 100644 index 0000000000..e69de29bb2