diff --git a/flytekit/core/checkpointer.py b/flytekit/core/checkpointer.py index d0fdf129e4..eae7cebd39 100644 --- a/flytekit/core/checkpointer.py +++ b/flytekit/core/checkpointer.py @@ -3,7 +3,7 @@ import typing from abc import abstractmethod from pathlib import Path - +import concurrent.futures as cf class Checkpoint(object): """ @@ -82,6 +82,7 @@ def __init__(self, checkpoint_dest: str, checkpoint_src: typing.Optional[str] = self._checkpoint_src = checkpoint_src if checkpoint_src and checkpoint_src != "" else None self._td = tempfile.TemporaryDirectory() self._prev_download_path: typing.Optional[Path] = None + self._torch_checkpoint: TorchAsyncCheckpoint = None def __del__(self): self._td.cleanup() @@ -113,8 +114,14 @@ def restore(self, path: typing.Optional[typing.Union[Path, str]] = None) -> typi self._prev_download_path = path return self._prev_download_path - def save(self, cp: typing.Union[Path, str, io.BufferedReader]): + def save(self, cp: typing.Union[Path, str, io.BufferedReader], future: cf.Future=None): # We have to lazy load, until we fix the imports + if future is not None: + if self._torch_checkpoint is None: + self._torch_checkpoint = TorchAsyncCheckpoint(self._checkpoint_dest, self._checkpoint_src) + self._torch_checkpoint.save(cp, future) + return + from flytekit.core.context_manager import FlyteContextManager fa = FlyteContextManager.current_context().file_access @@ -156,3 +163,56 @@ def write(self, b: bytes): p = io.BytesIO(b) f = typing.cast(io.BufferedReader, p) self.save(f) + + +class TorchAsyncCheckpoint(SyncCheckpoint): + """ + This class is NOT THREAD-SAFE! + Sync Checkpoint, will synchronously checkpoint a user given file or folder. + It will also synchronously download / restore previous checkpoints, when restore is invoked. + + TODO: Implement an async checkpoint system + """ + + SRC_LOCAL_FOLDER = "prev_cp" + TMP_DST_PATH = "_dst_cp" + + def __init__(self, checkpoint_dest: str, checkpoint_src: typing.Optional[str] = None): + """ + Args: + checkpoint_src: If a previous checkpoint should exist, this path should be set to the folder that contains the checkpoint information + checkpoint_dest: Location where the new checkpoint should be copied to + """ + super().__init__(checkpoint_dest, checkpoint_src) + self._async_upload: cf.Future = None + + def __del__(self): + super().__del__() + if self._async_upload: + self._async_upload.cancel() + + def _on_local_saved(self, cp: typing.Union[Path, str], fut: cf.Future): + # We have to lazy load, until we fix the imports + from flytekit.core.context_manager import FlyteContextManager + # wait for the checkpoint to be saved + fut.result() + print("local saved") + fa = FlyteContextManager.current_context().file_access + if isinstance(cp, str): + cp = Path(cp) + if cp.is_dir(): + fa.upload_directory(str(cp), self._checkpoint_dest) + return + + def save(self, cp: typing.Union[Path, str], future: cf.Future): + # We have to lazy load, until we fix the imports + from flytekit.core.context_manager import FlyteContextManager + + if self._async_upload: + self._async_upload.result() + print("remote saved") + executor = cf.ThreadPoolExecutor(max_workers=1) + self._async_upload = executor.submit(self._on_local_saved, cp, future) + return + + diff --git a/plugins/flytekit-kf-pytorch/dev-requirements.in b/plugins/flytekit-kf-pytorch/dev-requirements.in index 12c6d5d5ea..08ed5eeb4b 100644 --- a/plugins/flytekit-kf-pytorch/dev-requirements.in +++ b/plugins/flytekit-kf-pytorch/dev-requirements.in @@ -1 +1 @@ -torch +torch \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 6b50981faf..e5d7222c0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "flytekit" -dynamic = ["version"] +version = "1.15.0" authors = [{ name = "Flyte Contributors", email = "admin@flyte.org" }] description = "Flyte SDK for Python" license = { text = "Apache-2.0" }