diff --git a/flytekit/clis/sdk_in_container/constants.py b/flytekit/clis/sdk_in_container/constants.py index d0d7f7a229..46513553b9 100644 --- a/flytekit/clis/sdk_in_container/constants.py +++ b/flytekit/clis/sdk_in_container/constants.py @@ -7,6 +7,8 @@ CTX_PACKAGES = "pkgs" CTX_NOTIFICATIONS = "notifications" CTX_CONFIG_FILE = "config_file" +CTX_PROJECT_ROOT = "project_root" +CTX_MODULE = "module" project_option = _click.option( diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 95533fb4d5..935cfc1ad3 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -15,11 +15,17 @@ from typing_extensions import get_args from flytekit import BlobType, Literal, Scalar -from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE, CTX_DOMAIN, CTX_PROJECT +from flytekit.clis.sdk_in_container.constants import ( + CTX_CONFIG_FILE, + CTX_DOMAIN, + CTX_MODULE, + CTX_PROJECT, + CTX_PROJECT_ROOT, +) from flytekit.clis.sdk_in_container.helpers import FLYTE_REMOTE_INSTANCE_KEY, get_and_save_remote_with_click_context from flytekit.configuration import ImageConfig from flytekit.configuration.default_images import DefaultImages -from flytekit.core import context_manager, tracker +from flytekit.core import context_manager from flytekit.core.base_task import PythonTask from flytekit.core.context_manager import FlyteContext from flytekit.core.data_persistence import FileAccessProvider @@ -480,14 +486,12 @@ def get_entities_in_file(filename: str) -> Entities: workflows = [] tasks = [] module = importlib.import_module(module_name) - for k in dir(module): - o = module.__dict__[k] - if isinstance(o, PythonFunctionWorkflow): - _, _, fn, _ = tracker.extract_task_module(o) - workflows.append(fn) + for name in dir(module): + o = module.__dict__[name] + if isinstance(o, WorkflowBase): + workflows.append(name) elif isinstance(o, PythonTask): - _, _, fn, _ = tracker.extract_task_module(o) - tasks.append(fn) + tasks.append(name) return Entities(workflows, tasks) @@ -542,6 +546,8 @@ def _run(*args, **kwargs): domain=domain, image_config=image_config, destination_dir=run_level_params.get("destination_dir"), + source_path=ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_PROJECT_ROOT), + module_name=ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_MODULE), ) options = None @@ -602,11 +608,16 @@ def get_command(self, ctx, exe_entity): ) project_root = _find_project_root(self._filename) + # Find the relative path for the filename relative to the root of the project. # N.B.: by construction project_root will necessarily be an ancestor of the filename passed in as # a parameter. rel_path = self._filename.relative_to(project_root) module = os.path.splitext(rel_path)[0].replace(os.path.sep, ".") + + ctx.obj[RUN_LEVEL_PARAMS_KEY][CTX_PROJECT_ROOT] = project_root + ctx.obj[RUN_LEVEL_PARAMS_KEY][CTX_MODULE] = module + entity = load_naive_entity(module, exe_entity, project_root) # If this is a remote execution, which we should know at this point, then create the remote object diff --git a/flytekit/core/tracker.py b/flytekit/core/tracker.py index 0fad8335c2..9851e2e98b 100644 --- a/flytekit/core/tracker.py +++ b/flytekit/core/tracker.py @@ -4,6 +4,7 @@ import inspect as _inspect import os import typing +from types import ModuleType from typing import Callable, Tuple, Union from flytekit.configuration.feature_flags import FeatureFlags @@ -239,6 +240,11 @@ def extract_task_module(f: Union[Callable, TrackedInstance]) -> Tuple[str, str, if mod_name == "__main__": return name, "", name, os.path.abspath(inspect.getfile(f)) + mod_name = get_full_module_path(mod, mod_name) + return f"{mod_name}.{name}", mod_name, name, os.path.abspath(inspect.getfile(mod)) + + +def get_full_module_path(mod: ModuleType, mod_name: str) -> str: if FeatureFlags.FLYTE_PYTHON_PACKAGE_ROOT != ".": package_root = ( FeatureFlags.FLYTE_PYTHON_PACKAGE_ROOT if FeatureFlags.FLYTE_PYTHON_PACKAGE_ROOT != "auto" else None @@ -247,4 +253,4 @@ def extract_task_module(f: Union[Callable, TrackedInstance]) -> Tuple[str, str, # We only replace the mod_name if it is more specific, else we already have a fully resolved path if len(new_mod_name) > len(mod_name): mod_name = new_mod_name - return f"{mod_name}.{name}", mod_name, name, os.path.abspath(inspect.getfile(mod)) + return mod_name diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 99f54b7933..f02226decc 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -577,6 +577,8 @@ def register_script( destination_dir: str = ".", default_launch_plan: typing.Optional[bool] = True, options: typing.Optional[Options] = None, + source_path: typing.Optional[str] = None, + module_name: typing.Optional[str] = None, ) -> typing.Union[FlyteWorkflow, FlyteTask]: """ Use this method to register a workflow via script mode. @@ -588,13 +590,16 @@ def register_script( :param entity: The workflow to be registered or the task to be registered :param default_launch_plan: This should be true if a default launch plan should be created for the workflow :param options: Additional execution options that can be configured for the default launchplan + :param source_path: The root of the project path + :param module_name: the name of the module :return: """ if image_config is None: image_config = ImageConfig.auto_default_image() upload_location, md5_bytes = fast_register_single_script( - entity, + source_path, + module_name, functools.partial( self.client.get_upload_signed_url, project=project or self.default_project, diff --git a/flytekit/tools/script_mode.py b/flytekit/tools/script_mode.py index f837447637..29b617824c 100644 --- a/flytekit/tools/script_mode.py +++ b/flytekit/tools/script_mode.py @@ -1,5 +1,6 @@ import gzip import hashlib +import importlib import os import shutil import tarfile @@ -10,8 +11,7 @@ from flyteidl.service import dataproxy_pb2 as _data_proxy_pb2 from flytekit.core import context_manager -from flytekit.core.tracker import extract_task_module -from flytekit.core.workflow import WorkflowBase +from flytekit.core.tracker import get_full_module_path def compress_single_script(source_path: str, destination: str, full_module_name: str): @@ -97,16 +97,14 @@ def tar_strip_file_attributes(tar_info: tarfile.TarInfo) -> tarfile.TarInfo: def fast_register_single_script( - wf_entity: WorkflowBase, create_upload_location_fn: typing.Callable + source_path: str, module_name: str, create_upload_location_fn: typing.Callable ) -> (_data_proxy_pb2.CreateUploadLocationResponse, bytes): - _, mod_name, _, script_full_path = extract_task_module(wf_entity) - # Find project root by moving up the folder hierarchy until you cannot find a __init__.py file. - source_path = _find_project_root(script_full_path) # Open a temp directory and dump the contents of the digest. with tempfile.TemporaryDirectory() as tmp_dir: archive_fname = os.path.join(tmp_dir, "script_mode.tar.gz") - compress_single_script(source_path, archive_fname, mod_name) + mod = importlib.import_module(module_name) + compress_single_script(source_path, archive_fname, get_full_module_path(mod, mod.__name__)) flyte_ctx = context_manager.FlyteContextManager.current_context() md5, _ = hash_file(archive_fname) diff --git a/tests/flytekit/unit/cli/pyflyte/imperative_wf.py b/tests/flytekit/unit/cli/pyflyte/imperative_wf.py new file mode 100644 index 0000000000..12d7f2e3a3 --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/imperative_wf.py @@ -0,0 +1,39 @@ +import typing + +from flytekit import Workflow, task + + +@task +def t1(a: str) -> str: + return a + " world" + + +@task +def t2(): + print("side effect") + + +@task +def t3(a: typing.List[str]) -> str: + return ",".join(a) + + +wf = Workflow(name="my.imperative.workflow.example") +wf.add_workflow_input("in1", str) +node_t1 = wf.add_entity(t1, a=wf.inputs["in1"]) +wf.add_workflow_output("output_from_t1", node_t1.outputs["o0"]) +wf.add_entity(t2) + +wf_in2 = wf.add_workflow_input("in2", str) +node_t3 = wf.add_entity(t3, a=[wf.inputs["in1"], wf_in2]) + +wf.add_workflow_output( + "output_list", + [node_t1.outputs["o0"], node_t3.outputs["o0"]], + python_type=typing.List[str], +) + + +if __name__ == "__main__": + print(wf) + print(wf(in1="hello", in2="foo")) diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index 34f28f00b6..5bc94592b9 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -19,6 +19,7 @@ from flytekit.core.task import task WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "workflow.py") +IMPERATIVE_WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "imperative_wf.py") DIR_NAME = os.path.dirname(os.path.realpath(__file__)) @@ -30,6 +31,16 @@ def test_pyflyte_run_wf(): assert result.exit_code == 0 +def test_imperative_wf(): + runner = CliRunner() + result = runner.invoke( + pyflyte.main, + ["run", IMPERATIVE_WORKFLOW_FILE, "wf", "--in1", "hello", "--in2", "world"], + catch_exceptions=False, + ) + assert result.exit_code == 0 + + def test_pyflyte_run_cli(): runner = CliRunner() result = runner.invoke(