Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Databricks config to Spark Job #1358

Merged
merged 12 commits into from
Dec 19, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ jobs:
- name: Install dependencies
run: |
make setup${{ matrix.spark-version-suffix }}
pip install "git+https://github.com/flyteorg/flyteidl@databricks"
pip freeze
- name: Test with coverage
run: |
Expand Down Expand Up @@ -131,6 +132,7 @@ jobs:
pip install -r requirements.txt
if [ -f dev-requirements.txt ]; then pip install -r dev-requirements.txt; fi
pip install --no-deps -U https://github.com/flyteorg/flytekit/archive/${{ github.sha }}.zip#egg=flytekit
pip install "git+https://github.com/flyteorg/flyteidl@databricks"
pip freeze
- name: Test with coverage
run: |
Expand All @@ -155,6 +157,7 @@ jobs:
run: |
python -m pip install --upgrade pip==21.2.4
pip install -r dev-requirements.txt
pip install "git+https://github.com/flyteorg/flyteidl@databricks"
- name: Lint
run: |
make lint
Expand All @@ -176,5 +179,6 @@ jobs:
run: |
python -m pip install --upgrade pip==21.2.4 setuptools wheel
pip install -r doc-requirements.txt
pip install "git+https://github.com/flyteorg/flyteidl@databricks"
- name: Build the documentation
run: make -C docs html
16 changes: 8 additions & 8 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,14 +445,14 @@ def _pass_through():
def execute_task_cmd(
inputs,
output_prefix,
raw_output_data_prefix,
test,
prev_checkpoint,
checkpoint_path,
dynamic_addl_distro,
dynamic_dest_dir,
resolver,
resolver_args,
raw_output_data_prefix=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

databricks will use a different entrypoint file uploaded to dbfs right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, sorry. I reverted it.

test=False,
prev_checkpoint=None,
checkpoint_path=None,
dynamic_addl_distro=None,
dynamic_dest_dir=None,
resolver=None,
resolver_args=None,
):
logger.info(get_version_message())
# We get weird errors if there are no click echo messages at all, so emit an empty string so that unit tests pass.
Expand Down
1 change: 1 addition & 0 deletions flytekit/core/python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import importlib
import re
import sys
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

from abc import ABC
from types import ModuleType
from typing import Callable, Dict, List, Optional, TypeVar, Union
Expand Down
58 changes: 56 additions & 2 deletions plugins/flytekit-spark/flytekitplugins/spark/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import enum
import typing
from typing import Dict, Optional

from flyteidl.plugins import spark_pb2 as _spark_task
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Struct

from flytekit.exceptions import user as _user_exceptions
from flytekit.models import common as _common
Expand All @@ -22,6 +24,9 @@ def __init__(
main_class,
spark_conf,
hadoop_conf,
databricks_conf,
databricks_token,
databricks_instance,
executor_path,
):
"""
Expand All @@ -30,29 +35,44 @@ def __init__(
:param application_file: The main application file to execute.
:param dict[Text, Text] spark_conf: A definition of key-value pairs for spark config for the job.
:param dict[Text, Text] hadoop_conf: A definition of key-value pairs for hadoop config for the job.
:param dict[Text, dict] databricks_conf: A definition of key-value pairs for databricks config for the job. Refer to https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunsSubmit.
:param str databricks_token: databricks access token.
:param str databricks_instance: Domain name of your deployment. Use the form <account>.cloud.databricks.com.
"""
self._application_file = application_file
self._spark_type = spark_type
self._main_class = main_class
self._executor_path = executor_path
self._spark_conf = spark_conf
self._hadoop_conf = hadoop_conf
self._databricks_conf = databricks_conf
self._databricks_token = databricks_token
self._databricks_instance = databricks_instance

def with_overrides(
self, new_spark_conf: typing.Dict[str, str] = None, new_hadoop_conf: typing.Dict[str, str] = None
self,
new_spark_conf: Optional[Dict[str, str]] = None,
new_hadoop_conf: Optional[Dict[str, str]] = None,
new_databricks_conf: Optional[Dict[str, Dict]] = None,
) -> "SparkJob":
if not new_spark_conf:
new_spark_conf = self.spark_conf

if not new_hadoop_conf:
new_hadoop_conf = self.hadoop_conf

if not new_databricks_conf:
new_databricks_conf = self.databricks_conf

return SparkJob(
spark_type=self.spark_type,
application_file=self.application_file,
main_class=self.main_class,
spark_conf=new_spark_conf,
hadoop_conf=new_hadoop_conf,
databricks_conf=new_databricks_conf,
databricks_token=self.databricks_token,
databricks_instance=self.databricks_instance,
executor_path=self.executor_path,
)

Expand Down Expand Up @@ -104,6 +124,31 @@ def hadoop_conf(self):
"""
return self._hadoop_conf

@property
def databricks_conf(self) -> Dict[str, Dict]:
"""
databricks_conf: Databricks job configuration.
Config structure can be found here. https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure
:rtype: dict[Text, dict[Text, Text]]
"""
return self._databricks_conf

@property
def databricks_token(self) -> str:
"""
Databricks access token
:rtype: str
"""
return self._databricks_token

@property
def databricks_instance(self) -> str:
"""
Domain name of your deployment. Use the form <account>.cloud.databricks.com.
:rtype: str
"""
return self._databricks_instance

def to_flyte_idl(self):
"""
:rtype: flyteidl.plugins.spark_pb2.SparkJob
Expand All @@ -120,13 +165,19 @@ def to_flyte_idl(self):
else:
raise _user_exceptions.FlyteValidationException("Invalid Spark Application Type Specified")

databricks_conf = Struct()
databricks_conf.update(self.databricks_conf)

return _spark_task.SparkJob(
applicationType=application_type,
mainApplicationFile=self.application_file,
mainClass=self.main_class,
executorPath=self.executor_path,
sparkConf=self.spark_conf,
hadoopConf=self.hadoop_conf,
databricksConf=databricks_conf,
databricksToken=self.databricks_token,
databricksInstance=self.databricks_instance,
)

@classmethod
Expand All @@ -151,4 +202,7 @@ def from_flyte_idl(cls, pb2_object):
main_class=pb2_object.mainClass,
hadoop_conf=pb2_object.hadoopConf,
executor_path=pb2_object.executorPath,
databricks_conf=json_format.MessageToDict(pb2_object.databricksConf),
databricks_token=pb2_object.databricksToken,
databricks_instance=pb2_object.databricksInstance,
)
14 changes: 14 additions & 0 deletions plugins/flytekit-spark/flytekitplugins/spark/task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import base64
import json
import os
import typing
from dataclasses import dataclass
Expand All @@ -23,10 +25,16 @@ class Spark(object):
Args:
spark_conf: Dictionary of spark config. The variables should match what spark expects
hadoop_conf: Dictionary of hadoop conf. The variables should match a typical hadoop configuration for spark
databricks_conf: Databricks job configuration. Config structure can be found here. https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure
databricks_token: Databricks access token. https://docs.databricks.com/dev-tools/api/latest/authentication.html.
databricks_instance: Domain name of your deployment. Use the form <account>.cloud.databricks.com.
"""

spark_conf: Optional[Dict[str, str]] = None
hadoop_conf: Optional[Dict[str, str]] = None
databricks_conf: typing.Optional[Dict[str, typing.Union[str, dict]]] = None
databricks_token: Optional[str] = None
databricks_instance: Optional[str] = None

def __post_init__(self):
if self.spark_conf is None:
Expand All @@ -35,6 +43,9 @@ def __post_init__(self):
if self.hadoop_conf is None:
self.hadoop_conf = {}

if self.databricks_conf is None:
self.databricks_conf = {}


# This method does not reset the SparkSession since it's a bit hard to handle multiple
# Spark sessions in a single application as it's described in:
Expand Down Expand Up @@ -95,6 +106,9 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
job = SparkJob(
spark_conf=self.task_config.spark_conf,
hadoop_conf=self.task_config.hadoop_conf,
databricks_conf=self.task_config.databricks_conf,
databricks_token=self.task_config.databricks_token,
databricks_instance=self.task_config.databricks_instance,
application_file="local://" + settings.entrypoint_settings.path,
executor_path=settings.python_interpreter,
main_class="",
Expand Down
20 changes: 19 additions & 1 deletion plugins/flytekit-spark/tests/test_spark_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,32 @@ def reset_spark_session() -> None:


def test_spark_task(reset_spark_session):
@task(task_config=Spark(spark_conf={"spark": "1"}))
databricks_conf = {
"name": "flytekit databricks plugin example",
"new_cluster": {
"spark_version": "11.0.x-scala2.12",
"node_type_id": "r3.xlarge",
"aws_attributes": {"availability": "ON_DEMAND"},
"num_workers": 4,
"docker_image": {"url": "pingsutw/databricks:latest"},
},
"timeout_seconds": 3600,
"max_retries": 1,
"spark_python_task": {
"python_file": "dbfs:///FileStore/tables/entrypoint-1.py",
"parameters": "ls",
},
}

@task(task_config=Spark(spark_conf={"spark": "1"}, databricks_conf=databricks_conf))
def my_spark(a: str) -> int:
session = flytekit.current_context().spark_session
assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local"
return 10

assert my_spark.task_config is not None
assert my_spark.task_config.spark_conf == {"spark": "1"}
assert my_spark.task_config.databricks_conf == databricks_conf

default_img = Image(name="default", fqn="test", tag="tag")
settings = SerializationSettings(
Expand Down