Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add assert_type in dataclass transformer #1149

Merged
merged 9 commits into from
Sep 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,46 @@ class Test():
def __init__(self):
super().__init__("Object-Dataclass-Transformer", object)

def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T):
# Skip iterating all attributes in the dataclass if the type of v already matches the expected_type
if type(v) == expected_type:
return

# @dataclass_json
# @dataclass
# class Foo(object):
# a: int = 0
#
# @task
# def t1(a: Foo):
# ...
#
# In above example, the type of v may not equal to the expected_type in some cases
# For example,
# 1. The input of t1 is another dataclass (bar), then we should raise an error
# 2. when using flyte remote to execute the above task, the expected_type is guess_python_type (FooSchema) by default.
# However, FooSchema is created by flytekit and it's not equal to the user-defined dataclass (Foo).
# Therefore, we should iterate all attributes in the dataclass and check the type of value in dataclass matches the expected_type.

expected_fields_dict = {}
for f in dataclasses.fields(expected_type):
expected_fields_dict[f.name] = f.type

for f in dataclasses.fields(type(v)):
original_type = f.type
expected_type = expected_fields_dict[f.name]

if UnionTransformer.is_optional_type(original_type):
original_type = UnionTransformer.get_sub_type_in_optional(original_type)
if UnionTransformer.is_optional_type(expected_type):
expected_type = UnionTransformer.get_sub_type_in_optional(expected_type)

val = v.__getattribute__(f.name)
if dataclasses.is_dataclass(val):
self.assert_type(expected_type, val)
elif original_type != expected_type:
raise TypeTransformerFailedError(f"Type of Val '{original_type}' is not an instance of {expected_type}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it easy to add a test for this line?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated it.
I've tried to find a kind of function that can compare the attribute in dataclass, but I didn't find that.


def get_literal_type(self, t: Type[T]) -> LiteralType:
"""
Extracts the Literal type definition for a Dataclass and returns a type Struct.
Expand Down Expand Up @@ -977,6 +1017,17 @@ class UnionTransformer(TypeTransformer[T]):
def __init__(self):
super().__init__("Typed Union", typing.Union)

@staticmethod
def is_optional_type(t: Type[T]) -> bool:
return get_origin(t) is typing.Union and type(None) in get_args(t)

@staticmethod
def get_sub_type_in_optional(t: Type[T]) -> Type[T]:
"""
Return the generic Type T of the Optional type
"""
return get_args(t)[0]

def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]:
if get_origin(t) is Annotated:
t = get_args(t)[0]
Expand Down
38 changes: 38 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
TypeEngine,
TypeTransformer,
TypeTransformerFailedError,
UnionTransformer,
convert_json_schema_to_python_class,
dataclass_from_dict,
)
Expand Down Expand Up @@ -791,6 +792,43 @@ def test_union_type():
assert v == "hello"


def test_assert_dataclass_type():
@dataclass_json
@dataclass
class Args(object):
x: int
y: typing.Optional[str]

@dataclass_json
@dataclass
class Schema(object):
x: typing.Optional[Args] = None

pt = Schema
lt = TypeEngine.to_literal_type(pt)
gt = TypeEngine.guess_python_type(lt)
pv = Schema(x=Args(x=3, y="hello"))
DataclassTransformer().assert_type(gt, pv)
DataclassTransformer().assert_type(Schema, pv)

@dataclass_json
@dataclass
class Bar(object):
x: int

pv = Bar(x=3)
with pytest.raises(
TypeTransformerFailedError, match="Type of Val '<class 'int'>' is not an instance of <class 'types.ArgsSchema'>"
):
DataclassTransformer().assert_type(gt, pv)


def test_union_transformer():
assert UnionTransformer.is_optional_type(typing.Optional[int])
assert not UnionTransformer.is_optional_type(str)
assert UnionTransformer.get_sub_type_in_optional(typing.Optional[int]) == int


def test_union_type_with_annotated():
pt = typing.Union[
Annotated[str, FlyteAnnotation({"hello": "world"})], Annotated[int, FlyteAnnotation({"test": 123})]
Expand Down