Skip to content

Commit

Permalink
Ray Task Support (#1093)
Browse files Browse the repository at this point in the history
* Add Ray Plugin

Signed-off-by: Kevin <[email protected]>

* wip

Signed-off-by: Kevin Su <[email protected]>

* update dep

Signed-off-by: Kevin Su <[email protected]>

* wip

Signed-off-by: Kevin Su <[email protected]>

* update

Signed-off-by: Kevin Su <[email protected]>

* update

Signed-off-by: Kevin Su <[email protected]>

* Added to ci

Signed-off-by: Kevin Su <[email protected]>

* Added to ci

Signed-off-by: Kevin Su <[email protected]>

* Fixed tests

Signed-off-by: Kevin Su <[email protected]>

* Fixed tests

Signed-off-by: Kevin Su <[email protected]>

* Updated image

Signed-off-by: Kevin Su <[email protected]>

* wip

Signed-off-by: Kevin Su <[email protected]>

* test

Signed-off-by: Kevin Su <[email protected]>

* test

Signed-off-by: Kevin Su <[email protected]>

* wip

Signed-off-by: Kevin Su <[email protected]>

* typo

Signed-off-by: Kevin Su <[email protected]>

* wip

Signed-off-by: Kevin Su <[email protected]>

* wip

Signed-off-by: Kevin Su <[email protected]>

* wip

Signed-off-by: Kevin Su <[email protected]>

* wip

Signed-off-by: Kevin Su <[email protected]>

* Add runtime

Signed-off-by: Kevin Su <[email protected]>

* update

Signed-off-by: Kevin Su <[email protected]>

* Fix error

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* remove example file

Signed-off-by: Kevin Su <[email protected]>

* Fix test error

Signed-off-by: Kevin Su <[email protected]>

* update type

Signed-off-by: Kevin Su <[email protected]>

* Update idl

Signed-off-by: Kevin Su <[email protected]>

* fix test error

Signed-off-by: Kevin Su <[email protected]>

* fix test error

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Aug 10, 2022
1 parent 8d7efb0 commit 73eaad1
Show file tree
Hide file tree
Showing 12 changed files with 629 additions and 1 deletion.
1 change: 1 addition & 0 deletions .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ jobs:
- flytekit-pandera
- flytekit-papermill
- flytekit-polars
- flytekit-ray
- flytekit-snowflake
- flytekit-spark
- flytekit-sqlalchemy
Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __init__(
can be used to inject some client side variables only. Prefer using ExecutionParams
:param Optional[ExecutionBehavior] execution_mode: Defines how the execution should behave, for example
executing normally or specially handling a dynamic case.
:param Optional[TaskResolverMixin] task_type: String task type to be associated with this Task
:param str task_type: String task type to be associated with this Task
"""
if task_function is None:
raise ValueError("TaskFunction is a required parameter for PythonFunctionTask")
Expand Down
9 changes: 9 additions & 0 deletions plugins/flytekit-ray/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Flytekit Ray Plugin

Flyte backend can be connected with Ray. Once enabled, it allows you to run flyte task on Ray cluster

To install the plugin, run the following command:

```bash
pip install flytekitplugins-ray
```
13 changes: 13 additions & 0 deletions plugins/flytekit-ray/flytekitplugins/ray/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
.. currentmodule:: flytekitplugins.ray
This package contains things that are useful when extending Flytekit.
.. autosummary::
:template: custom.rst
:toctree: generated/
RayConfig
"""

from .task import HeadNodeConfig, RayJobConfig, WorkerNodeConfig
204 changes: 204 additions & 0 deletions plugins/flytekit-ray/flytekitplugins/ray/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import typing

from flyteidl.plugins import ray_pb2 as _ray_pb2

from flytekit.models import common as _common


class WorkerGroupSpec(_common.FlyteIdlEntity):
def __init__(
self,
group_name: str,
replicas: int,
min_replicas: typing.Optional[int] = 0,
max_replicas: typing.Optional[int] = None,
ray_start_params: typing.Optional[typing.Dict[str, str]] = None,
):
self._group_name = group_name
self._replicas = replicas
self._min_replicas = min_replicas
self._max_replicas = max_replicas if max_replicas else replicas
self._ray_start_params = ray_start_params

@property
def group_name(self):
"""
Group name of the current worker group.
:rtype: str
"""
return self._group_name

@property
def replicas(self):
"""
Desired replicas of the worker group.
:rtype: int
"""
return self._replicas

@property
def min_replicas(self):
"""
Min replicas of the worker group.
:rtype: int
"""
return self._min_replicas

@property
def max_replicas(self):
"""
Max replicas of the worker group.
:rtype: int
"""
return self._max_replicas

@property
def ray_start_params(self):
"""
The ray start params of worker node group.
:rtype: typing.Dict[str, str]
"""
return self._ray_start_params

def to_flyte_idl(self):
"""
:rtype: flyteidl.plugins._ray_pb2.WorkerGroupSpec
"""
return _ray_pb2.WorkerGroupSpec(
group_name=self.group_name,
replicas=self.replicas,
min_replicas=self.min_replicas,
max_replicas=self.max_replicas,
ray_start_params=self.ray_start_params,
)

@classmethod
def from_flyte_idl(cls, proto):
"""
:param flyteidl.plugins._ray_pb2.WorkerGroupSpec proto:
:rtype: WorkerGroupSpec
"""
return cls(
group_name=proto.group_name,
replicas=proto.replicas,
min_replicas=proto.min_replicas,
max_replicas=proto.max_replicas,
ray_start_params=proto.ray_start_params,
)


class HeadGroupSpec(_common.FlyteIdlEntity):
def __init__(
self,
ray_start_params: typing.Optional[typing.Dict[str, str]] = None,
):
self._ray_start_params = ray_start_params

@property
def ray_start_params(self):
"""
The ray start params of worker node group.
:rtype: typing.Dict[str, str]
"""
return self._ray_start_params

def to_flyte_idl(self):
"""
:rtype: flyteidl.plugins._ray_pb2.HeadGroupSpec
"""
return _ray_pb2.HeadGroupSpec(
ray_start_params=self.ray_start_params if self.ray_start_params else {},
)

@classmethod
def from_flyte_idl(cls, proto):
"""
:param flyteidl.plugins._ray_pb2.HeadGroupSpec proto:
:rtype: HeadGroupSpec
"""
return cls(
ray_start_params=proto.ray_start_params,
)


class RayCluster(_common.FlyteIdlEntity):
"""
Define RayCluster spec that will be used by KubeRay to launch the cluster.
"""

def __init__(
self, worker_group_spec: typing.List[WorkerGroupSpec], head_group_spec: typing.Optional[HeadGroupSpec] = None
):
self._head_group_spec = head_group_spec
self._worker_group_spec = worker_group_spec

@property
def head_group_spec(self) -> HeadGroupSpec:
"""
The head group configuration.
:rtype: HeadGroupSpec
"""
return self._head_group_spec

@property
def worker_group_spec(self) -> typing.List[WorkerGroupSpec]:
"""
The worker group configurations.
:rtype: typing.List[WorkerGroupSpec]
"""
return self._worker_group_spec

def to_flyte_idl(self) -> _ray_pb2.RayCluster:
"""
:rtype: flyteidl.plugins._ray_pb2.RayCluster
"""
return _ray_pb2.RayCluster(
head_group_spec=self.head_group_spec.to_flyte_idl() if self.head_group_spec else None,
worker_group_spec=[wg.to_flyte_idl() for wg in self.worker_group_spec],
)

@classmethod
def from_flyte_idl(cls, proto):
"""
:param flyteidl.plugins._ray_pb2.RayCluster proto:
:rtype: RayCluster
"""
return cls(
head_group_spec=HeadGroupSpec.from_flyte_idl(proto.head_group_spec) if proto.head_group_spec else None,
worker_group_spec=[WorkerGroupSpec.from_flyte_idl(wg) for wg in proto.worker_group_spec],
)


class RayJob(_common.FlyteIdlEntity):
"""
Models _ray_pb2.RayJob
"""

def __init__(
self,
ray_cluster: RayCluster,
runtime_env: typing.Optional[str],
):
self._ray_cluster = ray_cluster
self._runtime_env = runtime_env

@property
def ray_cluster(self) -> RayCluster:
return self._ray_cluster

@property
def runtime_env(self) -> typing.Optional[str]:
return self._runtime_env

def to_flyte_idl(self) -> _ray_pb2.RayJob:
return _ray_pb2.RayJob(
ray_cluster=self.ray_cluster.to_flyte_idl(),
runtime_env=self.runtime_env,
)

@classmethod
def from_flyte_idl(cls, proto: _ray_pb2.RayJob):
return cls(
ray_cluster=RayCluster.from_flyte_idl(proto.ray_cluster) if proto.ray_cluster else None,
runtime_env=proto.runtime_env,
)
76 changes: 76 additions & 0 deletions plugins/flytekit-ray/flytekitplugins/ray/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import base64
import json
import typing
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional

import ray
from flytekitplugins.ray.models import HeadGroupSpec, RayCluster, RayJob, WorkerGroupSpec
from google.protobuf.json_format import MessageToDict

from flytekit.configuration import SerializationSettings
from flytekit.core.context_manager import ExecutionParameters
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.extend import TaskPlugins


@dataclass
class HeadNodeConfig:
ray_start_params: typing.Optional[typing.Dict[str, str]] = None


@dataclass
class WorkerNodeConfig:
group_name: str
replicas: int
min_replicas: typing.Optional[int] = None
max_replicas: typing.Optional[int] = None
ray_start_params: typing.Optional[typing.Dict[str, str]] = None


@dataclass
class RayJobConfig:
worker_node_config: typing.List[WorkerNodeConfig]
head_node_config: typing.Optional[HeadNodeConfig] = None
runtime_env: typing.Optional[dict] = None
address: typing.Optional[str] = None


class RayFunctionTask(PythonFunctionTask):
"""
Actual Plugin that transforms the local python code for execution within Ray job.
"""

_RAY_TASK_TYPE = "ray"

def __init__(self, task_config: RayJobConfig, task_function: Callable, **kwargs):
super().__init__(task_config=task_config, task_type=self._RAY_TASK_TYPE, task_function=task_function, **kwargs)
self._task_config = task_config

def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
ray.init(address=self._task_config.address)
return user_params

def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any:
ray.shutdown()
return rval

def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]]:
cfg = self._task_config

ray_job = RayJob(
ray_cluster=RayCluster(
head_group_spec=HeadGroupSpec(cfg.head_node_config.ray_start_params) if cfg.head_node_config else None,
worker_group_spec=[
WorkerGroupSpec(c.group_name, c.replicas, c.min_replicas, c.max_replicas, c.ray_start_params)
for c in cfg.worker_node_config
],
),
# Use base64 to encode runtime_env dict and convert it to byte string
runtime_env=base64.b64encode(json.dumps(cfg.runtime_env).encode()).decode(),
)
return MessageToDict(ray_job.to_flyte_idl())


# Inject the Ray plugin into flytekits dynamic plugin loading system
TaskPlugins.register_pythontask_plugin(RayJobConfig, RayFunctionTask)
2 changes: 2 additions & 0 deletions plugins/flytekit-ray/requirements.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.
-e file:.#egg=flytekitplugins-ray
Loading

0 comments on commit 73eaad1

Please sign in to comment.