From 9edda50a32daeeba205b81f0213d33c555fcb3e2 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 13 Feb 2023 12:07:54 -0800 Subject: [PATCH 1/6] Get the origin type when serializing dataclass Signed-off-by: Kevin Su --- flytekit/core/type_engine.py | 12 ++++++++++++ tests/flytekit/unit/core/test_type_engine.py | 4 ++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 61c448b365..8c3a942481 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -320,6 +320,7 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: ) schema = None try: + t = self._get_origin_type_in_annotation(t) s = cast(DataClassJsonMixin, t).schema() for _, v in s.fields.items(): # marshmallow-jsonschema only supports enums loaded by name. @@ -352,6 +353,17 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp scalar=Scalar(generic=_json_format.Parse(cast(DataClassJsonMixin, python_val).to_json(), _struct.Struct())) ) + def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]: + # dataclass will try to hash python type when calling dataclass.schema(), but some types in the annotation is + # not hashable, such as Annotated[StructuredDataset, kwtypes(...)]. Therefore, we should just extract the origin + # type from annotated. + for field in dataclasses.fields(python_type): + if get_origin(field.type) is Annotated: + field.type = get_args(field.type)[0] + elif dataclasses.is_dataclass(field.type): + field.type = self._get_origin_type_in_annotation(field.type) + return python_type + def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.Any: """ If any field inside the dataclass is flyte type, we should use flyte type transformer for that field. diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index eb38a8d80b..1bc429db37 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -753,14 +753,14 @@ def test_structured_dataset_in_dataclass(): @dataclass_json @dataclass class InnerDatasetStruct(object): - a: StructuredDataset + a: Annotated[StructuredDataset, kwtypes(Name=str, Age=int)] b: typing.List[StructuredDataset] c: typing.Dict[str, StructuredDataset] @dataclass_json @dataclass class DatasetStruct(object): - a: StructuredDataset + a: Annotated[StructuredDataset, kwtypes(Name=str, Age=int)] b: InnerDatasetStruct sd = StructuredDataset(dataframe=df, file_format="parquet") From ab2b59242098b814bddfc885836db58d7a10622e Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 14 Feb 2023 16:43:30 -0800 Subject: [PATCH 2/6] test Signed-off-by: Kevin Su --- flytekit/core/type_engine.py | 48 +++++++++++++++++--- tests/flytekit/unit/core/test_type_engine.py | 9 ++-- 2 files changed, 46 insertions(+), 11 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 8c3a942481..9a85e3323a 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1,6 +1,7 @@ from __future__ import annotations import collections +import copy import dataclasses import datetime as _datetime import enum @@ -320,8 +321,7 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: ) schema = None try: - t = self._get_origin_type_in_annotation(t) - s = cast(DataClassJsonMixin, t).schema() + s = cast(DataClassJsonMixin, self._get_origin_type_in_annotation(t)).schema() for _, v in s.fields.items(): # marshmallow-jsonschema only supports enums loaded by name. # https://github.com/fuhrysteve/marshmallow-jsonschema/blob/81eada1a0c42ff67de216923968af0a6b54e5dcb/marshmallow_jsonschema/base.py#L228 @@ -357,13 +357,42 @@ def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]: # dataclass will try to hash python type when calling dataclass.schema(), but some types in the annotation is # not hashable, such as Annotated[StructuredDataset, kwtypes(...)]. Therefore, we should just extract the origin # type from annotated. - for field in dataclasses.fields(python_type): - if get_origin(field.type) is Annotated: - field.type = get_args(field.type)[0] - elif dataclasses.is_dataclass(field.type): + if get_origin(python_type) is list: + return typing.List[self._get_origin_type_in_annotation(get_args(python_type)[0])] + elif get_origin(python_type) is dict: + return typing.Dict[ + self._get_origin_type_in_annotation(get_args(python_type)[0]), + self._get_origin_type_in_annotation(get_args(python_type)[1]), + ] + elif get_origin(python_type) is Annotated: + return get_args(python_type)[0] + elif dataclasses.is_dataclass(python_type): + for field in dataclasses.fields(copy.deepcopy(python_type)): field.type = self._get_origin_type_in_annotation(field.type) return python_type + def _fix_structured_dataset_type(self, python_type: Type[T], python_val: T) -> T: + # In python 3.7, 3.8, DataclassJson will deserialize Annotated[StructuredDataset, kwtypes(..)] to a dict, + # so here we convert it back to the Structured Dataset. + from flytekit import StructuredDataset + + if python_type == StructuredDataset and type(python_val) == dict: + return StructuredDataset(**python_val) + elif get_origin(python_type) is list: + return [self._fix_structured_dataset_type(get_args(python_type)[0], v) for v in python_val] + elif get_origin(python_type) is dict: + return { + self._fix_structured_dataset_type(get_args(python_type)[0], k): self._fix_structured_dataset_type( + get_args(python_type)[1], v + ) + for k, v in python_val.items() + } + elif dataclasses.is_dataclass(python_type): + for field in dataclasses.fields(python_type): + val = python_val.__getattribute__(field.name) + python_val.__setattr__(field.name, self._fix_structured_dataset_type(field.type, val)) + return python_val + def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.Any: """ If any field inside the dataclass is flyte type, we should use flyte type transformer for that field. @@ -511,6 +540,9 @@ def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any: if val is None: return val + if get_origin(t) == Annotated: + t = get_args(t)[0] + if get_origin(t) is typing.Union and type(None) in get_args(t): # Handle optional type. e.g. Optional[int], Optional[dataclass] # Marshmallow doesn't support union type, so the type here is always an optional type. @@ -559,9 +591,11 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: f"Dataclass {expected_python_type} should be decorated with @dataclass_json to be " f"serialized correctly" ) - json_str = _json_format.MessageToJson(lv.scalar.generic) dc = cast(DataClassJsonMixin, expected_python_type).from_json(json_str) + dc = self._fix_structured_dataset_type(expected_python_type, dc) + # from flytekit import StructuredDataset + # dc.a = StructuredDataset(**dc.a) return self._fix_dataclass_int(expected_python_type, self._deserialize_flyte_type(dc, expected_python_type)) # This ensures that calls with the same literal type returns the same dataclass. For example, `pyflyte run`` diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 1bc429db37..6231402aa8 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -749,18 +749,19 @@ class TestFileStruct(object): def test_structured_dataset_in_dataclass(): df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + People = Annotated[StructuredDataset, "parquet", kwtypes(Name=str, Age=int)] @dataclass_json @dataclass class InnerDatasetStruct(object): - a: Annotated[StructuredDataset, kwtypes(Name=str, Age=int)] - b: typing.List[StructuredDataset] - c: typing.Dict[str, StructuredDataset] + a: People + b: typing.List[People] + c: typing.Dict[str, People] @dataclass_json @dataclass class DatasetStruct(object): - a: Annotated[StructuredDataset, kwtypes(Name=str, Age=int)] + a: People b: InnerDatasetStruct sd = StructuredDataset(dataframe=df, file_format="parquet") From 7e10447be6d134a85f137975e2e6ac98e262f0bd Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 14 Feb 2023 17:12:59 -0800 Subject: [PATCH 3/6] nit Signed-off-by: Kevin Su --- flytekit/core/type_engine.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 9a85e3323a..0b459dc4b4 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -540,9 +540,6 @@ def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any: if val is None: return val - if get_origin(t) == Annotated: - t = get_args(t)[0] - if get_origin(t) is typing.Union and type(None) in get_args(t): # Handle optional type. e.g. Optional[int], Optional[dataclass] # Marshmallow doesn't support union type, so the type here is always an optional type. @@ -594,8 +591,6 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: json_str = _json_format.MessageToJson(lv.scalar.generic) dc = cast(DataClassJsonMixin, expected_python_type).from_json(json_str) dc = self._fix_structured_dataset_type(expected_python_type, dc) - # from flytekit import StructuredDataset - # dc.a = StructuredDataset(**dc.a) return self._fix_dataclass_int(expected_python_type, self._deserialize_flyte_type(dc, expected_python_type)) # This ensures that calls with the same literal type returns the same dataclass. For example, `pyflyte run`` From e825e3881eb7b7b0480dc18580181284a7937b3c Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 16 Feb 2023 00:51:44 -0800 Subject: [PATCH 4/6] update test Signed-off-by: Kevin Su --- tests/flytekit/unit/core/test_type_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 4e72fa0e79..997000bc5a 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -755,8 +755,8 @@ def test_structured_dataset_in_dataclass(): @dataclass class InnerDatasetStruct(object): a: People - b: typing.List[People] - c: typing.Dict[str, People] + b: typing.List[Annotated[StructuredDataset, "parquet"]] + c: typing.Dict[str, Annotated[StructuredDataset, kwtypes(Name=str, Age=int)]] @dataclass_json @dataclass From 6b6d06362afe95f2441dcc203d42ff3a28d5970d Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 16 Feb 2023 10:54:37 -0800 Subject: [PATCH 5/6] lint Signed-off-by: Kevin Su --- flytekit/core/type_engine.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 0192caefea..3d9b64a2bf 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -358,9 +358,9 @@ def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]: # not hashable, such as Annotated[StructuredDataset, kwtypes(...)]. Therefore, we should just extract the origin # type from annotated. if get_origin(python_type) is list: - return typing.List[self._get_origin_type_in_annotation(get_args(python_type)[0])] + return typing.List[self._get_origin_type_in_annotation(get_args(python_type)[0])] # type: ignore elif get_origin(python_type) is dict: - return typing.Dict[ + return typing.Dict[ # type: ignore self._get_origin_type_in_annotation(get_args(python_type)[0]), self._get_origin_type_in_annotation(get_args(python_type)[1]), ] @@ -371,7 +371,7 @@ def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]: field.type = self._get_origin_type_in_annotation(field.type) return python_type - def _fix_structured_dataset_type(self, python_type: Type[T], python_val: T) -> T: + def _fix_structured_dataset_type(self, python_type: Type[T], python_val: typing.Any) -> T: # In python 3.7, 3.8, DataclassJson will deserialize Annotated[StructuredDataset, kwtypes(..)] to a dict, # so here we convert it back to the Structured Dataset. from flytekit import StructuredDataset @@ -379,9 +379,9 @@ def _fix_structured_dataset_type(self, python_type: Type[T], python_val: T) -> T if python_type == StructuredDataset and type(python_val) == dict: return StructuredDataset(**python_val) elif get_origin(python_type) is list: - return [self._fix_structured_dataset_type(get_args(python_type)[0], v) for v in python_val] + return [self._fix_structured_dataset_type(get_args(python_type)[0], v) for v in python_val] # type: ignore elif get_origin(python_type) is dict: - return { + return { # type: ignore self._fix_structured_dataset_type(get_args(python_type)[0], k): self._fix_structured_dataset_type( get_args(python_type)[1], v ) From 75a6090250595b1c0c58bc0471c8beadbbedc107 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 16 Feb 2023 10:58:42 -0800 Subject: [PATCH 6/6] nit Signed-off-by: Kevin Su --- tests/flytekit/unit/core/test_type_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 997000bc5a..842ae7a98c 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -754,7 +754,7 @@ def test_structured_dataset_in_dataclass(): @dataclass_json @dataclass class InnerDatasetStruct(object): - a: People + a: StructuredDataset b: typing.List[Annotated[StructuredDataset, "parquet"]] c: typing.Dict[str, Annotated[StructuredDataset, kwtypes(Name=str, Age=int)]]