Skip to content

Commit

Permalink
Swap out inspect file location (#991)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored May 6, 2022
1 parent e2f6cdf commit c011ef7
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 3 deletions.
6 changes: 3 additions & 3 deletions flytekit/core/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def _resolve_abs_module_name(self, path: str, package_root: str) -> str:
if "__init__.py" not in os.listdir(dirname):
return basename

# Now recurse down such that we can extract the absolute module path
# Now recurse down such that we can extract the absolute module path
mod_name = self._resolve_abs_module_name(dirname, package_root)
final_mod_name = f"{mod_name}.{basename}" if mod_name else basename
self._module_cache[path] = final_mod_name
Expand Down Expand Up @@ -243,8 +243,8 @@ def extract_task_module(f: Union[Callable, TrackedInstance]) -> Tuple[str, str,
package_root = (
FeatureFlags.FLYTE_PYTHON_PACKAGE_ROOT if FeatureFlags.FLYTE_PYTHON_PACKAGE_ROOT != "auto" else None
)
new_mod_name = _mod_sanitizer.get_absolute_module_name(inspect.getabsfile(f), package_root)
new_mod_name = _mod_sanitizer.get_absolute_module_name(inspect.getabsfile(mod), package_root)
# We only replace the mod_name if it is more specific, else we already have a fully resolved path
if len(new_mod_name) > len(mod_name):
mod_name = new_mod_name
return f"{mod_name}.{name}", mod_name, name, os.path.abspath(inspect.getfile(f))
return f"{mod_name}.{name}", mod_name, name, os.path.abspath(inspect.getfile(mod))
Empty file.
23 changes: 23 additions & 0 deletions tests/flytekit/unit/core/functools/decorator_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Script used for testing local execution of functool.wraps-wrapped tasks for stacked decorators"""

from functools import wraps
from typing import List


def task_setup(function: callable = None, *, integration_requests: List = None) -> None:
integration_requests = integration_requests or []

@wraps(function)
def wrapper(*args, **kwargs):
# Preprocessing of task
print("preprocessing")

# Execute function
output = function(*args, **kwargs)

# Postprocessing of output
print("postprocessing")

return output

return functools.partial(task_setup, integration_requests=integration_requests) if function is None else wrapper
9 changes: 9 additions & 0 deletions tests/flytekit/unit/core/functools/decorator_usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from flytekit import task

from .decorator_source import task_setup


@task
@task_setup
def get_data(x: int) -> int:
return x + 1
17 changes: 17 additions & 0 deletions tests/flytekit/unit/core/functools/test_decorator_location.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import importlib

from flytekit.core.tracker import extract_task_module


def test_dont_use_wrapper_location():
m = importlib.import_module("tests.flytekit.unit.core.functools.decorator_usage")
get_data_task = getattr(m, "get_data")
assert "decorator_source" not in get_data_task.name
assert "decorator_usage" in get_data_task.name

a, b, c, _ = extract_task_module(get_data_task)
assert (a, b, c) == (
"tests.flytekit.unit.core.functools.decorator_usage.get_data",
"tests.flytekit.unit.core.functools.decorator_usage",
"get_data",
)

0 comments on commit c011ef7

Please sign in to comment.