diff --git a/flytekit/tools/serialize_helpers.py b/flytekit/tools/serialize_helpers.py index f7937443be..69af2b96b4 100644 --- a/flytekit/tools/serialize_helpers.py +++ b/flytekit/tools/serialize_helpers.py @@ -17,6 +17,7 @@ from flytekit.models.admin.workflow import WorkflowSpec from flytekit.models.core import identifier as _identifier from flytekit.models.task import TaskSpec +from flytekit.remote.remote_callable import RemoteEntity from flytekit.tools.translator import FlyteControlPlaneEntity, Options, get_serializable @@ -40,7 +41,7 @@ def _should_register_with_admin(entity) -> bool: """ return isinstance( entity, (task_models.TaskSpec, _launch_plan_models.LaunchPlan, admin_workflow_models.WorkflowSpec) - ) + ) and not isinstance(entity, RemoteEntity) def _find_duplicate_tasks(tasks: typing.List[task_models.TaskSpec]) -> typing.Set[task_models.TaskSpec]: diff --git a/tests/flytekit/unit/cli/pyflyte/test_package.py b/tests/flytekit/unit/cli/pyflyte/test_package.py index 40d63021f2..e3ccb1d803 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_package.py +++ b/tests/flytekit/unit/cli/pyflyte/test_package.py @@ -7,12 +7,17 @@ import flytekit import flytekit.configuration import flytekit.tools.serialize_helpers +from flytekit import TaskMetadata from flytekit.clis.sdk_in_container import pyflyte from flytekit.core import context_manager from flytekit.exceptions.user import FlyteValidationException from flytekit.models.admin.workflow import WorkflowSpec +from flytekit.models.core.identifier import Identifier, ResourceType from flytekit.models.launch_plan import LaunchPlan from flytekit.models.task import TaskSpec +from flytekit.remote import FlyteTask +from flytekit.remote.interface import TypedInterface +from flytekit.remote.remote_callable import RemoteEntity sample_file_contents = """ from flytekit import task, workflow @@ -52,12 +57,25 @@ def test_get_registrable_entities(): ), ) ) - context_manager.FlyteEntities.entities = [foo, wf, "str"] + context_manager.FlyteEntities.entities = [ + foo, + wf, + "str", + FlyteTask( + id=Identifier(ResourceType.TASK, "p", "d", "n", "v"), + type="t", + metadata=TaskMetadata().to_taskmetadata_model(), + interface=TypedInterface(inputs={}, outputs={}), + custom=None, + ), + ] entities = flytekit.tools.serialize_helpers.get_registrable_entities(ctx) assert entities assert len(entities) == 3 for e in entities: + if isinstance(e, RemoteEntity): + assert False, "found unexpected remote entity" if isinstance(e, WorkflowSpec) or isinstance(e, TaskSpec) or isinstance(e, LaunchPlan): continue assert False, f"found unknown entity {type(e)}"