Skip to content

Commit

Permalink
Support checkpointing in local mode from cached tasks (#1457)
Browse files Browse the repository at this point in the history
* support checkpointing in local mode from cached tasks

* clear cache before tests

---------

Co-authored-by: Stef Nelson-Lindall <[email protected]>
Co-authored-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
3 people authored Feb 10, 2023
1 parent 47ac6ac commit 35b1fa6
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 5 deletions.
23 changes: 18 additions & 5 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
# The cache returns None iff the key does not exist in the cache
if outputs_literal_map is None:
logger.info("Cache miss, task will be executed now")
outputs_literal_map = self.dispatch_execute(ctx, input_literal_map)
outputs_literal_map = self.sandbox_execute(ctx, input_literal_map)
# TODO: need `native_inputs`
LocalTaskCache.set(self.name, self.metadata.cache_version, input_literal_map, outputs_literal_map)
logger.info(
Expand All @@ -268,10 +268,10 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
else:
logger.info("Cache hit")
else:
es = ctx.execution_state
b = es.user_space_params.with_task_sandbox()
ctx = ctx.current_context().with_execution_state(es.with_params(user_space_params=b.build())).build()
outputs_literal_map = self.dispatch_execute(ctx, input_literal_map)
# This code should mirror the call to `sandbox_execute` in the above cache case.
# Code is simpler with duplication and less metaprogramming, but introduces regressions
# if one is changed and not the other.
outputs_literal_map = self.sandbox_execute(ctx, input_literal_map)
outputs_literals = outputs_literal_map.literals

# TODO maybe this is the part that should be done for local execution, we pass the outputs to some special
Expand Down Expand Up @@ -326,6 +326,19 @@ def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]
"""
return None

def sandbox_execute(
self,
ctx: FlyteContext,
input_literal_map: _literal_models.LiteralMap,
) -> _literal_models.LiteralMap:
"""
Call dispatch_execute, in the context of a local sandbox execution. Not invoked during runtime.
"""
es = ctx.execution_state
b = es.user_space_params.with_task_sandbox()
ctx = ctx.current_context().with_execution_state(es.with_params(user_space_params=b.build())).build()
return self.dispatch_execute(ctx, input_literal_map)

@abstractmethod
def dispatch_execute(
self,
Expand Down
19 changes: 19 additions & 0 deletions tests/flytekit/unit/core/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import flytekit
from flytekit.core.checkpointer import SyncCheckpoint
from flytekit.core.local_cache import LocalTaskCache


def test_sync_checkpoint_write(tmpdir):
Expand Down Expand Up @@ -123,5 +124,23 @@ def t1(n: int) -> int:
return n + 1


@flytekit.task(cache=True, cache_version="v0")
def t2(n: int) -> int:
ctx = flytekit.current_context()
cp = ctx.checkpoint
cp.write(bytes(n + 1))
return n + 1


@pytest.fixture(scope="function", autouse=True)
def setup():
LocalTaskCache.initialize()
LocalTaskCache.clear()


def test_checkpoint_task():
assert t1(n=5) == 6


def test_checkpoint_cached_task():
assert t2(n=5) == 6

0 comments on commit 35b1fa6

Please sign in to comment.