Skip to content

Commit

Permalink
Handle common cases of mutable default arguments explicitly (#2651)
Browse files Browse the repository at this point in the history
Signed-off-by: Eduardo Apolinario <[email protected]>
Co-authored-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
eapolinario and eapolinario authored Aug 6, 2024
1 parent 4465249 commit 7d1227b
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 10 deletions.
11 changes: 8 additions & 3 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import typing
from copy import deepcopy
from enum import Enum
from typing import Any, Coroutine, Dict, Hashable, List, Optional, Set, Tuple, Union, cast, get_args
from typing import Any, Coroutine, Dict, List, Optional, Set, Tuple, Union, cast, get_args

from google.protobuf import struct_pb2 as _struct
from typing_extensions import Protocol
Expand Down Expand Up @@ -1116,8 +1116,13 @@ def create_and_link_node(
or UnionTransformer.is_optional_type(interface.inputs_with_defaults[k][0])
):
default_val = interface.inputs_with_defaults[k][1]
if not isinstance(default_val, Hashable):
raise _user_exceptions.FlyteAssertion("Cannot use non-hashable object as default argument")
# Common cases of mutable default arguments, as described in https://www.pullrequest.com/blog/python-pitfalls-the-perils-of-using-lists-and-dicts-as-default-arguments/
# or https://florimond.dev/en/posts/2018/08/python-mutable-defaults-are-the-source-of-all-evil, are not supported.
# As of 2024-08-05, Python native sets are not supported in Flytekit. However, they are included here for completeness.
if isinstance(default_val, list) or isinstance(default_val, dict) or isinstance(default_val, set):
raise _user_exceptions.FlyteAssertion(
f"Argument {k} for function {entity.name} is a mutable default argument, which is a python anti-pattern and not supported in flytekit tasks"
)
kwargs[k] = default_val
else:
error_msg = f"Input {k} of type {interface.inputs[k]} was not specified for function {entity.name}"
Expand Down
16 changes: 13 additions & 3 deletions tests/flytekit/unit/core/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import os
import typing
from collections import OrderedDict
Expand Down Expand Up @@ -775,7 +776,10 @@ def wf_no_input() -> typing.List[int]:
def wf_with_input() -> typing.List[int]:
return t1(a=input_val)

with pytest.raises(FlyteAssertion, match="Cannot use non-hashable object as default argument"):
with pytest.raises(
FlyteAssertion,
match=r"Argument a for function .*test_serialization\.t1 is a mutable default argument, which is a python anti-pattern and not supported in flytekit tasks"
):
get_serializable(OrderedDict(), serialization_settings, wf_no_input)

wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input)
Expand Down Expand Up @@ -810,7 +814,10 @@ def wf_no_input() -> typing.Dict[str, int]:
def wf_with_input() -> typing.Dict[str, int]:
return t1(a=input_val)

with pytest.raises(FlyteAssertion, match="Cannot use non-hashable object as default argument"):
with pytest.raises(
FlyteAssertion,
match=r"Argument a for function .*test_serialization\.t1 is a mutable default argument, which is a python anti-pattern and not supported in flytekit tasks"
):
get_serializable(OrderedDict(), serialization_settings, wf_no_input)

wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input)
Expand Down Expand Up @@ -910,7 +917,10 @@ def wf_no_input() -> typing.Optional[typing.List[int]]:
def wf_with_input() -> typing.Optional[typing.List[int]]:
return t1(a=input_val)

with pytest.raises(FlyteAssertion, match="Cannot use non-hashable object as default argument"):
with pytest.raises(
FlyteAssertion,
match=r"Argument a for function .*test_serialization\.t1 is a mutable default argument, which is a python anti-pattern and not supported in flytekit tasks"
):
get_serializable(OrderedDict(), serialization_settings, wf_no_input)

wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -543,10 +543,11 @@ def test_reregister_encoder():


def test_default_args_task():
default_val = pd.DataFrame({"name": ["Aegon"], "age": [27]})
input_val = generate_pandas()

@task
def t1(a: pd.DataFrame = pd.DataFrame()) -> pd.DataFrame:
def t1(a: pd.DataFrame = default_val) -> pd.DataFrame:
return a

@workflow
Expand All @@ -557,11 +558,16 @@ def wf_no_input() -> pd.DataFrame:
def wf_with_input() -> pd.DataFrame:
return t1(a=input_val)

with pytest.raises(FlyteAssertion, match="Cannot use non-hashable object as default argument"):
get_serializable(OrderedDict(), serialization_settings, wf_no_input)

wf_no_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_no_input)
wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input)

assert wf_no_input_spec.template.nodes[0].inputs[
0
].binding.value.structured_dataset.metadata == StructuredDatasetMetadata(
structured_dataset_type=StructuredDatasetType(
format="parquet",
),
)
assert wf_with_input_spec.template.nodes[0].inputs[
0
].binding.value.structured_dataset.metadata == StructuredDatasetMetadata(
Expand All @@ -570,8 +576,12 @@ def wf_with_input() -> pd.DataFrame:
),
)

assert wf_no_input_spec.template.interface.outputs["o0"].type == LiteralType(
structured_dataset_type=StructuredDatasetType()
)
assert wf_with_input_spec.template.interface.outputs["o0"].type == LiteralType(
structured_dataset_type=StructuredDatasetType()
)

pd.testing.assert_frame_equal(wf_no_input(), default_val)
pd.testing.assert_frame_equal(wf_with_input(), input_val)

0 comments on commit 7d1227b

Please sign in to comment.