From 944c42eb6b24e2079533c17259ef4c065450ced2 Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Mon, 9 May 2022 21:40:18 -0700 Subject: [PATCH 1/2] pyflyte run can now execute a task either locally or remote Signed-off-by: Ketan Umare --- flytekit/clis/sdk_in_container/run.py | 54 +++++++++++++++------------ flytekit/remote/remote.py | 8 ++-- 2 files changed, 35 insertions(+), 27 deletions(-) diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 1b528a3563..5b142a82c3 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -16,6 +16,7 @@ from flytekit.configuration import Config, ImageConfig, SerializationSettings from flytekit.configuration.default_images import DefaultImages from flytekit.core import context_manager, tracker +from flytekit.core.base_task import PythonTask from flytekit.core.context_manager import FlyteContext from flytekit.core.data_persistence import FileAccessProvider from flytekit.core.type_engine import TypeEngine @@ -393,7 +394,7 @@ def get_workflow_command_base_params() -> typing.List[click.Option]: ] -def load_naive_entity(module_name: str, workflow_name: str) -> WorkflowBase: +def load_naive_entity(module_name: str, entity_name: str) -> typing.Union[WorkflowBase, PythonTask]: """ Load the workflow of a the script file. N.B.: it assumes that the file is self-contained, in other words, there are no relative imports. @@ -404,7 +405,7 @@ def load_naive_entity(module_name: str, workflow_name: str) -> WorkflowBase: with context_manager.FlyteContextManager.with_context(flyte_ctx): with module_loader.add_sys_path(os.getcwd()): importlib.import_module(module_name) - return module_loader.load_object_from_module(f"{module_name}.{workflow_name}") + return module_loader.load_object_from_module(f"{module_name}.{entity_name}") def dump_flyte_remote_snippet(execution: FlyteWorkflowExecution, project: str, domain: str): @@ -422,9 +423,9 @@ def dump_flyte_remote_snippet(execution: FlyteWorkflowExecution, project: str, d ) -def get_workflows_in_file(filename: str) -> typing.List[str]: +def get_entities_in_file(filename: str) -> typing.Tuple[typing.List[str], typing.List[str]]: """ - Returns a list of flyte workflow names in a file. + Returns a list of flyte workflow names and list of Flyte tasks in a file. """ flyte_ctx = context_manager.FlyteContextManager.current_context().with_serialization_settings( SerializationSettings(None) @@ -435,17 +436,21 @@ def get_workflows_in_file(filename: str) -> typing.List[str]: importlib.import_module(module_name) workflows = [] + tasks = [] module = importlib.import_module(module_name) for k in dir(module): o = module.__dict__[k] if isinstance(o, PythonFunctionWorkflow): _, _, fn, _ = tracker.extract_task_module(o) workflows.append(fn) + elif isinstance(o, PythonTask): + _, _, fn, _ = tracker.extract_task_module(o) + tasks.append(fn) - return workflows + return workflows, tasks -def run_command(ctx: click.Context, wf_entity: PythonFunctionWorkflow): +def run_command(ctx: click.Context, entity: typing.Union[PythonFunctionWorkflow, PythonTask]): """ Returns a function that is used to implement WorkflowCommand and execute a flyte workflow. """ @@ -454,11 +459,11 @@ def _run(*args, **kwargs): run_level_params = ctx.obj[RUN_LEVEL_PARAMS_KEY] project, domain = run_level_params.get("project"), run_level_params.get("domain") inputs = {} - for input_name, _ in wf_entity.python_interface.inputs.items(): + for input_name, _ in entity.python_interface.inputs.items(): inputs[input_name] = kwargs.get(input_name) if not ctx.obj[REMOTE_FLAG_KEY]: - output = wf_entity(**inputs) + output = entity(**inputs) click.echo(output) return @@ -468,8 +473,8 @@ def _run(*args, **kwargs): # PandasToParquetDataProxyEncodingHandler(get_upload_url_fn), default_for_type=True # ) - wf = remote.register_script( - wf_entity, + remote_entity = remote.register_script( + entity, project=project, domain=domain, image_config=run_level_params.get("image_config", None), @@ -484,14 +489,14 @@ def _run(*args, **kwargs): options = Options.default_from(k8s_service_account=service_account) execution = remote.execute( - wf, + remote_entity, inputs=inputs, project=project, domain=domain, name=run_level_params.get("name"), wait=run_level_params.get("wait_execution"), options=options, - type_hints=wf_entity.python_interface.inputs, + type_hints=entity.python_interface.inputs, ) console_url = remote.generate_console_url(execution) @@ -513,10 +518,11 @@ def __init__(self, filename: str, *args, **kwargs): self._filename = filename def list_commands(self, ctx): - workflows = get_workflows_in_file(self._filename) + workflows, tasks = get_entities_in_file(self._filename) + workflows.extend(tasks) return workflows - def get_command(self, ctx, workflow): + def get_command(self, ctx, exe_entity): rel_path = os.path.relpath(self._filename) if rel_path.startswith(".."): raise ValueError( @@ -524,7 +530,7 @@ def get_command(self, ctx, workflow): ) module = os.path.splitext(rel_path)[0].replace(os.path.sep, ".") - wf_entity = load_naive_entity(module, workflow) + entity = load_naive_entity(module, exe_entity) # If this is a remote execution, which we should know at this point, then create the remote object p = ctx.obj[RUN_LEVEL_PARAMS_KEY].get("project") @@ -537,24 +543,24 @@ def get_command(self, ctx, workflow): # Add options for each of the workflow inputs params = [] - for input_name, input_type_val in wf_entity.python_interface.inputs_with_defaults.items(): - literal_var = wf_entity.interface.inputs.get(input_name) + for input_name, input_type_val in entity.python_interface.inputs_with_defaults.items(): + literal_var = entity.interface.inputs.get(input_name) python_type, default_val = input_type_val params.append( to_click_option(ctx, flyte_ctx, input_name, literal_var, python_type, default_val, get_upload_url_fn) ) cmd = click.Command( - name=workflow, + name=exe_entity, params=params, - callback=run_command(ctx, wf_entity), - help=f"Run {module}.{workflow} in script mode", + callback=run_command(ctx, entity), + help=f"Run {module}.{exe_entity} in script mode", ) return cmd class RunCommand(click.MultiCommand): """ - A click command group for registering and executing flyte workflows in a file. + A click command group for registering and executing flyte workflows & tasks in a file. """ def __init__(self, *args, **kwargs): @@ -566,11 +572,11 @@ def list_commands(self, ctx): def get_command(self, ctx, filename): ctx.obj[RUN_LEVEL_PARAMS_KEY] = ctx.params - return WorkflowCommand(filename, name=filename, help="Run a workflow in a file using script mode") + return WorkflowCommand(filename, name=filename, help="Run a [workflow|task] in a file using script mode") run = RunCommand( name="run", - help="Run_old command, a.k.a. script mode. It allows for a a single script to be " - + "registered and run from the command line (e.g. Jupyter notebooks).", + help="Run command: This command can execute either a workflow or a task from the commandline, for " + "fully self-contained scripts. Tasks and workflows cannot be imported from other files currently.", ) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 5802ecbbca..198cd7d576 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -507,7 +507,7 @@ def register_workflow( def register_script( self, - entity: WorkflowBase, + entity: typing.Union[WorkflowBase, PythonTask], image_config: typing.Optional[ImageConfig] = None, version: typing.Optional[str] = None, project: typing.Optional[str] = None, @@ -515,7 +515,7 @@ def register_script( destination_dir: str = ".", default_launch_plan: typing.Optional[bool] = True, options: typing.Optional[Options] = None, - ) -> FlyteWorkflow: + ) -> typing.Union[FlyteWorkflow, FlyteTask]: """ Use this method to register a workflow via script mode. :param destination_dir: @@ -523,7 +523,7 @@ def register_script( :param project: :param image_config: :param version: version for the entity to be registered as - :param entity: The workflow to be registered + :param entity: The workflow to be registered or the task to be registered :param default_launch_plan: This should be true if a default launch plan should be created for the workflow :param options: Additional execution options that can be configured for the default launchplan :return: @@ -565,6 +565,8 @@ def register_script( h.update(bytes(__version__, "utf-8")) version = base64.urlsafe_b64encode(h.digest()) + if isinstance(entity, PythonTask): + return self.register_task(entity, serialization_settings, version) return self.register_workflow(entity, serialization_settings, version, default_launch_plan, options) def register_launch_plan( From 79d9465ac64a4ffea425ec593b4fabbe1b669516 Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Tue, 10 May 2022 21:29:25 -0700 Subject: [PATCH 2/2] Fix namedtuple and test Signed-off-by: Ketan Umare --- flytekit/clis/sdk_in_container/run.py | 26 ++++++++++++++++----- tests/flytekit/unit/cli/pyflyte/test_run.py | 12 +++++++++- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 5b142a82c3..02658454f1 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -423,14 +423,29 @@ def dump_flyte_remote_snippet(execution: FlyteWorkflowExecution, project: str, d ) -def get_entities_in_file(filename: str) -> typing.Tuple[typing.List[str], typing.List[str]]: +class Entities(typing.NamedTuple): + """ + NamedTuple to group all entities in a file + """ + + workflows: typing.List[str] + tasks: typing.List[str] + + def all(self) -> typing.List[str]: + e = [] + e.extend(self.workflows) + e.extend(self.tasks) + return e + + +def get_entities_in_file(filename: str) -> Entities: """ Returns a list of flyte workflow names and list of Flyte tasks in a file. """ flyte_ctx = context_manager.FlyteContextManager.current_context().with_serialization_settings( SerializationSettings(None) ) - module_name = os.path.splitext(filename)[0].replace(os.path.sep, ".") + module_name = os.path.splitext(os.path.relpath(filename))[0].replace(os.path.sep, ".") with context_manager.FlyteContextManager.with_context(flyte_ctx): with module_loader.add_sys_path(os.getcwd()): importlib.import_module(module_name) @@ -447,7 +462,7 @@ def get_entities_in_file(filename: str) -> typing.Tuple[typing.List[str], typing _, _, fn, _ = tracker.extract_task_module(o) tasks.append(fn) - return workflows, tasks + return Entities(workflows, tasks) def run_command(ctx: click.Context, entity: typing.Union[PythonFunctionWorkflow, PythonTask]): @@ -518,9 +533,8 @@ def __init__(self, filename: str, *args, **kwargs): self._filename = filename def list_commands(self, ctx): - workflows, tasks = get_entities_in_file(self._filename) - workflows.extend(tasks) - return workflows + entities = get_entities_in_file(self._filename) + return entities.all() def get_command(self, ctx, exe_entity): rel_path = os.path.relpath(self._filename) diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index 5d3aad70a1..8fcbd3667c 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -3,11 +3,14 @@ from click.testing import CliRunner from flytekit.clis.sdk_in_container import pyflyte +from flytekit.clis.sdk_in_container.run import get_entities_in_file + +WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "workflow.py") def test_pyflyte_run_wf(): runner = CliRunner() - module_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "workflow.py") + module_path = WORKFLOW_FILE result = runner.invoke(pyflyte.main, ["run", module_path, "my_wf", "--help"], catch_exceptions=False) assert result.exit_code == 0 @@ -52,3 +55,10 @@ def test_pyflyte_run_cli(): ) print(result.stdout) assert result.exit_code == 0 + + +def test_get_entities_in_file(): + e = get_entities_in_file(WORKFLOW_FILE) + assert e.workflows == ["my_wf"] + assert e.tasks == ["get_subset_df", "print_all", "show_sd"] + assert e.all() == ["my_wf", "get_subset_df", "print_all", "show_sd"]