From 8c05797feeaf0e47a56075bf1c725c8284157660 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 21 Apr 2023 07:44:00 +0800 Subject: [PATCH] pyflyte run imperative workflow (#1597) * pyflyte run imperative workflow Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * lint Signed-off-by: Kevin Su --------- Signed-off-by: Kevin Su --- flytekit/tools/script_mode.py | 8 ++++++-- tests/flytekit/unit/tools/test_script_mode.py | 19 ++++++++++++++++++- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/flytekit/tools/script_mode.py b/flytekit/tools/script_mode.py index 221617dfe6..ecc71a2398 100644 --- a/flytekit/tools/script_mode.py +++ b/flytekit/tools/script_mode.py @@ -10,7 +10,7 @@ from flytekit import PythonFunctionTask from flytekit.core.tracker import get_full_module_path -from flytekit.core.workflow import WorkflowBase +from flytekit.core.workflow import ImperativeWorkflow, WorkflowBase def compress_scripts(source_path: str, destination: str, module_name: str): @@ -92,7 +92,11 @@ def copy_module_to_destination( # Try to copy other files to destination if tasks or workflows aren't in the same file for flyte_entity_name in mod.__dict__: flyte_entity = mod.__dict__[flyte_entity_name] - if isinstance(flyte_entity, (PythonFunctionTask, WorkflowBase)) and flyte_entity.instantiated_in: + if ( + isinstance(flyte_entity, (PythonFunctionTask, WorkflowBase)) + and not isinstance(flyte_entity, ImperativeWorkflow) + and flyte_entity.instantiated_in + ): copy_module_to_destination( original_source_path, original_destination_path, flyte_entity.instantiated_in, visited ) diff --git a/tests/flytekit/unit/tools/test_script_mode.py b/tests/flytekit/unit/tools/test_script_mode.py index c597e9bdc2..aba4e0ab17 100644 --- a/tests/flytekit/unit/tools/test_script_mode.py +++ b/tests/flytekit/unit/tools/test_script_mode.py @@ -13,6 +13,19 @@ def my_wf() -> str: return "hello world" """ +IMPERATIVE_WORKFLOW = """ +from flytekit import Workflow, task + +@task +def t1(a: int): + print(a) + + +wf = Workflow(name="my.imperative.workflow.example") +wf.add_workflow_input("a", int) +node_t1 = wf.add_entity(t1, a=wf.inputs["a"]) +""" + T1_TASK = """ from flytekit import task from wf2.test import t2 @@ -44,6 +57,9 @@ def test_deterministic_hash(tmp_path): workflow_file = workflows_dir / "hello_world.py" workflow_file.write_text(MAIN_WORKFLOW) + imperative_workflow_file = workflows_dir / "imperative_wf.py" + imperative_workflow_file.write_text(IMPERATIVE_WORKFLOW) + t1_dir = tmp_path / "wf1" t1_dir.mkdir() open(t1_dir / "__init__.py", "a").close() @@ -58,7 +74,6 @@ def test_deterministic_hash(tmp_path): destination = tmp_path / "destination" - print(workflows_dir) sys.path.append(str(workflows_dir.parent)) compress_scripts(str(workflows_dir.parent), str(destination), "workflows.hello_world") @@ -81,3 +96,5 @@ def test_deterministic_hash(tmp_path): ) result.check_returncode() assert len(next(os.walk(test_dir))[1]) == 3 + + compress_scripts(str(workflows_dir.parent), str(destination), "workflows.imperative_wf")