Skip to content

Commit

Permalink
Support optional type in dataclass
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw committed Aug 16, 2022
1 parent c4652cd commit d313077
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
2 changes: 2 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,8 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) ->
return python_val

def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any:
if val is None:
return val
if t == int:
return int(val)

Expand Down
8 changes: 6 additions & 2 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,14 @@ class Bar(object):
@dataclass_json
@dataclass()
class Foo(object):
u: typing.Optional[int]
v: typing.Optional[int]
w: int
x: typing.List[int]
y: typing.Dict[str, str]
z: Bar

foo = Foo(v=5, w=1, x=[1], y={"hello": "10"}, z=Bar(w=None, x=1.0, y="hello", z={"world": False}))
foo = Foo(u=5, v=None, w=1, x=[1], y={"hello": "10"}, z=Bar(w=None, x=1.0, y="hello", z={"world": False}))
generic = _json_format.Parse(typing.cast(DataClassJsonMixin, foo).to_json(), _struct.Struct())
lv = Literal(collection=LiteralCollection(literals=[Literal(scalar=Scalar(generic=generic))]))

Expand All @@ -173,12 +174,14 @@ class Foo(object):

pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[foo_class])
assert isinstance(pv, list)
assert pv[0].u == foo.u
assert pv[0].v == foo.v
assert pv[0].w == foo.w
assert pv[0].x == foo.x
assert pv[0].y == foo.y
assert pv[0].z.x == foo.z.x
assert type(pv[0].v) == int
assert type(pv[0].u) == int
assert pv[0].v is None
assert type(pv[0].w) == int
assert type(pv[0].z.x) == float
assert pv[0].z.y == foo.z.y
Expand Down Expand Up @@ -1178,6 +1181,7 @@ def test_pass_annotated_to_downstream_tasks():
"""
Test to confirm that the loaded dataframe is not affected and can be used in @dynamic.
"""

# pandas dataframe hash function
def hash_pandas_dataframe(df: pd.DataFrame) -> str:
return str(pd.util.hash_pandas_object(df))
Expand Down

0 comments on commit d313077

Please sign in to comment.