-
Notifications
You must be signed in to change notification settings - Fork 305
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
TypeTransformers for PyTorch Tensor, Module, and Checkpoint (#1032)
* TypeTransformers for PyTorch Tensor and Module Signed-off-by: Samhita Alla <[email protected]> * add torch to requirements Signed-off-by: Samhita Alla <[email protected]> * add module as a native type and PyTorchCheckpoint Signed-off-by: Samhita Alla <[email protected]> * update requirements Signed-off-by: Samhita Alla <[email protected]> * procedural to OOP approach Signed-off-by: Samhita Alla <[email protected]> * nit Signed-off-by: Samhita Alla <[email protected]> * verify device conversion Signed-off-by: Samhita Alla <[email protected]> * verify device conversion Signed-off-by: Samhita Alla <[email protected]> * hyperparameters can be None Signed-off-by: Samhita Alla <[email protected]> * device conversion Signed-off-by: Samhita Alla <[email protected]> * device conversion Signed-off-by: Samhita Alla <[email protected]> * checkpoint code cleanup Signed-off-by: Samhita Alla <[email protected]> * resolve merge conflict Signed-off-by: Samhita Alla <[email protected]> * fix pytorch api reference; resolve merge conflict Signed-off-by: Samhita Alla <[email protected]> * fix pytorch import Signed-off-by: Samhita Alla <[email protected]>
- Loading branch information
1 parent
edbd900
commit e5f9d88
Showing
17 changed files
with
574 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,3 +11,4 @@ codespell | |
google-cloud-bigquery | ||
google-cloud-bigquery-storage | ||
IPython | ||
torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,3 +33,4 @@ papermill # papermill | |
jupyter # papermill | ||
pyspark # spark | ||
sqlalchemy # sqlalchemy | ||
torch # pytorch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
############ | ||
PyTorch Type | ||
############ | ||
.. automodule:: flytekit.extras.pytorch | ||
:no-members: | ||
:no-inherited-members: | ||
:no-special-members: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
""" | ||
Flytekit PyTorch | ||
========================================= | ||
.. currentmodule:: flytekit.extras.pytorch | ||
.. autosummary:: | ||
:template: custom.rst | ||
:toctree: generated/ | ||
PyTorchCheckpoint | ||
""" | ||
from flytekit.loggers import logger | ||
|
||
try: | ||
from .checkpoint import PyTorchCheckpoint, PyTorchCheckpointTransformer | ||
from .native import PyTorchModuleTransformer, PyTorchTensorTransformer | ||
except ImportError: | ||
logger.info( | ||
"We won't register PyTorchCheckpointTransformer, PyTorchTensorTransformer, and PyTorchModuleTransformer because torch is not installed." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import pathlib | ||
import typing | ||
from dataclasses import asdict, dataclass, fields, is_dataclass | ||
from typing import Any, Callable, Dict, NamedTuple, Optional, Type, Union | ||
|
||
import torch | ||
from dataclasses_json import dataclass_json | ||
from typing_extensions import Protocol | ||
|
||
from flytekit.core.context_manager import FlyteContext | ||
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError | ||
from flytekit.models.core import types as _core_types | ||
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar | ||
from flytekit.models.types import LiteralType | ||
|
||
|
||
class IsDataclass(Protocol): | ||
__dataclass_fields__: Dict | ||
__dataclass_params__: Dict | ||
__post_init__: Optional[Callable] | ||
|
||
|
||
@dataclass_json | ||
@dataclass | ||
class PyTorchCheckpoint: | ||
""" | ||
This class is helpful to save a checkpoint. | ||
""" | ||
|
||
module: Optional[torch.nn.Module] = None | ||
hyperparameters: Optional[Union[Dict[str, Any], NamedTuple, IsDataclass]] = None | ||
optimizer: Optional[torch.optim.Optimizer] = None | ||
|
||
def __post_init__(self): | ||
if not ( | ||
isinstance(self.hyperparameters, dict) | ||
or (is_dataclass(self.hyperparameters) and not isinstance(self.hyperparameters, type)) | ||
or (isinstance(self.hyperparameters, tuple) and hasattr(self.hyperparameters, "_fields")) | ||
or (self.hyperparameters is None) | ||
): | ||
raise TypeTransformerFailedError( | ||
f"hyperparameters must be a dict, dataclass, or NamedTuple. Got {type(self.hyperparameters)}" | ||
) | ||
|
||
if not (self.module or self.hyperparameters or self.optimizer): | ||
raise TypeTransformerFailedError("Must have at least one of module, hyperparameters, or optimizer") | ||
|
||
|
||
class PyTorchCheckpointTransformer(TypeTransformer[PyTorchCheckpoint]): | ||
""" | ||
TypeTransformer that supports serializing and deserializing checkpoint. | ||
""" | ||
|
||
PYTORCH_CHECKPOINT_FORMAT = "PyTorchCheckpoint" | ||
|
||
def __init__(self): | ||
super().__init__(name="PyTorch Checkpoint", t=PyTorchCheckpoint) | ||
|
||
def get_literal_type(self, t: Type[PyTorchCheckpoint]) -> LiteralType: | ||
return LiteralType( | ||
blob=_core_types.BlobType( | ||
format=self.PYTORCH_CHECKPOINT_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE | ||
) | ||
) | ||
|
||
def to_literal( | ||
self, | ||
ctx: FlyteContext, | ||
python_val: PyTorchCheckpoint, | ||
python_type: Type[PyTorchCheckpoint], | ||
expected: LiteralType, | ||
) -> Literal: | ||
meta = BlobMetadata( | ||
type=_core_types.BlobType( | ||
format=self.PYTORCH_CHECKPOINT_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE | ||
) | ||
) | ||
|
||
local_path = ctx.file_access.get_random_local_path() + ".pt" | ||
pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) | ||
|
||
to_save = {} | ||
for field in fields(python_val): | ||
value = getattr(python_val, field.name) | ||
|
||
if value and field.name in ["module", "optimizer"]: | ||
to_save[field.name + "_state_dict"] = getattr(value, "state_dict")() | ||
elif value and field.name == "hyperparameters": | ||
if isinstance(value, dict): | ||
to_save.update(value) | ||
elif isinstance(value, tuple): | ||
to_save.update(value._asdict()) | ||
elif is_dataclass(value): | ||
to_save.update(asdict(value)) | ||
|
||
if not to_save: | ||
raise TypeTransformerFailedError(f"Cannot save empty {python_val}") | ||
|
||
# save checkpoint to a file | ||
torch.save(to_save, local_path) | ||
|
||
remote_path = ctx.file_access.get_random_remote_path(local_path) | ||
ctx.file_access.put_data(local_path, remote_path, is_multipart=False) | ||
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) | ||
|
||
def to_python_value( | ||
self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[PyTorchCheckpoint] | ||
) -> PyTorchCheckpoint: | ||
try: | ||
uri = lv.scalar.blob.uri | ||
except AttributeError: | ||
TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") | ||
|
||
local_path = ctx.file_access.get_random_local_path() | ||
ctx.file_access.get_data(uri, local_path, is_multipart=False) | ||
|
||
# cpu <-> gpu conversion | ||
if torch.cuda.is_available(): | ||
map_location = "cuda:0" | ||
else: | ||
map_location = torch.device("cpu") | ||
|
||
# load checkpoint from a file | ||
return typing.cast(PyTorchCheckpoint, torch.load(local_path, map_location=map_location)) | ||
|
||
def guess_python_type(self, literal_type: LiteralType) -> Type[PyTorchCheckpoint]: | ||
if ( | ||
literal_type.blob is not None | ||
and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE | ||
and literal_type.blob.format == self.PYTORCH_CHECKPOINT_FORMAT | ||
): | ||
return PyTorchCheckpoint | ||
|
||
raise ValueError(f"Transformer {self} cannot reverse {literal_type}") | ||
|
||
|
||
TypeEngine.register(PyTorchCheckpointTransformer()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import pathlib | ||
from typing import Generic, Type, TypeVar | ||
|
||
import torch | ||
|
||
from flytekit.core.context_manager import FlyteContext | ||
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError | ||
from flytekit.models.core import types as _core_types | ||
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar | ||
from flytekit.models.types import LiteralType | ||
|
||
T = TypeVar("T") | ||
|
||
|
||
class PyTorchTypeTransformer(TypeTransformer, Generic[T]): | ||
def get_literal_type(self, t: Type[T]) -> LiteralType: | ||
return LiteralType( | ||
blob=_core_types.BlobType( | ||
format=self.PYTORCH_FORMAT, | ||
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, | ||
) | ||
) | ||
|
||
def to_literal( | ||
self, | ||
ctx: FlyteContext, | ||
python_val: T, | ||
python_type: Type[T], | ||
expected: LiteralType, | ||
) -> Literal: | ||
meta = BlobMetadata( | ||
type=_core_types.BlobType( | ||
format=self.PYTORCH_FORMAT, | ||
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, | ||
) | ||
) | ||
|
||
local_path = ctx.file_access.get_random_local_path() + ".pt" | ||
pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) | ||
|
||
# save pytorch tensor/module to a file | ||
torch.save(python_val, local_path) | ||
|
||
remote_path = ctx.file_access.get_random_remote_path(local_path) | ||
ctx.file_access.put_data(local_path, remote_path, is_multipart=False) | ||
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) | ||
|
||
def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: | ||
try: | ||
uri = lv.scalar.blob.uri | ||
except AttributeError: | ||
TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") | ||
|
||
local_path = ctx.file_access.get_random_local_path() | ||
ctx.file_access.get_data(uri, local_path, is_multipart=False) | ||
|
||
# cpu <-> gpu conversion | ||
if torch.cuda.is_available(): | ||
map_location = "cuda:0" | ||
else: | ||
map_location = torch.device("cpu") | ||
|
||
# load pytorch tensor/module from a file | ||
return torch.load(local_path, map_location=map_location) | ||
|
||
def guess_python_type(self, literal_type: LiteralType) -> Type[T]: | ||
if ( | ||
literal_type.blob is not None | ||
and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE | ||
and literal_type.blob.format == self.PYTORCH_FORMAT | ||
): | ||
return T | ||
|
||
raise ValueError(f"Transformer {self} cannot reverse {literal_type}") | ||
|
||
|
||
class PyTorchTensorTransformer(PyTorchTypeTransformer[torch.Tensor]): | ||
PYTORCH_FORMAT = "PyTorchTensor" | ||
|
||
def __init__(self): | ||
super().__init__(name="PyTorch Tensor", t=torch.Tensor) | ||
|
||
|
||
class PyTorchModuleTransformer(PyTorchTypeTransformer[torch.nn.Module]): | ||
PYTORCH_FORMAT = "PyTorchModule" | ||
|
||
def __init__(self): | ||
super().__init__(name="PyTorch Module", t=torch.nn.Module) | ||
|
||
|
||
TypeEngine.register(PyTorchTensorTransformer()) | ||
TypeEngine.register(PyTorchModuleTransformer()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1 @@ | ||
""" | ||
Flytekit Numpy | ||
============== | ||
.. currentmodule:: flytekit.types.numpy | ||
.. autosummary:: | ||
:template: custom.rst | ||
:toctree: generated/ | ||
NumpyArrayTransformer | ||
""" | ||
|
||
from .ndarray import NumpyArrayTransformer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Oops, something went wrong.