-
Notifications
You must be signed in to change notification settings - Fork 301
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add Ray Plugin Signed-off-by: Kevin <[email protected]> * wip Signed-off-by: Kevin Su <[email protected]> * update dep Signed-off-by: Kevin Su <[email protected]> * wip Signed-off-by: Kevin Su <[email protected]> * update Signed-off-by: Kevin Su <[email protected]> * update Signed-off-by: Kevin Su <[email protected]> * Added to ci Signed-off-by: Kevin Su <[email protected]> * Added to ci Signed-off-by: Kevin Su <[email protected]> * Fixed tests Signed-off-by: Kevin Su <[email protected]> * Fixed tests Signed-off-by: Kevin Su <[email protected]> * Updated image Signed-off-by: Kevin Su <[email protected]> * wip Signed-off-by: Kevin Su <[email protected]> * test Signed-off-by: Kevin Su <[email protected]> * test Signed-off-by: Kevin Su <[email protected]> * wip Signed-off-by: Kevin Su <[email protected]> * typo Signed-off-by: Kevin Su <[email protected]> * wip Signed-off-by: Kevin Su <[email protected]> * wip Signed-off-by: Kevin Su <[email protected]> * wip Signed-off-by: Kevin Su <[email protected]> * wip Signed-off-by: Kevin Su <[email protected]> * Add runtime Signed-off-by: Kevin Su <[email protected]> * update Signed-off-by: Kevin Su <[email protected]> * Fix error Signed-off-by: Kevin Su <[email protected]> * nit Signed-off-by: Kevin Su <[email protected]> * remove example file Signed-off-by: Kevin Su <[email protected]> * Fix test error Signed-off-by: Kevin Su <[email protected]> * update type Signed-off-by: Kevin Su <[email protected]> * Update idl Signed-off-by: Kevin Su <[email protected]> * fix test error Signed-off-by: Kevin Su <[email protected]> * fix test error Signed-off-by: Kevin Su <[email protected]>
- Loading branch information
Showing
12 changed files
with
629 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Flytekit Ray Plugin | ||
|
||
Flyte backend can be connected with Ray. Once enabled, it allows you to run flyte task on Ray cluster | ||
|
||
To install the plugin, run the following command: | ||
|
||
```bash | ||
pip install flytekitplugins-ray | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
""" | ||
.. currentmodule:: flytekitplugins.ray | ||
This package contains things that are useful when extending Flytekit. | ||
.. autosummary:: | ||
:template: custom.rst | ||
:toctree: generated/ | ||
RayConfig | ||
""" | ||
|
||
from .task import HeadNodeConfig, RayJobConfig, WorkerNodeConfig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
import typing | ||
|
||
from flyteidl.plugins import ray_pb2 as _ray_pb2 | ||
|
||
from flytekit.models import common as _common | ||
|
||
|
||
class WorkerGroupSpec(_common.FlyteIdlEntity): | ||
def __init__( | ||
self, | ||
group_name: str, | ||
replicas: int, | ||
min_replicas: typing.Optional[int] = 0, | ||
max_replicas: typing.Optional[int] = None, | ||
ray_start_params: typing.Optional[typing.Dict[str, str]] = None, | ||
): | ||
self._group_name = group_name | ||
self._replicas = replicas | ||
self._min_replicas = min_replicas | ||
self._max_replicas = max_replicas if max_replicas else replicas | ||
self._ray_start_params = ray_start_params | ||
|
||
@property | ||
def group_name(self): | ||
""" | ||
Group name of the current worker group. | ||
:rtype: str | ||
""" | ||
return self._group_name | ||
|
||
@property | ||
def replicas(self): | ||
""" | ||
Desired replicas of the worker group. | ||
:rtype: int | ||
""" | ||
return self._replicas | ||
|
||
@property | ||
def min_replicas(self): | ||
""" | ||
Min replicas of the worker group. | ||
:rtype: int | ||
""" | ||
return self._min_replicas | ||
|
||
@property | ||
def max_replicas(self): | ||
""" | ||
Max replicas of the worker group. | ||
:rtype: int | ||
""" | ||
return self._max_replicas | ||
|
||
@property | ||
def ray_start_params(self): | ||
""" | ||
The ray start params of worker node group. | ||
:rtype: typing.Dict[str, str] | ||
""" | ||
return self._ray_start_params | ||
|
||
def to_flyte_idl(self): | ||
""" | ||
:rtype: flyteidl.plugins._ray_pb2.WorkerGroupSpec | ||
""" | ||
return _ray_pb2.WorkerGroupSpec( | ||
group_name=self.group_name, | ||
replicas=self.replicas, | ||
min_replicas=self.min_replicas, | ||
max_replicas=self.max_replicas, | ||
ray_start_params=self.ray_start_params, | ||
) | ||
|
||
@classmethod | ||
def from_flyte_idl(cls, proto): | ||
""" | ||
:param flyteidl.plugins._ray_pb2.WorkerGroupSpec proto: | ||
:rtype: WorkerGroupSpec | ||
""" | ||
return cls( | ||
group_name=proto.group_name, | ||
replicas=proto.replicas, | ||
min_replicas=proto.min_replicas, | ||
max_replicas=proto.max_replicas, | ||
ray_start_params=proto.ray_start_params, | ||
) | ||
|
||
|
||
class HeadGroupSpec(_common.FlyteIdlEntity): | ||
def __init__( | ||
self, | ||
ray_start_params: typing.Optional[typing.Dict[str, str]] = None, | ||
): | ||
self._ray_start_params = ray_start_params | ||
|
||
@property | ||
def ray_start_params(self): | ||
""" | ||
The ray start params of worker node group. | ||
:rtype: typing.Dict[str, str] | ||
""" | ||
return self._ray_start_params | ||
|
||
def to_flyte_idl(self): | ||
""" | ||
:rtype: flyteidl.plugins._ray_pb2.HeadGroupSpec | ||
""" | ||
return _ray_pb2.HeadGroupSpec( | ||
ray_start_params=self.ray_start_params if self.ray_start_params else {}, | ||
) | ||
|
||
@classmethod | ||
def from_flyte_idl(cls, proto): | ||
""" | ||
:param flyteidl.plugins._ray_pb2.HeadGroupSpec proto: | ||
:rtype: HeadGroupSpec | ||
""" | ||
return cls( | ||
ray_start_params=proto.ray_start_params, | ||
) | ||
|
||
|
||
class RayCluster(_common.FlyteIdlEntity): | ||
""" | ||
Define RayCluster spec that will be used by KubeRay to launch the cluster. | ||
""" | ||
|
||
def __init__( | ||
self, worker_group_spec: typing.List[WorkerGroupSpec], head_group_spec: typing.Optional[HeadGroupSpec] = None | ||
): | ||
self._head_group_spec = head_group_spec | ||
self._worker_group_spec = worker_group_spec | ||
|
||
@property | ||
def head_group_spec(self) -> HeadGroupSpec: | ||
""" | ||
The head group configuration. | ||
:rtype: HeadGroupSpec | ||
""" | ||
return self._head_group_spec | ||
|
||
@property | ||
def worker_group_spec(self) -> typing.List[WorkerGroupSpec]: | ||
""" | ||
The worker group configurations. | ||
:rtype: typing.List[WorkerGroupSpec] | ||
""" | ||
return self._worker_group_spec | ||
|
||
def to_flyte_idl(self) -> _ray_pb2.RayCluster: | ||
""" | ||
:rtype: flyteidl.plugins._ray_pb2.RayCluster | ||
""" | ||
return _ray_pb2.RayCluster( | ||
head_group_spec=self.head_group_spec.to_flyte_idl() if self.head_group_spec else None, | ||
worker_group_spec=[wg.to_flyte_idl() for wg in self.worker_group_spec], | ||
) | ||
|
||
@classmethod | ||
def from_flyte_idl(cls, proto): | ||
""" | ||
:param flyteidl.plugins._ray_pb2.RayCluster proto: | ||
:rtype: RayCluster | ||
""" | ||
return cls( | ||
head_group_spec=HeadGroupSpec.from_flyte_idl(proto.head_group_spec) if proto.head_group_spec else None, | ||
worker_group_spec=[WorkerGroupSpec.from_flyte_idl(wg) for wg in proto.worker_group_spec], | ||
) | ||
|
||
|
||
class RayJob(_common.FlyteIdlEntity): | ||
""" | ||
Models _ray_pb2.RayJob | ||
""" | ||
|
||
def __init__( | ||
self, | ||
ray_cluster: RayCluster, | ||
runtime_env: typing.Optional[str], | ||
): | ||
self._ray_cluster = ray_cluster | ||
self._runtime_env = runtime_env | ||
|
||
@property | ||
def ray_cluster(self) -> RayCluster: | ||
return self._ray_cluster | ||
|
||
@property | ||
def runtime_env(self) -> typing.Optional[str]: | ||
return self._runtime_env | ||
|
||
def to_flyte_idl(self) -> _ray_pb2.RayJob: | ||
return _ray_pb2.RayJob( | ||
ray_cluster=self.ray_cluster.to_flyte_idl(), | ||
runtime_env=self.runtime_env, | ||
) | ||
|
||
@classmethod | ||
def from_flyte_idl(cls, proto: _ray_pb2.RayJob): | ||
return cls( | ||
ray_cluster=RayCluster.from_flyte_idl(proto.ray_cluster) if proto.ray_cluster else None, | ||
runtime_env=proto.runtime_env, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import base64 | ||
import json | ||
import typing | ||
from dataclasses import dataclass | ||
from typing import Any, Callable, Dict, Optional | ||
|
||
import ray | ||
from flytekitplugins.ray.models import HeadGroupSpec, RayCluster, RayJob, WorkerGroupSpec | ||
from google.protobuf.json_format import MessageToDict | ||
|
||
from flytekit.configuration import SerializationSettings | ||
from flytekit.core.context_manager import ExecutionParameters | ||
from flytekit.core.python_function_task import PythonFunctionTask | ||
from flytekit.extend import TaskPlugins | ||
|
||
|
||
@dataclass | ||
class HeadNodeConfig: | ||
ray_start_params: typing.Optional[typing.Dict[str, str]] = None | ||
|
||
|
||
@dataclass | ||
class WorkerNodeConfig: | ||
group_name: str | ||
replicas: int | ||
min_replicas: typing.Optional[int] = None | ||
max_replicas: typing.Optional[int] = None | ||
ray_start_params: typing.Optional[typing.Dict[str, str]] = None | ||
|
||
|
||
@dataclass | ||
class RayJobConfig: | ||
worker_node_config: typing.List[WorkerNodeConfig] | ||
head_node_config: typing.Optional[HeadNodeConfig] = None | ||
runtime_env: typing.Optional[dict] = None | ||
address: typing.Optional[str] = None | ||
|
||
|
||
class RayFunctionTask(PythonFunctionTask): | ||
""" | ||
Actual Plugin that transforms the local python code for execution within Ray job. | ||
""" | ||
|
||
_RAY_TASK_TYPE = "ray" | ||
|
||
def __init__(self, task_config: RayJobConfig, task_function: Callable, **kwargs): | ||
super().__init__(task_config=task_config, task_type=self._RAY_TASK_TYPE, task_function=task_function, **kwargs) | ||
self._task_config = task_config | ||
|
||
def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: | ||
ray.init(address=self._task_config.address) | ||
return user_params | ||
|
||
def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any: | ||
ray.shutdown() | ||
return rval | ||
|
||
def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]]: | ||
cfg = self._task_config | ||
|
||
ray_job = RayJob( | ||
ray_cluster=RayCluster( | ||
head_group_spec=HeadGroupSpec(cfg.head_node_config.ray_start_params) if cfg.head_node_config else None, | ||
worker_group_spec=[ | ||
WorkerGroupSpec(c.group_name, c.replicas, c.min_replicas, c.max_replicas, c.ray_start_params) | ||
for c in cfg.worker_node_config | ||
], | ||
), | ||
# Use base64 to encode runtime_env dict and convert it to byte string | ||
runtime_env=base64.b64encode(json.dumps(cfg.runtime_env).encode()).decode(), | ||
) | ||
return MessageToDict(ray_job.to_flyte_idl()) | ||
|
||
|
||
# Inject the Ray plugin into flytekits dynamic plugin loading system | ||
TaskPlugins.register_pythontask_plugin(RayJobConfig, RayFunctionTask) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
. | ||
-e file:.#egg=flytekitplugins-ray |
Oops, something went wrong.