-
Notifications
You must be signed in to change notification settings - Fork 301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for python pickle type in flytekit/flyte #667
Changes from all commits
bc8b607
5c596ba
97c1ff1
5cbff83
5f7094e
0cdb183
46a7180
12abcb8
97b25bd
395220c
ff2083f
7bdeaba
d3907f0
5865e4e
7f0b513
e6c04c6
3ce284b
0f4fad2
3fc3943
8e4b3cf
249ff13
842425f
e22cc93
59f6b70
1dd1fd4
c25aafe
d3de4c0
5ab1328
d2c984d
a3daf49
55eea32
70158c7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
import collections | ||
import copy | ||
import inspect | ||
import logging as _logging | ||
import typing | ||
from collections import OrderedDict | ||
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union | ||
|
@@ -13,6 +14,9 @@ | |
from flytekit.core.type_engine import TypeEngine | ||
from flytekit.loggers import logger | ||
from flytekit.models.core import interface as _interface_models | ||
from flytekit.types.pickle import FlytePickle | ||
|
||
T = typing.TypeVar("T") | ||
|
||
|
||
class Interface(object): | ||
|
@@ -244,20 +248,41 @@ def transform_interface_to_list_interface(interface: Interface) -> Interface: | |
return Interface(inputs=map_inputs, outputs=map_outputs) | ||
|
||
|
||
def _change_unrecognized_type_to_pickle(t: Type[T]) -> Type[T]: | ||
try: | ||
if hasattr(t, "__origin__") and hasattr(t, "__args__"): | ||
if t.__origin__ == list: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm having a hard time understanding why there needs to be special casing for list/dict. Other types don't have this right? like we don't need this for strings. So why do we need it for Pickle? I pushed a couple tests that are failing for me... could you take a look at them please? Thanks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should check the element in the list can be converted to literal, otherwise, we will get an error here.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've updated it to handle nested lists and dict. |
||
return typing.List[_change_unrecognized_type_to_pickle(t.__args__[0])] | ||
elif t.__origin__ == dict and t.__args__[0] == str: | ||
return typing.Dict[str, _change_unrecognized_type_to_pickle(t.__args__[1])] | ||
else: | ||
TypeEngine.get_transformer(t) | ||
except ValueError: | ||
_logging.warning( | ||
f"Unsupported Type {t} found, Flyte will default to use PickleFile as the transport. " | ||
f"Pickle can only be used to send objects between the exact same version of Python, " | ||
f"and we strongly recommend to use python type that flyte support." | ||
) | ||
return FlytePickle[t] | ||
return t | ||
|
||
|
||
def transform_signature_to_interface(signature: inspect.Signature, docstring: Optional[Docstring] = None) -> Interface: | ||
""" | ||
From the annotations on a task function that the user should have provided, and the output names they want to use | ||
for each output parameter, construct the TypedInterface object | ||
|
||
For now the fancy object, maybe in the future a dumb object. | ||
|
||
""" | ||
outputs = extract_return_annotation(signature.return_annotation) | ||
|
||
for k, v in outputs.items(): | ||
outputs[k] = _change_unrecognized_type_to_pickle(v) | ||
inputs = OrderedDict() | ||
for k, v in signature.parameters.items(): | ||
annotation = v.annotation | ||
default = v.default if v.default is not inspect.Parameter.empty else None | ||
# Inputs with default values are currently ignored, we may want to look into that in the future | ||
inputs[k] = (v.annotation, v.default if v.default is not inspect.Parameter.empty else None) | ||
inputs[k] = (_change_unrecognized_type_to_pickle(annotation), default) | ||
|
||
# This is just for typing.NamedTuples - in those cases, the user can select a name to call the NamedTuple. We | ||
# would like to preserve that name in our custom collections.namedtuple. | ||
|
@@ -273,7 +298,8 @@ def transform_signature_to_interface(signature: inspect.Signature, docstring: Op | |
|
||
|
||
def transform_variable_map( | ||
variable_map: Dict[str, type], descriptions: Dict[str, str] = {} | ||
variable_map: Dict[str, type], | ||
descriptions: Dict[str, str] = {}, | ||
) -> Dict[str, _interface_models.Variable]: | ||
""" | ||
Given a map of str (names of inputs for instance) to their Python native types, return a map of the name to a | ||
|
@@ -283,6 +309,14 @@ def transform_variable_map( | |
if variable_map: | ||
for k, v in variable_map.items(): | ||
res[k] = transform_type(v, descriptions.get(k, k)) | ||
sub_type: Type[T] = v | ||
if hasattr(v, "__origin__") and hasattr(v, "__args__"): | ||
if v.__origin__ is list: | ||
sub_type = v.__args__[0] | ||
elif v.__origin__ is dict: | ||
sub_type = v.__args__[1] | ||
if hasattr(sub_type, "__origin__") and sub_type.__origin__ is FlytePickle: | ||
res[k].type.metadata = {"python_class_name": sub_type.python_type().__name__} | ||
|
||
return res | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
""" | ||
Flytekit Pickle Type | ||
========================================================== | ||
.. currentmodule:: flytekit.types.pickle | ||
|
||
.. autosummary:: | ||
:toctree: generated/ | ||
|
||
FlytePickle | ||
""" | ||
|
||
from .pickle import FlytePickle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import os | ||
import typing | ||
from typing import Type | ||
|
||
import cloudpickle | ||
|
||
from flytekit.core.context_manager import FlyteContext | ||
from flytekit.core.type_engine import TypeEngine, TypeTransformer | ||
from flytekit.models.core import types as _core_types | ||
from flytekit.models.core.literals import Blob, BlobMetadata, Literal, Scalar | ||
from flytekit.models.core.types import LiteralType | ||
|
||
T = typing.TypeVar("T") | ||
|
||
|
||
class FlytePickle(typing.Generic[T]): | ||
""" | ||
This type is only used by flytekit internally. User should not use this type. | ||
Any type that flyte can't recognize will become FlytePickle | ||
""" | ||
|
||
@classmethod | ||
def python_type(cls) -> None: | ||
return None | ||
|
||
def __class_getitem__(cls, python_type: typing.Type) -> typing.Type[T]: | ||
if python_type is None: | ||
return cls | ||
|
||
class _SpecificFormatClass(FlytePickle): | ||
# Get the type engine to see this as kind of a generic | ||
__origin__ = FlytePickle | ||
|
||
@classmethod | ||
def python_type(cls) -> typing.Type: | ||
return python_type | ||
|
||
return _SpecificFormatClass | ||
|
||
|
||
class FlytePickleTransformer(TypeTransformer[FlytePickle]): | ||
PYTHON_PICKLE_FORMAT = "PythonPickle" | ||
|
||
def __init__(self): | ||
super().__init__(name="FlytePickle", t=FlytePickle) | ||
|
||
def assert_type(self, t: Type[T], v: T): | ||
# Every type can serialize to pickle, so we don't need to check the type here. | ||
... | ||
|
||
def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: | ||
uri = lv.scalar.blob.uri | ||
# Deserialize the pickle, and return data in the pickle, | ||
# and download pickle file to local first if file is not in the local file systems. | ||
if ctx.file_access.is_remote(uri): | ||
local_path = ctx.file_access.get_random_local_path() | ||
ctx.file_access.get_data(uri, local_path, False) | ||
uri = local_path | ||
with open(uri, "rb") as infile: | ||
data = cloudpickle.load(infile) | ||
return data | ||
|
||
def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: | ||
meta = BlobMetadata( | ||
type=_core_types.BlobType( | ||
format=self.PYTHON_PICKLE_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE | ||
) | ||
) | ||
# Dump the task output into pickle | ||
local_dir = ctx.file_access.get_random_local_directory() | ||
os.makedirs(local_dir, exist_ok=True) | ||
local_path = ctx.file_access.get_random_local_path() | ||
uri = os.path.join(local_dir, local_path) | ||
with open(uri, "w+b") as outfile: | ||
cloudpickle.dump(python_val, outfile) | ||
|
||
remote_path = ctx.file_access.get_random_remote_path(uri) | ||
ctx.file_access.put_data(uri, remote_path, is_multipart=False) | ||
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) | ||
|
||
def get_literal_type(self, t: Type[T]) -> LiteralType: | ||
return _core_types.LiteralType( | ||
blob=_core_types.BlobType( | ||
format=self.PYTHON_PICKLE_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE | ||
) | ||
) | ||
|
||
|
||
TypeEngine.register(FlytePickleTransformer()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from collections import OrderedDict | ||
from typing import Dict, List | ||
|
||
from flytekit.common.translator import get_serializable | ||
from flytekit.core import context_manager | ||
from flytekit.core.context_manager import Image, ImageConfig | ||
from flytekit.core.task import task | ||
from flytekit.models.core.literals import BlobMetadata | ||
from flytekit.models.core.types import BlobType, LiteralType | ||
from flytekit.types.pickle.pickle import FlytePickle, FlytePickleTransformer | ||
|
||
default_img = Image(name="default", fqn="test", tag="tag") | ||
serialization_settings = context_manager.SerializationSettings( | ||
project="project", | ||
domain="domain", | ||
version="version", | ||
env=None, | ||
image_config=ImageConfig(default_image=default_img, images=[default_img]), | ||
) | ||
|
||
|
||
def test_to_python_value_and_literal(): | ||
ctx = context_manager.FlyteContext.current_context() | ||
tf = FlytePickleTransformer() | ||
python_val = "fake_output" | ||
lt = tf.get_literal_type(FlytePickle) | ||
|
||
lv = tf.to_literal(ctx, python_val, type(python_val), lt) # type: ignore | ||
assert lv.scalar.blob.metadata == BlobMetadata( | ||
type=BlobType( | ||
format=FlytePickleTransformer.PYTHON_PICKLE_FORMAT, | ||
dimensionality=BlobType.BlobDimensionality.SINGLE, | ||
) | ||
) | ||
assert lv.scalar.blob.uri is not None | ||
|
||
output = tf.to_python_value(ctx, lv, str) | ||
assert output == python_val | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we add some more unit tests:
|
||
|
||
def test_get_literal_type(): | ||
tf = FlytePickleTransformer() | ||
lt = tf.get_literal_type(FlytePickle) | ||
assert lt == LiteralType( | ||
blob=BlobType( | ||
format=FlytePickleTransformer.PYTHON_PICKLE_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE | ||
) | ||
) | ||
|
||
|
||
def test_nested(): | ||
class Foo(object): | ||
def __init__(self, number: int): | ||
self.number = number | ||
|
||
@task | ||
def t1(a: int) -> List[List[Foo]]: | ||
return [[Foo(number=a)]] | ||
|
||
task_spec = get_serializable(OrderedDict(), serialization_settings, t1) | ||
assert ( | ||
task_spec.template.interface.outputs["o0"].type.collection_type.collection_type.blob.format | ||
is FlytePickleTransformer.PYTHON_PICKLE_FORMAT | ||
) | ||
|
||
|
||
def test_nested2(): | ||
class Foo(object): | ||
def __init__(self, number: int): | ||
self.number = number | ||
|
||
@task | ||
def t1(a: int) -> List[Dict[str, Foo]]: | ||
return [{"a": Foo(number=a)}] | ||
|
||
task_spec = get_serializable(OrderedDict(), serialization_settings, t1) | ||
assert ( | ||
task_spec.template.interface.outputs["o0"].type.collection_type.map_value_type.blob.format | ||
is FlytePickleTransformer.PYTHON_PICKLE_FORMAT | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does this do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we will not cancel all in-progress jobs if any unit test failed.
For example, Sometimes the test only failed with Python 3.9, but it will cancel the other tests with Python 3.7 or 3.8. It will be easy to debug If I know the test fail with which python version.