From fe9434f6b8db8d0ba7c74292a2400dcebda7f446 Mon Sep 17 00:00:00 2001 From: Adrian Rumpold Date: Sat, 20 May 2023 01:23:18 +0200 Subject: [PATCH 01/55] Allow annotated FlyteFile as task input argument (#1632) * fix: Allow annotated FlyteFile as task input argument Using an annotated FlyteFile type as an input to a task was previously impossible due to an exception being raised in `FlyteFilePathTransformer.to_python_value`. This commit applies the fix previously used in `FlyteFilePathTransformer.to_literal` to permit using annotated FlyteFiles as either inputs and outputs of a task. Issue: #3424 Signed-off-by: Adrian Rumpold * refactor: Unified handling of annotated types in type engine Issue: #3424 Signed-off-by: Adrian Rumpold * fix: Use py3.8-compatible types in type engine tests Issue: #3424 Signed-off-by: Adrian Rumpold --------- Signed-off-by: Adrian Rumpold Signed-off-by: Arthur --- flytekit/core/type_engine.py | 47 +++++++++++--------- flytekit/types/file/file.py | 9 ++-- tests/flytekit/unit/core/test_flyte_file.py | 13 ++++-- tests/flytekit/unit/core/test_type_engine.py | 27 +++++++++++ 4 files changed, 69 insertions(+), 27 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 5994390c8d..e01cb9a343 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -173,8 +173,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp return self._to_literal_transformer(python_val) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: - if get_origin(expected_python_type) is Annotated: - expected_python_type = get_args(expected_python_type)[0] + expected_python_type = get_underlying_type(expected_python_type) if expected_python_type != self._type: raise TypeTransformerFailedError( @@ -311,7 +310,7 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: Extracts the Literal type definition for a Dataclass and returns a type Struct. If possible also extracts the JSONSchema for the dataclass. """ - if get_origin(t) is Annotated: + if is_annotated(t): raise ValueError( "Flytekit does not currently have support for FlyteAnnotations applied to Dataclass." f"Type {t} cannot be parsed." @@ -368,7 +367,7 @@ def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]: self._get_origin_type_in_annotation(get_args(python_type)[0]), self._get_origin_type_in_annotation(get_args(python_type)[1]), ] - elif get_origin(python_type) is Annotated: + elif is_annotated(python_type): return get_args(python_type)[0] elif dataclasses.is_dataclass(python_type): for field in dataclasses.fields(copy.deepcopy(python_type)): @@ -737,7 +736,7 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]: """ cls.lazy_import_transformers() # Step 1 - if get_origin(python_type) is Annotated: + if is_annotated(python_type): args = get_args(python_type) for annotation in args: if isinstance(annotation, TypeTransformer): @@ -752,7 +751,7 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]: if hasattr(python_type, "__origin__"): # Handling of annotated generics, eg: # Annotated[typing.List[int], 'foo'] - if get_origin(python_type) is Annotated: + if is_annotated(python_type): return cls.get_transformer(get_args(python_type)[0]) if python_type.__origin__ in cls._REGISTRY: @@ -823,7 +822,7 @@ def to_literal_type(cls, python_type: Type) -> LiteralType: transformer = cls.get_transformer(python_type) res = transformer.get_literal_type(python_type) data = None - if get_origin(python_type) is Annotated: + if is_annotated(python_type): for x in get_args(python_type)[1:]: if not isinstance(x, FlyteAnnotation): continue @@ -851,9 +850,9 @@ def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type # In case the value is an annotated type we inspect the annotations and look for hash-related annotations. hash = None - if get_origin(python_type) is Annotated: + if is_annotated(python_type): # We are now dealing with one of two cases: - # 1. The annotated type is a `HashMethod`, which indicates that we should we should produce the hash using + # 1. The annotated type is a `HashMethod`, which indicates that we should produce the hash using # the method indicated in the annotation. # 2. The annotated type is being used for a different purpose other than calculating hash values, in which case # we should just continue. @@ -880,7 +879,7 @@ def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: T @classmethod def to_html(cls, ctx: FlyteContext, python_val: typing.Any, expected_python_type: Type[typing.Any]) -> str: transformer = cls.get_transformer(expected_python_type) - if get_origin(expected_python_type) is Annotated: + if is_annotated(expected_python_type): expected_python_type, *annotate_args = get_args(expected_python_type) from flytekit.deck.renderer import Renderable @@ -1004,7 +1003,7 @@ def get_sub_type(t: Type[T]) -> Type[T]: if hasattr(t, "__origin__"): # Handle annotation on list generic, eg: # Annotated[typing.List[int], 'foo'] - if get_origin(t) is Annotated: + if is_annotated(t): return ListTransformer.get_sub_type(get_args(t)[0]) if getattr(t, "__origin__") is list and hasattr(t, "__args__"): @@ -1030,7 +1029,7 @@ def is_batchable(t: Type): """ from flytekit.types.pickle import FlytePickle - if get_origin(t) is Annotated: + if is_annotated(t): return ListTransformer.is_batchable(get_args(t)[0]) if get_origin(t) is list: subtype = get_args(t)[0] @@ -1047,7 +1046,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp batch_size = len(python_val) # default batch size # parse annotated to get the number of items saved in a pickle file. - if get_origin(python_type) is Annotated: + if is_annotated(python_type): for annotation in get_args(python_type)[1:]: if isinstance(annotation, BatchSize): batch_size = annotation.val @@ -1191,8 +1190,7 @@ def get_sub_type_in_optional(t: Type[T]) -> Type[T]: return get_args(t)[0] def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]: - if get_origin(t) is Annotated: - t = get_args(t)[0] + t = get_underlying_type(t) try: trans: typing.List[typing.Tuple[TypeTransformer, typing.Any]] = [ @@ -1206,8 +1204,7 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]: raise ValueError(f"Type of Generic Union type is not supported, {e}") def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: - if get_origin(python_type) is Annotated: - python_type = get_args(python_type)[0] + python_type = get_underlying_type(python_type) found_res = False res = None @@ -1232,8 +1229,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp raise TypeTransformerFailedError(f"Cannot convert from {python_val} to {python_type}") def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> Optional[typing.Any]: - if get_origin(expected_python_type) is Annotated: - expected_python_type = get_args(expected_python_type)[0] + expected_python_type = get_underlying_type(expected_python_type) union_tag = None union_type = None @@ -1468,7 +1464,7 @@ def __init__(self): super().__init__(name="DefaultEnumTransformer", t=enum.Enum) def get_literal_type(self, t: Type[T]) -> LiteralType: - if get_origin(t) is Annotated: + if is_annotated(t): raise ValueError( f"Flytekit does not currently have support \ for FlyteAnnotations applied to enums. {t} cannot be \ @@ -1782,3 +1778,14 @@ def get(self, attr: str, as_type: Optional[typing.Type] = None) -> typing.Any: _register_default_type_transformers() + + +def is_annotated(t: Type) -> bool: + return get_origin(t) is Annotated + + +def get_underlying_type(t: Type) -> Type: + """Return the underlying type for annotated types or the type itself""" + if is_annotated(t): + return get_args(t)[0] + return t diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index bb8feb3d9c..d78ec152d7 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -8,10 +8,9 @@ from dataclasses_json import config, dataclass_json from marshmallow import fields -from typing_extensions import Annotated, get_args, get_origin from flytekit.core.context_manager import FlyteContext, FlyteContextManager -from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError, get_underlying_type from flytekit.loggers import logger from flytekit.models.core.types import BlobType from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar @@ -337,8 +336,7 @@ def to_literal( raise TypeTransformerFailedError("None value cannot be converted to a file.") # Correctly handle `Annotated[FlyteFile, ...]` by extracting the origin type - if get_origin(python_type) is Annotated: - python_type = get_args(python_type)[0] + python_type = get_underlying_type(python_type) if not (python_type is os.PathLike or issubclass(python_type, FlyteFile)): raise ValueError(f"Incorrect type {python_type}, must be either a FlyteFile or os.PathLike") @@ -413,6 +411,9 @@ def to_python_value( if expected_python_type is os.PathLike: return FlyteFile(uri) + # Correctly handle `Annotated[FlyteFile, ...]` by extracting the origin type + expected_python_type = get_underlying_type(expected_python_type) + # The rest of the logic is only for FlyteFile types. if not issubclass(expected_python_type, FlyteFile): # type: ignore raise TypeError(f"Neither os.PathLike nor FlyteFile specified {expected_python_type}") diff --git a/tests/flytekit/unit/core/test_flyte_file.py b/tests/flytekit/unit/core/test_flyte_file.py index b7f0a1aeee..5dd05cdffd 100644 --- a/tests/flytekit/unit/core/test_flyte_file.py +++ b/tests/flytekit/unit/core/test_flyte_file.py @@ -439,13 +439,20 @@ def test_flyte_file_annotated_hashmethod(local_dummy_file): def calc_hash(ff: FlyteFile) -> str: return str(ff.path) + HashedFlyteFile = Annotated[FlyteFile, HashMethod(calc_hash)] + @task - def t1(path: str) -> Annotated[FlyteFile, HashMethod(calc_hash)]: - return FlyteFile(path) + def t1(path: str) -> HashedFlyteFile: + return HashedFlyteFile(path) + + @task + def t2(ff: HashedFlyteFile) -> None: + print(ff.path) @workflow def wf(path: str) -> None: - t1(path=path) + ff = t1(path=path) + t2(ff=ff) wf(path=local_dummy_file) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 6d1b6829d5..2e52fdcf9d 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -41,6 +41,8 @@ UnionTransformer, convert_json_schema_to_python_class, dataclass_from_dict, + get_underlying_type, + is_annotated, ) from flytekit.exceptions import user as user_exceptions from flytekit.models import types as model_types @@ -1685,3 +1687,28 @@ def test_batch_pickle_list(python_val, python_type, expected_list_length): # data = task0() # task0() -> Annotated[typing.List[FlytePickle], BatchSize(2)] # task1(data=data) # task1(data: typing.List[FlytePickle]) assert pv == python_val + + +@pytest.mark.parametrize( + "t,expected", + [ + (list, False), + (Annotated[int, "tag"], True), + (Annotated[typing.List[str], "a", "b"], True), + (Annotated[typing.Dict[int, str], FlyteAnnotation({"foo": "bar"})], True), + ], +) +def test_is_annotated(t, expected): + assert is_annotated(t) == expected + + +@pytest.mark.parametrize( + "t,expected", + [ + (typing.List, typing.List), + (Annotated[int, "tag"], int), + (Annotated[typing.List[str], "a", "b"], typing.List[str]), + ], +) +def test_get_underlying_type(t, expected): + assert get_underlying_type(t) == expected From b4e6f8080f5fcc1bed10c8e46744d685ad6d0672 Mon Sep 17 00:00:00 2001 From: wirthual Date: Mon, 22 May 2023 07:58:04 -0700 Subject: [PATCH 02/55] Use logger instead of print statement in sqlalchemy plugin (#1651) * use logging info instead of print Signed-off-by: wirthual * isorted files Signed-off-by: wirthual * import root logger from flytekit Signed-off-by: wirthual --------- Signed-off-by: wirthual Signed-off-by: Arthur --- .../flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py b/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py index 8541bc6aed..8e8c464bd4 100644 --- a/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py +++ b/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py @@ -11,6 +11,7 @@ from flytekit.core.base_sql_task import SQLTask from flytekit.core.python_customized_container_task import PythonCustomizedContainerTask from flytekit.core.shim_task import ShimTaskExecutor +from flytekit.loggers import logger from flytekit.models import task as task_models from flytekit.models.security import Secret from flytekit.types.schema import FlyteSchema @@ -126,10 +127,10 @@ def execute_from_model(self, tt: task_models.TaskTemplate, **kwargs) -> typing.A tt.custom["connect_args"][key] = value engine = create_engine(tt.custom["uri"], connect_args=tt.custom["connect_args"], echo=False) - print(f"Connecting to db {tt.custom['uri']}") + logger.info(f"Connecting to db {tt.custom['uri']}") interpolated_query = SQLAlchemyTask.interpolate_query(tt.custom["query_template"], **kwargs) - print(f"Interpolated query {interpolated_query}") + logger.info(f"Interpolated query {interpolated_query}") with engine.begin() as connection: df = None if tt.interface.outputs: From ff734640260eb4865fdd4f98f24c0ccb4ba6232a Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 23 May 2023 21:48:30 -0700 Subject: [PATCH 03/55] Map over notebook task (#1650) * map over notebook Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * tests Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * add a flag Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * fix tests Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * lint Signed-off-by: Kevin Su * Fix tests Signed-off-by: Kevin Su * lint Signed-off-by: Kevin Su --------- Signed-off-by: Kevin Su Signed-off-by: Arthur --- flytekit/core/map_task.py | 17 ++++++++++++----- .../flytekitplugins/papermill/task.py | 8 +++++++- plugins/flytekit-papermill/setup.py | 2 +- plugins/flytekit-papermill/tests/test_task.py | 19 ++++++++++++++++++- 4 files changed, 38 insertions(+), 8 deletions(-) diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index b40b5029bb..52325ecb59 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -16,7 +16,7 @@ from flytekit.core.constants import SdkTaskType from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.interface import transform_interface_to_list_interface -from flytekit.core.python_function_task import PythonFunctionTask +from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask from flytekit.core.tracker import TrackedInstance from flytekit.core.utils import timeit from flytekit.exceptions import scopes as exception_scopes @@ -34,7 +34,7 @@ class MapPythonTask(PythonTask): def __init__( self, - python_function_task: typing.Union[PythonFunctionTask, functools.partial], + python_function_task: typing.Union[PythonFunctionTask, PythonInstanceTask, functools.partial], concurrency: Optional[int] = None, min_success_ratio: Optional[float] = None, bound_inputs: Optional[Set[str]] = None, @@ -65,7 +65,10 @@ def __init__( actual_task = python_function_task if not isinstance(actual_task, PythonFunctionTask): - raise ValueError("Map tasks can only compose of Python Functon Tasks currently") + if isinstance(actual_task, PythonInstanceTask): + pass + else: + raise ValueError("Map tasks can only compose of PythonFuncton and PythonInstanceTasks currently") if len(actual_task.python_interface.outputs.keys()) > 1: raise ValueError("Map tasks only accept python function tasks with 0 or 1 outputs") @@ -76,7 +79,11 @@ def __init__( collection_interface = transform_interface_to_list_interface(actual_task.python_interface, self._bound_inputs) self._run_task: PythonFunctionTask = actual_task - _, mod, f, _ = tracker.extract_task_module(actual_task.task_function) + if isinstance(actual_task, PythonInstanceTask): + mod = actual_task.task_type + f = actual_task.lhs + else: + _, mod, f, _ = tracker.extract_task_module(actual_task.task_function) h = hashlib.md5(collection_interface.__str__().encode("utf-8")).hexdigest() name = f"{mod}.map_{f}_{h}" @@ -271,7 +278,7 @@ def _raw_execute(self, **kwargs) -> Any: def map_task( - task_function: typing.Union[PythonFunctionTask, functools.partial], + task_function: typing.Union[PythonFunctionTask, PythonInstanceTask, functools.partial], concurrency: int = 0, min_success_ratio: float = 1.0, **kwargs, diff --git a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py index b1f472e99a..6f4ed6886c 100644 --- a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py +++ b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py @@ -133,6 +133,7 @@ def __init__( task_config: T = None, inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, outputs: typing.Optional[typing.Dict[str, typing.Type]] = None, + output_notebooks: typing.Optional[bool] = True, **kwargs, ): # Each instance of NotebookTask instantiates an underlying task with a dummy function that will only be used @@ -165,13 +166,16 @@ def __init__( if not os.path.exists(self._notebook_path): raise ValueError(f"Illegal notebook path passed in {self._notebook_path}") - if outputs: + if output_notebooks: + if outputs is None: + outputs = {} outputs.update( { self._IMPLICIT_OP_NOTEBOOK: self._IMPLICIT_OP_NOTEBOOK_TYPE, self._IMPLICIT_RENDERED_NOTEBOOK: self._IMPLICIT_RENDERED_NOTEBOOK_TYPE, } ) + super().__init__( name, task_config, @@ -287,6 +291,8 @@ def execute(self, **kwargs) -> Any: else: raise TypeError(f"Expected output {k} of type {type_v} not found in the notebook outputs") + if len(output_list) == 1: + return output_list[0] return tuple(output_list) def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any: diff --git a/plugins/flytekit-papermill/setup.py b/plugins/flytekit-papermill/setup.py index 33b9816081..538946a6d7 100644 --- a/plugins/flytekit-papermill/setup.py +++ b/plugins/flytekit-papermill/setup.py @@ -5,7 +5,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" plugin_requires = [ - "flytekit>=1.3.0b2,<2.0.0", + "flytekit", "papermill>=1.2.0", "nbconvert>=6.0.7", "ipykernel>=5.0.0", diff --git a/plugins/flytekit-papermill/tests/test_task.py b/plugins/flytekit-papermill/tests/test_task.py index 0e54e7082e..47db35793d 100644 --- a/plugins/flytekit-papermill/tests/test_task.py +++ b/plugins/flytekit-papermill/tests/test_task.py @@ -1,6 +1,7 @@ import datetime import os import tempfile +import typing import pandas as pd from flytekitplugins.papermill import NotebookTask @@ -8,7 +9,7 @@ from kubernetes.client import V1Container, V1PodSpec import flytekit -from flytekit import StructuredDataset, kwtypes, task +from flytekit import StructuredDataset, kwtypes, map_task, task, workflow from flytekit.configuration import Image, ImageConfig from flytekit.types.directory import FlyteDirectory from flytekit.types.file import FlyteFile, PythonNotebook @@ -33,6 +34,14 @@ def _get_nb_path(name: str, suffix: str = "", abs: bool = True, ext: str = ".ipy outputs=kwtypes(square=float), ) +nb_sub_task = NotebookTask( + name="test", + notebook_path=_get_nb_path(nb_name, abs=False), + inputs=kwtypes(a=float), + outputs=kwtypes(square=float), + output_notebooks=False, +) + def test_notebook_task_simple(): serialization_settings = flytekit.configuration.SerializationSettings( @@ -172,3 +181,11 @@ def create_sd() -> StructuredDataset: ) success, out, render = nb_types.execute(ff=ff, fd=fd, sd=sd) assert success is True, "Notebook execution failed" + + +def test_map_over_notebook_task(): + @workflow + def wf(a: float) -> typing.List[float]: + return map_task(nb_sub_task)(a=[a, a]) + + assert wf(a=3.14) == [9.8596, 9.8596] From e9a714bebfa6095b3cfd2ea0351a8f8f9cbb6e14 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 24 May 2023 11:00:43 -0700 Subject: [PATCH 04/55] Support single literals in tiny url (#1654) Signed-off-by: Yee Hing Tong Signed-off-by: Arthur --- doc-requirements.txt | 2 +- flytekit/remote/remote.py | 11 +++++++++-- setup.py | 3 +-- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/doc-requirements.txt b/doc-requirements.txt index 1929925e84..5264673f4f 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -244,7 +244,7 @@ flask==2.2.3 # via mlflow flatbuffers==23.1.21 # via tensorflow -flyteidl==1.5.4 +flyteidl==1.5.6 # via flytekit fonttools==4.38.0 # via matplotlib diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 8b05ba69dc..e0a411de50 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -227,7 +227,12 @@ def file_access(self) -> FileAccessProvider: def get( self, flyte_uri: typing.Optional[str] = None - ) -> typing.Optional[typing.Union[LiteralsResolver, HTML, bytes]]: + ) -> typing.Optional[typing.Union[LiteralsResolver, Literal, HTML, bytes]]: + """ + General function that works with flyte tiny urls. This can return outputs (in the form of LiteralsResolver, or + individual Literals for singular requests), or HTML if passed a deck link, or bytes containing HTML, + if ipython is not available locally. + """ if flyte_uri is None: raise user_exceptions.FlyteUserException("flyte_uri cannot be empty") ctx = self._ctx or FlyteContextManager.current_context() @@ -237,6 +242,8 @@ def get( if data_response.HasField("literal_map"): lm = LiteralMap.from_flyte_idl(data_response.literal_map) return LiteralsResolver(lm.literals) + elif data_response.HasField("literal"): + return data_response.literal elif data_response.HasField("pre_signed_urls"): if len(data_response.pre_signed_urls.signed_url) == 0: raise ValueError(f"Flyte url {flyte_uri} resolved to empty download link") @@ -258,7 +265,7 @@ def get( except user_exceptions.FlyteUserException as e: remote_logger.info(f"Error from Flyte backend when trying to fetch data: {e.__cause__}") - remote_logger.debug(f"Nothing found from {flyte_uri}") + remote_logger.info(f"Nothing found from {flyte_uri}") def remote_context(self): """Context manager with remote-specific configuration.""" diff --git a/setup.py b/setup.py index 7215101600..ed8a90bb33 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ }, install_requires=[ "googleapis-common-protos>=1.57", - "flyteidl>=1.5.4", + "flyteidl>=1.5.6", "wheel>=0.30.0,<1.0.0", "pandas>=1.0.0,<2.0.0", "pyarrow>=4.0.0,<11.0.0", @@ -64,7 +64,6 @@ "marshmallow-jsonschema>=0.12.0", "natsort>=7.0.1", "docker-image-py>=0.1.10", - "singledispatchmethod; python_version < '3.8.0'", "typing_extensions", "docstring-parser>=0.9.0", "diskcache>=5.2.1", From 17f3441d96ac753493802905ff76495ecfa9486e Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario <653394+eapolinario@users.noreply.github.com> Date: Wed, 24 May 2023 11:02:16 -0700 Subject: [PATCH 05/55] Skip grpcio 1.55.0 (#1653) Signed-off-by: eduardo apolinario Co-authored-by: eduardo apolinario Signed-off-by: Arthur --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index ed8a90bb33..5273590d9f 100644 --- a/setup.py +++ b/setup.py @@ -40,8 +40,8 @@ "python-dateutil>=2.1", # Restrict grpcio and grpcio-status. Version 1.50.0 pulls in a version of protobuf that is not compatible # with the old protobuf library (as described in https://developers.google.com/protocol-buffers/docs/news/2022-05-06) - "grpcio>=1.50.0,<2.0", - "grpcio-status>=1.50.0,<2.0", + "grpcio>=1.50.0,!=1.55.0,<2.0", + "grpcio-status>=1.50.0,!=1.55.0,<2.0", "importlib-metadata", "fsspec>=2023.3.0", "adlfs", From 594026aa94a60e29d5c0a6fee9044638772a1299 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 24 May 2023 11:21:49 -0700 Subject: [PATCH 06/55] Add support overriding image (#1652) Signed-off-by: Kevin Su Signed-off-by: Arthur --- flytekit/core/node.py | 2 ++ tests/flytekit/unit/core/test_node_creation.py | 13 +++++++++++++ 2 files changed, 15 insertions(+) diff --git a/flytekit/core/node.py b/flytekit/core/node.py index 4f7838d2b6..bf5c97ba60 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -128,6 +128,8 @@ def with_overrides(self, *args, **kwargs): if not isinstance(new_task_config, type(self.flyte_entity._task_config)): raise ValueError("can't change the type of the task config") self.flyte_entity._task_config = new_task_config + if "container_image" in kwargs: + self.flyte_entity._container_image = kwargs["container_image"] return self diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index da708a8571..81621ef3fc 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -452,3 +452,16 @@ def my_wf(a: str) -> str: return t1(a=a).with_overrides(task_config=None) my_wf() + + +def test_override_image(): + @task + def bar(): + print("hello") + + @workflow + def wf() -> str: + bar().with_overrides(container_image="hello/world") + return "hi" + + assert wf.nodes[0].flyte_entity.container_image == "hello/world" From 1cf2556a621f1c966dd5ed760667229a47012650 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Thu, 4 May 2023 22:13:27 +0200 Subject: [PATCH 07/55] Add setup.py Signed-off-by: Arthur --- plugins/flytekit-pydantic/setup.py | 39 ++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 plugins/flytekit-pydantic/setup.py diff --git a/plugins/flytekit-pydantic/setup.py b/plugins/flytekit-pydantic/setup.py new file mode 100644 index 0000000000..1ad734c482 --- /dev/null +++ b/plugins/flytekit-pydantic/setup.py @@ -0,0 +1,39 @@ +from setuptools import setup + +PLUGIN_NAME = "pydantic" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0", "pydantic"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="Plugin adding type support for Pydantic models", + url="https://github.com/flyteorg/flytekit/tree/master/plugins/flytekit-pydantic", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.8", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, +) From 82d07731f8195de8f9037d9df6f80020467c27c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Thu, 4 May 2023 22:19:55 +0200 Subject: [PATCH 08/55] Add readme Signed-off-by: Arthur --- plugins/flytekit-pydantic/README.md | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 plugins/flytekit-pydantic/README.md diff --git a/plugins/flytekit-pydantic/README.md b/plugins/flytekit-pydantic/README.md new file mode 100644 index 0000000000..ef67727d6a --- /dev/null +++ b/plugins/flytekit-pydantic/README.md @@ -0,0 +1,28 @@ +# Flytekit Pydantic Plugin + +Pydantic is a data validation and settings management library that uses Python type annotations to enforce type hints at runtime and provide user-friendly errors when data is invalid. Pydantic models are classes that inherit from `pydantic.BaseModel` and are used to define the structure and validation of data using Python type annotations. + +The plugin adds type support for pydantic models. + +To install the plugin, run the following command: + +```bash +pip install flytekitplugins-pydantic +``` + + +## Type Example +```python +from pydantic import BaseModel +import flytekitplugins.pydantic + + +class Config(BaseModel): + lr: float = 1e-3 + batch_size: int = 32 + + +@task +def train(cfg: Config): + ... +``` From c6d14a4660e9d3075f8cbf169e3a59646cbc3370 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Thu, 4 May 2023 22:28:54 +0200 Subject: [PATCH 09/55] Add and compile requirements Signed-off-by: Arthur --- plugins/flytekit-pydantic/requirements.in | 2 + plugins/flytekit-pydantic/requirements.txt | 347 +++++++++++++++++++++ 2 files changed, 349 insertions(+) create mode 100644 plugins/flytekit-pydantic/requirements.in create mode 100644 plugins/flytekit-pydantic/requirements.txt diff --git a/plugins/flytekit-pydantic/requirements.in b/plugins/flytekit-pydantic/requirements.in new file mode 100644 index 0000000000..44f25884d7 --- /dev/null +++ b/plugins/flytekit-pydantic/requirements.in @@ -0,0 +1,2 @@ +. +-e file:.#egg=flytekitplugins-pydantic diff --git a/plugins/flytekit-pydantic/requirements.txt b/plugins/flytekit-pydantic/requirements.txt new file mode 100644 index 0000000000..68acf7008a --- /dev/null +++ b/plugins/flytekit-pydantic/requirements.txt @@ -0,0 +1,347 @@ +# +# This file is autogenerated by pip-compile with Python 3.10 +# by the following command: +# +# pip-compile requirements.in +# +-e file:.#egg=flytekitplugins-pydantic + # via -r requirements.in +adal==1.2.7 + # via azure-datalake-store +adlfs==2023.4.0 + # via flytekit +aiobotocore==2.5.0 + # via s3fs +aiohttp==3.8.4 + # via + # adlfs + # aiobotocore + # gcsfs + # s3fs +aioitertools==0.11.0 + # via aiobotocore +aiosignal==1.3.1 + # via aiohttp +arrow==1.2.3 + # via jinja2-time +async-timeout==4.0.2 + # via aiohttp +attrs==23.1.0 + # via aiohttp +azure-core==1.26.4 + # via + # adlfs + # azure-identity + # azure-storage-blob +azure-datalake-store==0.0.52 + # via adlfs +azure-identity==1.12.0 + # via adlfs +azure-storage-blob==12.16.0 + # via adlfs +binaryornot==0.4.4 + # via cookiecutter +botocore==1.29.76 + # via aiobotocore +cachetools==5.3.0 + # via google-auth +certifi==2022.12.7 + # via + # kubernetes + # requests +cffi==1.15.1 + # via + # azure-datalake-store + # cryptography +chardet==5.1.0 + # via binaryornot +charset-normalizer==3.1.0 + # via + # aiohttp + # requests +click==8.1.3 + # via + # cookiecutter + # flytekit +cloudpickle==2.2.1 + # via flytekit +cookiecutter==2.1.1 + # via flytekit +croniter==1.3.14 + # via flytekit +cryptography==40.0.2 + # via + # adal + # azure-identity + # azure-storage-blob + # msal + # pyjwt + # pyopenssl +dataclasses-json==0.5.7 + # via flytekit +decorator==5.1.1 + # via gcsfs +deprecated==1.2.13 + # via flytekit +diskcache==5.6.1 + # via flytekit +docker==6.0.1 + # via flytekit +docker-image-py==0.1.12 + # via flytekit +docstring-parser==0.15 + # via flytekit +flyteidl==1.3.20 + # via flytekit +flytekit==1.5.0 + # via flytekitplugins-pydantic +frozenlist==1.3.3 + # via + # aiohttp + # aiosignal +fsspec==2023.4.0 + # via + # adlfs + # flytekit + # gcsfs + # s3fs +gcsfs==2023.4.0 + # via flytekit +gitdb==4.0.10 + # via gitpython +gitpython==3.1.31 + # via flytekit +google-api-core==2.11.0 + # via + # google-cloud-core + # google-cloud-storage +google-auth==2.17.3 + # via + # gcsfs + # google-api-core + # google-auth-oauthlib + # google-cloud-core + # google-cloud-storage + # kubernetes +google-auth-oauthlib==1.0.0 + # via gcsfs +google-cloud-core==2.3.2 + # via google-cloud-storage +google-cloud-storage==2.9.0 + # via gcsfs +google-crc32c==1.5.0 + # via google-resumable-media +google-resumable-media==2.5.0 + # via google-cloud-storage +googleapis-common-protos==1.59.0 + # via + # flyteidl + # flytekit + # google-api-core + # grpcio-status +grpcio==1.54.0 + # via + # flytekit + # grpcio-status +grpcio-status==1.54.0 + # via flytekit +idna==3.4 + # via + # requests + # yarl +importlib-metadata==6.6.0 + # via + # flytekit + # keyring +isodate==0.6.1 + # via azure-storage-blob +jaraco-classes==3.2.3 + # via keyring +jinja2==3.1.2 + # via + # cookiecutter + # jinja2-time +jinja2-time==0.2.0 + # via cookiecutter +jmespath==1.0.1 + # via botocore +joblib==1.2.0 + # via flytekit +keyring==23.13.1 + # via flytekit +kubernetes==26.1.0 + # via flytekit +markupsafe==2.1.2 + # via jinja2 +marshmallow==3.19.0 + # via + # dataclasses-json + # marshmallow-enum + # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow-jsonschema==0.13.0 + # via flytekit +more-itertools==9.1.0 + # via jaraco-classes +msal==1.22.0 + # via + # azure-identity + # msal-extensions +msal-extensions==1.0.0 + # via azure-identity +multidict==6.0.4 + # via + # aiohttp + # yarl +mypy-extensions==1.0.0 + # via typing-inspect +natsort==8.3.1 + # via flytekit +numpy==1.24.3 + # via + # flytekit + # pandas + # pyarrow +oauthlib==3.2.2 + # via requests-oauthlib +packaging==23.1 + # via + # docker + # marshmallow +pandas==1.5.3 + # via flytekit +portalocker==2.7.0 + # via msal-extensions +protobuf==4.22.3 + # via + # flyteidl + # google-api-core + # googleapis-common-protos + # grpcio-status + # protoc-gen-swagger +protoc-gen-swagger==0.1.0 + # via flyteidl +pyarrow==10.0.1 + # via flytekit +pyasn1==0.5.0 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.3.0 + # via google-auth +pycparser==2.21 + # via cffi +pydantic==1.10.7 + # via flytekitplugins-pydantic +pyjwt[crypto]==2.6.0 + # via + # adal + # msal +pyopenssl==23.1.1 + # via flytekit +python-dateutil==2.8.2 + # via + # adal + # arrow + # botocore + # croniter + # flytekit + # kubernetes + # pandas +python-json-logger==2.0.7 + # via flytekit +python-slugify==8.0.1 + # via cookiecutter +pytimeparse==1.1.8 + # via flytekit +pytz==2023.3 + # via + # flytekit + # pandas +pyyaml==6.0 + # via + # cookiecutter + # flytekit + # kubernetes + # responses +regex==2023.5.5 + # via docker-image-py +requests==2.30.0 + # via + # adal + # azure-core + # azure-datalake-store + # cookiecutter + # docker + # flytekit + # gcsfs + # google-api-core + # google-cloud-storage + # kubernetes + # msal + # requests-oauthlib + # responses +requests-oauthlib==1.3.1 + # via + # google-auth-oauthlib + # kubernetes +responses==0.23.1 + # via flytekit +rsa==4.9 + # via google-auth +s3fs==2023.4.0 + # via flytekit +six==1.16.0 + # via + # azure-core + # azure-identity + # google-auth + # isodate + # kubernetes + # python-dateutil +smmap==5.0.0 + # via gitdb +sortedcontainers==2.4.0 + # via flytekit +statsd==3.3.0 + # via flytekit +text-unidecode==1.3 + # via python-slugify +types-pyyaml==6.0.12.9 + # via responses +typing-extensions==4.5.0 + # via + # azure-core + # azure-storage-blob + # flytekit + # pydantic + # typing-inspect +typing-inspect==0.8.0 + # via dataclasses-json +urllib3==1.26.15 + # via + # botocore + # docker + # flytekit + # kubernetes + # requests + # responses +websocket-client==1.5.1 + # via + # docker + # kubernetes +wheel==0.40.0 + # via flytekit +wrapt==1.15.0 + # via + # aiobotocore + # deprecated + # flytekit +yarl==1.9.2 + # via aiohttp +zipp==3.15.0 + # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools From 1433cd21344fe30cee3ab4f90cdc5bae60673a1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Thu, 4 May 2023 22:55:30 +0200 Subject: [PATCH 10/55] Add tests for type transformer Signed-off-by: Arthur --- .../tests/test_type_transformer.py | 105 ++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 plugins/flytekit-pydantic/tests/test_type_transformer.py diff --git a/plugins/flytekit-pydantic/tests/test_type_transformer.py b/plugins/flytekit-pydantic/tests/test_type_transformer.py new file mode 100644 index 0000000000..946db9900e --- /dev/null +++ b/plugins/flytekit-pydantic/tests/test_type_transformer.py @@ -0,0 +1,105 @@ +from typing import Any, Optional, Type, Union + +import flytekitplugins.pydantic # noqa F401 +import pytest +from flytekitplugins.pydantic import BaseModelTransformer +from pydantic import BaseModel, Extra + +from flytekit import task, workflow +from flytekit.core.type_engine import TypeTransformerFailedError + + +class TrainConfig(BaseModel): + """Config BaseModel for testing purposes.""" + + batch_size: int = 32 + lr: float = 1e-3 + loss: str = "cross_entropy" + + class Config: + extra = Extra.forbid + + +class Config(BaseModel): + """Config BaseModel for testing purposes with an optional type hint.""" + + model_config: Optional[Union[dict[str, TrainConfig], TrainConfig]] = TrainConfig() + + +class ConfigRequired(BaseModel): + """Config BaseModel for testing purposes with required attribute.""" + + model_config: Union[dict[str, TrainConfig], TrainConfig] + + +class ChildConfig(Config): + """Child class config BaseModel for testing purposes.""" + + d: list[int] = [1, 2, 3] + + +@pytest.mark.parametrize( + "python_type,kwargs", + [(Config, {}), (ConfigRequired, {"model_config": TrainConfig()}), (TrainConfig, {}), (TrainConfig, {})], +) +def test_transform_round_trip(python_type: Type, kwargs: dict[str, Any]): + """Test that a (de-)serialization roundtrip results in the identical BaseModel.""" + from flytekit.core.context_manager import FlyteContextManager + + ctx = FlyteContextManager().current_context() + + type_transformer = BaseModelTransformer() + + python_value = python_type(**kwargs) + + literal_value = type_transformer.to_literal( + ctx, + python_value, + python_type, + type_transformer.get_literal_type(python_value), + ) + + reconstructed_value = type_transformer.to_python_value(ctx, literal_value, type(python_value)) + + assert reconstructed_value == python_value + assert reconstructed_value.schema() == python_value.schema() + + +@pytest.mark.parametrize( + "config_type,kwargs", + [ + (Config, {"model_config": {"foo": TrainConfig(loss="mse")}}), + (ConfigRequired, {"model_config": {"foo": TrainConfig(loss="mse")}}), + ], +) +def test_pass_to_workflow(config_type: Type, kwargs: dict[str, Any]): + """Test passing a BaseModel instance to a workflow works.""" + cfg = config_type(**kwargs) + + @task + def train(cfg: config_type) -> config_type: + return cfg + + @workflow + def wf(cfg: config_type) -> config_type: + return train(cfg=cfg) + + returned_cfg = wf(cfg=cfg) + + assert cfg == returned_cfg + + +def test_pass_wrong_type_to_workflow(): + """Test passing the wrong type raises exception.""" + cfg = ChildConfig() + + @task + def train(cfg: Config) -> Config: + return cfg + + @workflow + def wf(cfg: Config) -> Config: + return train(cfg=cfg) + + with pytest.raises(TypeTransformerFailedError, match="The schema"): + wf(cfg=cfg) From 640d17a86b1859b2e6b5f934baadfdfaf4b1ce5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Thu, 4 May 2023 22:56:09 +0200 Subject: [PATCH 11/55] Add type transformer Signed-off-by: Arthur --- .../flytekitplugins/pydantic/__init__.py | 1 + .../flytekitplugins/pydantic/schema.py | 51 +++++++++++++++++++ 2 files changed, 52 insertions(+) create mode 100644 plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py create mode 100644 plugins/flytekit-pydantic/flytekitplugins/pydantic/schema.py diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py new file mode 100644 index 0000000000..8b448cd258 --- /dev/null +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py @@ -0,0 +1 @@ +from .schema import BaseModelTransformer diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/schema.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/schema.py new file mode 100644 index 0000000000..4aaafc375d --- /dev/null +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/schema.py @@ -0,0 +1,51 @@ +from typing import Type + +from google.protobuf.json_format import MessageToDict +from google.protobuf.struct_pb2 import Struct +from pydantic import BaseModel + +from flytekit import FlyteContext +from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.models.literals import Literal, Scalar +from flytekit.models.types import LiteralType, SimpleType + + +class BaseModelTransformer(TypeTransformer[BaseModel]): + _TYPE_INFO = LiteralType(simple=SimpleType.STRUCT) + + def __init__(self): + """Construct BaseModelTransformer.""" + super().__init__(name="basemodel-transform", t=BaseModel) + + def get_literal_type(self, t: Type[BaseModel]) -> LiteralType: + return LiteralType(simple=SimpleType.STRUCT) + + def to_literal( + self, + ctx: FlyteContext, + python_val: BaseModel, + python_type: Type[BaseModel], + expected: LiteralType, + ) -> Literal: + """This method is used to convert from given python type object pydantic ``BaseModel`` to the Literal representation.""" + s = Struct() + + s.update({"schema": python_val.schema(), "data": python_val.dict()}) + + return Literal(scalar=Scalar(generic=s)) + + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[BaseModel]) -> BaseModel: + """In this function we want to be able to re-hydrate the pydantic BaseModel object from Flyte Literal value.""" + base_model = MessageToDict(lv.scalar.generic) + schema = base_model["schema"] + data = base_model["data"] + + if (expected_schema := expected_python_type.schema()) != schema: + raise TypeTransformerFailedError( + f"The schema `{expected_schema}` of the expected python type {expected_python_type} is not equal to the received schema `{schema}`." + ) + + return expected_python_type.parse_obj(data) + + +TypeEngine.register(BaseModelTransformer()) From 8325a627633bfe4332e123e87d39ec923b84774f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Fri, 5 May 2023 15:32:46 +0200 Subject: [PATCH 12/55] Make docstring more concise Signed-off-by: Arthur --- plugins/flytekit-pydantic/flytekitplugins/pydantic/schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/schema.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/schema.py index 4aaafc375d..cf29f03865 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/schema.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/schema.py @@ -35,7 +35,7 @@ def to_literal( return Literal(scalar=Scalar(generic=s)) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[BaseModel]) -> BaseModel: - """In this function we want to be able to re-hydrate the pydantic BaseModel object from Flyte Literal value.""" + """Re-hydrate the pydantic BaseModel object from Flyte Literal value.""" base_model = MessageToDict(lv.scalar.generic) schema = base_model["schema"] data = base_model["data"] From f2140b69b949720e6317e651f12027f63b0cb380 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 25 May 2023 19:07:11 -0700 Subject: [PATCH 13/55] pydantic with flytepath and flytedirectory Signed-off-by: Arthur --- flytekit/clis/sdk_in_container/run.py | 5 +- plugins/flytekit-pydantic/README.md | 4 +- .../flytekitplugins/pydantic/__init__.py | 1 + .../pydantic/basemodel_extensions.py | 152 ++++++++++++++++++ .../pydantic/flytepath_creation.py | 67 ++++++++ .../flytekitplugins/pydantic/schema.py | 70 ++++---- 6 files changed, 266 insertions(+), 33 deletions(-) create mode 100644 plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_extensions.py create mode 100644 plugins/flytekit-pydantic/flytekitplugins/pydantic/flytepath_creation.py diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 336ffbdad6..0157c5e9ed 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -399,7 +399,10 @@ def convert_to_struct( Convert the loaded json object to a Flyte Literal struct type. """ if type(value) != self._python_type: - o = cast(DataClassJsonMixin, self._python_type).from_json(json.dumps(value)) + if hasattr(self._python_type, "parse_raw"): # e.g pydantic basemodel + o = self._python_type.parse_raw(json.dumps(value)) + else: + o = cast(DataClassJsonMixin, self._python_type).from_json(json.dumps(value)) else: o = value return TypeEngine.to_literal(self._flyte_ctx, o, self._python_type, self._literal_type) diff --git a/plugins/flytekit-pydantic/README.md b/plugins/flytekit-pydantic/README.md index ef67727d6a..04ffaa849f 100644 --- a/plugins/flytekit-pydantic/README.md +++ b/plugins/flytekit-pydantic/README.md @@ -17,9 +17,11 @@ from pydantic import BaseModel import flytekitplugins.pydantic -class Config(BaseModel): +class Config(BaseModel, **flytekitplugins.pydantic.pydantic_flyteobject_config): lr: float = 1e-3 batch_size: int = 32 + files: List[FlyteFile] + directories: List[FlyteDirectory] @task diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py index 8b448cd258..41f692637d 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py @@ -1 +1,2 @@ +from .basemodel_extensions import pydantic_flyteobject_config from .schema import BaseModelTransformer diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_extensions.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_extensions.py new file mode 100644 index 0000000000..988ada8b75 --- /dev/null +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_extensions.py @@ -0,0 +1,152 @@ +import abc +from typing import Any, Dict, Generic, Optional, Type, TypeVar, Union, cast + +from typing_extensions import TypedDict + +from flytekit.types.directory import types as flyte_directory_types +from flytekit.types.file import file as flyte_file + +from . import flytepath_creation + +## ==================================================================================== +# Base class +## ==================================================================================== + + +Serializable = TypeVar("Serializable") +SerializedFlyteType = TypeVar("SerializedFlyteType") + + +def get_pydantic_flyteobject_config() -> Dict[str, Any]: + """ + Returns the pydantic config for the serializers/deserializers + """ + return { + "json_encoders": { + serializer_deserializer()._t: serializer_deserializer().serialize # type: ignore + for serializer_deserializer in PydanticSerializerDeserializerBase.__subclasses__() + }, + } + + +class PydanticSerializerDeserializerBase( + abc.ABC, + Generic[Serializable, SerializedFlyteType], +): + """Base class for object serializers/deserializers""" + + @abc.abstractmethod + def serialize(self, obj: Serializable) -> SerializedFlyteType: + pass + + @abc.abstractmethod + def deserialize(self, obj: SerializedFlyteType) -> Serializable: + pass + + def __init__(self, t: Type[Serializable]): + self._t = t + self._set_validator_on_serializeable(t) + + def _set_validator_on_serializeable(self, serializeable: Type[Serializable]) -> None: + """ + Sets the validator on the pydantic model for the + type that is being serialized/deserialized + """ + setattr(serializeable, "__get_validators__", lambda *_: (self.deserialize,)) + + +## ==================================================================================== +# FlyteDir +## ==================================================================================== + + +class FlyteDirJsonEncoded(TypedDict): + """JSON representation of a FlyteDirectory""" + + remote_source: str + + +FLYTEDIR_DESERIALIZABLE_TYPES = Union[flyte_directory_types.FlyteDirectory, str, FlyteDirJsonEncoded] + + +class FlyteDirSerializerDeserializer( + PydanticSerializerDeserializerBase[flyte_directory_types.FlyteDirectory, FlyteDirJsonEncoded] +): + def __init__( + self, + t: Type[flyte_directory_types.FlyteDirectory] = flyte_directory_types.FlyteDirectory, + ): + super().__init__(t) + + def serialize(self, obj: flyte_directory_types.FlyteDirectory) -> FlyteDirJsonEncoded: + return {"remote_source": obj.remote_source} + + def deserialize(self, obj: FLYTEDIR_DESERIALIZABLE_TYPES) -> flyte_directory_types.FlyteDirectory: + flytedir = validate_flytedir(obj) + if flytedir is None: + raise ValueError(f"Could not deserialize {obj} to FlyteDirectory") + return flytedir + + +def validate_flytedir( + flytedir: FLYTEDIR_DESERIALIZABLE_TYPES, +) -> Optional[flyte_directory_types.FlyteDirectory]: + """validator for flytedir (i.e. deserializer)""" + if isinstance(flytedir, dict): # this is a json encoded flytedir + flytedir = cast(FlyteDirJsonEncoded, flytedir) + path = flytedir["remote_source"] + return flytepath_creation.make_flytepath(path, flyte_directory_types.FlyteDirectory) + elif isinstance(flytedir, str): # when e.g. initializing from config + return flytepath_creation.make_flytepath(flytedir, flyte_directory_types.FlyteDirectory) + elif isinstance(flytedir, flyte_directory_types.FlyteDirectory): + return flytedir + else: + raise ValueError(f"Invalid type for flytedir: {type(flytedir)}") + + +## ==================================================================================== +# FlyteFile +## ==================================================================================== + + +class FlyteFileJsonEncoded(TypedDict): + """JSON representation of a FlyteFile""" + + remote_source: str + + +FLYTEFILE_DESERIALIZABLE_TYPES = Union[flyte_directory_types.FlyteFile, str, FlyteFileJsonEncoded] + + +class FlyteFileSerializerDeserializer(PydanticSerializerDeserializerBase[flyte_file.FlyteFile, FlyteFileJsonEncoded]): + def __init__(self, t: Type[flyte_file.FlyteFile] = flyte_file.FlyteFile): + super().__init__(t) + + def serialize(self, obj: flyte_file.FlyteFile) -> FlyteFileJsonEncoded: + return {"remote_source": obj.remote_source} + + def deserialize(self, obj: FLYTEFILE_DESERIALIZABLE_TYPES) -> flyte_file.FlyteFile: + flyte_file = validate_flytefile(obj) + if flyte_file is None: + raise ValueError(f"Could not deserialize {obj} to FlyteFile") + return flyte_file + + +def validate_flytefile( + flytedir: FLYTEFILE_DESERIALIZABLE_TYPES, +) -> Optional[flyte_file.FlyteFile]: + """validator for flytedir (i.e. deserializer)""" + if isinstance(flytedir, dict): # this is a json encoded flytedir + flytedir = cast(FlyteFileJsonEncoded, flytedir) + path = flytedir["remote_source"] + return flytepath_creation.make_flytepath(path, flyte_file.FlyteFile) + elif isinstance(flytedir, str): # when e.g. initializing from config + return flytepath_creation.make_flytepath(flytedir, flyte_file.FlyteFile) + elif isinstance(flytedir, flyte_directory_types.FlyteFile): + return flytedir + else: + raise ValueError(f"Invalid type for flytedir: {type(flytedir)}") + + +# add these to your basemodel config to enable serialization/deserialization of flyte objects. +pydantic_flyteobject_config = get_pydantic_flyteobject_config() diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/flytepath_creation.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/flytepath_creation.py new file mode 100644 index 0000000000..e5c7d0a679 --- /dev/null +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/flytepath_creation.py @@ -0,0 +1,67 @@ +## ==================================================================================== +# Flyte directory and files creation +## ==================================================================================== + +import os +from typing import Any, Dict, Optional, Type, TypeVar, Union + +from flytekit.core import context_manager, type_engine +from flytekit.models import literals +from flytekit.models.core import types as core_types +from flytekit.types.directory import types as flyte_directory_types +from flytekit.types.file import file as flyte_file + +FlytePath = TypeVar("FlytePath", flyte_file.FlyteFile, flyte_directory_types.FlyteDirectory) + + +def make_flytepath(path: Union[str, os.PathLike], flyte_type: Type[FlytePath]) -> Optional[FlytePath]: + """create a FlyteDirectory from a path""" + context = context_manager.FlyteContextManager.current_context() + transformer = get_flyte_transformer(flyte_type) + dimensionality = get_flyte_dimensionality(flyte_type) + literal = make_literal(uri=path, dimensionality=dimensionality) + out_dir = transformer.to_python_value(context, literal, flyte_type) + return out_dir + + +def get_flyte_transformer( + flyte_type: Type[FlytePath], +) -> type_engine.TypeTransformer[FlytePath]: + """get the transformer for a given flyte type""" + return FLYTE_TRANSFORMERS[flyte_type] + + +FLYTE_TRANSFORMERS: Dict[Type, type_engine.TypeTransformer] = { + flyte_file.FlyteFile: flyte_file.FlyteFilePathTransformer(), + flyte_directory_types.FlyteDirectory: flyte_directory_types.FlyteDirToMultipartBlobTransformer(), +} + + +def get_flyte_dimensionality( + flyte_type: Type[FlytePath], +) -> str: + """get the transformer for a given flyte type""" + return FLYTE_DIMENSIONALITY[flyte_type] + + +FLYTE_DIMENSIONALITY: Dict[Type, Any] = { + flyte_file.FlyteFile: core_types.BlobType.BlobDimensionality.SINGLE, + flyte_directory_types.FlyteDirectory: core_types.BlobType.BlobDimensionality.MULTIPART, +} + + +def make_literal( + uri: Union[str, os.PathLike], + dimensionality, +) -> literals.Literal: + scalar = make_scalar(uri, dimensionality) + return literals.Literal(scalar=scalar) # type: ignore + + +def make_scalar( + uri: Union[str, os.PathLike], + dimensionality, +) -> literals.Scalar: + blobtype = core_types.BlobType(format="", dimensionality=dimensionality) + blob = literals.Blob(metadata=literals.BlobMetadata(type=blobtype), uri=uri) + return literals.Scalar(blob=blob) # type: ignore diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/schema.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/schema.py index cf29f03865..08188039cb 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/schema.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/schema.py @@ -1,51 +1,59 @@ from typing import Type -from google.protobuf.json_format import MessageToDict -from google.protobuf.struct_pb2 import Struct -from pydantic import BaseModel +import pydantic +from google.protobuf import json_format, struct_pb2 from flytekit import FlyteContext -from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError -from flytekit.models.literals import Literal, Scalar -from flytekit.models.types import LiteralType, SimpleType +from flytekit.core import type_engine +from flytekit.models import literals, types +""" +Serializes & deserializes the pydantic basemodels +""" -class BaseModelTransformer(TypeTransformer[BaseModel]): - _TYPE_INFO = LiteralType(simple=SimpleType.STRUCT) + +class BaseModelTransformer(type_engine.TypeTransformer[pydantic.BaseModel]): + _TYPE_INFO = types.LiteralType(simple=types.SimpleType.STRUCT) def __init__(self): - """Construct BaseModelTransformer.""" - super().__init__(name="basemodel-transform", t=BaseModel) + """Construct pydantic.BaseModelTransformer.""" + super().__init__(name="basemodel-transform", t=pydantic.BaseModel) - def get_literal_type(self, t: Type[BaseModel]) -> LiteralType: - return LiteralType(simple=SimpleType.STRUCT) + def get_literal_type(self, t: Type[pydantic.BaseModel]) -> types.LiteralType: + return types.LiteralType(simple=types.SimpleType.STRUCT) def to_literal( self, ctx: FlyteContext, - python_val: BaseModel, - python_type: Type[BaseModel], - expected: LiteralType, - ) -> Literal: - """This method is used to convert from given python type object pydantic ``BaseModel`` to the Literal representation.""" - s = Struct() - - s.update({"schema": python_val.schema(), "data": python_val.dict()}) - - return Literal(scalar=Scalar(generic=s)) - - def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[BaseModel]) -> BaseModel: - """Re-hydrate the pydantic BaseModel object from Flyte Literal value.""" - base_model = MessageToDict(lv.scalar.generic) + python_val: pydantic.BaseModel, + python_type: Type[pydantic.BaseModel], + expected: types.LiteralType, + ) -> literals.Literal: + """This method is used to convert from given python type object pydantic ``pydantic.BaseModel`` to the Literal representation.""" + + s = struct_pb2.Struct() + schema = python_val.schema_json() + data = python_val.json() + s.update({"schema": schema, "data": data}) + literal = literals.Literal(scalar=literals.Scalar(generic=s)) # type: ignore + return literal + + def to_python_value( + self, + ctx: FlyteContext, + lv: literals.Literal, + expected_python_type: Type[pydantic.BaseModel], + ) -> pydantic.BaseModel: + """Re-hydrate the pydantic pydantic.BaseModel object from Flyte Literal value.""" + base_model = json_format.MessageToDict(lv.scalar.generic) schema = base_model["schema"] data = base_model["data"] - - if (expected_schema := expected_python_type.schema()) != schema: - raise TypeTransformerFailedError( + if (expected_schema := expected_python_type.schema_json()) != schema: + raise type_engine.TypeTransformerFailedError( f"The schema `{expected_schema}` of the expected python type {expected_python_type} is not equal to the received schema `{schema}`." ) - return expected_python_type.parse_obj(data) + return expected_python_type.parse_raw(data) -TypeEngine.register(BaseModelTransformer()) +type_engine.TypeEngine.register(BaseModelTransformer()) From 2ea5b44e1e907300e05afa9e7c2d9e39ea18d683 Mon Sep 17 00:00:00 2001 From: Arthur Date: Fri, 26 May 2023 10:53:43 -0700 Subject: [PATCH 14/55] added example for how to use with BaseModel Config class Signed-off-by: Arthur --- plugins/flytekit-pydantic/README.md | 13 +++++++++++-- .../flytekitplugins/pydantic/__init__.py | 2 +- .../pydantic/basemodel_extensions.py | 1 + 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/plugins/flytekit-pydantic/README.md b/plugins/flytekit-pydantic/README.md index 04ffaa849f..6f0c630106 100644 --- a/plugins/flytekit-pydantic/README.md +++ b/plugins/flytekit-pydantic/README.md @@ -17,14 +17,23 @@ from pydantic import BaseModel import flytekitplugins.pydantic -class Config(BaseModel, **flytekitplugins.pydantic.pydantic_flyteobject_config): +class TrainConfig(BaseModel, **flytekitplugins.pydantic.pydantic_flyteobject_config): lr: float = 1e-3 batch_size: int = 32 files: List[FlyteFile] directories: List[FlyteDirectory] +# or alternatively +class TrainConfig(BaseModel): + lr: float = 1e-3 + batch_size: int = 32 + files: List[FlyteFile] + directories: List[FlyteDirectory] + + class Config: + json_encoders = flytekitplugins.pydantic.flyteobject_json_encoders @task -def train(cfg: Config): +def train(cfg: TrainConfig): ... ``` diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py index 41f692637d..8651572310 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py @@ -1,2 +1,2 @@ -from .basemodel_extensions import pydantic_flyteobject_config +from .basemodel_extensions import pydantic_flyteobject_config, flyteobject_json_encoders from .schema import BaseModelTransformer diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_extensions.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_extensions.py index 988ada8b75..674fa311ee 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_extensions.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_extensions.py @@ -150,3 +150,4 @@ def validate_flytefile( # add these to your basemodel config to enable serialization/deserialization of flyte objects. pydantic_flyteobject_config = get_pydantic_flyteobject_config() +flyteobject_json_encoders = pydantic_flyteobject_config["json_encoders"] From ccbb70bc8cb0ee2891533ea2a95c873997ed0a19 Mon Sep 17 00:00:00 2001 From: Arthur Date: Fri, 2 Jun 2023 13:04:15 -0700 Subject: [PATCH 15/55] moved setting of json-encoders in to pydantic transformer --- plugins/flytekit-pydantic/README.md | 12 +----------- .../flytekitplugins/pydantic/__init__.py | 2 +- .../flytekitplugins/pydantic/schema.py | 3 +++ .../flytekit-pydantic/tests/test_type_transformer.py | 12 ++++++------ 4 files changed, 11 insertions(+), 18 deletions(-) diff --git a/plugins/flytekit-pydantic/README.md b/plugins/flytekit-pydantic/README.md index 6f0c630106..16135bc9c9 100644 --- a/plugins/flytekit-pydantic/README.md +++ b/plugins/flytekit-pydantic/README.md @@ -14,25 +14,15 @@ pip install flytekitplugins-pydantic ## Type Example ```python from pydantic import BaseModel -import flytekitplugins.pydantic +import flytekitplugins.pydantic # This import will enable you to add FlyteFiles and FlyteDirectories to you BaseModels -class TrainConfig(BaseModel, **flytekitplugins.pydantic.pydantic_flyteobject_config): - lr: float = 1e-3 - batch_size: int = 32 - files: List[FlyteFile] - directories: List[FlyteDirectory] - -# or alternatively class TrainConfig(BaseModel): lr: float = 1e-3 batch_size: int = 32 files: List[FlyteFile] directories: List[FlyteDirectory] - class Config: - json_encoders = flytekitplugins.pydantic.flyteobject_json_encoders - @task def train(cfg: TrainConfig): ... diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py index 8651572310..50e400af81 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py @@ -1,2 +1,2 @@ -from .basemodel_extensions import pydantic_flyteobject_config, flyteobject_json_encoders +from .basemodel_extensions import flyteobject_json_encoders, pydantic_flyteobject_config from .schema import BaseModelTransformer diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/schema.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/schema.py index 08188039cb..a4e5df5dbe 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/schema.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/schema.py @@ -7,6 +7,8 @@ from flytekit.core import type_engine from flytekit.models import literals, types +from . import basemodel_extensions + """ Serializes & deserializes the pydantic basemodels """ @@ -32,6 +34,7 @@ def to_literal( """This method is used to convert from given python type object pydantic ``pydantic.BaseModel`` to the Literal representation.""" s = struct_pb2.Struct() + python_val.__config__.json_encoders.update(basemodel_extensions.flyteobject_json_encoders) schema = python_val.schema_json() data = python_val.json() s.update({"schema": schema, "data": data}) diff --git a/plugins/flytekit-pydantic/tests/test_type_transformer.py b/plugins/flytekit-pydantic/tests/test_type_transformer.py index 946db9900e..a8ed6ae39f 100644 --- a/plugins/flytekit-pydantic/tests/test_type_transformer.py +++ b/plugins/flytekit-pydantic/tests/test_type_transformer.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Type, Union +from typing import Any, Dict, List, Optional, Type, Union import flytekitplugins.pydantic # noqa F401 import pytest @@ -23,26 +23,26 @@ class Config: class Config(BaseModel): """Config BaseModel for testing purposes with an optional type hint.""" - model_config: Optional[Union[dict[str, TrainConfig], TrainConfig]] = TrainConfig() + model_config: Optional[Union[Dict[str, TrainConfig], TrainConfig]] = TrainConfig() class ConfigRequired(BaseModel): """Config BaseModel for testing purposes with required attribute.""" - model_config: Union[dict[str, TrainConfig], TrainConfig] + model_config: Union[Dict[str, TrainConfig], TrainConfig] class ChildConfig(Config): """Child class config BaseModel for testing purposes.""" - d: list[int] = [1, 2, 3] + d: List[int] = [1, 2, 3] @pytest.mark.parametrize( "python_type,kwargs", [(Config, {}), (ConfigRequired, {"model_config": TrainConfig()}), (TrainConfig, {}), (TrainConfig, {})], ) -def test_transform_round_trip(python_type: Type, kwargs: dict[str, Any]): +def test_transform_round_trip(python_type: Type, kwargs: Dict[str, Any]): """Test that a (de-)serialization roundtrip results in the identical BaseModel.""" from flytekit.core.context_manager import FlyteContextManager @@ -72,7 +72,7 @@ def test_transform_round_trip(python_type: Type, kwargs: dict[str, Any]): (ConfigRequired, {"model_config": {"foo": TrainConfig(loss="mse")}}), ], ) -def test_pass_to_workflow(config_type: Type, kwargs: dict[str, Any]): +def test_pass_to_workflow(config_type: Type, kwargs: Dict[str, Any]): """Test passing a BaseModel instance to a workflow works.""" cfg = config_type(**kwargs) From b671d3d32fd6d1aafcbd76a9f26bc2ab71d57fb4 Mon Sep 17 00:00:00 2001 From: Arthur Date: Fri, 2 Jun 2023 16:05:50 -0700 Subject: [PATCH 16/55] added test for flytepath --- .../pydantic/basemodel_extensions.py | 13 ++++---- plugins/flytekit-pydantic/tests/test_file.txt | 1 + .../tests/test_type_transformer.py | 30 +++++++++++++++++++ 3 files changed, 37 insertions(+), 7 deletions(-) create mode 100644 plugins/flytekit-pydantic/tests/test_file.txt diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_extensions.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_extensions.py index 674fa311ee..51008aecc8 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_extensions.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_extensions.py @@ -62,8 +62,7 @@ def _set_validator_on_serializeable(self, serializeable: Type[Serializable]) -> class FlyteDirJsonEncoded(TypedDict): """JSON representation of a FlyteDirectory""" - - remote_source: str + path: str FLYTEDIR_DESERIALIZABLE_TYPES = Union[flyte_directory_types.FlyteDirectory, str, FlyteDirJsonEncoded] @@ -79,7 +78,7 @@ def __init__( super().__init__(t) def serialize(self, obj: flyte_directory_types.FlyteDirectory) -> FlyteDirJsonEncoded: - return {"remote_source": obj.remote_source} + return {"path": obj.remote_source if obj.remote_source else obj.path} def deserialize(self, obj: FLYTEDIR_DESERIALIZABLE_TYPES) -> flyte_directory_types.FlyteDirectory: flytedir = validate_flytedir(obj) @@ -94,7 +93,7 @@ def validate_flytedir( """validator for flytedir (i.e. deserializer)""" if isinstance(flytedir, dict): # this is a json encoded flytedir flytedir = cast(FlyteDirJsonEncoded, flytedir) - path = flytedir["remote_source"] + path = flytedir["path"] return flytepath_creation.make_flytepath(path, flyte_directory_types.FlyteDirectory) elif isinstance(flytedir, str): # when e.g. initializing from config return flytepath_creation.make_flytepath(flytedir, flyte_directory_types.FlyteDirectory) @@ -112,7 +111,7 @@ def validate_flytedir( class FlyteFileJsonEncoded(TypedDict): """JSON representation of a FlyteFile""" - remote_source: str + path: str FLYTEFILE_DESERIALIZABLE_TYPES = Union[flyte_directory_types.FlyteFile, str, FlyteFileJsonEncoded] @@ -123,7 +122,7 @@ def __init__(self, t: Type[flyte_file.FlyteFile] = flyte_file.FlyteFile): super().__init__(t) def serialize(self, obj: flyte_file.FlyteFile) -> FlyteFileJsonEncoded: - return {"remote_source": obj.remote_source} + return {"path": obj.remote_source if obj.remote_source else obj.path} def deserialize(self, obj: FLYTEFILE_DESERIALIZABLE_TYPES) -> flyte_file.FlyteFile: flyte_file = validate_flytefile(obj) @@ -138,7 +137,7 @@ def validate_flytefile( """validator for flytedir (i.e. deserializer)""" if isinstance(flytedir, dict): # this is a json encoded flytedir flytedir = cast(FlyteFileJsonEncoded, flytedir) - path = flytedir["remote_source"] + path = flytedir["path"] return flytepath_creation.make_flytepath(path, flyte_file.FlyteFile) elif isinstance(flytedir, str): # when e.g. initializing from config return flytepath_creation.make_flytepath(flytedir, flyte_file.FlyteFile) diff --git a/plugins/flytekit-pydantic/tests/test_file.txt b/plugins/flytekit-pydantic/tests/test_file.txt new file mode 100644 index 0000000000..83ebd4399d --- /dev/null +++ b/plugins/flytekit-pydantic/tests/test_file.txt @@ -0,0 +1 @@ +love sosa \ No newline at end of file diff --git a/plugins/flytekit-pydantic/tests/test_type_transformer.py b/plugins/flytekit-pydantic/tests/test_type_transformer.py index a8ed6ae39f..01547ed099 100644 --- a/plugins/flytekit-pydantic/tests/test_type_transformer.py +++ b/plugins/flytekit-pydantic/tests/test_type_transformer.py @@ -1,4 +1,5 @@ from typing import Any, Dict, List, Optional, Type, Union +from flytekit.types.file import file import flytekitplugins.pydantic # noqa F401 import pytest @@ -31,6 +32,10 @@ class ConfigRequired(BaseModel): model_config: Union[Dict[str, TrainConfig], TrainConfig] +class ConfigWithFlyteFiles(BaseModel): + """Config BaseModel for testing purposes with flytekit.files.FlyteFile type hint.""" + + flytefiles: List[file.FlyteFile] class ChildConfig(Config): """Child class config BaseModel for testing purposes.""" @@ -70,6 +75,7 @@ def test_transform_round_trip(python_type: Type, kwargs: Dict[str, Any]): [ (Config, {"model_config": {"foo": TrainConfig(loss="mse")}}), (ConfigRequired, {"model_config": {"foo": TrainConfig(loss="mse")}}), + (ConfigWithFlyteFiles, {"flytefiles": ['s3://foo/bar']}) ], ) def test_pass_to_workflow(config_type: Type, kwargs: Dict[str, Any]): @@ -89,6 +95,30 @@ def wf(cfg: config_type) -> config_type: assert cfg == returned_cfg +@pytest.mark.parametrize( + "kwargs", + [ + {"flytefiles": ['tests/test_file.txt']} + ], +) +def test_pass_to_workflow(kwargs: Dict[str, Any]): + """Test passing a BaseModel instance to a workflow works.""" + cfg = ConfigWithFlyteFiles(**kwargs) + + @task + def read(cfg: ConfigWithFlyteFiles) -> str: + with open (cfg.flytefiles[0], 'r') as f: + return f.read() + + @workflow + def wf(cfg: ConfigWithFlyteFiles) -> str: + return read(cfg=cfg) + + string = wf(cfg=cfg) + assert string == 'love sosa' + + + def test_pass_wrong_type_to_workflow(): """Test passing the wrong type raises exception.""" cfg = ChildConfig() From 240aa300b14cbb0ab47f7e589481e1ed5dde0b2f Mon Sep 17 00:00:00 2001 From: Arthur Date: Fri, 2 Jun 2023 16:15:10 -0700 Subject: [PATCH 17/55] added flytedir test --- .../{test_file.txt => folder/test_file1.txt} | 0 .../tests/folder/test_file2.txt | 1 + .../tests/test_type_transformer.py | 36 +++++++++++++++++-- 3 files changed, 34 insertions(+), 3 deletions(-) rename plugins/flytekit-pydantic/tests/{test_file.txt => folder/test_file1.txt} (100%) create mode 100644 plugins/flytekit-pydantic/tests/folder/test_file2.txt diff --git a/plugins/flytekit-pydantic/tests/test_file.txt b/plugins/flytekit-pydantic/tests/folder/test_file1.txt similarity index 100% rename from plugins/flytekit-pydantic/tests/test_file.txt rename to plugins/flytekit-pydantic/tests/folder/test_file1.txt diff --git a/plugins/flytekit-pydantic/tests/folder/test_file2.txt b/plugins/flytekit-pydantic/tests/folder/test_file2.txt new file mode 100644 index 0000000000..83ebd4399d --- /dev/null +++ b/plugins/flytekit-pydantic/tests/folder/test_file2.txt @@ -0,0 +1 @@ +love sosa \ No newline at end of file diff --git a/plugins/flytekit-pydantic/tests/test_type_transformer.py b/plugins/flytekit-pydantic/tests/test_type_transformer.py index 01547ed099..fecfa0567f 100644 --- a/plugins/flytekit-pydantic/tests/test_type_transformer.py +++ b/plugins/flytekit-pydantic/tests/test_type_transformer.py @@ -1,4 +1,6 @@ +import os from typing import Any, Dict, List, Optional, Type, Union +from flytekit.types import directory from flytekit.types.file import file import flytekitplugins.pydantic # noqa F401 @@ -37,6 +39,12 @@ class ConfigWithFlyteFiles(BaseModel): flytefiles: List[file.FlyteFile] + +class ConfigWithFlyteDirs(BaseModel): + """Config BaseModel for testing purposes with flytekit.files.FlyteFile type hint.""" + + flytedirs: List[directory.FlyteDirectory] + class ChildConfig(Config): """Child class config BaseModel for testing purposes.""" @@ -75,7 +83,8 @@ def test_transform_round_trip(python_type: Type, kwargs: Dict[str, Any]): [ (Config, {"model_config": {"foo": TrainConfig(loss="mse")}}), (ConfigRequired, {"model_config": {"foo": TrainConfig(loss="mse")}}), - (ConfigWithFlyteFiles, {"flytefiles": ['s3://foo/bar']}) + (ConfigWithFlyteFiles, {"flytefiles": ['s3://foo/bar']}), + (ConfigWithFlyteDirs, {"flytedirs": ['s3://foo/bar']}) ], ) def test_pass_to_workflow(config_type: Type, kwargs: Dict[str, Any]): @@ -98,10 +107,10 @@ def wf(cfg: config_type) -> config_type: @pytest.mark.parametrize( "kwargs", [ - {"flytefiles": ['tests/test_file.txt']} + {"flytefiles": ['tests/folder/test_file1.txt', 'tests/folder/test_file2.txt']}, ], ) -def test_pass_to_workflow(kwargs: Dict[str, Any]): +def test_flytefiles_in_wf(kwargs: Dict[str, Any]): """Test passing a BaseModel instance to a workflow works.""" cfg = ConfigWithFlyteFiles(**kwargs) @@ -117,6 +126,27 @@ def wf(cfg: ConfigWithFlyteFiles) -> str: string = wf(cfg=cfg) assert string == 'love sosa' +@pytest.mark.parametrize( + "kwargs", + [ + {"flytedirs": ['tests/folder/']}, + ], +) +def test_flytedirs_in_wf(kwargs: Dict[str, Any]): + """Test passing a BaseModel instance to a workflow works.""" + cfg = ConfigWithFlyteDirs(**kwargs) + + @task + def listdir(cfg: ConfigWithFlyteDirs) -> List[str]: + return os.listdir(cfg.flytedirs[0]) + + @workflow + def wf(cfg: ConfigWithFlyteDirs) -> List[str]: + return listdir(cfg=cfg) + + dirs = wf(cfg=cfg) + assert len(dirs) == 2 + def test_pass_wrong_type_to_workflow(): From 34b1576b8df460beee3e779e4a357f746c3993b6 Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 5 Jun 2023 10:08:58 -0700 Subject: [PATCH 18/55] flytekit will auto load the transformer if the plugin is installed. --- plugins/flytekit-pydantic/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/plugins/flytekit-pydantic/README.md b/plugins/flytekit-pydantic/README.md index 16135bc9c9..8eb7267100 100644 --- a/plugins/flytekit-pydantic/README.md +++ b/plugins/flytekit-pydantic/README.md @@ -14,7 +14,6 @@ pip install flytekitplugins-pydantic ## Type Example ```python from pydantic import BaseModel -import flytekitplugins.pydantic # This import will enable you to add FlyteFiles and FlyteDirectories to you BaseModels class TrainConfig(BaseModel): From bd686ad825675671a487c16e275ba7e25856d8fc Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 5 Jun 2023 11:30:01 -0700 Subject: [PATCH 19/55] added upload to s3 --- .../pydantic/basemodel_extensions.py | 5 ++++- .../pydantic/flytepath_creation.py | 10 ++++++++++ .../tests/test_type_transformer.py | 20 ++++++++++--------- 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_extensions.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_extensions.py index 51008aecc8..0946ab50c3 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_extensions.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_extensions.py @@ -62,7 +62,8 @@ def _set_validator_on_serializeable(self, serializeable: Type[Serializable]) -> class FlyteDirJsonEncoded(TypedDict): """JSON representation of a FlyteDirectory""" - path: str + + path: str FLYTEDIR_DESERIALIZABLE_TYPES = Union[flyte_directory_types.FlyteDirectory, str, FlyteDirJsonEncoded] @@ -78,6 +79,7 @@ def __init__( super().__init__(t) def serialize(self, obj: flyte_directory_types.FlyteDirectory) -> FlyteDirJsonEncoded: + flytepath_creation.upload_to_s3(obj) return {"path": obj.remote_source if obj.remote_source else obj.path} def deserialize(self, obj: FLYTEDIR_DESERIALIZABLE_TYPES) -> flyte_directory_types.FlyteDirectory: @@ -122,6 +124,7 @@ def __init__(self, t: Type[flyte_file.FlyteFile] = flyte_file.FlyteFile): super().__init__(t) def serialize(self, obj: flyte_file.FlyteFile) -> FlyteFileJsonEncoded: + flytepath_creation.upload_to_s3(obj) return {"path": obj.remote_source if obj.remote_source else obj.path} def deserialize(self, obj: FLYTEFILE_DESERIALIZABLE_TYPES) -> flyte_file.FlyteFile: diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/flytepath_creation.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/flytepath_creation.py index e5c7d0a679..ff5de2126d 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/flytepath_creation.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/flytepath_creation.py @@ -14,6 +14,16 @@ FlytePath = TypeVar("FlytePath", flyte_file.FlyteFile, flyte_directory_types.FlyteDirectory) +def upload_to_s3(flytepath: FlytePath, ctx: Optional[context_manager.FlyteContext] = None) -> None: + """Upload a FlytePath to S3""" + if ctx is None: + ctx = context_manager.FlyteContextManager.current_context() + if flytepath.remote_path is None: + flytepath.remote_path = remote_path = ctx.file_access.get_random_remote_path(flytepath.path) + is_multipart = isinstance(flytepath, flyte_directory_types.FlyteDirectory) + ctx.file_access.put_data(flytepath.path, remote_path, is_multipart=is_multipart) + + def make_flytepath(path: Union[str, os.PathLike], flyte_type: Type[FlytePath]) -> Optional[FlytePath]: """create a FlyteDirectory from a path""" context = context_manager.FlyteContextManager.current_context() diff --git a/plugins/flytekit-pydantic/tests/test_type_transformer.py b/plugins/flytekit-pydantic/tests/test_type_transformer.py index fecfa0567f..8d8aa0ed05 100644 --- a/plugins/flytekit-pydantic/tests/test_type_transformer.py +++ b/plugins/flytekit-pydantic/tests/test_type_transformer.py @@ -1,7 +1,5 @@ import os from typing import Any, Dict, List, Optional, Type, Union -from flytekit.types import directory -from flytekit.types.file import file import flytekitplugins.pydantic # noqa F401 import pytest @@ -10,6 +8,8 @@ from flytekit import task, workflow from flytekit.core.type_engine import TypeTransformerFailedError +from flytekit.types import directory +from flytekit.types.file import file class TrainConfig(BaseModel): @@ -34,6 +34,7 @@ class ConfigRequired(BaseModel): model_config: Union[Dict[str, TrainConfig], TrainConfig] + class ConfigWithFlyteFiles(BaseModel): """Config BaseModel for testing purposes with flytekit.files.FlyteFile type hint.""" @@ -45,6 +46,7 @@ class ConfigWithFlyteDirs(BaseModel): flytedirs: List[directory.FlyteDirectory] + class ChildConfig(Config): """Child class config BaseModel for testing purposes.""" @@ -83,8 +85,8 @@ def test_transform_round_trip(python_type: Type, kwargs: Dict[str, Any]): [ (Config, {"model_config": {"foo": TrainConfig(loss="mse")}}), (ConfigRequired, {"model_config": {"foo": TrainConfig(loss="mse")}}), - (ConfigWithFlyteFiles, {"flytefiles": ['s3://foo/bar']}), - (ConfigWithFlyteDirs, {"flytedirs": ['s3://foo/bar']}) + (ConfigWithFlyteFiles, {"flytefiles": ["s3://foo/bar"]}), + (ConfigWithFlyteDirs, {"flytedirs": ["s3://foo/bar"]}), ], ) def test_pass_to_workflow(config_type: Type, kwargs: Dict[str, Any]): @@ -107,7 +109,7 @@ def wf(cfg: config_type) -> config_type: @pytest.mark.parametrize( "kwargs", [ - {"flytefiles": ['tests/folder/test_file1.txt', 'tests/folder/test_file2.txt']}, + {"flytefiles": ["tests/folder/test_file1.txt", "tests/folder/test_file2.txt"]}, ], ) def test_flytefiles_in_wf(kwargs: Dict[str, Any]): @@ -116,7 +118,7 @@ def test_flytefiles_in_wf(kwargs: Dict[str, Any]): @task def read(cfg: ConfigWithFlyteFiles) -> str: - with open (cfg.flytefiles[0], 'r') as f: + with open(cfg.flytefiles[0], "r") as f: return f.read() @workflow @@ -124,12 +126,13 @@ def wf(cfg: ConfigWithFlyteFiles) -> str: return read(cfg=cfg) string = wf(cfg=cfg) - assert string == 'love sosa' + assert string == "love sosa" + @pytest.mark.parametrize( "kwargs", [ - {"flytedirs": ['tests/folder/']}, + {"flytedirs": ["tests/folder/"]}, ], ) def test_flytedirs_in_wf(kwargs: Dict[str, Any]): @@ -148,7 +151,6 @@ def wf(cfg: ConfigWithFlyteDirs) -> List[str]: assert len(dirs) == 2 - def test_pass_wrong_type_to_workflow(): """Test passing the wrong type raises exception.""" cfg = ChildConfig() From 67da286bda1fbe697973abd7d462c4f43b4442e7 Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 7 Jun 2023 09:51:04 -0700 Subject: [PATCH 20/55] removed ctx as input to upload_to_s3 --- .../flytekitplugins/pydantic/flytepath_creation.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/flytepath_creation.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/flytepath_creation.py index ff5de2126d..31fc97903d 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/flytepath_creation.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/flytepath_creation.py @@ -14,10 +14,9 @@ FlytePath = TypeVar("FlytePath", flyte_file.FlyteFile, flyte_directory_types.FlyteDirectory) -def upload_to_s3(flytepath: FlytePath, ctx: Optional[context_manager.FlyteContext] = None) -> None: +def upload_to_s3(flytepath: FlytePath) -> None: """Upload a FlytePath to S3""" - if ctx is None: - ctx = context_manager.FlyteContextManager.current_context() + ctx = context_manager.FlyteContextManager.current_context() if flytepath.remote_path is None: flytepath.remote_path = remote_path = ctx.file_access.get_random_remote_path(flytepath.path) is_multipart = isinstance(flytepath, flyte_directory_types.FlyteDirectory) From de92c75524b2cfa844a0a67b1865a4caddfe257f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arthur=20B=C3=B6=C3=B6k?= <49250723+ArthurBook@users.noreply.github.com> Date: Mon, 12 Jun 2023 09:30:50 -0700 Subject: [PATCH 21/55] Update plugins/flytekit-pydantic/tests/test_type_transformer.py good catch Co-authored-by: Fabio M. Graetz, Ph.D. --- plugins/flytekit-pydantic/tests/test_type_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-pydantic/tests/test_type_transformer.py b/plugins/flytekit-pydantic/tests/test_type_transformer.py index 8d8aa0ed05..3f45de2401 100644 --- a/plugins/flytekit-pydantic/tests/test_type_transformer.py +++ b/plugins/flytekit-pydantic/tests/test_type_transformer.py @@ -42,7 +42,7 @@ class ConfigWithFlyteFiles(BaseModel): class ConfigWithFlyteDirs(BaseModel): - """Config BaseModel for testing purposes with flytekit.files.FlyteFile type hint.""" + """Config BaseModel for testing purposes with flytekit.directory.FlyteDirectory type hint.""" flytedirs: List[directory.FlyteDirectory] From 24a415284a08e675db52c99bcd501d7eba3ef377 Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 12 Jun 2023 09:43:49 -0700 Subject: [PATCH 22/55] Revert "Allow annotated FlyteFile as task input argument (#1632)" This reverts commit fe9434f6b8db8d0ba7c74292a2400dcebda7f446. --- flytekit/core/type_engine.py | 47 +++++++++----------- flytekit/types/file/file.py | 9 ++-- tests/flytekit/unit/core/test_flyte_file.py | 13 ++---- tests/flytekit/unit/core/test_type_engine.py | 27 ----------- 4 files changed, 27 insertions(+), 69 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index e01cb9a343..5994390c8d 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -173,7 +173,8 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp return self._to_literal_transformer(python_val) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: - expected_python_type = get_underlying_type(expected_python_type) + if get_origin(expected_python_type) is Annotated: + expected_python_type = get_args(expected_python_type)[0] if expected_python_type != self._type: raise TypeTransformerFailedError( @@ -310,7 +311,7 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: Extracts the Literal type definition for a Dataclass and returns a type Struct. If possible also extracts the JSONSchema for the dataclass. """ - if is_annotated(t): + if get_origin(t) is Annotated: raise ValueError( "Flytekit does not currently have support for FlyteAnnotations applied to Dataclass." f"Type {t} cannot be parsed." @@ -367,7 +368,7 @@ def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]: self._get_origin_type_in_annotation(get_args(python_type)[0]), self._get_origin_type_in_annotation(get_args(python_type)[1]), ] - elif is_annotated(python_type): + elif get_origin(python_type) is Annotated: return get_args(python_type)[0] elif dataclasses.is_dataclass(python_type): for field in dataclasses.fields(copy.deepcopy(python_type)): @@ -736,7 +737,7 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]: """ cls.lazy_import_transformers() # Step 1 - if is_annotated(python_type): + if get_origin(python_type) is Annotated: args = get_args(python_type) for annotation in args: if isinstance(annotation, TypeTransformer): @@ -751,7 +752,7 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]: if hasattr(python_type, "__origin__"): # Handling of annotated generics, eg: # Annotated[typing.List[int], 'foo'] - if is_annotated(python_type): + if get_origin(python_type) is Annotated: return cls.get_transformer(get_args(python_type)[0]) if python_type.__origin__ in cls._REGISTRY: @@ -822,7 +823,7 @@ def to_literal_type(cls, python_type: Type) -> LiteralType: transformer = cls.get_transformer(python_type) res = transformer.get_literal_type(python_type) data = None - if is_annotated(python_type): + if get_origin(python_type) is Annotated: for x in get_args(python_type)[1:]: if not isinstance(x, FlyteAnnotation): continue @@ -850,9 +851,9 @@ def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type # In case the value is an annotated type we inspect the annotations and look for hash-related annotations. hash = None - if is_annotated(python_type): + if get_origin(python_type) is Annotated: # We are now dealing with one of two cases: - # 1. The annotated type is a `HashMethod`, which indicates that we should produce the hash using + # 1. The annotated type is a `HashMethod`, which indicates that we should we should produce the hash using # the method indicated in the annotation. # 2. The annotated type is being used for a different purpose other than calculating hash values, in which case # we should just continue. @@ -879,7 +880,7 @@ def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: T @classmethod def to_html(cls, ctx: FlyteContext, python_val: typing.Any, expected_python_type: Type[typing.Any]) -> str: transformer = cls.get_transformer(expected_python_type) - if is_annotated(expected_python_type): + if get_origin(expected_python_type) is Annotated: expected_python_type, *annotate_args = get_args(expected_python_type) from flytekit.deck.renderer import Renderable @@ -1003,7 +1004,7 @@ def get_sub_type(t: Type[T]) -> Type[T]: if hasattr(t, "__origin__"): # Handle annotation on list generic, eg: # Annotated[typing.List[int], 'foo'] - if is_annotated(t): + if get_origin(t) is Annotated: return ListTransformer.get_sub_type(get_args(t)[0]) if getattr(t, "__origin__") is list and hasattr(t, "__args__"): @@ -1029,7 +1030,7 @@ def is_batchable(t: Type): """ from flytekit.types.pickle import FlytePickle - if is_annotated(t): + if get_origin(t) is Annotated: return ListTransformer.is_batchable(get_args(t)[0]) if get_origin(t) is list: subtype = get_args(t)[0] @@ -1046,7 +1047,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp batch_size = len(python_val) # default batch size # parse annotated to get the number of items saved in a pickle file. - if is_annotated(python_type): + if get_origin(python_type) is Annotated: for annotation in get_args(python_type)[1:]: if isinstance(annotation, BatchSize): batch_size = annotation.val @@ -1190,7 +1191,8 @@ def get_sub_type_in_optional(t: Type[T]) -> Type[T]: return get_args(t)[0] def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]: - t = get_underlying_type(t) + if get_origin(t) is Annotated: + t = get_args(t)[0] try: trans: typing.List[typing.Tuple[TypeTransformer, typing.Any]] = [ @@ -1204,7 +1206,8 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]: raise ValueError(f"Type of Generic Union type is not supported, {e}") def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: - python_type = get_underlying_type(python_type) + if get_origin(python_type) is Annotated: + python_type = get_args(python_type)[0] found_res = False res = None @@ -1229,7 +1232,8 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp raise TypeTransformerFailedError(f"Cannot convert from {python_val} to {python_type}") def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> Optional[typing.Any]: - expected_python_type = get_underlying_type(expected_python_type) + if get_origin(expected_python_type) is Annotated: + expected_python_type = get_args(expected_python_type)[0] union_tag = None union_type = None @@ -1464,7 +1468,7 @@ def __init__(self): super().__init__(name="DefaultEnumTransformer", t=enum.Enum) def get_literal_type(self, t: Type[T]) -> LiteralType: - if is_annotated(t): + if get_origin(t) is Annotated: raise ValueError( f"Flytekit does not currently have support \ for FlyteAnnotations applied to enums. {t} cannot be \ @@ -1778,14 +1782,3 @@ def get(self, attr: str, as_type: Optional[typing.Type] = None) -> typing.Any: _register_default_type_transformers() - - -def is_annotated(t: Type) -> bool: - return get_origin(t) is Annotated - - -def get_underlying_type(t: Type) -> Type: - """Return the underlying type for annotated types or the type itself""" - if is_annotated(t): - return get_args(t)[0] - return t diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index d78ec152d7..bb8feb3d9c 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -8,9 +8,10 @@ from dataclasses_json import config, dataclass_json from marshmallow import fields +from typing_extensions import Annotated, get_args, get_origin from flytekit.core.context_manager import FlyteContext, FlyteContextManager -from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError, get_underlying_type +from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError from flytekit.loggers import logger from flytekit.models.core.types import BlobType from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar @@ -336,7 +337,8 @@ def to_literal( raise TypeTransformerFailedError("None value cannot be converted to a file.") # Correctly handle `Annotated[FlyteFile, ...]` by extracting the origin type - python_type = get_underlying_type(python_type) + if get_origin(python_type) is Annotated: + python_type = get_args(python_type)[0] if not (python_type is os.PathLike or issubclass(python_type, FlyteFile)): raise ValueError(f"Incorrect type {python_type}, must be either a FlyteFile or os.PathLike") @@ -411,9 +413,6 @@ def to_python_value( if expected_python_type is os.PathLike: return FlyteFile(uri) - # Correctly handle `Annotated[FlyteFile, ...]` by extracting the origin type - expected_python_type = get_underlying_type(expected_python_type) - # The rest of the logic is only for FlyteFile types. if not issubclass(expected_python_type, FlyteFile): # type: ignore raise TypeError(f"Neither os.PathLike nor FlyteFile specified {expected_python_type}") diff --git a/tests/flytekit/unit/core/test_flyte_file.py b/tests/flytekit/unit/core/test_flyte_file.py index 5dd05cdffd..b7f0a1aeee 100644 --- a/tests/flytekit/unit/core/test_flyte_file.py +++ b/tests/flytekit/unit/core/test_flyte_file.py @@ -439,20 +439,13 @@ def test_flyte_file_annotated_hashmethod(local_dummy_file): def calc_hash(ff: FlyteFile) -> str: return str(ff.path) - HashedFlyteFile = Annotated[FlyteFile, HashMethod(calc_hash)] - - @task - def t1(path: str) -> HashedFlyteFile: - return HashedFlyteFile(path) - @task - def t2(ff: HashedFlyteFile) -> None: - print(ff.path) + def t1(path: str) -> Annotated[FlyteFile, HashMethod(calc_hash)]: + return FlyteFile(path) @workflow def wf(path: str) -> None: - ff = t1(path=path) - t2(ff=ff) + t1(path=path) wf(path=local_dummy_file) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 2e52fdcf9d..6d1b6829d5 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -41,8 +41,6 @@ UnionTransformer, convert_json_schema_to_python_class, dataclass_from_dict, - get_underlying_type, - is_annotated, ) from flytekit.exceptions import user as user_exceptions from flytekit.models import types as model_types @@ -1687,28 +1685,3 @@ def test_batch_pickle_list(python_val, python_type, expected_list_length): # data = task0() # task0() -> Annotated[typing.List[FlytePickle], BatchSize(2)] # task1(data=data) # task1(data: typing.List[FlytePickle]) assert pv == python_val - - -@pytest.mark.parametrize( - "t,expected", - [ - (list, False), - (Annotated[int, "tag"], True), - (Annotated[typing.List[str], "a", "b"], True), - (Annotated[typing.Dict[int, str], FlyteAnnotation({"foo": "bar"})], True), - ], -) -def test_is_annotated(t, expected): - assert is_annotated(t) == expected - - -@pytest.mark.parametrize( - "t,expected", - [ - (typing.List, typing.List), - (Annotated[int, "tag"], int), - (Annotated[typing.List[str], "a", "b"], typing.List[str]), - ], -) -def test_get_underlying_type(t, expected): - assert get_underlying_type(t) == expected From 1880913de9045e037a34c64653e7e7bdcc7405ef Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 12 Jun 2023 09:44:14 -0700 Subject: [PATCH 23/55] Revert "Use logger instead of print statement in sqlalchemy plugin (#1651)" This reverts commit b4e6f8080f5fcc1bed10c8e46744d685ad6d0672. --- .../flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py b/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py index 8e8c464bd4..8541bc6aed 100644 --- a/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py +++ b/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py @@ -11,7 +11,6 @@ from flytekit.core.base_sql_task import SQLTask from flytekit.core.python_customized_container_task import PythonCustomizedContainerTask from flytekit.core.shim_task import ShimTaskExecutor -from flytekit.loggers import logger from flytekit.models import task as task_models from flytekit.models.security import Secret from flytekit.types.schema import FlyteSchema @@ -127,10 +126,10 @@ def execute_from_model(self, tt: task_models.TaskTemplate, **kwargs) -> typing.A tt.custom["connect_args"][key] = value engine = create_engine(tt.custom["uri"], connect_args=tt.custom["connect_args"], echo=False) - logger.info(f"Connecting to db {tt.custom['uri']}") + print(f"Connecting to db {tt.custom['uri']}") interpolated_query = SQLAlchemyTask.interpolate_query(tt.custom["query_template"], **kwargs) - logger.info(f"Interpolated query {interpolated_query}") + print(f"Interpolated query {interpolated_query}") with engine.begin() as connection: df = None if tt.interface.outputs: From 595183d875e2c99f21e7e2a8e8d9bbfcb145dc94 Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 12 Jun 2023 09:44:33 -0700 Subject: [PATCH 24/55] Revert "Map over notebook task (#1650)" This reverts commit ff734640260eb4865fdd4f98f24c0ccb4ba6232a. --- flytekit/core/map_task.py | 17 +++++------------ .../flytekitplugins/papermill/task.py | 8 +------- plugins/flytekit-papermill/setup.py | 2 +- plugins/flytekit-papermill/tests/test_task.py | 19 +------------------ 4 files changed, 8 insertions(+), 38 deletions(-) diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index 52325ecb59..b40b5029bb 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -16,7 +16,7 @@ from flytekit.core.constants import SdkTaskType from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.interface import transform_interface_to_list_interface -from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask +from flytekit.core.python_function_task import PythonFunctionTask from flytekit.core.tracker import TrackedInstance from flytekit.core.utils import timeit from flytekit.exceptions import scopes as exception_scopes @@ -34,7 +34,7 @@ class MapPythonTask(PythonTask): def __init__( self, - python_function_task: typing.Union[PythonFunctionTask, PythonInstanceTask, functools.partial], + python_function_task: typing.Union[PythonFunctionTask, functools.partial], concurrency: Optional[int] = None, min_success_ratio: Optional[float] = None, bound_inputs: Optional[Set[str]] = None, @@ -65,10 +65,7 @@ def __init__( actual_task = python_function_task if not isinstance(actual_task, PythonFunctionTask): - if isinstance(actual_task, PythonInstanceTask): - pass - else: - raise ValueError("Map tasks can only compose of PythonFuncton and PythonInstanceTasks currently") + raise ValueError("Map tasks can only compose of Python Functon Tasks currently") if len(actual_task.python_interface.outputs.keys()) > 1: raise ValueError("Map tasks only accept python function tasks with 0 or 1 outputs") @@ -79,11 +76,7 @@ def __init__( collection_interface = transform_interface_to_list_interface(actual_task.python_interface, self._bound_inputs) self._run_task: PythonFunctionTask = actual_task - if isinstance(actual_task, PythonInstanceTask): - mod = actual_task.task_type - f = actual_task.lhs - else: - _, mod, f, _ = tracker.extract_task_module(actual_task.task_function) + _, mod, f, _ = tracker.extract_task_module(actual_task.task_function) h = hashlib.md5(collection_interface.__str__().encode("utf-8")).hexdigest() name = f"{mod}.map_{f}_{h}" @@ -278,7 +271,7 @@ def _raw_execute(self, **kwargs) -> Any: def map_task( - task_function: typing.Union[PythonFunctionTask, PythonInstanceTask, functools.partial], + task_function: typing.Union[PythonFunctionTask, functools.partial], concurrency: int = 0, min_success_ratio: float = 1.0, **kwargs, diff --git a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py index 6f4ed6886c..b1f472e99a 100644 --- a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py +++ b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py @@ -133,7 +133,6 @@ def __init__( task_config: T = None, inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, outputs: typing.Optional[typing.Dict[str, typing.Type]] = None, - output_notebooks: typing.Optional[bool] = True, **kwargs, ): # Each instance of NotebookTask instantiates an underlying task with a dummy function that will only be used @@ -166,16 +165,13 @@ def __init__( if not os.path.exists(self._notebook_path): raise ValueError(f"Illegal notebook path passed in {self._notebook_path}") - if output_notebooks: - if outputs is None: - outputs = {} + if outputs: outputs.update( { self._IMPLICIT_OP_NOTEBOOK: self._IMPLICIT_OP_NOTEBOOK_TYPE, self._IMPLICIT_RENDERED_NOTEBOOK: self._IMPLICIT_RENDERED_NOTEBOOK_TYPE, } ) - super().__init__( name, task_config, @@ -291,8 +287,6 @@ def execute(self, **kwargs) -> Any: else: raise TypeError(f"Expected output {k} of type {type_v} not found in the notebook outputs") - if len(output_list) == 1: - return output_list[0] return tuple(output_list) def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any: diff --git a/plugins/flytekit-papermill/setup.py b/plugins/flytekit-papermill/setup.py index 538946a6d7..33b9816081 100644 --- a/plugins/flytekit-papermill/setup.py +++ b/plugins/flytekit-papermill/setup.py @@ -5,7 +5,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" plugin_requires = [ - "flytekit", + "flytekit>=1.3.0b2,<2.0.0", "papermill>=1.2.0", "nbconvert>=6.0.7", "ipykernel>=5.0.0", diff --git a/plugins/flytekit-papermill/tests/test_task.py b/plugins/flytekit-papermill/tests/test_task.py index 47db35793d..0e54e7082e 100644 --- a/plugins/flytekit-papermill/tests/test_task.py +++ b/plugins/flytekit-papermill/tests/test_task.py @@ -1,7 +1,6 @@ import datetime import os import tempfile -import typing import pandas as pd from flytekitplugins.papermill import NotebookTask @@ -9,7 +8,7 @@ from kubernetes.client import V1Container, V1PodSpec import flytekit -from flytekit import StructuredDataset, kwtypes, map_task, task, workflow +from flytekit import StructuredDataset, kwtypes, task from flytekit.configuration import Image, ImageConfig from flytekit.types.directory import FlyteDirectory from flytekit.types.file import FlyteFile, PythonNotebook @@ -34,14 +33,6 @@ def _get_nb_path(name: str, suffix: str = "", abs: bool = True, ext: str = ".ipy outputs=kwtypes(square=float), ) -nb_sub_task = NotebookTask( - name="test", - notebook_path=_get_nb_path(nb_name, abs=False), - inputs=kwtypes(a=float), - outputs=kwtypes(square=float), - output_notebooks=False, -) - def test_notebook_task_simple(): serialization_settings = flytekit.configuration.SerializationSettings( @@ -181,11 +172,3 @@ def create_sd() -> StructuredDataset: ) success, out, render = nb_types.execute(ff=ff, fd=fd, sd=sd) assert success is True, "Notebook execution failed" - - -def test_map_over_notebook_task(): - @workflow - def wf(a: float) -> typing.List[float]: - return map_task(nb_sub_task)(a=[a, a]) - - assert wf(a=3.14) == [9.8596, 9.8596] From 75ea89becee579b0d49243c63b7037f36efd1fb0 Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 12 Jun 2023 09:44:52 -0700 Subject: [PATCH 25/55] Revert "Support single literals in tiny url (#1654)" This reverts commit e9a714bebfa6095b3cfd2ea0351a8f8f9cbb6e14. --- doc-requirements.txt | 2 +- flytekit/remote/remote.py | 11 ++--------- setup.py | 3 ++- 3 files changed, 5 insertions(+), 11 deletions(-) diff --git a/doc-requirements.txt b/doc-requirements.txt index 5264673f4f..1929925e84 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -244,7 +244,7 @@ flask==2.2.3 # via mlflow flatbuffers==23.1.21 # via tensorflow -flyteidl==1.5.6 +flyteidl==1.5.4 # via flytekit fonttools==4.38.0 # via matplotlib diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index e0a411de50..8b05ba69dc 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -227,12 +227,7 @@ def file_access(self) -> FileAccessProvider: def get( self, flyte_uri: typing.Optional[str] = None - ) -> typing.Optional[typing.Union[LiteralsResolver, Literal, HTML, bytes]]: - """ - General function that works with flyte tiny urls. This can return outputs (in the form of LiteralsResolver, or - individual Literals for singular requests), or HTML if passed a deck link, or bytes containing HTML, - if ipython is not available locally. - """ + ) -> typing.Optional[typing.Union[LiteralsResolver, HTML, bytes]]: if flyte_uri is None: raise user_exceptions.FlyteUserException("flyte_uri cannot be empty") ctx = self._ctx or FlyteContextManager.current_context() @@ -242,8 +237,6 @@ def get( if data_response.HasField("literal_map"): lm = LiteralMap.from_flyte_idl(data_response.literal_map) return LiteralsResolver(lm.literals) - elif data_response.HasField("literal"): - return data_response.literal elif data_response.HasField("pre_signed_urls"): if len(data_response.pre_signed_urls.signed_url) == 0: raise ValueError(f"Flyte url {flyte_uri} resolved to empty download link") @@ -265,7 +258,7 @@ def get( except user_exceptions.FlyteUserException as e: remote_logger.info(f"Error from Flyte backend when trying to fetch data: {e.__cause__}") - remote_logger.info(f"Nothing found from {flyte_uri}") + remote_logger.debug(f"Nothing found from {flyte_uri}") def remote_context(self): """Context manager with remote-specific configuration.""" diff --git a/setup.py b/setup.py index 5273590d9f..8e057f4b91 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ }, install_requires=[ "googleapis-common-protos>=1.57", - "flyteidl>=1.5.6", + "flyteidl>=1.5.4", "wheel>=0.30.0,<1.0.0", "pandas>=1.0.0,<2.0.0", "pyarrow>=4.0.0,<11.0.0", @@ -64,6 +64,7 @@ "marshmallow-jsonschema>=0.12.0", "natsort>=7.0.1", "docker-image-py>=0.1.10", + "singledispatchmethod; python_version < '3.8.0'", "typing_extensions", "docstring-parser>=0.9.0", "diskcache>=5.2.1", From 39acb76013d999575d7ea01befbe134c787d48c2 Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 12 Jun 2023 09:45:07 -0700 Subject: [PATCH 26/55] Revert "Skip grpcio 1.55.0 (#1653)" This reverts commit 17f3441d96ac753493802905ff76495ecfa9486e. --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 8e057f4b91..7215101600 100644 --- a/setup.py +++ b/setup.py @@ -40,8 +40,8 @@ "python-dateutil>=2.1", # Restrict grpcio and grpcio-status. Version 1.50.0 pulls in a version of protobuf that is not compatible # with the old protobuf library (as described in https://developers.google.com/protocol-buffers/docs/news/2022-05-06) - "grpcio>=1.50.0,!=1.55.0,<2.0", - "grpcio-status>=1.50.0,!=1.55.0,<2.0", + "grpcio>=1.50.0,<2.0", + "grpcio-status>=1.50.0,<2.0", "importlib-metadata", "fsspec>=2023.3.0", "adlfs", From ab448d71c128f9c0f96c4545bc29e85f2470cf38 Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 12 Jun 2023 09:45:30 -0700 Subject: [PATCH 27/55] Revert "Add support overriding image (#1652)" This reverts commit 594026aa94a60e29d5c0a6fee9044638772a1299. --- flytekit/core/node.py | 2 -- tests/flytekit/unit/core/test_node_creation.py | 13 ------------- 2 files changed, 15 deletions(-) diff --git a/flytekit/core/node.py b/flytekit/core/node.py index bf5c97ba60..4f7838d2b6 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -128,8 +128,6 @@ def with_overrides(self, *args, **kwargs): if not isinstance(new_task_config, type(self.flyte_entity._task_config)): raise ValueError("can't change the type of the task config") self.flyte_entity._task_config = new_task_config - if "container_image" in kwargs: - self.flyte_entity._container_image = kwargs["container_image"] return self diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index 81621ef3fc..da708a8571 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -452,16 +452,3 @@ def my_wf(a: str) -> str: return t1(a=a).with_overrides(task_config=None) my_wf() - - -def test_override_image(): - @task - def bar(): - print("hello") - - @workflow - def wf() -> str: - bar().with_overrides(container_image="hello/world") - return "hi" - - assert wf.nodes[0].flyte_entity.container_image == "hello/world" From f9ac9e9cecf64691b8a4f5b49cdd584ae24995d3 Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 14 Jun 2023 18:50:33 -0700 Subject: [PATCH 28/55] full revamp and V2 --- .../flytekitplugins/pydantic/__init__.py | 6 +- .../pydantic/basemodel_extensions.py | 155 ------------------ .../{schema.py => basemodel_transformer.py} | 51 +++--- .../pydantic/deserialization.py | 77 +++++++++ .../pydantic/flytepath_creation.py | 76 --------- .../flytekitplugins/pydantic/object_store.py | 107 ++++++++++++ .../flytekitplugins/pydantic/serialization.py | 67 ++++++++ .../tests/folder/test_file1.txt | 2 +- .../tests/folder/test_file2.txt | 2 +- .../tests/test_type_transformer.py | 66 +++++--- 10 files changed, 328 insertions(+), 281 deletions(-) delete mode 100644 plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_extensions.py rename plugins/flytekit-pydantic/flytekitplugins/pydantic/{schema.py => basemodel_transformer.py} (51%) create mode 100644 plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py delete mode 100644 plugins/flytekit-pydantic/flytekitplugins/pydantic/flytepath_creation.py create mode 100644 plugins/flytekit-pydantic/flytekitplugins/pydantic/object_store.py create mode 100644 plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py index 50e400af81..23e7e341bd 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py @@ -1,2 +1,4 @@ -from .basemodel_extensions import flyteobject_json_encoders, pydantic_flyteobject_config -from .schema import BaseModelTransformer +from .basemodel_transformer import BaseModelTransformer +from .deserialization import set_validators_on_supported_flyte_types as _set_validators_on_supported_flyte_types + +_set_validators_on_supported_flyte_types() # enables you to use flytekit.types in pydantic model diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_extensions.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_extensions.py deleted file mode 100644 index 0946ab50c3..0000000000 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_extensions.py +++ /dev/null @@ -1,155 +0,0 @@ -import abc -from typing import Any, Dict, Generic, Optional, Type, TypeVar, Union, cast - -from typing_extensions import TypedDict - -from flytekit.types.directory import types as flyte_directory_types -from flytekit.types.file import file as flyte_file - -from . import flytepath_creation - -## ==================================================================================== -# Base class -## ==================================================================================== - - -Serializable = TypeVar("Serializable") -SerializedFlyteType = TypeVar("SerializedFlyteType") - - -def get_pydantic_flyteobject_config() -> Dict[str, Any]: - """ - Returns the pydantic config for the serializers/deserializers - """ - return { - "json_encoders": { - serializer_deserializer()._t: serializer_deserializer().serialize # type: ignore - for serializer_deserializer in PydanticSerializerDeserializerBase.__subclasses__() - }, - } - - -class PydanticSerializerDeserializerBase( - abc.ABC, - Generic[Serializable, SerializedFlyteType], -): - """Base class for object serializers/deserializers""" - - @abc.abstractmethod - def serialize(self, obj: Serializable) -> SerializedFlyteType: - pass - - @abc.abstractmethod - def deserialize(self, obj: SerializedFlyteType) -> Serializable: - pass - - def __init__(self, t: Type[Serializable]): - self._t = t - self._set_validator_on_serializeable(t) - - def _set_validator_on_serializeable(self, serializeable: Type[Serializable]) -> None: - """ - Sets the validator on the pydantic model for the - type that is being serialized/deserialized - """ - setattr(serializeable, "__get_validators__", lambda *_: (self.deserialize,)) - - -## ==================================================================================== -# FlyteDir -## ==================================================================================== - - -class FlyteDirJsonEncoded(TypedDict): - """JSON representation of a FlyteDirectory""" - - path: str - - -FLYTEDIR_DESERIALIZABLE_TYPES = Union[flyte_directory_types.FlyteDirectory, str, FlyteDirJsonEncoded] - - -class FlyteDirSerializerDeserializer( - PydanticSerializerDeserializerBase[flyte_directory_types.FlyteDirectory, FlyteDirJsonEncoded] -): - def __init__( - self, - t: Type[flyte_directory_types.FlyteDirectory] = flyte_directory_types.FlyteDirectory, - ): - super().__init__(t) - - def serialize(self, obj: flyte_directory_types.FlyteDirectory) -> FlyteDirJsonEncoded: - flytepath_creation.upload_to_s3(obj) - return {"path": obj.remote_source if obj.remote_source else obj.path} - - def deserialize(self, obj: FLYTEDIR_DESERIALIZABLE_TYPES) -> flyte_directory_types.FlyteDirectory: - flytedir = validate_flytedir(obj) - if flytedir is None: - raise ValueError(f"Could not deserialize {obj} to FlyteDirectory") - return flytedir - - -def validate_flytedir( - flytedir: FLYTEDIR_DESERIALIZABLE_TYPES, -) -> Optional[flyte_directory_types.FlyteDirectory]: - """validator for flytedir (i.e. deserializer)""" - if isinstance(flytedir, dict): # this is a json encoded flytedir - flytedir = cast(FlyteDirJsonEncoded, flytedir) - path = flytedir["path"] - return flytepath_creation.make_flytepath(path, flyte_directory_types.FlyteDirectory) - elif isinstance(flytedir, str): # when e.g. initializing from config - return flytepath_creation.make_flytepath(flytedir, flyte_directory_types.FlyteDirectory) - elif isinstance(flytedir, flyte_directory_types.FlyteDirectory): - return flytedir - else: - raise ValueError(f"Invalid type for flytedir: {type(flytedir)}") - - -## ==================================================================================== -# FlyteFile -## ==================================================================================== - - -class FlyteFileJsonEncoded(TypedDict): - """JSON representation of a FlyteFile""" - - path: str - - -FLYTEFILE_DESERIALIZABLE_TYPES = Union[flyte_directory_types.FlyteFile, str, FlyteFileJsonEncoded] - - -class FlyteFileSerializerDeserializer(PydanticSerializerDeserializerBase[flyte_file.FlyteFile, FlyteFileJsonEncoded]): - def __init__(self, t: Type[flyte_file.FlyteFile] = flyte_file.FlyteFile): - super().__init__(t) - - def serialize(self, obj: flyte_file.FlyteFile) -> FlyteFileJsonEncoded: - flytepath_creation.upload_to_s3(obj) - return {"path": obj.remote_source if obj.remote_source else obj.path} - - def deserialize(self, obj: FLYTEFILE_DESERIALIZABLE_TYPES) -> flyte_file.FlyteFile: - flyte_file = validate_flytefile(obj) - if flyte_file is None: - raise ValueError(f"Could not deserialize {obj} to FlyteFile") - return flyte_file - - -def validate_flytefile( - flytedir: FLYTEFILE_DESERIALIZABLE_TYPES, -) -> Optional[flyte_file.FlyteFile]: - """validator for flytedir (i.e. deserializer)""" - if isinstance(flytedir, dict): # this is a json encoded flytedir - flytedir = cast(FlyteFileJsonEncoded, flytedir) - path = flytedir["path"] - return flytepath_creation.make_flytepath(path, flyte_file.FlyteFile) - elif isinstance(flytedir, str): # when e.g. initializing from config - return flytepath_creation.make_flytepath(flytedir, flyte_file.FlyteFile) - elif isinstance(flytedir, flyte_directory_types.FlyteFile): - return flytedir - else: - raise ValueError(f"Invalid type for flytedir: {type(flytedir)}") - - -# add these to your basemodel config to enable serialization/deserialization of flyte objects. -pydantic_flyteobject_config = get_pydantic_flyteobject_config() -flyteobject_json_encoders = pydantic_flyteobject_config["json_encoders"] diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/schema.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py similarity index 51% rename from plugins/flytekit-pydantic/flytekitplugins/pydantic/schema.py rename to plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py index a4e5df5dbe..5dc3d3cca3 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/schema.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py @@ -1,18 +1,26 @@ from typing import Type - +from typing_extensions import Annotated import pydantic -from google.protobuf import json_format, struct_pb2 from flytekit import FlyteContext from flytekit.core import type_engine from flytekit.models import literals, types -from . import basemodel_extensions +from . import object_store, serialization """ Serializes & deserializes the pydantic basemodels """ +BaseModelLiteralValue = Annotated[ + literals.LiteralMap, + """ + BaseModel serialized to a LiteralMap consisting of: + 1) the basemodel json with placeholders for flyte types + 2) a mapping from placeholders to flyte object store with the flyte types + """, +] + class BaseModelTransformer(type_engine.TypeTransformer[pydantic.BaseModel]): _TYPE_INFO = types.LiteralType(simple=types.SimpleType.STRUCT) @@ -30,33 +38,32 @@ def to_literal( python_val: pydantic.BaseModel, python_type: Type[pydantic.BaseModel], expected: types.LiteralType, - ) -> literals.Literal: + ) -> BaseModelLiteralValue: """This method is used to convert from given python type object pydantic ``pydantic.BaseModel`` to the Literal representation.""" - - s = struct_pb2.Struct() - python_val.__config__.json_encoders.update(basemodel_extensions.flyteobject_json_encoders) - schema = python_val.schema_json() - data = python_val.json() - s.update({"schema": schema, "data": data}) - literal = literals.Literal(scalar=literals.Scalar(generic=s)) # type: ignore - return literal + return serialization.serialize_basemodel(python_val) def to_python_value( self, ctx: FlyteContext, - lv: literals.Literal, + lv: BaseModelLiteralValue, expected_python_type: Type[pydantic.BaseModel], ) -> pydantic.BaseModel: """Re-hydrate the pydantic pydantic.BaseModel object from Flyte Literal value.""" - base_model = json_format.MessageToDict(lv.scalar.generic) - schema = base_model["schema"] - data = base_model["data"] - if (expected_schema := expected_python_type.schema_json()) != schema: - raise type_engine.TypeTransformerFailedError( - f"The schema `{expected_schema}` of the expected python type {expected_python_type} is not equal to the received schema `{schema}`." - ) - - return expected_python_type.parse_raw(data) + update_objectstore_from_serialized_basemodel(lv) + basemodel_json_w_placeholders = read_basemodel_json_from_literalmap(lv) + return expected_python_type.parse_raw(basemodel_json_w_placeholders) + + +def read_basemodel_json_from_literalmap(lv: BaseModelLiteralValue) -> serialization.SerializedBaseModel: + """ + Given a LiteralMap, re-hydrate the pydantic BaseModel object from Flyte Literal value. + """ + basemodel_literal: literals.Literal = lv.literals[serialization.BASEMODEL_KEY] + return object_store.deserialize_flyte_literal(basemodel_literal, str) + + +def update_objectstore_from_serialized_basemodel(lv: BaseModelLiteralValue) -> None: + object_store.FlyteObjectStore.read_literalmap(lv.literals[serialization.FLYTETYPES_KEY]) type_engine.TypeEngine.register(BaseModelTransformer()) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py new file mode 100644 index 0000000000..a67283e959 --- /dev/null +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py @@ -0,0 +1,77 @@ +from typing import Any, Callable, Dict, Iterator, List, Type, TypeVar, Union + +import pydantic + +from flytekit.types import directory, file + +from . import object_store + +# this field is used by pydantic to get the validator method +PYDANTIC_VALIDATOR_METHOD_NAME = pydantic.BaseModel.__get_validators__.__name__ +PythonType = TypeVar("PythonType") # target type of the deserialization +Serializable = TypeVar("Serializable") # flyte object type + + +def set_validators_on_supported_flyte_types() -> None: + """ + Sets the validator on the pydantic model for the + type that is being serialized/deserialized + """ + [set_validators_on_flyte_type(flyte_type) for flyte_type in object_store.PYDANTIC_SUPPORTED_FLYTE_TYPES] + + +def set_validators_on_flyte_type(flyte_type: Type) -> None: + """ + Sets the validator on the pydantic model for the + type that is being serialized/deserialized + """ + setattr(flyte_type, PYDANTIC_VALIDATOR_METHOD_NAME, make_validators_for_type(flyte_type)) + + +def make_validators_for_type( + flyte_obj_type: Type[Serializable], +) -> Callable[[Any], Iterator[Callable[[Any], Serializable]]]: + """ + Returns a validator that can be used by pydantic to deserialize the object + """ + + def validator(object_uid_maybe: Union[object_store.LiteralObjID, Any]) -> Union[Serializable, Any]: + """partial of deserialize_flyte_literal with the object_type fixed""" + if not isinstance(object_uid_maybe, str): + return object_uid_maybe # this validator should only trigger for the placholders + if object_uid_maybe not in object_store.FlyteObjectStore.get_literal_store(): + return object_uid_maybe # if not in the store pass to the next validator to resolve + return object_store.FlyteObjectStore.get_python_object(object_uid_maybe, flyte_obj_type) + + def validator_generator(*args, **kwags) -> Iterator[Callable[[Any], Serializable]]: + """Generator that returns the validator""" + yield validator + yield from additional_flytetype_validators.get(flyte_obj_type, []) + + return validator_generator + + +def validate_flytefile(flytefile: Union[str, file.FlyteFile]) -> file.FlyteFile: + """validator for flytefile (i.e. deserializer)""" + if isinstance(flytefile, file.FlyteFile): + return flytefile + if isinstance(flytefile, str): # when e.g. initializing from config + return file.FlyteFile(flytefile) + else: + raise ValueError(f"Invalid type for flytefile: {type(flytefile)}") + + +def validate_flytedir(flytedir: Union[str, directory.FlyteDirectory]) -> directory.FlyteDirectory: + """validator for flytedir (i.e. deserializer)""" + if isinstance(flytedir, directory.FlyteDirectory): + return flytedir + if isinstance(flytedir, str): # when e.g. initializing from config + return directory.FlyteDirectory(flytedir) + else: + raise ValueError(f"Invalid type for flytedir: {type(flytedir)}") + + +additional_flytetype_validators: Dict[Type, List[Callable[[Any], Any]]] = { + file.FlyteFile: [validate_flytefile], + directory.FlyteDirectory: [validate_flytedir], +} diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/flytepath_creation.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/flytepath_creation.py deleted file mode 100644 index 31fc97903d..0000000000 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/flytepath_creation.py +++ /dev/null @@ -1,76 +0,0 @@ -## ==================================================================================== -# Flyte directory and files creation -## ==================================================================================== - -import os -from typing import Any, Dict, Optional, Type, TypeVar, Union - -from flytekit.core import context_manager, type_engine -from flytekit.models import literals -from flytekit.models.core import types as core_types -from flytekit.types.directory import types as flyte_directory_types -from flytekit.types.file import file as flyte_file - -FlytePath = TypeVar("FlytePath", flyte_file.FlyteFile, flyte_directory_types.FlyteDirectory) - - -def upload_to_s3(flytepath: FlytePath) -> None: - """Upload a FlytePath to S3""" - ctx = context_manager.FlyteContextManager.current_context() - if flytepath.remote_path is None: - flytepath.remote_path = remote_path = ctx.file_access.get_random_remote_path(flytepath.path) - is_multipart = isinstance(flytepath, flyte_directory_types.FlyteDirectory) - ctx.file_access.put_data(flytepath.path, remote_path, is_multipart=is_multipart) - - -def make_flytepath(path: Union[str, os.PathLike], flyte_type: Type[FlytePath]) -> Optional[FlytePath]: - """create a FlyteDirectory from a path""" - context = context_manager.FlyteContextManager.current_context() - transformer = get_flyte_transformer(flyte_type) - dimensionality = get_flyte_dimensionality(flyte_type) - literal = make_literal(uri=path, dimensionality=dimensionality) - out_dir = transformer.to_python_value(context, literal, flyte_type) - return out_dir - - -def get_flyte_transformer( - flyte_type: Type[FlytePath], -) -> type_engine.TypeTransformer[FlytePath]: - """get the transformer for a given flyte type""" - return FLYTE_TRANSFORMERS[flyte_type] - - -FLYTE_TRANSFORMERS: Dict[Type, type_engine.TypeTransformer] = { - flyte_file.FlyteFile: flyte_file.FlyteFilePathTransformer(), - flyte_directory_types.FlyteDirectory: flyte_directory_types.FlyteDirToMultipartBlobTransformer(), -} - - -def get_flyte_dimensionality( - flyte_type: Type[FlytePath], -) -> str: - """get the transformer for a given flyte type""" - return FLYTE_DIMENSIONALITY[flyte_type] - - -FLYTE_DIMENSIONALITY: Dict[Type, Any] = { - flyte_file.FlyteFile: core_types.BlobType.BlobDimensionality.SINGLE, - flyte_directory_types.FlyteDirectory: core_types.BlobType.BlobDimensionality.MULTIPART, -} - - -def make_literal( - uri: Union[str, os.PathLike], - dimensionality, -) -> literals.Literal: - scalar = make_scalar(uri, dimensionality) - return literals.Literal(scalar=scalar) # type: ignore - - -def make_scalar( - uri: Union[str, os.PathLike], - dimensionality, -) -> literals.Scalar: - blobtype = core_types.BlobType(format="", dimensionality=dimensionality) - blob = literals.Blob(metadata=literals.BlobMetadata(type=blobtype), uri=uri) - return literals.Scalar(blob=blob) # type: ignore diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/object_store.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/object_store.py new file mode 100644 index 0000000000..d6465e101e --- /dev/null +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/object_store.py @@ -0,0 +1,107 @@ +import uuid +from typing import Any, Dict, Type, TypeVar, cast + +import pandas as pd +import torch.nn as nn +from typing_extensions import Annotated, NewType + +from flytekit.core import context_manager, type_engine +from flytekit.models import literals +from flytekit.types import directory +from flytekit.types.file import file + +PYDANTIC_SUPPORTED_FLYTE_TYPES = ( + nn.Module, + pd.DataFrame, + file.FlyteFile, + directory.FlyteDirectory, + # TODO - add all supported types +) + +LiteralObjID = Annotated[str, "Key for unique object in literal map."] +PythonType = TypeVar("PythonType") # target type of the deserialization + + +class FlyteObjectStore: + """ + This class is an intermediate store for python objects that are being serialized/deserialized. + + On serialization of a basemodel, flyte objects are serialized and stored in this object store. + On deserialization of a basemodel, flyte objects are deserialized from this object store. + """ + + _literal_store: Dict[LiteralObjID, literals.Literal] = {} + + def __setattr__(self, name: str, value: Any) -> None: + raise Exception("Attributes should not be set on the FlyteObjectStore.") + + def __init__(self) -> None: + raise Exception("This should not be instantiated, it is a singleton object store.") + + def __contains__(self, item: LiteralObjID) -> bool: + return item in self.get_literal_store() + + @classmethod + def get_literal_store(cls): + """Accessor to the class variable""" + return cls._literal_store + + @classmethod + def register_python_object(cls, python_object: object) -> LiteralObjID: + """serializes to literal and returns a unique identifier""" + serialized_item = serialize_to_flyte_literal(python_object) + identifier = make_identifier(python_object) + cls.get_literal_store()[identifier] = serialized_item + return identifier + + @classmethod + def get_python_object(cls, identifier: LiteralObjID, expected_type: Type[PythonType]) -> PythonType: + """deserializes a literal and returns the python object""" + literal = cls.get_literal_store()[identifier] + python_object = deserialize_flyte_literal(literal, expected_type) + return python_object + + @classmethod + def as_literalmap(cls) -> literals.LiteralMap: + """ + Converts the object store to a literal map + """ + return literals.LiteralMap(literals=cls.get_literal_store()) + + @classmethod + def read_literalmap(cls, literal_map: literals.LiteralMap) -> None: + """ + Reads a literal map and populates the object store from it + """ + literal_store = cls.get_literal_store() + literal_store.update(literal_map.literals) + + +def deserialize_flyte_literal(flyteobj_literal: literals.Literal, python_type: Type[PythonType]) -> PythonType: + """ + Deserializes a Flyte Literal into the python object instance. + """ + ctx = context_manager.FlyteContext.current_context() + transformer = type_engine.TypeEngine.get_transformer(python_type) + python_obj = transformer.to_python_value(ctx, flyteobj_literal, python_type) + return cast(PythonType, python_obj) + + +def serialize_to_flyte_literal(python_obj) -> literals.Literal: + """ + Use the Flyte TypeEngine to serialize a python object to a Flyte Literal. + """ + python_type = type(python_obj) + ctx = context_manager.FlyteContextManager().current_context() + literal_type = type_engine.TypeEngine.to_literal_type(python_type) + literal_obj = type_engine.TypeEngine.to_literal(ctx, python_obj, python_type, literal_type) + return literal_obj + + +def make_identifier(python_type: object) -> LiteralObjID: + """ + Create a unique identifier for a python object. + """ + # TODO - human readable way to identify the object + unique_id = f"{type(python_type).__name__}_{uuid.uuid4().hex}" + return cast(LiteralObjID, unique_id) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py new file mode 100644 index 0000000000..ae6ee6fae2 --- /dev/null +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py @@ -0,0 +1,67 @@ +""" +Logic for serializing a basemodel to a literalmap that can be passed between container + +The serialization process is as follows: + +1. Serialize the basemodel to json, replacing all flyte types with unique placeholder strings +2. Serialize the flyte types to separate literals and store them in the flyte object store (a singleton object) +3. Return a literal map with the json and the flyte object store represented as a literalmap {placeholder: flyte type} + +""" +from typing import Any, NamedTuple, Union, cast +from typing_extensions import Annotated + +import pydantic +from google.protobuf import struct_pb2 + +from flytekit.models import literals +from flytekit.core import context_manager, type_engine + +from . import object_store + + +BASEMODEL_KEY = cast(object_store.LiteralObjID, "BaseModel") +FLYTETYPES_KEY = cast(object_store.LiteralObjID, "Flyte Object Store") + +SerializedBaseModel = Annotated[str, "A pydantic BaseModel that has been serialized with placeholders for Flyte types."] + + +def serialize_basemodel(basemodel: pydantic.BaseModel) -> literals.LiteralMap: + """ + Serializes a given pydantic BaseModel instance into a LiteralMap. + The BaseModel is first serialized into a JSON format, where all Flyte types are replaced with unique placeholder strings. + The Flyte Types are serialized into separate Flyte literals + """ + + basemodel_json = serialize_basemodel_to_json_and_store(basemodel) + basemodel_literal = make_literal_from_json(basemodel_json) + return literals.LiteralMap( + { + BASEMODEL_KEY: basemodel_literal, # json with flyte types replaced with placeholders + FLYTETYPES_KEY: object_store.FlyteObjectStore.as_literalmap(), # placeholders mapped to flyte types + } + ) + + +def make_literal_from_json(json: str) -> literals.Literal: + """ + Converts the json representation of a pydantic BaseModel to a Flyte Literal. + """ + # serialize as a string literal + ctx = context_manager.FlyteContext.current_context() + string_transformer = type_engine.TypeEngine.get_transformer(str) + return string_transformer.to_literal(ctx, json, str, string_transformer.get_literal_type(str)) + + +def serialize_basemodel_to_json_and_store(basemodel: pydantic.BaseModel) -> SerializedBaseModel: + """ + Serialize a pydantic BaseModel to json and protobuf, separating out the Flyte types into a separate store. + On deserialization, the store is used to reconstruct the Flyte types. + """ + + def encoder(obj: Any) -> Union[str, object_store.LiteralObjID]: + if isinstance(obj, object_store.PYDANTIC_SUPPORTED_FLYTE_TYPES): + return object_store.FlyteObjectStore.register_python_object(obj) + return basemodel.__json_encoder__(obj) + + return basemodel.json(encoder=encoder) diff --git a/plugins/flytekit-pydantic/tests/folder/test_file1.txt b/plugins/flytekit-pydantic/tests/folder/test_file1.txt index 83ebd4399d..1910281566 100644 --- a/plugins/flytekit-pydantic/tests/folder/test_file1.txt +++ b/plugins/flytekit-pydantic/tests/folder/test_file1.txt @@ -1 +1 @@ -love sosa \ No newline at end of file +foo \ No newline at end of file diff --git a/plugins/flytekit-pydantic/tests/folder/test_file2.txt b/plugins/flytekit-pydantic/tests/folder/test_file2.txt index 83ebd4399d..ba0e162e1c 100644 --- a/plugins/flytekit-pydantic/tests/folder/test_file2.txt +++ b/plugins/flytekit-pydantic/tests/folder/test_file2.txt @@ -1 +1 @@ -love sosa \ No newline at end of file +bar \ No newline at end of file diff --git a/plugins/flytekit-pydantic/tests/test_type_transformer.py b/plugins/flytekit-pydantic/tests/test_type_transformer.py index 3f45de2401..47de3b083d 100644 --- a/plugins/flytekit-pydantic/tests/test_type_transformer.py +++ b/plugins/flytekit-pydantic/tests/test_type_transformer.py @@ -1,13 +1,13 @@ import os from typing import Any, Dict, List, Optional, Type, Union +import pandas as pd -import flytekitplugins.pydantic # noqa F401 import pytest from flytekitplugins.pydantic import BaseModelTransformer from pydantic import BaseModel, Extra -from flytekit import task, workflow -from flytekit.core.type_engine import TypeTransformerFailedError +import flytekit +from flytekit.core import type_engine from flytekit.types import directory from flytekit.types.file import file @@ -47,6 +47,12 @@ class ConfigWithFlyteDirs(BaseModel): flytedirs: List[directory.FlyteDirectory] +class ConfigWithPandasDataFrame(BaseModel): + """Config BaseModel for testing purposes with pandas.DataFrame type hint.""" + + df: pd.DataFrame + + class ChildConfig(Config): """Child class config BaseModel for testing purposes.""" @@ -55,7 +61,14 @@ class ChildConfig(Config): @pytest.mark.parametrize( "python_type,kwargs", - [(Config, {}), (ConfigRequired, {"model_config": TrainConfig()}), (TrainConfig, {}), (TrainConfig, {})], + [ + (Config, {}), + (ConfigRequired, {"model_config": TrainConfig()}), + (TrainConfig, {}), + (ConfigWithFlyteFiles, {"flytefiles": ["tests/folder/test_file1.txt", "tests/folder/test_file2.txt"]}), + (ConfigWithFlyteDirs, {"flytedirs": ["tests/folder/"]}), + (ConfigWithPandasDataFrame, {"df": {"a": [1, 2, 3], "b": [4, 5, 6]}}), + ], ) def test_transform_round_trip(python_type: Type, kwargs: Dict[str, Any]): """Test that a (de-)serialization roundtrip results in the identical BaseModel.""" @@ -76,8 +89,7 @@ def test_transform_round_trip(python_type: Type, kwargs: Dict[str, Any]): reconstructed_value = type_transformer.to_python_value(ctx, literal_value, type(python_value)) - assert reconstructed_value == python_value - assert reconstructed_value.schema() == python_value.schema() + # assert reconstructed_value == python_value @pytest.mark.parametrize( @@ -87,23 +99,25 @@ def test_transform_round_trip(python_type: Type, kwargs: Dict[str, Any]): (ConfigRequired, {"model_config": {"foo": TrainConfig(loss="mse")}}), (ConfigWithFlyteFiles, {"flytefiles": ["s3://foo/bar"]}), (ConfigWithFlyteDirs, {"flytedirs": ["s3://foo/bar"]}), + (ConfigWithPandasDataFrame, {"df": pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})}), ], ) def test_pass_to_workflow(config_type: Type, kwargs: Dict[str, Any]): """Test passing a BaseModel instance to a workflow works.""" cfg = config_type(**kwargs) - @task + @flytekit.task def train(cfg: config_type) -> config_type: return cfg - @workflow + @flytekit.workflow def wf(cfg: config_type) -> config_type: return train(cfg=cfg) returned_cfg = wf(cfg=cfg) - assert cfg == returned_cfg + # assert returned_cfg == cfg + # TODO these assertions are not valid for all types @pytest.mark.parametrize( @@ -116,17 +130,17 @@ def test_flytefiles_in_wf(kwargs: Dict[str, Any]): """Test passing a BaseModel instance to a workflow works.""" cfg = ConfigWithFlyteFiles(**kwargs) - @task + @flytekit.task def read(cfg: ConfigWithFlyteFiles) -> str: with open(cfg.flytefiles[0], "r") as f: return f.read() - @workflow + @flytekit.workflow def wf(cfg: ConfigWithFlyteFiles) -> str: - return read(cfg=cfg) + return read(cfg=cfg) # type: ignore string = wf(cfg=cfg) - assert string == "love sosa" + assert string in {"foo", "bar"} # type: ignore @pytest.mark.parametrize( @@ -139,29 +153,33 @@ def test_flytedirs_in_wf(kwargs: Dict[str, Any]): """Test passing a BaseModel instance to a workflow works.""" cfg = ConfigWithFlyteDirs(**kwargs) - @task + @flytekit.task def listdir(cfg: ConfigWithFlyteDirs) -> List[str]: return os.listdir(cfg.flytedirs[0]) - @workflow + @flytekit.workflow def wf(cfg: ConfigWithFlyteDirs) -> List[str]: - return listdir(cfg=cfg) + return listdir(cfg=cfg) # type: ignore dirs = wf(cfg=cfg) - assert len(dirs) == 2 + assert len(dirs) == 2 # type: ignore +# TODO: //Arthur to Fabio this was differente before but now im unsure what the test is doing +# previously a pattern match error was checked that its raised, but isnt it OK that the ChildConfig +# is passed since its a subclass of Config? +# I modified the test to work the other way around, but im not sure if this is what you intended def test_pass_wrong_type_to_workflow(): """Test passing the wrong type raises exception.""" - cfg = ChildConfig() + cfg = Config() - @task - def train(cfg: Config) -> Config: + @flytekit.task + def train(cfg: ChildConfig) -> ChildConfig: return cfg - @workflow - def wf(cfg: Config) -> Config: - return train(cfg=cfg) + @flytekit.workflow + def wf(cfg: ChildConfig) -> ChildConfig: + return train(cfg=cfg) # type: ignore - with pytest.raises(TypeTransformerFailedError, match="The schema"): + with pytest.raises(TypeError): # type: ignore wf(cfg=cfg) From da81828535421b14943f0b1eef96ed744b242852 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 15 Jun 2023 11:59:11 -0700 Subject: [PATCH 29/55] made pydantic basemodel check explicit --- flytekit/clis/sdk_in_container/run.py | 8 +++----- .../flytekitplugins/pydantic/basemodel_transformer.py | 3 --- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 0157c5e9ed..f67f74e671 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -10,6 +10,7 @@ from typing import cast import cloudpickle +import pydantic import rich_click as click import yaml from dataclasses_json import DataClassJsonMixin @@ -111,7 +112,6 @@ class PickleParamType(click.ParamType): def convert( self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] ) -> typing.Any: - uri = FlyteContextManager.current_context().file_access.get_random_local_path() with open(uri, "w+b") as outfile: cloudpickle.dump(value, outfile) @@ -119,7 +119,6 @@ def convert( class DateTimeType(click.DateTime): - _NOW_FMT = "now" _ADDITONAL_FORMATS = [_NOW_FMT] @@ -276,7 +275,6 @@ def get_uri_for_dir( def convert_to_structured_dataset( self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: Directory ) -> Literal: - uri = self.get_uri_for_dir(ctx, value, "00000.parquet") lit = Literal( @@ -338,7 +336,7 @@ def convert_to_union( python_val = converter._click_type.convert(value, param, ctx) literal = converter.convert_to_literal(ctx, param, python_val) return Literal(scalar=Scalar(union=Union(literal, variant))) - except (Exception or AttributeError) as e: + except Exception or AttributeError as e: logging.debug(f"Failed to convert python type {python_type} to literal type {variant}", e) raise ValueError(f"Failed to convert python type {self._python_type} to literal type {lt}") @@ -399,7 +397,7 @@ def convert_to_struct( Convert the loaded json object to a Flyte Literal struct type. """ if type(value) != self._python_type: - if hasattr(self._python_type, "parse_raw"): # e.g pydantic basemodel + if issubclass(self._python_type, pydantic.BaseModel): o = self._python_type.parse_raw(json.dumps(value)) else: o = cast(DataClassJsonMixin, self._python_type).from_json(json.dumps(value)) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py index 5dc3d3cca3..fbb538af15 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py @@ -55,9 +55,6 @@ def to_python_value( def read_basemodel_json_from_literalmap(lv: BaseModelLiteralValue) -> serialization.SerializedBaseModel: - """ - Given a LiteralMap, re-hydrate the pydantic BaseModel object from Flyte Literal value. - """ basemodel_literal: literals.Literal = lv.literals[serialization.BASEMODEL_KEY] return object_store.deserialize_flyte_literal(basemodel_literal, str) From 95edbabd0826657ae60842d9627e91f0f8fcdc7a Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 15 Jun 2023 12:32:47 -0700 Subject: [PATCH 30/55] dynamic retrieval of supported flytetypes --- .../pydantic/deserialization.py | 14 ++------- .../flytekitplugins/pydantic/object_store.py | 31 +++++++++++-------- 2 files changed, 21 insertions(+), 24 deletions(-) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py index a67283e959..894a060cbf 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py @@ -14,18 +14,10 @@ def set_validators_on_supported_flyte_types() -> None: """ - Sets the validator on the pydantic model for the - type that is being serialized/deserialized + Set validator on the pydantic model for the type that is being (de-)serialized """ - [set_validators_on_flyte_type(flyte_type) for flyte_type in object_store.PYDANTIC_SUPPORTED_FLYTE_TYPES] - - -def set_validators_on_flyte_type(flyte_type: Type) -> None: - """ - Sets the validator on the pydantic model for the - type that is being serialized/deserialized - """ - setattr(flyte_type, PYDANTIC_VALIDATOR_METHOD_NAME, make_validators_for_type(flyte_type)) + for flyte_type in object_store.PYDANTIC_SUPPORTED_FLYTE_TYPES: + setattr(flyte_type, PYDANTIC_VALIDATOR_METHOD_NAME, make_validators_for_type(flyte_type)) def make_validators_for_type( diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/object_store.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/object_store.py index d6465e101e..d68b7a9017 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/object_store.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/object_store.py @@ -1,23 +1,28 @@ +import builtins +import datetime +import typing import uuid from typing import Any, Dict, Type, TypeVar, cast -import pandas as pd -import torch.nn as nn -from typing_extensions import Annotated, NewType +from typing_extensions import Annotated from flytekit.core import context_manager, type_engine from flytekit.models import literals -from flytekit.types import directory -from flytekit.types.file import file - -PYDANTIC_SUPPORTED_FLYTE_TYPES = ( - nn.Module, - pd.DataFrame, - file.FlyteFile, - directory.FlyteDirectory, - # TODO - add all supported types -) +MODULES_TO_EXLCLUDE_FROM_FLYTE_TYPES = {m.__name__ for m in [builtins, typing, datetime]} + + +def include_in_flyte_types(t: type) -> bool: + if t is None: + return False + if t.__module__ in MODULES_TO_EXLCLUDE_FROM_FLYTE_TYPES: + return False + return True + + +PYDANTIC_SUPPORTED_FLYTE_TYPES = tuple( + filter(include_in_flyte_types, type_engine.TypeEngine.get_available_transformers()) +) LiteralObjID = Annotated[str, "Key for unique object in literal map."] PythonType = TypeVar("PythonType") # target type of the deserialization From b42bb4d7ecc3c01b823b51b687171295e249d556 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 15 Jun 2023 12:33:59 -0700 Subject: [PATCH 31/55] renamed FlyteObjectStore -> PydanticTransformerLiteralStore --- .../flytekitplugins/pydantic/basemodel_transformer.py | 2 +- .../flytekitplugins/pydantic/deserialization.py | 4 ++-- .../flytekitplugins/pydantic/object_store.py | 2 +- .../flytekitplugins/pydantic/serialization.py | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py index fbb538af15..288545789d 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py @@ -60,7 +60,7 @@ def read_basemodel_json_from_literalmap(lv: BaseModelLiteralValue) -> serializat def update_objectstore_from_serialized_basemodel(lv: BaseModelLiteralValue) -> None: - object_store.FlyteObjectStore.read_literalmap(lv.literals[serialization.FLYTETYPES_KEY]) + object_store.PydanticTransformerLiteralStore.read_literalmap(lv.literals[serialization.FLYTETYPES_KEY]) type_engine.TypeEngine.register(BaseModelTransformer()) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py index 894a060cbf..336d37bae4 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py @@ -31,9 +31,9 @@ def validator(object_uid_maybe: Union[object_store.LiteralObjID, Any]) -> Union[ """partial of deserialize_flyte_literal with the object_type fixed""" if not isinstance(object_uid_maybe, str): return object_uid_maybe # this validator should only trigger for the placholders - if object_uid_maybe not in object_store.FlyteObjectStore.get_literal_store(): + if object_uid_maybe not in object_store.PydanticTransformerLiteralStore.get_literal_store(): return object_uid_maybe # if not in the store pass to the next validator to resolve - return object_store.FlyteObjectStore.get_python_object(object_uid_maybe, flyte_obj_type) + return object_store.PydanticTransformerLiteralStore.get_python_object(object_uid_maybe, flyte_obj_type) def validator_generator(*args, **kwags) -> Iterator[Callable[[Any], Serializable]]: """Generator that returns the validator""" diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/object_store.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/object_store.py index d68b7a9017..e60d837bba 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/object_store.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/object_store.py @@ -27,7 +27,7 @@ def include_in_flyte_types(t: type) -> bool: PythonType = TypeVar("PythonType") # target type of the deserialization -class FlyteObjectStore: +class PydanticTransformerLiteralStore: """ This class is an intermediate store for python objects that are being serialized/deserialized. diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py index ae6ee6fae2..45eefa80ee 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py @@ -38,7 +38,7 @@ def serialize_basemodel(basemodel: pydantic.BaseModel) -> literals.LiteralMap: return literals.LiteralMap( { BASEMODEL_KEY: basemodel_literal, # json with flyte types replaced with placeholders - FLYTETYPES_KEY: object_store.FlyteObjectStore.as_literalmap(), # placeholders mapped to flyte types + FLYTETYPES_KEY: object_store.PydanticTransformerLiteralStore.as_literalmap(), # placeholders mapped to flyte types } ) @@ -61,7 +61,7 @@ def serialize_basemodel_to_json_and_store(basemodel: pydantic.BaseModel) -> Seri def encoder(obj: Any) -> Union[str, object_store.LiteralObjID]: if isinstance(obj, object_store.PYDANTIC_SUPPORTED_FLYTE_TYPES): - return object_store.FlyteObjectStore.register_python_object(obj) + return object_store.PydanticTransformerLiteralStore.register_python_object(obj) return basemodel.__json_encoder__(obj) return basemodel.json(encoder=encoder) From 7ccaa959c867b0b594a84544a65bc9eaa5fb29fc Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 15 Jun 2023 13:56:32 -0700 Subject: [PATCH 32/55] nit about a typehint --- flytekit/core/type_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 5994390c8d..428fae98bd 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -710,7 +710,7 @@ def register_additional_type(cls, transformer: TypeTransformer, additional_type: cls._REGISTRY[additional_type] = transformer @classmethod - def get_transformer(cls, python_type: Type) -> TypeTransformer[T]: + def get_transformer(cls, python_type: T) -> TypeTransformer[T]: """ The TypeEngine hierarchy for flyteKit. This method looksup and selects the type transformer. The algorithm is as follows From 515a7488be2aac35b538dbdf188590c994bf60f0 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 15 Jun 2023 13:59:17 -0700 Subject: [PATCH 33/55] small changes to docstrings and types --- .../flytekitplugins/pydantic/object_store.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/object_store.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/object_store.py index e60d837bba..4eab2ded75 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/object_store.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/object_store.py @@ -2,7 +2,7 @@ import datetime import typing import uuid -from typing import Any, Dict, Type, TypeVar, cast +from typing import Any, Dict, Optional, Type, TypeVar, cast from typing_extensions import Annotated @@ -48,20 +48,20 @@ def __contains__(self, item: LiteralObjID) -> bool: @classmethod def get_literal_store(cls): - """Accessor to the class variable""" + """Access the literal store""" return cls._literal_store @classmethod def register_python_object(cls, python_object: object) -> LiteralObjID: - """serializes to literal and returns a unique identifier""" + """Serialize to literal and return a unique identifier.""" serialized_item = serialize_to_flyte_literal(python_object) identifier = make_identifier(python_object) cls.get_literal_store()[identifier] = serialized_item return identifier @classmethod - def get_python_object(cls, identifier: LiteralObjID, expected_type: Type[PythonType]) -> PythonType: - """deserializes a literal and returns the python object""" + def get_python_object(cls, identifier: LiteralObjID, expected_type: Type[PythonType]) -> Optional[PythonType]: + """Deserialize a literal and return the python object""" literal = cls.get_literal_store()[identifier] python_object = deserialize_flyte_literal(literal, expected_type) return python_object @@ -76,20 +76,20 @@ def as_literalmap(cls) -> literals.LiteralMap: @classmethod def read_literalmap(cls, literal_map: literals.LiteralMap) -> None: """ - Reads a literal map and populates the object store from it + Read a literal map and populate the object store from it """ literal_store = cls.get_literal_store() literal_store.update(literal_map.literals) -def deserialize_flyte_literal(flyteobj_literal: literals.Literal, python_type: Type[PythonType]) -> PythonType: - """ - Deserializes a Flyte Literal into the python object instance. - """ +def deserialize_flyte_literal( + flyteobj_literal: literals.Literal, python_type: Type[PythonType] +) -> Optional[PythonType]: + """Deserialize a Flyte Literal into the python object instance.""" ctx = context_manager.FlyteContext.current_context() transformer = type_engine.TypeEngine.get_transformer(python_type) python_obj = transformer.to_python_value(ctx, flyteobj_literal, python_type) - return cast(PythonType, python_obj) + return python_obj def serialize_to_flyte_literal(python_obj) -> literals.Literal: From 24a4d0e7d43165b62aaeaccbbaf3c79c1a76655d Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 15 Jun 2023 15:47:48 -0700 Subject: [PATCH 34/55] v2.1 w/ basemodel specific object store --- .../pydantic/basemodel_transformer.py | 17 ++-- .../pydantic/deserialization.py | 6 +- .../flytekitplugins/pydantic/object_store.py | 94 ++++++++++++------- .../flytekitplugins/pydantic/serialization.py | 33 +++++-- .../tests/test_type_transformer.py | 20 +++- 5 files changed, 114 insertions(+), 56 deletions(-) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py index 288545789d..c3d9eb0472 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py @@ -1,4 +1,4 @@ -from typing import Type +from typing import Optional, Type from typing_extensions import Annotated import pydantic @@ -49,18 +49,17 @@ def to_python_value( expected_python_type: Type[pydantic.BaseModel], ) -> pydantic.BaseModel: """Re-hydrate the pydantic pydantic.BaseModel object from Flyte Literal value.""" - update_objectstore_from_serialized_basemodel(lv) basemodel_json_w_placeholders = read_basemodel_json_from_literalmap(lv) - return expected_python_type.parse_raw(basemodel_json_w_placeholders) + flyte_obj_literalmap = lv.literals[serialization.FLYTETYPE_OBJSTORE__KEY] + with object_store.PydanticTransformerLiteralStore.attach_to_literalmap(flyte_obj_literalmap): + return expected_python_type.parse_raw(basemodel_json_w_placeholders) def read_basemodel_json_from_literalmap(lv: BaseModelLiteralValue) -> serialization.SerializedBaseModel: - basemodel_literal: literals.Literal = lv.literals[serialization.BASEMODEL_KEY] - return object_store.deserialize_flyte_literal(basemodel_literal, str) - - -def update_objectstore_from_serialized_basemodel(lv: BaseModelLiteralValue) -> None: - object_store.PydanticTransformerLiteralStore.read_literalmap(lv.literals[serialization.FLYTETYPES_KEY]) + basemodel_literal: literals.Literal = lv.literals[serialization.BASEMODEL_JSON_KEY] + basemodel_json_w_placeholders = object_store.deserialize_flyte_literal(basemodel_literal, str) + assert isinstance(basemodel_json_w_placeholders, str) + return basemodel_json_w_placeholders type_engine.TypeEngine.register(BaseModelTransformer()) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py index 336d37bae4..843e6f598a 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py @@ -4,7 +4,7 @@ from flytekit.types import directory, file -from . import object_store +from flytekitplugins.pydantic import object_store # this field is used by pydantic to get the validator method PYDANTIC_VALIDATOR_METHOD_NAME = pydantic.BaseModel.__get_validators__.__name__ @@ -31,8 +31,8 @@ def validator(object_uid_maybe: Union[object_store.LiteralObjID, Any]) -> Union[ """partial of deserialize_flyte_literal with the object_type fixed""" if not isinstance(object_uid_maybe, str): return object_uid_maybe # this validator should only trigger for the placholders - if object_uid_maybe not in object_store.PydanticTransformerLiteralStore.get_literal_store(): - return object_uid_maybe # if not in the store pass to the next validator to resolve + if not object_store.PydanticTransformerLiteralStore.is_attached(): + return object_uid_maybe return object_store.PydanticTransformerLiteralStore.get_python_object(object_uid_maybe, flyte_obj_type) def validator_generator(*args, **kwags) -> Iterator[Callable[[Any], Serializable]]: diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/object_store.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/object_store.py index 4eab2ded75..67395e5166 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/object_store.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/object_store.py @@ -1,11 +1,14 @@ import builtins +import contextlib import datetime import typing import uuid -from typing import Any, Dict, Optional, Type, TypeVar, cast +from typing import Any, Dict, Generator, Optional, Type, TypeVar, cast from typing_extensions import Annotated +import pydantic + from flytekit.core import context_manager, type_engine from flytekit.models import literals @@ -23,7 +26,9 @@ def include_in_flyte_types(t: type) -> bool: PYDANTIC_SUPPORTED_FLYTE_TYPES = tuple( filter(include_in_flyte_types, type_engine.TypeEngine.get_available_transformers()) ) +ObjectStoreID = Annotated[str, "Key for unique literalmap of a serialized basemodel."] LiteralObjID = Annotated[str, "Key for unique object in literal map."] +LiteralStore = Annotated[Dict[LiteralObjID, literals.Literal], "uid to literals for a serialized BaseModel"] PythonType = TypeVar("PythonType") # target type of the deserialization @@ -35,51 +40,67 @@ class PydanticTransformerLiteralStore: On deserialization of a basemodel, flyte objects are deserialized from this object store. """ - _literal_store: Dict[LiteralObjID, literals.Literal] = {} + _literal_stores: Dict[ObjectStoreID, LiteralStore] = {} # for each basemodel, one literal store is kept + _attached_literalstore: Optional[LiteralStore] = None # when deserializing we attach to this with a context manager - def __setattr__(self, name: str, value: Any) -> None: - raise Exception("Attributes should not be set on the FlyteObjectStore.") + ### When serializing, we instantiate a new literal store for each basemodel + def __init__(self, uid: ObjectStoreID) -> None: + self._basemodel_uid = uid - def __init__(self) -> None: - raise Exception("This should not be instantiated, it is a singleton object store.") + def as_literalmap(self) -> literals.LiteralMap: + """ + Converts the object store to a literal map + """ + return literals.LiteralMap(literals=self._get_literal_store()) - def __contains__(self, item: LiteralObjID) -> bool: - return item in self.get_literal_store() + def register_python_object(self, python_object: object) -> LiteralObjID: + """Serialize to literal and return a unique identifier.""" + serialized_item = serialize_to_flyte_literal(python_object) + identifier = make_identifier_for_serializeable(python_object) + self._get_literal_store()[identifier] = serialized_item + return identifier @classmethod - def get_literal_store(cls): - """Access the literal store""" - return cls._literal_store + def from_basemodel(cls, basemodel: pydantic.BaseModel) -> "PydanticTransformerLiteralStore": + """Attach to a BaseModel to write to the literal store""" + internal_basemodel_uid = cls._basemodel_uid = make_identifier_for_basemodel(basemodel) + assert internal_basemodel_uid not in cls._literal_stores, "Every serialization must have unique basemodel uid" + cls._literal_stores[internal_basemodel_uid] = {} + return cls(internal_basemodel_uid) + + ## When deserializing, we attach to the literal store @classmethod - def register_python_object(cls, python_object: object) -> LiteralObjID: - """Serialize to literal and return a unique identifier.""" - serialized_item = serialize_to_flyte_literal(python_object) - identifier = make_identifier(python_object) - cls.get_literal_store()[identifier] = serialized_item - return identifier + @contextlib.contextmanager + def attach_to_literalmap(cls, literal_map: literals.LiteralMap) -> Generator[None, None, None]: + """ + Read a literal map and populate the object store from it + """ + # TODO make thread safe? + try: + cls._attached_literalstore = literal_map.literals + yield + finally: + cls._attached_literalstore = None + + @classmethod + def is_attached(cls) -> bool: + return cls._attached_literalstore is not None @classmethod def get_python_object(cls, identifier: LiteralObjID, expected_type: Type[PythonType]) -> Optional[PythonType]: """Deserialize a literal and return the python object""" - literal = cls.get_literal_store()[identifier] + literal = cls._get_literal_store()[identifier] python_object = deserialize_flyte_literal(literal, expected_type) return python_object + ## Private methods @classmethod - def as_literalmap(cls) -> literals.LiteralMap: - """ - Converts the object store to a literal map - """ - return literals.LiteralMap(literals=cls.get_literal_store()) - - @classmethod - def read_literalmap(cls, literal_map: literals.LiteralMap) -> None: - """ - Read a literal map and populate the object store from it - """ - literal_store = cls.get_literal_store() - literal_store.update(literal_map.literals) + def _get_literal_store(cls) -> LiteralStore: + if cls.is_attached(): + return cls._attached_literalstore # type: ignore there is always a literal store when attached + else: + return cls._literal_stores[cls._basemodel_uid] def deserialize_flyte_literal( @@ -103,10 +124,17 @@ def serialize_to_flyte_literal(python_obj) -> literals.Literal: return literal_obj -def make_identifier(python_type: object) -> LiteralObjID: +def make_identifier_for_basemodel(basemodel: pydantic.BaseModel) -> ObjectStoreID: + """ + Create a unique identifier for a basemodel. + """ + unique_id = f"{basemodel.__class__.__name__}_{uuid.uuid4().hex}" + return cast(ObjectStoreID, unique_id) + + +def make_identifier_for_serializeable(python_type: object) -> LiteralObjID: """ Create a unique identifier for a python object. """ - # TODO - human readable way to identify the object unique_id = f"{type(python_type).__name__}_{uuid.uuid4().hex}" return cast(LiteralObjID, unique_id) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py index 45eefa80ee..c23993d0f4 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py @@ -20,8 +20,8 @@ from . import object_store -BASEMODEL_KEY = cast(object_store.LiteralObjID, "BaseModel") -FLYTETYPES_KEY = cast(object_store.LiteralObjID, "Flyte Object Store") +BASEMODEL_JSON_KEY = "BaseModel JSON" +FLYTETYPE_OBJSTORE__KEY = "Flyte Object Store" SerializedBaseModel = Annotated[str, "A pydantic BaseModel that has been serialized with placeholders for Flyte types."] @@ -32,28 +32,41 @@ def serialize_basemodel(basemodel: pydantic.BaseModel) -> literals.LiteralMap: The BaseModel is first serialized into a JSON format, where all Flyte types are replaced with unique placeholder strings. The Flyte Types are serialized into separate Flyte literals """ - - basemodel_json = serialize_basemodel_to_json_and_store(basemodel) - basemodel_literal = make_literal_from_json(basemodel_json) + store = object_store.PydanticTransformerLiteralStore.from_basemodel(basemodel) + basemodel_literal = serialize_basemodel_to_literal(basemodel, store) + store_literal = store.as_literalmap() return literals.LiteralMap( { - BASEMODEL_KEY: basemodel_literal, # json with flyte types replaced with placeholders - FLYTETYPES_KEY: object_store.PydanticTransformerLiteralStore.as_literalmap(), # placeholders mapped to flyte types + BASEMODEL_JSON_KEY: basemodel_literal, # json with flyte types replaced with placeholders + FLYTETYPE_OBJSTORE__KEY: store_literal, # placeholders mapped to flyte types } ) +def serialize_basemodel_to_literal( + basemodel: pydantic.BaseModel, + flyteobject_store: object_store.PydanticTransformerLiteralStore, +) -> literals.Literal: + """ """ + basemodel_json = serialize_basemodel_to_json_and_store(basemodel, flyteobject_store) + basemodel_literal = make_literal_from_json(basemodel_json) + return basemodel_literal + + def make_literal_from_json(json: str) -> literals.Literal: """ Converts the json representation of a pydantic BaseModel to a Flyte Literal. """ # serialize as a string literal ctx = context_manager.FlyteContext.current_context() - string_transformer = type_engine.TypeEngine.get_transformer(str) + string_transformer = type_engine.TypeEngine.get_transformer(json) return string_transformer.to_literal(ctx, json, str, string_transformer.get_literal_type(str)) -def serialize_basemodel_to_json_and_store(basemodel: pydantic.BaseModel) -> SerializedBaseModel: +def serialize_basemodel_to_json_and_store( + basemodel: pydantic.BaseModel, + flyteobject_store: object_store.PydanticTransformerLiteralStore, +) -> SerializedBaseModel: """ Serialize a pydantic BaseModel to json and protobuf, separating out the Flyte types into a separate store. On deserialization, the store is used to reconstruct the Flyte types. @@ -61,7 +74,7 @@ def serialize_basemodel_to_json_and_store(basemodel: pydantic.BaseModel) -> Seri def encoder(obj: Any) -> Union[str, object_store.LiteralObjID]: if isinstance(obj, object_store.PYDANTIC_SUPPORTED_FLYTE_TYPES): - return object_store.PydanticTransformerLiteralStore.register_python_object(obj) + return flyteobject_store.register_python_object(obj) return basemodel.__json_encoder__(obj) return basemodel.json(encoder=encoder) diff --git a/plugins/flytekit-pydantic/tests/test_type_transformer.py b/plugins/flytekit-pydantic/tests/test_type_transformer.py index 47de3b083d..4786da26e0 100644 --- a/plugins/flytekit-pydantic/tests/test_type_transformer.py +++ b/plugins/flytekit-pydantic/tests/test_type_transformer.py @@ -7,7 +7,6 @@ from pydantic import BaseModel, Extra import flytekit -from flytekit.core import type_engine from flytekit.types import directory from flytekit.types.file import file @@ -52,6 +51,9 @@ class ConfigWithPandasDataFrame(BaseModel): df: pd.DataFrame + class Config: + arbitrary_types_allowed = True + class ChildConfig(Config): """Child class config BaseModel for testing purposes.""" @@ -165,6 +167,22 @@ def wf(cfg: ConfigWithFlyteDirs) -> List[str]: assert len(dirs) == 2 # type: ignore +def test_double_config_in_wf(): + """Test passing a BaseModel instance to a workflow works.""" + cfg1 = TrainConfig(batch_size=13) + cfg2 = TrainConfig(batch_size=31) + + @flytekit.task + def are_different(cfg1: TrainConfig, cfg2: TrainConfig) -> bool: + return cfg1 != cfg2 + + @flytekit.workflow + def wf(cfg1: TrainConfig, cfg2: TrainConfig) -> bool: + return are_different(cfg1=cfg1, cfg2=cfg2) # type: ignore + + assert wf(cfg1=cfg1, cfg2=cfg2), wf(cfg1=cfg1, cfg2=cfg2) # type: ignore + + # TODO: //Arthur to Fabio this was differente before but now im unsure what the test is doing # previously a pattern match error was checked that its raised, but isnt it OK that the ChildConfig # is passed since its a subclass of Config? From 68a28440a9fcfc7dbe3749f0b04de68b20cdb66b Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 15 Jun 2023 16:13:04 -0700 Subject: [PATCH 35/55] refactored --- .../pydantic/basemodel_transformer.py | 8 +- .../flytekitplugins/pydantic/commons.py | 27 ++++ .../pydantic/deserialization.py | 74 ++++++++- .../flytekitplugins/pydantic/object_store.py | 140 ------------------ .../flytekitplugins/pydantic/serialization.py | 62 +++++++- 5 files changed, 154 insertions(+), 157 deletions(-) create mode 100644 plugins/flytekit-pydantic/flytekitplugins/pydantic/commons.py delete mode 100644 plugins/flytekit-pydantic/flytekitplugins/pydantic/object_store.py diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py index c3d9eb0472..71b4b3fcd1 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py @@ -1,4 +1,4 @@ -from typing import Optional, Type +from typing import Type from typing_extensions import Annotated import pydantic @@ -6,7 +6,7 @@ from flytekit.core import type_engine from flytekit.models import literals, types -from . import object_store, serialization +from . import serialization, deserialization """ Serializes & deserializes the pydantic basemodels @@ -51,13 +51,13 @@ def to_python_value( """Re-hydrate the pydantic pydantic.BaseModel object from Flyte Literal value.""" basemodel_json_w_placeholders = read_basemodel_json_from_literalmap(lv) flyte_obj_literalmap = lv.literals[serialization.FLYTETYPE_OBJSTORE__KEY] - with object_store.PydanticTransformerLiteralStore.attach_to_literalmap(flyte_obj_literalmap): + with deserialization.PydanticDeserializationLiteralStore.attach(flyte_obj_literalmap): return expected_python_type.parse_raw(basemodel_json_w_placeholders) def read_basemodel_json_from_literalmap(lv: BaseModelLiteralValue) -> serialization.SerializedBaseModel: basemodel_literal: literals.Literal = lv.literals[serialization.BASEMODEL_JSON_KEY] - basemodel_json_w_placeholders = object_store.deserialize_flyte_literal(basemodel_literal, str) + basemodel_json_w_placeholders = deserialization.deserialize_flyte_literal(basemodel_literal, str) assert isinstance(basemodel_json_w_placeholders, str) return basemodel_json_w_placeholders diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/commons.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/commons.py new file mode 100644 index 0000000000..fcbde5ac05 --- /dev/null +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/commons.py @@ -0,0 +1,27 @@ +import builtins +import datetime +import typing + +from typing_extensions import Annotated + + +from flytekit.core import type_engine + +MODULES_TO_EXLCLUDE_FROM_FLYTE_TYPES = {m.__name__ for m in [builtins, typing, datetime]} + + +def include_in_flyte_types(t: type) -> bool: + if t is None: + return False + if t.__module__ in MODULES_TO_EXLCLUDE_FROM_FLYTE_TYPES: + return False + return True + + +PYDANTIC_SUPPORTED_FLYTE_TYPES = tuple( + filter(include_in_flyte_types, type_engine.TypeEngine.get_available_transformers()) +) + +# this is the UUID placeholder that is set in the serialized basemodel JSON, connecting that field to +# the literal map that holds the actual object that needs to be deserialized (w/ protobuf) +LiteralObjID = Annotated[str, "Key for unique object in literal map."] diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py index 843e6f598a..730bb70806 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py @@ -1,10 +1,13 @@ -from typing import Any, Callable, Dict, Iterator, List, Type, TypeVar, Union +import contextlib +from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Type, TypeVar, Union import pydantic +from flytekit.core import context_manager, type_engine +from flytekit.models import literals from flytekit.types import directory, file -from flytekitplugins.pydantic import object_store +from flytekitplugins.pydantic import commons, serialization # this field is used by pydantic to get the validator method PYDANTIC_VALIDATOR_METHOD_NAME = pydantic.BaseModel.__get_validators__.__name__ @@ -12,11 +15,60 @@ Serializable = TypeVar("Serializable") # flyte object type +class PydanticDeserializationLiteralStore: + """ + The purpose of this class is to provide a context manager that can be used to deserialize a basemodel from a + literal map. + + Because pydantic validators are grabbed when subclassing a BaseModel, this object is a singleton that + serves as a namesspace that can be set with the attach_to_literalmap context manager for the time that + a basemode is being deserialized. The validators are then accessing this namespace for the flyteobj + placeholders that it is trying to deserialize. + """ + + literal_store: Optional[serialization.LiteralStore] = None # attachement point for the literal map + + def __init__(self) -> None: + raise Exception("This class should not be instantiated") + + def __init_subclass__(cls) -> None: + raise Exception("This class should not be subclassed") + + @classmethod + @contextlib.contextmanager + def attach(cls, literal_map: literals.LiteralMap) -> Generator[None, None, None]: + """ + Read a literal map and populate the object store from it + """ + # TODO make thread safe? + try: + cls.literal_store = literal_map.literals + yield + finally: + cls.literal_store = None + + @classmethod + def is_attached(cls) -> bool: + return cls.literal_store is not None + + @classmethod + def get_python_object( + cls, identifier: commons.LiteralObjID, expected_type: Type[PythonType] + ) -> Optional[PythonType]: + """Deserialize a literal and return the python object""" + if not cls.is_attached(): + raise Exception("Must attach to a literal map before deserializing") + assert cls.literal_store is not None + literal = cls.literal_store[identifier] + python_object = deserialize_flyte_literal(literal, expected_type) + return python_object + + def set_validators_on_supported_flyte_types() -> None: """ Set validator on the pydantic model for the type that is being (de-)serialized """ - for flyte_type in object_store.PYDANTIC_SUPPORTED_FLYTE_TYPES: + for flyte_type in commons.PYDANTIC_SUPPORTED_FLYTE_TYPES: setattr(flyte_type, PYDANTIC_VALIDATOR_METHOD_NAME, make_validators_for_type(flyte_type)) @@ -27,13 +79,13 @@ def make_validators_for_type( Returns a validator that can be used by pydantic to deserialize the object """ - def validator(object_uid_maybe: Union[object_store.LiteralObjID, Any]) -> Union[Serializable, Any]: + def validator(object_uid_maybe: Union[commons.LiteralObjID, Any]) -> Union[Serializable, Any]: """partial of deserialize_flyte_literal with the object_type fixed""" if not isinstance(object_uid_maybe, str): return object_uid_maybe # this validator should only trigger for the placholders - if not object_store.PydanticTransformerLiteralStore.is_attached(): + if not PydanticDeserializationLiteralStore.is_attached(): return object_uid_maybe - return object_store.PydanticTransformerLiteralStore.get_python_object(object_uid_maybe, flyte_obj_type) + return PydanticDeserializationLiteralStore.get_python_object(object_uid_maybe, flyte_obj_type) def validator_generator(*args, **kwags) -> Iterator[Callable[[Any], Serializable]]: """Generator that returns the validator""" @@ -67,3 +119,13 @@ def validate_flytedir(flytedir: Union[str, directory.FlyteDirectory]) -> directo file.FlyteFile: [validate_flytefile], directory.FlyteDirectory: [validate_flytedir], } + + +def deserialize_flyte_literal( + flyteobj_literal: literals.Literal, python_type: Type[PythonType] +) -> Optional[PythonType]: + """Deserialize a Flyte Literal into the python object instance.""" + ctx = context_manager.FlyteContext.current_context() + transformer = type_engine.TypeEngine.get_transformer(python_type) + python_obj = transformer.to_python_value(ctx, flyteobj_literal, python_type) + return python_obj diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/object_store.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/object_store.py deleted file mode 100644 index 67395e5166..0000000000 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/object_store.py +++ /dev/null @@ -1,140 +0,0 @@ -import builtins -import contextlib -import datetime -import typing -import uuid -from typing import Any, Dict, Generator, Optional, Type, TypeVar, cast - -from typing_extensions import Annotated - -import pydantic - -from flytekit.core import context_manager, type_engine -from flytekit.models import literals - -MODULES_TO_EXLCLUDE_FROM_FLYTE_TYPES = {m.__name__ for m in [builtins, typing, datetime]} - - -def include_in_flyte_types(t: type) -> bool: - if t is None: - return False - if t.__module__ in MODULES_TO_EXLCLUDE_FROM_FLYTE_TYPES: - return False - return True - - -PYDANTIC_SUPPORTED_FLYTE_TYPES = tuple( - filter(include_in_flyte_types, type_engine.TypeEngine.get_available_transformers()) -) -ObjectStoreID = Annotated[str, "Key for unique literalmap of a serialized basemodel."] -LiteralObjID = Annotated[str, "Key for unique object in literal map."] -LiteralStore = Annotated[Dict[LiteralObjID, literals.Literal], "uid to literals for a serialized BaseModel"] -PythonType = TypeVar("PythonType") # target type of the deserialization - - -class PydanticTransformerLiteralStore: - """ - This class is an intermediate store for python objects that are being serialized/deserialized. - - On serialization of a basemodel, flyte objects are serialized and stored in this object store. - On deserialization of a basemodel, flyte objects are deserialized from this object store. - """ - - _literal_stores: Dict[ObjectStoreID, LiteralStore] = {} # for each basemodel, one literal store is kept - _attached_literalstore: Optional[LiteralStore] = None # when deserializing we attach to this with a context manager - - ### When serializing, we instantiate a new literal store for each basemodel - def __init__(self, uid: ObjectStoreID) -> None: - self._basemodel_uid = uid - - def as_literalmap(self) -> literals.LiteralMap: - """ - Converts the object store to a literal map - """ - return literals.LiteralMap(literals=self._get_literal_store()) - - def register_python_object(self, python_object: object) -> LiteralObjID: - """Serialize to literal and return a unique identifier.""" - serialized_item = serialize_to_flyte_literal(python_object) - identifier = make_identifier_for_serializeable(python_object) - self._get_literal_store()[identifier] = serialized_item - return identifier - - @classmethod - def from_basemodel(cls, basemodel: pydantic.BaseModel) -> "PydanticTransformerLiteralStore": - """Attach to a BaseModel to write to the literal store""" - internal_basemodel_uid = cls._basemodel_uid = make_identifier_for_basemodel(basemodel) - assert internal_basemodel_uid not in cls._literal_stores, "Every serialization must have unique basemodel uid" - cls._literal_stores[internal_basemodel_uid] = {} - return cls(internal_basemodel_uid) - - ## When deserializing, we attach to the literal store - - @classmethod - @contextlib.contextmanager - def attach_to_literalmap(cls, literal_map: literals.LiteralMap) -> Generator[None, None, None]: - """ - Read a literal map and populate the object store from it - """ - # TODO make thread safe? - try: - cls._attached_literalstore = literal_map.literals - yield - finally: - cls._attached_literalstore = None - - @classmethod - def is_attached(cls) -> bool: - return cls._attached_literalstore is not None - - @classmethod - def get_python_object(cls, identifier: LiteralObjID, expected_type: Type[PythonType]) -> Optional[PythonType]: - """Deserialize a literal and return the python object""" - literal = cls._get_literal_store()[identifier] - python_object = deserialize_flyte_literal(literal, expected_type) - return python_object - - ## Private methods - @classmethod - def _get_literal_store(cls) -> LiteralStore: - if cls.is_attached(): - return cls._attached_literalstore # type: ignore there is always a literal store when attached - else: - return cls._literal_stores[cls._basemodel_uid] - - -def deserialize_flyte_literal( - flyteobj_literal: literals.Literal, python_type: Type[PythonType] -) -> Optional[PythonType]: - """Deserialize a Flyte Literal into the python object instance.""" - ctx = context_manager.FlyteContext.current_context() - transformer = type_engine.TypeEngine.get_transformer(python_type) - python_obj = transformer.to_python_value(ctx, flyteobj_literal, python_type) - return python_obj - - -def serialize_to_flyte_literal(python_obj) -> literals.Literal: - """ - Use the Flyte TypeEngine to serialize a python object to a Flyte Literal. - """ - python_type = type(python_obj) - ctx = context_manager.FlyteContextManager().current_context() - literal_type = type_engine.TypeEngine.to_literal_type(python_type) - literal_obj = type_engine.TypeEngine.to_literal(ctx, python_obj, python_type, literal_type) - return literal_obj - - -def make_identifier_for_basemodel(basemodel: pydantic.BaseModel) -> ObjectStoreID: - """ - Create a unique identifier for a basemodel. - """ - unique_id = f"{basemodel.__class__.__name__}_{uuid.uuid4().hex}" - return cast(ObjectStoreID, unique_id) - - -def make_identifier_for_serializeable(python_type: object) -> LiteralObjID: - """ - Create a unique identifier for a python object. - """ - unique_id = f"{type(python_type).__name__}_{uuid.uuid4().hex}" - return cast(LiteralObjID, unique_id) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py index c23993d0f4..8ff8e27a75 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py @@ -8,8 +8,9 @@ 3. Return a literal map with the json and the flyte object store represented as a literalmap {placeholder: flyte type} """ -from typing import Any, NamedTuple, Union, cast +from typing import Any, Dict, NamedTuple, Union, cast from typing_extensions import Annotated +import uuid import pydantic from google.protobuf import struct_pb2 @@ -17,7 +18,7 @@ from flytekit.models import literals from flytekit.core import context_manager, type_engine -from . import object_store +from . import commons BASEMODEL_JSON_KEY = "BaseModel JSON" @@ -25,6 +26,34 @@ SerializedBaseModel = Annotated[str, "A pydantic BaseModel that has been serialized with placeholders for Flyte types."] +ObjectStoreID = Annotated[str, "Key for unique literalmap of a serialized basemodel."] +LiteralObjID = Annotated[str, "Key for unique object in literal map."] +LiteralStore = Annotated[Dict[LiteralObjID, literals.Literal], "uid to literals for a serialized BaseModel"] + + +class BaseModelFlyteObjectStore: + """ + This class is an intermediate store for python objects that are being serialized/deserialized. + + On serialization of a basemodel, flyte objects are serialized and stored in this object store. + """ + + def __init__(self) -> None: + self.literal_store: LiteralStore = dict() + + def register_python_object(self, python_object: object) -> LiteralObjID: + """Serialize to literal and return a unique identifier.""" + serialized_item = serialize_to_flyte_literal(python_object) + identifier = make_identifier_for_serializeable(python_object) + self.literal_store[identifier] = serialized_item + return identifier + + def as_literalmap(self) -> literals.LiteralMap: + """ + Converts the object store to a literal map + """ + return literals.LiteralMap(literals=self.literal_store) + def serialize_basemodel(basemodel: pydantic.BaseModel) -> literals.LiteralMap: """ @@ -32,7 +61,7 @@ def serialize_basemodel(basemodel: pydantic.BaseModel) -> literals.LiteralMap: The BaseModel is first serialized into a JSON format, where all Flyte types are replaced with unique placeholder strings. The Flyte Types are serialized into separate Flyte literals """ - store = object_store.PydanticTransformerLiteralStore.from_basemodel(basemodel) + store = BaseModelFlyteObjectStore() basemodel_literal = serialize_basemodel_to_literal(basemodel, store) store_literal = store.as_literalmap() return literals.LiteralMap( @@ -45,7 +74,7 @@ def serialize_basemodel(basemodel: pydantic.BaseModel) -> literals.LiteralMap: def serialize_basemodel_to_literal( basemodel: pydantic.BaseModel, - flyteobject_store: object_store.PydanticTransformerLiteralStore, + flyteobject_store: BaseModelFlyteObjectStore, ) -> literals.Literal: """ """ basemodel_json = serialize_basemodel_to_json_and_store(basemodel, flyteobject_store) @@ -65,16 +94,35 @@ def make_literal_from_json(json: str) -> literals.Literal: def serialize_basemodel_to_json_and_store( basemodel: pydantic.BaseModel, - flyteobject_store: object_store.PydanticTransformerLiteralStore, + flyteobject_store: BaseModelFlyteObjectStore, ) -> SerializedBaseModel: """ Serialize a pydantic BaseModel to json and protobuf, separating out the Flyte types into a separate store. On deserialization, the store is used to reconstruct the Flyte types. """ - def encoder(obj: Any) -> Union[str, object_store.LiteralObjID]: - if isinstance(obj, object_store.PYDANTIC_SUPPORTED_FLYTE_TYPES): + def encoder(obj: Any) -> Union[str, commons.LiteralObjID]: + if isinstance(obj, commons.PYDANTIC_SUPPORTED_FLYTE_TYPES): return flyteobject_store.register_python_object(obj) return basemodel.__json_encoder__(obj) return basemodel.json(encoder=encoder) + + +def serialize_to_flyte_literal(python_obj: object) -> literals.Literal: + """ + Use the Flyte TypeEngine to serialize a python object to a Flyte Literal. + """ + python_type = type(python_obj) + ctx = context_manager.FlyteContextManager().current_context() + literal_type = type_engine.TypeEngine.to_literal_type(python_type) + literal_obj = type_engine.TypeEngine.to_literal(ctx, python_obj, python_type, literal_type) + return literal_obj + + +def make_identifier_for_serializeable(python_type: object) -> LiteralObjID: + """ + Create a unique identifier for a python object. + """ + unique_id = f"{type(python_type).__name__}_{uuid.uuid4().hex}" + return cast(LiteralObjID, unique_id) From 1466ff86449e8b9486b21d573865ed133c25a5c7 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 15 Jun 2023 17:05:17 -0700 Subject: [PATCH 36/55] accomodate case where user has set validators on the type themselves --- .../flytekitplugins/pydantic/deserialization.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py index 730bb70806..32c25aa71e 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py @@ -79,6 +79,8 @@ def make_validators_for_type( Returns a validator that can be used by pydantic to deserialize the object """ + previous_validators = getattr(flyte_obj_type, PYDANTIC_VALIDATOR_METHOD_NAME, lambda *_: [])() + def validator(object_uid_maybe: Union[commons.LiteralObjID, Any]) -> Union[Serializable, Any]: """partial of deserialize_flyte_literal with the object_type fixed""" if not isinstance(object_uid_maybe, str): @@ -90,6 +92,7 @@ def validator(object_uid_maybe: Union[commons.LiteralObjID, Any]) -> Union[Seria def validator_generator(*args, **kwags) -> Iterator[Callable[[Any], Serializable]]: """Generator that returns the validator""" yield validator + yield from previous_validators yield from additional_flytetype_validators.get(flyte_obj_type, []) return validator_generator From 4c022decb7cb9120050de3ab7ea281dd722b1138 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arthur=20B=C3=B6=C3=B6k?= <49250723+ArthurBook@users.noreply.github.com> Date: Tue, 20 Jun 2023 08:56:22 -0700 Subject: [PATCH 37/55] Update plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py Co-authored-by: Fabio M. Graetz, Ph.D. --- .../flytekitplugins/pydantic/basemodel_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py index 71b4b3fcd1..66e89b3e3d 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py @@ -39,7 +39,7 @@ def to_literal( python_type: Type[pydantic.BaseModel], expected: types.LiteralType, ) -> BaseModelLiteralValue: - """This method is used to convert from given python type object pydantic ``pydantic.BaseModel`` to the Literal representation.""" + """Convert a given ``pydantic.BaseModel`` to the Literal representation.""" return serialization.serialize_basemodel(python_val) def to_python_value( From 96ef7e4284e7895042b76d3f80f8148330fd9a7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arthur=20B=C3=B6=C3=B6k?= <49250723+ArthurBook@users.noreply.github.com> Date: Tue, 20 Jun 2023 08:56:30 -0700 Subject: [PATCH 38/55] Update plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py Co-authored-by: Fabio M. Graetz, Ph.D. --- .../flytekitplugins/pydantic/basemodel_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py index 66e89b3e3d..4c00a9e238 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py @@ -48,7 +48,7 @@ def to_python_value( lv: BaseModelLiteralValue, expected_python_type: Type[pydantic.BaseModel], ) -> pydantic.BaseModel: - """Re-hydrate the pydantic pydantic.BaseModel object from Flyte Literal value.""" + """Re-hydrate the pydantic BaseModel object from Flyte Literal value.""" basemodel_json_w_placeholders = read_basemodel_json_from_literalmap(lv) flyte_obj_literalmap = lv.literals[serialization.FLYTETYPE_OBJSTORE__KEY] with deserialization.PydanticDeserializationLiteralStore.attach(flyte_obj_literalmap): From a477d908356f7cb07e0bb136657077fd0696a20d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arthur=20B=C3=B6=C3=B6k?= <49250723+ArthurBook@users.noreply.github.com> Date: Tue, 20 Jun 2023 08:56:59 -0700 Subject: [PATCH 39/55] Update plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py Co-authored-by: Fabio M. Graetz, Ph.D. --- .../flytekitplugins/pydantic/deserialization.py | 1 + 1 file changed, 1 insertion(+) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py index 32c25aa71e..b945155106 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py @@ -9,6 +9,7 @@ from flytekitplugins.pydantic import commons, serialization + # this field is used by pydantic to get the validator method PYDANTIC_VALIDATOR_METHOD_NAME = pydantic.BaseModel.__get_validators__.__name__ PythonType = TypeVar("PythonType") # target type of the deserialization From 056b069417765521a6324a22c380396ccaf51e00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arthur=20B=C3=B6=C3=B6k?= <49250723+ArthurBook@users.noreply.github.com> Date: Tue, 20 Jun 2023 08:57:19 -0700 Subject: [PATCH 40/55] Update plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py Co-authored-by: Fabio M. Graetz, Ph.D. --- .../flytekitplugins/pydantic/deserialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py index b945155106..5e9a5ccbbb 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py @@ -21,7 +21,7 @@ class PydanticDeserializationLiteralStore: The purpose of this class is to provide a context manager that can be used to deserialize a basemodel from a literal map. - Because pydantic validators are grabbed when subclassing a BaseModel, this object is a singleton that + Because pydantic validators are fixed when subclassing a BaseModel, this object is a singleton that serves as a namesspace that can be set with the attach_to_literalmap context manager for the time that a basemode is being deserialized. The validators are then accessing this namespace for the flyteobj placeholders that it is trying to deserialize. From aa33e316b5e0a4e57a2ec5116ec81f504e72eaef Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 20 Jun 2023 09:50:03 -0700 Subject: [PATCH 41/55] fixed assumption that pydantic is installed in typecheck --- flytekit/clis/sdk_in_container/run.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index f67f74e671..2277b18027 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -10,7 +10,6 @@ from typing import cast import cloudpickle -import pydantic import rich_click as click import yaml from dataclasses_json import DataClassJsonMixin @@ -397,8 +396,8 @@ def convert_to_struct( Convert the loaded json object to a Flyte Literal struct type. """ if type(value) != self._python_type: - if issubclass(self._python_type, pydantic.BaseModel): - o = self._python_type.parse_raw(json.dumps(value)) + if is_pydantic_basemodel(self._python_type): + o = self._python_type.parse_raw(json.dumps(value)) # type: ignore else: o = cast(DataClassJsonMixin, self._python_type).from_json(json.dumps(value)) else: @@ -447,6 +446,15 @@ def convert(self, ctx, param, value) -> typing.Union[Literal, typing.Any]: raise click.BadParameter(f"Failed to convert param {param}, {value} to {self._python_type}") from e +def is_pydantic_basemodel(python_type: typing.Type) -> bool: + try: + import pydantic + except ImportError: + return False + else: + return issubclass(python_type, pydantic.BaseModel) + + def to_click_option( ctx: click.Context, flyte_ctx: FlyteContext, From deb42bbb759588e00c81b78534037fe0ce322087 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 20 Jun 2023 11:40:58 -0700 Subject: [PATCH 42/55] comments revised --- .../pydantic/basemodel_transformer.py | 10 ++- .../pydantic/deserialization.py | 62 ++++++++++++------- .../flytekitplugins/pydantic/serialization.py | 47 ++++++-------- .../tests/test_type_transformer.py | 30 +++++++++ 4 files changed, 91 insertions(+), 58 deletions(-) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py index 4c00a9e238..6324ba0fbe 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py @@ -1,3 +1,5 @@ +"""Serializes & deserializes the pydantic basemodels """ + from typing import Type from typing_extensions import Annotated import pydantic @@ -8,16 +10,12 @@ from . import serialization, deserialization -""" -Serializes & deserializes the pydantic basemodels -""" - BaseModelLiteralValue = Annotated[ literals.LiteralMap, """ BaseModel serialized to a LiteralMap consisting of: 1) the basemodel json with placeholders for flyte types - 2) a mapping from placeholders to flyte object store with the flyte types + 2) mapping from placeholders to serialized flyte type values in the object store """, ] @@ -50,7 +48,7 @@ def to_python_value( ) -> pydantic.BaseModel: """Re-hydrate the pydantic BaseModel object from Flyte Literal value.""" basemodel_json_w_placeholders = read_basemodel_json_from_literalmap(lv) - flyte_obj_literalmap = lv.literals[serialization.FLYTETYPE_OBJSTORE__KEY] + flyte_obj_literalmap = lv.literals[serialization.FLYTETYPE_OBJSTORE_KEY] with deserialization.PydanticDeserializationLiteralStore.attach(flyte_obj_literalmap): return expected_python_type.parse_raw(basemodel_json_w_placeholders) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py index 5e9a5ccbbb..70a56add1a 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py @@ -11,9 +11,12 @@ # this field is used by pydantic to get the validator method -PYDANTIC_VALIDATOR_METHOD_NAME = pydantic.BaseModel.__get_validators__.__name__ +PYDANTIC_VALIDATOR_METHOD_NAME = ( + pydantic.BaseModel.__get_validators__.__name__ + if pydantic.__version__ < "2.0.0" + else pydantic.BaseModel.__get_pydantic_core_schema__.__name___ # type: ignore +) PythonType = TypeVar("PythonType") # target type of the deserialization -Serializable = TypeVar("Serializable") # flyte object type class PydanticDeserializationLiteralStore: @@ -22,8 +25,8 @@ class PydanticDeserializationLiteralStore: literal map. Because pydantic validators are fixed when subclassing a BaseModel, this object is a singleton that - serves as a namesspace that can be set with the attach_to_literalmap context manager for the time that - a basemode is being deserialized. The validators are then accessing this namespace for the flyteobj + serves as a namespace that can be set with the attach_to_literalmap context manager for the time that + a basemodel is being deserialized. The validators are then accessing this namespace for the flyteobj placeholders that it is trying to deserialize. """ @@ -39,15 +42,24 @@ def __init_subclass__(cls) -> None: @contextlib.contextmanager def attach(cls, literal_map: literals.LiteralMap) -> Generator[None, None, None]: """ - Read a literal map and populate the object store from it + Read a literal map and populate the object store from it. + + This can be used as a context manager to attach to a literal map for the duration of a deserialization + Note that this is not threadsafe, and designed to manage a single deserialization at a time. """ - # TODO make thread safe? + assert not cls.is_attached(), "can only be attached to one literal map at a time." try: cls.literal_store = literal_map.literals yield finally: cls.literal_store = None + @classmethod + def contains(cls, item: commons.LiteralObjID) -> bool: + assert cls.is_attached(), "can only check for existence of a literal when attached to a literal map" + assert cls.literal_store is not None + return item in cls.literal_store + @classmethod def is_attached(cls) -> bool: return cls.literal_store is not None @@ -56,7 +68,7 @@ def is_attached(cls) -> bool: def get_python_object( cls, identifier: commons.LiteralObjID, expected_type: Type[PythonType] ) -> Optional[PythonType]: - """Deserialize a literal and return the python object""" + """Deserialize a flyte literal and return the python object.""" if not cls.is_attached(): raise Exception("Must attach to a literal map before deserializing") assert cls.literal_store is not None @@ -67,40 +79,42 @@ def get_python_object( def set_validators_on_supported_flyte_types() -> None: """ - Set validator on the pydantic model for the type that is being (de-)serialized + Set pydantic validator for the flyte types supported by this plugin. """ for flyte_type in commons.PYDANTIC_SUPPORTED_FLYTE_TYPES: - setattr(flyte_type, PYDANTIC_VALIDATOR_METHOD_NAME, make_validators_for_type(flyte_type)) + setattr(flyte_type, PYDANTIC_VALIDATOR_METHOD_NAME, add_flyte_validators_for_type(flyte_type)) -def make_validators_for_type( - flyte_obj_type: Type[Serializable], -) -> Callable[[Any], Iterator[Callable[[Any], Serializable]]]: +def add_flyte_validators_for_type( + flyte_obj_type: Type[type_engine.T], +) -> Callable[[Any], Iterator[Callable[[Any], type_engine.T]]]: """ - Returns a validator that can be used by pydantic to deserialize the object + Add flyte deserialisation validators to a type. """ previous_validators = getattr(flyte_obj_type, PYDANTIC_VALIDATOR_METHOD_NAME, lambda *_: [])() - def validator(object_uid_maybe: Union[commons.LiteralObjID, Any]) -> Union[Serializable, Any]: - """partial of deserialize_flyte_literal with the object_type fixed""" - if not isinstance(object_uid_maybe, str): - return object_uid_maybe # this validator should only trigger for the placholders + def validator(object_uid_maybe: Union[commons.LiteralObjID, Any]) -> Union[type_engine.T, Any]: + """Partial of deserialize_flyte_literal with the object_type fixed""" if not PydanticDeserializationLiteralStore.is_attached(): - return object_uid_maybe + return object_uid_maybe # this validator should only trigger when we are deserializeing + if not isinstance(object_uid_maybe, str): + return object_uid_maybe # object uids are strings and we dont want to trigger on other types + if not PydanticDeserializationLiteralStore.contains(object_uid_maybe): + return object_uid_maybe # final safety check to make sure that the object uid is in the literal map return PydanticDeserializationLiteralStore.get_python_object(object_uid_maybe, flyte_obj_type) - def validator_generator(*args, **kwags) -> Iterator[Callable[[Any], Serializable]]: - """Generator that returns the validator""" + def validator_generator(*args, **kwags) -> Iterator[Callable[[Any], type_engine.T]]: + """Generator that returns validators.""" yield validator yield from previous_validators - yield from additional_flytetype_validators.get(flyte_obj_type, []) + yield from ADDITIONAL_FLYTETYPE_VALIDATORS.get(flyte_obj_type, []) return validator_generator def validate_flytefile(flytefile: Union[str, file.FlyteFile]) -> file.FlyteFile: - """validator for flytefile (i.e. deserializer)""" + """Validate a flytefile (i.e. deserialize).""" if isinstance(flytefile, file.FlyteFile): return flytefile if isinstance(flytefile, str): # when e.g. initializing from config @@ -110,7 +124,7 @@ def validate_flytefile(flytefile: Union[str, file.FlyteFile]) -> file.FlyteFile: def validate_flytedir(flytedir: Union[str, directory.FlyteDirectory]) -> directory.FlyteDirectory: - """validator for flytedir (i.e. deserializer)""" + """Validate a flytedir (i.e. deserialize).""" if isinstance(flytedir, directory.FlyteDirectory): return flytedir if isinstance(flytedir, str): # when e.g. initializing from config @@ -119,7 +133,7 @@ def validate_flytedir(flytedir: Union[str, directory.FlyteDirectory]) -> directo raise ValueError(f"Invalid type for flytedir: {type(flytedir)}") -additional_flytetype_validators: Dict[Type, List[Callable[[Any], Any]]] = { +ADDITIONAL_FLYTETYPE_VALIDATORS: Dict[Type, List[Callable[[Any], Any]]] = { file.FlyteFile: [validate_flytefile], directory.FlyteDirectory: [validate_flytedir], } diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py index 8ff8e27a75..631e84d717 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py @@ -1,5 +1,5 @@ """ -Logic for serializing a basemodel to a literalmap that can be passed between container +Logic for serializing a basemodel to a literalmap that can be passed between flyte tasks. The serialization process is as follows: @@ -8,21 +8,21 @@ 3. Return a literal map with the json and the flyte object store represented as a literalmap {placeholder: flyte type} """ -from typing import Any, Dict, NamedTuple, Union, cast -from typing_extensions import Annotated import uuid +from typing import Any, Dict, Union, cast import pydantic from google.protobuf import struct_pb2 +from typing_extensions import Annotated -from flytekit.models import literals from flytekit.core import context_manager, type_engine +from flytekit.models import literals from . import commons - BASEMODEL_JSON_KEY = "BaseModel JSON" -FLYTETYPE_OBJSTORE__KEY = "Flyte Object Store" +FLYTETYPE_OBJSTORE_KEY = "Flyte Object Store" + SerializedBaseModel = Annotated[str, "A pydantic BaseModel that has been serialized with placeholders for Flyte types."] @@ -67,7 +67,7 @@ def serialize_basemodel(basemodel: pydantic.BaseModel) -> literals.LiteralMap: return literals.LiteralMap( { BASEMODEL_JSON_KEY: basemodel_literal, # json with flyte types replaced with placeholders - FLYTETYPE_OBJSTORE__KEY: store_literal, # placeholders mapped to flyte types + FLYTETYPE_OBJSTORE_KEY: store_literal, # placeholders mapped to flyte types } ) @@ -76,26 +76,6 @@ def serialize_basemodel_to_literal( basemodel: pydantic.BaseModel, flyteobject_store: BaseModelFlyteObjectStore, ) -> literals.Literal: - """ """ - basemodel_json = serialize_basemodel_to_json_and_store(basemodel, flyteobject_store) - basemodel_literal = make_literal_from_json(basemodel_json) - return basemodel_literal - - -def make_literal_from_json(json: str) -> literals.Literal: - """ - Converts the json representation of a pydantic BaseModel to a Flyte Literal. - """ - # serialize as a string literal - ctx = context_manager.FlyteContext.current_context() - string_transformer = type_engine.TypeEngine.get_transformer(json) - return string_transformer.to_literal(ctx, json, str, string_transformer.get_literal_type(str)) - - -def serialize_basemodel_to_json_and_store( - basemodel: pydantic.BaseModel, - flyteobject_store: BaseModelFlyteObjectStore, -) -> SerializedBaseModel: """ Serialize a pydantic BaseModel to json and protobuf, separating out the Flyte types into a separate store. On deserialization, the store is used to reconstruct the Flyte types. @@ -106,7 +86,8 @@ def encoder(obj: Any) -> Union[str, commons.LiteralObjID]: return flyteobject_store.register_python_object(obj) return basemodel.__json_encoder__(obj) - return basemodel.json(encoder=encoder) + basemodel_json = basemodel.json(encoder=encoder) + return make_literal_from_json(basemodel_json) def serialize_to_flyte_literal(python_obj: object) -> literals.Literal: @@ -120,6 +101,16 @@ def serialize_to_flyte_literal(python_obj: object) -> literals.Literal: return literal_obj +def make_literal_from_json(json: str) -> literals.Literal: + """ + Converts the json representation of a pydantic BaseModel to a Flyte Literal. + """ + # serialize as a string literal + base_model_literal = struct_pb2.Struct() + base_model_literal.update({BASEMODEL_JSON_KEY: json}) + return literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(string_value=json))) # type: ignore + + def make_identifier_for_serializeable(python_type: object) -> LiteralObjID: """ Create a unique identifier for a python object. diff --git a/plugins/flytekit-pydantic/tests/test_type_transformer.py b/plugins/flytekit-pydantic/tests/test_type_transformer.py index 4786da26e0..95acb0aaa4 100644 --- a/plugins/flytekit-pydantic/tests/test_type_transformer.py +++ b/plugins/flytekit-pydantic/tests/test_type_transformer.py @@ -28,6 +28,14 @@ class Config(BaseModel): model_config: Optional[Union[Dict[str, TrainConfig], TrainConfig]] = TrainConfig() +class NestedConfig(BaseModel): + """Nested config BaseModel for testing purposes.""" + + files: "ConfigWithFlyteFiles" + dirs: "ConfigWithFlyteDirs" + df: "ConfigWithPandasDataFrame" + + class ConfigRequired(BaseModel): """Config BaseModel for testing purposes with required attribute.""" @@ -61,6 +69,9 @@ class ChildConfig(Config): d: List[int] = [1, 2, 3] +NestedConfig.update_forward_refs() + + @pytest.mark.parametrize( "python_type,kwargs", [ @@ -70,6 +81,14 @@ class ChildConfig(Config): (ConfigWithFlyteFiles, {"flytefiles": ["tests/folder/test_file1.txt", "tests/folder/test_file2.txt"]}), (ConfigWithFlyteDirs, {"flytedirs": ["tests/folder/"]}), (ConfigWithPandasDataFrame, {"df": {"a": [1, 2, 3], "b": [4, 5, 6]}}), + ( + NestedConfig, + { + "files": {"flytefiles": ["tests/folder/test_file1.txt", "tests/folder/test_file2.txt"]}, + "dirs": {"flytedirs": ["tests/folder/"]}, + "df": {"df": {"a": [1, 2, 3], "b": [4, 5, 6]}}, + }, + ), ], ) def test_transform_round_trip(python_type: Type, kwargs: Dict[str, Any]): @@ -102,6 +121,14 @@ def test_transform_round_trip(python_type: Type, kwargs: Dict[str, Any]): (ConfigWithFlyteFiles, {"flytefiles": ["s3://foo/bar"]}), (ConfigWithFlyteDirs, {"flytedirs": ["s3://foo/bar"]}), (ConfigWithPandasDataFrame, {"df": pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})}), + ( + NestedConfig, + { + "files": {"flytefiles": ["tests/folder/test_file1.txt", "tests/folder/test_file2.txt"]}, + "dirs": {"flytedirs": ["tests/folder/"]}, + "df": {"df": {"a": [1, 2, 3], "b": [4, 5, 6]}}, + }, + ), ], ) def test_pass_to_workflow(config_type: Type, kwargs: Dict[str, Any]): @@ -201,3 +228,6 @@ def wf(cfg: ChildConfig) -> ChildConfig: with pytest.raises(TypeError): # type: ignore wf(cfg=cfg) + + +test_transform_round_trip(ConfigWithFlyteDirs, {"flytedirs": ["s3://foo/bar"]}) From d55b251711862135ddb7c02ac3786d28be0577c2 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 20 Jun 2023 12:28:41 -0700 Subject: [PATCH 43/55] improved tests to work with flyte types --- .../flytekitplugins/pydantic/commons.py | 4 +- .../tests/test_type_transformer.py | 40 ++++++++++++++----- 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/commons.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/commons.py index fcbde5ac05..e309f681a0 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/commons.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/commons.py @@ -4,6 +4,8 @@ from typing_extensions import Annotated +import pandas as pd + from flytekit.core import type_engine @@ -20,7 +22,7 @@ def include_in_flyte_types(t: type) -> bool: PYDANTIC_SUPPORTED_FLYTE_TYPES = tuple( filter(include_in_flyte_types, type_engine.TypeEngine.get_available_transformers()) -) +) + (pd.DataFrame,) # this is the UUID placeholder that is set in the serialized basemodel JSON, connecting that field to # the literal map that holds the actual object that needs to be deserialized (w/ protobuf) diff --git a/plugins/flytekit-pydantic/tests/test_type_transformer.py b/plugins/flytekit-pydantic/tests/test_type_transformer.py index 95acb0aaa4..2d4bf95601 100644 --- a/plugins/flytekit-pydantic/tests/test_type_transformer.py +++ b/plugins/flytekit-pydantic/tests/test_type_transformer.py @@ -1,4 +1,5 @@ import os +import pathlib from typing import Any, Dict, List, Optional, Type, Union import pandas as pd @@ -35,6 +36,11 @@ class NestedConfig(BaseModel): dirs: "ConfigWithFlyteDirs" df: "ConfigWithPandasDataFrame" + def __eq__(self, __value: object) -> bool: + return isinstance(__value, NestedConfig) and all( + getattr(self, attr) == getattr(__value, attr) for attr in ["files", "dirs", "df"] + ) + class ConfigRequired(BaseModel): """Config BaseModel for testing purposes with required attribute.""" @@ -47,12 +53,24 @@ class ConfigWithFlyteFiles(BaseModel): flytefiles: List[file.FlyteFile] + def __eq__(self, __value: object) -> bool: + return isinstance(__value, ConfigWithFlyteFiles) and all( + pathlib.Path(self_file).read_text() == pathlib.Path(other_file).read_text() + for self_file, other_file in zip(self.flytefiles, __value.flytefiles) + ) + class ConfigWithFlyteDirs(BaseModel): """Config BaseModel for testing purposes with flytekit.directory.FlyteDirectory type hint.""" flytedirs: List[directory.FlyteDirectory] + def __eq__(self, __value: object) -> bool: + return isinstance(__value, ConfigWithFlyteDirs) and all( + os.listdir(self_dir) == os.listdir(other_dir) + for self_dir, other_dir in zip(self.flytedirs, __value.flytedirs) + ) + class ConfigWithPandasDataFrame(BaseModel): """Config BaseModel for testing purposes with pandas.DataFrame type hint.""" @@ -62,6 +80,9 @@ class ConfigWithPandasDataFrame(BaseModel): class Config: arbitrary_types_allowed = True + def __eq__(self, __value: object) -> bool: + return isinstance(__value, ConfigWithPandasDataFrame) and self.df.equals(__value.df) + class ChildConfig(Config): """Child class config BaseModel for testing purposes.""" @@ -80,13 +101,13 @@ class ChildConfig(Config): (TrainConfig, {}), (ConfigWithFlyteFiles, {"flytefiles": ["tests/folder/test_file1.txt", "tests/folder/test_file2.txt"]}), (ConfigWithFlyteDirs, {"flytedirs": ["tests/folder/"]}), - (ConfigWithPandasDataFrame, {"df": {"a": [1, 2, 3], "b": [4, 5, 6]}}), + (ConfigWithPandasDataFrame, {"df": pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})}), ( NestedConfig, { "files": {"flytefiles": ["tests/folder/test_file1.txt", "tests/folder/test_file2.txt"]}, "dirs": {"flytedirs": ["tests/folder/"]}, - "df": {"df": {"a": [1, 2, 3], "b": [4, 5, 6]}}, + "df": {"df": pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})}, }, ), ], @@ -110,7 +131,7 @@ def test_transform_round_trip(python_type: Type, kwargs: Dict[str, Any]): reconstructed_value = type_transformer.to_python_value(ctx, literal_value, type(python_value)) - # assert reconstructed_value == python_value + assert reconstructed_value == python_value @pytest.mark.parametrize( @@ -118,15 +139,15 @@ def test_transform_round_trip(python_type: Type, kwargs: Dict[str, Any]): [ (Config, {"model_config": {"foo": TrainConfig(loss="mse")}}), (ConfigRequired, {"model_config": {"foo": TrainConfig(loss="mse")}}), - (ConfigWithFlyteFiles, {"flytefiles": ["s3://foo/bar"]}), - (ConfigWithFlyteDirs, {"flytedirs": ["s3://foo/bar"]}), + (ConfigWithFlyteFiles, {"flytefiles": ["tests/folder/test_file1.txt"]}), + (ConfigWithFlyteDirs, {"flytedirs": ["tests/folder/"]}), (ConfigWithPandasDataFrame, {"df": pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})}), ( NestedConfig, { "files": {"flytefiles": ["tests/folder/test_file1.txt", "tests/folder/test_file2.txt"]}, "dirs": {"flytedirs": ["tests/folder/"]}, - "df": {"df": {"a": [1, 2, 3], "b": [4, 5, 6]}}, + "df": {"df": pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})}, }, ), ], @@ -143,9 +164,9 @@ def train(cfg: config_type) -> config_type: def wf(cfg: config_type) -> config_type: return train(cfg=cfg) - returned_cfg = wf(cfg=cfg) + returned_cfg = wf(cfg=cfg) # type: ignore - # assert returned_cfg == cfg + assert returned_cfg == cfg # TODO these assertions are not valid for all types @@ -228,6 +249,3 @@ def wf(cfg: ChildConfig) -> ChildConfig: with pytest.raises(TypeError): # type: ignore wf(cfg=cfg) - - -test_transform_round_trip(ConfigWithFlyteDirs, {"flytedirs": ["s3://foo/bar"]}) From 0c4ce8c3dac1b4ce8b1339f567d261509cb48555 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 20 Jun 2023 12:32:32 -0700 Subject: [PATCH 44/55] arbitrary_types_allowed is not needed when we have the __get_validators__ set on the class --- plugins/flytekit-pydantic/tests/test_type_transformer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/plugins/flytekit-pydantic/tests/test_type_transformer.py b/plugins/flytekit-pydantic/tests/test_type_transformer.py index 2d4bf95601..38d0c02fdd 100644 --- a/plugins/flytekit-pydantic/tests/test_type_transformer.py +++ b/plugins/flytekit-pydantic/tests/test_type_transformer.py @@ -77,9 +77,6 @@ class ConfigWithPandasDataFrame(BaseModel): df: pd.DataFrame - class Config: - arbitrary_types_allowed = True - def __eq__(self, __value: object) -> bool: return isinstance(__value, ConfigWithPandasDataFrame) and self.df.equals(__value.df) From bc27331e465279652a438507e9fa5631d9f224ee Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 20 Jun 2023 14:25:00 -0700 Subject: [PATCH 45/55] more tests --- .../pydantic/deserialization.py | 3 +- .../tests/test_type_transformer.py | 37 ++++++++++++++++++- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py index 70a56add1a..b3860ad600 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py @@ -71,8 +71,7 @@ def get_python_object( """Deserialize a flyte literal and return the python object.""" if not cls.is_attached(): raise Exception("Must attach to a literal map before deserializing") - assert cls.literal_store is not None - literal = cls.literal_store[identifier] + literal = cls.literal_store[identifier] # type: ignore python_object = deserialize_flyte_literal(literal, expected_type) return python_object diff --git a/plugins/flytekit-pydantic/tests/test_type_transformer.py b/plugins/flytekit-pydantic/tests/test_type_transformer.py index 38d0c02fdd..325398e89c 100644 --- a/plugins/flytekit-pydantic/tests/test_type_transformer.py +++ b/plugins/flytekit-pydantic/tests/test_type_transformer.py @@ -1,8 +1,9 @@ +import datetime as dt import os import pathlib from typing import Any, Dict, List, Optional, Type, Union -import pandas as pd +import pandas as pd import pytest from flytekitplugins.pydantic import BaseModelTransformer from pydantic import BaseModel, Extra @@ -29,12 +30,19 @@ class Config(BaseModel): model_config: Optional[Union[Dict[str, TrainConfig], TrainConfig]] = TrainConfig() +class ConfigWithDatetime(BaseModel): + """Config BaseModel for testing purposes with datetime type hint.""" + + datetime: dt.datetime = dt.datetime.now() + + class NestedConfig(BaseModel): """Nested config BaseModel for testing purposes.""" files: "ConfigWithFlyteFiles" dirs: "ConfigWithFlyteDirs" df: "ConfigWithPandasDataFrame" + datetime: "ConfigWithDatetime" = ConfigWithDatetime() def __eq__(self, __value: object) -> bool: return isinstance(__value, NestedConfig) and all( @@ -246,3 +254,30 @@ def wf(cfg: ChildConfig) -> ChildConfig: with pytest.raises(TypeError): # type: ignore wf(cfg=cfg) + + +python_value = NestedConfig( + **{ + "files": {"flytefiles": ["tests/folder/test_file1.txt", "tests/folder/test_file2.txt"]}, + "dirs": {"flytedirs": ["tests/folder/"]}, + "df": {"df": pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})}, + }, +) + +from flytekit.core.context_manager import FlyteContextManager + +ctx = FlyteContextManager().current_context() + +type_transformer = BaseModelTransformer() + + +literal_value = type_transformer.to_literal( + ctx, + python_value, + NestedConfig, + type_transformer.get_literal_type(python_value), +) + +ConfigWithDatetime(**{"datetime": "2023-06-20T13:19:48.820609"}) + +Config.parse_raw \ No newline at end of file From 202669862635e82524ca579bfc345eddc242713d Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 20 Jun 2023 14:27:46 -0700 Subject: [PATCH 46/55] removed some leftover test code --- .../tests/test_type_transformer.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/plugins/flytekit-pydantic/tests/test_type_transformer.py b/plugins/flytekit-pydantic/tests/test_type_transformer.py index 325398e89c..9d4575b602 100644 --- a/plugins/flytekit-pydantic/tests/test_type_transformer.py +++ b/plugins/flytekit-pydantic/tests/test_type_transformer.py @@ -263,21 +263,3 @@ def wf(cfg: ChildConfig) -> ChildConfig: "df": {"df": pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})}, }, ) - -from flytekit.core.context_manager import FlyteContextManager - -ctx = FlyteContextManager().current_context() - -type_transformer = BaseModelTransformer() - - -literal_value = type_transformer.to_literal( - ctx, - python_value, - NestedConfig, - type_transformer.get_literal_type(python_value), -) - -ConfigWithDatetime(**{"datetime": "2023-06-20T13:19:48.820609"}) - -Config.parse_raw \ No newline at end of file From 0e9f21559ee688f907e651777e70cb070efce91a Mon Sep 17 00:00:00 2001 From: Arthur Date: Fri, 23 Jun 2023 10:32:47 -0700 Subject: [PATCH 47/55] changed serialization to align w dataclass_json logic in core type engine. + small cleanups --- .../pydantic/basemodel_transformer.py | 8 +++-- .../flytekitplugins/pydantic/serialization.py | 7 ++-- .../tests/test_type_transformer.py | 36 ++++--------------- 3 files changed, 15 insertions(+), 36 deletions(-) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py index 6324ba0fbe..d3301860d2 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py @@ -1,14 +1,16 @@ """Serializes & deserializes the pydantic basemodels """ from typing import Type -from typing_extensions import Annotated + import pydantic +from google.protobuf import json_format +from typing_extensions import Annotated from flytekit import FlyteContext from flytekit.core import type_engine from flytekit.models import literals, types -from . import serialization, deserialization +from . import deserialization, serialization BaseModelLiteralValue = Annotated[ literals.LiteralMap, @@ -55,7 +57,7 @@ def to_python_value( def read_basemodel_json_from_literalmap(lv: BaseModelLiteralValue) -> serialization.SerializedBaseModel: basemodel_literal: literals.Literal = lv.literals[serialization.BASEMODEL_JSON_KEY] - basemodel_json_w_placeholders = deserialization.deserialize_flyte_literal(basemodel_literal, str) + basemodel_json_w_placeholders = json_format.MessageToJson(basemodel_literal.scalar.generic) assert isinstance(basemodel_json_w_placeholders, str) return basemodel_json_w_placeholders diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py index 631e84d717..946d4f56c5 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py @@ -13,6 +13,7 @@ import pydantic from google.protobuf import struct_pb2 +from google.protobuf import json_format from typing_extensions import Annotated from flytekit.core import context_manager, type_engine @@ -45,6 +46,7 @@ def register_python_object(self, python_object: object) -> LiteralObjID: """Serialize to literal and return a unique identifier.""" serialized_item = serialize_to_flyte_literal(python_object) identifier = make_identifier_for_serializeable(python_object) + assert identifier not in self.literal_store self.literal_store[identifier] = serialized_item return identifier @@ -105,10 +107,7 @@ def make_literal_from_json(json: str) -> literals.Literal: """ Converts the json representation of a pydantic BaseModel to a Flyte Literal. """ - # serialize as a string literal - base_model_literal = struct_pb2.Struct() - base_model_literal.update({BASEMODEL_JSON_KEY: json}) - return literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(string_value=json))) # type: ignore + return literals.Literal( scalar=literals.Scalar(generic=json_format.Parse(json, struct_pb2.Struct())) ) # type: ignore def make_identifier_for_serializeable(python_type: object) -> LiteralObjID: diff --git a/plugins/flytekit-pydantic/tests/test_type_transformer.py b/plugins/flytekit-pydantic/tests/test_type_transformer.py index 9d4575b602..9a3c05d2b8 100644 --- a/plugins/flytekit-pydantic/tests/test_type_transformer.py +++ b/plugins/flytekit-pydantic/tests/test_type_transformer.py @@ -46,7 +46,7 @@ class NestedConfig(BaseModel): def __eq__(self, __value: object) -> bool: return isinstance(__value, NestedConfig) and all( - getattr(self, attr) == getattr(__value, attr) for attr in ["files", "dirs", "df"] + getattr(self, attr) == getattr(__value, attr) for attr in ["files", "dirs", "df", 'datetime'] ) @@ -235,31 +235,9 @@ def wf(cfg1: TrainConfig, cfg2: TrainConfig) -> bool: assert wf(cfg1=cfg1, cfg2=cfg2), wf(cfg1=cfg1, cfg2=cfg2) # type: ignore - -# TODO: //Arthur to Fabio this was differente before but now im unsure what the test is doing -# previously a pattern match error was checked that its raised, but isnt it OK that the ChildConfig -# is passed since its a subclass of Config? -# I modified the test to work the other way around, but im not sure if this is what you intended -def test_pass_wrong_type_to_workflow(): - """Test passing the wrong type raises exception.""" - cfg = Config() - - @flytekit.task - def train(cfg: ChildConfig) -> ChildConfig: - return cfg - - @flytekit.workflow - def wf(cfg: ChildConfig) -> ChildConfig: - return train(cfg=cfg) # type: ignore - - with pytest.raises(TypeError): # type: ignore - wf(cfg=cfg) - - -python_value = NestedConfig( - **{ - "files": {"flytefiles": ["tests/folder/test_file1.txt", "tests/folder/test_file2.txt"]}, - "dirs": {"flytedirs": ["tests/folder/"]}, - "df": {"df": pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})}, - }, -) +if __name__ == "__main__": + # debugging + test_transform_round_trip( + ConfigWithPandasDataFrame, + {"df": pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})}, + ) \ No newline at end of file From 4c463dee4088e5e44dd1086cff6d7d2903df724d Mon Sep 17 00:00:00 2001 From: Arthur Date: Fri, 7 Jul 2023 10:43:44 -0700 Subject: [PATCH 48/55] addressed comments --- flytekit/core/type_engine.py | 2 +- .../pydantic/basemodel_transformer.py | 13 +++++++------ .../flytekitplugins/pydantic/commons.py | 14 ++++++++------ .../flytekitplugins/pydantic/deserialization.py | 6 ++---- .../flytekitplugins/pydantic/serialization.py | 11 ++++++----- plugins/flytekit-pydantic/setup.py | 3 ++- .../tests/test_type_transformer.py | 11 +++++++---- 7 files changed, 33 insertions(+), 27 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 428fae98bd..c123a2efa6 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -710,7 +710,7 @@ def register_additional_type(cls, transformer: TypeTransformer, additional_type: cls._REGISTRY[additional_type] = transformer @classmethod - def get_transformer(cls, python_type: T) -> TypeTransformer[T]: + def get_transformer(cls, python_type: Type[T]) -> TypeTransformer[T]: """ The TypeEngine hierarchy for flyteKit. This method looksup and selects the type transformer. The algorithm is as follows diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py index d3301860d2..5998991afa 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py @@ -12,7 +12,7 @@ from . import deserialization, serialization -BaseModelLiteralValue = Annotated[ +BaseModelLiteralMap = Annotated[ literals.LiteralMap, """ BaseModel serialized to a LiteralMap consisting of: @@ -38,24 +38,25 @@ def to_literal( python_val: pydantic.BaseModel, python_type: Type[pydantic.BaseModel], expected: types.LiteralType, - ) -> BaseModelLiteralValue: + ) -> literals.Literal: """Convert a given ``pydantic.BaseModel`` to the Literal representation.""" return serialization.serialize_basemodel(python_val) def to_python_value( self, ctx: FlyteContext, - lv: BaseModelLiteralValue, + lv: literals.Literal, expected_python_type: Type[pydantic.BaseModel], ) -> pydantic.BaseModel: """Re-hydrate the pydantic BaseModel object from Flyte Literal value.""" - basemodel_json_w_placeholders = read_basemodel_json_from_literalmap(lv) - flyte_obj_literalmap = lv.literals[serialization.FLYTETYPE_OBJSTORE_KEY] + literalmap: BaseModelLiteralMap = lv.map + basemodel_json_w_placeholders = read_basemodel_json_from_literalmap(literalmap) + flyte_obj_literalmap = literalmap.literals[serialization.FLYTETYPE_OBJSTORE_KEY] with deserialization.PydanticDeserializationLiteralStore.attach(flyte_obj_literalmap): return expected_python_type.parse_raw(basemodel_json_w_placeholders) -def read_basemodel_json_from_literalmap(lv: BaseModelLiteralValue) -> serialization.SerializedBaseModel: +def read_basemodel_json_from_literalmap(lv: BaseModelLiteralMap) -> serialization.SerializedBaseModel: basemodel_literal: literals.Literal = lv.literals[serialization.BASEMODEL_JSON_KEY] basemodel_json_w_placeholders = json_format.MessageToJson(basemodel_literal.scalar.generic) assert isinstance(basemodel_json_w_placeholders, str) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/commons.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/commons.py index e309f681a0..238e78c84d 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/commons.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/commons.py @@ -1,28 +1,30 @@ import builtins import datetime import typing +from typing import Set +import numpy +import pyarrow from typing_extensions import Annotated -import pandas as pd - - from flytekit.core import type_engine -MODULES_TO_EXLCLUDE_FROM_FLYTE_TYPES = {m.__name__ for m in [builtins, typing, datetime]} +MODULES_TO_EXCLUDE_FROM_FLYTE_TYPES: Set[str] = {m.__name__ for m in [builtins, typing, datetime, pyarrow, numpy]} def include_in_flyte_types(t: type) -> bool: if t is None: return False - if t.__module__ in MODULES_TO_EXLCLUDE_FROM_FLYTE_TYPES: + object_module = t.__module__ + if any(object_module.startswith(module) for module in MODULES_TO_EXCLUDE_FROM_FLYTE_TYPES): return False return True +type_engine.TypeEngine.lazy_import_transformers() # loads all transformers PYDANTIC_SUPPORTED_FLYTE_TYPES = tuple( filter(include_in_flyte_types, type_engine.TypeEngine.get_available_transformers()) -) + (pd.DataFrame,) +) # this is the UUID placeholder that is set in the serialized basemodel JSON, connecting that field to # the literal map that holds the actual object that needs to be deserialized (w/ protobuf) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py index b3860ad600..0ad903f396 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py @@ -2,14 +2,12 @@ from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Type, TypeVar, Union import pydantic +from flytekitplugins.pydantic import commons, serialization + from flytekit.core import context_manager, type_engine from flytekit.models import literals - from flytekit.types import directory, file -from flytekitplugins.pydantic import commons, serialization - - # this field is used by pydantic to get the validator method PYDANTIC_VALIDATOR_METHOD_NAME = ( pydantic.BaseModel.__get_validators__.__name__ diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py index 946d4f56c5..b481cbc405 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py @@ -12,8 +12,7 @@ from typing import Any, Dict, Union, cast import pydantic -from google.protobuf import struct_pb2 -from google.protobuf import json_format +from google.protobuf import json_format, struct_pb2 from typing_extensions import Annotated from flytekit.core import context_manager, type_engine @@ -57,7 +56,7 @@ def as_literalmap(self) -> literals.LiteralMap: return literals.LiteralMap(literals=self.literal_store) -def serialize_basemodel(basemodel: pydantic.BaseModel) -> literals.LiteralMap: +def serialize_basemodel(basemodel: pydantic.BaseModel) -> literals.Literal: """ Serializes a given pydantic BaseModel instance into a LiteralMap. The BaseModel is first serialized into a JSON format, where all Flyte types are replaced with unique placeholder strings. @@ -66,12 +65,14 @@ def serialize_basemodel(basemodel: pydantic.BaseModel) -> literals.LiteralMap: store = BaseModelFlyteObjectStore() basemodel_literal = serialize_basemodel_to_literal(basemodel, store) store_literal = store.as_literalmap() - return literals.LiteralMap( + basemodel_literalmap = literals.LiteralMap( { BASEMODEL_JSON_KEY: basemodel_literal, # json with flyte types replaced with placeholders FLYTETYPE_OBJSTORE_KEY: store_literal, # placeholders mapped to flyte types } ) + literal = literals.Literal(map=basemodel_literalmap) # type: ignore + return literal def serialize_basemodel_to_literal( @@ -107,7 +108,7 @@ def make_literal_from_json(json: str) -> literals.Literal: """ Converts the json representation of a pydantic BaseModel to a Flyte Literal. """ - return literals.Literal( scalar=literals.Scalar(generic=json_format.Parse(json, struct_pb2.Struct())) ) # type: ignore + return literals.Literal(scalar=literals.Scalar(generic=json_format.Parse(json, struct_pb2.Struct()))) # type: ignore def make_identifier_for_serializeable(python_type: object) -> LiteralObjID: diff --git a/plugins/flytekit-pydantic/setup.py b/plugins/flytekit-pydantic/setup.py index 1ad734c482..313c574dd1 100644 --- a/plugins/flytekit-pydantic/setup.py +++ b/plugins/flytekit-pydantic/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<2.0.0", "pydantic"] +plugin_requires = ["flytekit>=1.7.0b0,<2.0.0", "pydantic"] __version__ = "0.0.0+develop" @@ -29,6 +29,7 @@ "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Software Development", diff --git a/plugins/flytekit-pydantic/tests/test_type_transformer.py b/plugins/flytekit-pydantic/tests/test_type_transformer.py index 9a3c05d2b8..2b3344ca06 100644 --- a/plugins/flytekit-pydantic/tests/test_type_transformer.py +++ b/plugins/flytekit-pydantic/tests/test_type_transformer.py @@ -9,6 +9,7 @@ from pydantic import BaseModel, Extra import flytekit +from flytekit.core import context_manager from flytekit.types import directory from flytekit.types.file import file @@ -46,7 +47,7 @@ class NestedConfig(BaseModel): def __eq__(self, __value: object) -> bool: return isinstance(__value, NestedConfig) and all( - getattr(self, attr) == getattr(__value, attr) for attr in ["files", "dirs", "df", 'datetime'] + getattr(self, attr) == getattr(__value, attr) for attr in ["files", "dirs", "df", "datetime"] ) @@ -119,9 +120,8 @@ class ChildConfig(Config): ) def test_transform_round_trip(python_type: Type, kwargs: Dict[str, Any]): """Test that a (de-)serialization roundtrip results in the identical BaseModel.""" - from flytekit.core.context_manager import FlyteContextManager - ctx = FlyteContextManager().current_context() + ctx = context_manager.FlyteContextManager().current_context() type_transformer = BaseModelTransformer() @@ -235,9 +235,12 @@ def wf(cfg1: TrainConfig, cfg2: TrainConfig) -> bool: assert wf(cfg1=cfg1, cfg2=cfg2), wf(cfg1=cfg1, cfg2=cfg2) # type: ignore +import pydantic +print( pydantic.__version__) + if __name__ == "__main__": # debugging test_transform_round_trip( ConfigWithPandasDataFrame, {"df": pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})}, - ) \ No newline at end of file + ) From cf753266dafc2d927e7fa57a37f82caf115e9a80 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 29 Aug 2023 01:21:47 +0000 Subject: [PATCH 49/55] removed nested literalmap that caused issues during serialization --- .../pydantic/basemodel_transformer.py | 3 +-- .../flytekitplugins/pydantic/serialization.py | 11 ++--------- .../tests/test_type_transformer.py | 19 +++++++++++++++++-- 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py index 5998991afa..59d1f52e83 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py @@ -51,8 +51,7 @@ def to_python_value( """Re-hydrate the pydantic BaseModel object from Flyte Literal value.""" literalmap: BaseModelLiteralMap = lv.map basemodel_json_w_placeholders = read_basemodel_json_from_literalmap(literalmap) - flyte_obj_literalmap = literalmap.literals[serialization.FLYTETYPE_OBJSTORE_KEY] - with deserialization.PydanticDeserializationLiteralStore.attach(flyte_obj_literalmap): + with deserialization.PydanticDeserializationLiteralStore.attach(literalmap): return expected_python_type.parse_raw(basemodel_json_w_placeholders) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py index b481cbc405..2a9e0fa435 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py @@ -21,7 +21,6 @@ from . import commons BASEMODEL_JSON_KEY = "BaseModel JSON" -FLYTETYPE_OBJSTORE_KEY = "Flyte Object Store" SerializedBaseModel = Annotated[str, "A pydantic BaseModel that has been serialized with placeholders for Flyte types."] @@ -49,12 +48,6 @@ def register_python_object(self, python_object: object) -> LiteralObjID: self.literal_store[identifier] = serialized_item return identifier - def as_literalmap(self) -> literals.LiteralMap: - """ - Converts the object store to a literal map - """ - return literals.LiteralMap(literals=self.literal_store) - def serialize_basemodel(basemodel: pydantic.BaseModel) -> literals.Literal: """ @@ -64,11 +57,11 @@ def serialize_basemodel(basemodel: pydantic.BaseModel) -> literals.Literal: """ store = BaseModelFlyteObjectStore() basemodel_literal = serialize_basemodel_to_literal(basemodel, store) - store_literal = store.as_literalmap() + assert BASEMODEL_JSON_KEY not in store.literal_store, "literal map key already exists" basemodel_literalmap = literals.LiteralMap( { BASEMODEL_JSON_KEY: basemodel_literal, # json with flyte types replaced with placeholders - FLYTETYPE_OBJSTORE_KEY: store_literal, # placeholders mapped to flyte types + **store.literal_store, # flyte type-engine serialized types } ) literal = literals.Literal(map=basemodel_literalmap) # type: ignore diff --git a/plugins/flytekit-pydantic/tests/test_type_transformer.py b/plugins/flytekit-pydantic/tests/test_type_transformer.py index 2b3344ca06..9c9eadd7b6 100644 --- a/plugins/flytekit-pydantic/tests/test_type_transformer.py +++ b/plugins/flytekit-pydantic/tests/test_type_transformer.py @@ -235,8 +235,23 @@ def wf(cfg1: TrainConfig, cfg2: TrainConfig) -> bool: assert wf(cfg1=cfg1, cfg2=cfg2), wf(cfg1=cfg1, cfg2=cfg2) # type: ignore -import pydantic -print( pydantic.__version__) + +def test_dynamic(): + class Config(BaseModel): + path: str + + @flytekit.task + def train(cfg: Config): + print(cfg) + + @flytekit.dynamic(cache=True, cache_version="0.3") + def sub_wf(cfg: Config): + train(cfg=cfg) + + @flytekit.workflow + def wf(): + sub_wf(cfg=Config(path="bar")) + if __name__ == "__main__": # debugging From eec501d322f7021434d313cd3f8f7c33c6b74f11 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 29 Aug 2023 02:01:36 +0000 Subject: [PATCH 50/55] expanded dynamic task test --- .../tests/test_type_transformer.py | 32 +++++++++++++++---- 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/plugins/flytekit-pydantic/tests/test_type_transformer.py b/plugins/flytekit-pydantic/tests/test_type_transformer.py index 9c9eadd7b6..631d326055 100644 --- a/plugins/flytekit-pydantic/tests/test_type_transformer.py +++ b/plugins/flytekit-pydantic/tests/test_type_transformer.py @@ -236,21 +236,41 @@ def wf(cfg1: TrainConfig, cfg2: TrainConfig) -> bool: assert wf(cfg1=cfg1, cfg2=cfg2), wf(cfg1=cfg1, cfg2=cfg2) # type: ignore -def test_dynamic(): - class Config(BaseModel): - path: str +@pytest.mark.parametrize( + "python_type,config_kwargs", + [ + (Config, {}), + (ConfigRequired, {"model_config": TrainConfig()}), + (TrainConfig, {}), + (ConfigWithFlyteFiles, {"flytefiles": ["tests/folder/test_file1.txt", "tests/folder/test_file2.txt"]}), + (ConfigWithFlyteDirs, {"flytedirs": ["tests/folder/"]}), + (ConfigWithPandasDataFrame, {"df": pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})}), + ( + NestedConfig, + { + "files": {"flytefiles": ["tests/folder/test_file1.txt", "tests/folder/test_file2.txt"]}, + "dirs": {"flytedirs": ["tests/folder/"]}, + "df": {"df": pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})}, + }, + ), + ], +) +def test_dynamic(python_type: Type[BaseModel], config_kwargs: Dict[str, Any]): + config_instance = python_type(**config_kwargs) @flytekit.task - def train(cfg: Config): + def train(cfg: BaseModel): print(cfg) @flytekit.dynamic(cache=True, cache_version="0.3") - def sub_wf(cfg: Config): + def sub_wf(cfg: BaseModel): train(cfg=cfg) @flytekit.workflow def wf(): - sub_wf(cfg=Config(path="bar")) + sub_wf(cfg=config_instance) + + wf() if __name__ == "__main__": From 0c0a4832141a39e686d80f874e06a2ee16b4ed28 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 31 Aug 2023 18:52:48 +0000 Subject: [PATCH 51/55] changed serialization from flat literalmap to nested --- .../pydantic/basemodel_transformer.py | 16 ++++++++-------- .../flytekitplugins/pydantic/serialization.py | 10 +++++++--- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py index 59d1f52e83..98d38ee6bd 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py @@ -1,6 +1,6 @@ """Serializes & deserializes the pydantic basemodels """ -from typing import Type +from typing import Dict, Type import pydantic from google.protobuf import json_format @@ -12,8 +12,8 @@ from . import deserialization, serialization -BaseModelLiteralMap = Annotated[ - literals.LiteralMap, +BaseModelLiterals = Annotated[ + Dict[str, literals.Literal], """ BaseModel serialized to a LiteralMap consisting of: 1) the basemodel json with placeholders for flyte types @@ -49,14 +49,14 @@ def to_python_value( expected_python_type: Type[pydantic.BaseModel], ) -> pydantic.BaseModel: """Re-hydrate the pydantic BaseModel object from Flyte Literal value.""" - literalmap: BaseModelLiteralMap = lv.map - basemodel_json_w_placeholders = read_basemodel_json_from_literalmap(literalmap) - with deserialization.PydanticDeserializationLiteralStore.attach(literalmap): + basemodel_literals: BaseModelLiterals = lv.map.literals + basemodel_json_w_placeholders = read_basemodel_json_from_literalmap(basemodel_literals) + with deserialization.PydanticDeserializationLiteralStore.attach(basemodel_literals[serialization.OBJECTS_KEY].map): return expected_python_type.parse_raw(basemodel_json_w_placeholders) -def read_basemodel_json_from_literalmap(lv: BaseModelLiteralMap) -> serialization.SerializedBaseModel: - basemodel_literal: literals.Literal = lv.literals[serialization.BASEMODEL_JSON_KEY] +def read_basemodel_json_from_literalmap(lv: BaseModelLiterals) -> serialization.SerializedBaseModel: + basemodel_literal: literals.Literal = lv[serialization.BASEMODEL_JSON_KEY] basemodel_json_w_placeholders = json_format.MessageToJson(basemodel_literal.scalar.generic) assert isinstance(basemodel_json_w_placeholders, str) return basemodel_json_w_placeholders diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py index 2a9e0fa435..c505af12f2 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py @@ -12,6 +12,7 @@ from typing import Any, Dict, Union, cast import pydantic +from flyteidl.core import literals_pb2 from google.protobuf import json_format, struct_pb2 from typing_extensions import Annotated @@ -21,7 +22,7 @@ from . import commons BASEMODEL_JSON_KEY = "BaseModel JSON" - +OBJECTS_KEY = "Serialized Flyte Objects" SerializedBaseModel = Annotated[str, "A pydantic BaseModel that has been serialized with placeholders for Flyte types."] @@ -48,6 +49,10 @@ def register_python_object(self, python_object: object) -> LiteralObjID: self.literal_store[identifier] = serialized_item return identifier + def to_literal(self) -> literals.Literal: + """Convert the object store to a literal map.""" + return literals.Literal(map=literals.LiteralMap(literals=self.literal_store)) + def serialize_basemodel(basemodel: pydantic.BaseModel) -> literals.Literal: """ @@ -57,11 +62,10 @@ def serialize_basemodel(basemodel: pydantic.BaseModel) -> literals.Literal: """ store = BaseModelFlyteObjectStore() basemodel_literal = serialize_basemodel_to_literal(basemodel, store) - assert BASEMODEL_JSON_KEY not in store.literal_store, "literal map key already exists" basemodel_literalmap = literals.LiteralMap( { BASEMODEL_JSON_KEY: basemodel_literal, # json with flyte types replaced with placeholders - **store.literal_store, # flyte type-engine serialized types + OBJECTS_KEY: store.to_literal(), # flyte type-engine serialized types } ) literal = literals.Literal(map=basemodel_literalmap) # type: ignore From 01c065e86a891b06869f794a4a58311c648009f8 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Sun, 3 Sep 2023 10:12:45 -0700 Subject: [PATCH 52/55] add a unit test Signed-off-by: Yee Hing Tong --- .../pydantic/basemodel_transformer.py | 4 ++- .../flytekitplugins/pydantic/serialization.py | 1 - .../tests/test_type_transformer.py | 25 +++++++++++++++---- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py index 98d38ee6bd..a3359cbd2e 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py @@ -51,7 +51,9 @@ def to_python_value( """Re-hydrate the pydantic BaseModel object from Flyte Literal value.""" basemodel_literals: BaseModelLiterals = lv.map.literals basemodel_json_w_placeholders = read_basemodel_json_from_literalmap(basemodel_literals) - with deserialization.PydanticDeserializationLiteralStore.attach(basemodel_literals[serialization.OBJECTS_KEY].map): + with deserialization.PydanticDeserializationLiteralStore.attach( + basemodel_literals[serialization.OBJECTS_KEY].map + ): return expected_python_type.parse_raw(basemodel_json_w_placeholders) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py index c505af12f2..cd5b149fd9 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py @@ -12,7 +12,6 @@ from typing import Any, Dict, Union, cast import pydantic -from flyteidl.core import literals_pb2 from google.protobuf import json_format, struct_pb2 from typing_extensions import Annotated diff --git a/plugins/flytekit-pydantic/tests/test_type_transformer.py b/plugins/flytekit-pydantic/tests/test_type_transformer.py index 631d326055..b49af1e204 100644 --- a/plugins/flytekit-pydantic/tests/test_type_transformer.py +++ b/plugins/flytekit-pydantic/tests/test_type_transformer.py @@ -5,11 +5,14 @@ import pandas as pd import pytest +from flyteidl.core.types_pb2 import SimpleType from flytekitplugins.pydantic import BaseModelTransformer +from flytekitplugins.pydantic.commons import PYDANTIC_SUPPORTED_FLYTE_TYPES from pydantic import BaseModel, Extra import flytekit from flytekit.core import context_manager +from flytekit.core.type_engine import TypeEngine from flytekit.types import directory from flytekit.types.file import file @@ -273,9 +276,21 @@ def wf(): wf() -if __name__ == "__main__": - # debugging - test_transform_round_trip( - ConfigWithPandasDataFrame, - {"df": pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})}, +def test_supported(): + assert len(PYDANTIC_SUPPORTED_FLYTE_TYPES) == 9 + + +def test_single_df(): + ctx = context_manager.FlyteContextManager.current_context() + lt = TypeEngine.to_literal_type(ConfigWithPandasDataFrame) + assert lt.simple == SimpleType.STRUCT + + pyd = ConfigWithPandasDataFrame(df=pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})) + lit = TypeEngine.to_literal(ctx, pyd, ConfigWithPandasDataFrame, lt) + assert lit.map is not None + offloaded_keys = list(lit.map.literals["Serialized Flyte Objects"].map.literals.keys()) + assert len(offloaded_keys) == 1 + assert ( + lit.map.literals["Serialized Flyte Objects"].map.literals[offloaded_keys[0]].scalar.structured_dataset + is not None ) From 11235fbbee726042c6796ca31ec3890e529ec885 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 5 Sep 2023 20:42:16 +0000 Subject: [PATCH 53/55] linting issue fixed --- .../flytekitplugins/pydantic/deserialization.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py index 0ad903f396..24fe5afc1e 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py @@ -1,5 +1,5 @@ import contextlib -from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Type, TypeVar, Union +from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Type, TypeVar, Union, cast import pydantic from flytekitplugins.pydantic import commons, serialization @@ -9,11 +9,7 @@ from flytekit.types import directory, file # this field is used by pydantic to get the validator method -PYDANTIC_VALIDATOR_METHOD_NAME = ( - pydantic.BaseModel.__get_validators__.__name__ - if pydantic.__version__ < "2.0.0" - else pydantic.BaseModel.__get_pydantic_core_schema__.__name___ # type: ignore -) +PYDANTIC_VALIDATOR_METHOD_NAME = pydantic.BaseModel.__get_validators__.__name__ PythonType = TypeVar("PythonType") # target type of the deserialization @@ -89,7 +85,10 @@ def add_flyte_validators_for_type( Add flyte deserialisation validators to a type. """ - previous_validators = getattr(flyte_obj_type, PYDANTIC_VALIDATOR_METHOD_NAME, lambda *_: [])() + previous_validators = cast( + Iterator[Callable[[Any], type_engine.T]], + getattr(flyte_obj_type, PYDANTIC_VALIDATOR_METHOD_NAME, lambda *_: [])(), + ) def validator(object_uid_maybe: Union[commons.LiteralObjID, Any]) -> Union[type_engine.T, Any]: """Partial of deserialize_flyte_literal with the object_type fixed""" From 1e946bb26f9c8871508b4c73c030c6fd2a92f742 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 5 Sep 2023 22:05:34 +0000 Subject: [PATCH 54/55] revert typehint change in type engine --- flytekit/core/type_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index c123a2efa6..5994390c8d 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -710,7 +710,7 @@ def register_additional_type(cls, transformer: TypeTransformer, additional_type: cls._REGISTRY[additional_type] = transformer @classmethod - def get_transformer(cls, python_type: Type[T]) -> TypeTransformer[T]: + def get_transformer(cls, python_type: Type) -> TypeTransformer[T]: """ The TypeEngine hierarchy for flyteKit. This method looksup and selects the type transformer. The algorithm is as follows From f58a55b1b619198b7a6add5472a934437618444a Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 11 Sep 2023 16:53:20 +0000 Subject: [PATCH 55/55] more lint fixes --- .../flytekitplugins/pydantic/basemodel_transformer.py | 4 ++-- plugins/flytekit-pydantic/tests/folder/test_file1.txt | 2 +- plugins/flytekit-pydantic/tests/folder/test_file2.txt | 2 +- plugins/flytekit-pydantic/tests/test_type_transformer.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py index a3359cbd2e..325da8e500 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py @@ -15,8 +15,8 @@ BaseModelLiterals = Annotated[ Dict[str, literals.Literal], """ - BaseModel serialized to a LiteralMap consisting of: - 1) the basemodel json with placeholders for flyte types + BaseModel serialized to a LiteralMap consisting of: + 1) the basemodel json with placeholders for flyte types 2) mapping from placeholders to serialized flyte type values in the object store """, ] diff --git a/plugins/flytekit-pydantic/tests/folder/test_file1.txt b/plugins/flytekit-pydantic/tests/folder/test_file1.txt index 1910281566..257cc5642c 100644 --- a/plugins/flytekit-pydantic/tests/folder/test_file1.txt +++ b/plugins/flytekit-pydantic/tests/folder/test_file1.txt @@ -1 +1 @@ -foo \ No newline at end of file +foo diff --git a/plugins/flytekit-pydantic/tests/folder/test_file2.txt b/plugins/flytekit-pydantic/tests/folder/test_file2.txt index ba0e162e1c..5716ca5987 100644 --- a/plugins/flytekit-pydantic/tests/folder/test_file2.txt +++ b/plugins/flytekit-pydantic/tests/folder/test_file2.txt @@ -1 +1 @@ -bar \ No newline at end of file +bar diff --git a/plugins/flytekit-pydantic/tests/test_type_transformer.py b/plugins/flytekit-pydantic/tests/test_type_transformer.py index b49af1e204..3c02dcb3f1 100644 --- a/plugins/flytekit-pydantic/tests/test_type_transformer.py +++ b/plugins/flytekit-pydantic/tests/test_type_transformer.py @@ -198,7 +198,7 @@ def wf(cfg: ConfigWithFlyteFiles) -> str: return read(cfg=cfg) # type: ignore string = wf(cfg=cfg) - assert string in {"foo", "bar"} # type: ignore + assert string in {"foo\n", "bar\n"} # type: ignore @pytest.mark.parametrize(