diff --git a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py index 304932a828..9e9535dec2 100644 --- a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py +++ b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py @@ -1,5 +1,7 @@ import json +import logging import os +import sys import typing from typing import Any @@ -84,6 +86,13 @@ class NotebookTask(PythonInstanceTask[T]): Users can access these notebooks after execution of the task locally or from remote servers. + .. note: + + By default, print statements in your notebook won't be transmitted to the pod logs/stdout. If you would + like to have logs forwarded as the notebook executes, pass the stream_logs argument. Note that notebook + logs can be quite verbose, so ensure you are prepared for any downstream log ingestion costs + (e.g., cloudwatch) + .. todo: Implicit extraction of SparkConfiguration from the notebook is not supported. @@ -112,6 +121,7 @@ def __init__( name: str, notebook_path: str, render_deck: bool = False, + stream_logs: bool = False, task_config: T = None, inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, outputs: typing.Optional[typing.Dict[str, typing.Type]] = None, @@ -132,6 +142,16 @@ def __init__( self._notebook_path = os.path.abspath(notebook_path) self._render_deck = render_deck + self._stream_logs = stream_logs + + # Send the papermill logger to stdout so that it appears in pod logs. Note that papermill doesn't allow + # injecting a logger, so we cannot redirect logs to the flyte child loggers (e.g., the userspace logger) + # and inherit their settings, but we instead must send logs to stdout directly + if self._stream_logs: + papermill_logger = logging.getLogger("papermill") + papermill_logger.addHandler(logging.StreamHandler(sys.stdout)) + # Papermill leaves the default level of DEBUG. We increase it here. + papermill_logger.setLevel(logging.INFO) if not os.path.exists(self._notebook_path): raise ValueError(f"Illegal notebook path passed in {self._notebook_path}") @@ -207,7 +227,7 @@ def execute(self, **kwargs) -> Any: """ logger.info(f"Hijacking the call for task-type {self.task_type}, to call notebook.") # Execute Notebook via Papermill. - pm.execute_notebook(self._notebook_path, self.output_notebook_path, parameters=kwargs) # type: ignore + pm.execute_notebook(self._notebook_path, self.output_notebook_path, parameters=kwargs, log_output=self._stream_logs) # type: ignore outputs = self.extract_outputs(self.output_notebook_path) self.render_nb_html(self.output_notebook_path, self.rendered_output_path)