Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Databricks Agent #1797

Merged
merged 31 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
540eb5d
Merge branch 'master' of https://github.com/Future-Outlier/flytekit
Aug 12, 2023
984d44a
Merge branch 'master' of https://github.com/Future-Outlier/flytekit
Aug 16, 2023
d1491f0
databricks agent v1
Aug 16, 2023
5e5492d
revision for docker image
Aug 16, 2023
003fa8b
rerun make lint and make fmt
Aug 17, 2023
7c52cba
add PERMANENT_FAILURE
Aug 17, 2023
442be79
REST API for databricks agent v1, async get function is unsure
Aug 18, 2023
6e51c36
add aiohttp in setup.py
Aug 18, 2023
b9c98f1
databricks agent with getting token by secret
Aug 20, 2023
23c5501
revise the code and delete the databricks_token member
Aug 20, 2023
7a98b19
remove databricks_token member
Aug 22, 2023
fa2059d
add databricks agent test
Aug 22, 2023
8d4c77d
Merge branch 'flyteorg:master' into databricks-python-sdk-agent
Future-Outlier Aug 22, 2023
6b3e745
revise by kevin
Aug 22, 2023
a5e0412
Merge branch 'databricks-python-sdk-agent' of https://github.com/Futu…
Aug 22, 2023
8f9dcda
edit get function
Aug 22, 2023
bf857ac
add spark plugin_requires in setup.py
Aug 22, 2023
9c20c4b
Refactor and Revise test_agent.py after kevin's refactor
Aug 22, 2023
77d2d70
remove databricks endpoint member
Aug 23, 2023
ab01850
fix databricks test_agent.py args error
Aug 23, 2023
232d19d
Databricks Agent With Agent Server Only
Aug 30, 2023
ef2b2f7
Fix dev-requirements.in lint error
Aug 30, 2023
606c09f
Merge branch 'flyteorg:master' into databricks-python-sdk-agent
Future-Outlier Aug 30, 2023
d65525f
error handle
pingsutw Sep 4, 2023
550973d
lint
pingsutw Sep 4, 2023
008cf7b
nit
pingsutw Sep 4, 2023
92e3310
Update from kevin's revision
Sep 4, 2023
502ae3f
fix the mocked header in test
Sep 4, 2023
1dbfd38
Merge branch 'master' into databricks-python-sdk-agent
Sep 8, 2023
5b091b9
update spark agent test
Sep 8, 2023
f074740
rename token to access_token
Sep 9, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flytekit/extend/backend/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ def convert_to_flyte_state(state: str) -> State:
Convert the state from the agent to the state in flyte.
"""
state = state.lower()
if state in ["failed"]:
if state in ["failed", "timedout", "canceled"]:
return RETRYABLE_FAILURE
elif state in ["done", "succeeded"]:
elif state in ["done", "succeeded", "success"]:
return SUCCEEDED
elif state in ["running"]:
return RUNNING
Expand Down
3 changes: 2 additions & 1 deletion plugins/flytekit-spark/flytekitplugins/spark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

from flytekit.configuration import internal as _internal

from .agent import DatabricksAgent
from .pyspark_transformers import PySparkPipelineModelTransformer
from .schema import SparkDataFrameSchemaReader, SparkDataFrameSchemaWriter, SparkDataFrameTransformer # noqa
from .sd_transformers import ParquetToSparkDecodingHandler, SparkToParquetEncodingHandler
from .task import Databricks, Spark, new_spark_session # noqa
from .task import Databricks, DatabricksAgentTask, Spark, new_spark_session # noqa
146 changes: 146 additions & 0 deletions plugins/flytekit-spark/flytekitplugins/spark/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import json
from dataclasses import asdict, dataclass
from typing import Optional

import aiohttp
import grpc
from flyteidl.admin.agent_pb2 import PENDING, CreateTaskResponse, DeleteTaskResponse, GetTaskResponse, Resource

import flytekit
from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_state
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate


@dataclass
class Metadata:
databricks_endpoint: Optional[str]
databricks_instance: Optional[str]
token: str
run_id: str


class DatabricksAgent(AgentBase):
def __init__(self):
super().__init__(task_type="spark")

async def async_create(
self,
context: grpc.ServicerContext,
output_prefix: str,
task_template: TaskTemplate,
inputs: Optional[LiteralMap] = None,
) -> CreateTaskResponse:

custom = task_template.custom
container = task_template.container
databricks_job = custom["databricks_conf"]
if not databricks_job["new_cluster"].get("docker_image"):
databricks_job["new_cluster"]["docker_image"] = {"url": container.image}
if not databricks_job["new_cluster"].get("spark_conf"):
databricks_job["new_cluster"]["spark_conf"] = custom["spark_conf"]
databricks_job["spark_python_task"] = {
"python_file": custom["applications_path"],
"parameters": container.args,
}

secrets = task_template.security_context.secrets[0]
ctx = flytekit.current_context()
token = ctx.secrets.get(group=secrets.group, key=secrets.key, group_version=secrets.group_version)

response = await send_request(
method="POST",
databricks_job=databricks_job,
databricks_endpoint=custom["databricks_endpoint"],
databricks_instance=custom["databricks_instance"],
token=token,
run_id=None,
is_cancel=False,
)

metadata = Metadata(
databricks_endpoint=custom["databricks_endpoint"],
databricks_instance=custom["databricks_instance"],
token=token,
run_id=str(response["run_id"]),
)

return CreateTaskResponse(resource_meta=json.dumps(asdict(metadata)).encode("utf-8"))

async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse:
metadata = Metadata(**json.loads(resource_meta.decode("utf-8")))

response = await send_request(
method="GET",
databricks_job=None,
databricks_endpoint=metadata.databricks_endpoint,
databricks_instance=metadata.databricks_instance,
token=metadata.token,
run_id=metadata.run_id,
is_cancel=False,
)

cur_state = PENDING
if response["state"].get("result_state"):
cur_state = convert_to_flyte_state(response["state"]["result_state"])

return GetTaskResponse(resource=Resource(state=cur_state))

async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse:
metadata = Metadata(**json.loads(resource_meta.decode("utf-8")))

await send_request(
method="POST",
databricks_job=None,
databricks_endpoint=metadata.databricks_endpoint,
databricks_instance=metadata.databricks_instance,
token=metadata.token,
run_id=metadata.run_id,
is_cancel=True,
)

return DeleteTaskResponse()


async def send_request(
method: str,
databricks_job: dict,
databricks_endpoint: Optional[str],
databricks_instance: Optional[str],
token: str,
run_id: Optional[str],
is_cancel: bool,
) -> dict:
databricksAPI = "/api/2.0/jobs/runs"
post = "POST"
get = "GET"
data = None
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}

if not databricks_endpoint:
databricks_url = f"https://{databricks_instance}{databricksAPI}"
else:
databricks_url = f"{databricks_endpoint}{databricksAPI}"

if is_cancel:
databricks_url += "/cancel"
data = json.dumps({"run_id": run_id})
elif method == post:
databricks_url += "/submit"
try:
data = json.dumps(databricks_job)
except json.JSONDecodeError:
raise ValueError("Failed to marshal databricksJob to JSON")
elif method == get:
databricks_url += f"/get?run_id={run_id}"

async with aiohttp.ClientSession() as session:
if method == post:
async with session.post(databricks_url, headers=headers, data=data) as resp:
return await resp.json()
elif method == get:
async with session.get(databricks_url, headers=headers) as resp:
return await resp.json()


AgentRegistry.register(DatabricksAgent())
50 changes: 50 additions & 0 deletions plugins/flytekit-spark/flytekitplugins/spark/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from flytekit.configuration import DefaultImages, SerializationSettings
from flytekit.core.context_manager import ExecutionParameters
from flytekit.extend import ExecutionState, TaskPlugins
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin
from flytekit.image_spec import ImageSpec

from .models import SparkJob, SparkType
Expand Down Expand Up @@ -169,6 +170,55 @@ def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
return user_params.builder().add_attr("SPARK_SESSION", self.sess).build()


@dataclass
class DatabricksAgentTask(Spark):
"""
Use this to configure a Databricks task. Task's marked with this will automatically execute
natively onto databricks platform as a distributed execution of spark
For databricks token, you can get it from here. https://docs.databricks.com/dev-tools/api/latest/authentication.html.

Args:
databricks_conf: Databricks job configuration. Config structure can be found here. https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure
databricks_instance: Domain name of your deployment. Use the form <account>.cloud.databricks.com.
databricks_endpoint: Use for test.
"""

databricks_conf: Optional[Dict[str, Union[str, dict]]] = None
databricks_instance: Optional[str] = None
databricks_endpoint: Optional[str] = None


class PySparkDatabricksTask(AsyncAgentExecutorMixin, PythonFunctionTask[Spark]):
_SPARK_TASK_TYPE = "spark"

def __init__(
self,
task_config: Spark,
task_function: Callable,
**kwargs,
):
self._default_applications_path: Optional[str] = task_config.applications_path

super(PySparkDatabricksTask, self).__init__(
task_config=task_config,
task_type=self._SPARK_TASK_TYPE,
task_function=task_function,
**kwargs,
)

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
config = {
"spark_conf": self.task_config.spark_conf,
"applications_path": self.task_config.applications_path,
"databricks_conf": self.task_config.databricks_conf,
"databricks_instance": self.task_config.databricks_instance,
"databricks_endpoint": self.task_config.databricks_endpoint,
}

return config


# Inject the Spark plugin into flytekits dynamic plugin loading system
TaskPlugins.register_pythontask_plugin(Spark, PysparkFunctionTask)
TaskPlugins.register_pythontask_plugin(Databricks, PysparkFunctionTask)
TaskPlugins.register_pythontask_plugin(DatabricksAgentTask, PySparkDatabricksTask)
145 changes: 145 additions & 0 deletions plugins/flytekit-spark/tests/test_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import json
from dataclasses import asdict
from datetime import timedelta
from unittest import mock
from unittest.mock import MagicMock

import grpc
import pytest
from aioresponses import aioresponses
from flyteidl.admin.agent_pb2 import SUCCEEDED
from flytekitplugins.spark.agent import Metadata

from flytekit import Secret
from flytekit.extend.backend.base_agent import AgentRegistry
from flytekit.interfaces.cli_identifiers import Identifier
from flytekit.models import literals, task
from flytekit.models.core.identifier import ResourceType
from flytekit.models.security import SecurityContext
from flytekit.models.task import Container, Resources, TaskTemplate


@pytest.mark.asyncio
async def test_databricks_agent():
ctx = MagicMock(spec=grpc.ServicerContext)
agent = AgentRegistry.get_agent(ctx, "spark")

task_id = Identifier(
resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version"
)
task_metadata = task.TaskMetadata(
True,
task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"),
timedelta(days=1),
literals.RetryStrategy(3),
True,
"0.1.1b0",
"This is deprecated!",
True,
"A",
)
task_config = {
"spark_conf": {
"spark.driver.memory": "1000M",
"spark.executor.memory": "1000M",
"spark.executor.cores": "1",
"spark.executor.instances": "2",
"spark.driver.cores": "1",
},
"applications_path": "dbfs:/entrypoint.py",
"databricks_conf": {
"run_name": "flytekit databricks plugin example",
"new_cluster": {
"spark_version": "12.2.x-scala2.12",
"node_type_id": "n2-highmem-4",
"num_workers": 1,
},
"timeout_seconds": 3600,
"max_retries": 1,
},
"databricks_instance": "test-account.cloud.databricks.com",
"databricks_endpoint": None,
}
container = Container(
image="flyteorg/flytekit:databricks-0.18.0-py3.7",
command=[],
args=[
"pyflyte-execute",
"--inputs",
"{{.input}}",
"--output-prefix",
"{{.outputPrefix}}",
"--raw-output-data-prefix",
"{{.rawOutputDataPrefix}}",
"--checkpoint-path",
"{{.checkpointOutputPrefix}}",
"--prev-checkpoint",
"{{.prevCheckpointPrefix}}",
"--resolver",
"flytekit.core.python_auto_container.default_task_resolver",
"--",
"task-module",
"",
"task-name",
"hello_spark",
],
resources=Resources(
requests=[],
limits=[],
),
env={},
config={},
)

SECRET_GROUP = "token-info"
SECRET_NAME = "token_secret"
mocked_token = "mocked_secret_token"
dummy_template = TaskTemplate(
id=task_id,
custom=task_config,
metadata=task_metadata,
container=container,
interface=None,
type="spark",
security_context=SecurityContext(
secrets=Secret(
group=SECRET_GROUP,
key=SECRET_NAME,
mount_requirement=Secret.MountType.ENV_VAR,
)
),
)
mocked_context = mock.patch("flytekit.current_context", autospec=True).start()
mocked_context.return_value.secrets.get.return_value = mocked_token

metadata_bytes = json.dumps(
asdict(
Metadata(
databricks_endpoint=None,
databricks_instance="test-account.cloud.databricks.com",
token=mocked_token,
run_id="123",
)
)
).encode("utf-8")

mock_create_response = {"run_id": "123"}
mock_get_response = {"run_id": "123", "state": {"result_state": "SUCCESS"}}
mock_delete_response = {}
create_url = "https://test-account.cloud.databricks.com/api/2.0/jobs/runs/submit"
get_url = "https://test-account.cloud.databricks.com/api/2.0/jobs/runs/get?run_id=123"
delete_url = "https://test-account.cloud.databricks.com/api/2.0/jobs/runs/cancel"
with aioresponses() as mocked:
mocked.post(create_url, status=200, payload=mock_create_response)
res = await agent.async_create(ctx, "/tmp", dummy_template, None)
assert res.resource_meta == metadata_bytes

mocked.get(get_url, status=200, payload=mock_get_response)
res = await agent.async_get(ctx, metadata_bytes)
assert res.resource.state == SUCCEEDED
assert res.resource.outputs == literals.LiteralMap({}).to_flyte_idl()

mocked.post(delete_url, status=200, payload=mock_delete_response)
await agent.async_delete(ctx, metadata_bytes)

mock.patch.stopall()
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
"kubernetes>=12.0.1",
"rich",
"rich_click",
"aiohttp",
Future-Outlier marked this conversation as resolved.
Show resolved Hide resolved
],
extras_require=extras_require,
scripts=[
Expand Down