Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Databricks config to Spark Job #1358

Merged
merged 12 commits into from
Dec 19, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ jobs:
- name: Install dependencies
run: |
make setup${{ matrix.spark-version-suffix }}
pip install "git+https://github.com/flyteorg/flyteidl@databricks"
pip freeze
- name: Test with coverage
run: |
Expand Down Expand Up @@ -131,6 +132,7 @@ jobs:
pip install -r requirements.txt
if [ -f dev-requirements.txt ]; then pip install -r dev-requirements.txt; fi
pip install --no-deps -U https://github.com/flyteorg/flytekit/archive/${{ github.sha }}.zip#egg=flytekit
pip install "git+https://github.com/flyteorg/flyteidl@databricks"
pip freeze
- name: Test with coverage
run: |
Expand All @@ -155,6 +157,7 @@ jobs:
run: |
python -m pip install --upgrade pip==21.2.4
pip install -r dev-requirements.txt
pip install "git+https://github.com/flyteorg/flyteidl@databricks"
- name: Lint
run: |
make lint
Expand All @@ -176,5 +179,6 @@ jobs:
run: |
python -m pip install --upgrade pip==21.2.4 setuptools wheel
pip install -r doc-requirements.txt
pip install "git+https://github.com/flyteorg/flyteidl@databricks"
- name: Build the documentation
run: make -C docs html
2 changes: 2 additions & 0 deletions flytekit/core/python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import importlib
import re
import sys
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

from abc import ABC
from types import ModuleType
from typing import Callable, Dict, List, Optional, TypeVar, Union
Expand Down Expand Up @@ -191,6 +192,7 @@ def name(self) -> str:
def load_task(self, loader_args: List[Union[T, ModuleType]]) -> PythonAutoContainerTask:
_, task_module, _, task_name, *_ = loader_args

sys.path.append(".")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain this a bit? things like this kinda worry me. we also already have this. should we add that to the entrypoint.py instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed it

task_module = importlib.import_module(task_module)
task_def = getattr(task_module, task_name)
return task_def
Expand Down
20 changes: 18 additions & 2 deletions plugins/flytekit-spark/flytekitplugins/spark/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import enum
import typing
from typing import Dict, Optional

from flyteidl.plugins import spark_pb2 as _spark_task

Expand All @@ -22,6 +22,7 @@ def __init__(
main_class,
spark_conf,
hadoop_conf,
databricks_conf,
executor_path,
):
"""
Expand All @@ -30,29 +31,38 @@ def __init__(
:param application_file: The main application file to execute.
:param dict[Text, Text] spark_conf: A definition of key-value pairs for spark config for the job.
:param dict[Text, Text] hadoop_conf: A definition of key-value pairs for hadoop config for the job.
:param dict[Text, dict] databricks_conf: A definition of key-value pairs for databricks config for the job. Refer to https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunsSubmit.
"""
self._application_file = application_file
self._spark_type = spark_type
self._main_class = main_class
self._executor_path = executor_path
self._spark_conf = spark_conf
self._hadoop_conf = hadoop_conf
self._databricks_conf = databricks_conf

def with_overrides(
self, new_spark_conf: typing.Dict[str, str] = None, new_hadoop_conf: typing.Dict[str, str] = None
self,
new_spark_conf: Optional[Dict[str, str]] = None,
new_hadoop_conf: Optional[Dict[str, str]] = None,
new_databricks_conf: Optional[str] = None,
) -> "SparkJob":
if not new_spark_conf:
new_spark_conf = self.spark_conf

if not new_hadoop_conf:
new_hadoop_conf = self.hadoop_conf

if not new_databricks_conf:
new_databricks_conf = self.databricks_conf

return SparkJob(
spark_type=self.spark_type,
application_file=self.application_file,
main_class=self.main_class,
spark_conf=new_spark_conf,
hadoop_conf=new_hadoop_conf,
databricks_conf=new_databricks_conf,
executor_path=self.executor_path,
)

Expand Down Expand Up @@ -104,6 +114,10 @@ def hadoop_conf(self):
"""
return self._hadoop_conf

@property
def databricks_conf(self) -> str:
return self._databricks_conf

def to_flyte_idl(self):
"""
:rtype: flyteidl.plugins.spark_pb2.SparkJob
Expand All @@ -127,6 +141,7 @@ def to_flyte_idl(self):
executorPath=self.executor_path,
sparkConf=self.spark_conf,
hadoopConf=self.hadoop_conf,
databricksConf=self.databricks_conf,
)

@classmethod
Expand All @@ -151,4 +166,5 @@ def from_flyte_idl(cls, pb2_object):
main_class=pb2_object.mainClass,
hadoop_conf=pb2_object.hadoopConf,
executor_path=pb2_object.executorPath,
databricks_conf=pb2_object.databricksConf,
)
9 changes: 8 additions & 1 deletion plugins/flytekit-spark/flytekitplugins/spark/task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import base64
import json
import os
import typing
from dataclasses import dataclass
Expand Down Expand Up @@ -27,6 +29,7 @@ class Spark(object):

spark_conf: Optional[Dict[str, str]] = None
hadoop_conf: Optional[Dict[str, str]] = None
databricks_conf: typing.Optional[Dict[str, typing.Union[str, dict]]] = None

def __post_init__(self):
if self.spark_conf is None:
Expand All @@ -35,6 +38,9 @@ def __post_init__(self):
if self.hadoop_conf is None:
self.hadoop_conf = {}

if self.databricks_conf is None:
self.databricks_conf = {}


# This method does not reset the SparkSession since it's a bit hard to handle multiple
# Spark sessions in a single application as it's described in:
Expand Down Expand Up @@ -95,6 +101,7 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
job = SparkJob(
spark_conf=self.task_config.spark_conf,
hadoop_conf=self.task_config.hadoop_conf,
databricks_conf=base64.b64encode(json.dumps(self.task_config.databricks_conf).encode()).decode(),
application_file="local://" + settings.entrypoint_settings.path,
executor_path=settings.python_interpreter,
main_class="",
Expand All @@ -107,7 +114,7 @@ def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:

ctx = FlyteContextManager.current_context()
sess_builder = _pyspark.sql.SparkSession.builder.appName(f"FlyteSpark: {user_params.execution_id}")
if not (ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION):
if self.task_config.spark_conf and not (ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION):
# If either of above cases is not true, then we are in local execution of this task
# Add system spark-conf for local/notebook based execution.
spark_conf = _pyspark.SparkConf()
Expand Down
20 changes: 19 additions & 1 deletion plugins/flytekit-spark/tests/test_spark_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,32 @@ def reset_spark_session() -> None:


def test_spark_task(reset_spark_session):
@task(task_config=Spark(spark_conf={"spark": "1"}))
databricks_conf = {
"name": "flytekit databricks plugin example",
"new_cluster": {
"spark_version": "11.0.x-scala2.12",
"node_type_id": "r3.xlarge",
"aws_attributes": {"availability": "ON_DEMAND"},
"num_workers": 4,
"docker_image": {"url": "pingsutw/databricks:latest"},
},
"timeout_seconds": 3600,
"max_retries": 1,
"spark_python_task": {
"python_file": "dbfs:///FileStore/tables/entrypoint-1.py",
"parameters": "ls",
},
}

@task(task_config=Spark(spark_conf={"spark": "1"}, databricks_conf=databricks_conf))
def my_spark(a: str) -> int:
session = flytekit.current_context().spark_session
assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local"
return 10

assert my_spark.task_config is not None
assert my_spark.task_config.spark_conf == {"spark": "1"}
assert my_spark.task_config.databricks_conf == databricks_conf

default_img = Image(name="default", fqn="test", tag="tag")
settings = SerializationSettings(
Expand Down