Skip to content

Commit

Permalink
Add Databricks config to Spark Job (#1358)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Dec 19, 2022
1 parent 315956f commit e4911e7
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ jobs:
- flytekit-modin
- flytekit-onnx-pytorch
- flytekit-onnx-scikitlearn
# onnxx-tensorflow needs a version of tensorflow that does not work with protobuf>4.
# onnx-tensorflow needs a version of tensorflow that does not work with protobuf>4.
# The issue is being tracked on the tensorflow side in https://github.com/tensorflow/tensorflow/issues/53234#issuecomment-1330111693
# flytekit-onnx-tensorflow
- flytekit-pandera
Expand Down
2 changes: 1 addition & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ filelock==3.8.2
# via virtualenv
flatbuffers==22.12.6
# via tensorflow
flyteidl==1.3.0
flyteidl==1.3.1
# via
# -c requirements.txt
# flytekit
Expand Down
2 changes: 1 addition & 1 deletion doc-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ filelock==3.8.2
# virtualenv
flatbuffers==22.12.6
# via tensorflow
flyteidl==1.3.0
flyteidl==1.3.1
# via flytekit
fonttools==4.38.0
# via matplotlib
Expand Down
70 changes: 62 additions & 8 deletions plugins/flytekit-spark/flytekitplugins/spark/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import enum
import typing
from typing import Dict, Optional

from flyteidl.plugins import spark_pb2 as _spark_task
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Struct

from flytekit.exceptions import user as _user_exceptions
from flytekit.models import common as _common
Expand All @@ -17,42 +19,60 @@ class SparkType(enum.Enum):
class SparkJob(_common.FlyteIdlEntity):
def __init__(
self,
spark_type,
application_file,
main_class,
spark_conf,
hadoop_conf,
executor_path,
spark_type: SparkType,
application_file: str,
main_class: str,
spark_conf: Dict[str, str],
hadoop_conf: Dict[str, str],
executor_path: str,
databricks_conf: Dict[str, Dict[str, Dict]] = {},
databricks_token: Optional[str] = None,
databricks_instance: Optional[str] = None,
):
"""
This defines a SparkJob target. It will execute the appropriate SparkJob.
:param application_file: The main application file to execute.
:param dict[Text, Text] spark_conf: A definition of key-value pairs for spark config for the job.
:param dict[Text, Text] hadoop_conf: A definition of key-value pairs for hadoop config for the job.
:param Optional[dict[Text, dict]] databricks_conf: A definition of key-value pairs for databricks config for the job. Refer to https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunsSubmit.
:param Optional[str] databricks_token: databricks access token.
:param Optional[str] databricks_instance: Domain name of your deployment. Use the form <account>.cloud.databricks.com.
"""
self._application_file = application_file
self._spark_type = spark_type
self._main_class = main_class
self._executor_path = executor_path
self._spark_conf = spark_conf
self._hadoop_conf = hadoop_conf
self._databricks_conf = databricks_conf
self._databricks_token = databricks_token
self._databricks_instance = databricks_instance

def with_overrides(
self, new_spark_conf: typing.Dict[str, str] = None, new_hadoop_conf: typing.Dict[str, str] = None
self,
new_spark_conf: Optional[Dict[str, str]] = None,
new_hadoop_conf: Optional[Dict[str, str]] = None,
new_databricks_conf: Optional[Dict[str, Dict]] = None,
) -> "SparkJob":
if not new_spark_conf:
new_spark_conf = self.spark_conf

if not new_hadoop_conf:
new_hadoop_conf = self.hadoop_conf

if not new_databricks_conf:
new_databricks_conf = self.databricks_conf

return SparkJob(
spark_type=self.spark_type,
application_file=self.application_file,
main_class=self.main_class,
spark_conf=new_spark_conf,
hadoop_conf=new_hadoop_conf,
databricks_conf=new_databricks_conf,
databricks_token=self.databricks_token,
databricks_instance=self.databricks_instance,
executor_path=self.executor_path,
)

Expand Down Expand Up @@ -104,6 +124,31 @@ def hadoop_conf(self):
"""
return self._hadoop_conf

@property
def databricks_conf(self) -> Dict[str, Dict]:
"""
databricks_conf: Databricks job configuration.
Config structure can be found here. https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure
:rtype: dict[Text, dict[Text, Text]]
"""
return self._databricks_conf

@property
def databricks_token(self) -> str:
"""
Databricks access token
:rtype: str
"""
return self._databricks_token

@property
def databricks_instance(self) -> str:
"""
Domain name of your deployment. Use the form <account>.cloud.databricks.com.
:rtype: str
"""
return self._databricks_instance

def to_flyte_idl(self):
"""
:rtype: flyteidl.plugins.spark_pb2.SparkJob
Expand All @@ -120,13 +165,19 @@ def to_flyte_idl(self):
else:
raise _user_exceptions.FlyteValidationException("Invalid Spark Application Type Specified")

databricks_conf = Struct()
databricks_conf.update(self.databricks_conf)

return _spark_task.SparkJob(
applicationType=application_type,
mainApplicationFile=self.application_file,
mainClass=self.main_class,
executorPath=self.executor_path,
sparkConf=self.spark_conf,
hadoopConf=self.hadoop_conf,
databricksConf=databricks_conf,
databricksToken=self.databricks_token,
databricksInstance=self.databricks_instance,
)

@classmethod
Expand All @@ -151,4 +202,7 @@ def from_flyte_idl(cls, pb2_object):
main_class=pb2_object.mainClass,
hadoop_conf=pb2_object.hadoopConf,
executor_path=pb2_object.executorPath,
databricks_conf=json_format.MessageToDict(pb2_object.databricksConf),
databricks_token=pb2_object.databricksToken,
databricks_instance=pb2_object.databricksInstance,
)
23 changes: 23 additions & 0 deletions plugins/flytekit-spark/flytekitplugins/spark/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,23 @@ def __post_init__(self):
self.hadoop_conf = {}


@dataclass
class Databricks(Spark):
"""
Use this to configure a Databricks task. Task's marked with this will automatically execute
natively onto databricks platform as a distributed execution of spark
Args:
databricks_conf: Databricks job configuration. Config structure can be found here. https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure
databricks_token: Databricks access token. https://docs.databricks.com/dev-tools/api/latest/authentication.html.
databricks_instance: Domain name of your deployment. Use the form <account>.cloud.databricks.com.
"""

databricks_conf: typing.Optional[Dict[str, typing.Union[str, dict]]] = None
databricks_token: Optional[str] = None
databricks_instance: Optional[str] = None


# This method does not reset the SparkSession since it's a bit hard to handle multiple
# Spark sessions in a single application as it's described in:
# https://stackoverflow.com/questions/41491972/how-can-i-tear-down-a-sparksession-and-create-a-new-one-within-one-application.
Expand Down Expand Up @@ -100,6 +117,12 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
main_class="",
spark_type=SparkType.PYTHON,
)
if isinstance(self.task_config, Databricks):
cfg = typing.cast(self.task_config, Databricks)
job._databricks_conf = cfg.databricks_conf
job._databricks_token = cfg.databricks_token
job._databricks_instance = cfg.databricks_instance

return MessageToDict(job.to_flyte_idl())

def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-spark/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ docker-image-py==0.1.12
# via flytekit
docstring-parser==0.15
# via flytekit
flyteidl==1.3.0
flyteidl==1.3.1
# via flytekit
flytekit==1.3.0b2
# via flytekitplugins-spark
Expand Down
41 changes: 40 additions & 1 deletion plugins/flytekit-spark/tests/test_spark_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pyspark
import pytest
from flytekitplugins.spark import Spark
from flytekitplugins.spark.task import new_spark_session
from flytekitplugins.spark.task import Databricks, new_spark_session
from pyspark.sql import SparkSession

import flytekit
Expand All @@ -19,6 +19,23 @@ def reset_spark_session() -> None:


def test_spark_task(reset_spark_session):
databricks_conf = {
"name": "flytekit databricks plugin example",
"new_cluster": {
"spark_version": "11.0.x-scala2.12",
"node_type_id": "r3.xlarge",
"aws_attributes": {"availability": "ON_DEMAND"},
"num_workers": 4,
"docker_image": {"url": "pingsutw/databricks:latest"},
},
"timeout_seconds": 3600,
"max_retries": 1,
"spark_python_task": {
"python_file": "dbfs:///FileStore/tables/entrypoint-1.py",
"parameters": "ls",
},
}

@task(task_config=Spark(spark_conf={"spark": "1"}))
def my_spark(a: str) -> int:
session = flytekit.current_context().spark_session
Expand Down Expand Up @@ -53,6 +70,28 @@ def my_spark(a: str) -> int:
assert ("spark", "1") in configs
assert ("spark.app.name", "FlyteSpark: ex:local:local:local") in configs

databricks_token = "token"
databricks_instance = "account.cloud.databricks.com"

@task(
task_config=Databricks(
spark_conf={"spark": "2"},
databricks_conf=databricks_conf,
databricks_instance="account.cloud.databricks.com",
databricks_token="token",
)
)
def my_databricks(a: str) -> int:
session = flytekit.current_context().spark_session
assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local"
return 10

assert my_databricks.task_config is not None
assert my_databricks.task_config.spark_conf == {"spark": "2"}
assert my_databricks.task_config.databricks_conf == databricks_conf
assert my_databricks.task_config.databricks_instance == databricks_instance
assert my_databricks.task_config.databricks_token == databricks_token


def test_new_spark_session():
name = "SessionName"
Expand Down
2 changes: 1 addition & 1 deletion requirements-spark2.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ docker-image-py==0.1.12
# via flytekit
docstring-parser==0.15
# via flytekit
flyteidl==1.3.0
flyteidl==1.3.1
# via flytekit
googleapis-common-protos==1.57.0
# via
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ docker-image-py==0.1.12
# via flytekit
docstring-parser==0.15
# via flytekit
flyteidl==1.3.0
flyteidl==1.3.1
# via flytekit
googleapis-common-protos==1.57.0
# via
Expand Down

0 comments on commit e4911e7

Please sign in to comment.