Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add type support for pydantic BaseModels #1660

Merged
Merged
Show file tree
Hide file tree
Changes from 48 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
fe9434f
Allow annotated FlyteFile as task input argument (#1632)
AdrianoKF May 19, 2023
b4e6f80
Use logger instead of print statement in sqlalchemy plugin (#1651)
wirthual May 22, 2023
ff73464
Map over notebook task (#1650)
pingsutw May 24, 2023
e9a714b
Support single literals in tiny url (#1654)
wild-endeavor May 24, 2023
17f3441
Skip grpcio 1.55.0 (#1653)
eapolinario May 24, 2023
594026a
Add support overriding image (#1652)
pingsutw May 24, 2023
1cf2556
Add setup.py
May 4, 2023
82d0773
Add readme
May 4, 2023
c6d14a4
Add and compile requirements
May 4, 2023
1433cd2
Add tests for type transformer
May 4, 2023
640d17a
Add type transformer
May 4, 2023
8325a62
Make docstring more concise
May 5, 2023
f2140b6
pydantic with flytepath and flytedirectory
ArthurBook May 26, 2023
2ea5b44
added example for how to use with BaseModel Config class
ArthurBook May 26, 2023
ccbb70b
moved setting of json-encoders in to pydantic transformer
ArthurBook Jun 2, 2023
b671d3d
added test for flytepath
ArthurBook Jun 2, 2023
240aa30
added flytedir test
ArthurBook Jun 2, 2023
34b1576
flytekit will auto load the transformer if the plugin is installed.
ArthurBook Jun 5, 2023
bd686ad
added upload to s3
ArthurBook Jun 5, 2023
67da286
removed ctx as input to upload_to_s3
ArthurBook Jun 7, 2023
de92c75
Update plugins/flytekit-pydantic/tests/test_type_transformer.py
ArthurBook Jun 12, 2023
24a4152
Revert "Allow annotated FlyteFile as task input argument (#1632)"
ArthurBook Jun 12, 2023
1880913
Revert "Use logger instead of print statement in sqlalchemy plugin (#…
ArthurBook Jun 12, 2023
595183d
Revert "Map over notebook task (#1650)"
ArthurBook Jun 12, 2023
75ea89b
Revert "Support single literals in tiny url (#1654)"
ArthurBook Jun 12, 2023
39acb76
Revert "Skip grpcio 1.55.0 (#1653)"
ArthurBook Jun 12, 2023
ab448d7
Revert "Add support overriding image (#1652)"
ArthurBook Jun 12, 2023
249f43f
Merge branch 'arthur/pydantic-and-flytepaths' of github.com:ArthurBoo…
ArthurBook Jun 12, 2023
f9ac9e9
full revamp and V2
ArthurBook Jun 15, 2023
da81828
made pydantic basemodel check explicit
ArthurBook Jun 15, 2023
95edbab
dynamic retrieval of supported flytetypes
ArthurBook Jun 15, 2023
b42bb4d
renamed FlyteObjectStore -> PydanticTransformerLiteralStore
ArthurBook Jun 15, 2023
7ccaa95
nit about a typehint
ArthurBook Jun 15, 2023
515a748
small changes to docstrings and types
ArthurBook Jun 15, 2023
24a4d0e
v2.1 w/ basemodel specific object store
ArthurBook Jun 15, 2023
68a2844
refactored
ArthurBook Jun 15, 2023
1466ff8
accomodate case where user has set validators on the type themselves
ArthurBook Jun 16, 2023
4c022de
Update plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_t…
ArthurBook Jun 20, 2023
96ef7e4
Update plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_t…
ArthurBook Jun 20, 2023
a477d90
Update plugins/flytekit-pydantic/flytekitplugins/pydantic/deserializa…
ArthurBook Jun 20, 2023
056b069
Update plugins/flytekit-pydantic/flytekitplugins/pydantic/deserializa…
ArthurBook Jun 20, 2023
aa33e31
fixed assumption that pydantic is installed in typecheck
ArthurBook Jun 20, 2023
deb42bb
comments revised
ArthurBook Jun 20, 2023
d55b251
improved tests to work with flyte types
ArthurBook Jun 20, 2023
0c4ce8c
arbitrary_types_allowed is not needed when we have the __get_validato…
ArthurBook Jun 20, 2023
bc27331
more tests
ArthurBook Jun 20, 2023
2026698
removed some leftover test code
ArthurBook Jun 20, 2023
0e9f215
changed serialization to align w dataclass_json logic in core type en…
ArthurBook Jun 23, 2023
4c463de
addressed comments
ArthurBook Jul 7, 2023
cf75326
removed nested literalmap that caused issues during serialization
ArthurBook Aug 29, 2023
eec501d
expanded dynamic task test
ArthurBook Aug 29, 2023
0c0a483
changed serialization from flat literalmap to nested
ArthurBook Aug 31, 2023
01c065e
add a unit test
wild-endeavor Sep 3, 2023
68ddfa5
Merge remote-tracking branch 'upstream/pr/1660' into arthur/pydantic-…
ArthurBook Sep 3, 2023
11235fb
linting issue fixed
ArthurBook Sep 5, 2023
1e946bb
revert typehint change in type engine
ArthurBook Sep 5, 2023
f58a55b
more lint fixes
ArthurBook Sep 11, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,13 @@ class PickleParamType(click.ParamType):
def convert(
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
) -> typing.Any:

uri = FlyteContextManager.current_context().file_access.get_random_local_path()
with open(uri, "w+b") as outfile:
cloudpickle.dump(value, outfile)
return FileParam(filepath=str(pathlib.Path(uri).resolve()))


class DateTimeType(click.DateTime):

_NOW_FMT = "now"
_ADDITONAL_FORMATS = [_NOW_FMT]

Expand Down Expand Up @@ -276,7 +274,6 @@ def get_uri_for_dir(
def convert_to_structured_dataset(
self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: Directory
) -> Literal:

uri = self.get_uri_for_dir(ctx, value, "00000.parquet")

lit = Literal(
Expand Down Expand Up @@ -338,7 +335,7 @@ def convert_to_union(
python_val = converter._click_type.convert(value, param, ctx)
literal = converter.convert_to_literal(ctx, param, python_val)
return Literal(scalar=Scalar(union=Union(literal, variant)))
except (Exception or AttributeError) as e:
except Exception or AttributeError as e:
logging.debug(f"Failed to convert python type {python_type} to literal type {variant}", e)
raise ValueError(f"Failed to convert python type {self._python_type} to literal type {lt}")

Expand Down Expand Up @@ -399,7 +396,10 @@ def convert_to_struct(
Convert the loaded json object to a Flyte Literal struct type.
"""
if type(value) != self._python_type:
o = cast(DataClassJsonMixin, self._python_type).from_json(json.dumps(value))
if is_pydantic_basemodel(self._python_type):
o = self._python_type.parse_raw(json.dumps(value)) # type: ignore
else:
o = cast(DataClassJsonMixin, self._python_type).from_json(json.dumps(value))
else:
o = value
return TypeEngine.to_literal(self._flyte_ctx, o, self._python_type, self._literal_type)
Expand Down Expand Up @@ -446,6 +446,15 @@ def convert(self, ctx, param, value) -> typing.Union[Literal, typing.Any]:
raise click.BadParameter(f"Failed to convert param {param}, {value} to {self._python_type}") from e


def is_pydantic_basemodel(python_type: typing.Type) -> bool:
try:
import pydantic
except ImportError:
return False
else:
return issubclass(python_type, pydantic.BaseModel)


def to_click_option(
ctx: click.Context,
flyte_ctx: FlyteContext,
Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ def register_additional_type(cls, transformer: TypeTransformer, additional_type:
cls._REGISTRY[additional_type] = transformer

@classmethod
def get_transformer(cls, python_type: Type) -> TypeTransformer[T]:
def get_transformer(cls, python_type: T) -> TypeTransformer[T]:
ArthurBook marked this conversation as resolved.
Show resolved Hide resolved
"""
The TypeEngine hierarchy for flyteKit. This method looksup and selects the type transformer. The algorithm is
as follows
Expand Down
28 changes: 28 additions & 0 deletions plugins/flytekit-pydantic/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Flytekit Pydantic Plugin

Pydantic is a data validation and settings management library that uses Python type annotations to enforce type hints at runtime and provide user-friendly errors when data is invalid. Pydantic models are classes that inherit from `pydantic.BaseModel` and are used to define the structure and validation of data using Python type annotations.

The plugin adds type support for pydantic models.

To install the plugin, run the following command:

```bash
pip install flytekitplugins-pydantic
```


## Type Example
```python
from pydantic import BaseModel


class TrainConfig(BaseModel):
lr: float = 1e-3
batch_size: int = 32
files: List[FlyteFile]
directories: List[FlyteDirectory]

@task
def train(cfg: TrainConfig):
...
ArthurBook marked this conversation as resolved.
Show resolved Hide resolved
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .basemodel_transformer import BaseModelTransformer
from .deserialization import set_validators_on_supported_flyte_types as _set_validators_on_supported_flyte_types

_set_validators_on_supported_flyte_types() # enables you to use flytekit.types in pydantic model
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Serializes & deserializes the pydantic basemodels """

from typing import Type

import pydantic
from google.protobuf import json_format
from typing_extensions import Annotated

from flytekit import FlyteContext
from flytekit.core import type_engine
from flytekit.models import literals, types

from . import deserialization, serialization

BaseModelLiteralValue = Annotated[
literals.LiteralMap,
"""
BaseModel serialized to a LiteralMap consisting of:
1) the basemodel json with placeholders for flyte types
2) mapping from placeholders to serialized flyte type values in the object store
""",
]


class BaseModelTransformer(type_engine.TypeTransformer[pydantic.BaseModel]):
_TYPE_INFO = types.LiteralType(simple=types.SimpleType.STRUCT)

def __init__(self):
"""Construct pydantic.BaseModelTransformer."""
super().__init__(name="basemodel-transform", t=pydantic.BaseModel)

def get_literal_type(self, t: Type[pydantic.BaseModel]) -> types.LiteralType:
return types.LiteralType(simple=types.SimpleType.STRUCT)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is my only worry. I don't think I've seen this pattern before (cc @EngHabu @hamersaw @eapolinario) and I'm not sure if there's code that will break in the rest of flyte/flytekit. The literal type returned by this transformer is a simple but the value given will be a Literal containing a LiteralMap. Is there any reason why this won't work? I feel like it should be fine, it's still just one binding, but wanted to check. Is this a dumb question? does this pattern already exist elsewhere that I'm forgetting?

The reason it's a literalmap is because flyte types are serialized to a separate dict, and a pointer (effectively think of it as a memory pointer) is placed instead - should we investigate adding a metadata field to the Literal message @EngHabu?

@ArthurBook/ @fg91 i guess my main question is why the need for the separate store? is there a binary/byte type in pydantic that can be somehow flagged that we can just serialize the pb message into?

[side note: the broader pattern here is offloading and returning an artifact, which we can of course make json serializable, and then just serialize the Artifact. this complicates the network call of course and this concept should be held off until we turn flytekit's type engine async.]

Copy link
Member

@fg91 fg91 Jul 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have (naively) used types.LiteralType(simple=types.SimpleType.STRUCT) before in an internal plugin and never saw any suspicious behavior. Doesn't mean there isn't anything I have missed though.

i guess my main question is why the need for the separate store? is there a binary/byte type in pydantic that can be somehow flagged that we can just serialize the pb message into?

The pydantic base model is serialized into a json string by pydantic. The offloaded objects are stored in a dict[str, Literal] by the logic of this type transformer.

Parsing the json string into a protobuf struct (see here) and then saving both the "json struct" and the "store" into a LiteralMap was my first idea to put these two into a single literal but I'm open to any other way as well!


def to_literal(
self,
ctx: FlyteContext,
python_val: pydantic.BaseModel,
python_type: Type[pydantic.BaseModel],
expected: types.LiteralType,
) -> BaseModelLiteralValue:
"""Convert a given ``pydantic.BaseModel`` to the Literal representation."""
return serialization.serialize_basemodel(python_val)

def to_python_value(
self,
ctx: FlyteContext,
lv: BaseModelLiteralValue,
expected_python_type: Type[pydantic.BaseModel],
) -> pydantic.BaseModel:
"""Re-hydrate the pydantic BaseModel object from Flyte Literal value."""
basemodel_json_w_placeholders = read_basemodel_json_from_literalmap(lv)
flyte_obj_literalmap = lv.literals[serialization.FLYTETYPE_OBJSTORE_KEY]
with deserialization.PydanticDeserializationLiteralStore.attach(flyte_obj_literalmap):
return expected_python_type.parse_raw(basemodel_json_w_placeholders)


def read_basemodel_json_from_literalmap(lv: BaseModelLiteralValue) -> serialization.SerializedBaseModel:
basemodel_literal: literals.Literal = lv.literals[serialization.BASEMODEL_JSON_KEY]
basemodel_json_w_placeholders = json_format.MessageToJson(basemodel_literal.scalar.generic)
assert isinstance(basemodel_json_w_placeholders, str)
return basemodel_json_w_placeholders


type_engine.TypeEngine.register(BaseModelTransformer())
29 changes: 29 additions & 0 deletions plugins/flytekit-pydantic/flytekitplugins/pydantic/commons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import builtins
import datetime
import typing

from typing_extensions import Annotated

import pandas as pd


from flytekit.core import type_engine

MODULES_TO_EXLCLUDE_FROM_FLYTE_TYPES = {m.__name__ for m in [builtins, typing, datetime]}


def include_in_flyte_types(t: type) -> bool:
if t is None:
return False
if t.__module__ in MODULES_TO_EXLCLUDE_FROM_FLYTE_TYPES:
return False
return True


PYDANTIC_SUPPORTED_FLYTE_TYPES = tuple(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change request (blocking): Being fully aware that type_engine.TypeEngine.get_available_transformers() was my suggestion, I don't think it gives us the desired result:

dict_keys([<class 'int'>, <class 'float'>, <class 'bool'>, <class 'str'>, <class 'datetime.datetime'>, <class 'datetime.timedelta'>, <class 'NoneType'>, None, <class 'list'>, typing.Union, <class 'dict'>, <class 'typing.TextIO'>, <class 'typing.BinaryIO'>, <enum 'Enum'>, <class 'google.protobuf.message.Message'>, <class 'tuple'>, typing.Tuple, <function NamedTuple at 0x1009d3370>, <class 'flytekit.types.pickle.pickle.FlytePickle'>, <class 'flytekit.types.file.file.FlyteFile'>, <class 'os.PathLike'>, <class 'flytekit.types.directory.types.FlyteDirectory'>, <class 'flytekit.types.structured.structured_dataset.StructuredDataset'>])

I would have expected supported types like pandas dataframe or torch tensor/module to be contained in this list. (Also ints, floats, ... should be excluded as well. What about BaseModels themselves?)

We might have to go with an explicit list after all.

@pingsutw do you know any other mechanism to discover all type transformers, including ones coming from (potentially private) plugins.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, think we can figure out something better here.
I'll leave this as an open for now waiting for a reply from @pingsutw

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We lazily loaded some transformers in the new flytekit, so you didn't see pandas, and tensor, etc. we can add lazy_load_transformer to get_available_transformers. therefore, we'll be able to see the pandas in the _REGISTRY.keys.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pingsutw is out for a bit. could you help me understand the reasoning for being aware of the transformers registered? I don't think knowing a priori all the types is necessarily the right approach. Would it be possible to make this registerable? Like, stock it comes with a certain set of types. If for some reason during the operation of a task, a user wants to make Pydantic aware of a certain type, they can add a flytekit TypeEngine transformer for that type, and then register that type with this pydantic engine. that's kinda how the structured dataset transformer works. what do you think?

filter(include_in_flyte_types, type_engine.TypeEngine.get_available_transformers())
) + (pd.DataFrame,)

# this is the UUID placeholder that is set in the serialized basemodel JSON, connecting that field to
# the literal map that holds the actual object that needs to be deserialized (w/ protobuf)
LiteralObjID = Annotated[str, "Key for unique object in literal map."]
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import contextlib
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Type, TypeVar, Union

import pydantic
from flytekit.core import context_manager, type_engine
from flytekit.models import literals

from flytekit.types import directory, file

from flytekitplugins.pydantic import commons, serialization
ArthurBook marked this conversation as resolved.
Show resolved Hide resolved


# this field is used by pydantic to get the validator method
PYDANTIC_VALIDATOR_METHOD_NAME = (
pydantic.BaseModel.__get_validators__.__name__
if pydantic.__version__ < "2.0.0"
else pydantic.BaseModel.__get_pydantic_core_schema__.__name___ # type: ignore
)
PythonType = TypeVar("PythonType") # target type of the deserialization


class PydanticDeserializationLiteralStore:
"""
The purpose of this class is to provide a context manager that can be used to deserialize a basemodel from a
literal map.

Because pydantic validators are fixed when subclassing a BaseModel, this object is a singleton that
serves as a namespace that can be set with the attach_to_literalmap context manager for the time that
a basemodel is being deserialized. The validators are then accessing this namespace for the flyteobj
placeholders that it is trying to deserialize.
"""

literal_store: Optional[serialization.LiteralStore] = None # attachement point for the literal map

def __init__(self) -> None:
raise Exception("This class should not be instantiated")

def __init_subclass__(cls) -> None:
raise Exception("This class should not be subclassed")

@classmethod
@contextlib.contextmanager
def attach(cls, literal_map: literals.LiteralMap) -> Generator[None, None, None]:
"""
Read a literal map and populate the object store from it.

This can be used as a context manager to attach to a literal map for the duration of a deserialization
Note that this is not threadsafe, and designed to manage a single deserialization at a time.
"""
assert not cls.is_attached(), "can only be attached to one literal map at a time."
try:
cls.literal_store = literal_map.literals
yield
finally:
cls.literal_store = None
ArthurBook marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def contains(cls, item: commons.LiteralObjID) -> bool:
assert cls.is_attached(), "can only check for existence of a literal when attached to a literal map"
assert cls.literal_store is not None
return item in cls.literal_store

@classmethod
def is_attached(cls) -> bool:
return cls.literal_store is not None

@classmethod
def get_python_object(
cls, identifier: commons.LiteralObjID, expected_type: Type[PythonType]
) -> Optional[PythonType]:
"""Deserialize a flyte literal and return the python object."""
if not cls.is_attached():
raise Exception("Must attach to a literal map before deserializing")
literal = cls.literal_store[identifier] # type: ignore
python_object = deserialize_flyte_literal(literal, expected_type)
return python_object


def set_validators_on_supported_flyte_types() -> None:
"""
Set pydantic validator for the flyte types supported by this plugin.
"""
for flyte_type in commons.PYDANTIC_SUPPORTED_FLYTE_TYPES:
setattr(flyte_type, PYDANTIC_VALIDATOR_METHOD_NAME, add_flyte_validators_for_type(flyte_type))


def add_flyte_validators_for_type(
flyte_obj_type: Type[type_engine.T],
) -> Callable[[Any], Iterator[Callable[[Any], type_engine.T]]]:
"""
Add flyte deserialisation validators to a type.
"""

previous_validators = getattr(flyte_obj_type, PYDANTIC_VALIDATOR_METHOD_NAME, lambda *_: [])()

def validator(object_uid_maybe: Union[commons.LiteralObjID, Any]) -> Union[type_engine.T, Any]:
"""Partial of deserialize_flyte_literal with the object_type fixed"""
if not PydanticDeserializationLiteralStore.is_attached():
return object_uid_maybe # this validator should only trigger when we are deserializeing
if not isinstance(object_uid_maybe, str):
return object_uid_maybe # object uids are strings and we dont want to trigger on other types
if not PydanticDeserializationLiteralStore.contains(object_uid_maybe):
return object_uid_maybe # final safety check to make sure that the object uid is in the literal map
return PydanticDeserializationLiteralStore.get_python_object(object_uid_maybe, flyte_obj_type)

def validator_generator(*args, **kwags) -> Iterator[Callable[[Any], type_engine.T]]:
"""Generator that returns validators."""
yield validator
yield from previous_validators
yield from ADDITIONAL_FLYTETYPE_VALIDATORS.get(flyte_obj_type, [])

return validator_generator


def validate_flytefile(flytefile: Union[str, file.FlyteFile]) -> file.FlyteFile:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't really understand these validators. is this really required? is this a pydantic thing? i feel like we should let the base flytekit type engine handle the transformations

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC one could consider validators as the to_python_value function of pydantic when deserializing.

"""Validate a flytefile (i.e. deserialize)."""
if isinstance(flytefile, file.FlyteFile):
return flytefile
if isinstance(flytefile, str): # when e.g. initializing from config
return file.FlyteFile(flytefile)
fg91 marked this conversation as resolved.
Show resolved Hide resolved
else:
raise ValueError(f"Invalid type for flytefile: {type(flytefile)}")


def validate_flytedir(flytedir: Union[str, directory.FlyteDirectory]) -> directory.FlyteDirectory:
"""Validate a flytedir (i.e. deserialize)."""
if isinstance(flytedir, directory.FlyteDirectory):
return flytedir
if isinstance(flytedir, str): # when e.g. initializing from config
return directory.FlyteDirectory(flytedir)
else:
raise ValueError(f"Invalid type for flytedir: {type(flytedir)}")


ADDITIONAL_FLYTETYPE_VALIDATORS: Dict[Type, List[Callable[[Any], Any]]] = {
file.FlyteFile: [validate_flytefile],
directory.FlyteDirectory: [validate_flytedir],
}


def deserialize_flyte_literal(
flyteobj_literal: literals.Literal, python_type: Type[PythonType]
) -> Optional[PythonType]:
"""Deserialize a Flyte Literal into the python object instance."""
ctx = context_manager.FlyteContext.current_context()
transformer = type_engine.TypeEngine.get_transformer(python_type)
python_obj = transformer.to_python_value(ctx, flyteobj_literal, python_type)
return python_obj
Loading