From d1491f0eb0c9dfd7b2ab835a07dd38ed16fa6e7e Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Wed, 16 Aug 2023 20:47:12 +0800 Subject: [PATCH 01/24] databricks agent v1 Signed-off-by: Future Outlier --- .../flytekitplugins/spark/__init__.py | 3 +- .../flytekitplugins/spark/agent.py | 143 ++++++++++++++++++ .../flytekitplugins/spark/task.py | 80 ++++++++++ 3 files changed, 225 insertions(+), 1 deletion(-) create mode 100644 plugins/flytekit-spark/flytekitplugins/spark/agent.py diff --git a/plugins/flytekit-spark/flytekitplugins/spark/__init__.py b/plugins/flytekit-spark/flytekitplugins/spark/__init__.py index e769540aea..9cee417247 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/__init__.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/__init__.py @@ -17,7 +17,8 @@ from flytekit.configuration import internal as _internal +from .agent import DatabricksAgent from .pyspark_transformers import PySparkPipelineModelTransformer from .schema import SparkDataFrameSchemaReader, SparkDataFrameSchemaWriter, SparkDataFrameTransformer # noqa from .sd_transformers import ParquetToSparkDecodingHandler, SparkToParquetEncodingHandler -from .task import Databricks, Spark, new_spark_session # noqa +from .task import Databricks, DatabricksAgentTask, Spark, new_spark_session # noqa diff --git a/plugins/flytekit-spark/flytekitplugins/spark/agent.py b/plugins/flytekit-spark/flytekitplugins/spark/agent.py new file mode 100644 index 0000000000..649862bc01 --- /dev/null +++ b/plugins/flytekit-spark/flytekitplugins/spark/agent.py @@ -0,0 +1,143 @@ +import json +from dataclasses import asdict, dataclass +from typing import Optional + +import grpc +from flyteidl.admin.agent_pb2 import ( + PENDING, + SUCCEEDED, + CreateTaskResponse, + DeleteTaskResponse, + GetTaskResponse, + Resource, +) + +# for databricks +from databricks.sdk import WorkspaceClient +from databricks.sdk.service import jobs + +from flytekit import FlyteContextManager, StructuredDataset, logger +from flytekit.core.type_engine import TypeEngine +from flytekit.extend.backend.base_agent import ( + AgentBase, + AgentRegistry, +) +from flytekit.models import literals +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate +from flytekit.models.types import LiteralType, StructuredDatasetType + + +@dataclass +class Metadata: + cluster_id: str + host: str + token: str + run_id: int + + +class DatabricksAgent(AgentBase): + def __init__(self): + super().__init__(task_type="spark", asynchronous=False) + + def create( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + ) -> CreateTaskResponse: + for attr_name, attr_value in vars(task_template).items(): + print(f"Type:{type(attr_value)}, {attr_name}: {attr_value}") + + custom = task_template.custom + for attr_name, attr_value in custom.items(): + print(f"{attr_name}: {attr_value}") + + w = WorkspaceClient(host=custom["host"], token=custom["token"]) + + clstr = w.clusters.create_and_wait( + cluster_name = custom["cluster_name"], + spark_version = custom["spark_version"], + node_type_id = custom["node_type_id"], + autotermination_minutes = custom["autotermination_minutes"], + num_workers = custom["num_workers"] + ) + cluster_id = clstr.cluster_id # important metadata + + tasks = [ + jobs.Task( + description=custom["description"], + existing_cluster_id=cluster_id, + spark_python_task=jobs.SparkPythonTask(python_file=custom["python_file"]), + task_key=custom["task_key"], # metadata + timeout_seconds=custom["timeout_seconds"], # metadata + ) + ] + + run = w.jobs.submit( + name=custom["cluster_name"], # metadata + tasks=tasks, # tasks + ).result() + + # metadata + metadata = Metadata( + cluster_id=cluster_id, + host=custom["host"], + token=custom["token"], + run_id=run.tasks[0].run_id, + ) + + return CreateTaskResponse(resource_meta=json.dumps(asdict(metadata)).encode("utf-8")) + + def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: + metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) + + w = WorkspaceClient(host=metadata.host, + token=metadata.token,) + job = w.jobs.get_run_output(metadata.run_id) + + if job.error: # have already checked databricks sdk + logger.error(job.errors.__str__()) + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(job.errors.__str__()) + return GetTaskResponse(resource=Resource(state=PERMANENT_FAILURE)) + + if job.metadata.state.result_state == jobs.RunResultState.SUCCESS: + cur_state = SUCCEEDED + else: + # TODO: Discuss with Kevin for the state, considering mapping technique + cur_state = PENDING + + res = None + + if cur_state == SUCCEEDED: + ctx = FlyteContextManager.current_context() + # output page_url and task_output + if job.metadata.run_page_url: + output_location = job.metadata.run_page_url + res = literals.LiteralMap( + { + "results": TypeEngine.to_literal( + ctx, + StructuredDataset(uri=output_location), + StructuredDataset, + LiteralType(structured_dataset_type=StructuredDatasetType(format="")), + ) + } + ).to_flyte_idl() + w.clusters.permanent_delete(cluster_id=metadata.cluster_id) + + return GetTaskResponse(resource=Resource(state=cur_state, outputs=res)) + + def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: + metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) + w = WorkspaceClient(host=metadata.host, + token=metadata.token, + ) + w.jobs.delete_run(metadata.run_id) + w.clusters.permanent_delete(cluster_id=metadata.cluster_id) + return DeleteTaskResponse() + + +AgentRegistry.register(DatabricksAgent()) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 0d8ecd5b6e..aa0b556865 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -169,6 +169,86 @@ def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: return user_params.builder().add_attr("SPARK_SESSION", self.sess).build() +@dataclass +class DatabricksAgentTask(Spark): + """ + Use this to configure a Databricks task. Task's marked with this will automatically execute + natively onto databricks platform as a distributed execution of spark + + Args: + databricks_conf: Databricks job configuration. Config structure can be found here. https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure + databricks_token: Databricks access token. https://docs.databricks.com/dev-tools/api/latest/authentication.html. + databricks_instance: Domain name of your deployment. Use the form .cloud.databricks.com. + """ + + databricks_conf: Optional[Dict[str, Union[str, dict]]] = None + databricks_token: Optional[str] = None + databricks_instance: Optional[str] = None + + +class PySparkDatabricksTask(AsyncAgentExecutorMixin, PythonFunctionTask[Spark]): + _SPARK_TASK_TYPE = "spark" + + def __init__( + self, + task_config: Spark, + task_function: Callable, + container_image: Optional[Union[str, ImageSpec]] = None, + **kwargs, + ): + self.sess: Optional[SparkSession] = None + self._default_executor_path: Optional[str] = task_config.executor_path + self._default_applications_path: Optional[str] = task_config.applications_path + + super(PySparkDatabricksTask, self).__init__( + task_config=task_config, + task_type=self._SPARK_TASK_TYPE, + task_function=task_function, + container_image=container_image, + **kwargs, + ) + + """ + For Cluster + cluster_name: str, from task_config.databricks_conf.run_name + spark_conf:Optional[Dict[str, str]] = None, from task_config.spark_conf + spark_version: str, from task_config.databricks_conf.spark_version + node_type_id: str, from task_config.databricks_conf.node_type_id + num_workers: int, from task_config.databricks_conf.num_workers + For Job + description: str, from task_config.databricks_conf.description + existing_cluster_id: str (you should create int agent, and get it by metadata) + python_file: str, from task_config.databricks_conf.description + task_key: str, from task_config.databricks_conf.description + timeout_seconds: int, from task_config.databricks_conf.description + max_retries: int, from task_config.databricks_conf.description + """ + + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + databricks_conf = getattr(self.task_config, "databricks_conf", {}) + new_cluster_conf = databricks_conf.get("new_cluster", {}) + config = { + "host": self.task_config.databricks_instance, + "token": self.task_config.databricks_token, + "cluster_name": databricks_conf.get("run_name"), + "spark_conf": getattr(self.task_config, "spark_conf", {}), + "spark_version": new_cluster_conf.get("spark_version"), + "node_type_id": new_cluster_conf.get("node_type_id"), + "autotermination_minutes": new_cluster_conf.get("autotermination_minutes"), + "num_workers": new_cluster_conf.get("num_workers"), + "timeout_seconds": databricks_conf.get("timeout_seconds"), + "max_retries": databricks_conf.get("max_retries"), + "description": databricks_conf.get("description"), + "python_file": databricks_conf.get("python_file"), + "task_key": databricks_conf.get("task_key"), + } + + s = Struct() + s.update(config) + return json_format.MessageToDict(s) + + # Inject the Spark plugin into flytekits dynamic plugin loading system TaskPlugins.register_pythontask_plugin(Spark, PysparkFunctionTask) TaskPlugins.register_pythontask_plugin(Databricks, PysparkFunctionTask) +TaskPlugins.register_pythontask_plugin(DatabricksAgentTask, PySparkDatabricksTask) From 5e5492d8d39eca3c32dfefa595927c3b59ec2b53 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Wed, 16 Aug 2023 22:47:05 +0800 Subject: [PATCH 02/24] revision for docker image Signed-off-by: Future Outlier --- .../flytekitplugins/spark/agent.py | 56 +++++++++++-------- .../flytekitplugins/spark/task.py | 20 +++++++ 2 files changed, 54 insertions(+), 22 deletions(-) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/agent.py b/plugins/flytekit-spark/flytekitplugins/spark/agent.py index 649862bc01..738be4d605 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/agent.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/agent.py @@ -3,6 +3,10 @@ from typing import Optional import grpc + +from databricks.sdk import WorkspaceClient +from databricks.sdk.service import jobs +from databricks.sdk.service.compute import DockerBasicAuth, DockerImage from flyteidl.admin.agent_pb2 import ( PENDING, SUCCEEDED, @@ -12,16 +16,9 @@ Resource, ) -# for databricks -from databricks.sdk import WorkspaceClient -from databricks.sdk.service import jobs - from flytekit import FlyteContextManager, StructuredDataset, logger from flytekit.core.type_engine import TypeEngine -from flytekit.extend.backend.base_agent import ( - AgentBase, - AgentRegistry, -) +from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry from flytekit.models import literals from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate @@ -55,21 +52,33 @@ def create( print(f"{attr_name}: {attr_value}") w = WorkspaceClient(host=custom["host"], token=custom["token"]) - + # to be done, docker image, azure, aws, gcp + docker_image_conf = custom["docker_image_conf"] + basic_auth_conf = docker_image_conf.get("basic_auth", {}) + auth = DockerBasicAuth( + username=basic_auth_conf.get("username"), + password=basic_auth_conf.get("password"), + ) + docker_image = DockerImage( + url=docker_image_conf.get("url"), + basic_auth=auth, + ) + clstr = w.clusters.create_and_wait( - cluster_name = custom["cluster_name"], - spark_version = custom["spark_version"], - node_type_id = custom["node_type_id"], - autotermination_minutes = custom["autotermination_minutes"], - num_workers = custom["num_workers"] - ) - cluster_id = clstr.cluster_id # important metadata + cluster_name=custom["cluster_name"], + docker_image=docker_image, + spark_version=custom["spark_version"], + node_type_id=custom["node_type_id"], + autotermination_minutes=custom["autotermination_minutes"], + num_workers=custom["num_workers"], + ) + cluster_id = clstr.cluster_id # important metadata tasks = [ jobs.Task( description=custom["description"], existing_cluster_id=cluster_id, - spark_python_task=jobs.SparkPythonTask(python_file=custom["python_file"]), + spark_python_task=jobs.SparkPythonTask(python_file=custom["python_file"]), task_key=custom["task_key"], # metadata timeout_seconds=custom["timeout_seconds"], # metadata ) @@ -93,8 +102,10 @@ def create( def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) - w = WorkspaceClient(host=metadata.host, - token=metadata.token,) + w = WorkspaceClient( + host=metadata.host, + token=metadata.token, + ) job = w.jobs.get_run_output(metadata.run_id) if job.error: # have already checked databricks sdk @@ -132,9 +143,10 @@ def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskRes def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) - w = WorkspaceClient(host=metadata.host, - token=metadata.token, - ) + w = WorkspaceClient( + host=metadata.host, + token=metadata.token, + ) w.jobs.delete_run(metadata.run_id) w.clusters.permanent_delete(cluster_id=metadata.cluster_id) return DeleteTaskResponse() diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index aa0b556865..3fb674d0bb 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -2,13 +2,16 @@ from dataclasses import dataclass from typing import Any, Callable, Dict, Optional, Union, cast +from google.protobuf import json_format from google.protobuf.json_format import MessageToDict +from google.protobuf.struct_pb2 import Struct from pyspark.sql import SparkSession from flytekit import FlyteContextManager, PythonFunctionTask from flytekit.configuration import DefaultImages, SerializationSettings from flytekit.core.context_manager import ExecutionParameters from flytekit.extend import ExecutionState, TaskPlugins +from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin from flytekit.image_spec import ImageSpec from .models import SparkJob, SparkType @@ -225,12 +228,29 @@ def __init__( """ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + # from databricks.sdk.service.compute import DockerBasicAuth, DockerImage + databricks_conf = getattr(self.task_config, "databricks_conf", {}) new_cluster_conf = databricks_conf.get("new_cluster", {}) + # docker_image_conf = databricks_conf.get("docker_image", {}) + # basic_auth_conf = docker_image_conf.get("basic_auth", {}) + # auth = DockerBasicAuth( + # username=basic_auth_conf.get("username"), + # password=basic_auth_conf.get("password"), + # ) + # docker_image = DockerImage( + # url=docker_image_conf.get("url"), + # basic_auth=auth, + # ) + + # AwsAttributes + # GcpAttributes + # AzureAttributes config = { "host": self.task_config.databricks_instance, "token": self.task_config.databricks_token, "cluster_name": databricks_conf.get("run_name"), + "docker_image_conf": databricks_conf.get("docker_image", {}), "spark_conf": getattr(self.task_config, "spark_conf", {}), "spark_version": new_cluster_conf.get("spark_version"), "node_type_id": new_cluster_conf.get("node_type_id"), From 003fa8b37a65c62c8564efc9e0aa8accc9780ecf Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Thu, 17 Aug 2023 10:29:16 +0800 Subject: [PATCH 03/24] rerun make lint and make fmt Signed-off-by: Future Outlier --- plugins/flytekit-spark/flytekitplugins/spark/agent.py | 1 - 1 file changed, 1 deletion(-) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/agent.py b/plugins/flytekit-spark/flytekitplugins/spark/agent.py index 738be4d605..7afd3e86fa 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/agent.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/agent.py @@ -3,7 +3,6 @@ from typing import Optional import grpc - from databricks.sdk import WorkspaceClient from databricks.sdk.service import jobs from databricks.sdk.service.compute import DockerBasicAuth, DockerImage From 7c52cbaf16689c93cf840aec07ae6f43647e3758 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Thu, 17 Aug 2023 10:41:35 +0800 Subject: [PATCH 04/24] add PERMANENT_FAILURE Signed-off-by: Future Outlier --- plugins/flytekit-spark/flytekitplugins/spark/agent.py | 1 + 1 file changed, 1 insertion(+) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/agent.py b/plugins/flytekit-spark/flytekitplugins/spark/agent.py index 7afd3e86fa..2c8e279aef 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/agent.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/agent.py @@ -8,6 +8,7 @@ from databricks.sdk.service.compute import DockerBasicAuth, DockerImage from flyteidl.admin.agent_pb2 import ( PENDING, + PERMANENT_FAILURE, SUCCEEDED, CreateTaskResponse, DeleteTaskResponse, From 442be790d22305edb880ebb05594a9b6e37520ab Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Fri, 18 Aug 2023 14:51:02 +0800 Subject: [PATCH 05/24] REST API for databricks agent v1, async get function is unsure Signed-off-by: Future Outlier --- .../flytekitplugins/spark/agent.py | 230 +++++++++++------- .../flytekitplugins/spark/task.py | 59 +---- 2 files changed, 149 insertions(+), 140 deletions(-) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/agent.py b/plugins/flytekit-spark/flytekitplugins/spark/agent.py index 2c8e279aef..0d7b4d97e2 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/agent.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/agent.py @@ -2,10 +2,8 @@ from dataclasses import asdict, dataclass from typing import Optional +import aiohttp import grpc -from databricks.sdk import WorkspaceClient -from databricks.sdk.service import jobs -from databricks.sdk.service.compute import DockerBasicAuth, DockerImage from flyteidl.admin.agent_pb2 import ( PENDING, PERMANENT_FAILURE, @@ -16,7 +14,7 @@ Resource, ) -from flytekit import FlyteContextManager, StructuredDataset, logger +from flytekit import FlyteContextManager from flytekit.core.type_engine import TypeEngine from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry from flytekit.models import literals @@ -27,129 +25,183 @@ @dataclass class Metadata: - cluster_id: str - host: str + databricks_endpoint: Optional[str] + databricks_instance: Optional[str] token: str - run_id: int + run_id: str class DatabricksAgent(AgentBase): def __init__(self): - super().__init__(task_type="spark", asynchronous=False) + super().__init__(task_type="spark") - def create( + async def async_create( self, context: grpc.ServicerContext, output_prefix: str, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, ) -> CreateTaskResponse: - for attr_name, attr_value in vars(task_template).items(): - print(f"Type:{type(attr_value)}, {attr_name}: {attr_value}") custom = task_template.custom + container = task_template.container + print("@@@ custom") for attr_name, attr_value in custom.items(): + print(f"{type(attr_value)} {attr_name}: {attr_value}") + """ + add 1.docker image 2.spark config 3.spark python task + all of them into databricks_job + """ + + databricks_job = custom["databricks_conf"] + # note: current docker image does not support basic auth + # todo image and arguments + # if not databricks_job["new_cluster"].get("docker_image"): + # databricks_job["new_cluster"]["docker_image"] = {} + # databricks_job["new_cluster"]["docker_image"]["url"] = container.image + if not databricks_job["new_cluster"].get("spark_conf"): + databricks_job["new_cluster"]["spark_conf"] = custom["spark_conf"] + databricks_job["spark_python_task"] = { + "python_file": custom["applications_path"], + "parameters": container.args, + } + + print("@@@ databricks_job") + for attr_name, attr_value in databricks_job.items(): print(f"{attr_name}: {attr_value}") - w = WorkspaceClient(host=custom["host"], token=custom["token"]) - # to be done, docker image, azure, aws, gcp - docker_image_conf = custom["docker_image_conf"] - basic_auth_conf = docker_image_conf.get("basic_auth", {}) - auth = DockerBasicAuth( - username=basic_auth_conf.get("username"), - password=basic_auth_conf.get("password"), - ) - docker_image = DockerImage( - url=docker_image_conf.get("url"), - basic_auth=auth, + # with open("/mnt/c/code/dev/example/plugins/databricks.json", "w") as f: + # f.write(str(databricks_job)) + # json.dump(databricks_job, json_file) + response = await build_request( + method="POST", + databricks_job=databricks_job, + databricks_endpoint=custom["databricks_endpoint"], + databricks_instance=custom["databricks_instance"], + token=custom["token"], + run_id="", + is_cancel=False, ) - clstr = w.clusters.create_and_wait( - cluster_name=custom["cluster_name"], - docker_image=docker_image, - spark_version=custom["spark_version"], - node_type_id=custom["node_type_id"], - autotermination_minutes=custom["autotermination_minutes"], - num_workers=custom["num_workers"], - ) - cluster_id = clstr.cluster_id # important metadata - - tasks = [ - jobs.Task( - description=custom["description"], - existing_cluster_id=cluster_id, - spark_python_task=jobs.SparkPythonTask(python_file=custom["python_file"]), - task_key=custom["task_key"], # metadata - timeout_seconds=custom["timeout_seconds"], # metadata - ) - ] - - run = w.jobs.submit( - name=custom["cluster_name"], # metadata - tasks=tasks, # tasks - ).result() - - # metadata + print("response:", response) + print(type(response["run_id"])) metadata = Metadata( - cluster_id=cluster_id, - host=custom["host"], + databricks_endpoint=custom["databricks_endpoint"], + databricks_instance=custom["databricks_instance"], token=custom["token"], - run_id=run.tasks[0].run_id, + run_id=str(response["run_id"]), ) - return CreateTaskResponse(resource_meta=json.dumps(asdict(metadata)).encode("utf-8")) - def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: + async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) - - w = WorkspaceClient( - host=metadata.host, + response = await build_request( + method="GET", + databricks_job=None, + databricks_endpoint=metadata.databricks_endpoint, + databricks_instance=metadata.databricks_instance, token=metadata.token, + run_id=metadata.run_id, + is_cancel=False, ) - job = w.jobs.get_run_output(metadata.run_id) - - if job.error: # have already checked databricks sdk - logger.error(job.errors.__str__()) - context.set_code(grpc.StatusCode.INTERNAL) - context.set_details(job.errors.__str__()) - return GetTaskResponse(resource=Resource(state=PERMANENT_FAILURE)) - - if job.metadata.state.result_state == jobs.RunResultState.SUCCESS: - cur_state = SUCCEEDED - else: - # TODO: Discuss with Kevin for the state, considering mapping technique - cur_state = PENDING + print("response:", response) + print("@@@ response") + for attr_name, attr_value in response.items(): + print(f"{type(attr_value)} {attr_name}: {attr_value}") + + """ + jobState := data["state"].(map[string]interface{}) + message := fmt.Sprintf("%s", jobState["state_message"]) + jobID := fmt.Sprintf("%.0f", data["job_id"]) + lifeCycleState := fmt.Sprintf("%s", jobState["life_cycle_state"]) + resultState := fmt.Sprintf("%s", jobState["result_state"]) + + + response->state->state_message + ->life_cycle_state + ->result_state + ->job_id + """ + + """ + cur_state = response["state"]["result_state"] + + SUCCESS + FAILED + TIMEDOUT + CANCELED + """ + + cur_state = PENDING + if response["state"].get("result_state"): + if response["state"]["result_state"] == "SUCCESS": + cur_state = SUCCEEDED + else: + context.set_code(grpc.StatusCode.INTERNAL) + return GetTaskResponse(resource=Resource(state=PERMANENT_FAILURE)) res = None - if cur_state == SUCCEEDED: ctx = FlyteContextManager.current_context() # output page_url and task_output - if job.metadata.run_page_url: - output_location = job.metadata.run_page_url - res = literals.LiteralMap( - { - "results": TypeEngine.to_literal( - ctx, - StructuredDataset(uri=output_location), - StructuredDataset, - LiteralType(structured_dataset_type=StructuredDatasetType(format="")), - ) - } - ).to_flyte_idl() - w.clusters.permanent_delete(cluster_id=metadata.cluster_id) - + res = literals.LiteralMap({}).to_flyte_idl() return GetTaskResponse(resource=Resource(state=cur_state, outputs=res)) - def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: + async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) - w = WorkspaceClient( - host=metadata.host, + await build_request( + method="POST", + databricks_job=None, + databricks_endpoint=metadata.databricks_endpoint, + databricks_instance=metadata.databricks_instance, token=metadata.token, + run_id=metadata.run_id, + is_cancel=True, ) - w.jobs.delete_run(metadata.run_id) - w.clusters.permanent_delete(cluster_id=metadata.cluster_id) return DeleteTaskResponse() +async def build_request( + method: str, + databricks_job: dict, + databricks_endpoint: str, + databricks_instance: str, + token: str, + run_id: str, + is_cancel: bool, +) -> dict: + databricksAPI = "/api/2.0/jobs/runs" + post = "POST" + + # Build the databricks URL + if not databricks_endpoint: + databricks_url = f"https://{databricks_instance}{databricksAPI}" + else: + databricks_url = f"{databricks_endpoint}{databricksAPI}" + + data = None + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + if is_cancel: + databricks_url += "/cancel" + data = json.dumps({"run_id": run_id}) + elif method == post: + databricks_url += "/submit" + try: + data = json.dumps(databricks_job) + except json.JSONDecodeError: + raise ValueError("Failed to marshal databricksJob to JSON") + else: + databricks_url += f"/get?run_id={run_id}" + + print(databricks_url) + async with aiohttp.ClientSession() as session: + if method == post: + async with session.post(databricks_url, headers=headers, data=data) as resp: + return await resp.json() + else: + async with session.get(databricks_url, headers=headers) as resp: + return await resp.json() + + AgentRegistry.register(DatabricksAgent()) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 3fb674d0bb..65abd803fe 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -187,6 +187,7 @@ class DatabricksAgentTask(Spark): databricks_conf: Optional[Dict[str, Union[str, dict]]] = None databricks_token: Optional[str] = None databricks_instance: Optional[str] = None + databricks_endpoint: Optional[str] = None class PySparkDatabricksTask(AsyncAgentExecutorMixin, PythonFunctionTask[Spark]): @@ -211,61 +212,17 @@ def __init__( **kwargs, ) - """ - For Cluster - cluster_name: str, from task_config.databricks_conf.run_name - spark_conf:Optional[Dict[str, str]] = None, from task_config.spark_conf - spark_version: str, from task_config.databricks_conf.spark_version - node_type_id: str, from task_config.databricks_conf.node_type_id - num_workers: int, from task_config.databricks_conf.num_workers - For Job - description: str, from task_config.databricks_conf.description - existing_cluster_id: str (you should create int agent, and get it by metadata) - python_file: str, from task_config.databricks_conf.description - task_key: str, from task_config.databricks_conf.description - timeout_seconds: int, from task_config.databricks_conf.description - max_retries: int, from task_config.databricks_conf.description - """ - def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: - # from databricks.sdk.service.compute import DockerBasicAuth, DockerImage - - databricks_conf = getattr(self.task_config, "databricks_conf", {}) - new_cluster_conf = databricks_conf.get("new_cluster", {}) - # docker_image_conf = databricks_conf.get("docker_image", {}) - # basic_auth_conf = docker_image_conf.get("basic_auth", {}) - # auth = DockerBasicAuth( - # username=basic_auth_conf.get("username"), - # password=basic_auth_conf.get("password"), - # ) - # docker_image = DockerImage( - # url=docker_image_conf.get("url"), - # basic_auth=auth, - # ) - - # AwsAttributes - # GcpAttributes - # AzureAttributes config = { - "host": self.task_config.databricks_instance, - "token": self.task_config.databricks_token, - "cluster_name": databricks_conf.get("run_name"), - "docker_image_conf": databricks_conf.get("docker_image", {}), - "spark_conf": getattr(self.task_config, "spark_conf", {}), - "spark_version": new_cluster_conf.get("spark_version"), - "node_type_id": new_cluster_conf.get("node_type_id"), - "autotermination_minutes": new_cluster_conf.get("autotermination_minutes"), - "num_workers": new_cluster_conf.get("num_workers"), - "timeout_seconds": databricks_conf.get("timeout_seconds"), - "max_retries": databricks_conf.get("max_retries"), - "description": databricks_conf.get("description"), - "python_file": databricks_conf.get("python_file"), - "task_key": databricks_conf.get("task_key"), + "spark_conf": getattr(self.task_config, "spark_conf", None), + "applications_path": getattr(self.task_config, "applications_path", None), + "databricks_conf": getattr(self.task_config, "databricks_conf", None), + "token": getattr(self.task_config, "databricks_token", None), + "databricks_instance": getattr(self.task_config, "databricks_instance", None), + "databricks_endpoint": getattr(self.task_config, "databricks_endpoint", None), } - s = Struct() - s.update(config) - return json_format.MessageToDict(s) + return config # Inject the Spark plugin into flytekits dynamic plugin loading system From 6e51c36e2ac07ac83a663982450c50c1f4cca82f Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Fri, 18 Aug 2023 15:02:07 +0800 Subject: [PATCH 06/24] add aiohttp in setup.py Signed-off-by: Future Outlier --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index b758a3178d..a9174501ac 100644 --- a/setup.py +++ b/setup.py @@ -75,6 +75,7 @@ "kubernetes>=12.0.1", "rich", "rich_click", + "aiohttp", ], extras_require=extras_require, scripts=[ From b9c98f1dbd56f679abaa9abd6a2a999c738f5b27 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Sun, 20 Aug 2023 15:39:56 +0800 Subject: [PATCH 07/24] databricks agent with getting token by secret Signed-off-by: Future Outlier --- flytekit/extend/backend/base_agent.py | 4 +- .../flytekitplugins/spark/agent.py | 110 ++++-------------- .../flytekitplugins/spark/task.py | 20 ++-- 3 files changed, 33 insertions(+), 101 deletions(-) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index a6f21e1b0e..00388158a9 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -136,9 +136,9 @@ def convert_to_flyte_state(state: str) -> State: Convert the state from the agent to the state in flyte. """ state = state.lower() - if state in ["failed"]: + if state in ["failed", "timedout", "canceled"]: return RETRYABLE_FAILURE - elif state in ["done", "succeeded"]: + elif state in ["done", "succeeded", "success"]: return SUCCEEDED elif state in ["running"]: return RUNNING diff --git a/plugins/flytekit-spark/flytekitplugins/spark/agent.py b/plugins/flytekit-spark/flytekitplugins/spark/agent.py index 0d7b4d97e2..1f98f078f6 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/agent.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/agent.py @@ -4,23 +4,12 @@ import aiohttp import grpc -from flyteidl.admin.agent_pb2 import ( - PENDING, - PERMANENT_FAILURE, - SUCCEEDED, - CreateTaskResponse, - DeleteTaskResponse, - GetTaskResponse, - Resource, -) - -from flytekit import FlyteContextManager -from flytekit.core.type_engine import TypeEngine -from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry -from flytekit.models import literals +from flyteidl.admin.agent_pb2 import PENDING, CreateTaskResponse, DeleteTaskResponse, GetTaskResponse, Resource + +import flytekit +from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_state from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate -from flytekit.models.types import LiteralType, StructuredDatasetType @dataclass @@ -45,20 +34,9 @@ async def async_create( custom = task_template.custom container = task_template.container - print("@@@ custom") - for attr_name, attr_value in custom.items(): - print(f"{type(attr_value)} {attr_name}: {attr_value}") - """ - add 1.docker image 2.spark config 3.spark python task - all of them into databricks_job - """ - databricks_job = custom["databricks_conf"] - # note: current docker image does not support basic auth - # todo image and arguments - # if not databricks_job["new_cluster"].get("docker_image"): - # databricks_job["new_cluster"]["docker_image"] = {} - # databricks_job["new_cluster"]["docker_image"]["url"] = container.image + if not databricks_job["new_cluster"].get("docker_image"): + databricks_job["new_cluster"]["docker_image"] = {"url": container.image} if not databricks_job["new_cluster"].get("spark_conf"): databricks_job["new_cluster"]["spark_conf"] = custom["spark_conf"] databricks_job["spark_python_task"] = { @@ -66,36 +44,33 @@ async def async_create( "parameters": container.args, } - print("@@@ databricks_job") - for attr_name, attr_value in databricks_job.items(): - print(f"{attr_name}: {attr_value}") + secrets = task_template.security_context.secrets[0] + ctx = flytekit.current_context() + token = ctx.secrets.get(group=secrets.group, key=secrets.key, group_version=secrets.group_version) - # with open("/mnt/c/code/dev/example/plugins/databricks.json", "w") as f: - # f.write(str(databricks_job)) - # json.dump(databricks_job, json_file) - response = await build_request( + response = await send_request( method="POST", databricks_job=databricks_job, databricks_endpoint=custom["databricks_endpoint"], databricks_instance=custom["databricks_instance"], - token=custom["token"], + token=token, run_id="", is_cancel=False, ) - print("response:", response) - print(type(response["run_id"])) metadata = Metadata( databricks_endpoint=custom["databricks_endpoint"], databricks_instance=custom["databricks_instance"], - token=custom["token"], + token=token, run_id=str(response["run_id"]), ) + return CreateTaskResponse(resource_meta=json.dumps(asdict(metadata)).encode("utf-8")) async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) - response = await build_request( + + response = await send_request( method="GET", databricks_job=None, databricks_endpoint=metadata.databricks_endpoint, @@ -104,52 +79,17 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) - run_id=metadata.run_id, is_cancel=False, ) - print("response:", response) - print("@@@ response") - for attr_name, attr_value in response.items(): - print(f"{type(attr_value)} {attr_name}: {attr_value}") - - """ - jobState := data["state"].(map[string]interface{}) - message := fmt.Sprintf("%s", jobState["state_message"]) - jobID := fmt.Sprintf("%.0f", data["job_id"]) - lifeCycleState := fmt.Sprintf("%s", jobState["life_cycle_state"]) - resultState := fmt.Sprintf("%s", jobState["result_state"]) - - - response->state->state_message - ->life_cycle_state - ->result_state - ->job_id - """ - - """ - cur_state = response["state"]["result_state"] - - SUCCESS - FAILED - TIMEDOUT - CANCELED - """ cur_state = PENDING if response["state"].get("result_state"): - if response["state"]["result_state"] == "SUCCESS": - cur_state = SUCCEEDED - else: - context.set_code(grpc.StatusCode.INTERNAL) - return GetTaskResponse(resource=Resource(state=PERMANENT_FAILURE)) - - res = None - if cur_state == SUCCEEDED: - ctx = FlyteContextManager.current_context() - # output page_url and task_output - res = literals.LiteralMap({}).to_flyte_idl() - return GetTaskResponse(resource=Resource(state=cur_state, outputs=res)) + cur_state = convert_to_flyte_state(response["state"]["result_state"]) + + return GetTaskResponse(resource=Resource(state=cur_state)) async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) - await build_request( + + await send_request( method="POST", databricks_job=None, databricks_endpoint=metadata.databricks_endpoint, @@ -158,10 +98,11 @@ async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes run_id=metadata.run_id, is_cancel=True, ) + return DeleteTaskResponse() -async def build_request( +async def send_request( method: str, databricks_job: dict, databricks_endpoint: str, @@ -172,16 +113,14 @@ async def build_request( ) -> dict: databricksAPI = "/api/2.0/jobs/runs" post = "POST" + data = None + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} - # Build the databricks URL if not databricks_endpoint: databricks_url = f"https://{databricks_instance}{databricksAPI}" else: databricks_url = f"{databricks_endpoint}{databricksAPI}" - data = None - headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} - if is_cancel: databricks_url += "/cancel" data = json.dumps({"run_id": run_id}) @@ -194,7 +133,6 @@ async def build_request( else: databricks_url += f"/get?run_id={run_id}" - print(databricks_url) async with aiohttp.ClientSession() as session: if method == post: async with session.post(databricks_url, headers=headers, data=data) as resp: diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 65abd803fe..0a2a23b65a 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -2,9 +2,7 @@ from dataclasses import dataclass from typing import Any, Callable, Dict, Optional, Union, cast -from google.protobuf import json_format from google.protobuf.json_format import MessageToDict -from google.protobuf.struct_pb2 import Struct from pyspark.sql import SparkSession from flytekit import FlyteContextManager, PythonFunctionTask @@ -177,11 +175,12 @@ class DatabricksAgentTask(Spark): """ Use this to configure a Databricks task. Task's marked with this will automatically execute natively onto databricks platform as a distributed execution of spark + For databricks token, you can get it from here. https://docs.databricks.com/dev-tools/api/latest/authentication.html. Args: databricks_conf: Databricks job configuration. Config structure can be found here. https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure - databricks_token: Databricks access token. https://docs.databricks.com/dev-tools/api/latest/authentication.html. databricks_instance: Domain name of your deployment. Use the form .cloud.databricks.com. + databricks_endpoint: Use for test. """ databricks_conf: Optional[Dict[str, Union[str, dict]]] = None @@ -197,29 +196,24 @@ def __init__( self, task_config: Spark, task_function: Callable, - container_image: Optional[Union[str, ImageSpec]] = None, **kwargs, ): - self.sess: Optional[SparkSession] = None - self._default_executor_path: Optional[str] = task_config.executor_path self._default_applications_path: Optional[str] = task_config.applications_path super(PySparkDatabricksTask, self).__init__( task_config=task_config, task_type=self._SPARK_TASK_TYPE, task_function=task_function, - container_image=container_image, **kwargs, ) def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: config = { - "spark_conf": getattr(self.task_config, "spark_conf", None), - "applications_path": getattr(self.task_config, "applications_path", None), - "databricks_conf": getattr(self.task_config, "databricks_conf", None), - "token": getattr(self.task_config, "databricks_token", None), - "databricks_instance": getattr(self.task_config, "databricks_instance", None), - "databricks_endpoint": getattr(self.task_config, "databricks_endpoint", None), + "spark_conf": self.task_config.spark_conf, + "applications_path": self.task_config.applications_path, + "databricks_conf": self.task_config.databricks_conf, + "databricks_instance": self.task_config.databricks_instance, + "databricks_endpoint": self.task_config.databricks_endpoint, } return config From 23c550171a94f50962460d41c2490ea13b1c7b44 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Sun, 20 Aug 2023 15:51:57 +0800 Subject: [PATCH 08/24] revise the code and delete the databricks_token member Signed-off-by: Future Outlier --- .../flytekit-spark/flytekitplugins/spark/agent.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/agent.py b/plugins/flytekit-spark/flytekitplugins/spark/agent.py index 1f98f078f6..bc8a125f9e 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/agent.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/agent.py @@ -54,7 +54,7 @@ async def async_create( databricks_endpoint=custom["databricks_endpoint"], databricks_instance=custom["databricks_instance"], token=token, - run_id="", + run_id=None, is_cancel=False, ) @@ -105,14 +105,15 @@ async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes async def send_request( method: str, databricks_job: dict, - databricks_endpoint: str, - databricks_instance: str, + databricks_endpoint: Optional[str], + databricks_instance: Optional[str], token: str, - run_id: str, + run_id: Optional[str], is_cancel: bool, ) -> dict: databricksAPI = "/api/2.0/jobs/runs" post = "POST" + get = "GET" data = None headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} @@ -130,14 +131,14 @@ async def send_request( data = json.dumps(databricks_job) except json.JSONDecodeError: raise ValueError("Failed to marshal databricksJob to JSON") - else: + elif method == get: databricks_url += f"/get?run_id={run_id}" async with aiohttp.ClientSession() as session: if method == post: async with session.post(databricks_url, headers=headers, data=data) as resp: return await resp.json() - else: + elif method == get: async with session.get(databricks_url, headers=headers) as resp: return await resp.json() From 7a98b19129a59aee534c98fea5b46a2743662eb7 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Tue, 22 Aug 2023 14:28:19 +0800 Subject: [PATCH 09/24] remove databricks_token member Signed-off-by: Future Outlier --- plugins/flytekit-spark/flytekitplugins/spark/task.py | 1 - 1 file changed, 1 deletion(-) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 0a2a23b65a..b00d2d0e19 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -184,7 +184,6 @@ class DatabricksAgentTask(Spark): """ databricks_conf: Optional[Dict[str, Union[str, dict]]] = None - databricks_token: Optional[str] = None databricks_instance: Optional[str] = None databricks_endpoint: Optional[str] = None From fa2059de677bbdf1cb6aa3975ef2c55d4c9ae165 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Tue, 22 Aug 2023 15:04:48 +0800 Subject: [PATCH 10/24] add databricks agent test Signed-off-by: Future Outlier --- plugins/flytekit-spark/tests/test_agent.py | 145 +++++++++++++++++++++ 1 file changed, 145 insertions(+) create mode 100644 plugins/flytekit-spark/tests/test_agent.py diff --git a/plugins/flytekit-spark/tests/test_agent.py b/plugins/flytekit-spark/tests/test_agent.py new file mode 100644 index 0000000000..30b8567e91 --- /dev/null +++ b/plugins/flytekit-spark/tests/test_agent.py @@ -0,0 +1,145 @@ +import json +from dataclasses import asdict +from datetime import timedelta +from unittest import mock +from unittest.mock import MagicMock + +import grpc +import pytest +from aioresponses import aioresponses +from flyteidl.admin.agent_pb2 import SUCCEEDED +from flytekitplugins.spark.agent import Metadata + +from flytekit import Secret +from flytekit.extend.backend.base_agent import AgentRegistry +from flytekit.interfaces.cli_identifiers import Identifier +from flytekit.models import literals, task +from flytekit.models.core.identifier import ResourceType +from flytekit.models.security import SecurityContext +from flytekit.models.task import Container, Resources, TaskTemplate + + +@pytest.mark.asyncio +async def test_databricks_agent(): + ctx = MagicMock(spec=grpc.ServicerContext) + agent = AgentRegistry.get_agent(ctx, "spark") + + task_id = Identifier( + resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version" + ) + task_metadata = task.TaskMetadata( + True, + task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timedelta(days=1), + literals.RetryStrategy(3), + True, + "0.1.1b0", + "This is deprecated!", + True, + "A", + ) + task_config = { + "spark_conf": { + "spark.driver.memory": "1000M", + "spark.executor.memory": "1000M", + "spark.executor.cores": "1", + "spark.executor.instances": "2", + "spark.driver.cores": "1", + }, + "applications_path": "dbfs:/entrypoint.py", + "databricks_conf": { + "run_name": "flytekit databricks plugin example", + "new_cluster": { + "spark_version": "12.2.x-scala2.12", + "node_type_id": "n2-highmem-4", + "num_workers": 1, + }, + "timeout_seconds": 3600, + "max_retries": 1, + }, + "databricks_instance": "test-account.cloud.databricks.com", + "databricks_endpoint": None, + } + container = Container( + image="flyteorg/flytekit:databricks-0.18.0-py3.7", + command=[], + args=[ + "pyflyte-execute", + "--inputs", + "{{.input}}", + "--output-prefix", + "{{.outputPrefix}}", + "--raw-output-data-prefix", + "{{.rawOutputDataPrefix}}", + "--checkpoint-path", + "{{.checkpointOutputPrefix}}", + "--prev-checkpoint", + "{{.prevCheckpointPrefix}}", + "--resolver", + "flytekit.core.python_auto_container.default_task_resolver", + "--", + "task-module", + "", + "task-name", + "hello_spark", + ], + resources=Resources( + requests=[], + limits=[], + ), + env={}, + config={}, + ) + + SECRET_GROUP = "token-info" + SECRET_NAME = "token_secret" + mocked_token = "mocked_secret_token" + dummy_template = TaskTemplate( + id=task_id, + custom=task_config, + metadata=task_metadata, + container=container, + interface=None, + type="spark", + security_context=SecurityContext( + secrets=Secret( + group=SECRET_GROUP, + key=SECRET_NAME, + mount_requirement=Secret.MountType.ENV_VAR, + ) + ), + ) + mocked_context = mock.patch("flytekit.current_context", autospec=True).start() + mocked_context.return_value.secrets.get.return_value = mocked_token + + metadata_bytes = json.dumps( + asdict( + Metadata( + databricks_endpoint=None, + databricks_instance="test-account.cloud.databricks.com", + token=mocked_token, + run_id="123", + ) + ) + ).encode("utf-8") + + mock_create_response = {"run_id": "123"} + mock_get_response = {"run_id": "123", "state": {"result_state": "SUCCESS"}} + mock_delete_response = {} + create_url = "https://test-account.cloud.databricks.com/api/2.0/jobs/runs/submit" + get_url = "https://test-account.cloud.databricks.com/api/2.0/jobs/runs/get?run_id=123" + delete_url = "https://test-account.cloud.databricks.com/api/2.0/jobs/runs/cancel" + with aioresponses() as mocked: + mocked.post(create_url, status=200, payload=mock_create_response) + res = await agent.async_create(ctx, "/tmp", dummy_template, None) + assert res.resource_meta == metadata_bytes + + mocked.get(get_url, status=200, payload=mock_get_response) + res = await agent.async_get(ctx, metadata_bytes) + assert res.resource.state == SUCCEEDED + assert res.resource.outputs == literals.LiteralMap({}).to_flyte_idl() + + mocked.post(delete_url, status=200, payload=mock_delete_response) + await agent.async_delete(ctx, metadata_bytes) + + mock.patch.stopall() From 6b3e7451bb360c46f2b85af3b09f3e1fbc0826d6 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Tue, 22 Aug 2023 16:19:56 +0800 Subject: [PATCH 11/24] revise by kevin Signed-off-by: Future Outlier --- flytekit/extend/backend/agent_service.py | 3 - .../flytekitplugins/spark/agent.py | 130 ++++++------------ .../flytekitplugins/spark/task.py | 6 +- plugins/flytekit-spark/setup.py | 2 +- setup.py | 1 - 5 files changed, 43 insertions(+), 99 deletions(-) diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index 470bd01e2e..fe169ccd71 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -44,7 +44,6 @@ async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerCon logger.error(f"failed to run sync create with error {e}") raise except Exception as e: - logger.error(f"failed to create task with error {e}") context.set_code(grpc.StatusCode.INTERNAL) context.set_details(f"failed to create task with error {e}") @@ -66,7 +65,6 @@ async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) logger.error(f"failed to run sync get with error {e}") raise except Exception as e: - logger.error(f"failed to get task with error {e}") context.set_code(grpc.StatusCode.INTERNAL) context.set_details(f"failed to get task with error {e}") @@ -88,6 +86,5 @@ async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerCon logger.error(f"failed to run sync delete with error {e}") raise except Exception as e: - logger.error(f"failed to delete task with error {e}") context.set_code(grpc.StatusCode.INTERNAL) context.set_details(f"failed to delete task with error {e}") diff --git a/plugins/flytekit-spark/flytekitplugins/spark/agent.py b/plugins/flytekit-spark/flytekitplugins/spark/agent.py index bc8a125f9e..3d7c60d0f8 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/agent.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/agent.py @@ -1,5 +1,7 @@ import json -from dataclasses import asdict, dataclass +import pickle +import typing +from dataclasses import dataclass from typing import Optional import aiohttp @@ -14,9 +16,7 @@ @dataclass class Metadata: - databricks_endpoint: Optional[str] - databricks_instance: Optional[str] - token: str + databricks_instance: str run_id: str @@ -35,112 +35,62 @@ async def async_create( custom = task_template.custom container = task_template.container databricks_job = custom["databricks_conf"] - if not databricks_job["new_cluster"].get("docker_image"): + if databricks_job["new_cluster"].get("docker_image"): databricks_job["new_cluster"]["docker_image"] = {"url": container.image} - if not databricks_job["new_cluster"].get("spark_conf"): + if databricks_job["new_cluster"].get("spark_conf"): databricks_job["new_cluster"]["spark_conf"] = custom["spark_conf"] databricks_job["spark_python_task"] = { "python_file": custom["applications_path"], - "parameters": container.args, + "parameters": tuple(container.args), } - secrets = task_template.security_context.secrets[0] - ctx = flytekit.current_context() - token = ctx.secrets.get(group=secrets.group, key=secrets.key, group_version=secrets.group_version) - - response = await send_request( - method="POST", - databricks_job=databricks_job, - databricks_endpoint=custom["databricks_endpoint"], - databricks_instance=custom["databricks_instance"], - token=token, - run_id=None, - is_cancel=False, - ) + databricks_instance = custom["databricks_instance"] + databricks_url = f"https://{databricks_instance}/api/2.0/jobs/runs/submit" + data = json.dumps(databricks_job) + + async with aiohttp.ClientSession() as session: + async with session.post(databricks_url, headers=get_header(), data=data) as resp: + response = await resp.json() metadata = Metadata( - databricks_endpoint=custom["databricks_endpoint"], - databricks_instance=custom["databricks_instance"], - token=token, + databricks_instance=databricks_instance, run_id=str(response["run_id"]), ) - - return CreateTaskResponse(resource_meta=json.dumps(asdict(metadata)).encode("utf-8")) + return CreateTaskResponse(resource_meta=pickle.dumps(metadata)) async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: - metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) - - response = await send_request( - method="GET", - databricks_job=None, - databricks_endpoint=metadata.databricks_endpoint, - databricks_instance=metadata.databricks_instance, - token=metadata.token, - run_id=metadata.run_id, - is_cancel=False, - ) + metadata = pickle.loads(resource_meta) + databricks_instance = metadata.databricks_instance + databricks_url = f"https://{databricks_instance}/api/2.0/jobs/runs/get?run_id={metadata.run_id}" + + async with aiohttp.ClientSession() as session: + async with session.get(databricks_url, headers=get_header()) as resp: + response = await resp.json() cur_state = PENDING - if response["state"].get("result_state"): + if response.get("state"): cur_state = convert_to_flyte_state(response["state"]["result_state"]) return GetTaskResponse(resource=Resource(state=cur_state)) async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: - metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) - - await send_request( - method="POST", - databricks_job=None, - databricks_endpoint=metadata.databricks_endpoint, - databricks_instance=metadata.databricks_instance, - token=metadata.token, - run_id=metadata.run_id, - is_cancel=True, - ) + metadata = pickle.loads(resource_meta) + + databricks_url = f"https://{metadata.databricks_instance}/api/2.0/jobs/runs/cancel" + data = json.dumps({"run_id": metadata.run_id}) + + async with aiohttp.ClientSession() as session: + async with session.post(databricks_url, headers=get_header(), data=data) as resp: + if resp.status != 200: + raise Exception(f"Failed to cancel job {metadata.run_id}") + await resp.json() return DeleteTaskResponse() -async def send_request( - method: str, - databricks_job: dict, - databricks_endpoint: Optional[str], - databricks_instance: Optional[str], - token: str, - run_id: Optional[str], - is_cancel: bool, -) -> dict: - databricksAPI = "/api/2.0/jobs/runs" - post = "POST" - get = "GET" - data = None - headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} - - if not databricks_endpoint: - databricks_url = f"https://{databricks_instance}{databricksAPI}" - else: - databricks_url = f"{databricks_endpoint}{databricksAPI}" - - if is_cancel: - databricks_url += "/cancel" - data = json.dumps({"run_id": run_id}) - elif method == post: - databricks_url += "/submit" - try: - data = json.dumps(databricks_job) - except json.JSONDecodeError: - raise ValueError("Failed to marshal databricksJob to JSON") - elif method == get: - databricks_url += f"/get?run_id={run_id}" - - async with aiohttp.ClientSession() as session: - if method == post: - async with session.post(databricks_url, headers=headers, data=data) as resp: - return await resp.json() - elif method == get: - async with session.get(databricks_url, headers=headers) as resp: - return await resp.json() - - -AgentRegistry.register(DatabricksAgent()) +def get_header() -> typing.Dict[str, str]: + token = flytekit.current_context().secrets.get("databricks", "token") + return {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + +AgentRegistry.register(DatabricksAgent()) \ No newline at end of file diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index b00d2d0e19..e2c5e05d96 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -176,7 +176,6 @@ class DatabricksAgentTask(Spark): Use this to configure a Databricks task. Task's marked with this will automatically execute natively onto databricks platform as a distributed execution of spark For databricks token, you can get it from here. https://docs.databricks.com/dev-tools/api/latest/authentication.html. - Args: databricks_conf: Databricks job configuration. Config structure can be found here. https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure databricks_instance: Domain name of your deployment. Use the form .cloud.databricks.com. @@ -188,12 +187,12 @@ class DatabricksAgentTask(Spark): databricks_endpoint: Optional[str] = None -class PySparkDatabricksTask(AsyncAgentExecutorMixin, PythonFunctionTask[Spark]): +class PySparkDatabricksTask(AsyncAgentExecutorMixin, PythonFunctionTask[DatabricksAgentTask]): _SPARK_TASK_TYPE = "spark" def __init__( self, - task_config: Spark, + task_config: DatabricksAgentTask, task_function: Callable, **kwargs, ): @@ -214,7 +213,6 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: "databricks_instance": self.task_config.databricks_instance, "databricks_endpoint": self.task_config.databricks_endpoint, } - return config diff --git a/plugins/flytekit-spark/setup.py b/plugins/flytekit-spark/setup.py index 4207a0265c..21305263a6 100644 --- a/plugins/flytekit-spark/setup.py +++ b/plugins/flytekit-spark/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "pyspark>=3.0.0"] +plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "pyspark>=3.0.0", "aiohttp"] __version__ = "0.0.0+develop" diff --git a/setup.py b/setup.py index a9174501ac..b758a3178d 100644 --- a/setup.py +++ b/setup.py @@ -75,7 +75,6 @@ "kubernetes>=12.0.1", "rich", "rich_click", - "aiohttp", ], extras_require=extras_require, scripts=[ From 8f9dcda5d4705d8fc204cd54907c98ecfb540f7f Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Tue, 22 Aug 2023 16:49:59 +0800 Subject: [PATCH 12/24] edit get function Signed-off-by: Future Outlier --- plugins/flytekit-spark/flytekitplugins/spark/agent.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/agent.py b/plugins/flytekit-spark/flytekitplugins/spark/agent.py index 3d7c60d0f8..cfbd4ef041 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/agent.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/agent.py @@ -31,7 +31,6 @@ async def async_create( task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, ) -> CreateTaskResponse: - custom = task_template.custom container = task_template.container databricks_job = custom["databricks_conf"] @@ -68,7 +67,7 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) - response = await resp.json() cur_state = PENDING - if response.get("state"): + if response.get("state") and response["state"].get("result_state"): cur_state = convert_to_flyte_state(response["state"]["result_state"]) return GetTaskResponse(resource=Resource(state=cur_state)) @@ -93,4 +92,4 @@ def get_header() -> typing.Dict[str, str]: return {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} -AgentRegistry.register(DatabricksAgent()) \ No newline at end of file +AgentRegistry.register(DatabricksAgent()) From bf857ac5f46f130ea87bac3ba113e93e30aa7c1b Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Tue, 22 Aug 2023 17:06:25 +0800 Subject: [PATCH 13/24] add spark plugin_requires in setup.py Signed-off-by: Future Outlier --- plugins/flytekit-spark/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-spark/setup.py b/plugins/flytekit-spark/setup.py index 21305263a6..f47d9a0a8e 100644 --- a/plugins/flytekit-spark/setup.py +++ b/plugins/flytekit-spark/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "pyspark>=3.0.0", "aiohttp"] +plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "pyspark>=3.0.0", "aiohttp", "aioresponses", "pytest-asyncio"] __version__ = "0.0.0+develop" From 9c20c4b00fbe981c344ac727e609c8d5b50c3c90 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Tue, 22 Aug 2023 17:46:58 +0800 Subject: [PATCH 14/24] Refactor and Revise test_agent.py after kevin's refactor Signed-off-by: Future Outlier --- plugins/flytekit-spark/tests/test_agent.py | 35 +++++++--------------- 1 file changed, 11 insertions(+), 24 deletions(-) diff --git a/plugins/flytekit-spark/tests/test_agent.py b/plugins/flytekit-spark/tests/test_agent.py index 30b8567e91..6a2408761d 100644 --- a/plugins/flytekit-spark/tests/test_agent.py +++ b/plugins/flytekit-spark/tests/test_agent.py @@ -1,5 +1,4 @@ -import json -from dataclasses import asdict +import pickle from datetime import timedelta from unittest import mock from unittest.mock import MagicMock @@ -8,14 +7,12 @@ import pytest from aioresponses import aioresponses from flyteidl.admin.agent_pb2 import SUCCEEDED -from flytekitplugins.spark.agent import Metadata +from flytekitplugins.spark.agent import Metadata, get_header -from flytekit import Secret from flytekit.extend.backend.base_agent import AgentRegistry from flytekit.interfaces.cli_identifiers import Identifier from flytekit.models import literals, task from flytekit.models.core.identifier import ResourceType -from flytekit.models.security import SecurityContext from flytekit.models.task import Container, Resources, TaskTemplate @@ -91,9 +88,6 @@ async def test_databricks_agent(): config={}, ) - SECRET_GROUP = "token-info" - SECRET_NAME = "token_secret" - mocked_token = "mocked_secret_token" dummy_template = TaskTemplate( id=task_id, custom=task_config, @@ -101,27 +95,17 @@ async def test_databricks_agent(): container=container, interface=None, type="spark", - security_context=SecurityContext( - secrets=Secret( - group=SECRET_GROUP, - key=SECRET_NAME, - mount_requirement=Secret.MountType.ENV_VAR, - ) - ), ) + mocked_token = "mocked_databricks_token" mocked_context = mock.patch("flytekit.current_context", autospec=True).start() mocked_context.return_value.secrets.get.return_value = mocked_token - metadata_bytes = json.dumps( - asdict( - Metadata( - databricks_endpoint=None, - databricks_instance="test-account.cloud.databricks.com", - token=mocked_token, - run_id="123", - ) + metadata_bytes = pickle.dumps( + Metadata( + databricks_instance="test-account.cloud.databricks.com", + run_id="123", ) - ).encode("utf-8") + ) mock_create_response = {"run_id": "123"} mock_get_response = {"run_id": "123", "state": {"result_state": "SUCCESS"}} @@ -142,4 +126,7 @@ async def test_databricks_agent(): mocked.post(delete_url, status=200, payload=mock_delete_response) await agent.async_delete(ctx, metadata_bytes) + mocked_header = {"Authorization": f"Bearer {mocked_token}", "Content-Type": "application/json"} + assert mocked_header == get_header() + mock.patch.stopall() From 77d2d705f0fe621a79c65189f53ba3532c06c36c Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Wed, 23 Aug 2023 13:30:03 +0800 Subject: [PATCH 15/24] remove databricks endpoint member Signed-off-by: Future Outlier --- plugins/flytekit-spark/flytekitplugins/spark/task.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index e2c5e05d96..173eb857fa 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -184,7 +184,6 @@ class DatabricksAgentTask(Spark): databricks_conf: Optional[Dict[str, Union[str, dict]]] = None databricks_instance: Optional[str] = None - databricks_endpoint: Optional[str] = None class PySparkDatabricksTask(AsyncAgentExecutorMixin, PythonFunctionTask[DatabricksAgentTask]): @@ -211,7 +210,6 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: "applications_path": self.task_config.applications_path, "databricks_conf": self.task_config.databricks_conf, "databricks_instance": self.task_config.databricks_instance, - "databricks_endpoint": self.task_config.databricks_endpoint, } return config From ab01850de44e01c1a5703f08699d753901d3c8e6 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Wed, 23 Aug 2023 13:30:49 +0800 Subject: [PATCH 16/24] fix databricks test_agent.py args error Signed-off-by: Future Outlier --- plugins/flytekit-spark/tests/test_agent.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/plugins/flytekit-spark/tests/test_agent.py b/plugins/flytekit-spark/tests/test_agent.py index 6a2408761d..54a3dcc665 100644 --- a/plugins/flytekit-spark/tests/test_agent.py +++ b/plugins/flytekit-spark/tests/test_agent.py @@ -63,15 +63,15 @@ async def test_databricks_agent(): args=[ "pyflyte-execute", "--inputs", - "{{.input}}", + "s3://bucket-name/path/to/object", "--output-prefix", - "{{.outputPrefix}}", + "s3://bucket-name/path/to/object", "--raw-output-data-prefix", - "{{.rawOutputDataPrefix}}", + "s3://bucket-name/path/to/object", "--checkpoint-path", - "{{.checkpointOutputPrefix}}", + "s3://bucket-name/path/to/object", "--prev-checkpoint", - "{{.prevCheckpointPrefix}}", + "s3://bucket-name/path/to/object", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", From 232d19d578add71e637f6cca94ff22550a5831bb Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Wed, 30 Aug 2023 16:07:02 +0800 Subject: [PATCH 17/24] Databricks Agent With Agent Server Only Signed-off-by: Future Outlier --- plugins/flytekit-spark/dev-requirements.in | 2 + plugins/flytekit-spark/dev-requirements.txt | 44 ++++++++++++++++++ .../flytekitplugins/spark/__init__.py | 2 +- .../flytekitplugins/spark/agent.py | 12 ++--- .../flytekitplugins/spark/task.py | 46 ------------------- plugins/flytekit-spark/setup.py | 2 +- plugins/flytekit-spark/tests/test_agent.py | 27 ++++++----- 7 files changed, 70 insertions(+), 65 deletions(-) create mode 100644 plugins/flytekit-spark/dev-requirements.in create mode 100644 plugins/flytekit-spark/dev-requirements.txt diff --git a/plugins/flytekit-spark/dev-requirements.in b/plugins/flytekit-spark/dev-requirements.in new file mode 100644 index 0000000000..a5f1f83bf9 --- /dev/null +++ b/plugins/flytekit-spark/dev-requirements.in @@ -0,0 +1,2 @@ +aioresponses +pytest-asyncio \ No newline at end of file diff --git a/plugins/flytekit-spark/dev-requirements.txt b/plugins/flytekit-spark/dev-requirements.txt new file mode 100644 index 0000000000..8d30230498 --- /dev/null +++ b/plugins/flytekit-spark/dev-requirements.txt @@ -0,0 +1,44 @@ +# +# This file is autogenerated by pip-compile with Python 3.9 +# by the following command: +# +# pip-compile dev-requirements.in +# +aiohttp==3.8.5 + # via aioresponses +aioresponses==0.7.4 + # via -r dev-requirements.in +aiosignal==1.3.1 + # via aiohttp +async-timeout==4.0.3 + # via aiohttp +attrs==23.1.0 + # via aiohttp +charset-normalizer==3.2.0 + # via aiohttp +exceptiongroup==1.1.3 + # via pytest +frozenlist==1.4.0 + # via + # aiohttp + # aiosignal +idna==3.4 + # via yarl +iniconfig==2.0.0 + # via pytest +multidict==6.0.4 + # via + # aiohttp + # yarl +packaging==23.1 + # via pytest +pluggy==1.3.0 + # via pytest +pytest==7.4.0 + # via pytest-asyncio +pytest-asyncio==0.21.1 + # via -r dev-requirements.in +tomli==2.0.1 + # via pytest +yarl==1.9.2 + # via aiohttp diff --git a/plugins/flytekit-spark/flytekitplugins/spark/__init__.py b/plugins/flytekit-spark/flytekitplugins/spark/__init__.py index 9cee417247..72c9f37c9f 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/__init__.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/__init__.py @@ -21,4 +21,4 @@ from .pyspark_transformers import PySparkPipelineModelTransformer from .schema import SparkDataFrameSchemaReader, SparkDataFrameSchemaWriter, SparkDataFrameTransformer # noqa from .sd_transformers import ParquetToSparkDecodingHandler, SparkToParquetEncodingHandler -from .task import Databricks, DatabricksAgentTask, Spark, new_spark_session # noqa +from .task import Databricks, Spark, new_spark_session # noqa diff --git a/plugins/flytekit-spark/flytekitplugins/spark/agent.py b/plugins/flytekit-spark/flytekitplugins/spark/agent.py index cfbd4ef041..60fff230b1 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/agent.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/agent.py @@ -33,17 +33,17 @@ async def async_create( ) -> CreateTaskResponse: custom = task_template.custom container = task_template.container - databricks_job = custom["databricks_conf"] - if databricks_job["new_cluster"].get("docker_image"): + databricks_job = custom["databricksConf"] + if not databricks_job["new_cluster"].get("docker_image"): databricks_job["new_cluster"]["docker_image"] = {"url": container.image} - if databricks_job["new_cluster"].get("spark_conf"): - databricks_job["new_cluster"]["spark_conf"] = custom["spark_conf"] + if not databricks_job["new_cluster"].get("spark_conf"): + databricks_job["new_cluster"]["spark_conf"] = custom["sparkConf"] databricks_job["spark_python_task"] = { - "python_file": custom["applications_path"], + "python_file": custom["mainApplicationFile"], "parameters": tuple(container.args), } - databricks_instance = custom["databricks_instance"] + databricks_instance = custom["databricksInstance"] databricks_url = f"https://{databricks_instance}/api/2.0/jobs/runs/submit" data = json.dumps(databricks_job) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 173eb857fa..0d8ecd5b6e 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -9,7 +9,6 @@ from flytekit.configuration import DefaultImages, SerializationSettings from flytekit.core.context_manager import ExecutionParameters from flytekit.extend import ExecutionState, TaskPlugins -from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin from flytekit.image_spec import ImageSpec from .models import SparkJob, SparkType @@ -170,51 +169,6 @@ def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: return user_params.builder().add_attr("SPARK_SESSION", self.sess).build() -@dataclass -class DatabricksAgentTask(Spark): - """ - Use this to configure a Databricks task. Task's marked with this will automatically execute - natively onto databricks platform as a distributed execution of spark - For databricks token, you can get it from here. https://docs.databricks.com/dev-tools/api/latest/authentication.html. - Args: - databricks_conf: Databricks job configuration. Config structure can be found here. https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure - databricks_instance: Domain name of your deployment. Use the form .cloud.databricks.com. - databricks_endpoint: Use for test. - """ - - databricks_conf: Optional[Dict[str, Union[str, dict]]] = None - databricks_instance: Optional[str] = None - - -class PySparkDatabricksTask(AsyncAgentExecutorMixin, PythonFunctionTask[DatabricksAgentTask]): - _SPARK_TASK_TYPE = "spark" - - def __init__( - self, - task_config: DatabricksAgentTask, - task_function: Callable, - **kwargs, - ): - self._default_applications_path: Optional[str] = task_config.applications_path - - super(PySparkDatabricksTask, self).__init__( - task_config=task_config, - task_type=self._SPARK_TASK_TYPE, - task_function=task_function, - **kwargs, - ) - - def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: - config = { - "spark_conf": self.task_config.spark_conf, - "applications_path": self.task_config.applications_path, - "databricks_conf": self.task_config.databricks_conf, - "databricks_instance": self.task_config.databricks_instance, - } - return config - - # Inject the Spark plugin into flytekits dynamic plugin loading system TaskPlugins.register_pythontask_plugin(Spark, PysparkFunctionTask) TaskPlugins.register_pythontask_plugin(Databricks, PysparkFunctionTask) -TaskPlugins.register_pythontask_plugin(DatabricksAgentTask, PySparkDatabricksTask) diff --git a/plugins/flytekit-spark/setup.py b/plugins/flytekit-spark/setup.py index f47d9a0a8e..21305263a6 100644 --- a/plugins/flytekit-spark/setup.py +++ b/plugins/flytekit-spark/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "pyspark>=3.0.0", "aiohttp", "aioresponses", "pytest-asyncio"] +plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "pyspark>=3.0.0", "aiohttp"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-spark/tests/test_agent.py b/plugins/flytekit-spark/tests/test_agent.py index 54a3dcc665..547443fa89 100644 --- a/plugins/flytekit-spark/tests/test_agent.py +++ b/plugins/flytekit-spark/tests/test_agent.py @@ -36,15 +36,15 @@ async def test_databricks_agent(): "A", ) task_config = { - "spark_conf": { + "sparkConf": { "spark.driver.memory": "1000M", "spark.executor.memory": "1000M", "spark.executor.cores": "1", "spark.executor.instances": "2", "spark.driver.cores": "1", }, - "applications_path": "dbfs:/entrypoint.py", - "databricks_conf": { + "mainApplicationFile": "dbfs:/entrypoint.py", + "databricksConf": { "run_name": "flytekit databricks plugin example", "new_cluster": { "spark_version": "12.2.x-scala2.12", @@ -54,29 +54,34 @@ async def test_databricks_agent(): "timeout_seconds": 3600, "max_retries": 1, }, - "databricks_instance": "test-account.cloud.databricks.com", - "databricks_endpoint": None, + "databricksInstance": "test-account.cloud.databricks.com", } container = Container( image="flyteorg/flytekit:databricks-0.18.0-py3.7", command=[], args=[ + "pyflyte-fast-execute", + "--additional-distribution", + "s3://my-s3-bucket/flytesnacks/development/24UYJEF2HDZQN3SG4VAZSM4PLI======/script_mode.tar.gz", + "--dest-dir", + "/root", + "--", "pyflyte-execute", "--inputs", - "s3://bucket-name/path/to/object", + "s3://my-s3-bucket", "--output-prefix", - "s3://bucket-name/path/to/object", + "s3://my-s3-bucket", "--raw-output-data-prefix", - "s3://bucket-name/path/to/object", + "s3://my-s3-bucket", "--checkpoint-path", - "s3://bucket-name/path/to/object", + "s3://my-s3-bucket", "--prev-checkpoint", - "s3://bucket-name/path/to/object", + "s3://my-s3-bucket", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", "task-module", - "", + "spark_local_example", "task-name", "hello_spark", ], From ef2b2f75d424ba1e7393b31d98196c31dae550ac Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Wed, 30 Aug 2023 16:46:37 +0800 Subject: [PATCH 18/24] Fix dev-requirements.in lint error Signed-off-by: Future Outlier --- plugins/flytekit-spark/dev-requirements.in | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-spark/dev-requirements.in b/plugins/flytekit-spark/dev-requirements.in index a5f1f83bf9..78d0eca127 100644 --- a/plugins/flytekit-spark/dev-requirements.in +++ b/plugins/flytekit-spark/dev-requirements.in @@ -1,2 +1,2 @@ aioresponses -pytest-asyncio \ No newline at end of file +pytest-asyncio From d65525ff42abc03e91bd7d21506775c9ff351772 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 4 Sep 2023 02:39:08 -0700 Subject: [PATCH 19/24] error handle Signed-off-by: Kevin Su --- plugins/flytekit-spark/flytekitplugins/spark/agent.py | 10 ++++++++-- plugins/flytekit-spark/flytekitplugins/spark/task.py | 8 ++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/agent.py b/plugins/flytekit-spark/flytekitplugins/spark/agent.py index 60fff230b1..5c8fa4780d 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/agent.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/agent.py @@ -49,6 +49,10 @@ async def async_create( async with aiohttp.ClientSession() as session: async with session.post(databricks_url, headers=get_header(), data=data) as resp: + if resp.status != 200: + print(resp.content) + print(resp) + raise Exception(f"Failed to create databricks job with error: {resp.reason}") response = await resp.json() metadata = Metadata( @@ -64,6 +68,8 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) - async with aiohttp.ClientSession() as session: async with session.get(databricks_url, headers=get_header()) as resp: + if resp.status != 200: + raise Exception(f"Failed to get databricks job {metadata.run_id} with error: {resp.reason}") response = await resp.json() cur_state = PENDING @@ -81,7 +87,7 @@ async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes async with aiohttp.ClientSession() as session: async with session.post(databricks_url, headers=get_header(), data=data) as resp: if resp.status != 200: - raise Exception(f"Failed to cancel job {metadata.run_id}") + raise Exception(f"Failed to cancel databricks job {metadata.run_id} with error: {resp.reason}") await resp.json() return DeleteTaskResponse() @@ -89,7 +95,7 @@ async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes def get_header() -> typing.Dict[str, str]: token = flytekit.current_context().secrets.get("databricks", "token") - return {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + return {"Authorization": f"Bearer {token}", "content-type": "application/json"} AgentRegistry.register(DatabricksAgent()) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 0d8ecd5b6e..53ac994fde 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -111,16 +111,16 @@ def __init__( **kwargs, ): self.sess: Optional[SparkSession] = None - self._default_executor_path: Optional[str] = task_config.executor_path - self._default_applications_path: Optional[str] = task_config.applications_path + self._default_executor_path: str = task_config.executor_path + self._default_applications_path: str = task_config.applications_path if isinstance(container_image, ImageSpec): if container_image.base_image is None: img = f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}" container_image.base_image = img # default executor path and applications path in apache/spark-py:3.3.1 - self._default_executor_path = "/usr/bin/python3" - self._default_applications_path = "local:///usr/local/bin/entrypoint.py" + self._default_executor_path = self._default_executor_path or "/usr/bin/python3" + self._default_applications_path = self._default_applications_path or "local:///usr/local/bin/entrypoint.py" super(PysparkFunctionTask, self).__init__( task_config=task_config, task_type=self._SPARK_TASK_TYPE, From 550973d8ba3b8461b7cf780b2d2221ff07584f52 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 4 Sep 2023 02:44:44 -0700 Subject: [PATCH 20/24] lint Signed-off-by: Kevin Su --- plugins/flytekit-spark/flytekitplugins/spark/task.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 53ac994fde..17099350e4 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -120,7 +120,9 @@ def __init__( container_image.base_image = img # default executor path and applications path in apache/spark-py:3.3.1 self._default_executor_path = self._default_executor_path or "/usr/bin/python3" - self._default_applications_path = self._default_applications_path or "local:///usr/local/bin/entrypoint.py" + self._default_applications_path = ( + self._default_applications_path or "local:///usr/local/bin/entrypoint.py" + ) super(PysparkFunctionTask, self).__init__( task_config=task_config, task_type=self._SPARK_TASK_TYPE, From 008cf7b465fc174f3e0fabe0bcd5d9eec9fa0efe Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 4 Sep 2023 03:08:41 -0700 Subject: [PATCH 21/24] nit Signed-off-by: Kevin Su --- plugins/flytekit-spark/flytekitplugins/spark/agent.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/agent.py b/plugins/flytekit-spark/flytekitplugins/spark/agent.py index 5c8fa4780d..1b8c31d674 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/agent.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/agent.py @@ -50,8 +50,6 @@ async def async_create( async with aiohttp.ClientSession() as session: async with session.post(databricks_url, headers=get_header(), data=data) as resp: if resp.status != 200: - print(resp.content) - print(resp) raise Exception(f"Failed to create databricks job with error: {resp.reason}") response = await resp.json() From 502ae3f5457967d9cded5af50a795811aa9dc735 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Mon, 4 Sep 2023 19:57:05 +0800 Subject: [PATCH 22/24] fix the mocked header in test Signed-off-by: Future Outlier --- plugins/flytekit-spark/tests/test_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-spark/tests/test_agent.py b/plugins/flytekit-spark/tests/test_agent.py index 547443fa89..24fb074844 100644 --- a/plugins/flytekit-spark/tests/test_agent.py +++ b/plugins/flytekit-spark/tests/test_agent.py @@ -131,7 +131,7 @@ async def test_databricks_agent(): mocked.post(delete_url, status=200, payload=mock_delete_response) await agent.async_delete(ctx, metadata_bytes) - mocked_header = {"Authorization": f"Bearer {mocked_token}", "Content-Type": "application/json"} + mocked_header = {"Authorization": f"Bearer {mocked_token}", "content-type": "application/json"} assert mocked_header == get_header() mock.patch.stopall() From 5b091b9f4f1982b721bd6ea6d19bea29c7143140 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Fri, 8 Sep 2023 08:27:47 +0800 Subject: [PATCH 23/24] update spark agent test Signed-off-by: Future Outlier --- plugins/flytekit-spark/tests/test_agent.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/plugins/flytekit-spark/tests/test_agent.py b/plugins/flytekit-spark/tests/test_agent.py index 24fb074844..1b3941d7f6 100644 --- a/plugins/flytekit-spark/tests/test_agent.py +++ b/plugins/flytekit-spark/tests/test_agent.py @@ -19,7 +19,7 @@ @pytest.mark.asyncio async def test_databricks_agent(): ctx = MagicMock(spec=grpc.ServicerContext) - agent = AgentRegistry.get_agent(ctx, "spark") + agent = AgentRegistry.get_agent("spark") task_id = Identifier( resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version" @@ -131,7 +131,6 @@ async def test_databricks_agent(): mocked.post(delete_url, status=200, payload=mock_delete_response) await agent.async_delete(ctx, metadata_bytes) - mocked_header = {"Authorization": f"Bearer {mocked_token}", "content-type": "application/json"} - assert mocked_header == get_header() + assert get_header() == {"Authorization": f"Bearer {mocked_token}", "content-type": "application/json"} mock.patch.stopall() From f074740d0dee9782e9ce535396286fc57eeb39aa Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Sat, 9 Sep 2023 16:40:52 +0800 Subject: [PATCH 24/24] rename token to access_token Signed-off-by: Future Outlier --- plugins/flytekit-spark/flytekitplugins/spark/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/agent.py b/plugins/flytekit-spark/flytekitplugins/spark/agent.py index 1b8c31d674..93c03b5156 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/agent.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/agent.py @@ -92,7 +92,7 @@ async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes def get_header() -> typing.Dict[str, str]: - token = flytekit.current_context().secrets.get("databricks", "token") + token = flytekit.current_context().secrets.get("databricks", "access_token") return {"Authorization": f"Bearer {token}", "content-type": "application/json"}