Skip to content

Commit

Permalink
feat: Support SSH config in task config
Browse files Browse the repository at this point in the history
Add `ssh_conf` filed to let users specify connection secret

Note that reconnection is done in both `get` and `delete`. This is just
a temporary workaround.

Signed-off-by: JiaWei Jiang <[email protected]>
  • Loading branch information
JiangJiaWei1103 committed Dec 31, 2024
1 parent 0e538f0 commit 8229418
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 25 deletions.
46 changes: 24 additions & 22 deletions plugins/flytekit-slurm/flytekitplugins/slurm/agent.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,15 @@
import os
from dataclasses import dataclass
from typing import Dict, Optional

import asyncssh
from asyncssh import SSHClientConnection

from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta
from flytekit.extend.backend.utils import convert_to_flyte_phase
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate


# Configure ssh info
class SSHCfg:
host = os.environ["SSH_HOST"]
port = int(os.environ["SSH_PORT"])
username = os.environ["SSH_USERNAME"]
password = os.environ["SSH_PASSWORD"]


@dataclass
class SlurmJobMetadata(ResourceMeta):
"""Slurm job metadata.
Expand All @@ -27,11 +19,16 @@ class SlurmJobMetadata(ResourceMeta):
"""

job_id: str
ssh_conf: Dict[str, str]


class SlurmAgent(AsyncAgentBase):
name = "Slurm Agent"

# SSH connection pool for multi-host environment
# _ssh_clients: Dict[str, SSHClientConnection]
_conn: Optional[SSHClientConnection] = None

def __init__(self) -> None:
super(SlurmAgent, self).__init__(task_type_name="slurm", metadata_type=SlurmJobMetadata)

Expand All @@ -42,29 +39,27 @@ async def create(
**kwargs,
) -> SlurmJobMetadata:
# Retrieve task config
ssh_conf = task_template.custom["ssh_conf"]
srun_conf = task_template.custom["srun_conf"]

# Construct srun command for Slurm cluster
cmd = _get_srun_cmd(srun_conf=srun_conf, entrypoint=" ".join(task_template.container.args))

# Run Slurm job
async with asyncssh.connect(
host=SSHCfg.host, port=SSHCfg.port, username=SSHCfg.username, password=SSHCfg.password
) as conn:
res = await conn.run(cmd, check=True)
if self._conn is None:
await self._connect(ssh_conf)
res = await self._conn.run(cmd, check=True)

# Direct return for sbatch
# job_id = res.stdout.split()[-1]
# Use echo trick for srun
job_id = res.stdout.strip()

return SlurmJobMetadata(job_id=job_id)
return SlurmJobMetadata(job_id=job_id, ssh_conf=ssh_conf)

async def get(self, resource_meta: SlurmJobMetadata, **kwargs) -> Resource:
async with asyncssh.connect(
host=SSHCfg.host, port=SSHCfg.port, username=SSHCfg.username, password=SSHCfg.password
) as conn:
res = await conn.run(f"scontrol show job {resource_meta.job_id}", check=True)
await self._connect(resource_meta.ssh_conf)
res = await self._conn.run(f"scontrol show job {resource_meta.job_id}", check=True)

# Determine the current flyte phase from Slurm job state
job_state = "running"
Expand All @@ -77,10 +72,17 @@ async def get(self, resource_meta: SlurmJobMetadata, **kwargs) -> Resource:
return Resource(phase=cur_phase)

async def delete(self, resource_meta: SlurmJobMetadata, **kwargs) -> None:
async with asyncssh.connect(
host=SSHCfg.host, port=SSHCfg.port, username=SSHCfg.username, password=SSHCfg.password
) as conn:
_ = await conn.run(f"scancel {resource_meta.job_id}", check=True)
await self._connect(resource_meta.ssh_conf)
_ = await self._conn.run(f"scancel {resource_meta.job_id}", check=True)

async def _connect(self, ssh_conf: Dict[str, str]) -> None:
"""Make an SSH client connection."""
self._conn = await asyncssh.connect(
host=ssh_conf["host"],
port=int(ssh_conf["port"]),
username=ssh_conf["username"],
password=ssh_conf["password"],
)


def _get_srun_cmd(srun_conf: Dict[str, str], entrypoint: str) -> str:
Expand Down
9 changes: 6 additions & 3 deletions plugins/flytekit-slurm/flytekitplugins/slurm/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@ class Slurm(object):
Compared with spark, please refer to https://api-docs.databricks.com/python/pyspark/latest/api/pyspark.SparkContext.html.
Args:
ssh_conf: Options of ssh connection. The keys should match what asyncssh.connect method expects:
https://asyncssh.readthedocs.io/en/latest/api.html#asyncssh.connect
srun_conf: Options of srun command.
"""

ssh_conf: Optional[Dict[str, str]] = None
srun_conf: Optional[Dict[str, str]] = None

def __post_init__(self):
if self.ssh_conf is None:
self.ssh_conf = {}
if self.srun_conf is None:
self.srun_conf = {}

Expand Down Expand Up @@ -52,9 +57,7 @@ def __init__(
)

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
return {
"srun_conf": self.task_config.srun_conf,
}
return {"ssh_conf": self.task_config.ssh_conf, "srun_conf": self.task_config.srun_conf}

def execute(self, **kwargs) -> Any:
ctx = FlyteContextManager.current_context()
Expand Down

0 comments on commit 8229418

Please sign in to comment.