From 90450d1d481488a858e264f082df9baa57ec7585 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Thu, 19 Dec 2024 00:38:37 +0800 Subject: [PATCH 1/4] Fix pydantic default input Signed-off-by: Future-Outlier --- flytekit/clis/sdk_in_container/run.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 7d661c3ff8..e50c4cda1d 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -53,6 +53,7 @@ labels_callback, ) from flytekit.interaction.string_literals import literal_string_repr +from flytekit.lazy_import.lazy_module import is_imported from flytekit.loggers import logger from flytekit.models import security from flytekit.models.common import RawOutputDataConfig @@ -475,8 +476,21 @@ def to_click_option( If no custom logic exists, fall back to json.dumps. """ with FlyteContextManager.with_context(flyte_ctx.new_builder()): - encoder = JSONEncoder(python_type) - default_val = encoder.encode(default_val) + if is_imported("pydantic"): + try: + from pydantic import BaseModel as BaseModelV2 + from pydantic.v1 import BaseModel as BaseModelV1 + + if issubclass(python_type, BaseModelV2): + default_val = default_val.model_dump_json() + elif issubclass(python_type, BaseModelV1): + default_val = default_val.json() + except ImportError: + # Pydantic BaseModel v1 + default_val = default_val.json() + else: + encoder = JSONEncoder(python_type) + default_val = encoder.encode(default_val) if literal_var.type.metadata: description_extra = f": {json.dumps(literal_var.type.metadata)}" From d46d2928253994c8e9d6369d5424c6ccb5fa6961 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Thu, 19 Dec 2024 00:46:05 +0800 Subject: [PATCH 2/4] add pydantic integration test Signed-off-by: Future-Outlier --- .../integration/remote/test_remote.py | 8 ++++++++ .../remote/workflows/basic/pydantic_wf.py | 20 +++++++++++++++++++ 2 files changed, 28 insertions(+) create mode 100644 tests/flytekit/integration/remote/workflows/basic/pydantic_wf.py diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 5d953350a0..dadcc25101 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -111,6 +111,14 @@ def test_remote_eager_run(): # child_workflow.parent_wf asynchronously register a parent wf1 with child lp from another wf2. run("eager_example.py", "simple_eager_workflow", "--x", "3") +def test_pydantic_default_input_with_map_task(): + execution_id = run("pydantic_wf.py", "wf") + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) + execution = remote.fetch_execution(name=execution_id) + execution = remote.wait(execution=execution, timeout=datetime.timedelta(minutes=5)) + print("Execution Error:", execution.error) + assert execution.closure.phase == WorkflowExecutionPhase.SUCCEEDED, f"Execution failed with phase: {execution.closure.phase}" + def test_generic_idl_flytetypes(): os.environ["FLYTE_USE_OLD_DC_FORMAT"] = "true" diff --git a/tests/flytekit/integration/remote/workflows/basic/pydantic_wf.py b/tests/flytekit/integration/remote/workflows/basic/pydantic_wf.py new file mode 100644 index 0000000000..d5e9c32170 --- /dev/null +++ b/tests/flytekit/integration/remote/workflows/basic/pydantic_wf.py @@ -0,0 +1,20 @@ +from pydantic import BaseModel + +from flytekit import map_task +from typing import List +from flytekit import task, workflow + + +class MyBaseModel(BaseModel): + my_floats: List[float] = [1.0, 2.0, 5.0, 10.0] + +@task +def print_float(my_float: float): + print(f"my_float: {my_float}") + +@workflow +def wf(bm: MyBaseModel = MyBaseModel()): + map_task(print_float)(my_float=bm.my_floats) + +if __name__ == "__main__": + wf() From 9e23ba3558404afc4347174a83056f3d1bba2b28 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Thu, 2 Jan 2025 10:40:51 +0800 Subject: [PATCH 3/4] Use duck typing by Thomas's advice Signed-off-by: Future-Outlier Co-authored-by: Thomas J. Fan --- flytekit/clis/sdk_in_container/run.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index e50c4cda1d..f327e6f025 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -476,18 +476,12 @@ def to_click_option( If no custom logic exists, fall back to json.dumps. """ with FlyteContextManager.with_context(flyte_ctx.new_builder()): - if is_imported("pydantic"): - try: - from pydantic import BaseModel as BaseModelV2 - from pydantic.v1 import BaseModel as BaseModelV1 - - if issubclass(python_type, BaseModelV2): - default_val = default_val.model_dump_json() - elif issubclass(python_type, BaseModelV1): - default_val = default_val.json() - except ImportError: - # Pydantic BaseModel v1 - default_val = default_val.json() + if hasattr(default_val, "model_dump_json"): + # pydantic v2 + default_val = default_val.model_dump_json() + elif hasattr(default_val, "json"): + # pydantic v1 + default_val = default_val.json() else: encoder = JSONEncoder(python_type) default_val = encoder.encode(default_val) From 0b5ac7899a71862049931062f9beb2fc2dc2381b Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Thu, 2 Jan 2025 10:42:59 +0800 Subject: [PATCH 4/4] lint Signed-off-by: Future-Outlier --- flytekit/clis/sdk_in_container/run.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index f327e6f025..8da3739be0 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -53,7 +53,6 @@ labels_callback, ) from flytekit.interaction.string_literals import literal_string_repr -from flytekit.lazy_import.lazy_module import is_imported from flytekit.loggers import logger from flytekit.models import security from flytekit.models.common import RawOutputDataConfig