diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 81f5cf83f1..b6f723fe0c 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -17,6 +17,7 @@ from flytekit.configuration import Config, ImageConfig, SerializationSettings from flytekit.configuration.default_images import DefaultImages from flytekit.core import context_manager, tracker +from flytekit.core.base_task import PythonTask from flytekit.core.context_manager import FlyteContext from flytekit.core.data_persistence import FileAccessProvider from flytekit.core.type_engine import TypeEngine @@ -395,7 +396,7 @@ def get_workflow_command_base_params() -> typing.List[click.Option]: ] -def load_naive_entity(module_name: str, workflow_name: str) -> WorkflowBase: +def load_naive_entity(module_name: str, entity_name: str) -> typing.Union[WorkflowBase, PythonTask]: """ Load the workflow of a the script file. N.B.: it assumes that the file is self-contained, in other words, there are no relative imports. @@ -406,7 +407,7 @@ def load_naive_entity(module_name: str, workflow_name: str) -> WorkflowBase: with context_manager.FlyteContextManager.with_context(flyte_ctx): with module_loader.add_sys_path(os.getcwd()): importlib.import_module(module_name) - return module_loader.load_object_from_module(f"{module_name}.{workflow_name}") + return module_loader.load_object_from_module(f"{module_name}.{entity_name}") def dump_flyte_remote_snippet(execution: FlyteWorkflowExecution, project: str, domain: str): @@ -424,30 +425,49 @@ def dump_flyte_remote_snippet(execution: FlyteWorkflowExecution, project: str, d ) -def get_workflows_in_file(filename: str) -> typing.List[str]: +class Entities(typing.NamedTuple): """ - Returns a list of flyte workflow names in a file. + NamedTuple to group all entities in a file + """ + + workflows: typing.List[str] + tasks: typing.List[str] + + def all(self) -> typing.List[str]: + e = [] + e.extend(self.workflows) + e.extend(self.tasks) + return e + + +def get_entities_in_file(filename: str) -> Entities: + """ + Returns a list of flyte workflow names and list of Flyte tasks in a file. """ flyte_ctx = context_manager.FlyteContextManager.current_context().with_serialization_settings( SerializationSettings(None) ) - module_name = os.path.splitext(filename)[0].replace(os.path.sep, ".") + module_name = os.path.splitext(os.path.relpath(filename))[0].replace(os.path.sep, ".") with context_manager.FlyteContextManager.with_context(flyte_ctx): with module_loader.add_sys_path(os.getcwd()): importlib.import_module(module_name) workflows = [] + tasks = [] module = importlib.import_module(module_name) for k in dir(module): o = module.__dict__[k] if isinstance(o, PythonFunctionWorkflow): _, _, fn, _ = tracker.extract_task_module(o) workflows.append(fn) + elif isinstance(o, PythonTask): + _, _, fn, _ = tracker.extract_task_module(o) + tasks.append(fn) - return workflows + return Entities(workflows, tasks) -def run_command(ctx: click.Context, wf_entity: PythonFunctionWorkflow): +def run_command(ctx: click.Context, entity: typing.Union[PythonFunctionWorkflow, PythonTask]): """ Returns a function that is used to implement WorkflowCommand and execute a flyte workflow. """ @@ -456,11 +476,11 @@ def _run(*args, **kwargs): run_level_params = ctx.obj[RUN_LEVEL_PARAMS_KEY] project, domain = run_level_params.get("project"), run_level_params.get("domain") inputs = {} - for input_name, _ in wf_entity.python_interface.inputs.items(): + for input_name, _ in entity.python_interface.inputs.items(): inputs[input_name] = kwargs.get(input_name) if not ctx.obj[REMOTE_FLAG_KEY]: - output = wf_entity(**inputs) + output = entity(**inputs) click.echo(output) return @@ -470,8 +490,8 @@ def _run(*args, **kwargs): # PandasToParquetDataProxyEncodingHandler(get_upload_url_fn), default_for_type=True # ) - wf = remote.register_script( - wf_entity, + remote_entity = remote.register_script( + entity, project=project, domain=domain, image_config=run_level_params.get("image_config", None), @@ -486,14 +506,14 @@ def _run(*args, **kwargs): options = Options.default_from(k8s_service_account=service_account) execution = remote.execute( - wf, + remote_entity, inputs=inputs, project=project, domain=domain, name=run_level_params.get("name"), wait=run_level_params.get("wait_execution"), options=options, - type_hints=wf_entity.python_interface.inputs, + type_hints=entity.python_interface.inputs, ) console_url = remote.generate_console_url(execution) @@ -515,10 +535,10 @@ def __init__(self, filename: str, *args, **kwargs): self._filename = filename def list_commands(self, ctx): - workflows = get_workflows_in_file(self._filename) - return workflows + entities = get_entities_in_file(self._filename) + return entities.all() - def get_command(self, ctx, workflow): + def get_command(self, ctx, exe_entity): rel_path = os.path.relpath(self._filename) if rel_path.startswith(".."): raise ValueError( @@ -526,7 +546,7 @@ def get_command(self, ctx, workflow): ) module = os.path.splitext(rel_path)[0].replace(os.path.sep, ".") - wf_entity = load_naive_entity(module, workflow) + entity = load_naive_entity(module, exe_entity) # If this is a remote execution, which we should know at this point, then create the remote object p = ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_PROJECT) @@ -544,24 +564,24 @@ def get_command(self, ctx, workflow): # Add options for each of the workflow inputs params = [] - for input_name, input_type_val in wf_entity.python_interface.inputs_with_defaults.items(): - literal_var = wf_entity.interface.inputs.get(input_name) + for input_name, input_type_val in entity.python_interface.inputs_with_defaults.items(): + literal_var = entity.interface.inputs.get(input_name) python_type, default_val = input_type_val params.append( to_click_option(ctx, flyte_ctx, input_name, literal_var, python_type, default_val, get_upload_url_fn) ) cmd = click.Command( - name=workflow, + name=exe_entity, params=params, - callback=run_command(ctx, wf_entity), - help=f"Run {module}.{workflow} in script mode", + callback=run_command(ctx, entity), + help=f"Run {module}.{exe_entity} in script mode", ) return cmd class RunCommand(click.MultiCommand): """ - A click command group for registering and executing flyte workflows in a file. + A click command group for registering and executing flyte workflows & tasks in a file. """ def __init__(self, *args, **kwargs): @@ -573,11 +593,11 @@ def list_commands(self, ctx): def get_command(self, ctx, filename): ctx.obj[RUN_LEVEL_PARAMS_KEY] = ctx.params - return WorkflowCommand(filename, name=filename, help="Run a workflow in a file using script mode") + return WorkflowCommand(filename, name=filename, help="Run a [workflow|task] in a file using script mode") run = RunCommand( name="run", - help="Run command, a.k.a. script mode. It allows for a a single script to be " - + "registered and run from the command line (e.g. Jupyter notebooks).", + help="Run command: This command can execute either a workflow or a task from the commandline, for " + "fully self-contained scripts. Tasks and workflows cannot be imported from other files currently.", ) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 9dc1589c06..a00ac97222 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -509,7 +509,7 @@ def register_workflow( def register_script( self, - entity: WorkflowBase, + entity: typing.Union[WorkflowBase, PythonTask], image_config: typing.Optional[ImageConfig] = None, version: typing.Optional[str] = None, project: typing.Optional[str] = None, @@ -517,7 +517,7 @@ def register_script( destination_dir: str = ".", default_launch_plan: typing.Optional[bool] = True, options: typing.Optional[Options] = None, - ) -> FlyteWorkflow: + ) -> typing.Union[FlyteWorkflow, FlyteTask]: """ Use this method to register a workflow via script mode. :param destination_dir: @@ -525,7 +525,7 @@ def register_script( :param project: :param image_config: :param version: version for the entity to be registered as - :param entity: The workflow to be registered + :param entity: The workflow to be registered or the task to be registered :param default_launch_plan: This should be true if a default launch plan should be created for the workflow :param options: Additional execution options that can be configured for the default launchplan :return: @@ -567,6 +567,8 @@ def register_script( h.update(bytes(__version__, "utf-8")) version = base64.urlsafe_b64encode(h.digest()) + if isinstance(entity, PythonTask): + return self.register_task(entity, serialization_settings, version) return self.register_workflow(entity, serialization_settings, version, default_launch_plan, options) def register_launch_plan( diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index 5d3aad70a1..8fcbd3667c 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -3,11 +3,14 @@ from click.testing import CliRunner from flytekit.clis.sdk_in_container import pyflyte +from flytekit.clis.sdk_in_container.run import get_entities_in_file + +WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "workflow.py") def test_pyflyte_run_wf(): runner = CliRunner() - module_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "workflow.py") + module_path = WORKFLOW_FILE result = runner.invoke(pyflyte.main, ["run", module_path, "my_wf", "--help"], catch_exceptions=False) assert result.exit_code == 0 @@ -52,3 +55,10 @@ def test_pyflyte_run_cli(): ) print(result.stdout) assert result.exit_code == 0 + + +def test_get_entities_in_file(): + e = get_entities_in_file(WORKFLOW_FILE) + assert e.workflows == ["my_wf"] + assert e.tasks == ["get_subset_df", "print_all", "show_sd"] + assert e.all() == ["my_wf", "get_subset_df", "print_all", "show_sd"]