Skip to content

Commit

Permalink
feat: Support pep604 union operator (#2298)
Browse files Browse the repository at this point in the history
* feat: Support pep604 union operator

Signed-off-by: ggydush <[email protected]>

* fix: check union early

Signed-off-by: ggydush <[email protected]>

* refactor: Change name to 604

Signed-off-by: ggydush <[email protected]>

* fix lint

Signed-off-by: ggydush <[email protected]>

* fix: Fix duplicated code

Signed-off-by: ggydush <[email protected]>

* fix: Remove code for testing

Signed-off-by: ggydush <[email protected]>

* test: Add simple tests

Signed-off-by: ggydush <[email protected]>

* Add more tests

Signed-off-by: ggydush <[email protected]>

* fix: Fix names

Signed-off-by: ggydush <[email protected]>

* fix: Lint

Signed-off-by: ggydush <[email protected]>

* fix: Fix again

Signed-off-by: ggydush <[email protected]>

* fix: Fix default

Signed-off-by: ggydush <[email protected]>

* test: Add test for parameter and defaults

Signed-off-by: ggydush <[email protected]>

* fix: Fix code coverage by ignoring

Signed-off-by: ggydush <[email protected]>

* refactor: Use is_union_type

Signed-off-by: ggydush <[email protected]>

* fix import sort

Signed-off-by: ggydush <[email protected]>

* fix: cleanup

Signed-off-by: ggydush <[email protected]>

* fix: fix

Signed-off-by: ggydush <[email protected]>

* refactor: Clean it up

Signed-off-by: ggydush <[email protected]>

* fix: Fix lint

Signed-off-by: ggydush <[email protected]>

* fix: Fix pydantic plugin test failure

Signed-off-by: ggydush <[email protected]>

* Update flytekit/core/type_engine.py

Co-authored-by: Kevin Su <[email protected]>

* Address comment

Signed-off-by: ggydush <[email protected]>

* fix: Use UnionTransformer

Signed-off-by: ggydush <[email protected]>

* fix: Fix lint

Signed-off-by: ggydush <[email protected]>

* Skip tests with | syntax on < 3.10

Signed-off-by: ggydush <[email protected]>

* fix: Fix test

Signed-off-by: ggydush <[email protected]>

* fix: More review comments

Signed-off-by: ggydush <[email protected]>

* fix: Fix lint

Signed-off-by: ggydush <[email protected]>

---------

Signed-off-by: ggydush <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
Signed-off-by: Jan Fiedler <[email protected]>
  • Loading branch information
2 people authored and fiedlerNr9 committed Jul 25, 2024
1 parent e07f0fe commit eb294c9
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 10 deletions.
6 changes: 3 additions & 3 deletions flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union, cast

from flyteidl.core import artifact_id_pb2 as art_id
from typing_extensions import get_args, get_origin, get_type_hints
from typing_extensions import get_args, get_type_hints

from flytekit.core import context_manager
from flytekit.core.artifact import Artifact, ArtifactIDSpecification, ArtifactQuery
from flytekit.core.docstring import Docstring
from flytekit.core.sentinel import DYNAMIC_INPUT_BINDING
from flytekit.core.type_engine import TypeEngine
from flytekit.core.type_engine import TypeEngine, UnionTransformer
from flytekit.exceptions.user import FlyteValidationException
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
Expand Down Expand Up @@ -218,7 +218,7 @@ def transform_inputs_to_parameters(
inputs_with_def = interface.inputs_with_defaults
for k, v in inputs_vars.items():
val, _default = inputs_with_def[k]
if _default is None and get_origin(val) is typing.Union and type(None) in get_args(val):
if _default is None and UnionTransformer.is_optional_type(val):
literal = Literal(scalar=Scalar(none_type=Void()))
params[k] = _interface_models.Parameter(var=v, default=literal, required=False)
else:
Expand Down
20 changes: 15 additions & 5 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import json
import json as _json
import mimetypes
import sys
import textwrap
import typing
from abc import ABC, abstractmethod
Expand All @@ -26,6 +27,7 @@
from marshmallow_enum import EnumField, LoadDumpOptions
from mashumaro.mixins.json import DataClassJSONMixin
from typing_extensions import Annotated, get_args, get_origin
from typing_inspect import is_union_type

from flytekit.core.annotation import FlyteAnnotation
from flytekit.core.context_manager import FlyteContext
Expand Down Expand Up @@ -547,7 +549,7 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.A
from flytekit.types.structured.structured_dataset import StructuredDataset

# Handle Optional
if get_origin(python_type) is typing.Union and type(None) in get_args(python_type):
if UnionTransformer.is_optional_type(python_type):
if python_val is None:
return None
return self._serialize_flyte_type(python_val, get_args(python_type)[0])
Expand Down Expand Up @@ -600,7 +602,7 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) ->
from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetTransformerEngine

# Handle Optional
if get_origin(expected_python_type) is typing.Union and type(None) in get_args(expected_python_type):
if UnionTransformer.is_optional_type(expected_python_type):
if python_val is None:
return None
return self._deserialize_flyte_type(python_val, get_args(expected_python_type)[0])
Expand Down Expand Up @@ -694,7 +696,7 @@ def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any:
if val is None:
return val

if get_origin(t) is typing.Union and type(None) in get_args(t):
if UnionTransformer.is_optional_type(t):
# Handle optional type. e.g. Optional[int], Optional[dataclass]
# Marshmallow doesn't support union type, so the type here is always an optional type.
# https://github.com/marshmallow-code/marshmallow/issues/1191#issuecomment-480831796
Expand Down Expand Up @@ -961,6 +963,9 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]:
Step 5:
if v is of type data class, use the dataclass transformer
Step 6:
Pickle transformer is used
"""
cls.lazy_import_transformers()
# Step 1
Expand Down Expand Up @@ -1496,7 +1501,7 @@ def __init__(self):

@staticmethod
def is_optional_type(t: Type[T]) -> bool:
return get_origin(t) is typing.Union and type(None) in get_args(t)
return is_union_type(t) and type(None) in get_args(t)

@staticmethod
def get_sub_type_in_optional(t: Type[T]) -> Type[T]:
Expand Down Expand Up @@ -1968,7 +1973,12 @@ def _register_default_type_transformers():
[None],
)
TypeEngine.register(ListTransformer())
TypeEngine.register(UnionTransformer())
if sys.version_info >= (3, 10):
from types import UnionType

TypeEngine.register(UnionTransformer(), [UnionType])
else:
TypeEngine.register(UnionTransformer())
TypeEngine.register(DictTransformer())
TypeEngine.register(TextIOTransformer())
TypeEngine.register(BinaryIOTransformer())
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import builtins
import datetime
import types
import typing
from typing import Set

Expand All @@ -11,7 +12,9 @@
numpy = lazy_module("numpy")
pyarrow = lazy_module("pyarrow")

MODULES_TO_EXCLUDE_FROM_FLYTE_TYPES: Set[str] = {m.__name__ for m in [builtins, typing, datetime, pyarrow, numpy]}
MODULES_TO_EXCLUDE_FROM_FLYTE_TYPES: Set[str] = {
m.__name__ for m in [builtins, types, typing, datetime, pyarrow, numpy]
}


def include_in_flyte_types(t: type) -> bool:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ dependencies = [
"s3fs>=2023.3.0,!=2024.3.1",
"statsd>=3.0.0,<4.0.0",
"typing_extensions",
"typing-inspect",
"urllib3>=1.22,<2.0.0",
]
classifiers = [
Expand Down
14 changes: 14 additions & 0 deletions tests/flytekit/unit/core/test_interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys
import typing
from typing import Dict, List

Expand Down Expand Up @@ -156,6 +157,7 @@ def t1() -> FlyteFile[typing.TypeVar("svg")]:
assert return_type["o0"].extension() == FlyteFile[typing.TypeVar("svg")].extension()


@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10.")
def test_parameters_and_defaults():
ctx = context_manager.FlyteContext.current_context()

Expand Down Expand Up @@ -213,6 +215,18 @@ def z(
assert not params.parameters["c"].required
assert params.parameters["c"].default.scalar.none_type == Void()

def z(a: int | None = None, b: str | None = None, c: typing.List[int] | None = None) -> typing.Tuple[int, str]:
...

our_interface = transform_function_to_interface(z)
params = transform_inputs_to_parameters(ctx, our_interface)
assert not params.parameters["a"].required
assert params.parameters["a"].default.scalar.none_type == Void()
assert not params.parameters["b"].required
assert params.parameters["b"].default.scalar.none_type == Void()
assert not params.parameters["c"].required
assert params.parameters["c"].default.scalar.none_type == Void()


def test_parameters_with_docstring():
ctx = context_manager.FlyteContext.current_context()
Expand Down
19 changes: 18 additions & 1 deletion tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,9 +1368,13 @@ def union_type_tags_unique(t: LiteralType):
return True


@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10.")
def test_union_type():
pt = typing.Union[str, int]
lt = TypeEngine.to_literal_type(pt)
pt_604 = str | int
lt_604 = TypeEngine.to_literal_type(pt_604)
assert lt == lt_604
assert lt.union_type.variants == [
LiteralType(simple=SimpleType.STRING, structure=TypeStructure(tag="str")),
LiteralType(simple=SimpleType.INTEGER, structure=TypeStructure(tag="int")),
Expand Down Expand Up @@ -1526,10 +1530,13 @@ class Bar(DataClassJSONMixin):
DataclassTransformer().assert_type(gt, pv)


@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10.")
def test_union_transformer():
assert UnionTransformer.is_optional_type(typing.Optional[int])
assert UnionTransformer.is_optional_type(int | None)
assert not UnionTransformer.is_optional_type(str)
assert UnionTransformer.get_sub_type_in_optional(typing.Optional[int]) == int
assert UnionTransformer.get_sub_type_in_optional(int | None) == int


def test_union_guess_type():
Expand Down Expand Up @@ -1597,9 +1604,13 @@ def test_annotated_union_type():
assert v == "hello"


@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10.")
def test_optional_type():
pt = typing.Optional[int]
lt = TypeEngine.to_literal_type(pt)
pt_604 = int | None
lt_604 = TypeEngine.to_literal_type(pt_604)
assert lt == lt_604
assert lt.union_type.variants == [
LiteralType(simple=SimpleType.INTEGER, structure=TypeStructure(tag="int")),
LiteralType(simple=SimpleType.NONE, structure=TypeStructure(tag="none")),
Expand Down Expand Up @@ -1791,9 +1802,13 @@ def test_union_of_lists():
assert v == [1, 3]


@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10.")
def test_list_of_unions():
pt = typing.List[typing.Union[str, int]]
lt = TypeEngine.to_literal_type(pt)
pt_604 = typing.List[str | int]
lt_604 = TypeEngine.to_literal_type(pt_604)
assert lt == lt_604
# todo(maximsmol): seems like the order here is non-deterministic
assert lt.collection_type.union_type.variants == [
LiteralType(simple=SimpleType.STRING, structure=TypeStructure(tag="str")),
Expand All @@ -1804,8 +1819,10 @@ def test_list_of_unions():
ctx = FlyteContextManager.current_context()
lv = TypeEngine.to_literal(ctx, ["hello", 123, "world"], pt, lt)
v = TypeEngine.to_python_value(ctx, lv, pt)
lv_604 = TypeEngine.to_literal(ctx, ["hello", 123, "world"], pt_604, lt_604)
v_604 = TypeEngine.to_python_value(ctx, lv_604, pt_604)
assert [x.scalar.union.stored_type.structure.tag for x in lv.collection.literals] == ["str", "int", "str"]
assert v == ["hello", 123, "world"]
assert v == v_604 == ["hello", 123, "world"]


def test_pickle_type():
Expand Down

0 comments on commit eb294c9

Please sign in to comment.