diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index fe38f946f9..4f4962309d 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -270,6 +270,8 @@ def setup_execution( if compressed_serialization_settings: ss = SerializationSettings.from_transport(compressed_serialization_settings) ssb = ss.new_builder() + ssb.project = exe_project + ssb.domain = exe_domain ssb.version = tk_version if dynamic_addl_distro: ssb.fast_serialization_settings = FastSerializationSettings( diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 2b21f7c40b..2cf8032a6f 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -258,7 +258,7 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr # The cache returns None iff the key does not exist in the cache if outputs_literal_map is None: logger.info("Cache miss, task will be executed now") - outputs_literal_map = self.dispatch_execute(ctx, input_literal_map) + outputs_literal_map = self.sandbox_execute(ctx, input_literal_map) # TODO: need `native_inputs` LocalTaskCache.set(self.name, self.metadata.cache_version, input_literal_map, outputs_literal_map) logger.info( @@ -268,10 +268,10 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr else: logger.info("Cache hit") else: - es = ctx.execution_state - b = es.user_space_params.with_task_sandbox() - ctx = ctx.current_context().with_execution_state(es.with_params(user_space_params=b.build())).build() - outputs_literal_map = self.dispatch_execute(ctx, input_literal_map) + # This code should mirror the call to `sandbox_execute` in the above cache case. + # Code is simpler with duplication and less metaprogramming, but introduces regressions + # if one is changed and not the other. + outputs_literal_map = self.sandbox_execute(ctx, input_literal_map) outputs_literals = outputs_literal_map.literals # TODO maybe this is the part that should be done for local execution, we pass the outputs to some special @@ -326,6 +326,19 @@ def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str] """ return None + def sandbox_execute( + self, + ctx: FlyteContext, + input_literal_map: _literal_models.LiteralMap, + ) -> _literal_models.LiteralMap: + """ + Call dispatch_execute, in the context of a local sandbox execution. Not invoked during runtime. + """ + es = ctx.execution_state + b = es.user_space_params.with_task_sandbox() + ctx = ctx.current_context().with_execution_state(es.with_params(user_space_params=b.build())).build() + return self.dispatch_execute(ctx, input_literal_map) + @abstractmethod def dispatch_execute( self, diff --git a/setup.py b/setup.py index f560e6c112..74a466394b 100644 --- a/setup.py +++ b/setup.py @@ -1,16 +1,5 @@ -import sys - from setuptools import find_packages, setup # noqa -MIN_PYTHON_VERSION = (3, 7) -CURRENT_PYTHON = sys.version_info[:2] -if CURRENT_PYTHON < MIN_PYTHON_VERSION: - print( - f"Flytekit API is only supported for Python version is {MIN_PYTHON_VERSION}+. Detected you are on" - f" version {CURRENT_PYTHON}, installation will not proceed!" - ) - sys.exit(-1) - extras_require = {} __version__ = "0.0.0+develop" @@ -92,7 +81,7 @@ "flytekit/bin/entrypoint.py", ], license="apache2", - python_requires=">=3.7", + python_requires=">=3.7,<3.11", classifiers=[ "Intended Audience :: Science/Research", "Intended Audience :: Developers", diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index 6a8b8c430e..6dd1785585 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -7,7 +7,6 @@ from flyteidl.core.errors_pb2 import ErrorDocument from flytekit.bin.entrypoint import _dispatch_execute, normalize_inputs, setup_execution -from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core import context_manager from flytekit.core.base_task import IgnoreOutputs from flytekit.core.data_persistence import DiskPersistence @@ -324,22 +323,6 @@ def test_setup_cloud_prefix(): assert isinstance(ctx.file_access._default_remote, GCSPersistence) -def test_persist_ss(): - default_img = Image(name="default", fqn="test", tag="tag") - ss = SerializationSettings( - project="proj1", - domain="dom", - version="version123", - env=None, - image_config=ImageConfig(default_image=default_img, images=[default_img]), - ) - ss_txt = ss.serialized_context - os.environ["_F_SS_C"] = ss_txt - with setup_execution("s3://", checkpoint_path=None, prev_checkpoint=None) as ctx: - assert ctx.serialization_settings.project == "proj1" - assert ctx.serialization_settings.domain == "dom" - - def test_normalize_inputs(): assert normalize_inputs("{{.rawOutputDataPrefix}}", "{{.checkpointOutputPrefix}}", "{{.prevCheckpointPrefix}}") == ( None, diff --git a/tests/flytekit/unit/core/test_checkpoint.py b/tests/flytekit/unit/core/test_checkpoint.py index f1dbbbd5ef..2add1b9e7d 100644 --- a/tests/flytekit/unit/core/test_checkpoint.py +++ b/tests/flytekit/unit/core/test_checkpoint.py @@ -4,6 +4,7 @@ import flytekit from flytekit.core.checkpointer import SyncCheckpoint +from flytekit.core.local_cache import LocalTaskCache def test_sync_checkpoint_write(tmpdir): @@ -123,5 +124,23 @@ def t1(n: int) -> int: return n + 1 +@flytekit.task(cache=True, cache_version="v0") +def t2(n: int) -> int: + ctx = flytekit.current_context() + cp = ctx.checkpoint + cp.write(bytes(n + 1)) + return n + 1 + + +@pytest.fixture(scope="function", autouse=True) +def setup(): + LocalTaskCache.initialize() + LocalTaskCache.clear() + + def test_checkpoint_task(): assert t1(n=5) == 6 + + +def test_checkpoint_cached_task(): + assert t2(n=5) == 6