From c6a440cc67de5d1fc74b70130b653224a6b3199a Mon Sep 17 00:00:00 2001 From: HH Date: Wed, 30 Aug 2023 10:26:27 +0800 Subject: [PATCH 01/38] Add encoder/decoder in structureDataset for snowflake. Signed-off-by: HH --- flytekit/core/type_engine.py | 3 + flytekit/types/structured/__init__.py | 13 +++ flytekit/types/structured/snowflake.py | 116 +++++++++++++++++++++++++ 3 files changed, 132 insertions(+) create mode 100644 flytekit/types/structured/snowflake.py diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index eacc3a15d8..23b29f645f 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -827,6 +827,7 @@ def lazy_import_transformers(cls): register_arrow_handlers, register_bigquery_handlers, register_pandas_handlers, + register_snowflake_handlers, ) if is_imported("tensorflow"): @@ -845,6 +846,8 @@ def lazy_import_transformers(cls): register_arrow_handlers() if is_imported("google.cloud.bigquery"): register_bigquery_handlers() + if is_imported("snowflake.connector"): + register_snowflake_handlers() if is_imported("numpy"): from flytekit.types import numpy # noqa: F401 diff --git a/flytekit/types/structured/__init__.py b/flytekit/types/structured/__init__.py index 543117c865..407961717f 100644 --- a/flytekit/types/structured/__init__.py +++ b/flytekit/types/structured/__init__.py @@ -70,3 +70,16 @@ def register_bigquery_handlers(): "We won't register bigquery handler for structured dataset because " "we can't find the packages google-cloud-bigquery-storage and google-cloud-bigquery" ) + + +def register_snowflake_handlers(): + try: + from .snowflake import PandasToSnowflakeEncodingHandlers, SnowflakeToPandasDecodingHandler + + StructuredDatasetTransformerEngine.register(SnowflakeToPandasDecodingHandler()) + StructuredDatasetTransformerEngine.register(PandasToSnowflakeEncodingHandlers()) + + except ImportError: + logger.info( + "We won't register snowflake handler for structured dataset because " "we can't find package snowflake" + ) diff --git a/flytekit/types/structured/snowflake.py b/flytekit/types/structured/snowflake.py new file mode 100644 index 0000000000..b359ce424d --- /dev/null +++ b/flytekit/types/structured/snowflake.py @@ -0,0 +1,116 @@ +import re +import typing + +import pandas as pd +import pyarrow as pa +import snowflake.connector +from snowflake.connector.pandas_tools import write_pandas + +from flytekit import FlyteContext +from flytekit.models import literals +from flytekit.models.types import StructuredDatasetType +from flytekit.types.structured.structured_dataset import ( + StructuredDataset, + StructuredDatasetDecoder, + StructuredDatasetEncoder, + StructuredDatasetMetadata, +) + +SNOWFLAKE = "snowflake" + + +def get_private_key(): + import os + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.hazmat.primitives.asymmetric import dsa + from cryptography.hazmat.primitives import serialization + import flytekit + + pk_path = flytekit.current_context().secrets.get_secrets_file(SNOWFLAKE, "rsa_key.p8") + + with open(pk_path, "rb") as key: + p_key= serialization.load_pem_private_key( + key.read(), + password=None, + backend=default_backend() + ) + + return p_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption()) + + +def _write_to_sf(structured_dataset: StructuredDataset): + if structured_dataset.uri is None: + raise ValueError("structured_dataset.uri cannot be None.") + + uri = structured_dataset.uri + _, user, account, database, schema, warehouse, table = re.split("\\/|://|:", uri) + df = structured_dataset.dataframe + + conn = snowflake.connector.connect( + user=user, + account=account, + private_key=get_private_key(), + database=database, + schema=schema, + warehouse=warehouse + ) + + cs = conn.cursor() + write_pandas(conn, df, table) + + +def _read_from_sf( + flyte_value: literals.StructuredDataset, current_task_metadata: StructuredDatasetMetadata +) -> pd.DataFrame: + if flyte_value.uri is None: + raise ValueError("structured_dataset.uri cannot be None.") + + uri = flyte_value.uri + _, user, account, database, schema, warehouse, table = re.split("\\/|://|:", uri) + + conn = snowflake.connector.connect( + user=user, + account=account, + private_key=get_private_key(), + database=database, + schema=schema, + warehouse=warehouse + ) + + cs = conn.cursor() + cs.execute(f"select * from {table}") + + return cs.fetch_pandas_all() + + +class PandasToSnowflakeEncodingHandlers(StructuredDatasetEncoder): + def __init__(self): + super().__init__(pd.DataFrame, SNOWFLAKE, supported_format="") + + def encode( + self, + ctx: FlyteContext, + structured_dataset: StructuredDataset, + structured_dataset_type: StructuredDatasetType, + ) -> literals.StructuredDataset: + _write_to_sf(structured_dataset) + return literals.StructuredDataset( + uri=typing.cast(str, structured_dataset.uri), metadata=StructuredDatasetMetadata(structured_dataset_type) + ) + + +class SnowflakeToPandasDecodingHandler(StructuredDatasetDecoder): + def __init__(self): + super().__init__(pd.DataFrame, SNOWFLAKE, supported_format="") + + def decode( + self, + ctx: FlyteContext, + flyte_value: literals.StructuredDataset, + current_task_metadata: StructuredDatasetMetadata, + ) -> pd.DataFrame: + return _read_from_sf(flyte_value, current_task_metadata) From 45f1c1fa216e7b289e4ac6c4e88bf8a7c4b5b344 Mon Sep 17 00:00:00 2001 From: HH Date: Thu, 31 Aug 2023 12:21:37 +0800 Subject: [PATCH 02/38] add unit-test for snowflake structure dataset encoder/decoder Signed-off-by: HH --- .../structured_dataset/test_snowflake.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 tests/flytekit/unit/types/structured_dataset/test_snowflake.py diff --git a/tests/flytekit/unit/types/structured_dataset/test_snowflake.py b/tests/flytekit/unit/types/structured_dataset/test_snowflake.py new file mode 100644 index 0000000000..8ea85e9e17 --- /dev/null +++ b/tests/flytekit/unit/types/structured_dataset/test_snowflake.py @@ -0,0 +1,48 @@ +import mock +import pytest +import pandas as pd +from typing_extensions import Annotated + +from flytekit import StructuredDataset, kwtypes, task, workflow + +pd_df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) +my_cols = kwtypes(Name=str, Age=int) + + +@task +def gen_df() -> Annotated[pd.DataFrame, my_cols, "parquet"]: + return pd_df + + +@task +def t1(df: pd.DataFrame) -> Annotated[StructuredDataset, my_cols]: + return StructuredDataset(dataframe=df, uri="snowflake://dummy_user:dummy_account/dummy_database/dummy_schema/dummy_warehouse/dummy_table") + + +@task +def t2(sd: Annotated[StructuredDataset, my_cols]) -> pd.DataFrame: + return sd.open(pd.DataFrame).all() + + +@workflow +def wf() -> pd.DataFrame: + df = gen_df() + sd = t1(df=df) + return t2(sd=sd) + + +@mock.patch("snowflake.connector.connect") +@pytest.mark.asyncio +async def test_sf_wf(mock_connect): + class mock_pages: + def to_dataframe(self): + return pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + + class mock_rows: + pages = [mock_pages()] + + mock_connect_instance = mock_connect.return_value + mock_coursor_instance = mock_connect.cursor.return_value + mock_coursor_instance.fetch_pandas_all.return_value = mock_rows + + assert wf().equals(pd_df) From cc450f3ccd0f83b555be84280e28f625cde2bdbe Mon Sep 17 00:00:00 2001 From: HH Date: Thu, 31 Aug 2023 17:31:57 +0800 Subject: [PATCH 03/38] add unit-test for snowflake structure dataset encoder/decoder Signed-off-by: HH --- flytekit/types/structured/snowflake.py | 34 ++++++------------- .../structured_dataset/test_snowflake.py | 15 ++++---- 2 files changed, 17 insertions(+), 32 deletions(-) diff --git a/flytekit/types/structured/snowflake.py b/flytekit/types/structured/snowflake.py index b359ce424d..9f28734f43 100644 --- a/flytekit/types/structured/snowflake.py +++ b/flytekit/types/structured/snowflake.py @@ -2,7 +2,6 @@ import typing import pandas as pd -import pyarrow as pa import snowflake.connector from snowflake.connector.pandas_tools import write_pandas @@ -20,26 +19,21 @@ def get_private_key(): - import os from cryptography.hazmat.backends import default_backend - from cryptography.hazmat.primitives.asymmetric import rsa - from cryptography.hazmat.primitives.asymmetric import dsa from cryptography.hazmat.primitives import serialization + import flytekit pk_path = flytekit.current_context().secrets.get_secrets_file(SNOWFLAKE, "rsa_key.p8") with open(pk_path, "rb") as key: - p_key= serialization.load_pem_private_key( - key.read(), - password=None, - backend=default_backend() - ) + p_key = serialization.load_pem_private_key(key.read(), password=None, backend=default_backend()) return p_key.private_bytes( encoding=serialization.Encoding.DER, format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption()) + encryption_algorithm=serialization.NoEncryption(), + ) def _write_to_sf(structured_dataset: StructuredDataset): @@ -51,15 +45,9 @@ def _write_to_sf(structured_dataset: StructuredDataset): df = structured_dataset.dataframe conn = snowflake.connector.connect( - user=user, - account=account, - private_key=get_private_key(), - database=database, - schema=schema, - warehouse=warehouse + user=user, account=account, private_key=get_private_key(), database=database, schema=schema, warehouse=warehouse ) - cs = conn.cursor() write_pandas(conn, df, table) @@ -73,18 +61,16 @@ def _read_from_sf( _, user, account, database, schema, warehouse, table = re.split("\\/|://|:", uri) conn = snowflake.connector.connect( - user=user, - account=account, - private_key=get_private_key(), - database=database, - schema=schema, - warehouse=warehouse + user=user, account=account, private_key=get_private_key(), database=database, schema=schema, warehouse=warehouse ) cs = conn.cursor() cs.execute(f"select * from {table}") - return cs.fetch_pandas_all() + dff = cs.fetch_pandas_all() + print("cs", cs) + print("dff", dff) + return dff class PandasToSnowflakeEncodingHandlers(StructuredDatasetEncoder): diff --git a/tests/flytekit/unit/types/structured_dataset/test_snowflake.py b/tests/flytekit/unit/types/structured_dataset/test_snowflake.py index 8ea85e9e17..0c88be40d5 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_snowflake.py +++ b/tests/flytekit/unit/types/structured_dataset/test_snowflake.py @@ -1,6 +1,6 @@ import mock -import pytest import pandas as pd +import pytest from typing_extensions import Annotated from flytekit import StructuredDataset, kwtypes, task, workflow @@ -16,7 +16,9 @@ def gen_df() -> Annotated[pd.DataFrame, my_cols, "parquet"]: @task def t1(df: pd.DataFrame) -> Annotated[StructuredDataset, my_cols]: - return StructuredDataset(dataframe=df, uri="snowflake://dummy_user:dummy_account/dummy_database/dummy_schema/dummy_warehouse/dummy_table") + return StructuredDataset( + dataframe=df, uri="snowflake://dummy_user:dummy_account/dummy_database/dummy_schema/dummy_warehouse/dummy_table" + ) @task @@ -34,15 +36,12 @@ def wf() -> pd.DataFrame: @mock.patch("snowflake.connector.connect") @pytest.mark.asyncio async def test_sf_wf(mock_connect): - class mock_pages: + class mock_dataframe: def to_dataframe(self): return pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) - class mock_rows: - pages = [mock_pages()] - mock_connect_instance = mock_connect.return_value - mock_coursor_instance = mock_connect.cursor.return_value - mock_coursor_instance.fetch_pandas_all.return_value = mock_rows + mock_coursor_instance = mock_connect_instance.cursor.return_value + mock_coursor_instance.fetch_pandas_all.return_value = mock_dataframe().to_dataframe() assert wf().equals(pd_df) From 9ad6cb16fef33d1c2cf3a7f9771dc877669d61d7 Mon Sep 17 00:00:00 2001 From: HH Date: Fri, 1 Sep 2023 12:27:18 +0800 Subject: [PATCH 04/38] let lazy_import_transformers force load the snowflake-connector Signed-off-by: HH --- flytekit/core/type_engine.py | 7 +++++-- flytekit/types/structured/__init__.py | 3 ++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 23b29f645f..b5b14bdac0 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -846,11 +846,14 @@ def lazy_import_transformers(cls): register_arrow_handlers() if is_imported("google.cloud.bigquery"): register_bigquery_handlers() - if is_imported("snowflake.connector"): - register_snowflake_handlers() if is_imported("numpy"): from flytekit.types import numpy # noqa: F401 + try: + register_snowflake_handlers() + except ValueError as e: + logger.debug(f"Attempted to register the Snowflake handler but failed due to: {str(e)}") + @classmethod def to_literal_type(cls, python_type: Type) -> LiteralType: """ diff --git a/flytekit/types/structured/__init__.py b/flytekit/types/structured/__init__.py index 407961717f..617e4bcafa 100644 --- a/flytekit/types/structured/__init__.py +++ b/flytekit/types/structured/__init__.py @@ -81,5 +81,6 @@ def register_snowflake_handlers(): except ImportError: logger.info( - "We won't register snowflake handler for structured dataset because " "we can't find package snowflake" + "We won't register snowflake handler for structured dataset because " + "we can't find package snowflakee-connector-python" ) From 958c2b6ed0348ea7cbdaad477d853bd7f18fe348 Mon Sep 17 00:00:00 2001 From: HH Date: Fri, 1 Sep 2023 12:33:50 +0800 Subject: [PATCH 05/38] add mock get_private_key for unit-test Signed-off-by: HH --- tests/flytekit/unit/types/structured_dataset/test_snowflake.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/flytekit/unit/types/structured_dataset/test_snowflake.py b/tests/flytekit/unit/types/structured_dataset/test_snowflake.py index 0c88be40d5..c957c0bbce 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_snowflake.py +++ b/tests/flytekit/unit/types/structured_dataset/test_snowflake.py @@ -33,9 +33,10 @@ def wf() -> pd.DataFrame: return t2(sd=sd) +@mock.patch("flytekit.types.structured.snowflake.get_private_key", return_value="pb") @mock.patch("snowflake.connector.connect") @pytest.mark.asyncio -async def test_sf_wf(mock_connect): +async def test_sf_wf(mock_connect, mock_get_private_key): class mock_dataframe: def to_dataframe(self): return pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) From 7fac349fc0ad7c32ace4f17ca4c0fe47ef1d5b32 Mon Sep 17 00:00:00 2001 From: HH Date: Fri, 1 Sep 2023 12:36:33 +0800 Subject: [PATCH 06/38] add snowflake-connector-python in dev-requirements.in Signed-off-by: HH --- dev-requirements.in | 1 + 1 file changed, 1 insertion(+) diff --git a/dev-requirements.in b/dev-requirements.in index 2c7ddd00c2..4a9df85e53 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -11,6 +11,7 @@ pre-commit codespell google-cloud-bigquery google-cloud-bigquery-storage +snowflake-connector-python IPython keyrings.alt From c9d03bc00282d140e8b3c575ab14044a2df3eb5a Mon Sep 17 00:00:00 2001 From: HH Date: Wed, 26 Jul 2023 16:43:53 +0800 Subject: [PATCH 07/38] add setup.py Signed-off-by: HH --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index e78b815f5c..6496777d1d 100644 --- a/setup.py +++ b/setup.py @@ -62,6 +62,7 @@ # TODO: remove upper-bound after fixing change in contract "dataclasses-json>=0.5.2,<0.5.12", "marshmallow-jsonschema>=0.12.0", + "mashumaro>=3.8.1", "marshmallow-enum", "natsort>=7.0.1", "docker-image-py>=0.1.10", From 42d3bbb4cff0e059b56d63c5768d4836f46c36d9 Mon Sep 17 00:00:00 2001 From: HH Date: Wed, 26 Jul 2023 16:44:42 +0800 Subject: [PATCH 08/38] add import DataClassJSONMixin Signed-off-by: HH --- flytekit/core/type_engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index b5b14bdac0..fb342e6c9e 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -15,6 +15,7 @@ from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, cast from dataclasses_json import DataClassJsonMixin, dataclass_json +from mashumaro.mixins.json import DataClassJSONMixin from google.protobuf import json_format as _json_format from google.protobuf import struct_pb2 as _struct from google.protobuf.json_format import MessageToDict as _MessageToDict From c74951c2f940f65ca6b0657ed0fdac32b64b1982 Mon Sep 17 00:00:00 2001 From: HH Date: Wed, 26 Jul 2023 16:45:58 +0800 Subject: [PATCH 09/38] support DataClassJSONMixin to json in DataClassTransformer Signed-off-by: HH --- flytekit/core/type_engine.py | 38 +++++++++++++++++++++++------------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index fb342e6c9e..c95e3205ca 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -345,22 +345,27 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: f"Type {t} cannot be parsed." ) - if not issubclass(t, DataClassJsonMixin): + if not issubclass(t, DataClassJsonMixin) and not issubclass(t, DataClassJSONMixin): raise AssertionError( f"Dataclass {t} should be decorated with @dataclass_json or be a subclass of DataClassJsonMixin to be " "serialized correctly" ) schema = None try: - s = cast(DataClassJsonMixin, self._get_origin_type_in_annotation(t)).schema() - for _, v in s.fields.items(): - # marshmallow-jsonschema only supports enums loaded by name. - # https://github.com/fuhrysteve/marshmallow-jsonschema/blob/81eada1a0c42ff67de216923968af0a6b54e5dcb/marshmallow_jsonschema/base.py#L228 - if isinstance(v, EnumField): - v.load_by = LoadDumpOptions.name - from marshmallow_jsonschema import JSONSchema - - schema = JSONSchema().dump(s) + if issubclass(t, DataClassJsonMixin): + s = cast(DataClassJsonMixin, self._get_origin_type_in_annotation(t)).schema() + for _, v in s.fields.items(): + # marshmallow-jsonschema only supports enums loaded by name. + # https://github.com/fuhrysteve/marshmallow-jsonschema/blob/81eada1a0c42ff67de216923968af0a6b54e5dcb/marshmallow_jsonschema/base.py#L228 + if isinstance(v, EnumField): + v.load_by = LoadDumpOptions.name + # check if DataClass mixin + from marshmallow_jsonschema import JSONSchema + + schema = JSONSchema().dump(s) + else: # DataClassJSONMixin + from mashumaro.jsonschema import build_json_schema + schema = build_json_schema(cast(DataClassJSONMixin, self._get_origin_type_in_annotation(t))).to_json() except Exception as e: # https://github.com/lovasoa/marshmallow_dataclass/issues/13 logger.warning( @@ -377,14 +382,19 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp f"{type(python_val)} is not of type @dataclass, only Dataclasses are supported for " f"user defined datatypes in Flytekit" ) - if not issubclass(type(python_val), DataClassJsonMixin): + if not issubclass(type(python_val), DataClassJsonMixin) and not issubclass(type(python_val), DataClassJSONMixin): raise TypeTransformerFailedError( - f"Dataclass {python_type} should be decorated with @dataclass_json or be a subclass of " - "DataClassJsonMixin to be serialized correctly" + f"Dataclass {python_type} should be decorated with @dataclass_json to be " f"serialized correctly" ) self._serialize_flyte_type(python_val, python_type) + + if issubclass(type(python_val), DataClassJsonMixin): + json_str = cast(DataClassJsonMixin, python_val).to_json() # type: ignore + else: + json_str = cast(DataClassJSONMixin, python_val).to_json() # type: ignore + return Literal( - scalar=Scalar(generic=_json_format.Parse(cast(DataClassJsonMixin, python_val).to_json(), _struct.Struct())) + scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())) # type: ignore ) def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]: From de73c95947f4ccdd16383d71bf1fd6fbbf877125 Mon Sep 17 00:00:00 2001 From: HH Date: Wed, 26 Jul 2023 16:46:16 +0800 Subject: [PATCH 10/38] support DataClassJSONMixin from json in DataClassTransformer Signed-off-by: HH --- flytekit/core/type_engine.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index c95e3205ca..f7036db2f1 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -639,13 +639,18 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: f"{expected_python_type} is not of type @dataclass, only Dataclasses are supported for " "user defined datatypes in Flytekit" ) - if not issubclass(expected_python_type, DataClassJsonMixin): + if not issubclass(expected_python_type, DataClassJsonMixin) and not issubclass(expected_python_type, DataClassJSONMixin): raise TypeTransformerFailedError( f"Dataclass {expected_python_type} should be decorated with @dataclass_json or be a subclass of " "DataClassJsonMixin to be serialized correctly" ) json_str = _json_format.MessageToJson(lv.scalar.generic) - dc = cast(DataClassJsonMixin, expected_python_type).from_json(json_str) + + if issubclass(expected_python_type, DataClassJsonMixin): + dc = cast(DataClassJsonMixin, expected_python_type).from_json(json_str) # type: ignore + else: + dc = cast(DataClassJSONMixin, expected_python_type).from_json(json_str) # type: ignore + dc = self._fix_structured_dataset_type(expected_python_type, dc) return self._fix_dataclass_int(expected_python_type, self._deserialize_flyte_type(dc, expected_python_type)) From 449044e6bf96c011706b19736c49857ed2f71905 Mon Sep 17 00:00:00 2001 From: HH Date: Wed, 26 Jul 2023 17:42:36 +0800 Subject: [PATCH 11/38] fix Json Schema with to_dict Signed-off-by: HH --- flytekit/core/type_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index f7036db2f1..555c3972cb 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -365,7 +365,7 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: schema = JSONSchema().dump(s) else: # DataClassJSONMixin from mashumaro.jsonschema import build_json_schema - schema = build_json_schema(cast(DataClassJSONMixin, self._get_origin_type_in_annotation(t))).to_json() + schema = build_json_schema(cast(DataClassJSONMixin, self._get_origin_type_in_annotation(t))).to_dict() except Exception as e: # https://github.com/lovasoa/marshmallow_dataclass/issues/13 logger.warning( From 262d92df8a6b9d3759f700c386ef1ebfc3011ac6 Mon Sep 17 00:00:00 2001 From: hhcs9527 Date: Wed, 26 Jul 2023 17:01:52 +0000 Subject: [PATCH 12/38] fix lint issue files Signed-off-by: hhcs9527 --- flytekit/core/type_engine.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 555c3972cb..ba28869e6e 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -15,7 +15,6 @@ from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, cast from dataclasses_json import DataClassJsonMixin, dataclass_json -from mashumaro.mixins.json import DataClassJSONMixin from google.protobuf import json_format as _json_format from google.protobuf import struct_pb2 as _struct from google.protobuf.json_format import MessageToDict as _MessageToDict @@ -23,6 +22,7 @@ from google.protobuf.message import Message from google.protobuf.struct_pb2 import Struct from marshmallow_enum import EnumField, LoadDumpOptions +from mashumaro.mixins.json import DataClassJSONMixin from typing_extensions import Annotated, get_args, get_origin from flytekit.core.annotation import FlyteAnnotation @@ -363,8 +363,9 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: from marshmallow_jsonschema import JSONSchema schema = JSONSchema().dump(s) - else: # DataClassJSONMixin + else: # DataClassJSONMixin from mashumaro.jsonschema import build_json_schema + schema = build_json_schema(cast(DataClassJSONMixin, self._get_origin_type_in_annotation(t))).to_dict() except Exception as e: # https://github.com/lovasoa/marshmallow_dataclass/issues/13 @@ -382,9 +383,12 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp f"{type(python_val)} is not of type @dataclass, only Dataclasses are supported for " f"user defined datatypes in Flytekit" ) - if not issubclass(type(python_val), DataClassJsonMixin) and not issubclass(type(python_val), DataClassJSONMixin): + if not issubclass(type(python_val), DataClassJsonMixin) and not issubclass( + type(python_val), DataClassJSONMixin + ): raise TypeTransformerFailedError( - f"Dataclass {python_type} should be decorated with @dataclass_json to be " f"serialized correctly" + f"Dataclass {python_type} should be decorated with @dataclass_json or subclass of DataClassJSONMixin to be " + f"serialized correctly" ) self._serialize_flyte_type(python_val, python_type) @@ -393,9 +397,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp else: json_str = cast(DataClassJSONMixin, python_val).to_json() # type: ignore - return Literal( - scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())) # type: ignore - ) + return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) # type: ignore def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]: # dataclass will try to hash python type when calling dataclass.schema(), but some types in the annotation is @@ -639,7 +641,9 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: f"{expected_python_type} is not of type @dataclass, only Dataclasses are supported for " "user defined datatypes in Flytekit" ) - if not issubclass(expected_python_type, DataClassJsonMixin) and not issubclass(expected_python_type, DataClassJSONMixin): + if not issubclass(expected_python_type, DataClassJsonMixin) and not issubclass( + expected_python_type, DataClassJSONMixin + ): raise TypeTransformerFailedError( f"Dataclass {expected_python_type} should be decorated with @dataclass_json or be a subclass of " "DataClassJsonMixin to be serialized correctly" From b1c7875ab05c0405c14f4588b6c1eaadf0925bd3 Mon Sep 17 00:00:00 2001 From: hhcs9527 Date: Fri, 28 Jul 2023 07:08:01 +0000 Subject: [PATCH 13/38] add new test for testing mashumaro dataclass Signed-off-by: hhcs9527 --- flytekit/core/type_engine.py | 47 +- tests/flytekit/unit/core/test_type_engine.py | 473 +++++++++++++++++++ 2 files changed, 512 insertions(+), 8 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index ba28869e6e..25ef683025 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -665,10 +665,13 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: @lru_cache(typed=True) def guess_python_type(self, literal_type: LiteralType) -> Type[T]: # type: ignore if literal_type.simple == SimpleType.STRUCT: - if literal_type.metadata is not None and DEFINITIONS in literal_type.metadata: - schema_name = literal_type.metadata["$ref"].split("/")[-1] - return convert_json_schema_to_python_class(literal_type.metadata[DEFINITIONS], schema_name) - + if literal_type.metadata is not None: + if DEFINITIONS in literal_type.metadata: + schema_name = literal_type.metadata["$ref"].split("/")[-1] + return convert_json_schema_to_python_class(literal_type.metadata[DEFINITIONS], schema_name) + else: + schema_name = literal_type.metadata["title"] + return convert_json_schema_to_python_class(literal_type.metadata, schema_name, True) raise ValueError(f"Dataclass transformer cannot reverse {literal_type}") @@ -1575,14 +1578,42 @@ def to_literal( def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: return expected_python_type(lv.scalar.primitive.string_value) # type: ignore +def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name: typing.Any): + attribute_list = [] + for property_key, property_val in schema["properties"].items(): + if property_val.get("anyOf"): + property_type = property_val["anyOf"][0]["type"] + elif property_val.get("enum") : + property_type = "enum" + else: + property_type = property_val["type"] + # Handle list + if property_type == "array": + attribute_list.append((property_key, typing.List[_get_element_type(property_val["items"])])) # type: ignore + # Handle dataclass and dict + elif property_type == "object": + if property_val.get("anyOf"): + attribute_list.append((property_key, convert_json_schema_to_python_class(property_val["anyOf"][0], schema_name, True))) + elif property_val.get("additionalProperties"): + attribute_list.append( + (property_key, typing.Dict[str, _get_element_type(property_val["additionalProperties"])]) # type: ignore + ) + else: + attribute_list.append((property_key, convert_json_schema_to_python_class(property_val, schema_name, True))) + elif property_type == "enum": + attribute_list.append([property_key, str]) # type: ignore + # Handle int, float, bool or str + else: + attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore + return attribute_list -def convert_json_schema_to_python_class(schema: Dict[str, Any], schema_name: str) -> Type[Any]: +def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[dataclasses.dataclass()]: # type: ignore """ Generate a model class based on the provided JSON Schema :param schema: dict representing valid JSON schema :param schema_name: dataclass name of return type """ - attribute_list: List[Tuple[str, type]] = [] + attribute_list = [] for property_key, property_val in schema[schema_name]["properties"].items(): property_type = property_val["type"] # Handle list @@ -1601,13 +1632,13 @@ def convert_json_schema_to_python_class(schema: Dict[str, Any], schema_name: str attribute_list.append((property_key, Dict[str, _get_element_type(property_val)])) # type: ignore[misc,index] # Handle int, float, bool or str else: - attribute_list.append((property_key, _get_element_type(property_val))) + attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list)) def _get_element_type(element_property: typing.Dict[str, str]) -> Type: - element_type = element_property["type"] + element_type = [e_property["type"] for e_property in element_property["anyOf"]] if element_property.get("anyOf") else element_property["type"] element_format = element_property["format"] if "format" in element_property else None if type(element_type) == list: diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 68483cf430..0d09ce12ce 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -19,6 +19,8 @@ from google.protobuf import struct_pb2 as _struct from marshmallow_enum import LoadDumpOptions from marshmallow_jsonschema import JSONSchema +from mashumaro.mixins.json import DataClassJSONMixin +import mashumaro from pandas._testing import assert_frame_equal from typing_extensions import Annotated, get_args, get_origin @@ -171,6 +173,31 @@ class Foo(DataClassJsonMixin): assert pv[0].b == Bar(v=[1, 2, 99], w=[3.1415, 2.7182]) +@dataclass +class Bar(DataClassJSONMixin): + v: typing.Optional[typing.List[int]] + w: typing.Optional[typing.List[float]] + + +@dataclass +class Foo(DataClassJSONMixin): + a: typing.Optional[typing.List[str]] + b: Bar + + +def test_list_of_single_dataclassjsonmixin(): + foo = Foo(a=["abc", "def"], b=Bar(v=[1, 2, 99], w=[3.1415, 2.7182])) + generic = _json_format.Parse(typing.cast(DataClassJSONMixin, foo).to_json(), _struct.Struct()) + lv = Literal(collection=LiteralCollection(literals=[Literal(scalar=Scalar(generic=generic))])) + + transformer = TypeEngine.get_transformer(typing.List) + ctx = FlyteContext.current_context() + + pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[Foo]) + assert pv[0].a == ["abc", "def"] + assert pv[0].b == Bar(v=[1, 2, 99], w=[3.1415, 2.7182]) + + def test_annotated_type(): class JsonTypeTransformer(TypeTransformer[T]): LiteralType = LiteralType( @@ -264,6 +291,65 @@ class Foo(DataClassJsonMixin): assert pv[0] == dataclass_from_dict(Foo, asdict(guessed_pv[0])) +@dataclass +class Bar_getting_python_value(DataClassJSONMixin): + v: typing.Union[int, None] + w: typing.Optional[str] + x: float + y: str + z: typing.Dict[str, bool] + + +@dataclass +class Foo_getting_python_value(DataClassJSONMixin): + u: typing.Optional[int] + v: typing.Optional[int] + w: int + x: typing.List[int] + y: typing.Dict[str, str] + z: Bar_getting_python_value + + +def test_list_of_dataclassjsonmixin_getting_python_value(): + foo = Foo_getting_python_value( + u=5, + v=None, + w=1, + x=[1], + y={"hello": "10"}, + z=Bar_getting_python_value(v=3, w=None, x=1.0, y="hello", z={"world": False}), + ) + generic = _json_format.Parse(typing.cast(DataClassJSONMixin, foo).to_json(), _struct.Struct()) + lv = Literal(collection=LiteralCollection(literals=[Literal(scalar=Scalar(generic=generic))])) + + transformer = TypeEngine.get_transformer(typing.List) + ctx = FlyteContext.current_context() + + from mashumaro.jsonschema import build_json_schema + + schema = build_json_schema(typing.cast(DataClassJSONMixin, Foo_getting_python_value)).to_dict() + foo_class = convert_json_schema_to_python_class(schema, "FooSchema", True) + + guessed_pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[foo_class]) + pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[Foo_getting_python_value]) + assert isinstance(guessed_pv, list) + assert guessed_pv[0].u == pv[0].u + assert guessed_pv[0].v == pv[0].v + assert guessed_pv[0].w == pv[0].w + assert guessed_pv[0].x == pv[0].x + assert guessed_pv[0].y == pv[0].y + assert guessed_pv[0].z.x == pv[0].z.x + assert type(guessed_pv[0].u) == int + assert guessed_pv[0].v is None + assert type(guessed_pv[0].w) == int + assert type(guessed_pv[0].z.v) == int + assert type(guessed_pv[0].z.x) == float + assert guessed_pv[0].z.v == pv[0].z.v + assert guessed_pv[0].z.y == pv[0].z.y + assert guessed_pv[0].z.z == pv[0].z.z + assert pv[0] == dataclass_from_dict(Foo_getting_python_value, asdict(guessed_pv[0])) + + def test_file_no_downloader_default(): # The idea of this test is to assert that if a FlyteFile is created with no download specified, # then it should return the set path itself. This matches if we use open method @@ -401,6 +487,25 @@ class Foo(DataClassJsonMixin): _ = foo.c +def test_convert_json_schema_to_python_class_with_dataclassjsonmixin(): + @dataclass + class Foo(DataClassJSONMixin): + x: int + y: str + + # schema = JSONSchema().dump(typing.cast(DataClassJSONMixin, Foo).schema()) + from mashumaro.jsonschema import build_json_schema + + schema = build_json_schema(typing.cast(DataClassJSONMixin, Foo)).to_dict() + foo_class = convert_json_schema_to_python_class(schema, "FooSchema", is_dataclass_json_mixin=True) + foo = foo_class(x=1, y="hello") + foo.x = 2 + assert foo.x == 2 + assert foo.y == "hello" + with pytest.raises(AttributeError): + _ = foo.c + + def test_list_transformer(): l0 = Literal(scalar=Scalar(primitive=Primitive(integer=3))) l1 = Literal(scalar=Scalar(primitive=Primitive(integer=4))) @@ -591,6 +696,123 @@ def test_dataclass_transformer(): assert t.metadata is None +@dataclass +class InnerStruct_transformer(DataClassJSONMixin): + a: int + b: typing.Optional[str] + c: typing.List[int] + + +@dataclass +class TestStruct_transformer(DataClassJSONMixin): + s: InnerStruct_transformer + m: typing.Dict[str, str] + + +@dataclass +class TestStructB_transformer(DataClassJSONMixin): + s: InnerStruct_transformer + m: typing.Dict[int, str] + n: typing.Optional[typing.List[typing.List[int]]] = None + o: typing.Optional[typing.Dict[int, typing.Dict[int, int]]] = None + + +@dataclass +class TestStructC_transformer(DataClassJSONMixin): + s: InnerStruct_transformer + m: typing.Dict[str, int] + + +@dataclass +class TestStructD_transformer(DataClassJSONMixin): + s: InnerStruct_transformer + m: typing.Dict[str, typing.List[int]] + + +@dataclass # to ask => not support => failed right away +class UnsupportedSchemaType_transformer: + _a:str="Hello" + + +def test_dataclass_transformer_with_dataclassjsonmixin(): + schema = { + "type": "object", + "title": "TestStruct_transformer", + "properties": { + "s": { + "type": "object", + "title": "InnerStruct_transformer", + "properties": { + "a": { + "type": "integer" + }, + "b": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ] + }, + "c": { + "type": "array", + "items": { + "type": "integer" + } + } + }, + "additionalProperties": False, + "required": [ + "a", + "b", + "c" + ] + }, + "m": { + "type": "object", + "additionalProperties": { + "type": "string" + }, + "propertyNames": { + "type": "string" + } + } + }, + "additionalProperties": False, + "required": [ + "s", + "m" + ] + } + + tf = DataclassTransformer() + t = tf.get_literal_type(TestStruct_transformer) + assert t is not None + assert t.simple is not None + assert t.simple == SimpleType.STRUCT + assert t.metadata is not None + assert t.metadata == schema + + t = TypeEngine.to_literal_type(TestStruct_transformer) + assert t is not None + assert t.simple is not None + assert t.simple == SimpleType.STRUCT + assert t.metadata is not None + assert t.metadata == schema + + +@pytest.mark.xfail(raises=mashumaro.exceptions.UnserializableField) +def test_unsupported_schema_type(): + # The code that is expected to raise the exception during class definition + @dataclass + class UnsupportedNestedStruct_transformer(DataClassJSONMixin): + a: int + s: UnsupportedSchemaType_transformer + + tf = DataclassTransformer() + t = tf.get_literal_type(UnsupportedNestedStruct_transformer) def test_dataclass_int_preserving(): ctx = FlyteContext.current_context() @@ -700,6 +922,90 @@ class TestFileStruct(DataClassJsonMixin): assert o.i_prime == A(a=99) +@dataclass +class A_optional_flytefile(DataClassJSONMixin): + a: int + + +@dataclass +class TestFileStruct_optional_flytefile(DataClassJSONMixin): + a: FlyteFile + b: typing.Optional[FlyteFile] + b_prime: typing.Optional[FlyteFile] + c: typing.Union[FlyteFile, None] + d: typing.List[FlyteFile] + e: typing.List[typing.Optional[FlyteFile]] + e_prime: typing.List[typing.Optional[FlyteFile]] + f: typing.Dict[str, FlyteFile] + g: typing.Dict[str, typing.Optional[FlyteFile]] + g_prime: typing.Dict[str, typing.Optional[FlyteFile]] + h: typing.Optional[FlyteFile] = None + h_prime: typing.Optional[FlyteFile] = None + i: typing.Optional[A_optional_flytefile] = None + i_prime: typing.Optional[A_optional_flytefile] = field(default_factory=lambda: A_optional_flytefile(a=99)) + + +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") +def test_optional_flytefile_in_dataclassjsonmixin(mock_upload_dir): + mock_upload_dir.return_value = True + + remote_path = "s3://tmp/file" + with tempfile.TemporaryFile() as f: + f.write(b"abc") + f1 = FlyteFile("f1", remote_path=remote_path) + o = TestFileStruct_optional_flytefile( + a=f1, + b=f1, + b_prime=None, + c=f1, + d=[f1], + e=[f1], + e_prime=[None], + f={"a": f1}, + g={"a": f1}, + g_prime={"a": None}, + h=f1, + i=A_optional_flytefile(a=42), + ) + + ctx = FlyteContext.current_context() + tf = DataclassTransformer() + lt = tf.get_literal_type(TestFileStruct_optional_flytefile) + lv = tf.to_literal(ctx, o, TestFileStruct_optional_flytefile, lt) + + assert lv.scalar.generic["a"] == remote_path + assert lv.scalar.generic["b"] == remote_path + assert lv.scalar.generic["b_prime"] is None + assert lv.scalar.generic["c"] == remote_path + assert lv.scalar.generic["d"].values[0].string_value == remote_path + assert lv.scalar.generic["e"].values[0].string_value == remote_path + assert lv.scalar.generic["e_prime"].values[0].WhichOneof("kind") == "null_value" + assert lv.scalar.generic["f"]["a"] == remote_path + assert lv.scalar.generic["g"]["a"] == remote_path + assert lv.scalar.generic["g_prime"]["a"] is None + assert lv.scalar.generic["h"] == remote_path + assert lv.scalar.generic["h_prime"] is None + assert lv.scalar.generic["i"]["a"] == 42 + assert lv.scalar.generic["i_prime"]["a"] == 99 + + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=TestFileStruct_optional_flytefile) + + assert o.a.path == ot.a.remote_source + assert o.b.path == ot.b.remote_source + assert ot.b_prime is None + assert o.c.path == ot.c.remote_source + assert o.d[0].path == ot.d[0].remote_source + assert o.e[0].path == ot.e[0].remote_source + assert o.e_prime == [None] + assert o.f["a"].path == ot.f["a"].remote_source + assert o.g["a"].path == ot.g["a"].remote_source + assert o.g_prime == {"a": None} + assert o.h.path == ot.h.remote_source + assert ot.h_prime is None + assert o.i == ot.i + assert o.i_prime == A_optional_flytefile(a=99) + + def test_flyte_file_in_dataclass(): @dataclass class TestInnerFileStruct(DataClassJsonMixin): @@ -743,6 +1049,53 @@ class TestFileStruct(DataClassJsonMixin): assert not ctx.file_access.is_remote(ot.b.e["hello"].path) +@dataclass +class TestInnerFileStruct_flyte_file(DataClassJSONMixin): + a: JPEGImageFile + b: typing.List[FlyteFile] + c: typing.Dict[str, FlyteFile] + d: typing.List[FlyteFile] + e: typing.Dict[str, FlyteFile] + + +@dataclass +class TestFileStruct_flyte_file(DataClassJSONMixin): + a: FlyteFile + b: TestInnerFileStruct_flyte_file + + +def test_flyte_file_in_dataclassjsonmixin(): + remote_path = "s3://tmp/file" + f1 = FlyteFile(remote_path) + f2 = FlyteFile("/tmp/file") + f2._remote_source = remote_path + o = TestFileStruct_flyte_file( + a=f1, + b=TestInnerFileStruct_flyte_file( + a=JPEGImageFile("s3://tmp/file.jpeg"), b=[f1], c={"hello": f1}, d=[f2], e={"hello": f2} + ), + ) + + ctx = FlyteContext.current_context() + tf = DataclassTransformer() + lt = tf.get_literal_type(TestFileStruct_flyte_file) + lv = tf.to_literal(ctx, o, TestFileStruct_flyte_file, lt) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=TestFileStruct_flyte_file) + assert ot.a._downloader is not noop + assert ot.b.a._downloader is not noop + assert ot.b.b[0]._downloader is not noop + assert ot.b.c["hello"]._downloader is not noop + + assert o.a.path == ot.a.remote_source + assert o.b.a.path == ot.b.a.remote_source + assert o.b.b[0].path == ot.b.b[0].remote_source + assert o.b.c["hello"].path == ot.b.c["hello"].remote_source + assert ot.b.d[0].remote_source == remote_path + assert not ctx.file_access.is_remote(ot.b.d[0].path) + assert ot.b.e["hello"].remote_source == remote_path + assert not ctx.file_access.is_remote(ot.b.e["hello"].path) + + def test_flyte_directory_in_dataclass(): @dataclass class TestInnerFileStruct(DataClassJsonMixin): @@ -789,6 +1142,56 @@ class TestFileStruct(DataClassJsonMixin): assert o.b.e["hello"].path == ot.b.e["hello"].remote_source +@dataclass +class TestInnerFileStruct_flyte_directory(DataClassJSONMixin): + a: TensorboardLogs + b: typing.List[FlyteDirectory] + c: typing.Dict[str, FlyteDirectory] + d: typing.List[FlyteDirectory] + e: typing.Dict[str, FlyteDirectory] + + +@dataclass +class TestFileStruct_flyte_directory(DataClassJSONMixin): + a: FlyteDirectory + b: TestInnerFileStruct_flyte_directory + + +def test_flyte_directory_in_dataclassjsonmixin(): + remote_path = "s3://tmp/file" + tempdir = tempfile.mkdtemp(prefix="flyte-") + f1 = FlyteDirectory(tempdir) + f1._remote_source = remote_path + f2 = FlyteDirectory(remote_path) + o = TestFileStruct_flyte_directory( + a=f1, + b=TestInnerFileStruct_flyte_directory( + a=TensorboardLogs("s3://tensorboard"), b=[f1], c={"hello": f1}, d=[f2], e={"hello": f2} + ), + ) + + ctx = FlyteContext.current_context() + tf = DataclassTransformer() + lt = tf.get_literal_type(TestFileStruct_flyte_directory) + lv = tf.to_literal(ctx, o, TestFileStruct_flyte_directory, lt) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=TestFileStruct_flyte_directory) + + assert ot.a._downloader is not noop + assert ot.b.a._downloader is not noop + assert ot.b.b[0]._downloader is not noop + assert ot.b.c["hello"]._downloader is not noop + + assert o.a.remote_directory == ot.a.remote_directory + assert not ctx.file_access.is_remote(ot.a.path) + assert o.b.a.path == ot.b.a.remote_source + assert o.b.b[0].remote_directory == ot.b.b[0].remote_directory + assert not ctx.file_access.is_remote(ot.b.b[0].path) + assert o.b.c["hello"].remote_directory == ot.b.c["hello"].remote_directory + assert not ctx.file_access.is_remote(ot.b.c["hello"].path) + assert o.b.d[0].path == ot.b.d[0].remote_source + assert o.b.e["hello"].path == ot.b.e["hello"].remote_source + + def test_structured_dataset_in_dataclass(): df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) People = Annotated[StructuredDataset, "parquet", kwtypes(Name=str, Age=int)] @@ -973,6 +1376,35 @@ class Bar(DataClassJsonMixin): DataclassTransformer().assert_type(gt, pv) +@dataclass +class ArgsAssert(DataClassJSONMixin): + x: int + y: typing.Optional[str] + +@dataclass +class SchemaArgsAssert(DataClassJSONMixin): + x: typing.Optional[ArgsAssert] + + +def test_assert_dataclassjsonmixin_type(): + pt = SchemaArgsAssert + lt = TypeEngine.to_literal_type(pt) + gt = TypeEngine.guess_python_type(lt) + pv = SchemaArgsAssert(x=ArgsAssert(x=3, y="hello")) + DataclassTransformer().assert_type(gt, pv) + DataclassTransformer().assert_type(SchemaArgsAssert, pv) + + @dataclass + class Bar(DataClassJSONMixin): + x: int + + pv = Bar(x=3) + with pytest.raises( + TypeTransformerFailedError, match="Type of Val '' is not an instance of " + ): + DataclassTransformer().assert_type(gt, pv) + + def test_union_transformer(): assert UnionTransformer.is_optional_type(typing.Optional[int]) assert not UnionTransformer.is_optional_type(str) @@ -1288,6 +1720,28 @@ class Datum(DataClassJsonMixin): assert datum.y.value == pv.y +def test_enum_in_dataclassjsonmixin(): + @dataclass + class Datum(DataClassJSONMixin): + x: int + y: Color + + lt = TypeEngine.to_literal_type(Datum) + from mashumaro.jsonschema import build_json_schema + + schema = build_json_schema(typing.cast(DataClassJSONMixin, Datum)).to_dict() + assert lt.metadata == schema + + transformer = DataclassTransformer() + ctx = FlyteContext.current_context() + datum = Datum(5, Color.RED) + lv = transformer.to_literal(ctx, datum, Datum, lt) + gt = transformer.guess_python_type(lt) + pv = transformer.to_python_value(ctx, lv, expected_python_type=gt) + assert datum.x == pv.x + assert datum.y.value == pv.y + + @pytest.mark.parametrize( "python_value,python_types,expected_literal_map", [ @@ -1569,6 +2023,25 @@ def hello(self): assert hasattr(lr.get("a", Foo), "hello") is True +def test_guess_of_dataclassjsonmixin(): + @dataclass + class Foo(DataClassJSONMixin): + x: int + y: str + z: typing.Dict[str, int] + + def hello(self): + ... + + lt = TypeEngine.to_literal_type(Foo) + foo = Foo(1, "hello", {"world": 3}) + lv = TypeEngine.to_literal(FlyteContext.current_context(), foo, Foo, lt) + lit_dict = {"a": lv} + lr = LiteralsResolver(lit_dict) + assert lr.get("a", Foo) == foo + assert hasattr(lr.get("a", Foo), "hello") is True + + def test_flyte_dir_in_union(): pt = typing.Union[str, FlyteDirectory, FlyteFile] lt = TypeEngine.to_literal_type(pt) From 631602d7348a3f61dcde5475e97e86af9e3dd2bb Mon Sep 17 00:00:00 2001 From: hhcs9527 Date: Sun, 30 Jul 2023 16:36:15 +0000 Subject: [PATCH 14/38] support structure dataclass and flytescheme with DataClassJSONMixin Signed-off-by: hhcs9527 --- flytekit/types/schema/types.py | 3 ++- flytekit/types/structured/structured_dataset.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index dc7ca816ba..e039832fef 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -19,6 +19,7 @@ from flytekit.loggers import logger from flytekit.models.literals import Literal, Scalar, Schema from flytekit.models.types import LiteralType, SchemaType +from mashumaro.mixins.json import DataClassJSONMixin T = typing.TypeVar("T") @@ -180,7 +181,7 @@ def get_handler(cls, t: Type) -> SchemaHandler: @dataclass -class FlyteSchema(DataClassJsonMixin): +class FlyteSchema(DataClassJSONMixin): remote_path: typing.Optional[str] = field(default=None, metadata=config(mm_field=fields.String())) """ This is the main schema class that users should use. diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index a88de49974..cc61b04309 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -22,6 +22,7 @@ from flytekit.models import types as type_models from flytekit.models.literals import Literal, Scalar, StructuredDatasetMetadata from flytekit.models.types import LiteralType, SchemaType, StructuredDatasetType +from mashumaro.mixins.json import DataClassJSONMixin if typing.TYPE_CHECKING: import pandas as pd @@ -44,7 +45,7 @@ @dataclass -class StructuredDataset(DataClassJsonMixin): +class StructuredDataset(DataClassJSONMixin): """ This is the user facing StructuredDataset class. Please don't confuse it with the literals.StructuredDataset class (that is just a model, a Python class representation of the protobuf). From c637a5ea059c175139484d5fad1f5ae37f4d859f Mon Sep 17 00:00:00 2001 From: hhcs9527 Date: Mon, 31 Jul 2023 10:03:03 +0000 Subject: [PATCH 15/38] add test && fix lint Signed-off-by: hhcs9527 --- flytekit/core/type_engine.py | 38 +++-- flytekit/types/file/file.py | 5 +- flytekit/types/schema/types.py | 4 +- .../types/structured/structured_dataset.py | 4 +- tests/flytekit/unit/core/test_type_engine.py | 154 +++++++++++------- 5 files changed, 127 insertions(+), 78 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 25ef683025..d13e4d7d3c 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1578,12 +1578,13 @@ def to_literal( def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: return expected_python_type(lv.scalar.primitive.string_value) # type: ignore + def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name: typing.Any): attribute_list = [] for property_key, property_val in schema["properties"].items(): if property_val.get("anyOf"): property_type = property_val["anyOf"][0]["type"] - elif property_val.get("enum") : + elif property_val.get("enum"): property_type = "enum" else: property_type = property_val["type"] @@ -1593,13 +1594,17 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name: # Handle dataclass and dict elif property_type == "object": if property_val.get("anyOf"): - attribute_list.append((property_key, convert_json_schema_to_python_class(property_val["anyOf"][0], schema_name, True))) + attribute_list.append( + (property_key, convert_json_schema_to_python_class(property_val["anyOf"][0], schema_name, True)) + ) elif property_val.get("additionalProperties"): attribute_list.append( (property_key, typing.Dict[str, _get_element_type(property_val["additionalProperties"])]) # type: ignore ) else: - attribute_list.append((property_key, convert_json_schema_to_python_class(property_val, schema_name, True))) + attribute_list.append( + (property_key, convert_json_schema_to_python_class(property_val, schema_name, True)) + ) elif property_type == "enum": attribute_list.append([property_key, str]) # type: ignore # Handle int, float, bool or str @@ -1607,12 +1612,7 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name: attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore return attribute_list -def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[dataclasses.dataclass()]: # type: ignore - """ - Generate a model class based on the provided JSON Schema - :param schema: dict representing valid JSON schema - :param schema_name: dataclass name of return type - """ +def generate_attribute_list_from_dataclass_json(schema: dict, schema_name: typing.Any): attribute_list = [] for property_key, property_val in schema[schema_name]["properties"].items(): property_type = property_val["type"] @@ -1633,12 +1633,28 @@ def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[datac # Handle int, float, bool or str else: attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore + return attribute_list + +def convert_json_schema_to_python_class(schema: dict, schema_name: typing.Any, is_dataclass_json_mixin:bool=False) -> Type[dataclasses.dataclass()]: # type: ignore + """ + Generate a model class based on the provided JSON Schema + :param schema: dict representing valid JSON schema + :param schema_name: dataclass name of return type + """ + if is_dataclass_json_mixin: + attribute_list = generate_attribute_list_from_dataclass_json_mixin(schema, schema_name) + else: + attribute_list = generate_attribute_list_from_dataclass_json(schema, schema_name) return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list)) -def _get_element_type(element_property: typing.Dict[str, str]) -> Type: - element_type = [e_property["type"] for e_property in element_property["anyOf"]] if element_property.get("anyOf") else element_property["type"] +def _get_element_type(element_property: typing.Dict[str, typing.Any]) -> Type: + element_type = ( + [e_property["type"] for e_property in element_property["anyOf"]] + if element_property.get("anyOf") + else element_property["type"] + ) element_format = element_property["format"] if "format" in element_property else None if type(element_type) == list: diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index d949379705..34e4f39943 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -6,8 +6,9 @@ from contextlib import contextmanager from dataclasses import dataclass, field -from dataclasses_json import DataClassJsonMixin, config +from dataclasses_json import config from marshmallow import fields +from mashumaro.mixins.json import DataClassJSONMixin from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError, get_underlying_type @@ -26,7 +27,7 @@ def noop(): @dataclass -class FlyteFile(DataClassJsonMixin, os.PathLike, typing.Generic[T]): +class FlyteFile(os.PathLike, typing.Generic[T], DataClassJSONMixin): path: typing.Union[str, os.PathLike] = field( default=None, metadata=config(mm_field=fields.String()) ) # type: ignore diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index e039832fef..bba099a57e 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -11,15 +11,15 @@ import numpy as _np import pandas -from dataclasses_json import DataClassJsonMixin, config +from dataclasses_json import config from marshmallow import fields +from mashumaro.mixins.json import DataClassJSONMixin from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError from flytekit.loggers import logger from flytekit.models.literals import Literal, Scalar, Schema from flytekit.models.types import LiteralType, SchemaType -from mashumaro.mixins.json import DataClassJSONMixin T = typing.TypeVar("T") diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index cc61b04309..99a0e0832b 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -8,9 +8,10 @@ from typing import Dict, Generator, Optional, Type, Union import _datetime -from dataclasses_json import DataClassJsonMixin, config +from dataclasses_json import config from fsspec.utils import get_protocol from marshmallow import fields +from mashumaro.mixins.json import DataClassJSONMixin from typing_extensions import Annotated, TypeAlias, get_args, get_origin from flytekit import lazy_module @@ -22,7 +23,6 @@ from flytekit.models import types as type_models from flytekit.models.literals import Literal, Scalar, StructuredDatasetMetadata from flytekit.models.types import LiteralType, SchemaType, StructuredDatasetType -from mashumaro.mixins.json import DataClassJSONMixin if typing.TYPE_CHECKING: import pandas as pd diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 0d09ce12ce..8c7d2f3545 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -20,7 +20,6 @@ from marshmallow_enum import LoadDumpOptions from marshmallow_jsonschema import JSONSchema from mashumaro.mixins.json import DataClassJSONMixin -import mashumaro from pandas._testing import assert_frame_equal from typing_extensions import Annotated, get_args, get_origin @@ -493,7 +492,6 @@ class Foo(DataClassJSONMixin): x: int y: str - # schema = JSONSchema().dump(typing.cast(DataClassJSONMixin, Foo).schema()) from mashumaro.jsonschema import build_json_schema schema = build_json_schema(typing.cast(DataClassJSONMixin, Foo)).to_dict() @@ -729,9 +727,15 @@ class TestStructD_transformer(DataClassJSONMixin): m: typing.Dict[str, typing.List[int]] -@dataclass # to ask => not support => failed right away +@dataclass class UnsupportedSchemaType_transformer: - _a:str="Hello" + _a: str = "Hello" + + +@dataclass +class UnsupportedNestedStruct_transformer(DataClassJSONMixin): + a: int + s: UnsupportedSchemaType_transformer def test_dataclass_transformer_with_dataclassjsonmixin(): @@ -743,48 +747,17 @@ def test_dataclass_transformer_with_dataclassjsonmixin(): "type": "object", "title": "InnerStruct_transformer", "properties": { - "a": { - "type": "integer" - }, - "b": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ] - }, - "c": { - "type": "array", - "items": { - "type": "integer" - } - } + "a": {"type": "integer"}, + "b": {"anyOf": [{"type": "string"}, {"type": "null"}]}, + "c": {"type": "array", "items": {"type": "integer"}}, }, "additionalProperties": False, - "required": [ - "a", - "b", - "c" - ] + "required": ["a", "b", "c"], }, - "m": { - "type": "object", - "additionalProperties": { - "type": "string" - }, - "propertyNames": { - "type": "string" - } - } + "m": {"type": "object", "additionalProperties": {"type": "string"}, "propertyNames": {"type": "string"}}, }, "additionalProperties": False, - "required": [ - "s", - "m" - ] + "required": ["s", "m"], } tf = DataclassTransformer() @@ -802,17 +775,13 @@ def test_dataclass_transformer_with_dataclassjsonmixin(): assert t.metadata is not None assert t.metadata == schema + t = tf.get_literal_type(UnsupportedNestedStruct) + assert t is not None + assert t.simple is not None + assert t.simple == SimpleType.STRUCT + assert t.metadata is None -@pytest.mark.xfail(raises=mashumaro.exceptions.UnserializableField) -def test_unsupported_schema_type(): - # The code that is expected to raise the exception during class definition - @dataclass - class UnsupportedNestedStruct_transformer(DataClassJSONMixin): - a: int - s: UnsupportedSchemaType_transformer - tf = DataclassTransformer() - t = tf.get_literal_type(UnsupportedNestedStruct_transformer) def test_dataclass_int_preserving(): ctx = FlyteContext.current_context() @@ -973,20 +942,20 @@ def test_optional_flytefile_in_dataclassjsonmixin(mock_upload_dir): lt = tf.get_literal_type(TestFileStruct_optional_flytefile) lv = tf.to_literal(ctx, o, TestFileStruct_optional_flytefile, lt) - assert lv.scalar.generic["a"] == remote_path - assert lv.scalar.generic["b"] == remote_path + assert lv.scalar.generic["a"].fields["path"].string_value == remote_path + assert lv.scalar.generic["b"].fields["path"].string_value == remote_path assert lv.scalar.generic["b_prime"] is None - assert lv.scalar.generic["c"] == remote_path - assert lv.scalar.generic["d"].values[0].string_value == remote_path - assert lv.scalar.generic["e"].values[0].string_value == remote_path + assert lv.scalar.generic["c"].fields["path"].string_value == remote_path + assert lv.scalar.generic["d"].values[0].struct_value.fields["path"].string_value == remote_path + assert lv.scalar.generic["e"].values[0].struct_value.fields["path"].string_value == remote_path assert lv.scalar.generic["e_prime"].values[0].WhichOneof("kind") == "null_value" - assert lv.scalar.generic["f"]["a"] == remote_path - assert lv.scalar.generic["g"]["a"] == remote_path + assert lv.scalar.generic["f"]["a"].fields["path"].string_value == remote_path + assert lv.scalar.generic["g"]["a"].fields["path"].string_value == remote_path assert lv.scalar.generic["g_prime"]["a"] is None - assert lv.scalar.generic["h"] == remote_path + assert lv.scalar.generic["h"].fields["path"].string_value == remote_path assert lv.scalar.generic["h_prime"] is None - assert lv.scalar.generic["i"]["a"] == 42 - assert lv.scalar.generic["i_prime"]["a"] == 99 + assert lv.scalar.generic["i"].fields["a"].number_value == 42 + assert lv.scalar.generic["i_prime"].fields["a"].number_value == 99 ot = tf.to_python_value(ctx, lv=lv, expected_python_type=TestFileStruct_optional_flytefile) @@ -1226,6 +1195,41 @@ class DatasetStruct(DataClassJsonMixin): assert "parquet" == ot.b.c["hello"].file_format +@dataclass +class InnerDatasetStruct_dataclassjsonmixin(DataClassJSONMixin): + a: StructuredDataset + b: typing.List[Annotated[StructuredDataset, "parquet"]] + c: typing.Dict[str, Annotated[StructuredDataset, kwtypes(Name=str, Age=int)]] + + +def test_structured_dataset_in_dataclassjsonmixin(): + df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + People = Annotated[StructuredDataset, "parquet"] + + @dataclass + class DatasetStruct_dataclassjsonmixin(DataClassJSONMixin): + a: People + b: InnerDatasetStruct_dataclassjsonmixin + + sd = StructuredDataset(dataframe=df, file_format="parquet") + o = DatasetStruct_dataclassjsonmixin(a=sd, b=InnerDatasetStruct_dataclassjsonmixin(a=sd, b=[sd], c={"hello": sd})) + + ctx = FlyteContext.current_context() + tf = DataclassTransformer() + lt = tf.get_literal_type(DatasetStruct_dataclassjsonmixin) + lv = tf.to_literal(ctx, o, DatasetStruct_dataclassjsonmixin, lt) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=DatasetStruct_dataclassjsonmixin) + + assert_frame_equal(df, ot.a.open(pd.DataFrame).all()) + assert_frame_equal(df, ot.b.a.open(pd.DataFrame).all()) + assert_frame_equal(df, ot.b.b[0].open(pd.DataFrame).all()) + assert_frame_equal(df, ot.b.c["hello"].open(pd.DataFrame).all()) + assert "parquet" == ot.a.file_format + assert "parquet" == ot.b.a.file_format + assert "parquet" == ot.b.b[0].file_format + assert "parquet" == ot.b.c["hello"].file_format + + # Enums should have string values class Color(Enum): RED = "red" @@ -1381,6 +1385,7 @@ class ArgsAssert(DataClassJSONMixin): x: int y: typing.Optional[str] + @dataclass class SchemaArgsAssert(DataClassJSONMixin): x: typing.Optional[ArgsAssert] @@ -1400,7 +1405,8 @@ class Bar(DataClassJSONMixin): pv = Bar(x=3) with pytest.raises( - TypeTransformerFailedError, match="Type of Val '' is not an instance of " + TypeTransformerFailedError, + match="Type of Val '' is not an instance of ", ): DataclassTransformer().assert_type(gt, pv) @@ -2004,6 +2010,32 @@ def test_schema_in_dataclass(): assert o == ot +@dataclass +class InnerResult_dataclassjsonmixin(DataClassJSONMixin): + number: int + schema: TestSchema # type: ignore + + +@dataclass +class Result_dataclassjsonmixin(DataClassJSONMixin): + result: InnerResult_dataclassjsonmixin + schema: TestSchema # type: ignore + + +def test_schema_in_dataclassjsonmixin(): + schema = TestSchema() + df = pd.DataFrame(data={"some_str": ["a", "b", "c"]}) + schema.open().write(df) + o = Result(result=InnerResult(number=1, schema=schema), schema=schema) + ctx = FlyteContext.current_context() + tf = DataclassTransformer() + lt = tf.get_literal_type(Result) + lv = tf.to_literal(ctx, o, Result, lt) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=Result) + + assert o == ot + + def test_guess_of_dataclass(): @dataclass class Foo(DataClassJsonMixin): From c242bbdfa89ab2076a8197a276f1202e14266c86 Mon Sep 17 00:00:00 2001 From: hhcs9527 Date: Sun, 6 Aug 2023 14:22:03 +0000 Subject: [PATCH 16/38] support mahumaro with latest version Signed-off-by: hhcs9527 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 6496777d1d..c260d8d647 100644 --- a/setup.py +++ b/setup.py @@ -62,7 +62,7 @@ # TODO: remove upper-bound after fixing change in contract "dataclasses-json>=0.5.2,<0.5.12", "marshmallow-jsonschema>=0.12.0", - "mashumaro>=3.8.1", + "mashumaro>=3.9", "marshmallow-enum", "natsort>=7.0.1", "docker-image-py>=0.1.10", From e70ae1c8bb4cce8177ce7a6cb502bc6b7891d26b Mon Sep 17 00:00:00 2001 From: hhcs9527 Date: Tue, 8 Aug 2023 00:07:22 +0000 Subject: [PATCH 17/38] add files generated by the make requirement Signed-off-by: hhcs9527 --- doc-requirements.txt | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/doc-requirements.txt b/doc-requirements.txt index 485ef715ca..7e07936e0e 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -12,7 +12,7 @@ absl-py==1.4.0 # tensorflow adlfs==2023.4.0 # via flytekit -aiobotocore==2.5.2 +aiobotocore==2.5.3 # via s3fs aiohttp==3.8.5 # via @@ -98,7 +98,7 @@ bleach==6.0.0 # via nbconvert blinker==1.6.2 # via flask -botocore==1.29.161 +botocore==1.31.17 # via # -r doc-requirements.in # aiobotocore @@ -591,6 +591,8 @@ marshmallow-enum==1.5.1 # flytekit marshmallow-jsonschema==0.13.0 # via flytekit +mashumaro==3.9 + # via flytekit matplotlib==3.7.2 # via # ipympl @@ -1284,6 +1286,7 @@ typing-extensions==4.5.0 # flytekit # great-expectations # ipython + # mashumaro # pydantic # python-utils # sqlalchemy From 7a898f07a6f341f667875abc699e55bd0a0e363e Mon Sep 17 00:00:00 2001 From: hhcs9527 Date: Tue, 15 Aug 2023 16:26:56 +0000 Subject: [PATCH 18/38] fix test issues in type engine Signed-off-by: hhcs9527 --- flytekit/core/type_engine.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index d13e4d7d3c..04af4ecefe 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -54,6 +54,7 @@ T = typing.TypeVar("T") DEFINITIONS = "definitions" +TITLE = "title" class BatchSize: @@ -669,8 +670,8 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[T]: # type: igno if DEFINITIONS in literal_type.metadata: schema_name = literal_type.metadata["$ref"].split("/")[-1] return convert_json_schema_to_python_class(literal_type.metadata[DEFINITIONS], schema_name) - else: - schema_name = literal_type.metadata["title"] + elif TITLE in literal_type.metadata: + schema_name = literal_type.metadata[TITLE] return convert_json_schema_to_python_class(literal_type.metadata, schema_name, True) raise ValueError(f"Dataclass transformer cannot reverse {literal_type}") From adcaedf7cc665453fd8256d9d98efaa6405c49ea Mon Sep 17 00:00:00 2001 From: hhcs9527 Date: Wed, 16 Aug 2023 04:25:42 +0000 Subject: [PATCH 19/38] Update the code with advise Signed-off-by: hhcs9527 --- flytekit/core/type_engine.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 04af4ecefe..8bdcb46dc5 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -384,20 +384,13 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp f"{type(python_val)} is not of type @dataclass, only Dataclasses are supported for " f"user defined datatypes in Flytekit" ) - if not issubclass(type(python_val), DataClassJsonMixin) and not issubclass( - type(python_val), DataClassJSONMixin - ): + if not isinstance(python_val, (DataClassJsonMixin, DataClassJSONMixin)): raise TypeTransformerFailedError( f"Dataclass {python_type} should be decorated with @dataclass_json or subclass of DataClassJSONMixin to be " f"serialized correctly" ) self._serialize_flyte_type(python_val, python_type) - - if issubclass(type(python_val), DataClassJsonMixin): - json_str = cast(DataClassJsonMixin, python_val).to_json() # type: ignore - else: - json_str = cast(DataClassJSONMixin, python_val).to_json() # type: ignore - + json_str = python_val.to_json() # type: ignore return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) # type: ignore def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]: @@ -650,11 +643,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: "DataClassJsonMixin to be serialized correctly" ) json_str = _json_format.MessageToJson(lv.scalar.generic) - - if issubclass(expected_python_type, DataClassJsonMixin): - dc = cast(DataClassJsonMixin, expected_python_type).from_json(json_str) # type: ignore - else: - dc = cast(DataClassJSONMixin, expected_python_type).from_json(json_str) # type: ignore + dc = expected_python_type.from_json(json_str) # type: ignore dc = self._fix_structured_dataset_type(expected_python_type, dc) return self._fix_dataclass_int(expected_python_type, self._deserialize_flyte_type(dc, expected_python_type)) From 7be6ced74bad47bbe595b63bcb5c13198e63b9a1 Mon Sep 17 00:00:00 2001 From: hhcs9527 Date: Fri, 18 Aug 2023 01:22:16 +0000 Subject: [PATCH 20/38] support new version of mashumaro Signed-off-by: hhcs9527 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index c260d8d647..06765feae7 100644 --- a/setup.py +++ b/setup.py @@ -62,7 +62,7 @@ # TODO: remove upper-bound after fixing change in contract "dataclasses-json>=0.5.2,<0.5.12", "marshmallow-jsonschema>=0.12.0", - "mashumaro>=3.9", + "mashumaro>=3.9.1", "marshmallow-enum", "natsort>=7.0.1", "docker-image-py>=0.1.10", From 01c505b841113846953e202d080991c726a631ff Mon Sep 17 00:00:00 2001 From: HH Date: Wed, 23 Aug 2023 10:25:25 +0800 Subject: [PATCH 21/38] remove requirement Signed-off-by: HH --- doc-requirements.txt | 2 -- 1 file changed, 2 deletions(-) diff --git a/doc-requirements.txt b/doc-requirements.txt index 7e07936e0e..b41608534f 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -591,8 +591,6 @@ marshmallow-enum==1.5.1 # flytekit marshmallow-jsonschema==0.13.0 # via flytekit -mashumaro==3.9 - # via flytekit matplotlib==3.7.2 # via # ipympl From a84f6c48be9ddd1c272ea3f7495a1fc10c97dbcd Mon Sep 17 00:00:00 2001 From: HH Date: Wed, 23 Aug 2023 13:50:54 +0800 Subject: [PATCH 22/38] fix lint Signed-off-by: HH --- flytekit/core/type_engine.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 8bdcb46dc5..4c9f3f109b 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -12,7 +12,7 @@ import typing from abc import ABC, abstractmethod from functools import lru_cache -from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, cast +from typing import Dict, List, NamedTuple, Optional, Type, cast from dataclasses_json import DataClassJsonMixin, dataclass_json from google.protobuf import json_format as _json_format @@ -1602,6 +1602,7 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name: attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore return attribute_list + def generate_attribute_list_from_dataclass_json(schema: dict, schema_name: typing.Any): attribute_list = [] for property_key, property_val in schema[schema_name]["properties"].items(): @@ -1625,7 +1626,8 @@ def generate_attribute_list_from_dataclass_json(schema: dict, schema_name: typin attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore return attribute_list -def convert_json_schema_to_python_class(schema: dict, schema_name: typing.Any, is_dataclass_json_mixin:bool=False) -> Type[dataclasses.dataclass()]: # type: ignore + +def convert_json_schema_to_python_class(schema: dict, schema_name: typing.Any, is_dataclass_json_mixin: bool = False) -> Type[dataclasses.dataclass()]: # type: ignore """ Generate a model class based on the provided JSON Schema :param schema: dict representing valid JSON schema From 81b8a3db62a857bd067091edd23c854ab4ed5cd4 Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Mon, 21 Aug 2023 14:18:36 -0700 Subject: [PATCH 23/38] Inherit directly from DataClassJsonMixin instead of using @dataclass_json for improved static type checking (#1801) * Inherit directly from DataClassJsonMixin instead of @dataclass_json for improved static type checking As it says in the dataclasses-json README: https://github.com/lidatong/dataclasses-json/blob/89578cb9ebed290e70dba8946bfdb68ff6746755/README.md?plain=1#L111-L129, we can use inheritance for improved static type checking; this one change eliminates something like 467 pyright errors from the flytekit module Signed-off-by: Matthew Hoffman --- flytekit/core/type_engine.py | 58 +++++------------------------------- 1 file changed, 8 insertions(+), 50 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 4c9f3f109b..c8724facf8 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -639,8 +639,8 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: expected_python_type, DataClassJSONMixin ): raise TypeTransformerFailedError( - f"Dataclass {expected_python_type} should be decorated with @dataclass_json or be a subclass of " - "DataClassJsonMixin to be serialized correctly" + f"Dataclass {expected_python_type} should be decorated with @dataclass_json or mixin with DataClassJSONMixin to be " + f"serialized correctly" ) json_str = _json_format.MessageToJson(lv.scalar.generic) dc = expected_python_type.from_json(json_str) # type: ignore @@ -1569,41 +1569,12 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: return expected_python_type(lv.scalar.primitive.string_value) # type: ignore -def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name: typing.Any): - attribute_list = [] - for property_key, property_val in schema["properties"].items(): - if property_val.get("anyOf"): - property_type = property_val["anyOf"][0]["type"] - elif property_val.get("enum"): - property_type = "enum" - else: - property_type = property_val["type"] - # Handle list - if property_type == "array": - attribute_list.append((property_key, typing.List[_get_element_type(property_val["items"])])) # type: ignore - # Handle dataclass and dict - elif property_type == "object": - if property_val.get("anyOf"): - attribute_list.append( - (property_key, convert_json_schema_to_python_class(property_val["anyOf"][0], schema_name, True)) - ) - elif property_val.get("additionalProperties"): - attribute_list.append( - (property_key, typing.Dict[str, _get_element_type(property_val["additionalProperties"])]) # type: ignore - ) - else: - attribute_list.append( - (property_key, convert_json_schema_to_python_class(property_val, schema_name, True)) - ) - elif property_type == "enum": - attribute_list.append([property_key, str]) # type: ignore - # Handle int, float, bool or str - else: - attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore - return attribute_list - - -def generate_attribute_list_from_dataclass_json(schema: dict, schema_name: typing.Any): +def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[dataclasses.dataclass()]: # type: ignore + """ + Generate a model class based on the provided JSON Schema + :param schema: dict representing valid JSON schema + :param schema_name: dataclass name of return type + """ attribute_list = [] for property_key, property_val in schema[schema_name]["properties"].items(): property_type = property_val["type"] @@ -1624,19 +1595,6 @@ def generate_attribute_list_from_dataclass_json(schema: dict, schema_name: typin # Handle int, float, bool or str else: attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore - return attribute_list - - -def convert_json_schema_to_python_class(schema: dict, schema_name: typing.Any, is_dataclass_json_mixin: bool = False) -> Type[dataclasses.dataclass()]: # type: ignore - """ - Generate a model class based on the provided JSON Schema - :param schema: dict representing valid JSON schema - :param schema_name: dataclass name of return type - """ - if is_dataclass_json_mixin: - attribute_list = generate_attribute_list_from_dataclass_json_mixin(schema, schema_name) - else: - attribute_list = generate_attribute_list_from_dataclass_json(schema, schema_name) return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list)) From b57a82be8cf7b69893695a94f6eea1efb092d866 Mon Sep 17 00:00:00 2001 From: HH Date: Wed, 26 Jul 2023 16:44:42 +0800 Subject: [PATCH 24/38] add import DataClassJSONMixin Signed-off-by: HH --- flytekit/core/type_engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index c8724facf8..bc515fecc2 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -15,6 +15,7 @@ from typing import Dict, List, NamedTuple, Optional, Type, cast from dataclasses_json import DataClassJsonMixin, dataclass_json +from mashumaro.mixins.json import DataClassJSONMixin from google.protobuf import json_format as _json_format from google.protobuf import struct_pb2 as _struct from google.protobuf.json_format import MessageToDict as _MessageToDict From e3ccbf3b2649944a4da331c0bffab7de9ed0a6d2 Mon Sep 17 00:00:00 2001 From: HH Date: Wed, 26 Jul 2023 16:45:58 +0800 Subject: [PATCH 25/38] support DataClassJSONMixin to json in DataClassTransformer Signed-off-by: HH --- flytekit/core/type_engine.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index bc515fecc2..60ab29d115 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -349,8 +349,7 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: if not issubclass(t, DataClassJsonMixin) and not issubclass(t, DataClassJSONMixin): raise AssertionError( - f"Dataclass {t} should be decorated with @dataclass_json or be a subclass of DataClassJsonMixin to be " - "serialized correctly" + f"Dataclass {t} should be decorated with @dataclass_json to be " f"serialized correctly" ) schema = None try: @@ -364,11 +363,7 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: # check if DataClass mixin from marshmallow_jsonschema import JSONSchema - schema = JSONSchema().dump(s) - else: # DataClassJSONMixin - from mashumaro.jsonschema import build_json_schema - - schema = build_json_schema(cast(DataClassJSONMixin, self._get_origin_type_in_annotation(t))).to_dict() + schema = JSONSchema().dump(s) except Exception as e: # https://github.com/lovasoa/marshmallow_dataclass/issues/13 logger.warning( @@ -385,14 +380,15 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp f"{type(python_val)} is not of type @dataclass, only Dataclasses are supported for " f"user defined datatypes in Flytekit" ) - if not isinstance(python_val, (DataClassJsonMixin, DataClassJSONMixin)): + if not issubclass(type(python_val), DataClassJsonMixin): raise TypeTransformerFailedError( f"Dataclass {python_type} should be decorated with @dataclass_json or subclass of DataClassJSONMixin to be " f"serialized correctly" ) self._serialize_flyte_type(python_val, python_type) - json_str = python_val.to_json() # type: ignore - return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) # type: ignore + return Literal( + scalar=Scalar(generic=_json_format.Parse(cast(DataClassJsonMixin, python_val).to_json(), _struct.Struct())) + ) def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]: # dataclass will try to hash python type when calling dataclass.schema(), but some types in the annotation is From 457d0d8eedbc245bb9c2a6ef1b3e2c08b010658d Mon Sep 17 00:00:00 2001 From: hhcs9527 Date: Wed, 26 Jul 2023 17:01:52 +0000 Subject: [PATCH 26/38] fix lint issue files Signed-off-by: hhcs9527 --- flytekit/core/type_engine.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 60ab29d115..daec33571f 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -15,7 +15,6 @@ from typing import Dict, List, NamedTuple, Optional, Type, cast from dataclasses_json import DataClassJsonMixin, dataclass_json -from mashumaro.mixins.json import DataClassJSONMixin from google.protobuf import json_format as _json_format from google.protobuf import struct_pb2 as _struct from google.protobuf.json_format import MessageToDict as _MessageToDict @@ -349,7 +348,7 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: if not issubclass(t, DataClassJsonMixin) and not issubclass(t, DataClassJSONMixin): raise AssertionError( - f"Dataclass {t} should be decorated with @dataclass_json to be " f"serialized correctly" + f"Dataclass {t} should be decorated with @dataclass_json or mixin with DataClassJSONMixin to be " f"serialized correctly" ) schema = None try: @@ -363,7 +362,10 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: # check if DataClass mixin from marshmallow_jsonschema import JSONSchema - schema = JSONSchema().dump(s) + schema = JSONSchema().dump(s) + else: # DataClassJSONMixin + from mashumaro.jsonschema import build_json_schema + schema = build_json_schema(cast(DataClassJSONMixin, self._get_origin_type_in_annotation(t))).to_dict() except Exception as e: # https://github.com/lovasoa/marshmallow_dataclass/issues/13 logger.warning( @@ -380,14 +382,20 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp f"{type(python_val)} is not of type @dataclass, only Dataclasses are supported for " f"user defined datatypes in Flytekit" ) - if not issubclass(type(python_val), DataClassJsonMixin): + if not issubclass(type(python_val), DataClassJsonMixin) and not issubclass(type(python_val), DataClassJSONMixin): raise TypeTransformerFailedError( f"Dataclass {python_type} should be decorated with @dataclass_json or subclass of DataClassJSONMixin to be " f"serialized correctly" ) self._serialize_flyte_type(python_val, python_type) + + if issubclass(type(python_val), DataClassJsonMixin): + json_str = cast(DataClassJsonMixin, python_val).to_json() # type: ignore + else: + json_str = cast(DataClassJSONMixin, python_val).to_json() # type: ignore + return Literal( - scalar=Scalar(generic=_json_format.Parse(cast(DataClassJsonMixin, python_val).to_json(), _struct.Struct())) + scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())) # type: ignore ) def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]: From 11f7616f33914988e227a7cdc125b7b0bcb756eb Mon Sep 17 00:00:00 2001 From: hhcs9527 Date: Fri, 28 Jul 2023 07:08:01 +0000 Subject: [PATCH 27/38] add new test for testing mashumaro dataclass Signed-off-by: hhcs9527 --- flytekit/core/type_engine.py | 40 ++- tests/flytekit/unit/core/test_type_engine.py | 251 +++++++++---------- 2 files changed, 150 insertions(+), 141 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index daec33571f..503ef34b4e 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1573,14 +1573,42 @@ def to_literal( def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: return expected_python_type(lv.scalar.primitive.string_value) # type: ignore +def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name: typing.Any): + attribute_list = [] + for property_key, property_val in schema["properties"].items(): + if property_val.get("anyOf"): + property_type = property_val["anyOf"][0]["type"] + elif property_val.get("enum") : + property_type = "enum" + else: + property_type = property_val["type"] + # Handle list + if property_type == "array": + attribute_list.append((property_key, typing.List[_get_element_type(property_val["items"])])) # type: ignore + # Handle dataclass and dict + elif property_type == "object": + if property_val.get("anyOf"): + attribute_list.append((property_key, convert_json_schema_to_python_class(property_val["anyOf"][0], schema_name, True))) + elif property_val.get("additionalProperties"): + attribute_list.append( + (property_key, typing.Dict[str, _get_element_type(property_val["additionalProperties"])]) # type: ignore + ) + else: + attribute_list.append((property_key, convert_json_schema_to_python_class(property_val, schema_name, True))) + elif property_type == "enum": + attribute_list.append([property_key, str]) # type: ignore + # Handle int, float, bool or str + else: + attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore + return attribute_list -def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[dataclasses.dataclass()]: # type: ignore +def convert_json_schema_to_python_class(schema: Dict[str, Any], schema_name: str) -> Type[Any]: """ Generate a model class based on the provided JSON Schema :param schema: dict representing valid JSON schema :param schema_name: dataclass name of return type """ - attribute_list = [] + attribute_list: List[Tuple[str, type]] = [] for property_key, property_val in schema[schema_name]["properties"].items(): property_type = property_val["type"] # Handle list @@ -1604,12 +1632,8 @@ def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[datac return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list)) -def _get_element_type(element_property: typing.Dict[str, typing.Any]) -> Type: - element_type = ( - [e_property["type"] for e_property in element_property["anyOf"]] - if element_property.get("anyOf") - else element_property["type"] - ) +def _get_element_type(element_property: typing.Dict[str, str]) -> Type: + element_type = element_property["type"] element_format = element_property["format"] if "format" in element_property else None if type(element_type) == list: diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 8c7d2f3545..d3049118cf 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -20,6 +20,7 @@ from marshmallow_enum import LoadDumpOptions from marshmallow_jsonschema import JSONSchema from mashumaro.mixins.json import DataClassJSONMixin +import mashumaro from pandas._testing import assert_frame_equal from typing_extensions import Annotated, get_args, get_origin @@ -486,24 +487,6 @@ class Foo(DataClassJsonMixin): _ = foo.c -def test_convert_json_schema_to_python_class_with_dataclassjsonmixin(): - @dataclass - class Foo(DataClassJSONMixin): - x: int - y: str - - from mashumaro.jsonschema import build_json_schema - - schema = build_json_schema(typing.cast(DataClassJSONMixin, Foo)).to_dict() - foo_class = convert_json_schema_to_python_class(schema, "FooSchema", is_dataclass_json_mixin=True) - foo = foo_class(x=1, y="hello") - foo.x = 2 - assert foo.x == 2 - assert foo.y == "hello" - with pytest.raises(AttributeError): - _ = foo.c - - def test_list_transformer(): l0 = Literal(scalar=Scalar(primitive=Primitive(integer=3))) l1 = Literal(scalar=Scalar(primitive=Primitive(integer=4))) @@ -782,6 +765,123 @@ def test_dataclass_transformer_with_dataclassjsonmixin(): assert t.metadata is None +@dataclass +class InnerStruct_transformer(DataClassJSONMixin): + a: int + b: typing.Optional[str] + c: typing.List[int] + + +@dataclass +class TestStruct_transformer(DataClassJSONMixin): + s: InnerStruct_transformer + m: typing.Dict[str, str] + + +@dataclass +class TestStructB_transformer(DataClassJSONMixin): + s: InnerStruct_transformer + m: typing.Dict[int, str] + n: typing.Optional[typing.List[typing.List[int]]] = None + o: typing.Optional[typing.Dict[int, typing.Dict[int, int]]] = None + + +@dataclass +class TestStructC_transformer(DataClassJSONMixin): + s: InnerStruct_transformer + m: typing.Dict[str, int] + + +@dataclass +class TestStructD_transformer(DataClassJSONMixin): + s: InnerStruct_transformer + m: typing.Dict[str, typing.List[int]] + + +@dataclass # to ask => not support => failed right away +class UnsupportedSchemaType_transformer: + _a:str="Hello" + + +def test_dataclass_transformer_with_dataclassjsonmixin(): + schema = { + "type": "object", + "title": "TestStruct_transformer", + "properties": { + "s": { + "type": "object", + "title": "InnerStruct_transformer", + "properties": { + "a": { + "type": "integer" + }, + "b": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ] + }, + "c": { + "type": "array", + "items": { + "type": "integer" + } + } + }, + "additionalProperties": False, + "required": [ + "a", + "b", + "c" + ] + }, + "m": { + "type": "object", + "additionalProperties": { + "type": "string" + }, + "propertyNames": { + "type": "string" + } + } + }, + "additionalProperties": False, + "required": [ + "s", + "m" + ] + } + + tf = DataclassTransformer() + t = tf.get_literal_type(TestStruct_transformer) + assert t is not None + assert t.simple is not None + assert t.simple == SimpleType.STRUCT + assert t.metadata is not None + assert t.metadata == schema + + t = TypeEngine.to_literal_type(TestStruct_transformer) + assert t is not None + assert t.simple is not None + assert t.simple == SimpleType.STRUCT + assert t.metadata is not None + assert t.metadata == schema + + +@pytest.mark.xfail(raises=mashumaro.exceptions.UnserializableField) +def test_unsupported_schema_type(): + # The code that is expected to raise the exception during class definition + @dataclass + class UnsupportedNestedStruct_transformer(DataClassJSONMixin): + a: int + s: UnsupportedSchemaType_transformer + + tf = DataclassTransformer() + t = tf.get_literal_type(UnsupportedNestedStruct_transformer) def test_dataclass_int_preserving(): ctx = FlyteContext.current_context() @@ -891,90 +991,6 @@ class TestFileStruct(DataClassJsonMixin): assert o.i_prime == A(a=99) -@dataclass -class A_optional_flytefile(DataClassJSONMixin): - a: int - - -@dataclass -class TestFileStruct_optional_flytefile(DataClassJSONMixin): - a: FlyteFile - b: typing.Optional[FlyteFile] - b_prime: typing.Optional[FlyteFile] - c: typing.Union[FlyteFile, None] - d: typing.List[FlyteFile] - e: typing.List[typing.Optional[FlyteFile]] - e_prime: typing.List[typing.Optional[FlyteFile]] - f: typing.Dict[str, FlyteFile] - g: typing.Dict[str, typing.Optional[FlyteFile]] - g_prime: typing.Dict[str, typing.Optional[FlyteFile]] - h: typing.Optional[FlyteFile] = None - h_prime: typing.Optional[FlyteFile] = None - i: typing.Optional[A_optional_flytefile] = None - i_prime: typing.Optional[A_optional_flytefile] = field(default_factory=lambda: A_optional_flytefile(a=99)) - - -@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") -def test_optional_flytefile_in_dataclassjsonmixin(mock_upload_dir): - mock_upload_dir.return_value = True - - remote_path = "s3://tmp/file" - with tempfile.TemporaryFile() as f: - f.write(b"abc") - f1 = FlyteFile("f1", remote_path=remote_path) - o = TestFileStruct_optional_flytefile( - a=f1, - b=f1, - b_prime=None, - c=f1, - d=[f1], - e=[f1], - e_prime=[None], - f={"a": f1}, - g={"a": f1}, - g_prime={"a": None}, - h=f1, - i=A_optional_flytefile(a=42), - ) - - ctx = FlyteContext.current_context() - tf = DataclassTransformer() - lt = tf.get_literal_type(TestFileStruct_optional_flytefile) - lv = tf.to_literal(ctx, o, TestFileStruct_optional_flytefile, lt) - - assert lv.scalar.generic["a"].fields["path"].string_value == remote_path - assert lv.scalar.generic["b"].fields["path"].string_value == remote_path - assert lv.scalar.generic["b_prime"] is None - assert lv.scalar.generic["c"].fields["path"].string_value == remote_path - assert lv.scalar.generic["d"].values[0].struct_value.fields["path"].string_value == remote_path - assert lv.scalar.generic["e"].values[0].struct_value.fields["path"].string_value == remote_path - assert lv.scalar.generic["e_prime"].values[0].WhichOneof("kind") == "null_value" - assert lv.scalar.generic["f"]["a"].fields["path"].string_value == remote_path - assert lv.scalar.generic["g"]["a"].fields["path"].string_value == remote_path - assert lv.scalar.generic["g_prime"]["a"] is None - assert lv.scalar.generic["h"].fields["path"].string_value == remote_path - assert lv.scalar.generic["h_prime"] is None - assert lv.scalar.generic["i"].fields["a"].number_value == 42 - assert lv.scalar.generic["i_prime"].fields["a"].number_value == 99 - - ot = tf.to_python_value(ctx, lv=lv, expected_python_type=TestFileStruct_optional_flytefile) - - assert o.a.path == ot.a.remote_source - assert o.b.path == ot.b.remote_source - assert ot.b_prime is None - assert o.c.path == ot.c.remote_source - assert o.d[0].path == ot.d[0].remote_source - assert o.e[0].path == ot.e[0].remote_source - assert o.e_prime == [None] - assert o.f["a"].path == ot.f["a"].remote_source - assert o.g["a"].path == ot.g["a"].remote_source - assert o.g_prime == {"a": None} - assert o.h.path == ot.h.remote_source - assert ot.h_prime is None - assert o.i == ot.i - assert o.i_prime == A_optional_flytefile(a=99) - - def test_flyte_file_in_dataclass(): @dataclass class TestInnerFileStruct(DataClassJsonMixin): @@ -1380,37 +1396,6 @@ class Bar(DataClassJsonMixin): DataclassTransformer().assert_type(gt, pv) -@dataclass -class ArgsAssert(DataClassJSONMixin): - x: int - y: typing.Optional[str] - - -@dataclass -class SchemaArgsAssert(DataClassJSONMixin): - x: typing.Optional[ArgsAssert] - - -def test_assert_dataclassjsonmixin_type(): - pt = SchemaArgsAssert - lt = TypeEngine.to_literal_type(pt) - gt = TypeEngine.guess_python_type(lt) - pv = SchemaArgsAssert(x=ArgsAssert(x=3, y="hello")) - DataclassTransformer().assert_type(gt, pv) - DataclassTransformer().assert_type(SchemaArgsAssert, pv) - - @dataclass - class Bar(DataClassJSONMixin): - x: int - - pv = Bar(x=3) - with pytest.raises( - TypeTransformerFailedError, - match="Type of Val '' is not an instance of ", - ): - DataclassTransformer().assert_type(gt, pv) - - def test_union_transformer(): assert UnionTransformer.is_optional_type(typing.Optional[int]) assert not UnionTransformer.is_optional_type(str) From b9cb5e9b8b367cb9ffaef1384ba9ff7da882aba8 Mon Sep 17 00:00:00 2001 From: hhcs9527 Date: Sun, 30 Jul 2023 16:36:15 +0000 Subject: [PATCH 28/38] support structure dataclass and flytescheme with DataClassJSONMixin Signed-off-by: hhcs9527 --- flytekit/types/schema/types.py | 1 + flytekit/types/structured/structured_dataset.py | 1 + 2 files changed, 2 insertions(+) diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index bba099a57e..53ff012b1e 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -20,6 +20,7 @@ from flytekit.loggers import logger from flytekit.models.literals import Literal, Scalar, Schema from flytekit.models.types import LiteralType, SchemaType +from mashumaro.mixins.json import DataClassJSONMixin T = typing.TypeVar("T") diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 99a0e0832b..12331a4a2a 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -23,6 +23,7 @@ from flytekit.models import types as type_models from flytekit.models.literals import Literal, Scalar, StructuredDatasetMetadata from flytekit.models.types import LiteralType, SchemaType, StructuredDatasetType +from mashumaro.mixins.json import DataClassJSONMixin if typing.TYPE_CHECKING: import pandas as pd From e989f2c1c510959fffb4829ead750b6562f66bf4 Mon Sep 17 00:00:00 2001 From: hhcs9527 Date: Mon, 31 Jul 2023 10:03:03 +0000 Subject: [PATCH 29/38] add test && fix lint Signed-off-by: hhcs9527 --- flytekit/core/type_engine.py | 35 ++- flytekit/types/schema/types.py | 1 - .../types/structured/structured_dataset.py | 1 - tests/flytekit/unit/core/test_type_engine.py | 200 +++++++++++++----- 4 files changed, 175 insertions(+), 62 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 503ef34b4e..e706497d10 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1573,12 +1573,13 @@ def to_literal( def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: return expected_python_type(lv.scalar.primitive.string_value) # type: ignore + def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name: typing.Any): attribute_list = [] for property_key, property_val in schema["properties"].items(): if property_val.get("anyOf"): property_type = property_val["anyOf"][0]["type"] - elif property_val.get("enum") : + elif property_val.get("enum"): property_type = "enum" else: property_type = property_val["type"] @@ -1588,13 +1589,17 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name: # Handle dataclass and dict elif property_type == "object": if property_val.get("anyOf"): - attribute_list.append((property_key, convert_json_schema_to_python_class(property_val["anyOf"][0], schema_name, True))) + attribute_list.append( + (property_key, convert_json_schema_to_python_class(property_val["anyOf"][0], schema_name, True)) + ) elif property_val.get("additionalProperties"): attribute_list.append( (property_key, typing.Dict[str, _get_element_type(property_val["additionalProperties"])]) # type: ignore ) else: - attribute_list.append((property_key, convert_json_schema_to_python_class(property_val, schema_name, True))) + attribute_list.append( + (property_key, convert_json_schema_to_python_class(property_val, schema_name, True)) + ) elif property_type == "enum": attribute_list.append([property_key, str]) # type: ignore # Handle int, float, bool or str @@ -1602,13 +1607,9 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name: attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore return attribute_list -def convert_json_schema_to_python_class(schema: Dict[str, Any], schema_name: str) -> Type[Any]: - """ - Generate a model class based on the provided JSON Schema - :param schema: dict representing valid JSON schema - :param schema_name: dataclass name of return type - """ - attribute_list: List[Tuple[str, type]] = [] + +def generate_attribute_list_from_dataclass_json(schema: dict, schema_name: typing.Any): + attribute_list = [] for property_key, property_val in schema[schema_name]["properties"].items(): property_type = property_val["type"] # Handle list @@ -1629,11 +1630,23 @@ def convert_json_schema_to_python_class(schema: Dict[str, Any], schema_name: str else: attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore + +def convert_json_schema_to_python_class(schema: dict, schema_name: typing.Any, is_dataclass_json_mixin: bool = False) -> Type[dataclasses.dataclass()]: # type: ignore + """ + Generate a model class based on the provided JSON Schema + :param schema: dict representing valid JSON schema + :param schema_name: dataclass name of return type + """ + if is_dataclass_json_mixin: + attribute_list = generate_attribute_list_from_dataclass_json_mixin(schema, schema_name) + else: + attribute_list = generate_attribute_list_from_dataclass_json(schema, schema_name) + return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list)) def _get_element_type(element_property: typing.Dict[str, str]) -> Type: - element_type = element_property["type"] + element_type = [e_property["type"] for e_property in element_property["anyOf"]] if element_property.get("anyOf") else element_property["type"] element_format = element_property["format"] if "format" in element_property else None if type(element_type) == list: diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index 53ff012b1e..bba099a57e 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -20,7 +20,6 @@ from flytekit.loggers import logger from flytekit.models.literals import Literal, Scalar, Schema from flytekit.models.types import LiteralType, SchemaType -from mashumaro.mixins.json import DataClassJSONMixin T = typing.TypeVar("T") diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 12331a4a2a..99a0e0832b 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -23,7 +23,6 @@ from flytekit.models import types as type_models from flytekit.models.literals import Literal, Scalar, StructuredDatasetMetadata from flytekit.models.types import LiteralType, SchemaType, StructuredDatasetType -from mashumaro.mixins.json import DataClassJSONMixin if typing.TYPE_CHECKING: import pandas as pd diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index d3049118cf..79df9d0b51 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -20,7 +20,6 @@ from marshmallow_enum import LoadDumpOptions from marshmallow_jsonschema import JSONSchema from mashumaro.mixins.json import DataClassJSONMixin -import mashumaro from pandas._testing import assert_frame_equal from typing_extensions import Annotated, get_args, get_origin @@ -487,6 +486,25 @@ class Foo(DataClassJsonMixin): _ = foo.c +def test_convert_json_schema_to_python_class_with_dataclassjsonmixin(): + @dataclass + class Foo(DataClassJSONMixin): + x: int + y: str + + # schema = JSONSchema().dump(typing.cast(DataClassJSONMixin, Foo).schema()) + from mashumaro.jsonschema import build_json_schema + + schema = build_json_schema(typing.cast(DataClassJSONMixin, Foo)).to_dict() + foo_class = convert_json_schema_to_python_class(schema, "FooSchema", is_dataclass_json_mixin=True) + foo = foo_class(x=1, y="hello") + foo.x = 2 + assert foo.x == 2 + assert foo.y == "hello" + with pytest.raises(AttributeError): + _ = foo.c + + def test_list_transformer(): l0 = Literal(scalar=Scalar(primitive=Primitive(integer=3))) l1 = Literal(scalar=Scalar(primitive=Primitive(integer=4))) @@ -798,9 +816,15 @@ class TestStructD_transformer(DataClassJSONMixin): m: typing.Dict[str, typing.List[int]] -@dataclass # to ask => not support => failed right away +@dataclass class UnsupportedSchemaType_transformer: - _a:str="Hello" + _a: str = "Hello" + + +@dataclass +class UnsupportedNestedStruct_transformer(DataClassJSONMixin): + a: int + s: UnsupportedSchemaType_transformer def test_dataclass_transformer_with_dataclassjsonmixin(): @@ -812,48 +836,17 @@ def test_dataclass_transformer_with_dataclassjsonmixin(): "type": "object", "title": "InnerStruct_transformer", "properties": { - "a": { - "type": "integer" - }, - "b": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ] - }, - "c": { - "type": "array", - "items": { - "type": "integer" - } - } + "a": {"type": "integer"}, + "b": {"anyOf": [{"type": "string"}, {"type": "null"}]}, + "c": {"type": "array", "items": {"type": "integer"}}, }, "additionalProperties": False, - "required": [ - "a", - "b", - "c" - ] + "required": ["a", "b", "c"], }, - "m": { - "type": "object", - "additionalProperties": { - "type": "string" - }, - "propertyNames": { - "type": "string" - } - } + "m": {"type": "object", "additionalProperties": {"type": "string"}, "propertyNames": {"type": "string"}}, }, "additionalProperties": False, - "required": [ - "s", - "m" - ] + "required": ["s", "m"], } tf = DataclassTransformer() @@ -871,17 +864,13 @@ def test_dataclass_transformer_with_dataclassjsonmixin(): assert t.metadata is not None assert t.metadata == schema + t = tf.get_literal_type(UnsupportedNestedStruct) + assert t is not None + assert t.simple is not None + assert t.simple == SimpleType.STRUCT + assert t.metadata is None -@pytest.mark.xfail(raises=mashumaro.exceptions.UnserializableField) -def test_unsupported_schema_type(): - # The code that is expected to raise the exception during class definition - @dataclass - class UnsupportedNestedStruct_transformer(DataClassJSONMixin): - a: int - s: UnsupportedSchemaType_transformer - tf = DataclassTransformer() - t = tf.get_literal_type(UnsupportedNestedStruct_transformer) def test_dataclass_int_preserving(): ctx = FlyteContext.current_context() @@ -991,6 +980,90 @@ class TestFileStruct(DataClassJsonMixin): assert o.i_prime == A(a=99) +@dataclass +class A_optional_flytefile(DataClassJSONMixin): + a: int + + +@dataclass +class TestFileStruct_optional_flytefile(DataClassJSONMixin): + a: FlyteFile + b: typing.Optional[FlyteFile] + b_prime: typing.Optional[FlyteFile] + c: typing.Union[FlyteFile, None] + d: typing.List[FlyteFile] + e: typing.List[typing.Optional[FlyteFile]] + e_prime: typing.List[typing.Optional[FlyteFile]] + f: typing.Dict[str, FlyteFile] + g: typing.Dict[str, typing.Optional[FlyteFile]] + g_prime: typing.Dict[str, typing.Optional[FlyteFile]] + h: typing.Optional[FlyteFile] = None + h_prime: typing.Optional[FlyteFile] = None + i: typing.Optional[A_optional_flytefile] = None + i_prime: typing.Optional[A_optional_flytefile] = field(default_factory=lambda: A_optional_flytefile(a=99)) + + +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") +def test_optional_flytefile_in_dataclassjsonmixin(mock_upload_dir): + mock_upload_dir.return_value = True + + remote_path = "s3://tmp/file" + with tempfile.TemporaryFile() as f: + f.write(b"abc") + f1 = FlyteFile("f1", remote_path=remote_path) + o = TestFileStruct_optional_flytefile( + a=f1, + b=f1, + b_prime=None, + c=f1, + d=[f1], + e=[f1], + e_prime=[None], + f={"a": f1}, + g={"a": f1}, + g_prime={"a": None}, + h=f1, + i=A_optional_flytefile(a=42), + ) + + ctx = FlyteContext.current_context() + tf = DataclassTransformer() + lt = tf.get_literal_type(TestFileStruct_optional_flytefile) + lv = tf.to_literal(ctx, o, TestFileStruct_optional_flytefile, lt) + + assert lv.scalar.generic["a"] == remote_path + assert lv.scalar.generic["b"] == remote_path + assert lv.scalar.generic["b_prime"] is None + assert lv.scalar.generic["c"] == remote_path + assert lv.scalar.generic["d"].values[0].string_value == remote_path + assert lv.scalar.generic["e"].values[0].string_value == remote_path + assert lv.scalar.generic["e_prime"].values[0].WhichOneof("kind") == "null_value" + assert lv.scalar.generic["f"]["a"] == remote_path + assert lv.scalar.generic["g"]["a"] == remote_path + assert lv.scalar.generic["g_prime"]["a"] is None + assert lv.scalar.generic["h"] == remote_path + assert lv.scalar.generic["h_prime"] is None + assert lv.scalar.generic["i"]["a"] == 42 + assert lv.scalar.generic["i_prime"]["a"] == 99 + + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=TestFileStruct_optional_flytefile) + + assert o.a.path == ot.a.remote_source + assert o.b.path == ot.b.remote_source + assert ot.b_prime is None + assert o.c.path == ot.c.remote_source + assert o.d[0].path == ot.d[0].remote_source + assert o.e[0].path == ot.e[0].remote_source + assert o.e_prime == [None] + assert o.f["a"].path == ot.f["a"].remote_source + assert o.g["a"].path == ot.g["a"].remote_source + assert o.g_prime == {"a": None} + assert o.h.path == ot.h.remote_source + assert ot.h_prime is None + assert o.i == ot.i + assert o.i_prime == A_optional_flytefile(a=99) + + def test_flyte_file_in_dataclass(): @dataclass class TestInnerFileStruct(DataClassJsonMixin): @@ -1396,6 +1469,35 @@ class Bar(DataClassJsonMixin): DataclassTransformer().assert_type(gt, pv) +@dataclass +class ArgsAssert(DataClassJSONMixin): + x: int + y: typing.Optional[str] + +@dataclass +class SchemaArgsAssert(DataClassJSONMixin): + x: typing.Optional[ArgsAssert] + + +def test_assert_dataclassjsonmixin_type(): + pt = SchemaArgsAssert + lt = TypeEngine.to_literal_type(pt) + gt = TypeEngine.guess_python_type(lt) + pv = SchemaArgsAssert(x=ArgsAssert(x=3, y="hello")) + DataclassTransformer().assert_type(gt, pv) + DataclassTransformer().assert_type(SchemaArgsAssert, pv) + + @dataclass + class Bar(DataClassJSONMixin): + x: int + + pv = Bar(x=3) + with pytest.raises( + TypeTransformerFailedError, match="Type of Val '' is not an instance of " + ): + DataclassTransformer().assert_type(gt, pv) + + def test_union_transformer(): assert UnionTransformer.is_optional_type(typing.Optional[int]) assert not UnionTransformer.is_optional_type(str) From fcf04010534927a84a1e7cae7d085649814101f8 Mon Sep 17 00:00:00 2001 From: hhcs9527 Date: Wed, 16 Aug 2023 04:25:42 +0000 Subject: [PATCH 30/38] Update the code with advise Signed-off-by: hhcs9527 --- flytekit/core/type_engine.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index e706497d10..2d926217f9 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -382,7 +382,9 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp f"{type(python_val)} is not of type @dataclass, only Dataclasses are supported for " f"user defined datatypes in Flytekit" ) - if not issubclass(type(python_val), DataClassJsonMixin) and not issubclass(type(python_val), DataClassJSONMixin): + if not issubclass(type(python_val), DataClassJsonMixin) and not issubclass( + type(python_val), DataClassJSONMixin + ): raise TypeTransformerFailedError( f"Dataclass {python_type} should be decorated with @dataclass_json or subclass of DataClassJSONMixin to be " f"serialized correctly" @@ -394,9 +396,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp else: json_str = cast(DataClassJSONMixin, python_val).to_json() # type: ignore - return Literal( - scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())) # type: ignore - ) + return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) # type: ignore def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]: # dataclass will try to hash python type when calling dataclass.schema(), but some types in the annotation is From 84c8c1457e7ba7e5dd2e4a66ae8f1c2937e9aeba Mon Sep 17 00:00:00 2001 From: HH Date: Wed, 23 Aug 2023 09:34:49 +0800 Subject: [PATCH 31/38] fix type engine bugs Signed-off-by: HH --- flytekit/core/type_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 2d926217f9..39de4efcef 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1629,7 +1629,7 @@ def generate_attribute_list_from_dataclass_json(schema: dict, schema_name: typin # Handle int, float, bool or str else: attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore - + return attribute_list def convert_json_schema_to_python_class(schema: dict, schema_name: typing.Any, is_dataclass_json_mixin: bool = False) -> Type[dataclasses.dataclass()]: # type: ignore """ From 6d188faa91e712568310d796e53db0f14e2be4cc Mon Sep 17 00:00:00 2001 From: HH Date: Wed, 30 Aug 2023 11:28:33 +0800 Subject: [PATCH 32/38] Split convert_json_schema_to_python_class to convert_mashumaro_json_schema_to_python_class and convert_marshmallow_json_schema_to_python_class Signed-off-by: HH --- flytekit/core/type_engine.py | 42 ++++-- tests/flytekit/unit/core/test_type_engine.py | 142 +++++-------------- 2 files changed, 63 insertions(+), 121 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 39de4efcef..64ecfb7288 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -348,7 +348,8 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: if not issubclass(t, DataClassJsonMixin) and not issubclass(t, DataClassJSONMixin): raise AssertionError( - f"Dataclass {t} should be decorated with @dataclass_json or mixin with DataClassJSONMixin to be " f"serialized correctly" + f"Dataclass {t} should be decorated with @dataclass_json or mixin with DataClassJSONMixin to be " + f"serialized correctly" ) schema = None try: @@ -363,8 +364,9 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: from marshmallow_jsonschema import JSONSchema schema = JSONSchema().dump(s) - else: # DataClassJSONMixin + else: # DataClassJSONMixin from mashumaro.jsonschema import build_json_schema + schema = build_json_schema(cast(DataClassJSONMixin, self._get_origin_type_in_annotation(t))).to_dict() except Exception as e: # https://github.com/lovasoa/marshmallow_dataclass/issues/13 @@ -386,8 +388,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp type(python_val), DataClassJSONMixin ): raise TypeTransformerFailedError( - f"Dataclass {python_type} should be decorated with @dataclass_json or subclass of DataClassJSONMixin to be " - f"serialized correctly" + f"Dataclass {python_type} should be decorated with @dataclass_json or inherit DataClassJSONMixin to be " f"serialized correctly" ) self._serialize_flyte_type(python_val, python_type) @@ -663,10 +664,12 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[T]: # type: igno if literal_type.metadata is not None: if DEFINITIONS in literal_type.metadata: schema_name = literal_type.metadata["$ref"].split("/")[-1] - return convert_json_schema_to_python_class(literal_type.metadata[DEFINITIONS], schema_name) + return convert_marshmallow_json_schema_to_python_class( + literal_type.metadata[DEFINITIONS], schema_name + ) elif TITLE in literal_type.metadata: schema_name = literal_type.metadata[TITLE] - return convert_json_schema_to_python_class(literal_type.metadata, schema_name, True) + return convert_mashumaro_json_schema_to_python_class(literal_type.metadata, schema_name) raise ValueError(f"Dataclass transformer cannot reverse {literal_type}") @@ -1590,7 +1593,7 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name: elif property_type == "object": if property_val.get("anyOf"): attribute_list.append( - (property_key, convert_json_schema_to_python_class(property_val["anyOf"][0], schema_name, True)) + (property_key, convert_mashumaro_json_schema_to_python_class(property_val["anyOf"][0], schema_name)) ) elif property_val.get("additionalProperties"): attribute_list.append( @@ -1598,7 +1601,7 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name: ) else: attribute_list.append( - (property_key, convert_json_schema_to_python_class(property_val, schema_name, True)) + (property_key, convert_mashumaro_json_schema_to_python_class(property_val, schema_name)) ) elif property_type == "enum": attribute_list.append([property_key, str]) # type: ignore @@ -1619,7 +1622,7 @@ def generate_attribute_list_from_dataclass_json(schema: dict, schema_name: typin elif property_type == "object": if property_val.get("$ref"): name = property_val["$ref"].split("/")[-1] - attribute_list.append((property_key, convert_json_schema_to_python_class(schema, name))) + attribute_list.append((property_key, convert_marshmallow_json_schema_to_python_class(schema, name))) elif property_val.get("additionalProperties"): attribute_list.append( (property_key, Dict[str, _get_element_type(property_val["additionalProperties"])]) # type: ignore[misc,index] @@ -1631,22 +1634,31 @@ def generate_attribute_list_from_dataclass_json(schema: dict, schema_name: typin attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore return attribute_list -def convert_json_schema_to_python_class(schema: dict, schema_name: typing.Any, is_dataclass_json_mixin: bool = False) -> Type[dataclasses.dataclass()]: # type: ignore + +def convert_marshmallow_json_schema_to_python_class(schema: dict, schema_name: typing.Any, is_dataclass_json_mixin: bool = False) -> Type[dataclasses.dataclass()]: # type: ignore + """ + Generate a model class based on the provided JSON Schema + :param schema: dict representing valid JSON schema + :param schema_name: dataclass name of return type + """ + + attribute_list = generate_attribute_list_from_dataclass_json(schema, schema_name) + return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list)) + + +def convert_mashumaro_json_schema_to_python_class(schema: dict, schema_name: typing.Any, is_dataclass_json_mixin: bool = False) -> Type[dataclasses.dataclass()]: # type: ignore """ Generate a model class based on the provided JSON Schema :param schema: dict representing valid JSON schema :param schema_name: dataclass name of return type """ - if is_dataclass_json_mixin: - attribute_list = generate_attribute_list_from_dataclass_json_mixin(schema, schema_name) - else: - attribute_list = generate_attribute_list_from_dataclass_json(schema, schema_name) + attribute_list = generate_attribute_list_from_dataclass_json_mixin(schema, schema_name) return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list)) def _get_element_type(element_property: typing.Dict[str, str]) -> Type: - element_type = [e_property["type"] for e_property in element_property["anyOf"]] if element_property.get("anyOf") else element_property["type"] + element_type = [e_property["type"] for e_property in element_property["anyOf"]] if element_property.get("anyOf") else element_property["type"] # type: ignore element_format = element_property["format"] if "format" in element_property else None if type(element_type) == list: diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 79df9d0b51..fec17a9158 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -40,7 +40,8 @@ TypeTransformer, TypeTransformerFailedError, UnionTransformer, - convert_json_schema_to_python_class, + convert_marshmallow_json_schema_to_python_class, + convert_mashumaro_json_schema_to_python_class, dataclass_from_dict, get_underlying_type, is_annotated, @@ -268,7 +269,7 @@ class Foo(DataClassJsonMixin): ctx = FlyteContext.current_context() schema = JSONSchema().dump(typing.cast(DataClassJsonMixin, Foo).schema()) - foo_class = convert_json_schema_to_python_class(schema["definitions"], "FooSchema") + foo_class = convert_marshmallow_json_schema_to_python_class(schema["definitions"], "FooSchema") guessed_pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[foo_class]) pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[Foo]) @@ -327,7 +328,7 @@ def test_list_of_dataclassjsonmixin_getting_python_value(): from mashumaro.jsonschema import build_json_schema schema = build_json_schema(typing.cast(DataClassJSONMixin, Foo_getting_python_value)).to_dict() - foo_class = convert_json_schema_to_python_class(schema, "FooSchema", True) + foo_class = convert_mashumaro_json_schema_to_python_class(schema, "FooSchema") guessed_pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[foo_class]) pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[Foo_getting_python_value]) @@ -470,14 +471,14 @@ def recursive_assert(lit: LiteralType, expected: LiteralType, expected_depth: in ) -def test_convert_json_schema_to_python_class(): +def test_convert_marshmallow_json_schema_to_python_class(): @dataclass class Foo(DataClassJsonMixin): x: int y: str schema = JSONSchema().dump(typing.cast(DataClassJsonMixin, Foo).schema()) - foo_class = convert_json_schema_to_python_class(schema["definitions"], "FooSchema") + foo_class = convert_marshmallow_json_schema_to_python_class(schema["definitions"], "FooSchema") foo = foo_class(x=1, y="hello") foo.x = 2 assert foo.x == 2 @@ -486,7 +487,7 @@ class Foo(DataClassJsonMixin): _ = foo.c -def test_convert_json_schema_to_python_class_with_dataclassjsonmixin(): +def test_convert_mashumaro_json_schema_to_python_class(): @dataclass class Foo(DataClassJSONMixin): x: int @@ -496,7 +497,7 @@ class Foo(DataClassJSONMixin): from mashumaro.jsonschema import build_json_schema schema = build_json_schema(typing.cast(DataClassJSONMixin, Foo)).to_dict() - foo_class = convert_json_schema_to_python_class(schema, "FooSchema", is_dataclass_json_mixin=True) + foo_class = convert_mashumaro_json_schema_to_python_class(schema, "FooSchema") foo = foo_class(x=1, y="hello") foo.x = 2 assert foo.x == 2 @@ -783,94 +784,6 @@ def test_dataclass_transformer_with_dataclassjsonmixin(): assert t.metadata is None -@dataclass -class InnerStruct_transformer(DataClassJSONMixin): - a: int - b: typing.Optional[str] - c: typing.List[int] - - -@dataclass -class TestStruct_transformer(DataClassJSONMixin): - s: InnerStruct_transformer - m: typing.Dict[str, str] - - -@dataclass -class TestStructB_transformer(DataClassJSONMixin): - s: InnerStruct_transformer - m: typing.Dict[int, str] - n: typing.Optional[typing.List[typing.List[int]]] = None - o: typing.Optional[typing.Dict[int, typing.Dict[int, int]]] = None - - -@dataclass -class TestStructC_transformer(DataClassJSONMixin): - s: InnerStruct_transformer - m: typing.Dict[str, int] - - -@dataclass -class TestStructD_transformer(DataClassJSONMixin): - s: InnerStruct_transformer - m: typing.Dict[str, typing.List[int]] - - -@dataclass -class UnsupportedSchemaType_transformer: - _a: str = "Hello" - - -@dataclass -class UnsupportedNestedStruct_transformer(DataClassJSONMixin): - a: int - s: UnsupportedSchemaType_transformer - - -def test_dataclass_transformer_with_dataclassjsonmixin(): - schema = { - "type": "object", - "title": "TestStruct_transformer", - "properties": { - "s": { - "type": "object", - "title": "InnerStruct_transformer", - "properties": { - "a": {"type": "integer"}, - "b": {"anyOf": [{"type": "string"}, {"type": "null"}]}, - "c": {"type": "array", "items": {"type": "integer"}}, - }, - "additionalProperties": False, - "required": ["a", "b", "c"], - }, - "m": {"type": "object", "additionalProperties": {"type": "string"}, "propertyNames": {"type": "string"}}, - }, - "additionalProperties": False, - "required": ["s", "m"], - } - - tf = DataclassTransformer() - t = tf.get_literal_type(TestStruct_transformer) - assert t is not None - assert t.simple is not None - assert t.simple == SimpleType.STRUCT - assert t.metadata is not None - assert t.metadata == schema - - t = TypeEngine.to_literal_type(TestStruct_transformer) - assert t is not None - assert t.simple is not None - assert t.simple == SimpleType.STRUCT - assert t.metadata is not None - assert t.metadata == schema - - t = tf.get_literal_type(UnsupportedNestedStruct) - assert t is not None - assert t.simple is not None - assert t.simple == SimpleType.STRUCT - assert t.metadata is None - - def test_dataclass_int_preserving(): ctx = FlyteContext.current_context() @@ -1031,20 +944,35 @@ def test_optional_flytefile_in_dataclassjsonmixin(mock_upload_dir): lt = tf.get_literal_type(TestFileStruct_optional_flytefile) lv = tf.to_literal(ctx, o, TestFileStruct_optional_flytefile, lt) - assert lv.scalar.generic["a"] == remote_path - assert lv.scalar.generic["b"] == remote_path + # assert lv.scalar.generic["a"]["path"] == remote_path + # assert lv.scalar.generic["b"]["path"] == remote_path + # assert lv.scalar.generic["b_prime"] is None + # assert lv.scalar.generic["c"]["path"] == remote_path + # assert lv.scalar.generic["d"]["path"] == remote_path + # assert lv.scalar.generic["e"]["path"]== remote_path + # assert lv.scalar.generic["e_prime"].values[0].WhichOneof("kind") == "null_value" + # assert lv.scalar.generic["f"]["a"]["path"] == remote_path + # assert lv.scalar.generic["g"]["a"]["path"] == remote_path + # assert lv.scalar.generic["g_prime"]["a"] is None + # assert lv.scalar.generic["h"]["path"] == remote_path + # assert lv.scalar.generic["h_prime"] is None + # assert lv.scalar.generic["i"]["a"] == 42 + # assert lv.scalar.generic["i_prime"]["a"] == 99 + + assert lv.scalar.generic["a"].fields["path"].string_value == remote_path + assert lv.scalar.generic["b"].fields["path"].string_value == remote_path assert lv.scalar.generic["b_prime"] is None - assert lv.scalar.generic["c"] == remote_path - assert lv.scalar.generic["d"].values[0].string_value == remote_path - assert lv.scalar.generic["e"].values[0].string_value == remote_path + assert lv.scalar.generic["c"].fields["path"].string_value == remote_path + assert lv.scalar.generic["d"].values[0].struct_value.fields["path"].string_value == remote_path + assert lv.scalar.generic["e"].values[0].struct_value.fields["path"].string_value == remote_path assert lv.scalar.generic["e_prime"].values[0].WhichOneof("kind") == "null_value" - assert lv.scalar.generic["f"]["a"] == remote_path - assert lv.scalar.generic["g"]["a"] == remote_path + assert lv.scalar.generic["f"]["a"].fields["path"].string_value == remote_path + assert lv.scalar.generic["g"]["a"].fields["path"].string_value == remote_path assert lv.scalar.generic["g_prime"]["a"] is None - assert lv.scalar.generic["h"] == remote_path + assert lv.scalar.generic["h"].fields["path"].string_value == remote_path assert lv.scalar.generic["h_prime"] is None - assert lv.scalar.generic["i"]["a"] == 42 - assert lv.scalar.generic["i_prime"]["a"] == 99 + assert lv.scalar.generic["i"].fields["a"].number_value == 42 + assert lv.scalar.generic["i_prime"].fields["a"].number_value == 99 ot = tf.to_python_value(ctx, lv=lv, expected_python_type=TestFileStruct_optional_flytefile) @@ -1474,6 +1402,7 @@ class ArgsAssert(DataClassJSONMixin): x: int y: typing.Optional[str] + @dataclass class SchemaArgsAssert(DataClassJSONMixin): x: typing.Optional[ArgsAssert] @@ -1493,7 +1422,8 @@ class Bar(DataClassJSONMixin): pv = Bar(x=3) with pytest.raises( - TypeTransformerFailedError, match="Type of Val '' is not an instance of " + TypeTransformerFailedError, + match="Type of Val '' is not an instance of ", ): DataclassTransformer().assert_type(gt, pv) From c19bdc9b296f46fc29538bb07397e47060bba85e Mon Sep 17 00:00:00 2001 From: HH Date: Thu, 31 Aug 2023 09:37:15 +0800 Subject: [PATCH 33/38] Fix lint Signed-off-by: HH --- flytekit/core/type_engine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 64ecfb7288..e8a6c3f41f 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -388,7 +388,8 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp type(python_val), DataClassJSONMixin ): raise TypeTransformerFailedError( - f"Dataclass {python_type} should be decorated with @dataclass_json or inherit DataClassJSONMixin to be " f"serialized correctly" + f"Dataclass {python_type} should be decorated with @dataclass_json or inherit DataClassJSONMixin to be " + f"serialized correctly" ) self._serialize_flyte_type(python_val, python_type) From 5e7e72be99fe4c8d45c2f21f6efef708b560efac Mon Sep 17 00:00:00 2001 From: HH Date: Thu, 7 Sep 2023 11:29:07 +0800 Subject: [PATCH 34/38] remove un-relevant changes Signed-off-by: HH --- flytekit/types/structured/__init__.py | 14 --- flytekit/types/structured/snowflake.py | 102 ------------------ .../structured_dataset/test_snowflake.py | 48 --------- 3 files changed, 164 deletions(-) delete mode 100644 flytekit/types/structured/snowflake.py delete mode 100644 tests/flytekit/unit/types/structured_dataset/test_snowflake.py diff --git a/flytekit/types/structured/__init__.py b/flytekit/types/structured/__init__.py index 617e4bcafa..543117c865 100644 --- a/flytekit/types/structured/__init__.py +++ b/flytekit/types/structured/__init__.py @@ -70,17 +70,3 @@ def register_bigquery_handlers(): "We won't register bigquery handler for structured dataset because " "we can't find the packages google-cloud-bigquery-storage and google-cloud-bigquery" ) - - -def register_snowflake_handlers(): - try: - from .snowflake import PandasToSnowflakeEncodingHandlers, SnowflakeToPandasDecodingHandler - - StructuredDatasetTransformerEngine.register(SnowflakeToPandasDecodingHandler()) - StructuredDatasetTransformerEngine.register(PandasToSnowflakeEncodingHandlers()) - - except ImportError: - logger.info( - "We won't register snowflake handler for structured dataset because " - "we can't find package snowflakee-connector-python" - ) diff --git a/flytekit/types/structured/snowflake.py b/flytekit/types/structured/snowflake.py deleted file mode 100644 index 9f28734f43..0000000000 --- a/flytekit/types/structured/snowflake.py +++ /dev/null @@ -1,102 +0,0 @@ -import re -import typing - -import pandas as pd -import snowflake.connector -from snowflake.connector.pandas_tools import write_pandas - -from flytekit import FlyteContext -from flytekit.models import literals -from flytekit.models.types import StructuredDatasetType -from flytekit.types.structured.structured_dataset import ( - StructuredDataset, - StructuredDatasetDecoder, - StructuredDatasetEncoder, - StructuredDatasetMetadata, -) - -SNOWFLAKE = "snowflake" - - -def get_private_key(): - from cryptography.hazmat.backends import default_backend - from cryptography.hazmat.primitives import serialization - - import flytekit - - pk_path = flytekit.current_context().secrets.get_secrets_file(SNOWFLAKE, "rsa_key.p8") - - with open(pk_path, "rb") as key: - p_key = serialization.load_pem_private_key(key.read(), password=None, backend=default_backend()) - - return p_key.private_bytes( - encoding=serialization.Encoding.DER, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), - ) - - -def _write_to_sf(structured_dataset: StructuredDataset): - if structured_dataset.uri is None: - raise ValueError("structured_dataset.uri cannot be None.") - - uri = structured_dataset.uri - _, user, account, database, schema, warehouse, table = re.split("\\/|://|:", uri) - df = structured_dataset.dataframe - - conn = snowflake.connector.connect( - user=user, account=account, private_key=get_private_key(), database=database, schema=schema, warehouse=warehouse - ) - - write_pandas(conn, df, table) - - -def _read_from_sf( - flyte_value: literals.StructuredDataset, current_task_metadata: StructuredDatasetMetadata -) -> pd.DataFrame: - if flyte_value.uri is None: - raise ValueError("structured_dataset.uri cannot be None.") - - uri = flyte_value.uri - _, user, account, database, schema, warehouse, table = re.split("\\/|://|:", uri) - - conn = snowflake.connector.connect( - user=user, account=account, private_key=get_private_key(), database=database, schema=schema, warehouse=warehouse - ) - - cs = conn.cursor() - cs.execute(f"select * from {table}") - - dff = cs.fetch_pandas_all() - print("cs", cs) - print("dff", dff) - return dff - - -class PandasToSnowflakeEncodingHandlers(StructuredDatasetEncoder): - def __init__(self): - super().__init__(pd.DataFrame, SNOWFLAKE, supported_format="") - - def encode( - self, - ctx: FlyteContext, - structured_dataset: StructuredDataset, - structured_dataset_type: StructuredDatasetType, - ) -> literals.StructuredDataset: - _write_to_sf(structured_dataset) - return literals.StructuredDataset( - uri=typing.cast(str, structured_dataset.uri), metadata=StructuredDatasetMetadata(structured_dataset_type) - ) - - -class SnowflakeToPandasDecodingHandler(StructuredDatasetDecoder): - def __init__(self): - super().__init__(pd.DataFrame, SNOWFLAKE, supported_format="") - - def decode( - self, - ctx: FlyteContext, - flyte_value: literals.StructuredDataset, - current_task_metadata: StructuredDatasetMetadata, - ) -> pd.DataFrame: - return _read_from_sf(flyte_value, current_task_metadata) diff --git a/tests/flytekit/unit/types/structured_dataset/test_snowflake.py b/tests/flytekit/unit/types/structured_dataset/test_snowflake.py deleted file mode 100644 index c957c0bbce..0000000000 --- a/tests/flytekit/unit/types/structured_dataset/test_snowflake.py +++ /dev/null @@ -1,48 +0,0 @@ -import mock -import pandas as pd -import pytest -from typing_extensions import Annotated - -from flytekit import StructuredDataset, kwtypes, task, workflow - -pd_df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) -my_cols = kwtypes(Name=str, Age=int) - - -@task -def gen_df() -> Annotated[pd.DataFrame, my_cols, "parquet"]: - return pd_df - - -@task -def t1(df: pd.DataFrame) -> Annotated[StructuredDataset, my_cols]: - return StructuredDataset( - dataframe=df, uri="snowflake://dummy_user:dummy_account/dummy_database/dummy_schema/dummy_warehouse/dummy_table" - ) - - -@task -def t2(sd: Annotated[StructuredDataset, my_cols]) -> pd.DataFrame: - return sd.open(pd.DataFrame).all() - - -@workflow -def wf() -> pd.DataFrame: - df = gen_df() - sd = t1(df=df) - return t2(sd=sd) - - -@mock.patch("flytekit.types.structured.snowflake.get_private_key", return_value="pb") -@mock.patch("snowflake.connector.connect") -@pytest.mark.asyncio -async def test_sf_wf(mock_connect, mock_get_private_key): - class mock_dataframe: - def to_dataframe(self): - return pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) - - mock_connect_instance = mock_connect.return_value - mock_coursor_instance = mock_connect_instance.cursor.return_value - mock_coursor_instance.fetch_pandas_all.return_value = mock_dataframe().to_dataframe() - - assert wf().equals(pd_df) From 200300d3d1ac3f5f3d390f16c3c5d9ce6e2a0b77 Mon Sep 17 00:00:00 2001 From: HH Date: Thu, 7 Sep 2023 12:22:00 +0800 Subject: [PATCH 35/38] remove un-relevant changes Signed-off-by: HH --- flytekit/core/type_engine.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index e8a6c3f41f..235d519e75 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -849,7 +849,6 @@ def lazy_import_transformers(cls): register_arrow_handlers, register_bigquery_handlers, register_pandas_handlers, - register_snowflake_handlers, ) if is_imported("tensorflow"): @@ -871,11 +870,6 @@ def lazy_import_transformers(cls): if is_imported("numpy"): from flytekit.types import numpy # noqa: F401 - try: - register_snowflake_handlers() - except ValueError as e: - logger.debug(f"Attempted to register the Snowflake handler but failed due to: {str(e)}") - @classmethod def to_literal_type(cls, python_type: Type) -> LiteralType: """ From fb265dda40ba34b369f3c1c2c04b1630291ae48f Mon Sep 17 00:00:00 2001 From: HH Date: Fri, 8 Sep 2023 09:28:24 +0800 Subject: [PATCH 36/38] fix the suggestion part Signed-off-by: HH --- flytekit/core/type_engine.py | 4 ++-- tests/flytekit/unit/core/test_type_engine.py | 15 --------------- 2 files changed, 2 insertions(+), 17 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 235d519e75..9242cd3c27 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1630,7 +1630,7 @@ def generate_attribute_list_from_dataclass_json(schema: dict, schema_name: typin return attribute_list -def convert_marshmallow_json_schema_to_python_class(schema: dict, schema_name: typing.Any, is_dataclass_json_mixin: bool = False) -> Type[dataclasses.dataclass()]: # type: ignore +def convert_marshmallow_json_schema_to_python_class(schema: dict, schema_name: typing.Any) -> Type[dataclasses.dataclass()]: # type: ignore """ Generate a model class based on the provided JSON Schema :param schema: dict representing valid JSON schema @@ -1641,7 +1641,7 @@ def convert_marshmallow_json_schema_to_python_class(schema: dict, schema_name: t return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list)) -def convert_mashumaro_json_schema_to_python_class(schema: dict, schema_name: typing.Any, is_dataclass_json_mixin: bool = False) -> Type[dataclasses.dataclass()]: # type: ignore +def convert_mashumaro_json_schema_to_python_class(schema: dict, schema_name: typing.Any) -> Type[dataclasses.dataclass()]: # type: ignore """ Generate a model class based on the provided JSON Schema :param schema: dict representing valid JSON schema diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index fec17a9158..e2faad607e 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -944,21 +944,6 @@ def test_optional_flytefile_in_dataclassjsonmixin(mock_upload_dir): lt = tf.get_literal_type(TestFileStruct_optional_flytefile) lv = tf.to_literal(ctx, o, TestFileStruct_optional_flytefile, lt) - # assert lv.scalar.generic["a"]["path"] == remote_path - # assert lv.scalar.generic["b"]["path"] == remote_path - # assert lv.scalar.generic["b_prime"] is None - # assert lv.scalar.generic["c"]["path"] == remote_path - # assert lv.scalar.generic["d"]["path"] == remote_path - # assert lv.scalar.generic["e"]["path"]== remote_path - # assert lv.scalar.generic["e_prime"].values[0].WhichOneof("kind") == "null_value" - # assert lv.scalar.generic["f"]["a"]["path"] == remote_path - # assert lv.scalar.generic["g"]["a"]["path"] == remote_path - # assert lv.scalar.generic["g_prime"]["a"] is None - # assert lv.scalar.generic["h"]["path"] == remote_path - # assert lv.scalar.generic["h_prime"] is None - # assert lv.scalar.generic["i"]["a"] == 42 - # assert lv.scalar.generic["i_prime"]["a"] == 99 - assert lv.scalar.generic["a"].fields["path"].string_value == remote_path assert lv.scalar.generic["b"].fields["path"].string_value == remote_path assert lv.scalar.generic["b_prime"] is None From a322dec3c74c8938fd59d763958531785376c327 Mon Sep 17 00:00:00 2001 From: HH Date: Fri, 8 Sep 2023 15:33:52 +0800 Subject: [PATCH 37/38] add test to cover the code cov && fix some schema name in type_engine Signed-off-by: HH --- flytekit/core/type_engine.py | 7 +- tests/flytekit/unit/core/test_type_engine.py | 113 ++++++++++++++++++- 2 files changed, 116 insertions(+), 4 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 9242cd3c27..b752f17351 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1587,16 +1587,19 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name: # Handle dataclass and dict elif property_type == "object": if property_val.get("anyOf"): + sub_schemea = property_val["anyOf"][0] + sub_schemea_name = sub_schemea["title"] attribute_list.append( - (property_key, convert_mashumaro_json_schema_to_python_class(property_val["anyOf"][0], schema_name)) + (property_key, convert_mashumaro_json_schema_to_python_class(sub_schemea, sub_schemea_name)) ) elif property_val.get("additionalProperties"): attribute_list.append( (property_key, typing.Dict[str, _get_element_type(property_val["additionalProperties"])]) # type: ignore ) else: + sub_schemea_name = property_val["title"] attribute_list.append( - (property_key, convert_mashumaro_json_schema_to_python_class(property_val, schema_name)) + (property_key, convert_mashumaro_json_schema_to_python_class(property_val, sub_schemea_name)) ) elif property_type == "enum": attribute_list.append([property_key, str]) # type: ignore diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index e2faad607e..6f67e72be8 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -1,3 +1,4 @@ +import dataclasses import datetime import json import os @@ -13,7 +14,7 @@ import pyarrow as pa import pytest import typing_extensions -from dataclasses_json import DataClassJsonMixin +from dataclasses_json import DataClassJsonMixin, dataclass_json from flyteidl.core import errors_pb2 from google.protobuf import json_format as _json_format from google.protobuf import struct_pb2 as _struct @@ -289,6 +290,7 @@ class Foo(DataClassJsonMixin): assert guessed_pv[0].z.y == pv[0].z.y assert guessed_pv[0].z.z == pv[0].z.z assert pv[0] == dataclass_from_dict(Foo, asdict(guessed_pv[0])) + assert dataclasses.is_dataclass(foo_class) @dataclass @@ -348,6 +350,7 @@ def test_list_of_dataclassjsonmixin_getting_python_value(): assert guessed_pv[0].z.y == pv[0].z.y assert guessed_pv[0].z.z == pv[0].z.z assert pv[0] == dataclass_from_dict(Foo_getting_python_value, asdict(guessed_pv[0])) + assert dataclasses.is_dataclass(foo_class) def test_file_no_downloader_default(): @@ -485,6 +488,7 @@ class Foo(DataClassJsonMixin): assert foo.y == "hello" with pytest.raises(AttributeError): _ = foo.c + assert dataclasses.is_dataclass(foo_class) def test_convert_mashumaro_json_schema_to_python_class(): @@ -504,6 +508,7 @@ class Foo(DataClassJSONMixin): assert foo.y == "hello" with pytest.raises(AttributeError): _ = foo.c + assert dataclasses.is_dataclass(foo_class) def test_list_transformer(): @@ -1408,7 +1413,7 @@ class Bar(DataClassJSONMixin): pv = Bar(x=3) with pytest.raises( TypeTransformerFailedError, - match="Type of Val '' is not an instance of ", + match="Type of Val '' is not an instance of ", ): DataclassTransformer().assert_type(gt, pv) @@ -2220,3 +2225,107 @@ def test_get_underlying_type(t, expected): def test_dict_get(): assert DictTransformer.get_dict_types(None) == (None, None) + + +def test_DataclassTransformer_get_literal_type(): + @dataclass + class MyDataClassMashumaro(DataClassJsonMixin): + x: int + + @dataclass_json + @dataclass + class MyDataClass: + x: int + + de = DataclassTransformer() + + literal_type = de.get_literal_type(MyDataClass) + assert literal_type is not None + + literal_type = de.get_literal_type(MyDataClassMashumaro) + assert literal_type is not None + + invalid_json_str = "{ unbalanced_braces" + with pytest.raises(Exception): + Literal(scalar=Scalar(generic=_json_format.Parse(invalid_json_str, _struct.Struct()))) + + +def test_DataclassTransformer_to_literal(): + @dataclass + class MyDataClassMashumaro(DataClassJsonMixin): + x: int + + @dataclass_json + @dataclass + class MyDataClass: + x: int + + transformer = DataclassTransformer() + ctx = FlyteContext.current_context() + + my_dat_class_mashumaro = MyDataClassMashumaro(5) + my_data_class = MyDataClass(5) + + lv_mashumaro = transformer.to_literal(ctx, my_dat_class_mashumaro, MyDataClassMashumaro, MyDataClassMashumaro) + assert lv_mashumaro is not None + assert lv_mashumaro.scalar.generic["x"] == 5 + + lv = transformer.to_literal(ctx, my_data_class, MyDataClass, MyDataClass) + assert lv is not None + assert lv.scalar.generic["x"] == 5 + + +def test_DataclassTransformer_to_python_value(): + @dataclass + class MyDataClassMashumaro(DataClassJsonMixin): + x: int + + @dataclass_json + @dataclass + class MyDataClass: + x: int + + de = DataclassTransformer() + + json_str = '{ "x" : 5 }' + mock_literal = Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) + + result = de.to_python_value(FlyteContext.current_context(), mock_literal, MyDataClass) + assert isinstance(result, MyDataClass) + assert result.x == 5 + + result = de.to_python_value(FlyteContext.current_context(), mock_literal, MyDataClassMashumaro) + assert isinstance(result, MyDataClassMashumaro) + assert result.x == 5 + + +def test_DataclassTransformer_guess_python_type(): + @dataclass + class DatumMashumaro(DataClassJSONMixin): + x: int + y: Color + + @dataclass_json + @dataclass + class Datum(DataClassJSONMixin): + x: int + y: Color + + transformer = DataclassTransformer() + ctx = FlyteContext.current_context() + + lt = TypeEngine.to_literal_type(Datum) + datum = Datum(5, Color.RED) + lv = transformer.to_literal(ctx, datum, Datum, lt) + gt = transformer.guess_python_type(lt) + pv = transformer.to_python_value(ctx, lv, expected_python_type=gt) + assert datum.x == pv.x + assert datum.y.value == pv.y + + lt = TypeEngine.to_literal_type(DatumMashumaro) + datum_mashumaro = DatumMashumaro(5, Color.RED) + lv = transformer.to_literal(ctx, datum_mashumaro, DatumMashumaro, lt) + gt = transformer.guess_python_type(lt) + pv = transformer.to_python_value(ctx, lv, expected_python_type=gt) + assert datum_mashumaro.x == pv.x + assert datum_mashumaro.y.value == pv.y From a05dd4fc9a0c610e1fbec3fe69e31fc6b482ddaf Mon Sep 17 00:00:00 2001 From: HH Date: Sat, 9 Sep 2023 08:46:06 +0800 Subject: [PATCH 38/38] remove dedundant branch condition in type_engine and import in dev-requirement.in Signed-off-by: HH --- dev-requirements.in | 1 - flytekit/core/type_engine.py | 5 +---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/dev-requirements.in b/dev-requirements.in index 4a9df85e53..2c7ddd00c2 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -11,7 +11,6 @@ pre-commit codespell google-cloud-bigquery google-cloud-bigquery-storage -snowflake-connector-python IPython keyrings.alt diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index b752f17351..7ee95c772b 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -393,10 +393,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp ) self._serialize_flyte_type(python_val, python_type) - if issubclass(type(python_val), DataClassJsonMixin): - json_str = cast(DataClassJsonMixin, python_val).to_json() # type: ignore - else: - json_str = cast(DataClassJSONMixin, python_val).to_json() # type: ignore + json_str = python_val.to_json() # type: ignore return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) # type: ignore