-
Notifications
You must be signed in to change notification settings - Fork 301
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for python pickle type in flytekit/flyte (#667)
* Init python pickle Signed-off-by: Kevin Su <[email protected]> * Refactor pickle support Signed-off-by: Kevin Su <[email protected]> * Fixed lint Signed-off-by: Kevin Su <[email protected]> * Fixed lint Signed-off-by: Kevin Su <[email protected]> * Fixed register error Signed-off-by: Kevin Su <[email protected]> * Added tests Signed-off-by: Kevin Su <[email protected]> * Fixed lint Signed-off-by: Kevin Su <[email protected]> * Handle list of pickle Signed-off-by: Kevin Su <[email protected]> * Fixed lint Signed-off-by: Kevin Su <[email protected]> * Added assert_type Signed-off-by: Kevin Su <[email protected]> * Updated comment Signed-off-by: Kevin Su <[email protected]> * Remove unnecessary cast Signed-off-by: Kevin Su <[email protected]> * Added more tests Signed-off-by: Kevin Su <[email protected]> * Fixed test Signed-off-by: Kevin Su <[email protected]> * Address comment Signed-off-by: Kevin Su <[email protected]> * Update list of pickle Signed-off-by: Kevin Su <[email protected]> * Update list of pickle Signed-off-by: Kevin Su <[email protected]> * Update list of pickle Signed-off-by: Kevin Su <[email protected]> * Fix tests Signed-off-by: Kevin Su <[email protected]> * Fix tests Signed-off-by: Kevin Su <[email protected]> * Fix tests Signed-off-by: Kevin Su <[email protected]> * Fix tests Signed-off-by: Kevin Su <[email protected]> * Fixed test Signed-off-by: Kevin Su <[email protected]> * Fixed test Signed-off-by: Kevin Su <[email protected]> * Fixed test Signed-off-by: Kevin Su <[email protected]> * add test Signed-off-by: Yee Hing Tong <[email protected]> * add test Signed-off-by: Yee Hing Tong <[email protected]> * Fixed lint Signed-off-by: Kevin Su <[email protected]> * Fixed import Signed-off-by: Kevin Su <[email protected]> * Handle nested list and dict Signed-off-by: Kevin Su <[email protected]> * Added tests Signed-off-by: Kevin Su <[email protected]> * Fixed lint Signed-off-by: Kevin Su <[email protected]> Co-authored-by: Yee Hing Tong <[email protected]>
- Loading branch information
1 parent
895cded
commit 88b590c
Showing
14 changed files
with
271 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
""" | ||
Flytekit Pickle Type | ||
========================================================== | ||
.. currentmodule:: flytekit.types.pickle | ||
.. autosummary:: | ||
:toctree: generated/ | ||
FlytePickle | ||
""" | ||
|
||
from .pickle import FlytePickle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import os | ||
import typing | ||
from typing import Type | ||
|
||
import cloudpickle | ||
|
||
from flytekit.core.context_manager import FlyteContext | ||
from flytekit.core.type_engine import TypeEngine, TypeTransformer | ||
from flytekit.models.core import types as _core_types | ||
from flytekit.models.core.literals import Blob, BlobMetadata, Literal, Scalar | ||
from flytekit.models.core.types import LiteralType | ||
|
||
T = typing.TypeVar("T") | ||
|
||
|
||
class FlytePickle(typing.Generic[T]): | ||
""" | ||
This type is only used by flytekit internally. User should not use this type. | ||
Any type that flyte can't recognize will become FlytePickle | ||
""" | ||
|
||
@classmethod | ||
def python_type(cls) -> None: | ||
return None | ||
|
||
def __class_getitem__(cls, python_type: typing.Type) -> typing.Type[T]: | ||
if python_type is None: | ||
return cls | ||
|
||
class _SpecificFormatClass(FlytePickle): | ||
# Get the type engine to see this as kind of a generic | ||
__origin__ = FlytePickle | ||
|
||
@classmethod | ||
def python_type(cls) -> typing.Type: | ||
return python_type | ||
|
||
return _SpecificFormatClass | ||
|
||
|
||
class FlytePickleTransformer(TypeTransformer[FlytePickle]): | ||
PYTHON_PICKLE_FORMAT = "PythonPickle" | ||
|
||
def __init__(self): | ||
super().__init__(name="FlytePickle", t=FlytePickle) | ||
|
||
def assert_type(self, t: Type[T], v: T): | ||
# Every type can serialize to pickle, so we don't need to check the type here. | ||
... | ||
|
||
def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: | ||
uri = lv.scalar.blob.uri | ||
# Deserialize the pickle, and return data in the pickle, | ||
# and download pickle file to local first if file is not in the local file systems. | ||
if ctx.file_access.is_remote(uri): | ||
local_path = ctx.file_access.get_random_local_path() | ||
ctx.file_access.get_data(uri, local_path, False) | ||
uri = local_path | ||
with open(uri, "rb") as infile: | ||
data = cloudpickle.load(infile) | ||
return data | ||
|
||
def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: | ||
meta = BlobMetadata( | ||
type=_core_types.BlobType( | ||
format=self.PYTHON_PICKLE_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE | ||
) | ||
) | ||
# Dump the task output into pickle | ||
local_dir = ctx.file_access.get_random_local_directory() | ||
os.makedirs(local_dir, exist_ok=True) | ||
local_path = ctx.file_access.get_random_local_path() | ||
uri = os.path.join(local_dir, local_path) | ||
with open(uri, "w+b") as outfile: | ||
cloudpickle.dump(python_val, outfile) | ||
|
||
remote_path = ctx.file_access.get_random_remote_path(uri) | ||
ctx.file_access.put_data(uri, remote_path, is_multipart=False) | ||
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) | ||
|
||
def get_literal_type(self, t: Type[T]) -> LiteralType: | ||
return _core_types.LiteralType( | ||
blob=_core_types.BlobType( | ||
format=self.PYTHON_PICKLE_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE | ||
) | ||
) | ||
|
||
|
||
TypeEngine.register(FlytePickleTransformer()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from collections import OrderedDict | ||
from typing import Dict, List | ||
|
||
from flytekit.common.translator import get_serializable | ||
from flytekit.core import context_manager | ||
from flytekit.core.context_manager import Image, ImageConfig | ||
from flytekit.core.task import task | ||
from flytekit.models.core.literals import BlobMetadata | ||
from flytekit.models.core.types import BlobType, LiteralType | ||
from flytekit.types.pickle.pickle import FlytePickle, FlytePickleTransformer | ||
|
||
default_img = Image(name="default", fqn="test", tag="tag") | ||
serialization_settings = context_manager.SerializationSettings( | ||
project="project", | ||
domain="domain", | ||
version="version", | ||
env=None, | ||
image_config=ImageConfig(default_image=default_img, images=[default_img]), | ||
) | ||
|
||
|
||
def test_to_python_value_and_literal(): | ||
ctx = context_manager.FlyteContext.current_context() | ||
tf = FlytePickleTransformer() | ||
python_val = "fake_output" | ||
lt = tf.get_literal_type(FlytePickle) | ||
|
||
lv = tf.to_literal(ctx, python_val, type(python_val), lt) # type: ignore | ||
assert lv.scalar.blob.metadata == BlobMetadata( | ||
type=BlobType( | ||
format=FlytePickleTransformer.PYTHON_PICKLE_FORMAT, | ||
dimensionality=BlobType.BlobDimensionality.SINGLE, | ||
) | ||
) | ||
assert lv.scalar.blob.uri is not None | ||
|
||
output = tf.to_python_value(ctx, lv, str) | ||
assert output == python_val | ||
|
||
|
||
def test_get_literal_type(): | ||
tf = FlytePickleTransformer() | ||
lt = tf.get_literal_type(FlytePickle) | ||
assert lt == LiteralType( | ||
blob=BlobType( | ||
format=FlytePickleTransformer.PYTHON_PICKLE_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE | ||
) | ||
) | ||
|
||
|
||
def test_nested(): | ||
class Foo(object): | ||
def __init__(self, number: int): | ||
self.number = number | ||
|
||
@task | ||
def t1(a: int) -> List[List[Foo]]: | ||
return [[Foo(number=a)]] | ||
|
||
task_spec = get_serializable(OrderedDict(), serialization_settings, t1) | ||
assert ( | ||
task_spec.template.interface.outputs["o0"].type.collection_type.collection_type.blob.format | ||
is FlytePickleTransformer.PYTHON_PICKLE_FORMAT | ||
) | ||
|
||
|
||
def test_nested2(): | ||
class Foo(object): | ||
def __init__(self, number: int): | ||
self.number = number | ||
|
||
@task | ||
def t1(a: int) -> List[Dict[str, Foo]]: | ||
return [{"a": Foo(number=a)}] | ||
|
||
task_spec = get_serializable(OrderedDict(), serialization_settings, t1) | ||
assert ( | ||
task_spec.template.interface.outputs["o0"].type.collection_type.map_value_type.blob.format | ||
is FlytePickleTransformer.PYTHON_PICKLE_FORMAT | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.