From 79d9465ac64a4ffea425ec593b4fabbe1b669516 Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Tue, 10 May 2022 21:29:25 -0700 Subject: [PATCH] Fix namedtuple and test Signed-off-by: Ketan Umare --- flytekit/clis/sdk_in_container/run.py | 26 ++++++++++++++++----- tests/flytekit/unit/cli/pyflyte/test_run.py | 12 +++++++++- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 5b142a82c3..02658454f1 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -423,14 +423,29 @@ def dump_flyte_remote_snippet(execution: FlyteWorkflowExecution, project: str, d ) -def get_entities_in_file(filename: str) -> typing.Tuple[typing.List[str], typing.List[str]]: +class Entities(typing.NamedTuple): + """ + NamedTuple to group all entities in a file + """ + + workflows: typing.List[str] + tasks: typing.List[str] + + def all(self) -> typing.List[str]: + e = [] + e.extend(self.workflows) + e.extend(self.tasks) + return e + + +def get_entities_in_file(filename: str) -> Entities: """ Returns a list of flyte workflow names and list of Flyte tasks in a file. """ flyte_ctx = context_manager.FlyteContextManager.current_context().with_serialization_settings( SerializationSettings(None) ) - module_name = os.path.splitext(filename)[0].replace(os.path.sep, ".") + module_name = os.path.splitext(os.path.relpath(filename))[0].replace(os.path.sep, ".") with context_manager.FlyteContextManager.with_context(flyte_ctx): with module_loader.add_sys_path(os.getcwd()): importlib.import_module(module_name) @@ -447,7 +462,7 @@ def get_entities_in_file(filename: str) -> typing.Tuple[typing.List[str], typing _, _, fn, _ = tracker.extract_task_module(o) tasks.append(fn) - return workflows, tasks + return Entities(workflows, tasks) def run_command(ctx: click.Context, entity: typing.Union[PythonFunctionWorkflow, PythonTask]): @@ -518,9 +533,8 @@ def __init__(self, filename: str, *args, **kwargs): self._filename = filename def list_commands(self, ctx): - workflows, tasks = get_entities_in_file(self._filename) - workflows.extend(tasks) - return workflows + entities = get_entities_in_file(self._filename) + return entities.all() def get_command(self, ctx, exe_entity): rel_path = os.path.relpath(self._filename) diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index 5d3aad70a1..8fcbd3667c 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -3,11 +3,14 @@ from click.testing import CliRunner from flytekit.clis.sdk_in_container import pyflyte +from flytekit.clis.sdk_in_container.run import get_entities_in_file + +WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "workflow.py") def test_pyflyte_run_wf(): runner = CliRunner() - module_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "workflow.py") + module_path = WORKFLOW_FILE result = runner.invoke(pyflyte.main, ["run", module_path, "my_wf", "--help"], catch_exceptions=False) assert result.exit_code == 0 @@ -52,3 +55,10 @@ def test_pyflyte_run_cli(): ) print(result.stdout) assert result.exit_code == 0 + + +def test_get_entities_in_file(): + e = get_entities_in_file(WORKFLOW_FILE) + assert e.workflows == ["my_wf"] + assert e.tasks == ["get_subset_df", "print_all", "show_sd"] + assert e.all() == ["my_wf", "get_subset_df", "print_all", "show_sd"]