Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Databricks config to Spark Job #1358

Merged
merged 12 commits into from
Dec 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ jobs:
- flytekit-modin
- flytekit-onnx-pytorch
- flytekit-onnx-scikitlearn
# onnxx-tensorflow needs a version of tensorflow that does not work with protobuf>4.
# onnx-tensorflow needs a version of tensorflow that does not work with protobuf>4.
# The issue is being tracked on the tensorflow side in https://github.com/tensorflow/tensorflow/issues/53234#issuecomment-1330111693
# flytekit-onnx-tensorflow
- flytekit-pandera
Expand Down
2 changes: 1 addition & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ filelock==3.8.2
# via virtualenv
flatbuffers==22.12.6
# via tensorflow
flyteidl==1.3.0
flyteidl==1.3.1
# via
# -c requirements.txt
# flytekit
Expand Down
2 changes: 1 addition & 1 deletion doc-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ filelock==3.8.2
# virtualenv
flatbuffers==22.12.6
# via tensorflow
flyteidl==1.3.0
flyteidl==1.3.1
# via flytekit
fonttools==4.38.0
# via matplotlib
Expand Down
70 changes: 62 additions & 8 deletions plugins/flytekit-spark/flytekitplugins/spark/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import enum
import typing
from typing import Dict, Optional

from flyteidl.plugins import spark_pb2 as _spark_task
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Struct

from flytekit.exceptions import user as _user_exceptions
from flytekit.models import common as _common
Expand All @@ -17,42 +19,60 @@ class SparkType(enum.Enum):
class SparkJob(_common.FlyteIdlEntity):
def __init__(
self,
spark_type,
application_file,
main_class,
spark_conf,
hadoop_conf,
executor_path,
spark_type: SparkType,
application_file: str,
main_class: str,
spark_conf: Dict[str, str],
hadoop_conf: Dict[str, str],
executor_path: str,
databricks_conf: Dict[str, Dict[str, Dict]] = {},
databricks_token: Optional[str] = None,
databricks_instance: Optional[str] = None,
):
"""
This defines a SparkJob target. It will execute the appropriate SparkJob.

:param application_file: The main application file to execute.
:param dict[Text, Text] spark_conf: A definition of key-value pairs for spark config for the job.
:param dict[Text, Text] hadoop_conf: A definition of key-value pairs for hadoop config for the job.
:param Optional[dict[Text, dict]] databricks_conf: A definition of key-value pairs for databricks config for the job. Refer to https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunsSubmit.
:param Optional[str] databricks_token: databricks access token.
:param Optional[str] databricks_instance: Domain name of your deployment. Use the form <account>.cloud.databricks.com.
"""
self._application_file = application_file
self._spark_type = spark_type
self._main_class = main_class
self._executor_path = executor_path
self._spark_conf = spark_conf
self._hadoop_conf = hadoop_conf
self._databricks_conf = databricks_conf
self._databricks_token = databricks_token
self._databricks_instance = databricks_instance

def with_overrides(
self, new_spark_conf: typing.Dict[str, str] = None, new_hadoop_conf: typing.Dict[str, str] = None
self,
new_spark_conf: Optional[Dict[str, str]] = None,
new_hadoop_conf: Optional[Dict[str, str]] = None,
new_databricks_conf: Optional[Dict[str, Dict]] = None,
) -> "SparkJob":
if not new_spark_conf:
new_spark_conf = self.spark_conf

if not new_hadoop_conf:
new_hadoop_conf = self.hadoop_conf

if not new_databricks_conf:
new_databricks_conf = self.databricks_conf

return SparkJob(
spark_type=self.spark_type,
application_file=self.application_file,
main_class=self.main_class,
spark_conf=new_spark_conf,
hadoop_conf=new_hadoop_conf,
databricks_conf=new_databricks_conf,
databricks_token=self.databricks_token,
databricks_instance=self.databricks_instance,
executor_path=self.executor_path,
)

Expand Down Expand Up @@ -104,6 +124,31 @@ def hadoop_conf(self):
"""
return self._hadoop_conf

@property
def databricks_conf(self) -> Dict[str, Dict]:
"""
databricks_conf: Databricks job configuration.
Config structure can be found here. https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure
:rtype: dict[Text, dict[Text, Text]]
"""
return self._databricks_conf

@property
def databricks_token(self) -> str:
"""
Databricks access token
:rtype: str
"""
return self._databricks_token

@property
def databricks_instance(self) -> str:
"""
Domain name of your deployment. Use the form <account>.cloud.databricks.com.
:rtype: str
"""
return self._databricks_instance

def to_flyte_idl(self):
"""
:rtype: flyteidl.plugins.spark_pb2.SparkJob
Expand All @@ -120,13 +165,19 @@ def to_flyte_idl(self):
else:
raise _user_exceptions.FlyteValidationException("Invalid Spark Application Type Specified")

databricks_conf = Struct()
databricks_conf.update(self.databricks_conf)

return _spark_task.SparkJob(
applicationType=application_type,
mainApplicationFile=self.application_file,
mainClass=self.main_class,
executorPath=self.executor_path,
sparkConf=self.spark_conf,
hadoopConf=self.hadoop_conf,
databricksConf=databricks_conf,
databricksToken=self.databricks_token,
databricksInstance=self.databricks_instance,
)

@classmethod
Expand All @@ -151,4 +202,7 @@ def from_flyte_idl(cls, pb2_object):
main_class=pb2_object.mainClass,
hadoop_conf=pb2_object.hadoopConf,
executor_path=pb2_object.executorPath,
databricks_conf=json_format.MessageToDict(pb2_object.databricksConf),
databricks_token=pb2_object.databricksToken,
databricks_instance=pb2_object.databricksInstance,
)
23 changes: 23 additions & 0 deletions plugins/flytekit-spark/flytekitplugins/spark/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,23 @@ def __post_init__(self):
self.hadoop_conf = {}


@dataclass
class Databricks(Spark):
"""
Use this to configure a Databricks task. Task's marked with this will automatically execute
natively onto databricks platform as a distributed execution of spark

Args:
databricks_conf: Databricks job configuration. Config structure can be found here. https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure
databricks_token: Databricks access token. https://docs.databricks.com/dev-tools/api/latest/authentication.html.
databricks_instance: Domain name of your deployment. Use the form <account>.cloud.databricks.com.
"""

databricks_conf: typing.Optional[Dict[str, typing.Union[str, dict]]] = None
databricks_token: Optional[str] = None
databricks_instance: Optional[str] = None


# This method does not reset the SparkSession since it's a bit hard to handle multiple
# Spark sessions in a single application as it's described in:
# https://stackoverflow.com/questions/41491972/how-can-i-tear-down-a-sparksession-and-create-a-new-one-within-one-application.
Expand Down Expand Up @@ -100,6 +117,12 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
main_class="",
spark_type=SparkType.PYTHON,
)
if isinstance(self.task_config, Databricks):
cfg = typing.cast(self.task_config, Databricks)
job._databricks_conf = cfg.databricks_conf
job._databricks_token = cfg.databricks_token
job._databricks_instance = cfg.databricks_instance

return MessageToDict(job.to_flyte_idl())

def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-spark/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ docker-image-py==0.1.12
# via flytekit
docstring-parser==0.15
# via flytekit
flyteidl==1.3.0
flyteidl==1.3.1
# via flytekit
flytekit==1.3.0b2
# via flytekitplugins-spark
Expand Down
41 changes: 40 additions & 1 deletion plugins/flytekit-spark/tests/test_spark_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pyspark
import pytest
from flytekitplugins.spark import Spark
from flytekitplugins.spark.task import new_spark_session
from flytekitplugins.spark.task import Databricks, new_spark_session
from pyspark.sql import SparkSession

import flytekit
Expand All @@ -19,6 +19,23 @@ def reset_spark_session() -> None:


def test_spark_task(reset_spark_session):
databricks_conf = {
"name": "flytekit databricks plugin example",
"new_cluster": {
"spark_version": "11.0.x-scala2.12",
"node_type_id": "r3.xlarge",
"aws_attributes": {"availability": "ON_DEMAND"},
"num_workers": 4,
"docker_image": {"url": "pingsutw/databricks:latest"},
},
"timeout_seconds": 3600,
"max_retries": 1,
"spark_python_task": {
"python_file": "dbfs:///FileStore/tables/entrypoint-1.py",
"parameters": "ls",
},
}

@task(task_config=Spark(spark_conf={"spark": "1"}))
def my_spark(a: str) -> int:
session = flytekit.current_context().spark_session
Expand Down Expand Up @@ -53,6 +70,28 @@ def my_spark(a: str) -> int:
assert ("spark", "1") in configs
assert ("spark.app.name", "FlyteSpark: ex:local:local:local") in configs

databricks_token = "token"
databricks_instance = "account.cloud.databricks.com"

@task(
task_config=Databricks(
spark_conf={"spark": "2"},
databricks_conf=databricks_conf,
databricks_instance="account.cloud.databricks.com",
databricks_token="token",
)
)
def my_databricks(a: str) -> int:
session = flytekit.current_context().spark_session
assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local"
return 10

assert my_databricks.task_config is not None
assert my_databricks.task_config.spark_conf == {"spark": "2"}
assert my_databricks.task_config.databricks_conf == databricks_conf
assert my_databricks.task_config.databricks_instance == databricks_instance
assert my_databricks.task_config.databricks_token == databricks_token


def test_new_spark_session():
name = "SessionName"
Expand Down
2 changes: 1 addition & 1 deletion requirements-spark2.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ docker-image-py==0.1.12
# via flytekit
docstring-parser==0.15
# via flytekit
flyteidl==1.3.0
flyteidl==1.3.1
# via flytekit
googleapis-common-protos==1.57.0
# via
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ docker-image-py==0.1.12
# via flytekit
docstring-parser==0.15
# via flytekit
flyteidl==1.3.0
flyteidl==1.3.1
# via flytekit
googleapis-common-protos==1.57.0
# via
Expand Down