Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeTransformers for PyTorch Tensor, Module, and Checkpoint #1032

Merged
merged 19 commits into from
Jul 7, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dev-requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ codespell
google-cloud-bigquery
google-cloud-bigquery-storage
IPython
torch
15 changes: 5 additions & 10 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
# via
# -c requirements.txt
# pytest-flyte
appnope==0.1.3
# via ipython
arrow==1.2.2
# via
# -c requirements.txt
Expand Down Expand Up @@ -76,7 +78,6 @@ cryptography==37.0.2
# -c requirements.txt
# paramiko
# pyopenssl
# secretstorage
dataclasses-json==0.5.7
# via
# -c requirements.txt
Expand Down Expand Up @@ -186,11 +187,6 @@ ipython==7.34.0
# via -r dev-requirements.in
jedi==0.18.1
# via ipython
jeepney==0.8.0
# via
# -c requirements.txt
# keyring
# secretstorage
jinja2==3.1.2
# via
# -c requirements.txt
Expand Down Expand Up @@ -403,10 +399,6 @@ retry==0.9.2
# flytekit
rsa==4.8
# via google-auth
secretstorage==3.3.2
# via
# -c requirements.txt
# keyring
singledispatchmethod==1.0
# via
# -c requirements.txt
Expand Down Expand Up @@ -443,6 +435,8 @@ tomli==2.0.1
# coverage
# mypy
# pytest
torch==1.11.0
# via -r dev-requirements.in
traitlets==5.2.2.post1
# via
# ipython
Expand All @@ -457,6 +451,7 @@ typing-extensions==4.2.0
# importlib-metadata
# mypy
# responses
# torch
# typing-inspect
typing-inspect==0.7.1
# via
Expand Down
1 change: 1 addition & 0 deletions doc-requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ papermill # papermill
jupyter # papermill
pyspark # spark
sqlalchemy # sqlalchemy
torch # pytorch
16 changes: 8 additions & 8 deletions doc-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ altair==4.2.0
# via great-expectations
ansiwrap==0.8.4
# via papermill
appnope==0.1.3
# via
# ipykernel
# ipython
argon2-cffi==21.3.0
# via notebook
argon2-cffi-bindings==21.2.0
Expand Down Expand Up @@ -42,7 +46,7 @@ binaryornot==0.4.4
# via cookiecutter
bleach==5.0.0
# via nbconvert
botocore==1.27.7
botocore==1.27.8
# via -r doc-requirements.in
cachetools==5.2.0
# via google-auth
Expand Down Expand Up @@ -77,7 +81,6 @@ cryptography==37.0.2
# -r doc-requirements.in
# great-expectations
# pyopenssl
# secretstorage
css-html-js-minify==2.5.5
# via sphinx-material
cycler==0.11.0
Expand Down Expand Up @@ -213,10 +216,6 @@ ipywidgets==7.7.0
# via jupyter
jedi==0.18.1
# via ipython
jeepney==0.8.0
# via
# keyring
# secretstorage
jinja2==3.0.3
# via
# altair
Expand Down Expand Up @@ -572,8 +571,6 @@ seaborn==0.11.2
# via
# missingno
# pandas-profiling
secretstorage==3.3.2
# via keyring
send2trash==1.8.0
# via notebook
singledispatchmethod==1.0
Expand Down Expand Up @@ -663,6 +660,8 @@ tinycss2==1.1.1
# via nbconvert
toolz==0.11.2
# via altair
torch==1.11.0
# via -r doc-requirements.in
tornado==6.1
# via
# ipykernel
Expand Down Expand Up @@ -702,6 +701,7 @@ typing-extensions==4.2.0
# pandera
# pydantic
# responses
# torch
# typing-inspect
typing-inspect==0.7.1
# via
Expand Down
4 changes: 4 additions & 0 deletions docs/source/types.builtins.pytorch.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
.. automodule:: flytekit.types.pytorch
:no-members:
:no-inherited-members:
:no-special-members:
1 change: 1 addition & 0 deletions docs/source/types.extend.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ Feel free to follow the pattern of the built-in types.
types.builtins.structured
types.builtins.file
types.builtins.directory
types.builtins.pytorch
2 changes: 1 addition & 1 deletion flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@
from flytekit.models.core.types import BlobType
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType
from flytekit.types import directory, file, numpy, schema
from flytekit.types import directory, file, numpy, pytorch, schema
from flytekit.types.structured.structured_dataset import (
StructuredDataset,
StructuredDatasetFormat,
Expand Down
12 changes: 0 additions & 12 deletions flytekit/types/numpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1 @@
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is removing this from docs intentional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I don't think we'd want to have Transformer in the API reference cause the methods within the TypeTransformer class remain the same.

Flytekit Numpy
==============
.. currentmodule:: flytekit.types.numpy

.. autosummary::
:template: custom.rst
:toctree: generated/

NumpyArrayTransformer
"""

from .ndarray import NumpyArrayTransformer
20 changes: 20 additions & 0 deletions flytekit/types/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
Flytekit PyTorch
=========================================
.. currentmodule:: flytekit.types.pytorch

.. autosummary::
:template: custom.rst
:toctree: generated/

PyTorchCheckpoint
"""
from flytekit.loggers import logger

try:
from .checkpoint import PyTorchCheckpoint, PyTorchCheckpointTransformer
from .native import PyTorchModuleTransformer, PyTorchTensorTransformer
except ImportError:
logger.info(
"We won't register PyTorchCheckpointTransformer, PyTorchTensorTransformer, and PyTorchModuleTransformer because torch is not installed."
)
131 changes: 131 additions & 0 deletions flytekit/types/pytorch/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import pathlib
import typing
from dataclasses import asdict, dataclass, fields, is_dataclass
from typing import Any, Callable, Dict, NamedTuple, Optional, Type, Union

import torch
from dataclasses_json import dataclass_json

from flytekit.core.context_manager import FlyteContext
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType

try:
from typing import Protocol
except ImportError:
from typing_extensions import Protocol


class IsDataclass(Protocol):
__dataclass_fields__: Dict
__dataclass_params__: Dict
__post_init__: Optional[Callable]


@dataclass_json
@dataclass
class PyTorchCheckpoint:
"""
This class is helpful to save a checkpoint.
"""

module: Optional[torch.nn.Module] = None
hyperparameters: Optional[Union[Dict[str, Any], NamedTuple, IsDataclass]] = None
optimizer: Optional[torch.optim.Optimizer] = None

def __post_init__(self):
if not (
isinstance(self.hyperparameters, dict)
or (is_dataclass(self.hyperparameters) and not isinstance(self.hyperparameters, type))
or (isinstance(self.hyperparameters, tuple) and hasattr(self.hyperparameters, "_fields"))
):
raise TypeError(
f"hyperparameters must be a dict, dataclass, or NamedTuple. Got {type(self.hyperparameters)}"
)


class PyTorchCheckpointTransformer(TypeTransformer[PyTorchCheckpoint]):
"""
TypeTransformer that supports serializing and deserializing checkpoint.
"""

PYTORCH_CHECKPOINT_FORMAT = "PyTorchCheckpoint"

def __init__(self):
super().__init__(name="PyTorch Checkpoint", t=PyTorchCheckpoint)

def get_literal_type(self, t: Type[PyTorchCheckpoint]) -> LiteralType:
return LiteralType(
blob=_core_types.BlobType(
format=self.PYTORCH_CHECKPOINT_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
)

def to_literal(
self,
ctx: FlyteContext,
python_val: PyTorchCheckpoint,
python_type: Type[PyTorchCheckpoint],
expected: LiteralType,
) -> Literal:
meta = BlobMetadata(
type=_core_types.BlobType(
format=self.PYTORCH_CHECKPOINT_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
)

local_path = ctx.file_access.get_random_local_path() + ".pt"
pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True)

to_save = {}
for field in fields(python_val):
if field.name in ["module", "optimizer"]:
to_save[field.name + "_state_dict"] = getattr(getattr(python_val, field.name), "state_dict")()
elif field.name == "hyperparameters":
hyperparameters = getattr(python_val, field.name)

if isinstance(hyperparameters, dict):
to_save.update(hyperparameters)
elif isinstance(hyperparameters, tuple):
to_save.update(hyperparameters._asdict())
elif is_dataclass(hyperparameters):
to_save.update(asdict(hyperparameters))

if not to_save:
raise TypeTransformerFailedError(f"Cannot save empty {python_val}")

# save checkpoint to a file
torch.save(to_save, local_path)

remote_path = ctx.file_access.get_random_remote_path(local_path)
ctx.file_access.put_data(local_path, remote_path, is_multipart=False)
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))

def to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[PyTorchCheckpoint]
) -> PyTorchCheckpoint:
try:
uri = lv.scalar.blob.uri
except AttributeError:
TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")

local_path = ctx.file_access.get_random_local_path()
ctx.file_access.get_data(uri, local_path, is_multipart=False)

# load checkpoint from a file
return typing.cast(PyTorchCheckpoint, torch.load(local_path))

def guess_python_type(self, literal_type: LiteralType) -> Type[PyTorchCheckpoint]:
if (
literal_type.blob is not None
and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE
and literal_type.blob.format == self.PYTORCH_CHECKPOINT_FORMAT
):
return PyTorchCheckpoint

raise ValueError(f"Transformer {self} cannot reverse {literal_type}")


TypeEngine.register(PyTorchCheckpointTransformer())
Loading