From c011ef7cf47ac8ffc06c48e000cb309d9df99969 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Fri, 6 May 2022 16:21:44 -0700 Subject: [PATCH] Swap out inspect file location (#991) Signed-off-by: Yee Hing Tong --- flytekit/core/tracker.py | 6 ++--- .../flytekit/unit/core/functools/__init__.py | 0 .../unit/core/functools/decorator_source.py | 23 +++++++++++++++++++ .../unit/core/functools/decorator_usage.py | 9 ++++++++ .../core/functools/test_decorator_location.py | 17 ++++++++++++++ 5 files changed, 52 insertions(+), 3 deletions(-) create mode 100644 tests/flytekit/unit/core/functools/__init__.py create mode 100644 tests/flytekit/unit/core/functools/decorator_source.py create mode 100644 tests/flytekit/unit/core/functools/decorator_usage.py create mode 100644 tests/flytekit/unit/core/functools/test_decorator_location.py diff --git a/flytekit/core/tracker.py b/flytekit/core/tracker.py index 9db8233a05..0fad8335c2 100644 --- a/flytekit/core/tracker.py +++ b/flytekit/core/tracker.py @@ -199,7 +199,7 @@ def _resolve_abs_module_name(self, path: str, package_root: str) -> str: if "__init__.py" not in os.listdir(dirname): return basename - # Now recurse down such that we can extract the absolute module path + # Now recurse down such that we can extract the absolute module path mod_name = self._resolve_abs_module_name(dirname, package_root) final_mod_name = f"{mod_name}.{basename}" if mod_name else basename self._module_cache[path] = final_mod_name @@ -243,8 +243,8 @@ def extract_task_module(f: Union[Callable, TrackedInstance]) -> Tuple[str, str, package_root = ( FeatureFlags.FLYTE_PYTHON_PACKAGE_ROOT if FeatureFlags.FLYTE_PYTHON_PACKAGE_ROOT != "auto" else None ) - new_mod_name = _mod_sanitizer.get_absolute_module_name(inspect.getabsfile(f), package_root) + new_mod_name = _mod_sanitizer.get_absolute_module_name(inspect.getabsfile(mod), package_root) # We only replace the mod_name if it is more specific, else we already have a fully resolved path if len(new_mod_name) > len(mod_name): mod_name = new_mod_name - return f"{mod_name}.{name}", mod_name, name, os.path.abspath(inspect.getfile(f)) + return f"{mod_name}.{name}", mod_name, name, os.path.abspath(inspect.getfile(mod)) diff --git a/tests/flytekit/unit/core/functools/__init__.py b/tests/flytekit/unit/core/functools/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/core/functools/decorator_source.py b/tests/flytekit/unit/core/functools/decorator_source.py new file mode 100644 index 0000000000..9c92364649 --- /dev/null +++ b/tests/flytekit/unit/core/functools/decorator_source.py @@ -0,0 +1,23 @@ +"""Script used for testing local execution of functool.wraps-wrapped tasks for stacked decorators""" + +from functools import wraps +from typing import List + + +def task_setup(function: callable = None, *, integration_requests: List = None) -> None: + integration_requests = integration_requests or [] + + @wraps(function) + def wrapper(*args, **kwargs): + # Preprocessing of task + print("preprocessing") + + # Execute function + output = function(*args, **kwargs) + + # Postprocessing of output + print("postprocessing") + + return output + + return functools.partial(task_setup, integration_requests=integration_requests) if function is None else wrapper diff --git a/tests/flytekit/unit/core/functools/decorator_usage.py b/tests/flytekit/unit/core/functools/decorator_usage.py new file mode 100644 index 0000000000..bdd4fca2d9 --- /dev/null +++ b/tests/flytekit/unit/core/functools/decorator_usage.py @@ -0,0 +1,9 @@ +from flytekit import task + +from .decorator_source import task_setup + + +@task +@task_setup +def get_data(x: int) -> int: + return x + 1 diff --git a/tests/flytekit/unit/core/functools/test_decorator_location.py b/tests/flytekit/unit/core/functools/test_decorator_location.py new file mode 100644 index 0000000000..d896d09592 --- /dev/null +++ b/tests/flytekit/unit/core/functools/test_decorator_location.py @@ -0,0 +1,17 @@ +import importlib + +from flytekit.core.tracker import extract_task_module + + +def test_dont_use_wrapper_location(): + m = importlib.import_module("tests.flytekit.unit.core.functools.decorator_usage") + get_data_task = getattr(m, "get_data") + assert "decorator_source" not in get_data_task.name + assert "decorator_usage" in get_data_task.name + + a, b, c, _ = extract_task_module(get_data_task) + assert (a, b, c) == ( + "tests.flytekit.unit.core.functools.decorator_usage.get_data", + "tests.flytekit.unit.core.functools.decorator_usage", + "get_data", + )