From ac7b7e8b4801ccaf30fa51768750a805ba5540b8 Mon Sep 17 00:00:00 2001 From: Chun-Mao Lai <72752478+Mecoli1219@users.noreply.github.com> Date: Sat, 12 Oct 2024 19:26:38 -0700 Subject: [PATCH] Revise Pickle Remote Task for Jupyter Notebook Environment (#2799) Signed-off-by: Mecoli1219 --- flytekit/configuration/__init__.py | 5 - flytekit/core/promise.py | 8 -- flytekit/core/python_auto_container.py | 8 +- flytekit/core/tracker.py | 9 +- flytekit/remote/remote.py | 102 ++++++++++++++---- flytekit/tools/translator.py | 63 +---------- .../unit/core/test_array_node_map_task.py | 49 --------- .../unit/core/test_context_manager.py | 2 +- tests/flytekit/unit/core/test_promise.py | 18 ---- .../unit/core/test_python_auto_container.py | 33 +----- tests/flytekit/unit/core/test_resolver.py | 28 ++++- tests/flytekit/unit/test_translator.py | 17 --- 12 files changed, 128 insertions(+), 214 deletions(-) diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index 6dab3e0cb0..97a9940425 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -828,7 +828,6 @@ class SerializationSettings(DataClassJsonMixin): can be fast registered (and thus omit building a Docker image) this object contains additional parameters for serialization. source_root (Optional[str]): The root directory of the source code. - interactive_mode_enabled (bool): Whether or not the task is being serialized in interactive mode. """ image_config: ImageConfig @@ -841,7 +840,6 @@ class SerializationSettings(DataClassJsonMixin): flytekit_virtualenv_root: Optional[str] = None fast_serialization_settings: Optional[FastSerializationSettings] = None source_root: Optional[str] = None - interactive_mode_enabled: bool = False def __post_init__(self): if self.flytekit_virtualenv_root is None: @@ -916,7 +914,6 @@ def new_builder(self) -> Builder: python_interpreter=self.python_interpreter, fast_serialization_settings=self.fast_serialization_settings, source_root=self.source_root, - interactive_mode_enabled=self.interactive_mode_enabled, ) def should_fast_serialize(self) -> bool: @@ -968,7 +965,6 @@ class Builder(object): python_interpreter: Optional[str] = None fast_serialization_settings: Optional[FastSerializationSettings] = None source_root: Optional[str] = None - interactive_mode_enabled: bool = False def with_fast_serialization_settings(self, fss: fast_serialization_settings) -> SerializationSettings.Builder: self.fast_serialization_settings = fss @@ -986,5 +982,4 @@ def build(self) -> SerializationSettings: python_interpreter=self.python_interpreter, fast_serialization_settings=self.fast_serialization_settings, source_root=self.source_root, - interactive_mode_enabled=self.interactive_mode_enabled, ) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index a193c0702f..e959be9106 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -666,14 +666,6 @@ def _append_attr(self, key) -> Promise: return new_promise - def __getstate__(self) -> Dict[str, Any]: - # This func is used to pickle the object. - return vars(self) - - def __setstate__(self, state: Dict[str, Any]) -> None: - # This func is used to unpickle the object without infinite recursion. - vars(self).update(state) - def create_native_named_tuple( ctx: FlyteContext, diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index e644f87db3..1466c351ac 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -292,16 +292,18 @@ def name(self) -> str: @timeit("Load task") def load_task(self, loader_args: List[str]) -> PythonAutoContainerTask: + _, entity_name, *_ = loader_args import gzip import cloudpickle with gzip.open(PICKLE_FILE_PATH, "r") as f: - return cloudpickle.load(f) + entity_dict = cloudpickle.load(f) + return entity_dict[entity_name] def loader_args(self, settings: SerializationSettings, task: PythonAutoContainerTask) -> List[str]: # type:ignore - _, m, t, _ = extract_task_module(task) - return ["task-module", m, "task-name", t] + n, _, _, _ = extract_task_module(task) + return ["entity-name", n] def get_all_tasks(self) -> List[PythonAutoContainerTask]: # type: ignore raise NotImplementedError diff --git a/flytekit/core/tracker.py b/flytekit/core/tracker.py index 382ca4b234..c65b660ede 100644 --- a/flytekit/core/tracker.py +++ b/flytekit/core/tracker.py @@ -299,14 +299,15 @@ def _resolve_abs_module_name(self, path: str, package_root: typing.Optional[str] # Execution in a Jupyter notebook, we cannot resolve the module path if not os.path.exists(dirname): - logger.debug( - f"Directory {dirname} does not exist. It is likely that we are in a Jupyter notebook or a pickle file was received." - ) - if not is_ipython_or_pickle_exists(): raise AssertionError( f"Directory {dirname} does not exist, and we are not in a Jupyter notebook or received a pickle file." ) + + logger.debug( + f"Directory {dirname} does not exist. It is likely that we are in a Jupyter notebook or a pickle file was received." + f"Returning {basename} as the module name." + ) return basename # If we have reached a directory with no __init__, ignore diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 0528d0d155..98374ff26c 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -10,6 +10,7 @@ import base64 import configparser import functools +import gzip import hashlib import os import pathlib @@ -37,12 +38,18 @@ from flytekit.configuration import Config, FastSerializationSettings, ImageConfig, SerializationSettings from flytekit.constants import CopyFileDetection from flytekit.core import constants, utils +from flytekit.core.array_node_map_task import ArrayNodeMapTask from flytekit.core.artifact import Artifact from flytekit.core.base_task import PythonTask from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.data_persistence import FileAccessProvider from flytekit.core.launch_plan import LaunchPlan, ReferenceLaunchPlan -from flytekit.core.python_auto_container import PythonAutoContainerTask +from flytekit.core.node import Node as CoreNode +from flytekit.core.python_auto_container import ( + PICKLE_FILE_PATH, + PythonAutoContainerTask, + default_notebook_task_resolver, +) from flytekit.core.reference_entity import ReferenceSpec from flytekit.core.task import ReferenceTask from flytekit.core.tracker import extract_task_module @@ -791,10 +798,6 @@ async def _serialize_and_register( ) if serialization_settings.version is None: serialization_settings.version = version - serialization_settings.interactive_mode_enabled = self.interactive_mode_enabled - - options = options or Options() - options.file_uploader = options.file_uploader or self.upload_file _ = get_serializable(m, settings=serialization_settings, entity=entity, options=options) # concurrent register @@ -1071,6 +1074,16 @@ def _version_from_hash( # and does not increase entropy of the hash while making it very inconvenient to copy-and-paste. return base64.urlsafe_b64encode(h.digest()).decode("ascii").rstrip("=") + def _get_image_names(self, entity: typing.Union[PythonAutoContainerTask, WorkflowBase]) -> typing.List[str]: + if isinstance(entity, PythonAutoContainerTask) and isinstance(entity.container_image, ImageSpec): + return [entity.container_image.image_name()] + if isinstance(entity, WorkflowBase): + image_names = [] + for n in entity.nodes: + image_names.extend(self._get_image_names(n.flyte_entity)) + return image_names + return [] + def register_script( self, entity: typing.Union[WorkflowBase, PythonTask], @@ -1144,17 +1157,6 @@ def register_script( ) if version is None: - - def _get_image_names(entity: typing.Union[PythonAutoContainerTask, WorkflowBase]) -> typing.List[str]: - if isinstance(entity, PythonAutoContainerTask) and isinstance(entity.container_image, ImageSpec): - return [entity.container_image.image_name()] - if isinstance(entity, WorkflowBase): - image_names = [] - for n in entity.nodes: - image_names.extend(_get_image_names(n.flyte_entity)) - return image_names - return [] - default_inputs = None if isinstance(entity, WorkflowBase): default_inputs = entity.python_interface.default_inputs_as_kwargs @@ -1163,7 +1165,7 @@ def _get_image_names(entity: typing.Union[PythonAutoContainerTask, WorkflowBase] # but we don't have to use it when registering with the Flyte backend. # For that add the hash of the compilation settings to hash of file version = self._version_from_hash( - md5_bytes, serialization_settings, default_inputs, *_get_image_names(entity) + md5_bytes, serialization_settings, default_inputs, *self._get_image_names(entity) ) if isinstance(entity, PythonTask): @@ -1856,13 +1858,23 @@ def execute_local_task( not_found = True if not_found: + fast_serialization_settings = None + if self.interactive_mode_enabled: + md5_bytes, fast_serialization_settings = self._pickle_and_upload_entity(entity) + ss = SerializationSettings( image_config=image_config or ImageConfig.auto_default_image(), project=project or self.default_project, domain=domain or self._default_domain, version=version, + fast_serialization_settings=fast_serialization_settings, ) - flyte_task: FlyteTask = self.register_task(entity, ss) + + default_inputs = entity.python_interface.default_inputs_as_kwargs + if version is None and self.interactive_mode_enabled: + version = self._version_from_hash(md5_bytes, ss, default_inputs, *self._get_image_names(entity)) + + flyte_task: FlyteTask = self.register_task(entity, ss, version) return self.execute( flyte_task, @@ -1923,11 +1935,16 @@ def execute_local_workflow( if not image_config: image_config = ImageConfig.auto_default_image() + fast_serialization_settings = None + if self.interactive_mode_enabled: + md5_bytes, fast_serialization_settings = self._pickle_and_upload_entity(entity) + ss = SerializationSettings( image_config=image_config, project=resolved_identifiers.project, domain=resolved_identifiers.domain, version=resolved_identifiers.version, + fast_serialization_settings=fast_serialization_settings, ) try: # Just fetch to see if it already exists @@ -1935,6 +1952,9 @@ def execute_local_workflow( self.fetch_workflow(**resolved_identifiers_dict) except FlyteEntityNotExistException: logger.info("Registering workflow because it wasn't found in Flyte Admin.") + default_inputs = entity.python_interface.default_inputs_as_kwargs + if not version and self.interactive_mode_enabled: + version = self._version_from_hash(md5_bytes, ss, default_inputs, *self._get_image_names(entity)) self.register_workflow(entity, ss, version=version, options=options) try: @@ -2551,3 +2571,49 @@ def download( lm = data for var, literal in lm.items(): download_literal(self.file_access, var, literal, download_to) + + def _get_pickled_target_dict(self, root_entity: typing.Any) -> typing.Dict[str, typing.Any]: + """ + Get the pickled target dictionary for the entity. + :param root_entity: The entity to get the pickled target for. + :return: The pickled target dictionary. + """ + queue = [root_entity] + pickled_target_dict = {} + while queue: + entity = queue.pop() + if isinstance(entity, PythonTask): + if isinstance(entity, (PythonAutoContainerTask, ArrayNodeMapTask)): + if isinstance(entity, ArrayNodeMapTask): + entity._run_task.set_resolver(default_notebook_task_resolver) + pickled_target_dict[entity._run_task.name] = entity._run_task + else: + entity.set_resolver(default_notebook_task_resolver) + pickled_target_dict[entity.name] = entity + elif isinstance(entity, WorkflowBase): + for task in entity.nodes: + queue.append(task) + elif isinstance(entity, CoreNode): + queue.append(entity.flyte_entity) + return pickled_target_dict + + def _pickle_and_upload_entity(self, entity: typing.Any) -> typing.Tuple[bytes, FastSerializationSettings]: + """ + Pickle the entity to the specified location. This is useful for debugging and for sharing entities across + different environments. + :param entity: The entity to pickle + """ + # get all entity tasks + pickled_dict = self._get_pickled_target_dict(entity) + with tempfile.TemporaryDirectory() as tmp_dir: + dest = pathlib.Path(tmp_dir, PICKLE_FILE_PATH) + with gzip.GzipFile(filename=dest, mode="wb", mtime=0) as gzipped: + cloudpickle.dump(pickled_dict, gzipped) + if os.path.getsize(dest) > 150 * 1024 * 1024: + raise ValueError( + "The size of the task to pickled exceeds the limit of 150MB. Please reduce the size of the task." + ) + logger.debug(f"Uploading Pickled representation of Workflow `{entity.name}` to remote storage...") + md5_bytes, native_url = self.upload_file(dest) + + return md5_bytes, FastSerializationSettings(enabled=True, distribution_location=native_url, destination_dir=".") diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 9800d0eee3..812a71b5d3 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -1,7 +1,4 @@ -import os -import pathlib import sys -import tempfile import typing from collections import OrderedDict from typing import Callable, Dict, List, Optional, Tuple, Union @@ -23,16 +20,12 @@ from flytekit.core.node import Node from flytekit.core.options import Options from flytekit.core.python_auto_container import ( - PICKLE_FILE_PATH, PythonAutoContainerTask, - default_notebook_task_resolver, ) from flytekit.core.reference_entity import ReferenceEntity, ReferenceSpec, ReferenceTemplate from flytekit.core.task import ReferenceTask from flytekit.core.utils import ClassDecorator, _dnsify from flytekit.core.workflow import ReferenceWorkflow, WorkflowBase -from flytekit.exceptions.user import FlyteAssertion -from flytekit.loggers import logger from flytekit.models import common as _common_models from flytekit.models import interface as interface_models from flytekit.models import launch_plan as _launch_plan_models @@ -127,52 +120,6 @@ def fn(settings: SerializationSettings) -> List[str]: return fn -def _update_serialization_settings_for_ipython( - entity: FlyteLocalEntity, - serialization_settings: SerializationSettings, - options: Optional[Options] = None, -): - # We are in an interactive environment. We will serialize the task as a pickled object and upload it to remote - # storage. - if isinstance(entity, PythonFunctionTask): - if entity.execution_mode == PythonFunctionTask.ExecutionBehavior.DYNAMIC: - raise FlyteAssertion( - f"Dynamic tasks are not supported in interactive mode. {entity.name} is a dynamic task." - ) - - if options is None or options.file_uploader is None: - raise FlyteAssertion("To work interactively with Flyte, a code transporter/uploader should be configured.") - - # For map tasks, we need to serialize the actual task, not the map task itself - if isinstance(entity, ArrayNodeMapTask): - entity._run_task.set_resolver(default_notebook_task_resolver) - actual_task = entity._run_task - else: - entity.set_resolver(default_notebook_task_resolver) - actual_task = entity - - import gzip - - import cloudpickle - - from flytekit.configuration import FastSerializationSettings - - with tempfile.TemporaryDirectory() as tmp_dir: - dest = pathlib.Path(tmp_dir, PICKLE_FILE_PATH) - with gzip.GzipFile(filename=dest, mode="wb", mtime=0) as gzipped: - cloudpickle.dump(actual_task, gzipped) - if os.path.getsize(dest) > 150 * 1024 * 1024: - raise ValueError( - "The size of the task to pickled exceeds the limit of 150MB. Please reduce the size of the task." - ) - logger.debug(f"Uploading Pickled representation of Task `{actual_task.name}` to remote storage...") - _, native_url = options.file_uploader(dest) - - serialization_settings.fast_serialization_settings = FastSerializationSettings( - enabled=True, distribution_location=native_url, destination_dir="." - ) - - def get_serializable_task( entity_mapping: OrderedDict, settings: SerializationSettings, @@ -187,14 +134,6 @@ def get_serializable_task( settings.version, ) - # Try to update the serialization settings for ipython / jupyter notebook / interactive mode if we are in an - # interactive environment like Jupyter notebook - if settings.interactive_mode_enabled is True: - # If the entity is not a PythonAutoContainerTask, we don't need to do anything, as only Tasks with container | - # user code in container needs to be serialized as pickled objects. - if isinstance(entity, (PythonAutoContainerTask, ArrayNodeMapTask)): - _update_serialization_settings_for_ipython(entity, settings, options) - if isinstance(entity, PythonFunctionTask) and entity.execution_mode == PythonFunctionTask.ExecutionBehavior.DYNAMIC: for e in context_manager.FlyteEntities.entities: if isinstance(e, PythonAutoContainerTask): @@ -800,7 +739,7 @@ def get_serializable( cp_entity = get_reference_spec(entity_mapping, settings, entity) elif isinstance(entity, PythonTask): - cp_entity = get_serializable_task(entity_mapping, settings, entity, options) + cp_entity = get_serializable_task(entity_mapping, settings, entity) elif isinstance(entity, WorkflowBase): cp_entity = get_serializable_workflow(entity_mapping, settings, entity, options) diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index 1d365d6629..dcde77e8cd 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -154,55 +154,6 @@ def t1(a: int) -> int: ] -def test_interactive_serialization(interactive_serialization_settings): - @task - def t1(a: int) -> int: - return a + 1 - - def mock_file_uploader(dest: pathlib.Path): - return (0, dest.name) - - arraynode_maptask = map_task(t1, metadata=TaskMetadata(retries=2)) - option = Options() - option.file_uploader = mock_file_uploader - task_spec = get_serializable(OrderedDict(), interactive_serialization_settings, arraynode_maptask, options=option) - - assert task_spec.template.metadata.retries.retries == 2 - assert task_spec.template.custom["minSuccessRatio"] == 1.0 - assert task_spec.template.type == "python-task" - assert task_spec.template.task_type_version == 1 - assert task_spec.template.container.args == [ - "pyflyte-fast-execute", - "--additional-distribution", - PICKLE_FILE_PATH, - "--dest-dir", - ".", - "--", - "pyflyte-map-execute", - "--inputs", - "{{.input}}", - "--output-prefix", - "{{.outputPrefix}}", - "--raw-output-data-prefix", - "{{.rawOutputDataPrefix}}", - "--checkpoint-path", - "{{.checkpointOutputPrefix}}", - "--prev-checkpoint", - "{{.prevCheckpointPrefix}}", - "--resolver", - "flytekit.core.array_node_map_task.ArrayNodeMapTaskResolver", - "--", - "vars", - "", - "resolver", - "flytekit.core.python_auto_container.default_notebook_task_resolver", - "task-module", - "tests.flytekit.unit.core.test_array_node_map_task", - "task-name", - "t1", - ] - - def test_fast_serialization(serialization_settings): serialization_settings.fast_serialization_settings = FastSerializationSettings(enabled=True) diff --git a/tests/flytekit/unit/core/test_context_manager.py b/tests/flytekit/unit/core/test_context_manager.py index 8379a3d3eb..70a2d552d8 100644 --- a/tests/flytekit/unit/core/test_context_manager.py +++ b/tests/flytekit/unit/core/test_context_manager.py @@ -267,7 +267,7 @@ def test_serialization_settings_transport(): ss = SerializationSettings.from_transport(tp) assert ss is not None assert ss == serialization_settings - assert len(tp) == 432 + assert len(tp) == 408 def test_exec_params(): diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index 59faefdc38..8154f9994e 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -273,21 +273,3 @@ async def test_prom_with_union_literals(): assert bd.scalar.union.stored_type.structure.tag == "int" bd = await binding_data_from_python_std(ctx, lt, "hello", pt, []) assert bd.scalar.union.stored_type.structure.tag == "str" - -def test_pickling_promise_object(): - @task - def t1(a: int) -> int: - return a - - ctx = context_manager.FlyteContext.current_context().with_compilation_state(CompilationState(prefix="")) - p = create_and_link_node(ctx, t1, a=3) - assert p.ref.node_id == "n0" - assert p.ref.var == "o0" - assert len(p.ref.node.bindings) == 1 - - import cloudpickle - - p2 = cloudpickle.loads(cloudpickle.dumps(p)) - assert p2.ref.node_id == "n0" - assert p2.ref.var == "o0" - assert len(p2.ref.node.bindings) == 1 diff --git a/tests/flytekit/unit/core/test_python_auto_container.py b/tests/flytekit/unit/core/test_python_auto_container.py index 2f05c1227b..331a2e0561 100644 --- a/tests/flytekit/unit/core/test_python_auto_container.py +++ b/tests/flytekit/unit/core/test_python_auto_container.py @@ -8,7 +8,7 @@ from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core.base_task import TaskMetadata from flytekit.core.pod_template import PodTemplate -from flytekit.core.python_auto_container import PythonAutoContainerTask, get_registerable_container_image, PICKLE_FILE_PATH +from flytekit.core.python_auto_container import PythonAutoContainerTask, get_registerable_container_image, default_notebook_task_resolver from flytekit.core.resources import Resources from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec from flytekit.tools.translator import get_serializable_task, Options @@ -43,13 +43,6 @@ def minimal_serialization_settings_no_default_image(no_default_image_config): return SerializationSettings(project="p", domain="d", version="v", image_config=no_default_image_config) -@pytest.fixture -def interactive_serialization_settings(default_image_config): - return SerializationSettings( - project="p", domain="d", version="v", image_config=default_image_config, env={"FOO": "bar"}, interactive_mode_enabled=True - ) - - @pytest.fixture( params=[ "default_serialization_settings", @@ -142,26 +135,6 @@ def test_get_container_without_serialization_settings_envvars(minimal_serializat assert ts.template.container.env == {"HAM": "spam"} -def test_get_container_with_interactive_settings(interactive_serialization_settings): - c = task_with_env_vars.get_container(interactive_serialization_settings) - assert c.image == "docker.io/xyz:some-git-hash" - assert c.env == {"FOO": "bar", "HAM": "spam"} - - def mock_file_uploader(dest: pathlib.Path): - return (0, dest.name) - - option = Options() - option.file_uploader = mock_file_uploader - ts = get_serializable_task(OrderedDict(), interactive_serialization_settings, task_with_env_vars, options=option) - assert ts.template.container.image == "docker.io/xyz:some-git-hash" - assert ts.template.container.env == {"FOO": "bar", "HAM": "spam"} - assert 'flytekit.core.python_auto_container.default_notebook_task_resolver' in ts.template.container.args - assert interactive_serialization_settings.fast_serialization_settings is not None - assert interactive_serialization_settings.fast_serialization_settings.enabled is True - assert interactive_serialization_settings.fast_serialization_settings.destination_dir == "." - assert interactive_serialization_settings.fast_serialization_settings.distribution_location == PICKLE_FILE_PATH - - task_with_pod_template = DummyAutoContainerTask( name="x", metadata=TaskMetadata( @@ -409,3 +382,7 @@ def test_pod_template_with_image_spec(default_serialization_settings, mock_image pod = image_spec_task.get_k8s_pod(default_serialization_settings) assert pod.pod_spec["containers"][0]["image"] == image_spec_1.image_name() assert pod.pod_spec["containers"][1]["image"] == image_spec_2.image_name() + +def test_set_resolver(): + task.set_resolver(default_notebook_task_resolver) + assert task._task_resolver == default_notebook_task_resolver diff --git a/tests/flytekit/unit/core/test_resolver.py b/tests/flytekit/unit/core/test_resolver.py index 09d016fbbc..116b1251ae 100644 --- a/tests/flytekit/unit/core/test_resolver.py +++ b/tests/flytekit/unit/core/test_resolver.py @@ -1,13 +1,15 @@ import typing from collections import OrderedDict +import cloudpickle +import mock import pytest import flytekit.configuration from flytekit.configuration import Image, ImageConfig from flytekit.core.base_task import TaskResolverMixin from flytekit.core.class_based_resolver import ClassStorageTaskResolver -from flytekit.core.python_auto_container import default_task_resolver +from flytekit.core.python_auto_container import default_task_resolver, default_notebook_task_resolver, PICKLE_FILE_PATH from flytekit.core.task import task from flytekit.core.workflow import workflow from flytekit.tools.translator import get_serializable @@ -104,3 +106,27 @@ def test_mixin(): def test_error(): with pytest.raises(Exception): default_task_resolver.get_all_tasks() + + +@mock.patch("cloudpickle.load") # Mock cloudpickle.load and pass it as the first parameter +@mock.patch("gzip.open", new_callable=mock.mock_open) +def test_notebook_resolver(mock_gzip_open, mock_cloudpickle): + c = default_notebook_task_resolver + assert c.name() != "" + + with pytest.raises(ValueError): + c.load_task([]) + + @task + def t1(a: str, b: str) -> str: + return b + a + + assert c.loader_args(None, t1) == ["entity-name", "tests.flytekit.unit.core.test_resolver.t1"] + + pickled_dict = {"tests.flytekit.unit.core.test_resolver.t1": t1} + custom_pickled_object = cloudpickle.dumps(pickled_dict) + mock_gzip_open.return_value.read.return_value = custom_pickled_object + mock_cloudpickle.return_value = pickled_dict + + t = c.load_task(["entity-name", "tests.flytekit.unit.core.test_resolver.t1"]) + assert t == t1 diff --git a/tests/flytekit/unit/test_translator.py b/tests/flytekit/unit/test_translator.py index 213a267611..24f2c14131 100644 --- a/tests/flytekit/unit/test_translator.py +++ b/tests/flytekit/unit/test_translator.py @@ -84,23 +84,6 @@ def my_wf(a: int, b: str) -> (int, str): assert lp_model.id.name == "testlp" -def test_interactive(): - @task - def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): - return a + 2, "world" - - b = serialization_settings.new_builder() - b.interactive_mode_enabled = True - ssettings = b.build() - - fake_file_uploader = lambda dest: (0, dest) - options = Options(file_uploader=fake_file_uploader) - - task_spec = get_serializable(OrderedDict(), ssettings, t1, options) - assert "--dest-dir" in task_spec.template.container.args - assert task_spec.template.container.args[task_spec.template.container.args.index("--dest-dir") + 1] == "." - - def test_fast(): @task def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):