Skip to content

Commit

Permalink
Revise Pickle Remote Task for Jupyter Notebook Environment (#2799)
Browse files Browse the repository at this point in the history
Signed-off-by: Mecoli1219 <[email protected]>
  • Loading branch information
Mecoli1219 authored Oct 13, 2024
1 parent 3f0b218 commit ac7b7e8
Show file tree
Hide file tree
Showing 12 changed files with 128 additions and 214 deletions.
5 changes: 0 additions & 5 deletions flytekit/configuration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,6 @@ class SerializationSettings(DataClassJsonMixin):
can be fast registered (and thus omit building a Docker image) this object contains additional parameters
for serialization.
source_root (Optional[str]): The root directory of the source code.
interactive_mode_enabled (bool): Whether or not the task is being serialized in interactive mode.
"""

image_config: ImageConfig
Expand All @@ -841,7 +840,6 @@ class SerializationSettings(DataClassJsonMixin):
flytekit_virtualenv_root: Optional[str] = None
fast_serialization_settings: Optional[FastSerializationSettings] = None
source_root: Optional[str] = None
interactive_mode_enabled: bool = False

def __post_init__(self):
if self.flytekit_virtualenv_root is None:
Expand Down Expand Up @@ -916,7 +914,6 @@ def new_builder(self) -> Builder:
python_interpreter=self.python_interpreter,
fast_serialization_settings=self.fast_serialization_settings,
source_root=self.source_root,
interactive_mode_enabled=self.interactive_mode_enabled,
)

def should_fast_serialize(self) -> bool:
Expand Down Expand Up @@ -968,7 +965,6 @@ class Builder(object):
python_interpreter: Optional[str] = None
fast_serialization_settings: Optional[FastSerializationSettings] = None
source_root: Optional[str] = None
interactive_mode_enabled: bool = False

def with_fast_serialization_settings(self, fss: fast_serialization_settings) -> SerializationSettings.Builder:
self.fast_serialization_settings = fss
Expand All @@ -986,5 +982,4 @@ def build(self) -> SerializationSettings:
python_interpreter=self.python_interpreter,
fast_serialization_settings=self.fast_serialization_settings,
source_root=self.source_root,
interactive_mode_enabled=self.interactive_mode_enabled,
)
8 changes: 0 additions & 8 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,14 +666,6 @@ def _append_attr(self, key) -> Promise:

return new_promise

def __getstate__(self) -> Dict[str, Any]:
# This func is used to pickle the object.
return vars(self)

def __setstate__(self, state: Dict[str, Any]) -> None:
# This func is used to unpickle the object without infinite recursion.
vars(self).update(state)


def create_native_named_tuple(
ctx: FlyteContext,
Expand Down
8 changes: 5 additions & 3 deletions flytekit/core/python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,16 +292,18 @@ def name(self) -> str:

@timeit("Load task")
def load_task(self, loader_args: List[str]) -> PythonAutoContainerTask:
_, entity_name, *_ = loader_args
import gzip

import cloudpickle

with gzip.open(PICKLE_FILE_PATH, "r") as f:
return cloudpickle.load(f)
entity_dict = cloudpickle.load(f)
return entity_dict[entity_name]

def loader_args(self, settings: SerializationSettings, task: PythonAutoContainerTask) -> List[str]: # type:ignore
_, m, t, _ = extract_task_module(task)
return ["task-module", m, "task-name", t]
n, _, _, _ = extract_task_module(task)
return ["entity-name", n]

def get_all_tasks(self) -> List[PythonAutoContainerTask]: # type: ignore
raise NotImplementedError
Expand Down
9 changes: 5 additions & 4 deletions flytekit/core/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,14 +299,15 @@ def _resolve_abs_module_name(self, path: str, package_root: typing.Optional[str]

# Execution in a Jupyter notebook, we cannot resolve the module path
if not os.path.exists(dirname):
logger.debug(
f"Directory {dirname} does not exist. It is likely that we are in a Jupyter notebook or a pickle file was received."
)

if not is_ipython_or_pickle_exists():
raise AssertionError(
f"Directory {dirname} does not exist, and we are not in a Jupyter notebook or received a pickle file."
)

logger.debug(
f"Directory {dirname} does not exist. It is likely that we are in a Jupyter notebook or a pickle file was received."
f"Returning {basename} as the module name."
)
return basename

# If we have reached a directory with no __init__, ignore
Expand Down
102 changes: 84 additions & 18 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import base64
import configparser
import functools
import gzip
import hashlib
import os
import pathlib
Expand Down Expand Up @@ -37,12 +38,18 @@
from flytekit.configuration import Config, FastSerializationSettings, ImageConfig, SerializationSettings
from flytekit.constants import CopyFileDetection
from flytekit.core import constants, utils
from flytekit.core.array_node_map_task import ArrayNodeMapTask
from flytekit.core.artifact import Artifact
from flytekit.core.base_task import PythonTask
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.core.launch_plan import LaunchPlan, ReferenceLaunchPlan
from flytekit.core.python_auto_container import PythonAutoContainerTask
from flytekit.core.node import Node as CoreNode
from flytekit.core.python_auto_container import (
PICKLE_FILE_PATH,
PythonAutoContainerTask,
default_notebook_task_resolver,
)
from flytekit.core.reference_entity import ReferenceSpec
from flytekit.core.task import ReferenceTask
from flytekit.core.tracker import extract_task_module
Expand Down Expand Up @@ -791,10 +798,6 @@ async def _serialize_and_register(
)
if serialization_settings.version is None:
serialization_settings.version = version
serialization_settings.interactive_mode_enabled = self.interactive_mode_enabled

options = options or Options()
options.file_uploader = options.file_uploader or self.upload_file

_ = get_serializable(m, settings=serialization_settings, entity=entity, options=options)
# concurrent register
Expand Down Expand Up @@ -1071,6 +1074,16 @@ def _version_from_hash(
# and does not increase entropy of the hash while making it very inconvenient to copy-and-paste.
return base64.urlsafe_b64encode(h.digest()).decode("ascii").rstrip("=")

def _get_image_names(self, entity: typing.Union[PythonAutoContainerTask, WorkflowBase]) -> typing.List[str]:
if isinstance(entity, PythonAutoContainerTask) and isinstance(entity.container_image, ImageSpec):
return [entity.container_image.image_name()]
if isinstance(entity, WorkflowBase):
image_names = []
for n in entity.nodes:
image_names.extend(self._get_image_names(n.flyte_entity))
return image_names
return []

def register_script(
self,
entity: typing.Union[WorkflowBase, PythonTask],
Expand Down Expand Up @@ -1144,17 +1157,6 @@ def register_script(
)

if version is None:

def _get_image_names(entity: typing.Union[PythonAutoContainerTask, WorkflowBase]) -> typing.List[str]:
if isinstance(entity, PythonAutoContainerTask) and isinstance(entity.container_image, ImageSpec):
return [entity.container_image.image_name()]
if isinstance(entity, WorkflowBase):
image_names = []
for n in entity.nodes:
image_names.extend(_get_image_names(n.flyte_entity))
return image_names
return []

default_inputs = None
if isinstance(entity, WorkflowBase):
default_inputs = entity.python_interface.default_inputs_as_kwargs
Expand All @@ -1163,7 +1165,7 @@ def _get_image_names(entity: typing.Union[PythonAutoContainerTask, WorkflowBase]
# but we don't have to use it when registering with the Flyte backend.
# For that add the hash of the compilation settings to hash of file
version = self._version_from_hash(
md5_bytes, serialization_settings, default_inputs, *_get_image_names(entity)
md5_bytes, serialization_settings, default_inputs, *self._get_image_names(entity)
)

if isinstance(entity, PythonTask):
Expand Down Expand Up @@ -1856,13 +1858,23 @@ def execute_local_task(
not_found = True

if not_found:
fast_serialization_settings = None
if self.interactive_mode_enabled:
md5_bytes, fast_serialization_settings = self._pickle_and_upload_entity(entity)

ss = SerializationSettings(
image_config=image_config or ImageConfig.auto_default_image(),
project=project or self.default_project,
domain=domain or self._default_domain,
version=version,
fast_serialization_settings=fast_serialization_settings,
)
flyte_task: FlyteTask = self.register_task(entity, ss)

default_inputs = entity.python_interface.default_inputs_as_kwargs
if version is None and self.interactive_mode_enabled:
version = self._version_from_hash(md5_bytes, ss, default_inputs, *self._get_image_names(entity))

flyte_task: FlyteTask = self.register_task(entity, ss, version)

return self.execute(
flyte_task,
Expand Down Expand Up @@ -1923,18 +1935,26 @@ def execute_local_workflow(
if not image_config:
image_config = ImageConfig.auto_default_image()

fast_serialization_settings = None
if self.interactive_mode_enabled:
md5_bytes, fast_serialization_settings = self._pickle_and_upload_entity(entity)

ss = SerializationSettings(
image_config=image_config,
project=resolved_identifiers.project,
domain=resolved_identifiers.domain,
version=resolved_identifiers.version,
fast_serialization_settings=fast_serialization_settings,
)
try:
# Just fetch to see if it already exists
# todo: Add logic to check that the fetched workflow is functionally equivalent.
self.fetch_workflow(**resolved_identifiers_dict)
except FlyteEntityNotExistException:
logger.info("Registering workflow because it wasn't found in Flyte Admin.")
default_inputs = entity.python_interface.default_inputs_as_kwargs
if not version and self.interactive_mode_enabled:
version = self._version_from_hash(md5_bytes, ss, default_inputs, *self._get_image_names(entity))
self.register_workflow(entity, ss, version=version, options=options)

try:
Expand Down Expand Up @@ -2551,3 +2571,49 @@ def download(
lm = data
for var, literal in lm.items():
download_literal(self.file_access, var, literal, download_to)

def _get_pickled_target_dict(self, root_entity: typing.Any) -> typing.Dict[str, typing.Any]:
"""
Get the pickled target dictionary for the entity.
:param root_entity: The entity to get the pickled target for.
:return: The pickled target dictionary.
"""
queue = [root_entity]
pickled_target_dict = {}
while queue:
entity = queue.pop()
if isinstance(entity, PythonTask):
if isinstance(entity, (PythonAutoContainerTask, ArrayNodeMapTask)):
if isinstance(entity, ArrayNodeMapTask):
entity._run_task.set_resolver(default_notebook_task_resolver)
pickled_target_dict[entity._run_task.name] = entity._run_task
else:
entity.set_resolver(default_notebook_task_resolver)
pickled_target_dict[entity.name] = entity
elif isinstance(entity, WorkflowBase):
for task in entity.nodes:
queue.append(task)
elif isinstance(entity, CoreNode):
queue.append(entity.flyte_entity)
return pickled_target_dict

def _pickle_and_upload_entity(self, entity: typing.Any) -> typing.Tuple[bytes, FastSerializationSettings]:
"""
Pickle the entity to the specified location. This is useful for debugging and for sharing entities across
different environments.
:param entity: The entity to pickle
"""
# get all entity tasks
pickled_dict = self._get_pickled_target_dict(entity)
with tempfile.TemporaryDirectory() as tmp_dir:
dest = pathlib.Path(tmp_dir, PICKLE_FILE_PATH)
with gzip.GzipFile(filename=dest, mode="wb", mtime=0) as gzipped:
cloudpickle.dump(pickled_dict, gzipped)
if os.path.getsize(dest) > 150 * 1024 * 1024:
raise ValueError(
"The size of the task to pickled exceeds the limit of 150MB. Please reduce the size of the task."
)
logger.debug(f"Uploading Pickled representation of Workflow `{entity.name}` to remote storage...")
md5_bytes, native_url = self.upload_file(dest)

return md5_bytes, FastSerializationSettings(enabled=True, distribution_location=native_url, destination_dir=".")
63 changes: 1 addition & 62 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import os
import pathlib
import sys
import tempfile
import typing
from collections import OrderedDict
from typing import Callable, Dict, List, Optional, Tuple, Union
Expand All @@ -23,16 +20,12 @@
from flytekit.core.node import Node
from flytekit.core.options import Options
from flytekit.core.python_auto_container import (
PICKLE_FILE_PATH,
PythonAutoContainerTask,
default_notebook_task_resolver,
)
from flytekit.core.reference_entity import ReferenceEntity, ReferenceSpec, ReferenceTemplate
from flytekit.core.task import ReferenceTask
from flytekit.core.utils import ClassDecorator, _dnsify
from flytekit.core.workflow import ReferenceWorkflow, WorkflowBase
from flytekit.exceptions.user import FlyteAssertion
from flytekit.loggers import logger
from flytekit.models import common as _common_models
from flytekit.models import interface as interface_models
from flytekit.models import launch_plan as _launch_plan_models
Expand Down Expand Up @@ -127,52 +120,6 @@ def fn(settings: SerializationSettings) -> List[str]:
return fn


def _update_serialization_settings_for_ipython(
entity: FlyteLocalEntity,
serialization_settings: SerializationSettings,
options: Optional[Options] = None,
):
# We are in an interactive environment. We will serialize the task as a pickled object and upload it to remote
# storage.
if isinstance(entity, PythonFunctionTask):
if entity.execution_mode == PythonFunctionTask.ExecutionBehavior.DYNAMIC:
raise FlyteAssertion(
f"Dynamic tasks are not supported in interactive mode. {entity.name} is a dynamic task."
)

if options is None or options.file_uploader is None:
raise FlyteAssertion("To work interactively with Flyte, a code transporter/uploader should be configured.")

# For map tasks, we need to serialize the actual task, not the map task itself
if isinstance(entity, ArrayNodeMapTask):
entity._run_task.set_resolver(default_notebook_task_resolver)
actual_task = entity._run_task
else:
entity.set_resolver(default_notebook_task_resolver)
actual_task = entity

import gzip

import cloudpickle

from flytekit.configuration import FastSerializationSettings

with tempfile.TemporaryDirectory() as tmp_dir:
dest = pathlib.Path(tmp_dir, PICKLE_FILE_PATH)
with gzip.GzipFile(filename=dest, mode="wb", mtime=0) as gzipped:
cloudpickle.dump(actual_task, gzipped)
if os.path.getsize(dest) > 150 * 1024 * 1024:
raise ValueError(
"The size of the task to pickled exceeds the limit of 150MB. Please reduce the size of the task."
)
logger.debug(f"Uploading Pickled representation of Task `{actual_task.name}` to remote storage...")
_, native_url = options.file_uploader(dest)

serialization_settings.fast_serialization_settings = FastSerializationSettings(
enabled=True, distribution_location=native_url, destination_dir="."
)


def get_serializable_task(
entity_mapping: OrderedDict,
settings: SerializationSettings,
Expand All @@ -187,14 +134,6 @@ def get_serializable_task(
settings.version,
)

# Try to update the serialization settings for ipython / jupyter notebook / interactive mode if we are in an
# interactive environment like Jupyter notebook
if settings.interactive_mode_enabled is True:
# If the entity is not a PythonAutoContainerTask, we don't need to do anything, as only Tasks with container |
# user code in container needs to be serialized as pickled objects.
if isinstance(entity, (PythonAutoContainerTask, ArrayNodeMapTask)):
_update_serialization_settings_for_ipython(entity, settings, options)

if isinstance(entity, PythonFunctionTask) and entity.execution_mode == PythonFunctionTask.ExecutionBehavior.DYNAMIC:
for e in context_manager.FlyteEntities.entities:
if isinstance(e, PythonAutoContainerTask):
Expand Down Expand Up @@ -800,7 +739,7 @@ def get_serializable(
cp_entity = get_reference_spec(entity_mapping, settings, entity)

elif isinstance(entity, PythonTask):
cp_entity = get_serializable_task(entity_mapping, settings, entity, options)
cp_entity = get_serializable_task(entity_mapping, settings, entity)

elif isinstance(entity, WorkflowBase):
cp_entity = get_serializable_workflow(entity_mapping, settings, entity, options)
Expand Down
Loading

0 comments on commit ac7b7e8

Please sign in to comment.