-
Notifications
You must be signed in to change notification settings - Fork 301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TypeTransformers for PyTorch Tensor, Module, and Checkpoint #1032
Conversation
Signed-off-by: Samhita Alla <[email protected]>
Signed-off-by: Samhita Alla <[email protected]>
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #1032 +/- ##
==========================================
+ Coverage 86.90% 86.92% +0.01%
==========================================
Files 269 275 +6
Lines 25144 25448 +304
Branches 2834 2862 +28
==========================================
+ Hits 21851 22120 +269
- Misses 2823 2850 +27
- Partials 470 478 +8 ☔ View full report in Codecov by Sentry. |
@samhita-alla i think the type should still be torch.no.module and we should do the right thing? I read your point, but is there a problem on reloading non.module? Cc @cosmicBboy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like this should be a flytekit plugin and not part of core flytekit package. It sets a precedence of adding more and more additional deps to the core package that I'm not sure we want to shoulder (pytorch, tensorflow, sklearn, etc.)
thoughts @eapolinario @kumare3 ?
edit: I see how we're handling the case if torch
isn't installed. I suppose this is okay, as long as we're okay establishing this pattern for other types
Agreed, I think any subclass of
So we should do the "right thing" automatically using state dict instead of pickling the module. We're already abstracting away how these types are stored in Flyte, so I think the risk of confusion here is minimal. The extra layer of indirection with In the Of course all of this assumes that the user-code has access to the class MyModel(nn.Module): ...
@task
def my_task(model: MyModel):
# model is automatically converted to a MyModel type by the type engine.
... |
@cosmicBboy, I don't think this can be a standalone plugin because having so will require users to install
How should we serialize and deserialize a PyTorch model? |
The plugin could specify an unpinned
Playing around with this PR locally, it does seem like there are a bunch of issues associated with trying to handle serialization/deserialization in the type transformer:
I have two thoughts here:
T = TypeVar("T", bound=typing.Union[typing.Dict, typing.NamedTuple])
@dataclass_json
@dataclass
class PyTorchCheckpoint(object, Generic[T]):
module: typing.Optional[torch.nn.Module] = None
hyperparameters: typing.Optional[T] = None # not required for models that have hard-coded architecture
optimizer: typing.Optional[torch.optim. Optimizer] = None
epoch: typing.Optional[int] = None
loss: typing.Optional[float] = None Basically this supports the special case of just wanted to store the module state dict, while also supporting a fully generalized checkpoint that we can probably use in concert with the intra-task checkpointing. Since hyperparameters must be user-provided, we can't know its type ahead of time, hence the use of class Hyperparameters(typing.NamedTuple):
...
ModelCheckpoint = PyTorchCheckpoint[Hyperparameters]
@task
def produce_model(hyperparameters: Hyperparameters) -> ModelCheckpoint:
model = MyModel(**hyperparameters._asdict())
optim = torch.optim.SGD()
... # train
return ModelCheckpoint(
module=model,
hyperparameters=hyperparameters,
optimizer=optim,
epoch=...,
loss=...,
) edit: since we'd include the hyperparameters in the state dict that's serialized, we may not need the Generic stuff |
Can we not store hyperparameters when the user returns a module with @task
def generate_model() -> PyTorchStateDict:
bn = MyModel(...)
return PyTorchStateDict(module=bn)
Yeah! But isn't this the approach we'll have to follow in case we want to support applying
I love this! We can have |
Yep! This is pretty much the
So if we use
So if we agree that do we want to support |
Do you mean to say that we support:
|
Signed-off-by: Samhita Alla <[email protected]>
Signed-off-by: Samhita Alla <[email protected]>
I've dynamically created transformers for On the whole, the following are the types enclosed in this PR:
All the examples are available in the PR description. Note: |
Signed-off-by: Samhita Alla <[email protected]>
Signed-off-by: Samhita Alla <[email protected]>
@@ -1,13 +1 @@ | |||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is removing this from docs intentional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah. I don't think we'd want to have Transformer in the API reference cause the methods within the TypeTransformer class remain the same.
flytekit/types/pytorch/native.py
Outdated
@@ -0,0 +1,110 @@ | |||
from __future__ import annotations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall notes on this module:
I can appreciate the DRY principle here, but the dynamically generated classes in this module makes it a little hard to reason about and read imo... I'm also not sure whether these dynamically generated classes will have useful code/auto-completion on the user side.
Suggestion: DRYing this logic could be achieved with a parent class like BaseTensor
and two subclasses Tensor
and Module
could be clearer to read.
One suggestion for an approach would be:
T = typing.TypeVar("T")
# use generics to abstract out the types in the method definitions
class PytorchTypeTransformer(TypeTransformer, typing.Generic[T]):
def get_literal_type(self, t: Type[T]) -> LiteralType:
return LiteralType(
blob=_core_types.BlobType(
format=self.PYTORCH_FORMAT,
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,
)
)
def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
...
# implement the other common methods here
class PyTorchTensorTransformer(PytorchTypeTransformer[nn.Tensor]):
PYTORCH_FORMAT = "PytorchTensor"
def __init__(self):
super().__init__(name="Pytorch Tensor", t=nn.Tensor)
...
class PyTorchModuleTransformer(PytorchTypeTransformer[nn.Module]): ...
PYTORCH_FORMAT = "PytorchModule"
def __init__(self):
super().__init__(name="Pytorch Module", t=nn.Module)
There's probably a better way of doing this, but just wanted to propose a feasible alternative.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I followed an imperative style to tackle this, but yes, the OOP approach is more readable. There shouldn't be a problem with auto-completion cause users wouldn't be importing any module in their code as such. However, I modified the code to use inheritance now. :) Thanks for looking into this!
Signed-off-by: Samhita Alla <[email protected]>
Signed-off-by: Samhita Alla <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great work!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, one thing I forgot to mention here is: how do we want to handle cpu/cuda tensors?
For example, if I serialize a cuda tensor/module, the type engine will fail if we try to deserialize it in a cpu machine.
Should we leave that up to the user? Or should we do some automagic to handle that?
(depending on the answer, we can work on a follow-up PR)
Amazing! Let's merge this @samhita-alla @eapolinario |
+1, please. @cosmicBboy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couple of minor comments. LGTM otherwise
tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt
Outdated
Show resolved
Hide resolved
tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt
Outdated
Show resolved
Hide resolved
Signed-off-by: Samhita Alla <[email protected]>
@pingsutw, I missed this! Resolved it now. |
should we make this part of |
Signed-off-by: Samhita Alla <[email protected]>
Signed-off-by: Samhita Alla <[email protected]>
just for the paper trail:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Module organization makes sense to me, see comment about flytekit/__init__.py
import
flytekit/__init__.py
Outdated
@@ -183,6 +183,7 @@ | |||
from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow | |||
from flytekit.deck import Deck | |||
from flytekit.extras.persistence import GCSPersistence, HttpPersistence, S3Persistence | |||
from flytekit.extras.pytorch import PyTorchModuleTransformer, PyTorchTensorTransformer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this will break if torch isn't installed (?)
Maybe we can do something like:
from flytekit.extras import pytorch
Which would automatically register the tensor and module type transformers (if torch
is installed)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad. Fixed the import now.
Signed-off-by: Samhita Alla <[email protected]>
from flytekit.loggers import logger | ||
|
||
try: | ||
from .checkpoint import PyTorchCheckpoint, PyTorchCheckpointTransformer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should stick with the full import in the future, just for consistency. merge as is, i'll update it in the future.
try: | ||
from typing import Protocol | ||
except ImportError: | ||
from typing_extensions import Protocol |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can always use typing_extensions right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep! I'll merge this now but will make sure to modify the import to use typing_extensions
in a different PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified the import — I had to resolve a merge conflict.
Signed-off-by: Samhita Alla <[email protected]>
a48472a
* TypeTransformers for PyTorch Tensor and Module Signed-off-by: Samhita Alla <[email protected]> * add torch to requirements Signed-off-by: Samhita Alla <[email protected]> * add module as a native type and PyTorchCheckpoint Signed-off-by: Samhita Alla <[email protected]> * update requirements Signed-off-by: Samhita Alla <[email protected]> * procedural to OOP approach Signed-off-by: Samhita Alla <[email protected]> * nit Signed-off-by: Samhita Alla <[email protected]> * verify device conversion Signed-off-by: Samhita Alla <[email protected]> * verify device conversion Signed-off-by: Samhita Alla <[email protected]> * hyperparameters can be None Signed-off-by: Samhita Alla <[email protected]> * device conversion Signed-off-by: Samhita Alla <[email protected]> * device conversion Signed-off-by: Samhita Alla <[email protected]> * checkpoint code cleanup Signed-off-by: Samhita Alla <[email protected]> * resolve merge conflict Signed-off-by: Samhita Alla <[email protected]> * fix pytorch api reference; resolve merge conflict Signed-off-by: Samhita Alla <[email protected]> * fix pytorch import Signed-off-by: Samhita Alla <[email protected]>
* TypeTransformers for PyTorch Tensor and Module Signed-off-by: Samhita Alla <[email protected]> * add torch to requirements Signed-off-by: Samhita Alla <[email protected]> * add module as a native type and PyTorchCheckpoint Signed-off-by: Samhita Alla <[email protected]> * update requirements Signed-off-by: Samhita Alla <[email protected]> * procedural to OOP approach Signed-off-by: Samhita Alla <[email protected]> * nit Signed-off-by: Samhita Alla <[email protected]> * verify device conversion Signed-off-by: Samhita Alla <[email protected]> * verify device conversion Signed-off-by: Samhita Alla <[email protected]> * hyperparameters can be None Signed-off-by: Samhita Alla <[email protected]> * device conversion Signed-off-by: Samhita Alla <[email protected]> * device conversion Signed-off-by: Samhita Alla <[email protected]> * checkpoint code cleanup Signed-off-by: Samhita Alla <[email protected]> * resolve merge conflict Signed-off-by: Samhita Alla <[email protected]> * fix pytorch api reference; resolve merge conflict Signed-off-by: Samhita Alla <[email protected]> * fix pytorch import Signed-off-by: Samhita Alla <[email protected]>
Signed-off-by: Samhita Alla [email protected]
TL;DR
This PR adds support for:
torch.Tensor
as a native type.state_dict
.PyTorchStateDict
is the custom type defined to handle serialization and deserialization ofstate_dict
.torch.nn.Module
(base class for all neural network modules) isn't considered a native type here; this is per the docs: "Instead of saving a module directly, for compatibility reasons it is recommended to instead save only its state dict." Subclassingtorch.nn.Module
in a TypeTransfomer might mislead the users into thinking that a PyTorch model is being serialized and deserialized when instead the model'sstate_dict
is being considered.Type
Are all requirements met?
Complete description
Module/Model serialization example:
Tensor serialization example:
Checkpoint example:
Tracking Issue
flyteorg/flyte#2544
Follow-up issue
NA
OR
https://github.com/flyteorg/flyte/issues/