-
Notifications
You must be signed in to change notification settings - Fork 301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Map over notebook task #1650
Map over notebook task #1650
Changes from 5 commits
0470d09
29f40e9
7aaa9c3
022ade9
53ec8b8
26237aa
6291806
022d366
d221d1f
4ded410
a44193f
fad7075
ab0375a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,7 @@ | |
from flytekit.core.constants import SdkTaskType | ||
from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager | ||
from flytekit.core.interface import transform_interface_to_list_interface | ||
from flytekit.core.python_function_task import PythonFunctionTask | ||
from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask | ||
from flytekit.core.tracker import TrackedInstance | ||
from flytekit.core.utils import timeit | ||
from flytekit.exceptions import scopes as exception_scopes | ||
|
@@ -64,7 +64,7 @@ def __init__( | |
else: | ||
actual_task = python_function_task | ||
|
||
if not isinstance(actual_task, PythonFunctionTask): | ||
if not isinstance(actual_task, PythonTask) or not issubclass(type(actual_task), PythonInstanceTask): | ||
raise ValueError("Map tasks can only compose of Python Functon Tasks currently") | ||
|
||
if len(actual_task.python_interface.outputs.keys()) > 1: | ||
|
@@ -76,7 +76,11 @@ def __init__( | |
|
||
collection_interface = transform_interface_to_list_interface(actual_task.python_interface, self._bound_inputs) | ||
self._run_task: PythonFunctionTask = actual_task | ||
_, mod, f, _ = tracker.extract_task_module(actual_task.task_function) | ||
if hasattr(actual_task, "_IMPLICIT_OP_NOTEBOOK_TYPE"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems to special for notebook- we should make it for instance tasks in general |
||
mod = "papermill" | ||
f = actual_task.name | ||
else: | ||
_, mod, f, _ = tracker.extract_task_module(actual_task.task_function) | ||
h = hashlib.md5(collection_interface.__str__().encode("utf-8")).hexdigest() | ||
name = f"{mod}.map_{f}_{h}" | ||
|
||
|
@@ -271,7 +275,7 @@ def _raw_execute(self, **kwargs) -> Any: | |
|
||
|
||
def map_task( | ||
task_function: typing.Union[PythonFunctionTask, functools.partial], | ||
task_function: typing.Union[PythonFunctionTask, PythonInstanceTask, functools.partial], | ||
concurrency: int = 0, | ||
min_success_ratio: float = 1.0, | ||
**kwargs, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -165,13 +165,6 @@ def __init__( | |
if not os.path.exists(self._notebook_path): | ||
raise ValueError(f"Illegal notebook path passed in {self._notebook_path}") | ||
|
||
if outputs: | ||
outputs.update( | ||
{ | ||
self._IMPLICIT_OP_NOTEBOOK: self._IMPLICIT_OP_NOTEBOOK_TYPE, | ||
self._IMPLICIT_RENDERED_NOTEBOOK: self._IMPLICIT_RENDERED_NOTEBOOK_TYPE, | ||
} | ||
) | ||
super().__init__( | ||
name, | ||
task_config, | ||
|
@@ -277,16 +270,14 @@ def execute(self, **kwargs) -> Any: | |
output_list = [] | ||
|
||
for k, type_v in self.python_interface.outputs.items(): | ||
if k == self._IMPLICIT_OP_NOTEBOOK: | ||
output_list.append(self.output_notebook_path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you just removing these? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @rahul-theorem I know you're using papermill plugin, do you use these paths in the output? |
||
elif k == self._IMPLICIT_RENDERED_NOTEBOOK: | ||
output_list.append(self.rendered_output_path) | ||
elif k in m: | ||
if k in m: | ||
v = TypeEngine.to_python_value(ctx=FlyteContext.current_context(), lv=m[k], expected_python_type=type_v) | ||
output_list.append(v) | ||
else: | ||
raise TypeError(f"Expected output {k} of type {type_v} not found in the notebook outputs") | ||
|
||
if len(output_list) == 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what happens if we don't do this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There will be a mismatch between the output type and the downstream task's input type. |
||
return output_list[0] | ||
return tuple(output_list) | ||
|
||
def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need the second check instance is subclass of pythontask