Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle common cases of mutable default arguments explicitly #2651

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import typing
from copy import deepcopy
from enum import Enum
from typing import Any, Coroutine, Dict, Hashable, List, Optional, Set, Tuple, Union, cast, get_args
from typing import Any, Coroutine, Dict, List, Optional, Set, Tuple, Union, cast, get_args

from google.protobuf import struct_pb2 as _struct
from typing_extensions import Protocol
Expand Down Expand Up @@ -1116,8 +1116,13 @@ def create_and_link_node(
or UnionTransformer.is_optional_type(interface.inputs_with_defaults[k][0])
):
default_val = interface.inputs_with_defaults[k][1]
if not isinstance(default_val, Hashable):
raise _user_exceptions.FlyteAssertion("Cannot use non-hashable object as default argument")
# Common cases of mutable default arguments, as described in https://www.pullrequest.com/blog/python-pitfalls-the-perils-of-using-lists-and-dicts-as-default-arguments/
# or https://florimond.dev/en/posts/2018/08/python-mutable-defaults-are-the-source-of-all-evil, are not supported.
# As of 2024-08-05, Python native sets are not supported in Flytekit. However, they are included here for completeness.
if isinstance(default_val, list) or isinstance(default_val, dict) or isinstance(default_val, set):
raise _user_exceptions.FlyteAssertion(
f"Argument {k} for function {entity.name} is a mutable default argument, which is a python anti-pattern and not supported in flytekit tasks"
)
kwargs[k] = default_val
else:
error_msg = f"Input {k} of type {interface.inputs[k]} was not specified for function {entity.name}"
Expand Down
16 changes: 13 additions & 3 deletions tests/flytekit/unit/core/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import os
import typing
from collections import OrderedDict
Expand Down Expand Up @@ -775,7 +776,10 @@ def wf_no_input() -> typing.List[int]:
def wf_with_input() -> typing.List[int]:
return t1(a=input_val)

with pytest.raises(FlyteAssertion, match="Cannot use non-hashable object as default argument"):
with pytest.raises(
FlyteAssertion,
match=r"Argument a for function .*test_serialization\.t1 is a mutable default argument, which is a python anti-pattern and not supported in flytekit tasks"
):
get_serializable(OrderedDict(), serialization_settings, wf_no_input)

wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input)
Expand Down Expand Up @@ -810,7 +814,10 @@ def wf_no_input() -> typing.Dict[str, int]:
def wf_with_input() -> typing.Dict[str, int]:
return t1(a=input_val)

with pytest.raises(FlyteAssertion, match="Cannot use non-hashable object as default argument"):
with pytest.raises(
FlyteAssertion,
match=r"Argument a for function .*test_serialization\.t1 is a mutable default argument, which is a python anti-pattern and not supported in flytekit tasks"
):
get_serializable(OrderedDict(), serialization_settings, wf_no_input)

wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input)
Expand Down Expand Up @@ -910,7 +917,10 @@ def wf_no_input() -> typing.Optional[typing.List[int]]:
def wf_with_input() -> typing.Optional[typing.List[int]]:
return t1(a=input_val)

with pytest.raises(FlyteAssertion, match="Cannot use non-hashable object as default argument"):
with pytest.raises(
FlyteAssertion,
match=r"Argument a for function .*test_serialization\.t1 is a mutable default argument, which is a python anti-pattern and not supported in flytekit tasks"
):
get_serializable(OrderedDict(), serialization_settings, wf_no_input)

wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -543,10 +543,11 @@ def test_reregister_encoder():


def test_default_args_task():
default_val = pd.DataFrame({"name": ["Aegon"], "age": [27]})
input_val = generate_pandas()

@task
def t1(a: pd.DataFrame = pd.DataFrame()) -> pd.DataFrame:
def t1(a: pd.DataFrame = default_val) -> pd.DataFrame:
return a

@workflow
Expand All @@ -557,11 +558,16 @@ def wf_no_input() -> pd.DataFrame:
def wf_with_input() -> pd.DataFrame:
return t1(a=input_val)

with pytest.raises(FlyteAssertion, match="Cannot use non-hashable object as default argument"):
get_serializable(OrderedDict(), serialization_settings, wf_no_input)

wf_no_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_no_input)
wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input)

assert wf_no_input_spec.template.nodes[0].inputs[
0
].binding.value.structured_dataset.metadata == StructuredDatasetMetadata(
structured_dataset_type=StructuredDatasetType(
format="parquet",
),
)
assert wf_with_input_spec.template.nodes[0].inputs[
0
].binding.value.structured_dataset.metadata == StructuredDatasetMetadata(
Expand All @@ -570,8 +576,12 @@ def wf_with_input() -> pd.DataFrame:
),
)

assert wf_no_input_spec.template.interface.outputs["o0"].type == LiteralType(
structured_dataset_type=StructuredDatasetType()
)
assert wf_with_input_spec.template.interface.outputs["o0"].type == LiteralType(
structured_dataset_type=StructuredDatasetType()
)

pd.testing.assert_frame_equal(wf_no_input(), default_val)
pd.testing.assert_frame_equal(wf_with_input(), input_val)
Loading