Skip to content

Commit

Permalink
Fix namedtuple and test
Browse files Browse the repository at this point in the history
Signed-off-by: Ketan Umare <[email protected]>
  • Loading branch information
kumare3 committed May 11, 2022
1 parent 193ac4e commit 79d9465
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 7 deletions.
26 changes: 20 additions & 6 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,14 +423,29 @@ def dump_flyte_remote_snippet(execution: FlyteWorkflowExecution, project: str, d
)


def get_entities_in_file(filename: str) -> typing.Tuple[typing.List[str], typing.List[str]]:
class Entities(typing.NamedTuple):
"""
NamedTuple to group all entities in a file
"""

workflows: typing.List[str]
tasks: typing.List[str]

def all(self) -> typing.List[str]:
e = []
e.extend(self.workflows)
e.extend(self.tasks)
return e


def get_entities_in_file(filename: str) -> Entities:
"""
Returns a list of flyte workflow names and list of Flyte tasks in a file.
"""
flyte_ctx = context_manager.FlyteContextManager.current_context().with_serialization_settings(
SerializationSettings(None)
)
module_name = os.path.splitext(filename)[0].replace(os.path.sep, ".")
module_name = os.path.splitext(os.path.relpath(filename))[0].replace(os.path.sep, ".")
with context_manager.FlyteContextManager.with_context(flyte_ctx):
with module_loader.add_sys_path(os.getcwd()):
importlib.import_module(module_name)
Expand All @@ -447,7 +462,7 @@ def get_entities_in_file(filename: str) -> typing.Tuple[typing.List[str], typing
_, _, fn, _ = tracker.extract_task_module(o)
tasks.append(fn)

return workflows, tasks
return Entities(workflows, tasks)


def run_command(ctx: click.Context, entity: typing.Union[PythonFunctionWorkflow, PythonTask]):
Expand Down Expand Up @@ -518,9 +533,8 @@ def __init__(self, filename: str, *args, **kwargs):
self._filename = filename

def list_commands(self, ctx):
workflows, tasks = get_entities_in_file(self._filename)
workflows.extend(tasks)
return workflows
entities = get_entities_in_file(self._filename)
return entities.all()

def get_command(self, ctx, exe_entity):
rel_path = os.path.relpath(self._filename)
Expand Down
12 changes: 11 additions & 1 deletion tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
from click.testing import CliRunner

from flytekit.clis.sdk_in_container import pyflyte
from flytekit.clis.sdk_in_container.run import get_entities_in_file

WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "workflow.py")


def test_pyflyte_run_wf():
runner = CliRunner()
module_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "workflow.py")
module_path = WORKFLOW_FILE
result = runner.invoke(pyflyte.main, ["run", module_path, "my_wf", "--help"], catch_exceptions=False)

assert result.exit_code == 0
Expand Down Expand Up @@ -52,3 +55,10 @@ def test_pyflyte_run_cli():
)
print(result.stdout)
assert result.exit_code == 0


def test_get_entities_in_file():
e = get_entities_in_file(WORKFLOW_FILE)
assert e.workflows == ["my_wf"]
assert e.tasks == ["get_subset_df", "print_all", "show_sd"]
assert e.all() == ["my_wf", "get_subset_df", "print_all", "show_sd"]

0 comments on commit 79d9465

Please sign in to comment.