Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support pep604 union operator #2298

Merged
merged 29 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union, cast

from flyteidl.core import artifact_id_pb2 as art_id
from typing_extensions import get_args, get_origin, get_type_hints
from typing_extensions import get_args, get_type_hints

from flytekit.core import context_manager
from flytekit.core.artifact import Artifact, ArtifactIDSpecification, ArtifactQuery
from flytekit.core.docstring import Docstring
from flytekit.core.sentinel import DYNAMIC_INPUT_BINDING
from flytekit.core.type_engine import TypeEngine
from flytekit.core.type_engine import TypeEngine, UnionTransformer
from flytekit.exceptions.user import FlyteValidationException
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
Expand Down Expand Up @@ -218,7 +218,7 @@ def transform_inputs_to_parameters(
inputs_with_def = interface.inputs_with_defaults
for k, v in inputs_vars.items():
val, _default = inputs_with_def[k]
if _default is None and get_origin(val) is typing.Union and type(None) in get_args(val):
if _default is None and UnionTransformer.is_optional_type(val):
literal = Literal(scalar=Scalar(none_type=Void()))
params[k] = _interface_models.Parameter(var=v, default=literal, required=False)
else:
Expand Down
20 changes: 15 additions & 5 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import json
import json as _json
import mimetypes
import sys
import textwrap
import typing
from abc import ABC, abstractmethod
Expand All @@ -26,6 +27,7 @@
from marshmallow_enum import EnumField, LoadDumpOptions
from mashumaro.mixins.json import DataClassJSONMixin
from typing_extensions import Annotated, get_args, get_origin
from typing_inspect import is_union_type
ggydush marked this conversation as resolved.
Show resolved Hide resolved

from flytekit.core.annotation import FlyteAnnotation
from flytekit.core.context_manager import FlyteContext
Expand Down Expand Up @@ -547,7 +549,7 @@
from flytekit.types.structured.structured_dataset import StructuredDataset

# Handle Optional
if get_origin(python_type) is typing.Union and type(None) in get_args(python_type):
if UnionTransformer.is_optional_type(python_type):
if python_val is None:
return None
return self._serialize_flyte_type(python_val, get_args(python_type)[0])
Expand Down Expand Up @@ -600,7 +602,7 @@
from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetTransformerEngine

# Handle Optional
if get_origin(expected_python_type) is typing.Union and type(None) in get_args(expected_python_type):
if UnionTransformer.is_optional_type(expected_python_type):
if python_val is None:
return None
return self._deserialize_flyte_type(python_val, get_args(expected_python_type)[0])
Expand Down Expand Up @@ -694,7 +696,7 @@
if val is None:
return val

if get_origin(t) is typing.Union and type(None) in get_args(t):
if UnionTransformer.is_optional_type(t):
# Handle optional type. e.g. Optional[int], Optional[dataclass]
# Marshmallow doesn't support union type, so the type here is always an optional type.
# https://github.com/marshmallow-code/marshmallow/issues/1191#issuecomment-480831796
Expand Down Expand Up @@ -961,6 +963,9 @@

Step 5:
if v is of type data class, use the dataclass transformer

Step 6:
Pickle transformer is used
"""
cls.lazy_import_transformers()
# Step 1
Expand Down Expand Up @@ -1496,7 +1501,7 @@

@staticmethod
def is_optional_type(t: Type[T]) -> bool:
return get_origin(t) is typing.Union and type(None) in get_args(t)
return is_union_type(t) and type(None) in get_args(t)

@staticmethod
def get_sub_type_in_optional(t: Type[T]) -> Type[T]:
Expand Down Expand Up @@ -1968,7 +1973,12 @@
[None],
)
TypeEngine.register(ListTransformer())
TypeEngine.register(UnionTransformer())
if sys.version_info >= (3, 10):
from types import UnionType

TypeEngine.register(UnionTransformer(), [UnionType])
else:
TypeEngine.register(UnionTransformer())

Check warning on line 1981 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L1981

Added line #L1981 was not covered by tests
TypeEngine.register(DictTransformer())
TypeEngine.register(TextIOTransformer())
TypeEngine.register(BinaryIOTransformer())
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import builtins
import datetime
import types
import typing
from typing import Set

Expand All @@ -11,7 +12,9 @@
numpy = lazy_module("numpy")
pyarrow = lazy_module("pyarrow")

MODULES_TO_EXCLUDE_FROM_FLYTE_TYPES: Set[str] = {m.__name__ for m in [builtins, typing, datetime, pyarrow, numpy]}
MODULES_TO_EXCLUDE_FROM_FLYTE_TYPES: Set[str] = {
m.__name__ for m in [builtins, types, typing, datetime, pyarrow, numpy]
}


def include_in_flyte_types(t: type) -> bool:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ dependencies = [
"s3fs>=2023.3.0,!=2024.3.1",
"statsd>=3.0.0,<4.0.0",
"typing_extensions",
"typing-inspect",
"urllib3>=1.22,<2.0.0",
]
classifiers = [
Expand Down
15 changes: 14 additions & 1 deletion tests/flytekit/unit/core/test_interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys
import typing
from typing import Dict, List

Expand Down Expand Up @@ -155,7 +156,7 @@ def t1() -> FlyteFile[typing.TypeVar("svg")]:
return_type = extract_return_annotation(typing.get_type_hints(t1).get("return", None))
assert return_type["o0"].extension() == FlyteFile[typing.TypeVar("svg")].extension()


@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10.")
def test_parameters_and_defaults():
ctx = context_manager.FlyteContext.current_context()

Expand Down Expand Up @@ -213,6 +214,18 @@ def z(
assert not params.parameters["c"].required
assert params.parameters["c"].default.scalar.none_type == Void()

def z(a: int | None = None, b: str | None = None, c: typing.List[int] | None = None) -> typing.Tuple[int, str]:
ggydush marked this conversation as resolved.
Show resolved Hide resolved
...

our_interface = transform_function_to_interface(z)
params = transform_inputs_to_parameters(ctx, our_interface)
assert not params.parameters["a"].required
assert params.parameters["a"].default.scalar.none_type == Void()
assert not params.parameters["b"].required
assert params.parameters["b"].default.scalar.none_type == Void()
assert not params.parameters["c"].required
assert params.parameters["c"].default.scalar.none_type == Void()


def test_parameters_with_docstring():
ctx = context_manager.FlyteContext.current_context()
Expand Down
19 changes: 18 additions & 1 deletion tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,9 +1368,13 @@ def union_type_tags_unique(t: LiteralType):
return True


@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10.")
def test_union_type():
pt = typing.Union[str, int]
lt = TypeEngine.to_literal_type(pt)
pt_604 = str | int
lt_604 = TypeEngine.to_literal_type(pt_604)
assert lt == lt_604
assert lt.union_type.variants == [
LiteralType(simple=SimpleType.STRING, structure=TypeStructure(tag="str")),
LiteralType(simple=SimpleType.INTEGER, structure=TypeStructure(tag="int")),
Expand Down Expand Up @@ -1526,10 +1530,13 @@ class Bar(DataClassJSONMixin):
DataclassTransformer().assert_type(gt, pv)


@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10.")
def test_union_transformer():
assert UnionTransformer.is_optional_type(typing.Optional[int])
assert UnionTransformer.is_optional_type(int | None)
assert not UnionTransformer.is_optional_type(str)
assert UnionTransformer.get_sub_type_in_optional(typing.Optional[int]) == int
assert UnionTransformer.get_sub_type_in_optional(int | None) == int


def test_union_guess_type():
Expand Down Expand Up @@ -1597,9 +1604,13 @@ def test_annotated_union_type():
assert v == "hello"


@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10.")
def test_optional_type():
pt = typing.Optional[int]
lt = TypeEngine.to_literal_type(pt)
pt_604 = int | None
lt_604 = TypeEngine.to_literal_type(pt_604)
assert lt == lt_604
assert lt.union_type.variants == [
LiteralType(simple=SimpleType.INTEGER, structure=TypeStructure(tag="int")),
LiteralType(simple=SimpleType.NONE, structure=TypeStructure(tag="none")),
Expand Down Expand Up @@ -1791,9 +1802,13 @@ def test_union_of_lists():
assert v == [1, 3]


@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10.")
def test_list_of_unions():
ggydush marked this conversation as resolved.
Show resolved Hide resolved
pt = typing.List[typing.Union[str, int]]
lt = TypeEngine.to_literal_type(pt)
pt_604 = typing.List[str | int]
lt_604 = TypeEngine.to_literal_type(pt_604)
assert lt == lt_604
# todo(maximsmol): seems like the order here is non-deterministic
assert lt.collection_type.union_type.variants == [
LiteralType(simple=SimpleType.STRING, structure=TypeStructure(tag="str")),
Expand All @@ -1804,8 +1819,10 @@ def test_list_of_unions():
ctx = FlyteContextManager.current_context()
lv = TypeEngine.to_literal(ctx, ["hello", 123, "world"], pt, lt)
v = TypeEngine.to_python_value(ctx, lv, pt)
lv_604 = TypeEngine.to_literal(ctx, ["hello", 123, "world"], pt_604, lt_604)
v_604 = TypeEngine.to_python_value(ctx, lv_604, pt_604)
assert [x.scalar.union.stored_type.structure.tag for x in lv.collection.literals] == ["str", "int", "str"]
assert v == ["hello", 123, "world"]
assert v == v_604 == ["hello", 123, "world"]


def test_pickle_type():
Expand Down
Loading