From dc575e4e8b7cb66f83cafd70151e0ac6a598602e Mon Sep 17 00:00:00 2001 From: Jan Fiedler <89976021+fiedlerNr9@users.noreply.github.com> Date: Mon, 29 Jul 2024 22:50:14 +0200 Subject: [PATCH] Enable Ray Fast Register (#2606) Signed-off-by: Jan Fiedler Signed-off-by: mao3267 --- .../flytekit-ray/flytekitplugins/ray/task.py | 48 ++++++++++++++++--- 1 file changed, 41 insertions(+), 7 deletions(-) diff --git a/plugins/flytekit-ray/flytekitplugins/ray/task.py b/plugins/flytekit-ray/flytekitplugins/ray/task.py index e6b3ad8039..86bc12a4c4 100644 --- a/plugins/flytekit-ray/flytekitplugins/ray/task.py +++ b/plugins/flytekit-ray/flytekitplugins/ray/task.py @@ -1,16 +1,22 @@ import base64 import json +import os import typing from dataclasses import dataclass from typing import Any, Callable, Dict, Optional import yaml -from flytekitplugins.ray.models import HeadGroupSpec, RayCluster, RayJob, WorkerGroupSpec +from flytekitplugins.ray.models import ( + HeadGroupSpec, + RayCluster, + RayJob, + WorkerGroupSpec, +) from google.protobuf.json_format import MessageToDict from flytekit import lazy_module from flytekit.configuration import SerializationSettings -from flytekit.core.context_manager import ExecutionParameters +from flytekit.core.context_manager import ExecutionParameters, FlyteContextManager from flytekit.core.python_function_task import PythonFunctionTask from flytekit.extend import TaskPlugins @@ -40,6 +46,7 @@ class RayJobConfig: address: typing.Optional[str] = None shutdown_after_job_finishes: bool = False ttl_seconds_after_finished: typing.Optional[int] = None + excludes_working_dir: typing.Optional[typing.List[str]] = None class RayFunctionTask(PythonFunctionTask): @@ -50,11 +57,30 @@ class RayFunctionTask(PythonFunctionTask): _RAY_TASK_TYPE = "ray" def __init__(self, task_config: RayJobConfig, task_function: Callable, **kwargs): - super().__init__(task_config=task_config, task_type=self._RAY_TASK_TYPE, task_function=task_function, **kwargs) + super().__init__( + task_config=task_config, + task_type=self._RAY_TASK_TYPE, + task_function=task_function, + **kwargs, + ) self._task_config = task_config def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: - ray.init(address=self._task_config.address) + init_params = {"address": self._task_config.address} + + ctx = FlyteContextManager.current_context() + if not ctx.execution_state.is_local_execution(): + working_dir = os.getcwd() + init_params["runtime_env"] = { + "working_dir": working_dir, + "excludes": ["script_mode.tar.gz", "fast*.tar.gz"], + } + + cfg = self._task_config + if cfg.excludes_working_dir: + init_params["runtime_env"]["excludes"].extend(cfg.excludes_working_dir) + + ray.init(**init_params) return user_params def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]]: @@ -67,12 +93,20 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any] ray_job = RayJob( ray_cluster=RayCluster( - head_group_spec=HeadGroupSpec(cfg.head_node_config.ray_start_params) if cfg.head_node_config else None, + head_group_spec=( + HeadGroupSpec(cfg.head_node_config.ray_start_params) if cfg.head_node_config else None + ), worker_group_spec=[ - WorkerGroupSpec(c.group_name, c.replicas, c.min_replicas, c.max_replicas, c.ray_start_params) + WorkerGroupSpec( + c.group_name, + c.replicas, + c.min_replicas, + c.max_replicas, + c.ray_start_params, + ) for c in cfg.worker_node_config ], - enable_autoscaling=cfg.enable_autoscaling if cfg.enable_autoscaling else False, + enable_autoscaling=(cfg.enable_autoscaling if cfg.enable_autoscaling else False), ), runtime_env=runtime_env, runtime_env_yaml=runtime_env_yaml,