diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index b40b5029bb..52325ecb59 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -16,7 +16,7 @@ from flytekit.core.constants import SdkTaskType from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.interface import transform_interface_to_list_interface -from flytekit.core.python_function_task import PythonFunctionTask +from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask from flytekit.core.tracker import TrackedInstance from flytekit.core.utils import timeit from flytekit.exceptions import scopes as exception_scopes @@ -34,7 +34,7 @@ class MapPythonTask(PythonTask): def __init__( self, - python_function_task: typing.Union[PythonFunctionTask, functools.partial], + python_function_task: typing.Union[PythonFunctionTask, PythonInstanceTask, functools.partial], concurrency: Optional[int] = None, min_success_ratio: Optional[float] = None, bound_inputs: Optional[Set[str]] = None, @@ -65,7 +65,10 @@ def __init__( actual_task = python_function_task if not isinstance(actual_task, PythonFunctionTask): - raise ValueError("Map tasks can only compose of Python Functon Tasks currently") + if isinstance(actual_task, PythonInstanceTask): + pass + else: + raise ValueError("Map tasks can only compose of PythonFuncton and PythonInstanceTasks currently") if len(actual_task.python_interface.outputs.keys()) > 1: raise ValueError("Map tasks only accept python function tasks with 0 or 1 outputs") @@ -76,7 +79,11 @@ def __init__( collection_interface = transform_interface_to_list_interface(actual_task.python_interface, self._bound_inputs) self._run_task: PythonFunctionTask = actual_task - _, mod, f, _ = tracker.extract_task_module(actual_task.task_function) + if isinstance(actual_task, PythonInstanceTask): + mod = actual_task.task_type + f = actual_task.lhs + else: + _, mod, f, _ = tracker.extract_task_module(actual_task.task_function) h = hashlib.md5(collection_interface.__str__().encode("utf-8")).hexdigest() name = f"{mod}.map_{f}_{h}" @@ -271,7 +278,7 @@ def _raw_execute(self, **kwargs) -> Any: def map_task( - task_function: typing.Union[PythonFunctionTask, functools.partial], + task_function: typing.Union[PythonFunctionTask, PythonInstanceTask, functools.partial], concurrency: int = 0, min_success_ratio: float = 1.0, **kwargs, diff --git a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py index b1f472e99a..6f4ed6886c 100644 --- a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py +++ b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py @@ -133,6 +133,7 @@ def __init__( task_config: T = None, inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, outputs: typing.Optional[typing.Dict[str, typing.Type]] = None, + output_notebooks: typing.Optional[bool] = True, **kwargs, ): # Each instance of NotebookTask instantiates an underlying task with a dummy function that will only be used @@ -165,13 +166,16 @@ def __init__( if not os.path.exists(self._notebook_path): raise ValueError(f"Illegal notebook path passed in {self._notebook_path}") - if outputs: + if output_notebooks: + if outputs is None: + outputs = {} outputs.update( { self._IMPLICIT_OP_NOTEBOOK: self._IMPLICIT_OP_NOTEBOOK_TYPE, self._IMPLICIT_RENDERED_NOTEBOOK: self._IMPLICIT_RENDERED_NOTEBOOK_TYPE, } ) + super().__init__( name, task_config, @@ -287,6 +291,8 @@ def execute(self, **kwargs) -> Any: else: raise TypeError(f"Expected output {k} of type {type_v} not found in the notebook outputs") + if len(output_list) == 1: + return output_list[0] return tuple(output_list) def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any: diff --git a/plugins/flytekit-papermill/setup.py b/plugins/flytekit-papermill/setup.py index 33b9816081..538946a6d7 100644 --- a/plugins/flytekit-papermill/setup.py +++ b/plugins/flytekit-papermill/setup.py @@ -5,7 +5,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" plugin_requires = [ - "flytekit>=1.3.0b2,<2.0.0", + "flytekit", "papermill>=1.2.0", "nbconvert>=6.0.7", "ipykernel>=5.0.0", diff --git a/plugins/flytekit-papermill/tests/test_task.py b/plugins/flytekit-papermill/tests/test_task.py index 0e54e7082e..47db35793d 100644 --- a/plugins/flytekit-papermill/tests/test_task.py +++ b/plugins/flytekit-papermill/tests/test_task.py @@ -1,6 +1,7 @@ import datetime import os import tempfile +import typing import pandas as pd from flytekitplugins.papermill import NotebookTask @@ -8,7 +9,7 @@ from kubernetes.client import V1Container, V1PodSpec import flytekit -from flytekit import StructuredDataset, kwtypes, task +from flytekit import StructuredDataset, kwtypes, map_task, task, workflow from flytekit.configuration import Image, ImageConfig from flytekit.types.directory import FlyteDirectory from flytekit.types.file import FlyteFile, PythonNotebook @@ -33,6 +34,14 @@ def _get_nb_path(name: str, suffix: str = "", abs: bool = True, ext: str = ".ipy outputs=kwtypes(square=float), ) +nb_sub_task = NotebookTask( + name="test", + notebook_path=_get_nb_path(nb_name, abs=False), + inputs=kwtypes(a=float), + outputs=kwtypes(square=float), + output_notebooks=False, +) + def test_notebook_task_simple(): serialization_settings = flytekit.configuration.SerializationSettings( @@ -172,3 +181,11 @@ def create_sd() -> StructuredDataset: ) success, out, render = nb_types.execute(ff=ff, fd=fd, sd=sd) assert success is True, "Notebook execution failed" + + +def test_map_over_notebook_task(): + @workflow + def wf(a: float) -> typing.List[float]: + return map_task(nb_sub_task)(a=[a, a]) + + assert wf(a=3.14) == [9.8596, 9.8596]