Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the type of optional[int] in dataclass #1135

Merged
merged 3 commits into from
Aug 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,9 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) ->
return python_val

def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any:
if t == int:
if val is None:
return val
if t == int or t == typing.Optional[int]:
return int(val)

if isinstance(val, list):
Expand Down Expand Up @@ -1309,6 +1311,11 @@ def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[datac
def _get_element_type(element_property: typing.Dict[str, str]) -> Type[T]:
element_type = element_property["type"]
element_format = element_property["format"] if "format" in element_property else None

if type(element_type) == list:
eapolinario marked this conversation as resolved.
Show resolved Hide resolved
# Element type of Optional[int] is [integer, None]
return typing.Optional[_get_element_type({"type": element_type[0]})]

if element_type == "string":
return str
elif element_type == "integer":
Expand Down
32 changes: 21 additions & 11 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,14 @@ class Bar(object):
@dataclass_json
@dataclass()
class Foo(object):
u: typing.Optional[int]
v: typing.Optional[int]
w: int
x: typing.List[int]
y: typing.Dict[str, str]
z: Bar

foo = Foo(w=1, x=[1], y={"hello": "10"}, z=Bar(w=None, x=1.0, y="hello", z={"world": False}))
foo = Foo(u=5, v=None, w=1, x=[1], y={"hello": "10"}, z=Bar(w=None, x=1.0, y="hello", z={"world": False}))
generic = _json_format.Parse(typing.cast(DataClassJsonMixin, foo).to_json(), _struct.Struct())
lv = Literal(collection=LiteralCollection(literals=[Literal(scalar=Scalar(generic=generic))]))

Expand All @@ -170,16 +172,23 @@ class Foo(object):
schema = JSONSchema().dump(typing.cast(DataClassJsonMixin, Foo).schema())
foo_class = convert_json_schema_to_python_class(schema["definitions"], "FooSchema")

pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[foo_class])
assert isinstance(pv, list)
assert pv[0].w == foo.w
assert pv[0].x == foo.x
assert pv[0].y == foo.y
assert pv[0].z.x == foo.z.x
assert type(pv[0].z.x) == float
assert pv[0].z.y == foo.z.y
assert pv[0].z.z == foo.z.z
assert foo == dataclass_from_dict(Foo, asdict(pv[0]))
guessed_pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[foo_class])
print("=====")
pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[Foo])
assert isinstance(guessed_pv, list)
assert guessed_pv[0].u == pv[0].u
assert guessed_pv[0].v == pv[0].v
assert guessed_pv[0].w == pv[0].w
assert guessed_pv[0].x == pv[0].x
assert guessed_pv[0].y == pv[0].y
assert guessed_pv[0].z.x == pv[0].z.x
assert type(guessed_pv[0].u) == int
assert guessed_pv[0].v is None
assert type(guessed_pv[0].w) == int
assert type(guessed_pv[0].z.x) == float
assert guessed_pv[0].z.y == pv[0].z.y
assert guessed_pv[0].z.z == pv[0].z.z
assert pv[0] == dataclass_from_dict(Foo, asdict(guessed_pv[0]))


def test_file_no_downloader_default():
Expand Down Expand Up @@ -1174,6 +1183,7 @@ def test_pass_annotated_to_downstream_tasks():
"""
Test to confirm that the loaded dataframe is not affected and can be used in @dynamic.
"""

# pandas dataframe hash function
def hash_pandas_dataframe(df: pd.DataFrame) -> str:
return str(pd.util.hash_pandas_object(df))
Expand Down