Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pickle as a default Transformer rather than a special case of function signatures #1661

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from flytekit.core.base_task import PythonTask
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.core.type_engine import TypeEngine
from flytekit.core.type_engine import FlytePickleTransformer, TypeEngine
from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase
from flytekit.models import literals
from flytekit.models.interface import Variable
Expand All @@ -46,7 +46,6 @@
from flytekit.tools import module_loader, script_mode
from flytekit.tools.script_mode import _find_project_root
from flytekit.tools.translator import Options
from flytekit.types.pickle.pickle import FlytePickleTransformer

REMOTE_FLAG_KEY = "remote"
RUN_LEVEL_PARAMS_KEY = "run_level_params"
Expand Down
46 changes: 2 additions & 44 deletions flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections import OrderedDict
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union, cast

from typing_extensions import Annotated, get_args, get_origin, get_type_hints
from typing_extensions import get_args, get_origin, get_type_hints

from flytekit.core import context_manager
from flytekit.core.docstring import Docstring
Expand All @@ -16,7 +16,6 @@
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
from flytekit.models.literals import Void
from flytekit.types.pickle import FlytePickle

T = typing.TypeVar("T")

Expand Down Expand Up @@ -294,31 +293,6 @@ def transform_interface_to_list_interface(interface: Interface, bound_inputs: ty
return Interface(inputs=map_inputs, outputs=map_outputs)


def _change_unrecognized_type_to_pickle(t: Type[T]) -> typing.Union[Tuple[Type[T]], Type[T]]:
try:
if hasattr(t, "__origin__") and hasattr(t, "__args__"):
ot = get_origin(t)
args = getattr(t, "__args__")
if ot is list:
return typing.List[_change_unrecognized_type_to_pickle(args[0])] # type: ignore
elif ot is dict and args[0] == str:
return typing.Dict[str, _change_unrecognized_type_to_pickle(args[1])] # type: ignore
elif ot is typing.Union:
return typing.Union[tuple(_change_unrecognized_type_to_pickle(v) for v in get_args(t))] # type: ignore
elif ot is Annotated:
base_type, *config = get_args(t)
return Annotated[(_change_unrecognized_type_to_pickle(base_type), *config)] # type: ignore
TypeEngine.get_transformer(t)
except ValueError:
logger.warning(
f"Unsupported Type {t} found, Flyte will default to use PickleFile as the transport. "
f"Pickle can only be used to send objects between the exact same version of Python, "
f"and we strongly recommend to use python type that flyte support."
)
return FlytePickle[t]
return t


def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Docstring] = None) -> Interface:
"""
From the annotations on a task function that the user should have provided, and the output names they want to use
Expand All @@ -332,14 +306,12 @@ def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Doc
return_annotation = type_hints.get("return", None)

outputs = extract_return_annotation(return_annotation)
for k, v in outputs.items():
outputs[k] = _change_unrecognized_type_to_pickle(v) # type: ignore
inputs: Dict[str, Tuple[Type, Any]] = OrderedDict()
for k, v in signature.parameters.items(): # type: ignore
annotation = type_hints.get(k, None)
default = v.default if v.default is not inspect.Parameter.empty else None
# Inputs with default values are currently ignored, we may want to look into that in the future
inputs[k] = (_change_unrecognized_type_to_pickle(annotation), default) # type: ignore
inputs[k] = (annotation, default) # type: ignore

# This is just for typing.NamedTuples - in those cases, the user can select a name to call the NamedTuple. We
# would like to preserve that name in our custom collections.namedtuple.
Expand All @@ -365,20 +337,6 @@ def transform_variable_map(
if variable_map:
for k, v in variable_map.items():
res[k] = transform_type(v, descriptions.get(k, k))
sub_type: type = v
if hasattr(v, "__origin__") and hasattr(v, "__args__"):
if getattr(v, "__origin__") is list:
sub_type = getattr(v, "__args__")[0]
elif getattr(v, "__origin__") is dict:
sub_type = getattr(v, "__args__")[1]
if hasattr(sub_type, "__origin__") and getattr(sub_type, "__origin__") is FlytePickle:
original_type = cast(FlytePickle, sub_type).python_type()
if hasattr(original_type, "__name__"):
res[k].type.metadata = {"python_class_name": original_type.__name__}
elif hasattr(original_type, "_name"):
# If the class doesn't have the __name__ attribute, like typing.Sequence, use _name instead.
res[k].type.metadata = {"python_class_name": original_type._name}

return res


Expand Down
6 changes: 0 additions & 6 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,6 @@ def extract_value(
if len(input_val) == 0:
raise
sub_type = type(input_val[0])
# To maintain consistency between translate_inputs_to_literals and ListTransformer.to_literal for batchable types,
# directly call ListTransformer.to_literal to batch process the list items. This is necessary because processing
# each list item separately could lead to errors since ListTransformer.to_python_value may treat the literal
# as it is batched for batchable types.
if ListTransformer.is_batchable(python_type):
return TypeEngine.to_literal(ctx, input_val, python_type, lt)
literal_list = [extract_value(ctx, v, sub_type, lt.collection_type) for v in input_val]
return _literal_models.Literal(collection=_literal_models.LiteralCollection(literals=literal_list))
elif isinstance(input_val, dict):
Expand Down
139 changes: 84 additions & 55 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import inspect
import json as _json
import mimetypes
import os
import textwrap
import typing
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import Dict, NamedTuple, Optional, Type, cast

import cloudpickle
from dataclasses_json import DataClassJsonMixin, dataclass_json
from google.protobuf import json_format as _json_format
from google.protobuf import struct_pb2 as _struct
Expand Down Expand Up @@ -663,6 +665,64 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[T]:
raise ValueError(f"Transformer {self} cannot reverse {literal_type}")


class FlytePickleTransformer(TypeTransformer[T]):
PYTHON_PICKLE_FORMAT = "PythonPickle"

def __init__(self, t: Type[T]):
super().__init__(name="FlytePickle", t=t)

def assert_type(self, t: Type[T], v: T):
if not isinstance(v, self.python_type):
raise ValueError

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
uri = lv.scalar.blob.uri
# Deserialize the pickle, and return data in the pickle,
# and download pickle file to local first if file is not in the local file systems.
if ctx.file_access.is_remote(uri):
local_path = ctx.file_access.get_random_local_path()
ctx.file_access.get_data(uri, local_path, False)
uri = local_path
with open(uri, "rb") as infile:
data = cloudpickle.load(infile)
return data

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
meta = BlobMetadata(
type=_core_types.BlobType(
format=self.PYTHON_PICKLE_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
)
# Dump the task output into pickle
local_dir = ctx.file_access.get_random_local_directory()
os.makedirs(local_dir, exist_ok=True)
local_path = ctx.file_access.get_random_local_path()
uri = os.path.join(local_dir, local_path)
with open(uri, "w+b") as outfile:
cloudpickle.dump(python_val, outfile)

remote_path = ctx.file_access.get_random_remote_path(uri)
ctx.file_access.put_data(uri, remote_path, is_multipart=False)
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))

def guess_python_type(self, literal_type: LiteralType) -> typing.Type[T]:
if (
literal_type.blob is not None
and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE
and literal_type.blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT
):
return self.python_type

raise ValueError(f"Transformer {self} cannot reverse {literal_type}")

def get_literal_type(self, t: Type[T]) -> LiteralType:
return LiteralType(
blob=_core_types.BlobType(
format=self.PYTHON_PICKLE_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
)


class TypeEngine(typing.Generic[T]):
"""
Core Extensible TypeEngine of Flytekit. This should be used to extend the capabilities of FlyteKits type system.
Expand Down Expand Up @@ -733,6 +793,9 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]:

Step 5:
if v is of type data class, use the dataclass transformer

Step 5:
Otherwise, default to Pickling the object with a warning
"""
cls.lazy_import_transformers()
# Step 1
Expand All @@ -757,8 +820,6 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]:
if python_type.__origin__ in cls._REGISTRY:
return cls._REGISTRY[python_type.__origin__]

raise ValueError(f"Generic Type {python_type.__origin__} not supported currently in Flytekit.")

# Step 3
# To facilitate cases where users may specify one transformer for multiple types that all inherit from one
# parent.
Expand All @@ -779,7 +840,11 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]:
if dataclasses.is_dataclass(python_type):
return cls._DATACLASS_TRANSFORMER

raise ValueError(f"Type {python_type} not supported currently in Flytekit. Please register a new transformer")
logger.warning(
f"Using pickling for type {python_type}. This may be unsafe if passing between tasks with"
" different python environments."
)
return FlytePickleTransformer(t=python_type)

@classmethod
def lazy_import_transformers(cls):
Expand Down Expand Up @@ -1021,62 +1086,22 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]:
except Exception as e:
raise ValueError(f"Type of Generic List type is not supported, {e}")

@staticmethod
def is_batchable(t: Type):
"""
This function evaluates whether the provided type is batchable or not.
It returns True only if the type is either List or Annotated(List) and the List subtype is FlytePickle.
"""
from flytekit.types.pickle import FlytePickle

if is_annotated(t):
return ListTransformer.is_batchable(get_args(t)[0])
if get_origin(t) is list:
subtype = get_args(t)[0]
if subtype == FlytePickle or (hasattr(subtype, "__origin__") and subtype.__origin__ == FlytePickle):
return True
return False

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
if type(python_val) != list:
raise TypeTransformerFailedError("Expected a list")

if ListTransformer.is_batchable(python_type):
from flytekit.types.pickle.pickle import BatchSize, FlytePickle

batch_size = len(python_val) # default batch size
# parse annotated to get the number of items saved in a pickle file.
if is_annotated(python_type):
for annotation in get_args(python_type)[1:]:
if isinstance(annotation, BatchSize):
batch_size = annotation.val
break
if batch_size > 0:
lit_list = [TypeEngine.to_literal(ctx, python_val[i : i + batch_size], FlytePickle, expected.collection_type) for i in range(0, len(python_val), batch_size)] # type: ignore
else:
lit_list = []
else:
t = self.get_sub_type(python_type)
lit_list = [TypeEngine.to_literal(ctx, x, t, expected.collection_type) for x in python_val] # type: ignore
t = self.get_sub_type(python_type)
lit_list = [TypeEngine.to_literal(ctx, x, t, expected.collection_type) for x in python_val] # type: ignore
return Literal(collection=LiteralCollection(literals=lit_list))

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> typing.List[typing.Any]: # type: ignore
try:
lits = lv.collection.literals
except AttributeError:
raise TypeTransformerFailedError()
if self.is_batchable(expected_python_type):
from flytekit.types.pickle import FlytePickle

batch_list = [TypeEngine.to_python_value(ctx, batch, FlytePickle) for batch in lits]
if len(batch_list) > 0 and type(batch_list[0]) is list:
# Make it have backward compatibility. The upstream task may use old version of Flytekit that
# won't merge the elements in the list. Therefore, we should check if the batch_list[0] is the list first.
return [item for batch in batch_list for item in batch]
return batch_list
else:
st = self.get_sub_type(expected_python_type)
return [TypeEngine.to_python_value(ctx, x, st) for x in lits]

st = self.get_sub_type(expected_python_type)
return [TypeEngine.to_python_value(ctx, x, st) for x in lits]

def guess_python_type(self, literal_type: LiteralType) -> list: # type: ignore
if literal_type.collection_type:
Expand Down Expand Up @@ -1206,7 +1231,7 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]:
def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
python_type = get_underlying_type(python_type)

found_res = False
found_res = None
res = None
res_type = None
for t in get_args(python_type):
Expand All @@ -1215,10 +1240,14 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp

res = trans.to_literal(ctx, python_val, t, expected)
res_type = _add_tag_to_type(trans.get_literal_type(t), trans.name)
# Triggered if there are two valid encoders for two types in the Union type.
# This only happens in the case where there is an `Annotated` override
if found_res:
# Should really never happen, sanity check
raise TypeError("Ambiguous choice of variant for union type")
found_res = True
raise TypeError(
f"Ambiguous choice of variant for union type: {python_val} can be encoded as either"
f"{found_res} or {t}. Please remove one from your Union type."
)
found_res = t
except (TypeTransformerFailedError, AttributeError, ValueError, AssertionError) as e:
logger.debug(f"Failed to convert from {python_val} to {t}", e)
continue
Expand All @@ -1238,7 +1267,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
if union_type.structure is not None:
union_tag = union_type.structure.tag

found_res = False
found_res = None
res = None
res_tag = None
for v in get_args(expected_python_type):
Expand All @@ -1260,9 +1289,9 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
if found_res:
raise TypeError(
"Ambiguous choice of variant for union type. "
+ f"Both {res_tag} and {trans.name} transformers match"
+ f"Can be decoded as either {found_res} or {v}. Please remove one from your Union type."
)
found_res = True
found_res = v
else:
res = trans.to_python_value(ctx, lv, v)
if found_res:
Expand Down
10 changes: 4 additions & 6 deletions flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@

from dataclasses_json import config, dataclass_json
from marshmallow import fields
from typing_extensions import Annotated, get_args, get_origin

from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError, get_underlying_type
from flytekit.core.type_engine import FlytePickleTransformer, TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.loggers import logger
from flytekit.models.core.types import BlobType
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType
from flytekit.types.pickle.pickle import FlytePickleTransformer


def noop():
Expand Down Expand Up @@ -336,7 +336,8 @@ def to_literal(
raise TypeTransformerFailedError("None value cannot be converted to a file.")

# Correctly handle `Annotated[FlyteFile, ...]` by extracting the origin type
python_type = get_underlying_type(python_type)
if get_origin(python_type) is Annotated:
python_type = get_args(python_type)[0]

if not (python_type is os.PathLike or issubclass(python_type, FlyteFile)):
raise ValueError(f"Incorrect type {python_type}, must be either a FlyteFile or os.PathLike")
Expand Down Expand Up @@ -411,9 +412,6 @@ def to_python_value(
if expected_python_type is os.PathLike:
return FlyteFile(uri)

# Correctly handle `Annotated[FlyteFile, ...]` by extracting the origin type
expected_python_type = get_underlying_type(expected_python_type)

# The rest of the logic is only for FlyteFile types.
if not issubclass(expected_python_type, FlyteFile): # type: ignore
raise TypeError(f"Neither os.PathLike nor FlyteFile specified {expected_python_type}")
Expand Down
12 changes: 0 additions & 12 deletions flytekit/types/pickle/__init__.py

This file was deleted.

Loading