Skip to content

Commit

Permalink
Add support for python pickle type in flytekit/flyte (#667)
Browse files Browse the repository at this point in the history
* Init python pickle

Signed-off-by: Kevin Su <[email protected]>

* Refactor pickle support

Signed-off-by: Kevin Su <[email protected]>

* Fixed lint

Signed-off-by: Kevin Su <[email protected]>

* Fixed lint

Signed-off-by: Kevin Su <[email protected]>

* Fixed register error

Signed-off-by: Kevin Su <[email protected]>

* Added tests

Signed-off-by: Kevin Su <[email protected]>

* Fixed lint

Signed-off-by: Kevin Su <[email protected]>

* Handle list of pickle

Signed-off-by: Kevin Su <[email protected]>

* Fixed lint

Signed-off-by: Kevin Su <[email protected]>

* Added assert_type

Signed-off-by: Kevin Su <[email protected]>

* Updated comment

Signed-off-by: Kevin Su <[email protected]>

* Remove unnecessary cast

Signed-off-by: Kevin Su <[email protected]>

* Added more tests

Signed-off-by: Kevin Su <[email protected]>

* Fixed test

Signed-off-by: Kevin Su <[email protected]>

* Address comment

Signed-off-by: Kevin Su <[email protected]>

* Update list of pickle

Signed-off-by: Kevin Su <[email protected]>

* Update list of pickle

Signed-off-by: Kevin Su <[email protected]>

* Update list of pickle

Signed-off-by: Kevin Su <[email protected]>

* Fix tests

Signed-off-by: Kevin Su <[email protected]>

* Fix tests

Signed-off-by: Kevin Su <[email protected]>

* Fix tests

Signed-off-by: Kevin Su <[email protected]>

* Fix tests

Signed-off-by: Kevin Su <[email protected]>

* Fixed test

Signed-off-by: Kevin Su <[email protected]>

* Fixed test

Signed-off-by: Kevin Su <[email protected]>

* Fixed test

Signed-off-by: Kevin Su <[email protected]>

* add test

Signed-off-by: Yee Hing Tong <[email protected]>

* add test

Signed-off-by: Yee Hing Tong <[email protected]>

* Fixed lint

Signed-off-by: Kevin Su <[email protected]>

* Fixed import

Signed-off-by: Kevin Su <[email protected]>

* Handle nested list and dict

Signed-off-by: Kevin Su <[email protected]>

* Added tests

Signed-off-by: Kevin Su <[email protected]>

* Fixed lint

Signed-off-by: Kevin Su <[email protected]>

Co-authored-by: Yee Hing Tong <[email protected]>
  • Loading branch information
pingsutw and wild-endeavor authored Oct 27, 2021
1 parent 895cded commit 88b590c
Show file tree
Hide file tree
Showing 14 changed files with 271 additions and 14 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ jobs:
build:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: [3.7, 3.8, 3.9]
spark-version-suffix: ["", "-spark2"]
Expand Down
1 change: 0 additions & 1 deletion flytekit/common/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def get_serializable_task(
)
if settings.should_fast_serialize() and isinstance(entity, PythonAutoContainerTask):
entity.reset_command_fn()

return task_models.TaskSpec(template=tt)


Expand Down
42 changes: 38 additions & 4 deletions flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import collections
import copy
import inspect
import logging as _logging
import typing
from collections import OrderedDict
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union
Expand All @@ -13,6 +14,9 @@
from flytekit.core.type_engine import TypeEngine
from flytekit.loggers import logger
from flytekit.models.core import interface as _interface_models
from flytekit.types.pickle import FlytePickle

T = typing.TypeVar("T")


class Interface(object):
Expand Down Expand Up @@ -244,20 +248,41 @@ def transform_interface_to_list_interface(interface: Interface) -> Interface:
return Interface(inputs=map_inputs, outputs=map_outputs)


def _change_unrecognized_type_to_pickle(t: Type[T]) -> Type[T]:
try:
if hasattr(t, "__origin__") and hasattr(t, "__args__"):
if t.__origin__ == list:
return typing.List[_change_unrecognized_type_to_pickle(t.__args__[0])]
elif t.__origin__ == dict and t.__args__[0] == str:
return typing.Dict[str, _change_unrecognized_type_to_pickle(t.__args__[1])]
else:
TypeEngine.get_transformer(t)
except ValueError:
_logging.warning(
f"Unsupported Type {t} found, Flyte will default to use PickleFile as the transport. "
f"Pickle can only be used to send objects between the exact same version of Python, "
f"and we strongly recommend to use python type that flyte support."
)
return FlytePickle[t]
return t


def transform_signature_to_interface(signature: inspect.Signature, docstring: Optional[Docstring] = None) -> Interface:
"""
From the annotations on a task function that the user should have provided, and the output names they want to use
for each output parameter, construct the TypedInterface object
For now the fancy object, maybe in the future a dumb object.
"""
outputs = extract_return_annotation(signature.return_annotation)

for k, v in outputs.items():
outputs[k] = _change_unrecognized_type_to_pickle(v)
inputs = OrderedDict()
for k, v in signature.parameters.items():
annotation = v.annotation
default = v.default if v.default is not inspect.Parameter.empty else None
# Inputs with default values are currently ignored, we may want to look into that in the future
inputs[k] = (v.annotation, v.default if v.default is not inspect.Parameter.empty else None)
inputs[k] = (_change_unrecognized_type_to_pickle(annotation), default)

# This is just for typing.NamedTuples - in those cases, the user can select a name to call the NamedTuple. We
# would like to preserve that name in our custom collections.namedtuple.
Expand All @@ -273,7 +298,8 @@ def transform_signature_to_interface(signature: inspect.Signature, docstring: Op


def transform_variable_map(
variable_map: Dict[str, type], descriptions: Dict[str, str] = {}
variable_map: Dict[str, type],
descriptions: Dict[str, str] = {},
) -> Dict[str, _interface_models.Variable]:
"""
Given a map of str (names of inputs for instance) to their Python native types, return a map of the name to a
Expand All @@ -283,6 +309,14 @@ def transform_variable_map(
if variable_map:
for k, v in variable_map.items():
res[k] = transform_type(v, descriptions.get(k, k))
sub_type: Type[T] = v
if hasattr(v, "__origin__") and hasattr(v, "__args__"):
if v.__origin__ is list:
sub_type = v.__args__[0]
elif v.__origin__ is dict:
sub_type = v.__args__[1]
if hasattr(sub_type, "__origin__") and sub_type.__origin__ is FlytePickle:
res[k].type.metadata = {"python_class_name": sub_type.python_type().__name__}

return res

Expand Down
1 change: 1 addition & 0 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def my_wf(in1: int, in2: int) -> int:
def extract_value(
ctx: FlyteContext, input_val: Any, val_type: type, flyte_literal_type: flytekit.models.core.types.LiteralType
) -> _literal_models.Literal:

if isinstance(input_val, list):
if flyte_literal_type.collection_type is None:
raise TypeError(f"Not a collection type {flyte_literal_type} but got a list {input_val}")
Expand Down
2 changes: 1 addition & 1 deletion flytekit/extras/cloud_pickle_resolver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from base64 import b64decode, b64encode
from typing import List

import cloudpickle # intentionally not yet part of setup.py
import cloudpickle

from flytekit.core.base_task import TaskResolverMixin
from flytekit.core.context_manager import SerializationSettings
Expand Down
4 changes: 4 additions & 0 deletions flytekit/models/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ def metadata(self):
"""
return self._metadata

@metadata.setter
def metadata(self, value):
self._metadata = value

def to_flyte_idl(self):
"""
:rtype: flyteidl.core.types_pb2.LiteralType
Expand Down
12 changes: 12 additions & 0 deletions flytekit/types/pickle/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
Flytekit Pickle Type
==========================================================
.. currentmodule:: flytekit.types.pickle
.. autosummary::
:toctree: generated/
FlytePickle
"""

from .pickle import FlytePickle
89 changes: 89 additions & 0 deletions flytekit/types/pickle/pickle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import os
import typing
from typing import Type

import cloudpickle

from flytekit.core.context_manager import FlyteContext
from flytekit.core.type_engine import TypeEngine, TypeTransformer
from flytekit.models.core import types as _core_types
from flytekit.models.core.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.core.types import LiteralType

T = typing.TypeVar("T")


class FlytePickle(typing.Generic[T]):
"""
This type is only used by flytekit internally. User should not use this type.
Any type that flyte can't recognize will become FlytePickle
"""

@classmethod
def python_type(cls) -> None:
return None

def __class_getitem__(cls, python_type: typing.Type) -> typing.Type[T]:
if python_type is None:
return cls

class _SpecificFormatClass(FlytePickle):
# Get the type engine to see this as kind of a generic
__origin__ = FlytePickle

@classmethod
def python_type(cls) -> typing.Type:
return python_type

return _SpecificFormatClass


class FlytePickleTransformer(TypeTransformer[FlytePickle]):
PYTHON_PICKLE_FORMAT = "PythonPickle"

def __init__(self):
super().__init__(name="FlytePickle", t=FlytePickle)

def assert_type(self, t: Type[T], v: T):
# Every type can serialize to pickle, so we don't need to check the type here.
...

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
uri = lv.scalar.blob.uri
# Deserialize the pickle, and return data in the pickle,
# and download pickle file to local first if file is not in the local file systems.
if ctx.file_access.is_remote(uri):
local_path = ctx.file_access.get_random_local_path()
ctx.file_access.get_data(uri, local_path, False)
uri = local_path
with open(uri, "rb") as infile:
data = cloudpickle.load(infile)
return data

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
meta = BlobMetadata(
type=_core_types.BlobType(
format=self.PYTHON_PICKLE_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
)
# Dump the task output into pickle
local_dir = ctx.file_access.get_random_local_directory()
os.makedirs(local_dir, exist_ok=True)
local_path = ctx.file_access.get_random_local_path()
uri = os.path.join(local_dir, local_path)
with open(uri, "w+b") as outfile:
cloudpickle.dump(python_val, outfile)

remote_path = ctx.file_access.get_random_remote_path(uri)
ctx.file_access.put_data(uri, remote_path, is_multipart=False)
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))

def get_literal_type(self, t: Type[T]) -> LiteralType:
return _core_types.LiteralType(
blob=_core_types.BlobType(
format=self.PYTHON_PICKLE_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
)


TypeEngine.register(FlytePickleTransformer())
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
"docstring-parser>=0.9.0",
"diskcache>=5.2.1",
"checksumdir>=1.2.0",
"cloudpickle>=2.0.0",
],
extras_require=extras_require,
scripts=[
Expand Down
6 changes: 3 additions & 3 deletions tests/flytekit/unit/core/functools/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def test_unwrapped_task():
)
error = completed_process.stderr
error_str = error.strip().split("\n")[-1]
assert error_str == (
"ValueError: Type <class 'inspect._empty'> not supported currently in Flytekit. "
"Please register a new transformer"
assert (
"TaskFunction cannot be a nested/inner or local function."
" It should be accessible at a module level for Flyte to execute it." in error_str
)


Expand Down
80 changes: 80 additions & 0 deletions tests/flytekit/unit/core/test_flyte_pickle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from collections import OrderedDict
from typing import Dict, List

from flytekit.common.translator import get_serializable
from flytekit.core import context_manager
from flytekit.core.context_manager import Image, ImageConfig
from flytekit.core.task import task
from flytekit.models.core.literals import BlobMetadata
from flytekit.models.core.types import BlobType, LiteralType
from flytekit.types.pickle.pickle import FlytePickle, FlytePickleTransformer

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = context_manager.SerializationSettings(
project="project",
domain="domain",
version="version",
env=None,
image_config=ImageConfig(default_image=default_img, images=[default_img]),
)


def test_to_python_value_and_literal():
ctx = context_manager.FlyteContext.current_context()
tf = FlytePickleTransformer()
python_val = "fake_output"
lt = tf.get_literal_type(FlytePickle)

lv = tf.to_literal(ctx, python_val, type(python_val), lt) # type: ignore
assert lv.scalar.blob.metadata == BlobMetadata(
type=BlobType(
format=FlytePickleTransformer.PYTHON_PICKLE_FORMAT,
dimensionality=BlobType.BlobDimensionality.SINGLE,
)
)
assert lv.scalar.blob.uri is not None

output = tf.to_python_value(ctx, lv, str)
assert output == python_val


def test_get_literal_type():
tf = FlytePickleTransformer()
lt = tf.get_literal_type(FlytePickle)
assert lt == LiteralType(
blob=BlobType(
format=FlytePickleTransformer.PYTHON_PICKLE_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE
)
)


def test_nested():
class Foo(object):
def __init__(self, number: int):
self.number = number

@task
def t1(a: int) -> List[List[Foo]]:
return [[Foo(number=a)]]

task_spec = get_serializable(OrderedDict(), serialization_settings, t1)
assert (
task_spec.template.interface.outputs["o0"].type.collection_type.collection_type.blob.format
is FlytePickleTransformer.PYTHON_PICKLE_FORMAT
)


def test_nested2():
class Foo(object):
def __init__(self, number: int):
self.number = number

@task
def t1(a: int) -> List[Dict[str, Foo]]:
return [{"a": Foo(number=a)}]

task_spec = get_serializable(OrderedDict(), serialization_settings, t1)
assert (
task_spec.template.interface.outputs["o0"].type.collection_type.map_value_type.blob.format
is FlytePickleTransformer.PYTHON_PICKLE_FORMAT
)
19 changes: 19 additions & 0 deletions tests/flytekit/unit/core/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from flytekit.models.core import types as _core_types
from flytekit.types.file import FlyteFile
from flytekit.types.pickle import FlytePickle


def test_extract_only():
Expand Down Expand Up @@ -269,3 +270,21 @@ def z(a: int, b: str) -> typing.NamedTuple("NT", x_str=str, y_int=int):
assert typed_interface.inputs.get("b").description == "bar"
assert typed_interface.outputs.get("x_str").description == "description for x_str"
assert typed_interface.outputs.get("y_int").description == "description for y_int"


def test_parameter_change_to_pickle_type():
ctx = context_manager.FlyteContext.current_context()

class Foo:
def __init__(self, name):
self.name = name

def z(a: Foo) -> Foo:
...

our_interface = transform_signature_to_interface(inspect.signature(z))
params = transform_inputs_to_parameters(ctx, our_interface)
assert params.parameters["a"].required
assert params.parameters["a"].default is None
assert our_interface.outputs["o0"].__origin__ == FlytePickle
assert our_interface.inputs["a"].__origin__ == FlytePickle
3 changes: 3 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from flytekit.types.directory.types import FlyteDirectory
from flytekit.types.file import JPEGImageFile
from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer
from flytekit.types.pickle import FlytePickle
from flytekit.types.pickle.pickle import FlytePickleTransformer


def test_type_engine():
Expand Down Expand Up @@ -60,6 +62,7 @@ def test_type_resolution():
assert type(TypeEngine.get_transformer(int)) == SimpleTransformer

assert type(TypeEngine.get_transformer(os.PathLike)) == FlyteFilePathTransformer
assert type(TypeEngine.get_transformer(FlytePickle)) == FlytePickleTransformer

with pytest.raises(ValueError):
TypeEngine.get_transformer(typing.Any)
Expand Down
Loading

0 comments on commit 88b590c

Please sign in to comment.