Skip to content

Commit

Permalink
Add assert_type in dataclass transformer (#1149)
Browse files Browse the repository at this point in the history
* Add assert_type in dataclassTransformer

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* fix tests

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* fix tests

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* more tests

Signed-off-by: Kevin Su <[email protected]>

* fix lint

Signed-off-by: Kevin Su <[email protected]>

* Add one more test

Signed-off-by: Kevin Su <[email protected]>

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Sep 8, 2022
1 parent aff19cb commit eced84b
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 0 deletions.
51 changes: 51 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,46 @@ class Test():
def __init__(self):
super().__init__("Object-Dataclass-Transformer", object)

def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T):
# Skip iterating all attributes in the dataclass if the type of v already matches the expected_type
if type(v) == expected_type:
return

# @dataclass_json
# @dataclass
# class Foo(object):
# a: int = 0
#
# @task
# def t1(a: Foo):
# ...
#
# In above example, the type of v may not equal to the expected_type in some cases
# For example,
# 1. The input of t1 is another dataclass (bar), then we should raise an error
# 2. when using flyte remote to execute the above task, the expected_type is guess_python_type (FooSchema) by default.
# However, FooSchema is created by flytekit and it's not equal to the user-defined dataclass (Foo).
# Therefore, we should iterate all attributes in the dataclass and check the type of value in dataclass matches the expected_type.

expected_fields_dict = {}
for f in dataclasses.fields(expected_type):
expected_fields_dict[f.name] = f.type

for f in dataclasses.fields(type(v)):
original_type = f.type
expected_type = expected_fields_dict[f.name]

if UnionTransformer.is_optional_type(original_type):
original_type = UnionTransformer.get_sub_type_in_optional(original_type)
if UnionTransformer.is_optional_type(expected_type):
expected_type = UnionTransformer.get_sub_type_in_optional(expected_type)

val = v.__getattribute__(f.name)
if dataclasses.is_dataclass(val):
self.assert_type(expected_type, val)
elif original_type != expected_type:
raise TypeTransformerFailedError(f"Type of Val '{original_type}' is not an instance of {expected_type}")

def get_literal_type(self, t: Type[T]) -> LiteralType:
"""
Extracts the Literal type definition for a Dataclass and returns a type Struct.
Expand Down Expand Up @@ -977,6 +1017,17 @@ class UnionTransformer(TypeTransformer[T]):
def __init__(self):
super().__init__("Typed Union", typing.Union)

@staticmethod
def is_optional_type(t: Type[T]) -> bool:
return get_origin(t) is typing.Union and type(None) in get_args(t)

@staticmethod
def get_sub_type_in_optional(t: Type[T]) -> Type[T]:
"""
Return the generic Type T of the Optional type
"""
return get_args(t)[0]

def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]:
if get_origin(t) is Annotated:
t = get_args(t)[0]
Expand Down
38 changes: 38 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
TypeEngine,
TypeTransformer,
TypeTransformerFailedError,
UnionTransformer,
convert_json_schema_to_python_class,
dataclass_from_dict,
)
Expand Down Expand Up @@ -791,6 +792,43 @@ def test_union_type():
assert v == "hello"


def test_assert_dataclass_type():
@dataclass_json
@dataclass
class Args(object):
x: int
y: typing.Optional[str]

@dataclass_json
@dataclass
class Schema(object):
x: typing.Optional[Args] = None

pt = Schema
lt = TypeEngine.to_literal_type(pt)
gt = TypeEngine.guess_python_type(lt)
pv = Schema(x=Args(x=3, y="hello"))
DataclassTransformer().assert_type(gt, pv)
DataclassTransformer().assert_type(Schema, pv)

@dataclass_json
@dataclass
class Bar(object):
x: int

pv = Bar(x=3)
with pytest.raises(
TypeTransformerFailedError, match="Type of Val '<class 'int'>' is not an instance of <class 'types.ArgsSchema'>"
):
DataclassTransformer().assert_type(gt, pv)


def test_union_transformer():
assert UnionTransformer.is_optional_type(typing.Optional[int])
assert not UnionTransformer.is_optional_type(str)
assert UnionTransformer.get_sub_type_in_optional(typing.Optional[int]) == int


def test_union_type_with_annotated():
pt = typing.Union[
Annotated[str, FlyteAnnotation({"hello": "world"})], Annotated[int, FlyteAnnotation({"test": 123})]
Expand Down

0 comments on commit eced84b

Please sign in to comment.