diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index 8a92d58591..aecca2936d 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -8,13 +8,13 @@ from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union, cast from flyteidl.core import artifact_id_pb2 as art_id -from typing_extensions import get_args, get_origin, get_type_hints +from typing_extensions import get_args, get_type_hints from flytekit.core import context_manager from flytekit.core.artifact import Artifact, ArtifactIDSpecification, ArtifactQuery from flytekit.core.docstring import Docstring from flytekit.core.sentinel import DYNAMIC_INPUT_BINDING -from flytekit.core.type_engine import TypeEngine +from flytekit.core.type_engine import TypeEngine, UnionTransformer from flytekit.exceptions.user import FlyteValidationException from flytekit.loggers import logger from flytekit.models import interface as _interface_models @@ -218,7 +218,7 @@ def transform_inputs_to_parameters( inputs_with_def = interface.inputs_with_defaults for k, v in inputs_vars.items(): val, _default = inputs_with_def[k] - if _default is None and get_origin(val) is typing.Union and type(None) in get_args(val): + if _default is None and UnionTransformer.is_optional_type(val): literal = Literal(scalar=Scalar(none_type=Void())) params[k] = _interface_models.Parameter(var=v, default=literal, required=False) else: diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 9153fca032..df90991b3c 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -9,6 +9,7 @@ import json import json as _json import mimetypes +import sys import textwrap import typing from abc import ABC, abstractmethod @@ -26,6 +27,7 @@ from marshmallow_enum import EnumField, LoadDumpOptions from mashumaro.mixins.json import DataClassJSONMixin from typing_extensions import Annotated, get_args, get_origin +from typing_inspect import is_union_type from flytekit.core.annotation import FlyteAnnotation from flytekit.core.context_manager import FlyteContext @@ -547,7 +549,7 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.A from flytekit.types.structured.structured_dataset import StructuredDataset # Handle Optional - if get_origin(python_type) is typing.Union and type(None) in get_args(python_type): + if UnionTransformer.is_optional_type(python_type): if python_val is None: return None return self._serialize_flyte_type(python_val, get_args(python_type)[0]) @@ -600,7 +602,7 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetTransformerEngine # Handle Optional - if get_origin(expected_python_type) is typing.Union and type(None) in get_args(expected_python_type): + if UnionTransformer.is_optional_type(expected_python_type): if python_val is None: return None return self._deserialize_flyte_type(python_val, get_args(expected_python_type)[0]) @@ -694,7 +696,7 @@ def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any: if val is None: return val - if get_origin(t) is typing.Union and type(None) in get_args(t): + if UnionTransformer.is_optional_type(t): # Handle optional type. e.g. Optional[int], Optional[dataclass] # Marshmallow doesn't support union type, so the type here is always an optional type. # https://github.com/marshmallow-code/marshmallow/issues/1191#issuecomment-480831796 @@ -961,6 +963,9 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]: Step 5: if v is of type data class, use the dataclass transformer + + Step 6: + Pickle transformer is used """ cls.lazy_import_transformers() # Step 1 @@ -1496,7 +1501,7 @@ def __init__(self): @staticmethod def is_optional_type(t: Type[T]) -> bool: - return get_origin(t) is typing.Union and type(None) in get_args(t) + return is_union_type(t) and type(None) in get_args(t) @staticmethod def get_sub_type_in_optional(t: Type[T]) -> Type[T]: @@ -1968,7 +1973,12 @@ def _register_default_type_transformers(): [None], ) TypeEngine.register(ListTransformer()) - TypeEngine.register(UnionTransformer()) + if sys.version_info >= (3, 10): + from types import UnionType + + TypeEngine.register(UnionTransformer(), [UnionType]) + else: + TypeEngine.register(UnionTransformer()) TypeEngine.register(DictTransformer()) TypeEngine.register(TextIOTransformer()) TypeEngine.register(BinaryIOTransformer()) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/commons.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/commons.py index 79d56d8354..36358ddbd8 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/commons.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/commons.py @@ -1,5 +1,6 @@ import builtins import datetime +import types import typing from typing import Set @@ -11,7 +12,9 @@ numpy = lazy_module("numpy") pyarrow = lazy_module("pyarrow") -MODULES_TO_EXCLUDE_FROM_FLYTE_TYPES: Set[str] = {m.__name__ for m in [builtins, typing, datetime, pyarrow, numpy]} +MODULES_TO_EXCLUDE_FROM_FLYTE_TYPES: Set[str] = { + m.__name__ for m in [builtins, types, typing, datetime, pyarrow, numpy] +} def include_in_flyte_types(t: type) -> bool: diff --git a/pyproject.toml b/pyproject.toml index df457ac3ad..5678d342b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ dependencies = [ "s3fs>=2023.3.0,!=2024.3.1", "statsd>=3.0.0,<4.0.0", "typing_extensions", + "typing-inspect", "urllib3>=1.22,<2.0.0", ] classifiers = [ diff --git a/tests/flytekit/unit/core/test_interface.py b/tests/flytekit/unit/core/test_interface.py index b8f743118b..ecf25fb6e0 100644 --- a/tests/flytekit/unit/core/test_interface.py +++ b/tests/flytekit/unit/core/test_interface.py @@ -1,4 +1,5 @@ import os +import sys import typing from typing import Dict, List @@ -156,6 +157,7 @@ def t1() -> FlyteFile[typing.TypeVar("svg")]: assert return_type["o0"].extension() == FlyteFile[typing.TypeVar("svg")].extension() +@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10.") def test_parameters_and_defaults(): ctx = context_manager.FlyteContext.current_context() @@ -213,6 +215,18 @@ def z( assert not params.parameters["c"].required assert params.parameters["c"].default.scalar.none_type == Void() + def z(a: int | None = None, b: str | None = None, c: typing.List[int] | None = None) -> typing.Tuple[int, str]: + ... + + our_interface = transform_function_to_interface(z) + params = transform_inputs_to_parameters(ctx, our_interface) + assert not params.parameters["a"].required + assert params.parameters["a"].default.scalar.none_type == Void() + assert not params.parameters["b"].required + assert params.parameters["b"].default.scalar.none_type == Void() + assert not params.parameters["c"].required + assert params.parameters["c"].default.scalar.none_type == Void() + def test_parameters_with_docstring(): ctx = context_manager.FlyteContext.current_context() diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index b9c9dac8f2..b94249f804 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -1368,9 +1368,13 @@ def union_type_tags_unique(t: LiteralType): return True +@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10.") def test_union_type(): pt = typing.Union[str, int] lt = TypeEngine.to_literal_type(pt) + pt_604 = str | int + lt_604 = TypeEngine.to_literal_type(pt_604) + assert lt == lt_604 assert lt.union_type.variants == [ LiteralType(simple=SimpleType.STRING, structure=TypeStructure(tag="str")), LiteralType(simple=SimpleType.INTEGER, structure=TypeStructure(tag="int")), @@ -1526,10 +1530,13 @@ class Bar(DataClassJSONMixin): DataclassTransformer().assert_type(gt, pv) +@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10.") def test_union_transformer(): assert UnionTransformer.is_optional_type(typing.Optional[int]) + assert UnionTransformer.is_optional_type(int | None) assert not UnionTransformer.is_optional_type(str) assert UnionTransformer.get_sub_type_in_optional(typing.Optional[int]) == int + assert UnionTransformer.get_sub_type_in_optional(int | None) == int def test_union_guess_type(): @@ -1597,9 +1604,13 @@ def test_annotated_union_type(): assert v == "hello" +@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10.") def test_optional_type(): pt = typing.Optional[int] lt = TypeEngine.to_literal_type(pt) + pt_604 = int | None + lt_604 = TypeEngine.to_literal_type(pt_604) + assert lt == lt_604 assert lt.union_type.variants == [ LiteralType(simple=SimpleType.INTEGER, structure=TypeStructure(tag="int")), LiteralType(simple=SimpleType.NONE, structure=TypeStructure(tag="none")), @@ -1791,9 +1802,13 @@ def test_union_of_lists(): assert v == [1, 3] +@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10.") def test_list_of_unions(): pt = typing.List[typing.Union[str, int]] lt = TypeEngine.to_literal_type(pt) + pt_604 = typing.List[str | int] + lt_604 = TypeEngine.to_literal_type(pt_604) + assert lt == lt_604 # todo(maximsmol): seems like the order here is non-deterministic assert lt.collection_type.union_type.variants == [ LiteralType(simple=SimpleType.STRING, structure=TypeStructure(tag="str")), @@ -1804,8 +1819,10 @@ def test_list_of_unions(): ctx = FlyteContextManager.current_context() lv = TypeEngine.to_literal(ctx, ["hello", 123, "world"], pt, lt) v = TypeEngine.to_python_value(ctx, lv, pt) + lv_604 = TypeEngine.to_literal(ctx, ["hello", 123, "world"], pt_604, lt_604) + v_604 = TypeEngine.to_python_value(ctx, lv_604, pt_604) assert [x.scalar.union.stored_type.structure.tag for x in lv.collection.literals] == ["str", "int", "str"] - assert v == ["hello", 123, "world"] + assert v == v_604 == ["hello", 123, "world"] def test_pickle_type():