Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for python pickle type in flytekit/flyte #667

Merged
merged 32 commits into from
Oct 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ jobs:
build:
runs-on: ubuntu-latest
strategy:
fail-fast: false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this do?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will not cancel all in-progress jobs if any unit test failed.
For example, Sometimes the test only failed with Python 3.9, but it will cancel the other tests with Python 3.7 or 3.8. It will be easy to debug If I know the test fail with which python version.

matrix:
python-version: [3.7, 3.8, 3.9]
spark-version-suffix: ["", "-spark2"]
Expand Down
1 change: 0 additions & 1 deletion flytekit/common/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def get_serializable_task(
)
if settings.should_fast_serialize() and isinstance(entity, PythonAutoContainerTask):
entity.reset_command_fn()

return task_models.TaskSpec(template=tt)


Expand Down
42 changes: 38 additions & 4 deletions flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import collections
import copy
import inspect
import logging as _logging
import typing
from collections import OrderedDict
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union
Expand All @@ -13,6 +14,9 @@
from flytekit.core.type_engine import TypeEngine
from flytekit.loggers import logger
from flytekit.models.core import interface as _interface_models
from flytekit.types.pickle import FlytePickle

T = typing.TypeVar("T")


class Interface(object):
Expand Down Expand Up @@ -244,20 +248,41 @@ def transform_interface_to_list_interface(interface: Interface) -> Interface:
return Interface(inputs=map_inputs, outputs=map_outputs)


def _change_unrecognized_type_to_pickle(t: Type[T]) -> Type[T]:
try:
if hasattr(t, "__origin__") and hasattr(t, "__args__"):
if t.__origin__ == list:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm having a hard time understanding why there needs to be special casing for list/dict. Other types don't have this right? like we don't need this for strings. So why do we need it for Pickle? I pushed a couple tests that are failing for me... could you take a look at them please? Thanks.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should check the element in the list can be converted to literal, otherwise, we will get an error here.
So we have to change the type to pickle in compile-time and it also makes FlyteConsole work.

List[Foo] -> List[FlytePickle[Foo]]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated it to handle nested lists and dict.

return typing.List[_change_unrecognized_type_to_pickle(t.__args__[0])]
elif t.__origin__ == dict and t.__args__[0] == str:
return typing.Dict[str, _change_unrecognized_type_to_pickle(t.__args__[1])]
else:
TypeEngine.get_transformer(t)
except ValueError:
_logging.warning(
f"Unsupported Type {t} found, Flyte will default to use PickleFile as the transport. "
f"Pickle can only be used to send objects between the exact same version of Python, "
f"and we strongly recommend to use python type that flyte support."
)
return FlytePickle[t]
return t


def transform_signature_to_interface(signature: inspect.Signature, docstring: Optional[Docstring] = None) -> Interface:
"""
From the annotations on a task function that the user should have provided, and the output names they want to use
for each output parameter, construct the TypedInterface object

For now the fancy object, maybe in the future a dumb object.

"""
outputs = extract_return_annotation(signature.return_annotation)

for k, v in outputs.items():
outputs[k] = _change_unrecognized_type_to_pickle(v)
inputs = OrderedDict()
for k, v in signature.parameters.items():
annotation = v.annotation
default = v.default if v.default is not inspect.Parameter.empty else None
# Inputs with default values are currently ignored, we may want to look into that in the future
inputs[k] = (v.annotation, v.default if v.default is not inspect.Parameter.empty else None)
inputs[k] = (_change_unrecognized_type_to_pickle(annotation), default)

# This is just for typing.NamedTuples - in those cases, the user can select a name to call the NamedTuple. We
# would like to preserve that name in our custom collections.namedtuple.
Expand All @@ -273,7 +298,8 @@ def transform_signature_to_interface(signature: inspect.Signature, docstring: Op


def transform_variable_map(
variable_map: Dict[str, type], descriptions: Dict[str, str] = {}
variable_map: Dict[str, type],
descriptions: Dict[str, str] = {},
) -> Dict[str, _interface_models.Variable]:
"""
Given a map of str (names of inputs for instance) to their Python native types, return a map of the name to a
Expand All @@ -283,6 +309,14 @@ def transform_variable_map(
if variable_map:
for k, v in variable_map.items():
res[k] = transform_type(v, descriptions.get(k, k))
sub_type: Type[T] = v
if hasattr(v, "__origin__") and hasattr(v, "__args__"):
if v.__origin__ is list:
sub_type = v.__args__[0]
elif v.__origin__ is dict:
sub_type = v.__args__[1]
if hasattr(sub_type, "__origin__") and sub_type.__origin__ is FlytePickle:
res[k].type.metadata = {"python_class_name": sub_type.python_type().__name__}

return res

Expand Down
1 change: 1 addition & 0 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def my_wf(in1: int, in2: int) -> int:
def extract_value(
ctx: FlyteContext, input_val: Any, val_type: type, flyte_literal_type: flytekit.models.core.types.LiteralType
) -> _literal_models.Literal:

if isinstance(input_val, list):
if flyte_literal_type.collection_type is None:
raise TypeError(f"Not a collection type {flyte_literal_type} but got a list {input_val}")
Expand Down
2 changes: 1 addition & 1 deletion flytekit/extras/cloud_pickle_resolver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from base64 import b64decode, b64encode
from typing import List

import cloudpickle # intentionally not yet part of setup.py
import cloudpickle

from flytekit.core.base_task import TaskResolverMixin
from flytekit.core.context_manager import SerializationSettings
Expand Down
4 changes: 4 additions & 0 deletions flytekit/models/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ def metadata(self):
"""
return self._metadata

@metadata.setter
def metadata(self, value):
self._metadata = value

def to_flyte_idl(self):
"""
:rtype: flyteidl.core.types_pb2.LiteralType
Expand Down
12 changes: 12 additions & 0 deletions flytekit/types/pickle/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
Flytekit Pickle Type
==========================================================
.. currentmodule:: flytekit.types.pickle

.. autosummary::
:toctree: generated/

FlytePickle
"""

from .pickle import FlytePickle
89 changes: 89 additions & 0 deletions flytekit/types/pickle/pickle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import os
import typing
from typing import Type

import cloudpickle

from flytekit.core.context_manager import FlyteContext
from flytekit.core.type_engine import TypeEngine, TypeTransformer
from flytekit.models.core import types as _core_types
from flytekit.models.core.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.core.types import LiteralType

T = typing.TypeVar("T")


class FlytePickle(typing.Generic[T]):
"""
This type is only used by flytekit internally. User should not use this type.
Any type that flyte can't recognize will become FlytePickle
"""

@classmethod
def python_type(cls) -> None:
return None

def __class_getitem__(cls, python_type: typing.Type) -> typing.Type[T]:
if python_type is None:
return cls

class _SpecificFormatClass(FlytePickle):
# Get the type engine to see this as kind of a generic
__origin__ = FlytePickle

@classmethod
def python_type(cls) -> typing.Type:
return python_type

return _SpecificFormatClass


class FlytePickleTransformer(TypeTransformer[FlytePickle]):
PYTHON_PICKLE_FORMAT = "PythonPickle"

def __init__(self):
super().__init__(name="FlytePickle", t=FlytePickle)

def assert_type(self, t: Type[T], v: T):
# Every type can serialize to pickle, so we don't need to check the type here.
...

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
uri = lv.scalar.blob.uri
# Deserialize the pickle, and return data in the pickle,
# and download pickle file to local first if file is not in the local file systems.
if ctx.file_access.is_remote(uri):
local_path = ctx.file_access.get_random_local_path()
ctx.file_access.get_data(uri, local_path, False)
uri = local_path
with open(uri, "rb") as infile:
data = cloudpickle.load(infile)
return data

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
meta = BlobMetadata(
type=_core_types.BlobType(
format=self.PYTHON_PICKLE_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
)
# Dump the task output into pickle
local_dir = ctx.file_access.get_random_local_directory()
os.makedirs(local_dir, exist_ok=True)
local_path = ctx.file_access.get_random_local_path()
uri = os.path.join(local_dir, local_path)
with open(uri, "w+b") as outfile:
cloudpickle.dump(python_val, outfile)

remote_path = ctx.file_access.get_random_remote_path(uri)
ctx.file_access.put_data(uri, remote_path, is_multipart=False)
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))

def get_literal_type(self, t: Type[T]) -> LiteralType:
return _core_types.LiteralType(
blob=_core_types.BlobType(
format=self.PYTHON_PICKLE_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
)


TypeEngine.register(FlytePickleTransformer())
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
"docstring-parser>=0.9.0",
"diskcache>=5.2.1",
"checksumdir>=1.2.0",
"cloudpickle>=2.0.0",
],
extras_require=extras_require,
scripts=[
Expand Down
6 changes: 3 additions & 3 deletions tests/flytekit/unit/core/functools/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def test_unwrapped_task():
)
error = completed_process.stderr
error_str = error.strip().split("\n")[-1]
assert error_str == (
"ValueError: Type <class 'inspect._empty'> not supported currently in Flytekit. "
"Please register a new transformer"
assert (
"TaskFunction cannot be a nested/inner or local function."
" It should be accessible at a module level for Flyte to execute it." in error_str
)


Expand Down
80 changes: 80 additions & 0 deletions tests/flytekit/unit/core/test_flyte_pickle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from collections import OrderedDict
from typing import Dict, List

from flytekit.common.translator import get_serializable
from flytekit.core import context_manager
from flytekit.core.context_manager import Image, ImageConfig
from flytekit.core.task import task
from flytekit.models.core.literals import BlobMetadata
from flytekit.models.core.types import BlobType, LiteralType
from flytekit.types.pickle.pickle import FlytePickle, FlytePickleTransformer

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = context_manager.SerializationSettings(
project="project",
domain="domain",
version="version",
env=None,
image_config=ImageConfig(default_image=default_img, images=[default_img]),
)


def test_to_python_value_and_literal():
ctx = context_manager.FlyteContext.current_context()
tf = FlytePickleTransformer()
python_val = "fake_output"
lt = tf.get_literal_type(FlytePickle)

lv = tf.to_literal(ctx, python_val, type(python_val), lt) # type: ignore
assert lv.scalar.blob.metadata == BlobMetadata(
type=BlobType(
format=FlytePickleTransformer.PYTHON_PICKLE_FORMAT,
dimensionality=BlobType.BlobDimensionality.SINGLE,
)
)
assert lv.scalar.blob.uri is not None

output = tf.to_python_value(ctx, lv, str)
assert output == python_val

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add some more unit tests:

  1. a task that uses a list of Foo
  2. a task that uses a dict of str -> Foo
  3. a workflow that uses it, and testing that local workflow execution still works?


def test_get_literal_type():
tf = FlytePickleTransformer()
lt = tf.get_literal_type(FlytePickle)
assert lt == LiteralType(
blob=BlobType(
format=FlytePickleTransformer.PYTHON_PICKLE_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE
)
)


def test_nested():
class Foo(object):
def __init__(self, number: int):
self.number = number

@task
def t1(a: int) -> List[List[Foo]]:
return [[Foo(number=a)]]

task_spec = get_serializable(OrderedDict(), serialization_settings, t1)
assert (
task_spec.template.interface.outputs["o0"].type.collection_type.collection_type.blob.format
is FlytePickleTransformer.PYTHON_PICKLE_FORMAT
)


def test_nested2():
class Foo(object):
def __init__(self, number: int):
self.number = number

@task
def t1(a: int) -> List[Dict[str, Foo]]:
return [{"a": Foo(number=a)}]

task_spec = get_serializable(OrderedDict(), serialization_settings, t1)
assert (
task_spec.template.interface.outputs["o0"].type.collection_type.map_value_type.blob.format
is FlytePickleTransformer.PYTHON_PICKLE_FORMAT
)
19 changes: 19 additions & 0 deletions tests/flytekit/unit/core/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from flytekit.models.core import types as _core_types
from flytekit.types.file import FlyteFile
from flytekit.types.pickle import FlytePickle


def test_extract_only():
Expand Down Expand Up @@ -269,3 +270,21 @@ def z(a: int, b: str) -> typing.NamedTuple("NT", x_str=str, y_int=int):
assert typed_interface.inputs.get("b").description == "bar"
assert typed_interface.outputs.get("x_str").description == "description for x_str"
assert typed_interface.outputs.get("y_int").description == "description for y_int"


def test_parameter_change_to_pickle_type():
ctx = context_manager.FlyteContext.current_context()

class Foo:
def __init__(self, name):
self.name = name

def z(a: Foo) -> Foo:
...

our_interface = transform_signature_to_interface(inspect.signature(z))
params = transform_inputs_to_parameters(ctx, our_interface)
assert params.parameters["a"].required
assert params.parameters["a"].default is None
assert our_interface.outputs["o0"].__origin__ == FlytePickle
assert our_interface.inputs["a"].__origin__ == FlytePickle
3 changes: 3 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from flytekit.types.directory.types import FlyteDirectory
from flytekit.types.file import JPEGImageFile
from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer
from flytekit.types.pickle import FlytePickle
from flytekit.types.pickle.pickle import FlytePickleTransformer


def test_type_engine():
Expand Down Expand Up @@ -60,6 +62,7 @@ def test_type_resolution():
assert type(TypeEngine.get_transformer(int)) == SimpleTransformer

assert type(TypeEngine.get_transformer(os.PathLike)) == FlyteFilePathTransformer
assert type(TypeEngine.get_transformer(FlytePickle)) == FlytePickleTransformer

with pytest.raises(ValueError):
TypeEngine.get_transformer(typing.Any)
Expand Down
Loading