Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use mashumaro to serialize/deserialize dataclass #1735

Merged
merged 38 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
c6a440c
Add encoder/decoder in structureDataset for snowflake.
hhcs9527 Aug 30, 2023
45f1c1f
add unit-test for snowflake structure dataset encoder/decoder
hhcs9527 Aug 31, 2023
cc450f3
add unit-test for snowflake structure dataset encoder/decoder
hhcs9527 Aug 31, 2023
9ad6cb1
let lazy_import_transformers force load the snowflake-connector
hhcs9527 Sep 1, 2023
958c2b6
add mock get_private_key for unit-test
hhcs9527 Sep 1, 2023
7fac349
add snowflake-connector-python in dev-requirements.in
hhcs9527 Sep 1, 2023
c9d03bc
add setup.py
hhcs9527 Jul 26, 2023
42d3bbb
add import DataClassJSONMixin
hhcs9527 Jul 26, 2023
c74951c
support DataClassJSONMixin to json in DataClassTransformer
hhcs9527 Jul 26, 2023
de73c95
support DataClassJSONMixin from json in DataClassTransformer
hhcs9527 Jul 26, 2023
449044e
fix Json Schema with to_dict
hhcs9527 Jul 26, 2023
262d92d
fix lint issue files
hhcs9527 Jul 26, 2023
b1c7875
add new test for testing mashumaro dataclass
hhcs9527 Jul 28, 2023
631602d
support structure dataclass and flytescheme with DataClassJSONMixin
hhcs9527 Jul 30, 2023
c637a5e
add test && fix lint
hhcs9527 Jul 31, 2023
c242bbd
support mahumaro with latest version
hhcs9527 Aug 6, 2023
e70ae1c
add files generated by the make requirement
hhcs9527 Aug 8, 2023
7a898f0
fix test issues in type engine
hhcs9527 Aug 15, 2023
adcaedf
Update the code with advise
hhcs9527 Aug 16, 2023
7be6ced
support new version of mashumaro
hhcs9527 Aug 18, 2023
01c505b
remove requirement
hhcs9527 Aug 23, 2023
a84f6c4
fix lint
hhcs9527 Aug 23, 2023
81b8a3d
Inherit directly from DataClassJsonMixin instead of using @dataclass_…
ringohoffman Aug 21, 2023
b57a82b
add import DataClassJSONMixin
hhcs9527 Jul 26, 2023
e3ccbf3
support DataClassJSONMixin to json in DataClassTransformer
hhcs9527 Jul 26, 2023
457d0d8
fix lint issue files
hhcs9527 Jul 26, 2023
11f7616
add new test for testing mashumaro dataclass
hhcs9527 Jul 28, 2023
b9cb5e9
support structure dataclass and flytescheme with DataClassJSONMixin
hhcs9527 Jul 30, 2023
e989f2c
add test && fix lint
hhcs9527 Jul 31, 2023
fcf0401
Update the code with advise
hhcs9527 Aug 16, 2023
84c8c14
fix type engine bugs
hhcs9527 Aug 23, 2023
6d188fa
Split convert_json_schema_to_python_class to convert_mashumaro_json_s…
hhcs9527 Aug 30, 2023
c19bdc9
Fix lint
hhcs9527 Aug 31, 2023
5e7e72b
remove un-relevant changes
hhcs9527 Sep 7, 2023
200300d
remove un-relevant changes
hhcs9527 Sep 7, 2023
fb265dd
fix the suggestion part
hhcs9527 Sep 8, 2023
a322dec
add test to cover the code cov && fix some schema name in type_engine
hhcs9527 Sep 8, 2023
a05dd4f
remove dedundant branch condition in type_engine and import in dev-re…
hhcs9527 Sep 9, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dev-requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pre-commit
codespell
google-cloud-bigquery
google-cloud-bigquery-storage
snowflake-connector-python
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
IPython
keyrings.alt

Expand Down
5 changes: 3 additions & 2 deletions doc-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ absl-py==1.4.0
# tensorflow
adlfs==2023.4.0
# via flytekit
aiobotocore==2.5.2
aiobotocore==2.5.3
# via s3fs
aiohttp==3.8.5
# via
Expand Down Expand Up @@ -98,7 +98,7 @@ bleach==6.0.0
# via nbconvert
blinker==1.6.2
# via flask
botocore==1.29.161
botocore==1.31.17
# via
# -r doc-requirements.in
# aiobotocore
Expand Down Expand Up @@ -1284,6 +1284,7 @@ typing-extensions==4.5.0
# flytekit
# great-expectations
# ipython
# mashumaro
# pydantic
# python-utils
# sqlalchemy
Expand Down
149 changes: 112 additions & 37 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import typing
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, cast
from typing import Dict, List, NamedTuple, Optional, Type, cast

from dataclasses_json import DataClassJsonMixin, dataclass_json
from google.protobuf import json_format as _json_format
Expand All @@ -22,6 +22,7 @@
from google.protobuf.message import Message
from google.protobuf.struct_pb2 import Struct
from marshmallow_enum import EnumField, LoadDumpOptions
from mashumaro.mixins.json import DataClassJSONMixin
from typing_extensions import Annotated, get_args, get_origin

from flytekit.core.annotation import FlyteAnnotation
Expand Down Expand Up @@ -53,6 +54,7 @@

T = typing.TypeVar("T")
DEFINITIONS = "definitions"
TITLE = "title"


class BatchSize:
Expand Down Expand Up @@ -344,22 +346,28 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:
f"Type {t} cannot be parsed."
)

if not issubclass(t, DataClassJsonMixin):
if not issubclass(t, DataClassJsonMixin) and not issubclass(t, DataClassJSONMixin):
raise AssertionError(
f"Dataclass {t} should be decorated with @dataclass_json or be a subclass of DataClassJsonMixin to be "
"serialized correctly"
f"Dataclass {t} should be decorated with @dataclass_json or mixin with DataClassJSONMixin to be "
f"serialized correctly"
)
schema = None
try:
s = cast(DataClassJsonMixin, self._get_origin_type_in_annotation(t)).schema()
for _, v in s.fields.items():
# marshmallow-jsonschema only supports enums loaded by name.
# https://github.com/fuhrysteve/marshmallow-jsonschema/blob/81eada1a0c42ff67de216923968af0a6b54e5dcb/marshmallow_jsonschema/base.py#L228
if isinstance(v, EnumField):
v.load_by = LoadDumpOptions.name
from marshmallow_jsonschema import JSONSchema

schema = JSONSchema().dump(s)
if issubclass(t, DataClassJsonMixin):
s = cast(DataClassJsonMixin, self._get_origin_type_in_annotation(t)).schema()
for _, v in s.fields.items():
# marshmallow-jsonschema only supports enums loaded by name.
# https://github.com/fuhrysteve/marshmallow-jsonschema/blob/81eada1a0c42ff67de216923968af0a6b54e5dcb/marshmallow_jsonschema/base.py#L228
if isinstance(v, EnumField):
v.load_by = LoadDumpOptions.name
# check if DataClass mixin
from marshmallow_jsonschema import JSONSchema

schema = JSONSchema().dump(s)
else: # DataClassJSONMixin
from mashumaro.jsonschema import build_json_schema

schema = build_json_schema(cast(DataClassJSONMixin, self._get_origin_type_in_annotation(t))).to_dict()
except Exception as e:
# https://github.com/lovasoa/marshmallow_dataclass/issues/13
logger.warning(
Expand All @@ -376,15 +384,21 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
f"{type(python_val)} is not of type @dataclass, only Dataclasses are supported for "
f"user defined datatypes in Flytekit"
)
if not issubclass(type(python_val), DataClassJsonMixin):
if not issubclass(type(python_val), DataClassJsonMixin) and not issubclass(
type(python_val), DataClassJSONMixin
):
hhcs9527 marked this conversation as resolved.
Show resolved Hide resolved
raise TypeTransformerFailedError(
f"Dataclass {python_type} should be decorated with @dataclass_json or be a subclass of "
"DataClassJsonMixin to be serialized correctly"
f"Dataclass {python_type} should be decorated with @dataclass_json or inherit DataClassJSONMixin to be "
f"serialized correctly"
)
self._serialize_flyte_type(python_val, python_type)
return Literal(
scalar=Scalar(generic=_json_format.Parse(cast(DataClassJsonMixin, python_val).to_json(), _struct.Struct()))
)

if issubclass(type(python_val), DataClassJsonMixin):
json_str = cast(DataClassJsonMixin, python_val).to_json() # type: ignore
else:
json_str = cast(DataClassJSONMixin, python_val).to_json() # type: ignore
hhcs9527 marked this conversation as resolved.
Show resolved Hide resolved
pingsutw marked this conversation as resolved.
Show resolved Hide resolved

return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) # type: ignore

def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]:
# dataclass will try to hash python type when calling dataclass.schema(), but some types in the annotation is
Expand Down Expand Up @@ -628,13 +642,16 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
f"{expected_python_type} is not of type @dataclass, only Dataclasses are supported for "
"user defined datatypes in Flytekit"
)
if not issubclass(expected_python_type, DataClassJsonMixin):
if not issubclass(expected_python_type, DataClassJsonMixin) and not issubclass(
expected_python_type, DataClassJSONMixin
):
raise TypeTransformerFailedError(
f"Dataclass {expected_python_type} should be decorated with @dataclass_json or be a subclass of "
"DataClassJsonMixin to be serialized correctly"
f"Dataclass {expected_python_type} should be decorated with @dataclass_json or mixin with DataClassJSONMixin to be "
f"serialized correctly"
)
json_str = _json_format.MessageToJson(lv.scalar.generic)
dc = cast(DataClassJsonMixin, expected_python_type).from_json(json_str)
dc = expected_python_type.from_json(json_str) # type: ignore

dc = self._fix_structured_dataset_type(expected_python_type, dc)
return self._fix_dataclass_int(expected_python_type, self._deserialize_flyte_type(dc, expected_python_type))

Expand All @@ -645,10 +662,15 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
@lru_cache(typed=True)
def guess_python_type(self, literal_type: LiteralType) -> Type[T]: # type: ignore
if literal_type.simple == SimpleType.STRUCT:
if literal_type.metadata is not None and DEFINITIONS in literal_type.metadata:
schema_name = literal_type.metadata["$ref"].split("/")[-1]
return convert_json_schema_to_python_class(literal_type.metadata[DEFINITIONS], schema_name)

if literal_type.metadata is not None:
if DEFINITIONS in literal_type.metadata:
schema_name = literal_type.metadata["$ref"].split("/")[-1]
return convert_marshmallow_json_schema_to_python_class(
literal_type.metadata[DEFINITIONS], schema_name
)
elif TITLE in literal_type.metadata:
schema_name = literal_type.metadata[TITLE]
return convert_mashumaro_json_schema_to_python_class(literal_type.metadata, schema_name)
raise ValueError(f"Dataclass transformer cannot reverse {literal_type}")


Expand Down Expand Up @@ -1550,13 +1572,45 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
return expected_python_type(lv.scalar.primitive.string_value) # type: ignore


def convert_json_schema_to_python_class(schema: Dict[str, Any], schema_name: str) -> Type[Any]:
"""
Generate a model class based on the provided JSON Schema
:param schema: dict representing valid JSON schema
:param schema_name: dataclass name of return type
"""
attribute_list: List[Tuple[str, type]] = []
def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name: typing.Any):
attribute_list = []
for property_key, property_val in schema["properties"].items():
if property_val.get("anyOf"):
property_type = property_val["anyOf"][0]["type"]
elif property_val.get("enum"):
property_type = "enum"
else:
property_type = property_val["type"]
# Handle list
if property_type == "array":
attribute_list.append((property_key, typing.List[_get_element_type(property_val["items"])])) # type: ignore
# Handle dataclass and dict
elif property_type == "object":
if property_val.get("anyOf"):
sub_schemea = property_val["anyOf"][0]
sub_schemea_name = sub_schemea["title"]
attribute_list.append(
(property_key, convert_mashumaro_json_schema_to_python_class(sub_schemea, sub_schemea_name))
)
elif property_val.get("additionalProperties"):
attribute_list.append(
(property_key, typing.Dict[str, _get_element_type(property_val["additionalProperties"])]) # type: ignore
)
else:
sub_schemea_name = property_val["title"]
attribute_list.append(
(property_key, convert_mashumaro_json_schema_to_python_class(property_val, sub_schemea_name))
)
elif property_type == "enum":
attribute_list.append([property_key, str]) # type: ignore
# Handle int, float, bool or str
else:
attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore
return attribute_list


def generate_attribute_list_from_dataclass_json(schema: dict, schema_name: typing.Any):
attribute_list = []
for property_key, property_val in schema[schema_name]["properties"].items():
property_type = property_val["type"]
# Handle list
Expand All @@ -1566,7 +1620,7 @@ def convert_json_schema_to_python_class(schema: Dict[str, Any], schema_name: str
elif property_type == "object":
if property_val.get("$ref"):
name = property_val["$ref"].split("/")[-1]
attribute_list.append((property_key, convert_json_schema_to_python_class(schema, name)))
attribute_list.append((property_key, convert_marshmallow_json_schema_to_python_class(schema, name)))
elif property_val.get("additionalProperties"):
attribute_list.append(
(property_key, Dict[str, _get_element_type(property_val["additionalProperties"])]) # type: ignore[misc,index]
Expand All @@ -1575,13 +1629,34 @@ def convert_json_schema_to_python_class(schema: Dict[str, Any], schema_name: str
attribute_list.append((property_key, Dict[str, _get_element_type(property_val)])) # type: ignore[misc,index]
# Handle int, float, bool or str
else:
attribute_list.append((property_key, _get_element_type(property_val)))
attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore
return attribute_list


def convert_marshmallow_json_schema_to_python_class(schema: dict, schema_name: typing.Any) -> Type[dataclasses.dataclass()]: # type: ignore
"""
Generate a model class based on the provided JSON Schema
:param schema: dict representing valid JSON schema
:param schema_name: dataclass name of return type
"""

attribute_list = generate_attribute_list_from_dataclass_json(schema, schema_name)
return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list))


def convert_mashumaro_json_schema_to_python_class(schema: dict, schema_name: typing.Any) -> Type[dataclasses.dataclass()]: # type: ignore
"""
Generate a model class based on the provided JSON Schema
:param schema: dict representing valid JSON schema
:param schema_name: dataclass name of return type
"""

attribute_list = generate_attribute_list_from_dataclass_json_mixin(schema, schema_name)
return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list))


def _get_element_type(element_property: typing.Dict[str, str]) -> Type:
element_type = element_property["type"]
element_type = [e_property["type"] for e_property in element_property["anyOf"]] if element_property.get("anyOf") else element_property["type"] # type: ignore
element_format = element_property["format"] if "format" in element_property else None

if type(element_type) == list:
Expand Down
5 changes: 3 additions & 2 deletions flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from contextlib import contextmanager
from dataclasses import dataclass, field

from dataclasses_json import DataClassJsonMixin, config
from dataclasses_json import config
from marshmallow import fields
from mashumaro.mixins.json import DataClassJSONMixin

from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError, get_underlying_type
Expand All @@ -26,7 +27,7 @@ def noop():


@dataclass
class FlyteFile(DataClassJsonMixin, os.PathLike, typing.Generic[T]):
class FlyteFile(os.PathLike, typing.Generic[T], DataClassJSONMixin):
path: typing.Union[str, os.PathLike] = field(
hhcs9527 marked this conversation as resolved.
Show resolved Hide resolved
default=None, metadata=config(mm_field=fields.String())
) # type: ignore
Expand Down
5 changes: 3 additions & 2 deletions flytekit/types/schema/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@

import numpy as _np
import pandas
from dataclasses_json import DataClassJsonMixin, config
from dataclasses_json import config
from marshmallow import fields
from mashumaro.mixins.json import DataClassJSONMixin

from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
Expand Down Expand Up @@ -180,7 +181,7 @@ def get_handler(cls, t: Type) -> SchemaHandler:


@dataclass
class FlyteSchema(DataClassJsonMixin):
class FlyteSchema(DataClassJSONMixin):
remote_path: typing.Optional[str] = field(default=None, metadata=config(mm_field=fields.String()))
"""
This is the main schema class that users should use.
Expand Down
5 changes: 3 additions & 2 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from typing import Dict, Generator, Optional, Type, Union

import _datetime
from dataclasses_json import DataClassJsonMixin, config
from dataclasses_json import config
from fsspec.utils import get_protocol
from marshmallow import fields
from mashumaro.mixins.json import DataClassJSONMixin
from typing_extensions import Annotated, TypeAlias, get_args, get_origin

from flytekit import lazy_module
Expand Down Expand Up @@ -44,7 +45,7 @@


@dataclass
class StructuredDataset(DataClassJsonMixin):
class StructuredDataset(DataClassJSONMixin):
"""
This is the user facing StructuredDataset class. Please don't confuse it with the literals.StructuredDataset
class (that is just a model, a Python class representation of the protobuf).
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
# TODO: remove upper-bound after fixing change in contract
"dataclasses-json>=0.5.2,<0.5.12",
"marshmallow-jsonschema>=0.12.0",
"mashumaro>=3.9.1",
"marshmallow-enum",
"natsort>=7.0.1",
"docker-image-py>=0.1.10",
Expand Down
Loading