Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get the origin type when serializing dataclass #1508

Merged
merged 10 commits into from
Feb 16, 2023
Merged
45 changes: 43 additions & 2 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import collections
import copy
import dataclasses
import datetime as _datetime
import enum
Expand Down Expand Up @@ -320,7 +321,7 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:
)
schema = None
try:
s = cast(DataClassJsonMixin, t).schema()
s = cast(DataClassJsonMixin, self._get_origin_type_in_annotation(t)).schema()
for _, v in s.fields.items():
# marshmallow-jsonschema only supports enums loaded by name.
# https://github.com/fuhrysteve/marshmallow-jsonschema/blob/81eada1a0c42ff67de216923968af0a6b54e5dcb/marshmallow_jsonschema/base.py#L228
Expand Down Expand Up @@ -352,6 +353,46 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
scalar=Scalar(generic=_json_format.Parse(cast(DataClassJsonMixin, python_val).to_json(), _struct.Struct()))
)

def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]:
# dataclass will try to hash python type when calling dataclass.schema(), but some types in the annotation is
# not hashable, such as Annotated[StructuredDataset, kwtypes(...)]. Therefore, we should just extract the origin
# type from annotated.
if get_origin(python_type) is list:
return typing.List[self._get_origin_type_in_annotation(get_args(python_type)[0])] # type: ignore
elif get_origin(python_type) is dict:
return typing.Dict[ # type: ignore
self._get_origin_type_in_annotation(get_args(python_type)[0]),
self._get_origin_type_in_annotation(get_args(python_type)[1]),
]
elif get_origin(python_type) is Annotated:
return get_args(python_type)[0]
elif dataclasses.is_dataclass(python_type):
for field in dataclasses.fields(copy.deepcopy(python_type)):
field.type = self._get_origin_type_in_annotation(field.type)
return python_type

def _fix_structured_dataset_type(self, python_type: Type[T], python_val: typing.Any) -> T:
# In python 3.7, 3.8, DataclassJson will deserialize Annotated[StructuredDataset, kwtypes(..)] to a dict,
# so here we convert it back to the Structured Dataset.
from flytekit import StructuredDataset

if python_type == StructuredDataset and type(python_val) == dict:
return StructuredDataset(**python_val)
elif get_origin(python_type) is list:
return [self._fix_structured_dataset_type(get_args(python_type)[0], v) for v in python_val] # type: ignore
elif get_origin(python_type) is dict:
return { # type: ignore
self._fix_structured_dataset_type(get_args(python_type)[0], k): self._fix_structured_dataset_type(
get_args(python_type)[1], v
)
for k, v in python_val.items()
}
elif dataclasses.is_dataclass(python_type):
for field in dataclasses.fields(python_type):
val = python_val.__getattribute__(field.name)
python_val.__setattr__(field.name, self._fix_structured_dataset_type(field.type, val))
return python_val

def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.Any:
"""
If any field inside the dataclass is flyte type, we should use flyte type transformer for that field.
Expand Down Expand Up @@ -559,9 +600,9 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
f"Dataclass {expected_python_type} should be decorated with @dataclass_json to be "
f"serialized correctly"
)

json_str = _json_format.MessageToJson(lv.scalar.generic)
dc = cast(DataClassJsonMixin, expected_python_type).from_json(json_str)
dc = self._fix_structured_dataset_type(expected_python_type, dc)
return self._fix_dataclass_int(expected_python_type, self._deserialize_flyte_type(dc, expected_python_type))

# This ensures that calls with the same literal type returns the same dataclass. For example, `pyflyte run``
Expand Down
7 changes: 4 additions & 3 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,18 +749,19 @@ class TestFileStruct(object):

def test_structured_dataset_in_dataclass():
df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})
People = Annotated[StructuredDataset, "parquet", kwtypes(Name=str, Age=int)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: preserve the bare StructuredDataset case.

Do we need to test other Annotated combinations? like Annotated[StructuredDataset, kwtypes(...)] without specifying the file format?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, updated it


@dataclass_json
@dataclass
class InnerDatasetStruct(object):
a: StructuredDataset
b: typing.List[StructuredDataset]
c: typing.Dict[str, StructuredDataset]
b: typing.List[Annotated[StructuredDataset, "parquet"]]
c: typing.Dict[str, Annotated[StructuredDataset, kwtypes(Name=str, Age=int)]]

@dataclass_json
@dataclass
class DatasetStruct(object):
a: StructuredDataset
a: People
b: InnerDatasetStruct

sd = StructuredDataset(dataframe=df, file_format="parquet")
Expand Down