Skip to content

Commit

Permalink
pyflyte run imperative workflow (#1597)
Browse files Browse the repository at this point in the history
* pyflyte run imperative workflow

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* lint

Signed-off-by: Kevin Su <[email protected]>

---------

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Apr 20, 2023
1 parent 18b212b commit 8c05797
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
8 changes: 6 additions & 2 deletions flytekit/tools/script_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from flytekit import PythonFunctionTask
from flytekit.core.tracker import get_full_module_path
from flytekit.core.workflow import WorkflowBase
from flytekit.core.workflow import ImperativeWorkflow, WorkflowBase


def compress_scripts(source_path: str, destination: str, module_name: str):
Expand Down Expand Up @@ -92,7 +92,11 @@ def copy_module_to_destination(
# Try to copy other files to destination if tasks or workflows aren't in the same file
for flyte_entity_name in mod.__dict__:
flyte_entity = mod.__dict__[flyte_entity_name]
if isinstance(flyte_entity, (PythonFunctionTask, WorkflowBase)) and flyte_entity.instantiated_in:
if (
isinstance(flyte_entity, (PythonFunctionTask, WorkflowBase))
and not isinstance(flyte_entity, ImperativeWorkflow)
and flyte_entity.instantiated_in
):
copy_module_to_destination(
original_source_path, original_destination_path, flyte_entity.instantiated_in, visited
)
Expand Down
19 changes: 18 additions & 1 deletion tests/flytekit/unit/tools/test_script_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,19 @@ def my_wf() -> str:
return "hello world"
"""

IMPERATIVE_WORKFLOW = """
from flytekit import Workflow, task
@task
def t1(a: int):
print(a)
wf = Workflow(name="my.imperative.workflow.example")
wf.add_workflow_input("a", int)
node_t1 = wf.add_entity(t1, a=wf.inputs["a"])
"""

T1_TASK = """
from flytekit import task
from wf2.test import t2
Expand Down Expand Up @@ -44,6 +57,9 @@ def test_deterministic_hash(tmp_path):
workflow_file = workflows_dir / "hello_world.py"
workflow_file.write_text(MAIN_WORKFLOW)

imperative_workflow_file = workflows_dir / "imperative_wf.py"
imperative_workflow_file.write_text(IMPERATIVE_WORKFLOW)

t1_dir = tmp_path / "wf1"
t1_dir.mkdir()
open(t1_dir / "__init__.py", "a").close()
Expand All @@ -58,7 +74,6 @@ def test_deterministic_hash(tmp_path):

destination = tmp_path / "destination"

print(workflows_dir)
sys.path.append(str(workflows_dir.parent))
compress_scripts(str(workflows_dir.parent), str(destination), "workflows.hello_world")

Expand All @@ -81,3 +96,5 @@ def test_deterministic_hash(tmp_path):
)
result.check_returncode()
assert len(next(os.walk(test_dir))[1]) == 3

compress_scripts(str(workflows_dir.parent), str(destination), "workflows.imperative_wf")

0 comments on commit 8c05797

Please sign in to comment.