Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Map over notebook task #1650

Merged
merged 13 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions flytekit/core/map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from flytekit.core.constants import SdkTaskType
from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager
from flytekit.core.interface import transform_interface_to_list_interface
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask
from flytekit.core.tracker import TrackedInstance
from flytekit.core.utils import timeit
from flytekit.exceptions import scopes as exception_scopes
Expand All @@ -34,7 +34,7 @@ class MapPythonTask(PythonTask):

def __init__(
self,
python_function_task: typing.Union[PythonFunctionTask, functools.partial],
python_function_task: typing.Union[PythonFunctionTask, PythonInstanceTask, functools.partial],
concurrency: Optional[int] = None,
min_success_ratio: Optional[float] = None,
bound_inputs: Optional[Set[str]] = None,
Expand Down Expand Up @@ -65,7 +65,10 @@ def __init__(
actual_task = python_function_task

if not isinstance(actual_task, PythonFunctionTask):
raise ValueError("Map tasks can only compose of Python Functon Tasks currently")
if isinstance(actual_task, PythonInstanceTask):
pass
else:
raise ValueError("Map tasks can only compose of PythonFuncton and PythonInstanceTasks currently")

if len(actual_task.python_interface.outputs.keys()) > 1:
raise ValueError("Map tasks only accept python function tasks with 0 or 1 outputs")
Expand All @@ -76,7 +79,11 @@ def __init__(

collection_interface = transform_interface_to_list_interface(actual_task.python_interface, self._bound_inputs)
self._run_task: PythonFunctionTask = actual_task
_, mod, f, _ = tracker.extract_task_module(actual_task.task_function)
if isinstance(actual_task, PythonInstanceTask):
mod = actual_task.task_type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be actual_task.instantiated_in? what happens if there are two of the same notebook tasks, named the same, with the same interface, but in two different .py files? will there be confusion?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

used lhs instead

f = actual_task.lhs
else:
_, mod, f, _ = tracker.extract_task_module(actual_task.task_function)
h = hashlib.md5(collection_interface.__str__().encode("utf-8")).hexdigest()
name = f"{mod}.map_{f}_{h}"

Expand Down Expand Up @@ -271,7 +278,7 @@ def _raw_execute(self, **kwargs) -> Any:


def map_task(
task_function: typing.Union[PythonFunctionTask, functools.partial],
task_function: typing.Union[PythonFunctionTask, PythonInstanceTask, functools.partial],
concurrency: int = 0,
min_success_ratio: float = 1.0,
**kwargs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(
task_config: T = None,
inputs: typing.Optional[typing.Dict[str, typing.Type]] = None,
outputs: typing.Optional[typing.Dict[str, typing.Type]] = None,
output_notebooks: typing.Optional[bool] = True,
**kwargs,
):
# Each instance of NotebookTask instantiates an underlying task with a dummy function that will only be used
Expand Down Expand Up @@ -165,13 +166,16 @@ def __init__(
if not os.path.exists(self._notebook_path):
raise ValueError(f"Illegal notebook path passed in {self._notebook_path}")

if outputs:
if output_notebooks:
if outputs is None:
Copy link
Contributor

@wild-endeavor wild-endeavor May 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can do outputs = outputs or {}

outputs = {}
outputs.update(
{
self._IMPLICIT_OP_NOTEBOOK: self._IMPLICIT_OP_NOTEBOOK_TYPE,
self._IMPLICIT_RENDERED_NOTEBOOK: self._IMPLICIT_RENDERED_NOTEBOOK_TYPE,
}
)

super().__init__(
name,
task_config,
Expand Down Expand Up @@ -287,6 +291,8 @@ def execute(self, **kwargs) -> Any:
else:
raise TypeError(f"Expected output {k} of type {type_v} not found in the notebook outputs")

if len(output_list) == 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if we don't do this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There will be a mismatch between the output type and the downstream task's input type.

return output_list[0]
return tuple(output_list)

def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-papermill/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = [
"flytekit>=1.3.0b2,<2.0.0",
"flytekit",
"papermill>=1.2.0",
"nbconvert>=6.0.7",
"ipykernel>=5.0.0",
Expand Down
19 changes: 18 additions & 1 deletion plugins/flytekit-papermill/tests/test_task.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import datetime
import os
import tempfile
import typing

import pandas as pd
from flytekitplugins.papermill import NotebookTask
from flytekitplugins.pod import Pod
from kubernetes.client import V1Container, V1PodSpec

import flytekit
from flytekit import StructuredDataset, kwtypes, task
from flytekit import StructuredDataset, kwtypes, map_task, task, workflow
from flytekit.configuration import Image, ImageConfig
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile, PythonNotebook
Expand All @@ -33,6 +34,14 @@ def _get_nb_path(name: str, suffix: str = "", abs: bool = True, ext: str = ".ipy
outputs=kwtypes(square=float),
)

nb_sub_task = NotebookTask(
name="test",
notebook_path=_get_nb_path(nb_name, abs=False),
inputs=kwtypes(a=float),
outputs=kwtypes(square=float),
output_notebooks=False,
)


def test_notebook_task_simple():
serialization_settings = flytekit.configuration.SerializationSettings(
Expand Down Expand Up @@ -172,3 +181,11 @@ def create_sd() -> StructuredDataset:
)
success, out, render = nb_types.execute(ff=ff, fd=fd, sd=sd)
assert success is True, "Notebook execution failed"


def test_map_over_notebook_task():
@workflow
def wf(a: float) -> typing.List[float]:
return map_task(nb_sub_task)(a=[a, a])

assert wf(a=3.14) == [9.8596, 9.8596]