Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Map over notebook task #1650

Merged
merged 13 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions flytekit/core/map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from flytekit.core.constants import SdkTaskType
from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager
from flytekit.core.interface import transform_interface_to_list_interface
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask
from flytekit.core.tracker import TrackedInstance
from flytekit.core.utils import timeit
from flytekit.exceptions import scopes as exception_scopes
Expand Down Expand Up @@ -64,7 +64,7 @@ def __init__(
else:
actual_task = python_function_task

if not isinstance(actual_task, PythonFunctionTask):
if not issubclass(type(actual_task), PythonTask):
raise ValueError("Map tasks can only compose of Python Functon Tasks currently")

if len(actual_task.python_interface.outputs.keys()) > 1:
Expand All @@ -76,7 +76,11 @@ def __init__(

collection_interface = transform_interface_to_list_interface(actual_task.python_interface, self._bound_inputs)
self._run_task: PythonFunctionTask = actual_task
_, mod, f, _ = tracker.extract_task_module(actual_task.task_function)
if issubclass(type(actual_task), PythonInstanceTask):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a difference between this and isinstance(actual_task, PythonInstanceTask)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated it to isinstance(actual_task, PythonInstanceTask)

mod = actual_task.task_type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be actual_task.instantiated_in? what happens if there are two of the same notebook tasks, named the same, with the same interface, but in two different .py files? will there be confusion?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

used lhs instead

f = actual_task.name
else:
_, mod, f, _ = tracker.extract_task_module(actual_task.task_function)
h = hashlib.md5(collection_interface.__str__().encode("utf-8")).hexdigest()
name = f"{mod}.map_{f}_{h}"

Expand Down Expand Up @@ -271,7 +275,7 @@ def _raw_execute(self, **kwargs) -> Any:


def map_task(
task_function: typing.Union[PythonFunctionTask, functools.partial],
task_function: typing.Union[PythonFunctionTask, PythonInstanceTask, functools.partial],
concurrency: int = 0,
min_success_ratio: float = 1.0,
**kwargs,
Expand Down
10 changes: 8 additions & 2 deletions plugins/flytekit-papermill/flytekitplugins/papermill/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(
task_config: T = None,
inputs: typing.Optional[typing.Dict[str, typing.Type]] = None,
outputs: typing.Optional[typing.Dict[str, typing.Type]] = None,
output_notebooks: typing.Optional[bool] = True,
**kwargs,
):
# Each instance of NotebookTask instantiates an underlying task with a dummy function that will only be used
Expand Down Expand Up @@ -165,13 +166,14 @@ def __init__(
if not os.path.exists(self._notebook_path):
raise ValueError(f"Illegal notebook path passed in {self._notebook_path}")

if outputs:
if outputs and output_notebooks:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if outputs and output_notebooks:
if output_notebooks:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if you just want a notebook? don't want to have to make a fake output.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated it

outputs.update(
{
self._IMPLICIT_OP_NOTEBOOK: self._IMPLICIT_OP_NOTEBOOK_TYPE,
self._IMPLICIT_RENDERED_NOTEBOOK: self._IMPLICIT_RENDERED_NOTEBOOK_TYPE,
}
)

super().__init__(
name,
task_config,
Expand Down Expand Up @@ -282,11 +284,15 @@ def execute(self, **kwargs) -> Any:
elif k == self._IMPLICIT_RENDERED_NOTEBOOK:
output_list.append(self.rendered_output_path)
elif k in m:
v = TypeEngine.to_python_value(ctx=FlyteContext.current_context(), lv=m[k], expected_python_type=type_v)
v = TypeEngine.to_python_value(
ctx=FlyteContext.current_context(), lv=m[k], expected_python_type=type_v
)
output_list.append(v)
else:
raise TypeError(f"Expected output {k} of type {type_v} not found in the notebook outputs")

if len(output_list) == 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if we don't do this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There will be a mismatch between the output type and the downstream task's input type.

return output_list[0]
return tuple(output_list)

def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-papermill/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = [
"flytekit>=1.3.0b2,<2.0.0",
"flytekit",
"papermill>=1.2.0",
"nbconvert>=6.0.7",
"ipykernel>=5.0.0",
Expand Down
34 changes: 22 additions & 12 deletions plugins/flytekit-papermill/tests/test_task.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import datetime
import os
import tempfile
import typing

import pandas as pd
from flytekitplugins.papermill import NotebookTask
from flytekitplugins.pod import Pod
from kubernetes.client import V1Container, V1PodSpec

import flytekit
from flytekit import StructuredDataset, kwtypes, task
from flytekit import StructuredDataset, kwtypes, map_task, task, workflow
from flytekit.configuration import Image, ImageConfig
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile, PythonNotebook
Expand Down Expand Up @@ -43,12 +44,10 @@ def test_notebook_task_simple():
image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
)

sqr, out, render = nb_simple.execute(pi=4)
sqr = nb_simple.execute(pi=4)
assert sqr == 16.0
assert nb_simple.python_interface.inputs == {"pi": float}
assert nb_simple.python_interface.outputs.keys() == {"square", "out_nb", "out_rendered_nb"}
assert nb_simple.output_notebook_path == out == _get_nb_path(nb_name, suffix="-out")
assert nb_simple.rendered_output_path == render == _get_nb_path(nb_name, suffix="-out", ext=".html")
assert (
nb_simple.get_command(settings=serialization_settings)
== nb_simple.get_container(settings=serialization_settings).args
Expand All @@ -63,15 +62,13 @@ def test_notebook_task_multi_values():
inputs=kwtypes(x=int, y=int, h=str),
outputs=kwtypes(z=int, m=int, h=str, n=datetime.datetime),
)
z, m, h, n, out, render = nb.execute(x=10, y=10, h="blah")
z, m, h, n = nb.execute(x=10, y=10, h="blah")
assert z == 20
assert m == 100
assert h == "blah world!"
assert type(n) == datetime.datetime
assert nb.python_interface.inputs == {"x": int, "y": int, "h": str}
assert nb.python_interface.outputs.keys() == {"z", "m", "h", "n", "out_nb", "out_rendered_nb"}
assert nb.output_notebook_path == out == _get_nb_path(nb_name, suffix="-out")
assert nb.rendered_output_path == render == _get_nb_path(nb_name, suffix="-out", ext=".html")


def test_notebook_task_complex():
Expand All @@ -82,14 +79,12 @@ def test_notebook_task_complex():
inputs=kwtypes(h=str, n=int, w=str),
outputs=kwtypes(h=str, w=PythonNotebook, x=X),
)
h, w, x, out, render = nb.execute(h="blah", n=10, w=_get_nb_path("nb-multi"))
h, w, x = nb.execute(h="blah", n=10, w=_get_nb_path("nb-multi"))
assert h == "blah world!"
assert w is not None
assert x.x == 10
assert nb.python_interface.inputs == {"n": int, "h": str, "w": str}
assert nb.python_interface.outputs.keys() == {"h", "w", "x", "out_nb", "out_rendered_nb"}
assert nb.output_notebook_path == out == _get_nb_path(nb_name, suffix="-out")
assert nb.rendered_output_path == render == _get_nb_path(nb_name, suffix="-out", ext=".html")


def test_notebook_deck_local_execution_doesnt_fail():
Expand All @@ -101,7 +96,7 @@ def test_notebook_deck_local_execution_doesnt_fail():
inputs=kwtypes(pi=float),
outputs=kwtypes(square=float),
)
sqr, out, render = nb.execute(pi=4)
sqr = nb.execute(pi=4)
# This is largely a no assert test to ensure render_deck never inhibits local execution.
assert nb._render_deck, "Passing render deck to init should result in private attribute being set"

Expand Down Expand Up @@ -170,5 +165,20 @@ def create_sd() -> StructuredDataset:
inputs=kwtypes(ff=FlyteFile, fd=FlyteDirectory, sd=StructuredDataset),
outputs=kwtypes(success=bool),
)
success, out, render = nb_types.execute(ff=ff, fd=fd, sd=sd)
success = nb_types.execute(ff=ff, fd=fd, sd=sd)
assert success is True, "Notebook execution failed"


def test_map_over_notebook_task():
nb_simple = NotebookTask(
name="test",
notebook_path=_get_nb_path(nb_name, abs=False),
inputs=kwtypes(a=float),
outputs=kwtypes(square=float),
)

@workflow
def wf(a: float) -> typing.List[float]:
return map_task(nb_simple)(a=[a, a, a])

assert wf(a=3.14) == [9.8596, 9.8596, 9.8596]