Skip to content

Commit

Permalink
refactor(core): Improve task module extraction logic (#2290)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Apr 3, 2024
1 parent a240b65 commit df5dbea
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 3 deletions.
5 changes: 4 additions & 1 deletion flytekit/core/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,11 @@ def extract_task_module(f: Union[Callable, TrackedInstance]) -> Tuple[str, str,
if mod_name == "__main__":
if hasattr(f, "task_function"):
f = f.task_function
# If the module is __main__, we need to find the actual module name based on the file path
inspect_file = inspect.getfile(f) # type: ignore
return name, "", name, os.path.abspath(inspect_file)
file_name, _ = os.path.splitext(os.path.basename(inspect_file))
mod_name = get_full_module_path(f, file_name) # type: ignore
return name, mod_name, name, os.path.abspath(inspect_file)

mod_name = get_full_module_path(mod, mod_name)
return f"{mod_name}.{name}", mod_name, name, os.path.abspath(inspect.getfile(mod))
Expand Down
18 changes: 16 additions & 2 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from flytekit.core.python_auto_container import PythonAutoContainerTask
from flytekit.core.reference_entity import ReferenceSpec
from flytekit.core.task import ReferenceTask
from flytekit.core.tracker import extract_task_module
from flytekit.core.type_engine import LiteralsResolver, TypeEngine
from flytekit.core.workflow import ReferenceWorkflow, WorkflowBase, WorkflowFailurePolicy
from flytekit.exceptions import user as user_exceptions
Expand Down Expand Up @@ -82,7 +83,7 @@
from flytekit.remote.remote_fs import get_flyte_fs
from flytekit.tools.fast_registration import fast_package
from flytekit.tools.interactive import ipython_check
from flytekit.tools.script_mode import compress_scripts, hash_file
from flytekit.tools.script_mode import _find_project_root, compress_scripts, hash_file
from flytekit.tools.translator import (
FlyteControlPlaneEntity,
FlyteLocalEntity,
Expand Down Expand Up @@ -778,7 +779,10 @@ def _serialize_and_register(
return ident

def register_task(
self, entity: PythonTask, serialization_settings: SerializationSettings, version: typing.Optional[str] = None
self,
entity: PythonTask,
serialization_settings: typing.Optional[SerializationSettings] = None,
version: typing.Optional[str] = None,
) -> FlyteTask:
"""
Register a qualified task (PythonTask) with Remote
Expand All @@ -789,6 +793,16 @@ def register_task(
:param version: version that will be used to register. If not specified will default to using the serialization settings default
:return:
"""
# Create a default serialization settings object if not provided
# It makes registration easier for the user
if serialization_settings is None:
_, _, _, module_file = extract_task_module(entity)
project_root = _find_project_root(module_file)
serialization_settings = SerializationSettings(
image_config=ImageConfig.auto_default_image(),
source_root=project_root,
)

ident = self._serialize_and_register(entity=entity, settings=serialization_settings, version=version)
ft = self.fetch_task(
ident.project,
Expand Down
4 changes: 4 additions & 0 deletions plugins/flytekit-spark/tests/test_remote_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,7 @@ def my_python_task(a: str) -> int:

# Check if the serialized python task has no mainApplicaitonFile field set by default.
assert serialized_spec.template.custom is None

remote.register_task(my_python_task, version="v1")
serialized_spec = mock_client.create_task.call_args.kwargs["task_spec"]
assert serialized_spec.template.custom is None

0 comments on commit df5dbea

Please sign in to comment.